1 #include <catch2/catch_all.hpp>
11 extern std::unique_ptr<std::mt19937_64>
mersenne;
12 std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
28 using sopt::algorithm::SDMM<Scalar>::t_Vectors;
34 t_Vector const translation = t_Vector::Ones(
N) * 5;
37 t_Vector const input = t_Vector::Random(
N).array() + 1e0;
38 CHECK(g(0.1, input).isApprox((1e0 - 0.1 / input.stableNorm()) * input));
39 auto const gamma = input.stableNorm() * 0.5;
40 CHECK(g(gamma, input).isApprox((1e0 - gamma / input.stableNorm()) * input));
41 CHECK(g(gamma * 2 + 1, input).isApprox(input.Zero(
N)));
43 .isApprox((1e0 - 0.1 / (input - translation).stableNorm()) * (input - translation) +
49 TEST_CASE(
"Introspect SDMM with L_i = Identity and Euclidian objectives",
"[sdmm]") {
52 t_Matrix const Id = t_Matrix::Identity(
N,
N).eval();
53 t_Vector const target0 = t_Vector::Zero(
N);
54 t_Vector const target1 = t_Vector::Random(
N);
58 t_Vector const input = 10 * t_Vector::Random(
N);
67 SECTION(
"Step by Step") {
68 INFO(
"Initialization");
70 IntrospectSDMM::t_Vectors y(sdmm.
transforms().size(), t_Vector::Zero(out.size()));
71 IntrospectSDMM::t_Vectors z(sdmm.
transforms().size(), t_Vector::Zero(out.size()));
72 sdmm.initialization(y, z, out);
73 CHECK(y[0].isApprox(input));
74 CHECK(y[1].isApprox(input));
76 INFO(
"\nThen solve for conjugate gradient");
77 auto const diagnostic0 = sdmm.solve_for_xn(out, y, z);
78 CHECK(diagnostic0.good);
79 CAPTURE(out.transpose());
80 CAPTURE(input.transpose());
81 CAPTURE(0.5 * (y[0] + y[1]).transpose());
82 CHECK(out.isApprox(0.5 * (y[0] + y[1]), 1e-8));
83 CHECK(out.isApprox(input, 1e-8));
85 INFO(
"\nWe move on to first iteration!");
86 INFO(
"- updates y and z");
87 sdmm.update_directions(y, z, out);
88 CHECK(y[0].isApprox(g0(sdmm.gamma(), input)));
89 CHECK(y[1].isApprox(g1(sdmm.gamma(), input)));
90 CHECK(z[0].isApprox(input - y[0]));
91 CHECK(z[1].isApprox(input - y[1]));
93 INFO(
"- solve for conjugate gradient");
94 auto const diagnostic1 = sdmm.solve_for_xn(out, y, z);
95 CHECK(diagnostic1.good);
96 CAPTURE(out.transpose());
97 CAPTURE((0.5 * (y[0] - z[0] + y[1] - z[1])).transpose());
98 CHECK(out.isApprox(0.5 * (y[0] - z[0] + y[1] - z[1])));
99 t_Vector const x1 = g0(sdmm.gamma(), input) + g1(sdmm.gamma(), input) - input;
100 CHECK(out.isApprox(x1));
102 INFO(
"\nWe move on to second iteration!");
103 INFO(
"- updates y and z");
104 sdmm.update_directions(y, z, out);
105 CHECK(y[0].isApprox(g0(sdmm.gamma(), g1(sdmm.gamma(), input))));
106 CHECK(y[1].isApprox(g1(sdmm.gamma(), g0(sdmm.gamma(), input))));
107 CHECK(z[0].isApprox(g1(sdmm.gamma(), input) - y[0]));
108 CHECK(z[1].isApprox(g0(sdmm.gamma(), input) - y[1]));
110 INFO(
"- solve for conjugate gradient");
111 auto const diagnostic2 = sdmm.solve_for_xn(out, y, z);
112 CHECK(diagnostic2.good);
113 CHECK(out.isApprox(0.5 * (y[0] - z[0] + y[1] - z[1])));
114 t_Vector const x2 = g0(sdmm.gamma(), g1(sdmm.gamma(), input)) +
115 g1(sdmm.gamma(), g0(sdmm.gamma(), input)) - 0.5 * g1(sdmm.gamma(), input) -
116 0.5 * g0(sdmm.gamma(), input);
117 CHECK(out.isApprox(x2));
120 SECTION(
"Iteration by Iteration") {
122 SECTION(
"First Iteration") {
124 auto const diagnostic = sdmm(out, input);
125 CHECK(not diagnostic.good);
126 CHECK(diagnostic.niters == 1);
127 CHECK(out.isApprox(g0(sdmm.gamma(), input) + g1(sdmm.gamma(), input) - input));
129 SECTION(
"Second Iteration") {
131 auto const diagnostic = sdmm(out, input);
132 CHECK(not diagnostic.good);
133 CHECK(diagnostic.niters == 2);
134 t_Vector const x2 = g0(sdmm.gamma(), g1(sdmm.gamma(), input)) +
135 g1(sdmm.gamma(), g0(sdmm.gamma(), input)) -
136 0.5 * g1(sdmm.gamma(), input) - 0.5 * g0(sdmm.gamma(), input);
137 CHECK(out.isApprox(x2));
140 SECTION(
"Nth Iterations") {
142 for (
t_uint itermax(0); itermax < 10; ++itermax) {
145 t_Vector z[2] = {t_Vector::Zero(
N).eval(), t_Vector::Zero(
N).eval()};
146 for (
t_uint i(0); i < itermax; ++i) {
147 y[0] = g0(sdmm.gamma(), x + z[0]);
148 y[1] = g1(sdmm.gamma(), x + z[1]);
149 z[0] += x - g0(sdmm.gamma(), x + z[0]);
150 z[1] += x - g1(sdmm.gamma(), x + z[1]);
151 x = 0.5 * (y[0] - z[0] + y[1] - z[1]);
154 sdmm.itermax(itermax);
155 auto const diagnostic = sdmm(out, input);
156 CHECK(out.isApprox(x, 1e-8));
157 CHECK(not diagnostic.good);
158 CHECK(diagnostic.niters == itermax);
164 TEST_CASE(
"SDMM with ||x - x0||_2 functions",
"[sdmm][integration]") {
165 using namespace sopt;
167 t_Matrix const Id = t_Matrix::Identity(
N,
N).eval();
168 t_Vector const target0 = t_Vector::Random(
N);
169 t_Vector target1 = t_Vector::Random(
N) * 4;
180 SECTION(
"Just two operators") {
181 auto const diagnostic = sdmm(result, t_Vector::Random(
N));
182 CHECK(not diagnostic.good);
183 CHECK(diagnostic.niters == sdmm.itermax());
184 t_Vector const segment = (target1 - target0).normalized();
185 t_real const alpha = (result - target0).transpose() * segment;
186 CAPTURE(target0.transpose());
187 CAPTURE(target1.transpose());
188 CHECK((target1 - target0).transpose() * segment >= alpha);
190 CHECK((result - target0 - alpha * segment).stableNorm() < 1e-8);
193 SECTION(
"Three operators") {
194 t_Vector const target2 = t_Vector::Random(
N) * 8;
196 auto const diagnostic = sdmm(result, t_Vector::Random(
N));
197 CHECK(not diagnostic.good);
198 CHECK(diagnostic.niters == sdmm.itermax());
199 CAPTURE(result.transpose());
200 auto const func = [&target0, &target1, &target2](
t_Vector const &x) {
201 return (x - target0).stableNorm() + (x - target1).stableNorm() + (x - target2).stableNorm();
203 for (
int i(0); i <
N; ++i) {
206 CHECK(func(result) < func(result +
epsilon));
207 CHECK(func(result) < func(result -
epsilon));
211 SECTION(
"With different L") {
212 t_Matrix const L0 = t_Matrix::Random(
N,
N) * 2;
213 t_Matrix const L1 = t_Matrix::Random(
N,
N) * 4;
214 REQUIRE(std::abs((L0.transpose() * L0 + L1.transpose() * L1).determinant()) > 1e-8);
218 auto const diagnostic = sdmm(result, t_Vector::Random(
N));
219 CHECK(not diagnostic.good);
220 CHECK(diagnostic.niters == sdmm.itermax());
221 CAPTURE(result.transpose());
222 auto const func = [&target0, &target1, &L0, &L1](
t_Vector const &x) {
223 return (L0 * x - target0).stableNorm() + (L1 * x - target1).stableNorm();
225 for (
int i(0); i <
N; ++i) {
229 CHECK(func(result) <= func(result +
epsilon));
230 CHECK(func(result) <= func(result -
epsilon));
sopt::Vector< Scalar > t_Vector
sopt::Matrix< Scalar > t_Matrix
Simultaneous-direction method of the multipliers.
std::vector< t_LinearTransform > const & transforms() const
Linear transforms associated with each objective function.
Vector< SCALAR > t_Vector
Type of then underlying vectors.
SDMM< SCALAR > & conjugate_gradient(t_uint itermax, t_real tolerance)
Helps setup conjugate gradient.
SDMM< SCALAR > & append(PROXIMAL proximal, T args)
Appends a proximal and linear transform.
Proximal of euclidian norm.
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))
Translation< FUNCTION, VECTOR > translate(FUNCTION const &func, VECTOR const &translation)
Translates given proximal by given vector.
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
int t_int
Root of the type hierarchy for signed integers.
double t_real
Root of the type hierarchy for real numbers.
size_t t_uint
Root of the type hierarchy for unsigned integers.
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > Matrix
A matrix of a given type.
sopt::t_int random_integer(sopt::t_int min, sopt::t_int max)
TEST_CASE("Proximal translation", "[proximal]")