1 #include <catch2/catch_all.hpp>
14 extern std::unique_ptr<std::mt19937_64>
mersenne;
15 std::vector<size_t>
cols(j);
16 std::iota(
cols.begin(),
cols.end(), 0);
35 CHECK(x.isApprox(out / 0.5 * x.stableNorm()));
36 ball.
epsilon(x.stableNorm() * 1.001);
38 CHECK(x.isApprox(out));
50 ball((x.array() * weights.array()).matrix()).array() / weights.array();
52 CHECK(actual.isApprox(expected));
54 wball.
epsilon((x.array() * weights.array()).matrix().stableNorm() * 1.001);
55 CHECK(x.isApprox(wball(x)));
65 eucl(out, x.stableNorm() * 1.001, x);
69 CHECK(out.isApprox(x * (1e0 - 0.1 / x.stableNorm())));
80 translated(out, 0, x);
81 CHECK(out.isApprox(x));
84 out = translated(0, x);
86 CHECK(out.isApprox(expected));
89 TEST_CASE(
"Tight-Frame L1 proximal",
"[l1][proximal]") {
95 auto const mini = l1.objective(x, p, gamma);
96 auto constexpr eps = 1e-4;
101 CHECK(l1.objective(x, p_plus, gamma) >= mini);
109 SECTION(
"Scalar weights") {
112 check_is_minimum(input, 0.664);
116 SECTION(
"vector weights") {
120 check_is_minimum(input, 0.664);
123 SECTION(
"vector weights with random values") {
125 check_is_minimum(input, 0.235);
128 SECTION(
"Psi is a concatenation of permutations") {
129 auto const psi = concatenated_permutations<t_complex>(input.size(), input.size() * 10);
130 l1.Psi(psi).weights(1e0);
131 check_is_minimum(input, 0.235);
134 #ifdef CATCH_HAS_THROWS_AS
135 SECTION(
"Weights cannot be negative") {
136 CHECK_THROWS_AS(l1.weights(-1e0),
Exception);
139 CHECK_THROWS_AS(l1.weights(weights),
Exception);
145 using namespace sopt;
152 SECTION(
"No Mixing") {
154 CHECK(output.isApprox(2.1 * input));
156 CHECK(output.isApprox(4.1 * input));
159 SECTION(
"Fista Mixing") {
162 fista(output, 2.1 * input, 0);
163 CHECK(output.isApprox(2.1 * input));
165 fista(output, 3.1 * input, 1);
166 auto const alpha = (fista.
next(1) - 1) / fista.
next(fista.
next(1));
167 Vector<Scalar> const first = (1e0 + alpha) * 3.1 * input - alpha * 2.1 * input;
168 CHECK(output.isApprox(first));
170 fista(output, 4.1 * input, 1);
172 Vector<Scalar> const second = (1e0 + alpha) * 4.1 * input - alpha * first;
173 CHECK(output.isApprox(second));
179 SECTION(
"Finds convergence") {
180 std::vector<t_real> objectives = {
181 1.0, 0.9, 0.5, 0.6, 0.4, 0.4 + 0.41 * 1e-8, 0.3, 0.3 + 0.29 * 1e-8};
182 for (
size_t i(0); i < objectives.size() - 1; ++i) {
183 CHECK(not breaker(objectives[i]));
186 CHECK(breaker(objectives.back()));
191 SECTION(
"Find cycle") {
192 std::vector<t_real> objectives = {1.0, 0.9, 0.5, 0.6, 0.4, 0.3, 0.4, 0.3};
193 for (
size_t i(0); i < objectives.size() - 1; ++i) {
194 CHECK(not breaker(objectives[i]));
197 CHECK(breaker(objectives.back()));
205 using namespace sopt;
211 SECTION(
"Check against tight-frame") {
212 l1.fista_mixing(
false);
213 SECTION(
"Scalar weights") {
214 auto const result = l1(1, input);
216 CHECK(result.niters > 0);
217 CHECK(l1.itermax() == 0);
220 SECTION(
"Vector weights and more complex Psi") {
221 auto const Psi = concatenated_permutations<Scalar>(input.size(), input.size() * 10);
223 auto const gamma = 1e-1 / Psi.array().abs().sum();
224 l1.Psi(Psi).weights(weights).tolerance(1e-12);
225 auto const result = l1(gamma, input);
227 CHECK(result.niters > 0);
228 auto const expected = l1.tight_frame(gamma, input).eval();
229 CHECK(result.objective == Approx(l1.objective(input, expected, gamma)));
230 CAPTURE((result.proximal - expected).array().abs().transpose());
231 CHECK(result.proximal.isApprox(expected));
235 SECTION(
"General case") {
239 auto const mini = l1.objective(input, proximal, gamma);
240 auto constexpr eps = 1e-3;
246 if (l1.positivity_constraint())
248 else if (l1.real_constraint())
249 p_plus = p_plus.real().cast<
Scalar>();
250 auto const rel_var = std::abs((l1.objective(input, p_plus, gamma) - mini) / mini);
251 CHECK((l1.objective(input, p_plus, gamma) > mini or rel_var < l1.tolerance() * 10));
255 for (
size_t i(0); i < 10; ++i) {
256 Vector<Scalar> p_plus = proximal + proximal.Random(proximal.size()) * eps;
257 if (l1.positivity_constraint())
259 else if (l1.real_constraint())
260 p_plus = p_plus.real().cast<
Scalar>();
261 auto const rel_var = std::abs((l1.objective(input, p_plus, gamma) - mini) / mini);
262 CHECK((l1.objective(input, p_plus, gamma) > mini or rel_var < l1.tolerance() * 10));
268 auto const gamma = 1e-1 / Psi.array().abs().sum();
270 l1.Psi(Psi).weights(weights).fista_mixing(
true).tolerance(1e-10).itermax(5000);
272 SECTION(
"No constraints") {
273 CHECK(not l1.positivity_constraint());
274 CHECK(not l1.real_constraint());
275 auto const result = l1(gamma, input);
277 check_is_minimum(gamma, result.proximal);
279 SECTION(
"Positivity constraints") {
280 l1.positivity_constraint(
true);
281 CHECK(l1.positivity_constraint());
282 CHECK(not l1.real_constraint());
283 auto const result = l1(gamma, input);
285 check_is_minimum(gamma, result.proximal);
287 SECTION(
"Real constraints") {
288 l1.real_constraint(
true);
289 CHECK(l1.real_constraint());
290 CHECK(not l1.positivity_constraint());
291 auto const result = l1(gamma, input);
293 check_is_minimum(gamma, result.proximal);
Computes inner-most element type.
Proximal of euclidian norm.
L1 proximal, including linear transform.
bool converged() const
True if relative variation smaller than tolerance.
bool two_cycle() const
Whether we have a cycle of period two.
Real current() const
Current objective.
Proximal for indicator function of L2 ball.
Real epsilon() const
Size of the ball.
Real epsilon() const
Size of the ball.
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))
Translation< FUNCTION, VECTOR > translate(FUNCTION const &func, VECTOR const &translation)
Translates given proximal by given vector.
void l1_norm(Eigen::DenseBase< T0 > &out, typename real_type< typename T0::Scalar >::type gamma, Eigen::DenseBase< T1 > const &x)
Proximal of the l1 norm.
Eigen::CwiseUnaryOp< const details::ProjectPositiveQuadrant< typename T::Scalar >, const T > positive_quadrant(Eigen::DenseBase< T > const &input)
Expression to create projection onto positive quadrant.
double t_real
Root of the type hierarchy for real numbers.
size_t t_uint
Root of the type hierarchy for unsigned integers.
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
std::complex< t_real > t_complex
Root of the type hierarchy for (real) complex numbers.
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > Matrix
A matrix of a given type.
TEST_CASE("L2Ball", "[proximal]")
sopt::Matrix< T > concatenated_permutations(sopt::t_uint i, sopt::t_uint j)