SOPT
Sparse OPTimisation
reweighted.cc
Go to the documentation of this file.
1 #include <catch2/catch_all.hpp>
2 #include <random>
3 
4 #include <Eigen/Dense>
5 
6 #include "sopt/imaging_padmm.h"
7 #include "sopt/reweighted.h"
8 
9 using namespace sopt;
10 
14  using Scalar = t_real;
17 
20  static t_Vector x;
21  };
22 
24  ++called_with_x;
25  DiagnosticAndResult::x = x.array() + 0.1;
26  return {};
27  }
29  ++called_with_warm;
30  DiagnosticAndResult::x = warm.x.array() + 0.1;
31  return {};
32  }
33 
35  static t_Vector reweightee(DummyAlgorithm const &, t_Vector const &x) {
37  return x * 2;
38  }
40  static void set_weights(DummyAlgorithm &, t_Vector const &weights) {
42  DummyAlgorithm::weights = weights;
43  }
44 
45  static t_Vector weights;
46  static int called_with_x;
47  static int called_with_warm;
48  static int called_reweightee;
49  static int called_weights;
50 };
51 
58 
59 TEST_CASE("L0-Approximation") {
60  auto constexpr N = 6;
61  DummyAlgorithm::t_Vector const input = DummyAlgorithm::t_Vector::Random(N);
62 
65 
70  DummyAlgorithm::DiagnosticAndResult::x = DummyAlgorithm::t_Vector::Zero(0);
71  DummyAlgorithm::weights = DummyAlgorithm::t_Vector::Zero(0);
72 
73  GIVEN("The maximum number of iteration is zero") {
74  l0algo.itermax(0);
75  WHEN("The reweighting algorithm is called") {
76  auto const result = l0algo(input);
77  THEN("The algorithm exited at the first iteration") {
78  CHECK(result.niters == 0);
79  CHECK(result.good == true);
80  }
81  THEN("The weights is set to 1") {
82  CHECK(result.weights.size() == 1);
83  CHECK(std::abs(result.weights(0) - 1) < 1e-12);
84  }
85  THEN("The inner algorithm was called once") {
88  CHECK(result.algo.x.array().isApprox(input.array() + 0.1));
89  }
90  }
91  }
92 
93  GIVEN("The maximum number of iterations is one") {
94  l0algo.itermax(1);
95  WHEN("The reweighting algorithm is called") {
96  auto const result = l0algo(input);
97  THEN("The algorithm exited at the second iteration") {
98  CHECK(result.niters == 1);
99  CHECK(result.good == true);
100  }
101  THEN("The weights are not one") {
102  CHECK(result.weights.size() == input.size());
103  // standard deviation of Ψ^T x, with x the output of the first call to the inner algorithm
104  Vector<> const PsiT_x = DummyAlgorithm::reweightee({}, input.array() + 0.1);
105  auto delta = standard_deviation(PsiT_x);
106  CHECK(result.weights.array().isApprox(delta / (delta + PsiT_x.array().abs())));
107  }
108  THEN("The inner algorithm was called twice") {
109  CHECK(DummyAlgorithm::called_with_x == 1);
111  CHECK(result.algo.x.array().isApprox(input.array() + 0.2));
112  }
113  }
114  }
115 }
constexpr auto N
Definition: wavelets.cc:57
Reweighted< ALGORITHM > reweighted(ALGORITHM const &algo, typename Reweighted< ALGORITHM >::t_SetWeights const &set_weights, typename Reweighted< ALGORITHM >::t_Reweightee const &reweightee)
Factory function to create an l0-approximation by reweighting an l1 norm.
Definition: reweighted.h:238
real_type< typename T::Scalar >::type standard_deviation(Eigen::ArrayBase< T > const &x)
Computes the standard deviation of a vector.
Definition: maths.h:16
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
std::function< bool(Vector< SCALAR > const &)> ConvergenceFunction
Typical function signature for convergence.
Definition: types.h:52
static t_Vector x
Expected by reweighted algorithm.
Definition: reweighted.cc:20
Minimum set of functions and type aliases needed by reweighting.
Definition: reweighted.cc:13
static int called_with_warm
Definition: reweighted.cc:47
static int called_with_x
Definition: reweighted.cc:46
static t_Vector weights
Definition: reweighted.cc:45
static int called_reweightee
Definition: reweighted.cc:48
static int called_weights
Definition: reweighted.cc:49
DiagnosticAndResult operator()(DiagnosticAndResult const &warm) const
Definition: reweighted.cc:28
static t_Vector reweightee(DummyAlgorithm const &, t_Vector const &x)
Applies Ψ^T * x.
Definition: reweighted.cc:35
ConvergenceFunction< Scalar > t_IsConverged
Definition: reweighted.cc:16
static void set_weights(DummyAlgorithm &, t_Vector const &weights)
sets the weights
Definition: reweighted.cc:40
DiagnosticAndResult operator()(t_Vector const &x) const
Definition: reweighted.cc:23
Vector< Scalar > t_Vector
Definition: reweighted.cc:15
TEST_CASE("L0-Approximation")
Definition: reweighted.cc:59