52 t_Matrix const Id = t_Matrix::Identity(
N,
N).eval();
53 t_Vector const target0 = t_Vector::Zero(
N);
54 t_Vector const target1 = t_Vector::Random(
N);
58 t_Vector const input = 10 * t_Vector::Random(
N);
67 SECTION(
"Step by Step") {
68 INFO(
"Initialization");
70 IntrospectSDMM::t_Vectors y(sdmm.
transforms().size(), t_Vector::Zero(out.size()));
71 IntrospectSDMM::t_Vectors z(sdmm.
transforms().size(), t_Vector::Zero(out.size()));
72 sdmm.initialization(y, z, out);
73 CHECK(y[0].isApprox(input));
74 CHECK(y[1].isApprox(input));
76 INFO(
"\nThen solve for conjugate gradient");
77 auto const diagnostic0 = sdmm.solve_for_xn(out, y, z);
78 CHECK(diagnostic0.good);
79 CAPTURE(out.transpose());
80 CAPTURE(input.transpose());
81 CAPTURE(0.5 * (y[0] + y[1]).transpose());
82 CHECK(out.isApprox(0.5 * (y[0] + y[1]), 1e-8));
83 CHECK(out.isApprox(input, 1e-8));
85 INFO(
"\nWe move on to first iteration!");
86 INFO(
"- updates y and z");
87 sdmm.update_directions(y, z, out);
88 CHECK(y[0].isApprox(g0(sdmm.gamma(), input)));
89 CHECK(y[1].isApprox(g1(sdmm.gamma(), input)));
90 CHECK(z[0].isApprox(input - y[0]));
91 CHECK(z[1].isApprox(input - y[1]));
93 INFO(
"- solve for conjugate gradient");
94 auto const diagnostic1 = sdmm.solve_for_xn(out, y, z);
95 CHECK(diagnostic1.good);
96 CAPTURE(out.transpose());
97 CAPTURE((0.5 * (y[0] - z[0] + y[1] - z[1])).transpose());
98 CHECK(out.isApprox(0.5 * (y[0] - z[0] + y[1] - z[1])));
99 t_Vector const x1 = g0(sdmm.gamma(), input) + g1(sdmm.gamma(), input) - input;
100 CHECK(out.isApprox(x1));
102 INFO(
"\nWe move on to second iteration!");
103 INFO(
"- updates y and z");
104 sdmm.update_directions(y, z, out);
105 CHECK(y[0].isApprox(g0(sdmm.gamma(), g1(sdmm.gamma(), input))));
106 CHECK(y[1].isApprox(g1(sdmm.gamma(), g0(sdmm.gamma(), input))));
107 CHECK(z[0].isApprox(g1(sdmm.gamma(), input) - y[0]));
108 CHECK(z[1].isApprox(g0(sdmm.gamma(), input) - y[1]));
110 INFO(
"- solve for conjugate gradient");
111 auto const diagnostic2 = sdmm.solve_for_xn(out, y, z);
112 CHECK(diagnostic2.good);
113 CHECK(out.isApprox(0.5 * (y[0] - z[0] + y[1] - z[1])));
114 t_Vector const x2 = g0(sdmm.gamma(), g1(sdmm.gamma(), input)) +
115 g1(sdmm.gamma(), g0(sdmm.gamma(), input)) - 0.5 * g1(sdmm.gamma(), input) -
116 0.5 * g0(sdmm.gamma(), input);
117 CHECK(out.isApprox(x2));
120 SECTION(
"Iteration by Iteration") {
122 SECTION(
"First Iteration") {
124 auto const diagnostic = sdmm(out, input);
125 CHECK(not diagnostic.good);
126 CHECK(diagnostic.niters == 1);
127 CHECK(out.isApprox(g0(sdmm.gamma(), input) + g1(sdmm.gamma(), input) - input));
129 SECTION(
"Second Iteration") {
131 auto const diagnostic = sdmm(out, input);
132 CHECK(not diagnostic.good);
133 CHECK(diagnostic.niters == 2);
134 t_Vector const x2 = g0(sdmm.gamma(), g1(sdmm.gamma(), input)) +
135 g1(sdmm.gamma(), g0(sdmm.gamma(), input)) -
136 0.5 * g1(sdmm.gamma(), input) - 0.5 * g0(sdmm.gamma(), input);
137 CHECK(out.isApprox(x2));
140 SECTION(
"Nth Iterations") {
142 for (
t_uint itermax(0); itermax < 10; ++itermax) {
145 t_Vector z[2] = {t_Vector::Zero(
N).eval(), t_Vector::Zero(
N).eval()};
146 for (
t_uint i(0); i < itermax; ++i) {
147 y[0] = g0(sdmm.gamma(), x + z[0]);
148 y[1] = g1(sdmm.gamma(), x + z[1]);
149 z[0] += x - g0(sdmm.gamma(), x + z[0]);
150 z[1] += x - g1(sdmm.gamma(), x + z[1]);
151 x = 0.5 * (y[0] - z[0] + y[1] - z[1]);
154 sdmm.itermax(itermax);
155 auto const diagnostic = sdmm(out, input);
156 CHECK(out.isApprox(x, 1e-8));
157 CHECK(not diagnostic.good);
158 CHECK(diagnostic.niters == itermax);
sopt::Vector< Scalar > t_Vector
sopt::Matrix< Scalar > t_Matrix
std::vector< t_LinearTransform > const & transforms() const
Linear transforms associated with each objective function.
SDMM< SCALAR > & conjugate_gradient(t_uint itermax, t_real tolerance)
Helps setup conjugate gradient.
SDMM< SCALAR > & append(PROXIMAL proximal, T args)
Appends a proximal and linear transform.
Proximal of euclidian norm.
Translation< FUNCTION, VECTOR > translate(FUNCTION const &func, VECTOR const &translation)
Translates given proximal by given vector.
size_t t_uint
Root of the type hierarchy for unsigned integers.