SOPT
Sparse OPTimisation
proximal.cc
Go to the documentation of this file.
1 #include <catch2/catch_all.hpp>
2 #include <numeric>
3 #include <random>
4 #include <utility>
5 
6 #include "sopt/l1_proximal.h"
7 #include "sopt/proximal.h"
8 #include "sopt/types.h"
9 
10 using Catch::Approx;
11 
12 template <typename T>
14  extern std::unique_ptr<std::mt19937_64> mersenne;
15  std::vector<size_t> cols(j);
16  std::iota(cols.begin(), cols.end(), 0);
17  std::shuffle(cols.begin(), cols.end(), *mersenne);
18 
19  assert(j % i == 0);
20  auto const N = j / i;
21  auto const elem = 1e0 / std::sqrt(static_cast<typename sopt::real_type<T>::type>(N));
22  sopt::Matrix<T> result = sopt::Matrix<T>::Zero(i, cols.size());
23  for (typename sopt::Matrix<T>::Index k(0); k < result.cols(); ++k) result(cols[k] / N, k) = elem;
24  return result;
25 }
26 
27 TEST_CASE("L2Ball", "[proximal]") {
28  using namespace sopt;
29  proximal::L2Ball<t_real> ball(0.5);
30  Vector<t_real> out;
31  Vector<t_real> x(5);
32  x << 1, 2, 3, 4, 5;
33 
34  out = ball(0, x);
35  CHECK(x.isApprox(out / 0.5 * x.stableNorm()));
36  ball.epsilon(x.stableNorm() * 1.001);
37  out = ball(0, x);
38  CHECK(x.isApprox(out));
39 }
40 
41 TEST_CASE("WeightedL2Ball", "[proximal]") {
42  using namespace sopt;
43  Vector<t_real> const weights = 0.01 * Vector<t_real>::Random(5).array() + 1e0;
44  Vector<t_real> x(5);
45  x << 1, 2, 3, 4, 5;
46  proximal::WeightedL2Ball<t_real> wball(0.5, weights);
47  proximal::L2Ball<t_real> const ball(0.5);
48 
49  Vector<t_real> const expected =
50  ball((x.array() * weights.array()).matrix()).array() / weights.array();
51  Vector<t_real> const actual = wball(x);
52  CHECK(actual.isApprox(expected));
53 
54  wball.epsilon((x.array() * weights.array()).matrix().stableNorm() * 1.001);
55  CHECK(x.isApprox(wball(x)));
56 }
57 
58 TEST_CASE("Euclidian norm", "[proximal]") {
59  using namespace sopt;
60  proximal::EuclidianNorm const eucl;
61 
62  Vector<t_real> out(5);
63  Vector<t_real> x(5);
64  x << 1, 2, 3, 4, 5;
65  eucl(out, x.stableNorm() * 1.001, x);
66  CHECK(out.isApprox(Vector<t_real>::Zero(x.size())));
67 
68  out = eucl(0.1, x);
69  CHECK(out.isApprox(x * (1e0 - 0.1 / x.stableNorm())));
70 }
71 
72 TEST_CASE("Translation", "[proximal]") {
73  using namespace sopt;
74  Vector<t_real> out(5);
75  Vector<t_real> x(5);
76  x << 1, 2, 3, 4, 5;
77  proximal::L2Ball<t_real> ball(5000);
78  // Pass in a reference, so we can modify ball.epsilon later in the test.
79  auto const translated = proximal::translate(std::ref(ball), -x * 0.5);
80  translated(out, 0, x);
81  CHECK(out.isApprox(x));
82 
83  ball.epsilon(0.125);
84  out = translated(0, x);
85  Vector<t_real> const expected = ball(1, x * 0.5) + x * 0.5;
86  CHECK(out.isApprox(expected));
87 }
88 
89 TEST_CASE("Tight-Frame L1 proximal", "[l1][proximal]") {
90  using namespace sopt;
92  auto check_is_minimum = [&l1](Vector<t_complex> const &x, t_real gamma = 1e0) {
93  using Scalar = t_complex;
94  Vector<t_complex> const p = l1(gamma, x);
95  auto const mini = l1.objective(x, p, gamma);
96  auto constexpr eps = 1e-4;
97  for (Vector<t_complex>::Index i(0); i < p.size(); ++i) {
98  for (auto const dir : {Scalar(eps, 0), Scalar(0, eps), Scalar(-eps, 0), Scalar(0, -eps)}) {
99  Vector<t_complex> p_plus = p;
100  p_plus[i] += dir;
101  CHECK(l1.objective(x, p_plus, gamma) >= mini);
102  }
103  }
104  };
105 
107 
108  // no weights
109  SECTION("Scalar weights") {
110  CHECK(l1(1, input).isApprox(proximal::l1_norm(1, input)));
111  CHECK(l1(0.3, input).isApprox(proximal::l1_norm(0.3, input)));
112  check_is_minimum(input, 0.664);
113  }
114 
115  // with weights == 1
116  SECTION("vector weights") {
117  l1.weights(Vector<t_real>::Ones(input.size()));
118  CHECK(l1(1, input).isApprox(proximal::l1_norm(1, input)));
119  CHECK(l1(0.2, input).isApprox(proximal::l1_norm(0.2, input)));
120  check_is_minimum(input, 0.664);
121  }
122 
123  SECTION("vector weights with random values") {
124  l1.weights(Vector<t_real>::Random(input.size()).array().abs().matrix());
125  check_is_minimum(input, 0.235);
126  }
127 
128  SECTION("Psi is a concatenation of permutations") {
129  auto const psi = concatenated_permutations<t_complex>(input.size(), input.size() * 10);
130  l1.Psi(psi).weights(1e0);
131  check_is_minimum(input, 0.235);
132  }
133 
134 #ifdef CATCH_HAS_THROWS_AS
135  SECTION("Weights cannot be negative") {
136  CHECK_THROWS_AS(l1.weights(-1e0), Exception);
137  Vector<t_real> weights = Vector<t_real>::Random(5).array().abs().matrix();
138  weights[2] = -1;
139  CHECK_THROWS_AS(l1.weights(weights), Exception);
140  }
141 #endif
142 }
143 
144 TEST_CASE("L1 proximal utilities", "[l1][utilities]") {
145  using namespace sopt;
146  using Scalar = t_complex;
147 
148  SECTION("Mixing") {
149  auto const input = Vector<Scalar>::Random(10).eval();
150  Vector<Scalar> output;
151 
152  SECTION("No Mixing") {
153  proximal::L1<Scalar>::NoMixing()(output, 2.1 * input, 0);
154  CHECK(output.isApprox(2.1 * input));
155  proximal::L1<Scalar>::NoMixing()(output, 4.1 * input, 10);
156  CHECK(output.isApprox(4.1 * input));
157  }
158 
159  SECTION("Fista Mixing") {
161  // step zero: no mixing yet
162  fista(output, 2.1 * input, 0);
163  CHECK(output.isApprox(2.1 * input));
164  // step one: first mixing
165  fista(output, 3.1 * input, 1);
166  auto const alpha = (fista.next(1) - 1) / fista.next(fista.next(1));
167  Vector<Scalar> const first = (1e0 + alpha) * 3.1 * input - alpha * 2.1 * input;
168  CHECK(output.isApprox(first));
169  // step two: second mixing
170  fista(output, 4.1 * input, 1);
171  auto const beta = (fista.next(fista.next(1)) - 1) / fista.next(fista.next(fista.next(1)));
172  Vector<Scalar> const second = (1e0 + alpha) * 4.1 * input - alpha * first;
173  CHECK(output.isApprox(second));
174  }
175  }
176 
177  SECTION("Breaker") {
178  proximal::L1<Scalar>::Breaker breaker(2e0);
179  SECTION("Finds convergence") {
180  std::vector<t_real> objectives = {
181  1.0, 0.9, 0.5, 0.6, 0.4, 0.4 + 0.41 * 1e-8, 0.3, 0.3 + 0.29 * 1e-8};
182  for (size_t i(0); i < objectives.size() - 1; ++i) {
183  CHECK(not breaker(objectives[i]));
184  CHECK(breaker.current() == Approx(objectives[i]).epsilon(1e-12));
185  }
186  CHECK(breaker(objectives.back()));
187  CHECK(not breaker.two_cycle());
188  CHECK(breaker.converged());
189  }
190 
191  SECTION("Find cycle") {
192  std::vector<t_real> objectives = {1.0, 0.9, 0.5, 0.6, 0.4, 0.3, 0.4, 0.3};
193  for (size_t i(0); i < objectives.size() - 1; ++i) {
194  CHECK(not breaker(objectives[i]));
195  CHECK(breaker.current() == Approx(objectives[i]).epsilon(1e-12));
196  }
197  CHECK(breaker(objectives.back()));
198  CHECK(breaker.two_cycle());
199  CHECK(not breaker.converged());
200  }
201  }
202 }
203 
204 TEST_CASE("L1 proximal", "[l1][proximal]") {
205  using namespace sopt;
206  using Scalar = t_complex;
207  auto l1 = proximal::L1<Scalar>().tolerance(1e-10);
208 
209  Vector<Scalar> const input = Vector<Scalar>::Random(4);
210 
211  SECTION("Check against tight-frame") {
212  l1.fista_mixing(false);
213  SECTION("Scalar weights") {
214  auto const result = l1(1, input);
215  CHECK(result.good);
216  CHECK(result.niters > 0);
217  CHECK(l1.itermax() == 0);
218  CHECK(result.proximal.isApprox(proximal::L1TightFrame<Scalar>()(1, input)));
219  }
220  SECTION("Vector weights and more complex Psi") {
221  auto const Psi = concatenated_permutations<Scalar>(input.size(), input.size() * 10);
222  auto const weights = Vector<t_real>::Random(Psi.cols()).array().abs().matrix().eval();
223  auto const gamma = 1e-1 / Psi.array().abs().sum();
224  l1.Psi(Psi).weights(weights).tolerance(1e-12);
225  auto const result = l1(gamma, input);
226  CHECK(result.good);
227  CHECK(result.niters > 0);
228  auto const expected = l1.tight_frame(gamma, input).eval();
229  CHECK(result.objective == Approx(l1.objective(input, expected, gamma)));
230  CAPTURE((result.proximal - expected).array().abs().transpose());
231  CHECK(result.proximal.isApprox(expected));
232  }
233  }
234 
235  SECTION("General case") {
236  auto check_is_minimum = [&l1, &input](t_real gamma, Vector<Scalar> const &proximal) {
237  // returns false if did not converge.
238  // Looks like computing the proximal does not always work...
239  auto const mini = l1.objective(input, proximal, gamma);
240  auto constexpr eps = 1e-3;
241  // check alongst specific directions
242  for (Vector<Scalar>::Index i(0); i < proximal.size(); ++i) {
243  for (auto const dir : {Scalar(eps, 0), Scalar(0, eps), Scalar(-eps, 0), Scalar(0, -eps)}) {
244  Vector<Scalar> p_plus = proximal;
245  p_plus[i] += dir;
246  if (l1.positivity_constraint())
247  p_plus = sopt::positive_quadrant(p_plus);
248  else if (l1.real_constraint())
249  p_plus = p_plus.real().cast<Scalar>();
250  auto const rel_var = std::abs((l1.objective(input, p_plus, gamma) - mini) / mini);
251  CHECK((l1.objective(input, p_plus, gamma) > mini or rel_var < l1.tolerance() * 10));
252  }
253  }
254  // check alongst non-specific directions
255  for (size_t i(0); i < 10; ++i) {
256  Vector<Scalar> p_plus = proximal + proximal.Random(proximal.size()) * eps;
257  if (l1.positivity_constraint())
258  p_plus = sopt::positive_quadrant(p_plus);
259  else if (l1.real_constraint())
260  p_plus = p_plus.real().cast<Scalar>();
261  auto const rel_var = std::abs((l1.objective(input, p_plus, gamma) - mini) / mini);
262  CHECK((l1.objective(input, p_plus, gamma) > mini or rel_var < l1.tolerance() * 10));
263  }
264  };
265 
266  auto const Psi = Matrix<Scalar>::Random(input.size(), input.size() * 10).eval();
267  auto const weights = Vector<t_real>::Random(Psi.cols()).array().abs().matrix().eval();
268  auto const gamma = 1e-1 / Psi.array().abs().sum();
269 
270  l1.Psi(Psi).weights(weights).fista_mixing(true).tolerance(1e-10).itermax(5000);
271 
272  SECTION("No constraints") {
273  CHECK(not l1.positivity_constraint());
274  CHECK(not l1.real_constraint());
275  auto const result = l1(gamma, input);
276  CHECK(result.good);
277  check_is_minimum(gamma, result.proximal);
278  }
279  SECTION("Positivity constraints") {
280  l1.positivity_constraint(true);
281  CHECK(l1.positivity_constraint());
282  CHECK(not l1.real_constraint());
283  auto const result = l1(gamma, input);
284  CHECK(result.good);
285  check_is_minimum(gamma, result.proximal);
286  }
287  SECTION("Real constraints") {
288  l1.real_constraint(true);
289  CHECK(l1.real_constraint());
290  CHECK(not l1.positivity_constraint());
291  auto const result = l1(gamma, input);
292  CHECK(result.good);
293  check_is_minimum(gamma, result.proximal);
294  }
295  }
296 }
constexpr auto N
Definition: wavelets.cc:57
sopt::t_real Scalar
Root exception for sopt.
Definition: exception.h:11
Computes inner-most element type.
Definition: real_type.h:42
Proximal of euclidian norm.
Definition: proximal.h:18
L1 proximal, including linear transform.
Definition: l1_proximal.h:26
bool converged() const
True if relative variation smaller than tolerance.
Definition: l1_proximal.h:458
bool two_cycle() const
Whether we have a cycle of period two.
Definition: l1_proximal.h:452
Real current() const
Current objective.
Definition: l1_proximal.h:445
static Real next(Real t)
Definition: l1_proximal.h:409
Proximal for indicator function of L2 ball.
Definition: proximal.h:182
Real epsilon() const
Size of the ball.
Definition: proximal.h:222
Real epsilon() const
Size of the ball.
Definition: proximal.h:312
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))
t_uint cols
Translation< FUNCTION, VECTOR > translate(FUNCTION const &func, VECTOR const &translation)
Translates given proximal by given vector.
Definition: proximal.h:362
void l1_norm(Eigen::DenseBase< T0 > &out, typename real_type< typename T0::Scalar >::type gamma, Eigen::DenseBase< T1 > const &x)
Proximal of the l1 norm.
Definition: proximal.h:64
Eigen::CwiseUnaryOp< const details::ProjectPositiveQuadrant< typename T::Scalar >, const T > positive_quadrant(Eigen::DenseBase< T > const &input)
Expression to create projection onto positive quadrant.
Definition: maths.h:60
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:38
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
std::complex< t_real > t_complex
Root of the type hierarchy for (real) complex numbers.
Definition: types.h:19
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > Matrix
A matrix of a given type.
Definition: types.h:29
TEST_CASE("L2Ball", "[proximal]")
Definition: proximal.cc:27
sopt::Matrix< T > concatenated_permutations(sopt::t_uint i, sopt::t_uint j)
Definition: proximal.cc:13