SOPT
Sparse OPTimisation
Classes | Functions
reweighted.cc File Reference
#include <catch2/catch_all.hpp>
#include <random>
#include <Eigen/Dense>
#include "sopt/imaging_padmm.h"
#include "sopt/reweighted.h"
+ Include dependency graph for reweighted.cc:

Go to the source code of this file.

Classes

struct  DummyAlgorithm
 Minimum set of functions and type aliases needed by reweighting. More...
 
struct  DummyAlgorithm::DiagnosticAndResult
 

Functions

 TEST_CASE ("L0-Approximation")
 

Function Documentation

◆ TEST_CASE()

TEST_CASE ( "L0-Approximation"  )

Definition at line 59 of file reweighted.cc.

59  {
60  auto constexpr N = 6;
61  DummyAlgorithm::t_Vector const input = DummyAlgorithm::t_Vector::Random(N);
62 
65 
70  DummyAlgorithm::DiagnosticAndResult::x = DummyAlgorithm::t_Vector::Zero(0);
71  DummyAlgorithm::weights = DummyAlgorithm::t_Vector::Zero(0);
72 
73  GIVEN("The maximum number of iteration is zero") {
74  l0algo.itermax(0);
75  WHEN("The reweighting algorithm is called") {
76  auto const result = l0algo(input);
77  THEN("The algorithm exited at the first iteration") {
78  CHECK(result.niters == 0);
79  CHECK(result.good == true);
80  }
81  THEN("The weights is set to 1") {
82  CHECK(result.weights.size() == 1);
83  CHECK(std::abs(result.weights(0) - 1) < 1e-12);
84  }
85  THEN("The inner algorithm was called once") {
88  CHECK(result.algo.x.array().isApprox(input.array() + 0.1));
89  }
90  }
91  }
92 
93  GIVEN("The maximum number of iterations is one") {
94  l0algo.itermax(1);
95  WHEN("The reweighting algorithm is called") {
96  auto const result = l0algo(input);
97  THEN("The algorithm exited at the second iteration") {
98  CHECK(result.niters == 1);
99  CHECK(result.good == true);
100  }
101  THEN("The weights are not one") {
102  CHECK(result.weights.size() == input.size());
103  // standard deviation of Ψ^T x, with x the output of the first call to the inner algorithm
104  Vector<> const PsiT_x = DummyAlgorithm::reweightee({}, input.array() + 0.1);
105  auto delta = standard_deviation(PsiT_x);
106  CHECK(result.weights.array().isApprox(delta / (delta + PsiT_x.array().abs())));
107  }
108  THEN("The inner algorithm was called twice") {
109  CHECK(DummyAlgorithm::called_with_x == 1);
111  CHECK(result.algo.x.array().isApprox(input.array() + 0.2));
112  }
113  }
114  }
115 }
constexpr auto N
Definition: wavelets.cc:57
Reweighted< ALGORITHM > reweighted(ALGORITHM const &algo, typename Reweighted< ALGORITHM >::t_SetWeights const &set_weights, typename Reweighted< ALGORITHM >::t_Reweightee const &reweightee)
Factory function to create an l0-approximation by reweighting an l1 norm.
Definition: reweighted.h:238
real_type< typename T::Scalar >::type standard_deviation(Eigen::ArrayBase< T > const &x)
Computes the standard deviation of a vector.
Definition: maths.h:16
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
static t_Vector x
Expected by reweighted algorithm.
Definition: reweighted.cc:20
Minimum set of functions and type aliases needed by reweighting.
Definition: reweighted.cc:13
static int called_with_warm
Definition: reweighted.cc:47
static int called_with_x
Definition: reweighted.cc:46
static t_Vector weights
Definition: reweighted.cc:45
static int called_reweightee
Definition: reweighted.cc:48
static int called_weights
Definition: reweighted.cc:49
static t_Vector reweightee(DummyAlgorithm const &, t_Vector const &x)
Applies Ψ^T * x.
Definition: reweighted.cc:35
static void set_weights(DummyAlgorithm &, t_Vector const &weights)
sets the weights
Definition: reweighted.cc:40
Vector< Scalar > t_Vector
Definition: reweighted.cc:15

References DummyAlgorithm::called_reweightee, DummyAlgorithm::called_weights, DummyAlgorithm::called_with_warm, DummyAlgorithm::called_with_x, N, sopt::algorithm::reweighted(), DummyAlgorithm::reweightee(), DummyAlgorithm::set_weights(), sopt::standard_deviation(), DummyAlgorithm::weights, and DummyAlgorithm::DiagnosticAndResult::x.