SOPT
Sparse OPTimisation
forward_backward.cc
Go to the documentation of this file.
1 #include <catch2/catch_all.hpp>
2 #include <random>
3 #include <vector>
4 
5 #include <Eigen/Dense>
6 
9 #include "sopt/logging.h"
10 #include "sopt/maths.h"
11 #include "sopt/proximal.h"
12 #include "sopt/types.h"
13 
14 // This header is not part of the installed sopt interface
15 // It is only present in tests
16 #include "tools_for_tests/directories.h"
18 
20  extern std::unique_ptr<std::mt19937_64> mersenne;
21  std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
22  return uniform_dist(*mersenne);
23 };
24 
29 auto constexpr N = 5;
30 
31 TEST_CASE("Forward Backward with ||x - x0||_2^2 function", "[fb]") {
32  using namespace sopt;
33  t_Vector const target0 = t_Vector::Random(N);
34  t_real constexpr beta = 0.2;
35  t_real constexpr regulariser_strength = 0.1;
36  int constexpr itermax = 300;
37  auto const g0 = [](t_Vector &out, const t_real regulariser_strength, const t_Vector &x) {
38  proximal::id(out, regulariser_strength, x);
39  };
40  auto const grad = [](t_Vector &out, const t_Vector image, const t_Vector &res,
41  const t_LinearTransform &Phi) { out = Phi.adjoint() * res; };
42  const t_Vector x_guess = t_Vector::Random(target0.size());
43  const t_Vector res = x_guess - target0;
44  auto const convergence = [&target0](const t_Vector &x, const t_Vector &res) -> bool {
45  return x.isApprox(target0, 1e-9);
46  };
47  CAPTURE(target0);
48  CAPTURE(x_guess);
49  CAPTURE(res);
50  auto fb = algorithm::ForwardBackward<Scalar>(grad, g0, target0)
51  .itermax(itermax)
52  .regulariser_strength(regulariser_strength)
53  .step_size(beta)
54  .is_converged(convergence);
55  auto const result = fb(std::make_tuple(x_guess, res));
56  CAPTURE(result.niters);
57  CAPTURE(result.x);
58  CAPTURE(result.residual);
59  CHECK(result.x.isApprox(target0, 1e-9));
60  CHECK(result.good);
61  CHECK(result.niters < itermax);
62 }
63 
64 template <typename T> struct is_imaging_proximal_ref
65  : public std::is_same<sopt::algorithm::ImagingForwardBackward<double> &, T> {};
66 template <typename T> struct is_l1_g_proximal_ref
67  : public std::is_same<sopt::algorithm::L1GProximal<double> &, T> {};
68 
69 TEST_CASE("Check type returned on setting variables") {
70  // Yeah, could be static asserts
71  using namespace sopt;
72  using namespace sopt::algorithm;
74  CHECK(is_imaging_proximal_ref<decltype(fb.itermax(500))>::value);
75  CHECK(is_imaging_proximal_ref<decltype(fb.step_size(1e-1))>::value);
76  CHECK(is_imaging_proximal_ref<decltype(fb.regulariser_strength(1e-1))>::value);
77  CHECK(is_imaging_proximal_ref<decltype(fb.sigma(1e-1))>::value);
78  CHECK(is_imaging_proximal_ref<decltype(fb.residual_convergence(1.001))>::value);
79  CHECK(is_imaging_proximal_ref<decltype(fb.target(Vector<double>::Zero(0)))>::value);
80  using ConvFunc = ConvergenceFunction<double>;
81  CHECK(is_imaging_proximal_ref<decltype(fb.is_converged(std::declval<ConvFunc>()))>::value);
82  CHECK(is_imaging_proximal_ref<decltype(fb.is_converged(std::declval<ConvFunc &>()))>::value);
83  CHECK(is_imaging_proximal_ref<decltype(fb.is_converged(std::declval<ConvFunc &&>()))>::value);
84  CHECK(is_imaging_proximal_ref<decltype(fb.is_converged(std::declval<ConvFunc const &>()))>::value);
85  CHECK(is_imaging_proximal_ref<decltype(fb.relative_variation(5e-4))>::value);
86  CHECK(is_imaging_proximal_ref<decltype(fb.tight_frame(false))>::value);
87 
88  // Test the types of the l1 g_proximal object separately
89  auto gp = std::make_shared<sopt::algorithm::L1GProximal<Scalar>>(false);
90  CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_tolerance(1e-2))>::value);
91  CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_nu(1))>::value);
92  CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_itermax(50))>::value);
93  CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_positivity_constraint(true))>::value);
94  CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_real_constraint(true))>::value);
95  using LinTrans = LinearTransform<Vector<double>>;
96  CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(linear_transform_identity<double>()))>::value);
97  CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(std::declval<LinTrans>()))>::value);
98  CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(std::declval<LinTrans &&>()))>::value);
99  CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(std::declval<LinTrans &>()))>::value);
100  CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(std::declval<LinTrans const &>()))>::value);
101 }
sopt::Vector< Scalar > t_Vector
sopt::t_real Scalar
ImagingForwardBackward< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
ImagingForwardBackward< Scalar > & residual_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
t_Vector const & target() const
Vector of target measurements.
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))
sopt::t_int random_integer(sopt::t_int min, sopt::t_int max)
constexpr auto N
TEST_CASE("Forward Backward with ||x - x0||_2^2 function", "[fb]")
sopt::t_real t_real
void id(Eigen::DenseBase< T0 > &out, typename real_type< typename T0::Scalar >::type gamma, Eigen::DenseBase< T1 > const &x)
Proximal of a function that is always zero, the identity.
Definition: proximal.h:135
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
std::function< bool(Vector< SCALAR > const &)> ConvergenceFunction
Typical function signature for convergence.
Definition: types.h:52