SOPT
Sparse OPTimisation
padmm.cc
Go to the documentation of this file.
1 #include <catch2/catch_all.hpp>
2 #include <random>
3 
4 #include <Eigen/Dense>
5 
6 #include "sopt/imaging_padmm.h"
7 #include "sopt/padmm.h"
8 #include "sopt/proximal.h"
9 #include "sopt/types.h"
10 
12  extern std::unique_ptr<std::mt19937_64> mersenne;
13  std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
14  return uniform_dist(*mersenne);
15 };
16 
20 
21 auto constexpr N = 5;
22 
23 TEST_CASE("Proximal ADMM with ||x - x0||_2 functions", "[padmm][integration]") {
24  using namespace sopt;
25  t_Vector const target0 = t_Vector::Random(N);
26  t_Vector const target1 = t_Vector::Random(N) * 4;
27  auto const g0 = proximal::translate(proximal::EuclidianNorm(), -target0);
28  auto const g1 = proximal::translate(proximal::EuclidianNorm(), -target1);
29 
30  t_Matrix const mId = -t_Matrix::Identity(N, N);
31 
32  t_Vector const translation = t_Vector::Ones(N) * 5;
33  auto const padmm =
34  algorithm::ProximalADMM<Scalar>(g0, g1, t_Vector::Zero(N)).Phi(mId).itermax(3000).regulariser_strength(0.01);
35  auto const result = padmm();
36 
37  t_Vector const segment = (target1 - target0).normalized();
38  t_real const alpha = (result.x - target0).transpose() * segment;
39 
40  CHECK((target1 - target0).transpose() * segment >= alpha);
41  CHECK(alpha >= 0e0);
42  CAPTURE(segment.transpose());
43  CAPTURE((result.x - target0).transpose());
44  CAPTURE((result.x - target1).transpose());
45  CHECK((result.x - target0 - alpha * segment).stableNorm() < 1e-8);
46 }
47 
48 template <typename T>
50  : public std::is_same<sopt::algorithm::ImagingProximalADMM<double> &, T> {};
51 TEST_CASE("Check type returned on setting variables") {
52  // Yeah, could be static asserts
53  using namespace sopt;
54  using namespace sopt::algorithm;
56  CHECK(is_imaging_proximal_ref<decltype(admm.itermax(500))>::value);
57  CHECK(is_imaging_proximal_ref<decltype(admm.regulariser_strength(1e-1))>::value);
58  CHECK(is_imaging_proximal_ref<decltype(admm.relative_variation(5e-4))>::value);
59  CHECK(is_imaging_proximal_ref<decltype(admm.l2ball_proximal_epsilon(1e-4))>::value);
60  CHECK(is_imaging_proximal_ref<decltype(admm.tight_frame(false))>::value);
61  CHECK(is_imaging_proximal_ref<decltype(admm.l1_proximal_tolerance(1e-2))>::value);
62  CHECK(is_imaging_proximal_ref<decltype(admm.l1_proximal_nu(1))>::value);
63  CHECK(is_imaging_proximal_ref<decltype(admm.l1_proximal_itermax(50))>::value);
64  CHECK(is_imaging_proximal_ref<decltype(admm.l1_proximal_positivity_constraint(true))>::value);
65  CHECK(is_imaging_proximal_ref<decltype(admm.l1_proximal_real_constraint(true))>::value);
66  CHECK(is_imaging_proximal_ref<decltype(admm.residual_convergence(1.001))>::value);
67  CHECK(is_imaging_proximal_ref<decltype(admm.lagrange_update_scale(0.9))>::value);
68  CHECK(is_imaging_proximal_ref<decltype(admm.target(Vector<double>::Zero(0)))>::value);
69  using ConvFunc = ConvergenceFunction<double>;
70  CHECK(is_imaging_proximal_ref<decltype(admm.is_converged(std::declval<ConvFunc>()))>::value);
71  CHECK(is_imaging_proximal_ref<decltype(admm.is_converged(std::declval<ConvFunc &>()))>::value);
72  CHECK(is_imaging_proximal_ref<decltype(admm.is_converged(std::declval<ConvFunc &&>()))>::value);
73  CHECK(is_imaging_proximal_ref<decltype(
74  admm.is_converged(std::declval<ConvFunc const &>()))>::value);
75  using LinTrans = LinearTransform<Vector<double>>;
76  CHECK(is_imaging_proximal_ref<decltype(admm.Phi(linear_transform_identity<double>()))>::value);
77  CHECK(is_imaging_proximal_ref<decltype(admm.Phi(std::declval<LinTrans>()))>::value);
78  CHECK(is_imaging_proximal_ref<decltype(admm.Phi(std::declval<LinTrans &&>()))>::value);
79  CHECK(is_imaging_proximal_ref<decltype(admm.Phi(std::declval<LinTrans &>()))>::value);
80  CHECK(is_imaging_proximal_ref<decltype(admm.Phi(std::declval<LinTrans const &>()))>::value);
81  CHECK(is_imaging_proximal_ref<decltype(admm.Psi(linear_transform_identity<double>()))>::value);
82  CHECK(is_imaging_proximal_ref<decltype(admm.Psi(std::declval<LinTrans>()))>::value);
83  CHECK(is_imaging_proximal_ref<decltype(admm.Psi(std::declval<LinTrans &&>()))>::value);
84  CHECK(is_imaging_proximal_ref<decltype(admm.Psi(std::declval<LinTrans &>()))>::value);
85  CHECK(is_imaging_proximal_ref<decltype(admm.Psi(std::declval<LinTrans const &>()))>::value);
86 }
sopt::Vector< Scalar > t_Vector
sopt::t_real Scalar
sopt::Matrix< Scalar > t_Matrix
Joins together direct and indirect operators.
ImagingProximalADMM< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
t_Vector const & target() const
Vector of target measurements.
ImagingProximalADMM< Scalar > & residual_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
ImagingProximalADMM &::type Phi(ARGS &&... args)
t_LinearTransform const & Psi() const
Analysis operator Ψ
Proximal Alternate Direction method of mutltipliers.
Definition: padmm.h:19
ProximalADMM &::type Phi(ARGS &&... args)
Definition: padmm.h:174
Proximal of euclidian norm.
Definition: proximal.h:18
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))
Translation< FUNCTION, VECTOR > translate(FUNCTION const &func, VECTOR const &translation)
Translates given proximal by given vector.
Definition: proximal.h:362
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
std::function< bool(Vector< SCALAR > const &)> ConvergenceFunction
Typical function signature for convergence.
Definition: types.h:52
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > Matrix
A matrix of a given type.
Definition: types.h:29
sopt::t_int random_integer(sopt::t_int min, sopt::t_int max)
Definition: padmm.cc:11
constexpr auto N
Definition: padmm.cc:21
TEST_CASE("Proximal ADMM with ||x - x0||_2 functions", "[padmm][integration]")
Definition: padmm.cc:23