SOPT
Sparse OPTimisation
Classes | Typedefs | Functions | Variables
forward_backward.cc File Reference
#include <catch2/catch_all.hpp>
#include <random>
#include <vector>
#include <Eigen/Dense>
#include "sopt/imaging_forward_backward.h"
#include "sopt/l1_non_diff_function.h"
#include "sopt/logging.h"
#include "sopt/maths.h"
#include "sopt/proximal.h"
#include "sopt/types.h"
#include "tools_for_tests/directories.h"
#include "tools_for_tests/tiffwrappers.h"
+ Include dependency graph for forward_backward.cc:

Go to the source code of this file.

Classes

struct  is_imaging_proximal_ref< T >
 
struct  is_l1_g_proximal_ref< T >
 

Typedefs

using Scalar = sopt::t_real
 
using t_Vector = sopt::Vector< Scalar >
 
using t_LinearTransform = sopt::LinearTransform< t_Vector >
 
using t_real = sopt::t_real
 

Functions

sopt::t_int random_integer (sopt::t_int min, sopt::t_int max)
 
 TEST_CASE ("Forward Backward with ||x - x0||_2^2 function", "[fb]")
 
 TEST_CASE ("Check type returned on setting variables")
 

Variables

constexpr auto N = 5
 

Typedef Documentation

◆ Scalar

Definition at line 25 of file forward_backward.cc.

◆ t_LinearTransform

Definition at line 27 of file forward_backward.cc.

◆ t_real

Definition at line 28 of file forward_backward.cc.

◆ t_Vector

Definition at line 26 of file forward_backward.cc.

Function Documentation

◆ random_integer()

sopt::t_int random_integer ( sopt::t_int  min,
sopt::t_int  max 
)

Definition at line 19 of file forward_backward.cc.

19  {
20  extern std::unique_ptr<std::mt19937_64> mersenne;
21  std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
22  return uniform_dist(*mersenne);
23 };
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))

References mersenne().

◆ TEST_CASE() [1/2]

TEST_CASE ( "Check type returned on setting variables"  )

Definition at line 69 of file forward_backward.cc.

69  {
70  // Yeah, could be static asserts
71  using namespace sopt;
72  using namespace sopt::algorithm;
74  CHECK(is_imaging_proximal_ref<decltype(fb.itermax(500))>::value);
75  CHECK(is_imaging_proximal_ref<decltype(fb.step_size(1e-1))>::value);
76  CHECK(is_imaging_proximal_ref<decltype(fb.regulariser_strength(1e-1))>::value);
77  CHECK(is_imaging_proximal_ref<decltype(fb.sigma(1e-1))>::value);
78  CHECK(is_imaging_proximal_ref<decltype(fb.residual_convergence(1.001))>::value);
79  CHECK(is_imaging_proximal_ref<decltype(fb.target(Vector<double>::Zero(0)))>::value);
80  using ConvFunc = ConvergenceFunction<double>;
81  CHECK(is_imaging_proximal_ref<decltype(fb.is_converged(std::declval<ConvFunc>()))>::value);
82  CHECK(is_imaging_proximal_ref<decltype(fb.is_converged(std::declval<ConvFunc &>()))>::value);
83  CHECK(is_imaging_proximal_ref<decltype(fb.is_converged(std::declval<ConvFunc &&>()))>::value);
84  CHECK(is_imaging_proximal_ref<decltype(fb.is_converged(std::declval<ConvFunc const &>()))>::value);
85  CHECK(is_imaging_proximal_ref<decltype(fb.relative_variation(5e-4))>::value);
86  CHECK(is_imaging_proximal_ref<decltype(fb.tight_frame(false))>::value);
87 
88  // Test the types of the l1 g_proximal object separately
89  auto gp = std::make_shared<sopt::algorithm::L1GProximal<Scalar>>(false);
90  CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_tolerance(1e-2))>::value);
91  CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_nu(1))>::value);
92  CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_itermax(50))>::value);
93  CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_positivity_constraint(true))>::value);
94  CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_real_constraint(true))>::value);
95  using LinTrans = LinearTransform<Vector<double>>;
96  CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(linear_transform_identity<double>()))>::value);
97  CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(std::declval<LinTrans>()))>::value);
98  CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(std::declval<LinTrans &&>()))>::value);
99  CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(std::declval<LinTrans &>()))>::value);
100  CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(std::declval<LinTrans const &>()))>::value);
101 }
Joins together direct and indirect operators.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
std::function< bool(Vector< SCALAR > const &)> ConvergenceFunction
Typical function signature for convergence.
Definition: types.h:52

References sopt::algorithm::ImagingForwardBackward< SCALAR >::is_converged(), sopt::algorithm::ImagingForwardBackward< SCALAR >::residual_convergence(), and sopt::algorithm::ImagingForwardBackward< SCALAR >::target().

◆ TEST_CASE() [2/2]

TEST_CASE ( "Forward Backward with ||x - x0||_2^2 function"  ,
""  [fb] 
)

Definition at line 31 of file forward_backward.cc.

31  {
32  using namespace sopt;
33  t_Vector const target0 = t_Vector::Random(N);
34  t_real constexpr beta = 0.2;
35  t_real constexpr regulariser_strength = 0.1;
36  int constexpr itermax = 300;
37  auto const g0 = [](t_Vector &out, const t_real regulariser_strength, const t_Vector &x) {
38  proximal::id(out, regulariser_strength, x);
39  };
40  auto const grad = [](t_Vector &out, const t_Vector image, const t_Vector &res,
41  const t_LinearTransform &Phi) { out = Phi.adjoint() * res; };
42  const t_Vector x_guess = t_Vector::Random(target0.size());
43  const t_Vector res = x_guess - target0;
44  auto const convergence = [&target0](const t_Vector &x, const t_Vector &res) -> bool {
45  return x.isApprox(target0, 1e-9);
46  };
47  CAPTURE(target0);
48  CAPTURE(x_guess);
49  CAPTURE(res);
50  auto fb = algorithm::ForwardBackward<Scalar>(grad, g0, target0)
51  .itermax(itermax)
52  .regulariser_strength(regulariser_strength)
53  .step_size(beta)
54  .is_converged(convergence);
55  auto const result = fb(std::make_tuple(x_guess, res));
56  CAPTURE(result.niters);
57  CAPTURE(result.x);
58  CAPTURE(result.residual);
59  CHECK(result.x.isApprox(target0, 1e-9));
60  CHECK(result.good);
61  CHECK(result.niters < itermax);
62 }
sopt::Vector< Scalar > t_Vector
constexpr auto N
void id(Eigen::DenseBase< T0 > &out, typename real_type< typename T0::Scalar >::type gamma, Eigen::DenseBase< T1 > const &x)
Proximal of a function that is always zero, the identity.
Definition: proximal.h:135
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17

References sopt::proximal::id(), and N.

Variable Documentation

◆ N

constexpr auto N = 5
constexpr

Definition at line 29 of file forward_backward.cc.

Referenced by TEST_CASE().