SOPT
Sparse OPTimisation
sdmm.cc
Go to the documentation of this file.
1 #include <catch2/catch_all.hpp>
2 #include <random>
3 
4 #include <Eigen/Dense>
5 
6 #include "sopt/proximal.h"
7 #include "sopt/sdmm.h"
8 #include "sopt/types.h"
9 
11  extern std::unique_ptr<std::mt19937_64> mersenne;
12  std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
13  return uniform_dist(*mersenne);
14 };
15 
19 
20 auto constexpr N = 4;
21 
22 // Makes members public so we can test one at a time
23 class IntrospectSDMM : public sopt::algorithm::SDMM<Scalar> {
24  public:
28  using sopt::algorithm::SDMM<Scalar>::t_Vectors;
30 };
31 
32 TEST_CASE("Proximal translation", "[proximal]") {
33  using namespace sopt;
34  t_Vector const translation = t_Vector::Ones(N) * 5;
35  auto const g = proximal::EuclidianNorm();
36  auto const gT = proximal::translate(g, -translation);
37  t_Vector const input = t_Vector::Random(N).array() + 1e0;
38  CHECK(g(0.1, input).isApprox((1e0 - 0.1 / input.stableNorm()) * input));
39  auto const gamma = input.stableNorm() * 0.5;
40  CHECK(g(gamma, input).isApprox((1e0 - gamma / input.stableNorm()) * input));
41  CHECK(g(gamma * 2 + 1, input).isApprox(input.Zero(N)));
42  CHECK(gT(0.1, input)
43  .isApprox((1e0 - 0.1 / (input - translation).stableNorm()) * (input - translation) +
44  translation));
45 }
46 
47 // Iterate through algorithm for special case where the L_i are identies and the objective functions
48 // are simple euclidian norms
49 TEST_CASE("Introspect SDMM with L_i = Identity and Euclidian objectives", "[sdmm]") {
50  using namespace sopt;
51 
52  t_Matrix const Id = t_Matrix::Identity(N, N).eval();
53  t_Vector const target0 = t_Vector::Zero(N);
54  t_Vector const target1 = t_Vector::Random(N);
55 
56  auto const g0 = proximal::translate(proximal::EuclidianNorm(), -target0);
57  auto const g1 = proximal::translate(proximal::EuclidianNorm(), -target1);
58  t_Vector const input = 10 * t_Vector::Random(N);
59 
61  sdmm.itermax(10)
62  .gamma(0.01)
63  .conjugate_gradient(std::numeric_limits<t_uint>::max(), 1e-12)
64  .append(g0, Id)
65  .append(g1, Id);
66 
67  SECTION("Step by Step") {
68  INFO("Initialization");
69  t_Vector out = input;
70  IntrospectSDMM::t_Vectors y(sdmm.transforms().size(), t_Vector::Zero(out.size()));
71  IntrospectSDMM::t_Vectors z(sdmm.transforms().size(), t_Vector::Zero(out.size()));
72  sdmm.initialization(y, z, out);
73  CHECK(y[0].isApprox(input));
74  CHECK(y[1].isApprox(input));
75 
76  INFO("\nThen solve for conjugate gradient");
77  auto const diagnostic0 = sdmm.solve_for_xn(out, y, z);
78  CHECK(diagnostic0.good);
79  CAPTURE(out.transpose());
80  CAPTURE(input.transpose());
81  CAPTURE(0.5 * (y[0] + y[1]).transpose());
82  CHECK(out.isApprox(0.5 * (y[0] + y[1]), 1e-8));
83  CHECK(out.isApprox(input, 1e-8));
84 
85  INFO("\nWe move on to first iteration!");
86  INFO("- updates y and z");
87  sdmm.update_directions(y, z, out);
88  CHECK(y[0].isApprox(g0(sdmm.gamma(), input)));
89  CHECK(y[1].isApprox(g1(sdmm.gamma(), input)));
90  CHECK(z[0].isApprox(input - y[0]));
91  CHECK(z[1].isApprox(input - y[1]));
92 
93  INFO("- solve for conjugate gradient");
94  auto const diagnostic1 = sdmm.solve_for_xn(out, y, z);
95  CHECK(diagnostic1.good);
96  CAPTURE(out.transpose());
97  CAPTURE((0.5 * (y[0] - z[0] + y[1] - z[1])).transpose());
98  CHECK(out.isApprox(0.5 * (y[0] - z[0] + y[1] - z[1])));
99  t_Vector const x1 = g0(sdmm.gamma(), input) + g1(sdmm.gamma(), input) - input;
100  CHECK(out.isApprox(x1));
101 
102  INFO("\nWe move on to second iteration!");
103  INFO("- updates y and z");
104  sdmm.update_directions(y, z, out);
105  CHECK(y[0].isApprox(g0(sdmm.gamma(), g1(sdmm.gamma(), input))));
106  CHECK(y[1].isApprox(g1(sdmm.gamma(), g0(sdmm.gamma(), input))));
107  CHECK(z[0].isApprox(g1(sdmm.gamma(), input) - y[0]));
108  CHECK(z[1].isApprox(g0(sdmm.gamma(), input) - y[1]));
109 
110  INFO("- solve for conjugate gradient");
111  auto const diagnostic2 = sdmm.solve_for_xn(out, y, z);
112  CHECK(diagnostic2.good);
113  CHECK(out.isApprox(0.5 * (y[0] - z[0] + y[1] - z[1])));
114  t_Vector const x2 = g0(sdmm.gamma(), g1(sdmm.gamma(), input)) +
115  g1(sdmm.gamma(), g0(sdmm.gamma(), input)) - 0.5 * g1(sdmm.gamma(), input) -
116  0.5 * g0(sdmm.gamma(), input);
117  CHECK(out.isApprox(x2));
118  }
119 
120  SECTION("Iteration by Iteration") {
121  t_Vector out;
122  SECTION("First Iteration") {
123  sdmm.itermax(1);
124  auto const diagnostic = sdmm(out, input);
125  CHECK(not diagnostic.good);
126  CHECK(diagnostic.niters == 1);
127  CHECK(out.isApprox(g0(sdmm.gamma(), input) + g1(sdmm.gamma(), input) - input));
128  }
129  SECTION("Second Iteration") {
130  sdmm.itermax(2);
131  auto const diagnostic = sdmm(out, input);
132  CHECK(not diagnostic.good);
133  CHECK(diagnostic.niters == 2);
134  t_Vector const x2 = g0(sdmm.gamma(), g1(sdmm.gamma(), input)) +
135  g1(sdmm.gamma(), g0(sdmm.gamma(), input)) -
136  0.5 * g1(sdmm.gamma(), input) - 0.5 * g0(sdmm.gamma(), input);
137  CHECK(out.isApprox(x2));
138  }
139 
140  SECTION("Nth Iterations") {
141  sdmm.gamma(1);
142  for (t_uint itermax(0); itermax < 10; ++itermax) {
143  t_Vector x = input;
144  t_Vector y[2] = {x, x};
145  t_Vector z[2] = {t_Vector::Zero(N).eval(), t_Vector::Zero(N).eval()};
146  for (t_uint i(0); i < itermax; ++i) {
147  y[0] = g0(sdmm.gamma(), x + z[0]);
148  y[1] = g1(sdmm.gamma(), x + z[1]);
149  z[0] += x - g0(sdmm.gamma(), x + z[0]);
150  z[1] += x - g1(sdmm.gamma(), x + z[1]);
151  x = 0.5 * (y[0] - z[0] + y[1] - z[1]);
152  }
153 
154  sdmm.itermax(itermax);
155  auto const diagnostic = sdmm(out, input);
156  CHECK(out.isApprox(x, 1e-8));
157  CHECK(not diagnostic.good);
158  CHECK(diagnostic.niters == itermax);
159  }
160  }
161  }
162 }
163 
164 TEST_CASE("SDMM with ||x - x0||_2 functions", "[sdmm][integration]") {
165  using namespace sopt;
166 
167  t_Matrix const Id = t_Matrix::Identity(N, N).eval();
168  t_Vector const target0 = t_Vector::Random(N);
169  t_Vector target1 = t_Vector::Random(N) * 4;
170  // for(t_uint i(0); i < N; ++i) target1(i) = i + 1;
171 
172  auto sdmm = algorithm::SDMM<Scalar>()
173  .itermax(5000)
174  .gamma(1)
175  .conjugate_gradient(std::numeric_limits<t_uint>::max(), 1e-12)
178 
179  t_Vector result;
180  SECTION("Just two operators") {
181  auto const diagnostic = sdmm(result, t_Vector::Random(N));
182  CHECK(not diagnostic.good);
183  CHECK(diagnostic.niters == sdmm.itermax());
184  t_Vector const segment = (target1 - target0).normalized();
185  t_real const alpha = (result - target0).transpose() * segment;
186  CAPTURE(target0.transpose());
187  CAPTURE(target1.transpose());
188  CHECK((target1 - target0).transpose() * segment >= alpha);
189  CHECK(alpha >= 0e0);
190  CHECK((result - target0 - alpha * segment).stableNorm() < 1e-8);
191  }
192 
193  SECTION("Three operators") {
194  t_Vector const target2 = t_Vector::Random(N) * 8;
195  sdmm.append(proximal::translate(proximal::EuclidianNorm(), -target2), Id);
196  auto const diagnostic = sdmm(result, t_Vector::Random(N));
197  CHECK(not diagnostic.good);
198  CHECK(diagnostic.niters == sdmm.itermax());
199  CAPTURE(result.transpose());
200  auto const func = [&target0, &target1, &target2](t_Vector const &x) {
201  return (x - target0).stableNorm() + (x - target1).stableNorm() + (x - target2).stableNorm();
202  };
203  for (int i(0); i < N; ++i) {
204  t_Vector epsilon = t_Vector::Zero(N);
205  epsilon(i) = 1e-6;
206  CHECK(func(result) < func(result + epsilon));
207  CHECK(func(result) < func(result - epsilon));
208  }
209  }
210 
211  SECTION("With different L") {
212  t_Matrix const L0 = t_Matrix::Random(N, N) * 2;
213  t_Matrix const L1 = t_Matrix::Random(N, N) * 4;
214  REQUIRE(std::abs((L0.transpose() * L0 + L1.transpose() * L1).determinant()) > 1e-8);
215  sdmm.itermax(300);
216  sdmm.transforms(0) = linear_transform(L0);
217  sdmm.transforms(1) = linear_transform(L1);
218  auto const diagnostic = sdmm(result, t_Vector::Random(N));
219  CHECK(not diagnostic.good);
220  CHECK(diagnostic.niters == sdmm.itermax());
221  CAPTURE(result.transpose());
222  auto const func = [&target0, &target1, &L0, &L1](t_Vector const &x) {
223  return (L0 * x - target0).stableNorm() + (L1 * x - target1).stableNorm();
224  };
225  for (int i(0); i < N; ++i) {
226  t_Vector epsilon = t_Vector::Zero(N);
227  epsilon(i) = 1e-3;
228  CAPTURE(epsilon.transpose());
229  CHECK(func(result) <= func(result + epsilon));
230  CHECK(func(result) <= func(result - epsilon));
231  }
232  }
233 }
sopt::Vector< Scalar > t_Vector
sopt::t_real Scalar
sopt::Matrix< Scalar > t_Matrix
Simultaneous-direction method of the multipliers.
Definition: sdmm.h:23
std::vector< t_LinearTransform > const & transforms() const
Linear transforms associated with each objective function.
Definition: sdmm.h:135
Vector< SCALAR > t_Vector
Type of then underlying vectors.
Definition: sdmm.h:45
SDMM< SCALAR > & conjugate_gradient(t_uint itermax, t_real tolerance)
Helps setup conjugate gradient.
Definition: sdmm.h:83
SDMM< SCALAR > & append(PROXIMAL proximal, T args)
Appends a proximal and linear transform.
Definition: sdmm.h:90
Proximal of euclidian norm.
Definition: proximal.h:18
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))
Translation< FUNCTION, VECTOR > translate(FUNCTION const &func, VECTOR const &translation)
Translates given proximal by given vector.
Definition: proximal.h:362
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:38
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > Matrix
A matrix of a given type.
Definition: types.h:29
sopt::t_int random_integer(sopt::t_int min, sopt::t_int max)
Definition: sdmm.cc:10
constexpr auto N
Definition: sdmm.cc:20
TEST_CASE("Proximal translation", "[proximal]")
Definition: sdmm.cc:32