SOPT
Sparse OPTimisation
primal_dual.cc
Go to the documentation of this file.
1 #include <catch2/catch_all.hpp>
2 #include <random>
3 
4 #include <cassert>
5 
6 #include <Eigen/Dense>
7 
9 #include "sopt/primal_dual.h"
10 #include "sopt/proximal.h"
11 #include "sopt/types.h"
12 
16 using Catch::Approx;
17 
18 auto constexpr N = 5;
19 
20 TEST_CASE("Primal Dual Imaging", "[primaldual]") {
21  using namespace sopt;
22 
23  t_Matrix const mId = t_Matrix::Identity(N, N);
24 
25  t_Vector target = t_Vector::Random(N);
26 
28 
29  auto const epsilon = target.stableNorm() / 2;
30 
32  .l1_proximal_weights(t_Vector::Ones(target.size()))
33  .Phi(mId)
34  .Psi(mId)
35  .itermax(5000)
36  .tau(0.1)
37  .regulariser_strength(0.4)
38  .l2ball_proximal_epsilon(epsilon)
39  .relative_variation(1e-4)
41 
42  auto const result = primaldual();
43  CHECK((result.x - target).stableNorm() <= Approx(epsilon).margin(1e-10));
44  CHECK(result.good);
45  primaldual
46  .l1_proximal([](t_Vector &output, const t_real &regulariser_strength, const t_Vector &input) {
47  output = regulariser_strength * input;
48  })
49  .l1_proximal_weighted(
50  [](t_Vector &output, const Vector<t_real> &regulariser_strength, const t_Vector &input) {
51  output = 10 * regulariser_strength.array() * input.array();
52  });
53  CHECK_THROWS(primaldual());
54 }
55 TEST_CASE("Primal Dual with 0.5 * ||x - x0||_2^2 function", "[primaldual]") {
56  using namespace sopt;
57  t_Vector const target0 = t_Vector::Random(N);
58  auto const f = [](t_Vector &out, const t_real regulariser_strength, const t_Vector &x) {
59  proximal::id(out, regulariser_strength, x);
60  };
61  auto const g = proximal::L2Norm<Scalar>();
62  const t_Vector x_guess = t_Vector::Random(target0.size());
63  const t_Vector res = x_guess - target0;
64  auto const convergence = [&target0](const t_Vector &x, const t_Vector &res) -> bool {
65  return x.isApprox(target0, 1e-9);
66  };
67  CAPTURE(target0);
68  CAPTURE(x_guess);
69  CAPTURE(res);
70  auto const pd = algorithm::PrimalDual<Scalar>(f, g, target0)
71  .itermax(3000)
72  .regulariser_strength(0.9)
73  .rho(0.5)
74  .update_scale(0.5)
75  .is_converged(convergence);
76  auto const result = pd(std::make_tuple(x_guess, res));
77  CAPTURE(result.niters);
78  CAPTURE(result.x);
79  CAPTURE(result.residual);
80  CHECK(result.x.isApprox(target0, 1e-9));
81  CHECK(result.good);
82  CHECK(result.niters < 200);
83 }
84 
85 template <typename T>
86 struct is_primal_dual_ref : public std::is_same<sopt::algorithm::ImagingPrimalDual<double> &, T> {};
87 TEST_CASE("Check type returned on setting variables") {
88  // Yeah, could be static asserts
89  using namespace sopt;
90  using namespace sopt::algorithm;
92  CHECK(is_primal_dual_ref<decltype(pd.itermax(500))>::value);
93  CHECK(is_primal_dual_ref<decltype(pd.sigma(1))>::value);
94  CHECK(is_primal_dual_ref<decltype(pd.tau(1))>::value);
95  CHECK(is_primal_dual_ref<decltype(pd.rho(1))>::value);
96  CHECK(is_primal_dual_ref<decltype(pd.xi(1))>::value);
97  CHECK(is_primal_dual_ref<decltype(pd.regulariser_strength(1e0))>::value);
98  CHECK(is_primal_dual_ref<decltype(pd.update_scale(1e0))>::value);
99  CHECK(is_primal_dual_ref<decltype(pd.positivity_constraint(true))>::value);
100  CHECK(is_primal_dual_ref<decltype(pd.real_constraint(true))>::value);
101 }
sopt::Vector< Scalar > t_Vector
sopt::t_real Scalar
sopt::Matrix< Scalar > t_Matrix
ImagingPrimalDual &::type Psi(ARGS &&... args)
ImagingPrimalDual< Scalar > & residual_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
ImagingPrimalDual &::type Phi(ARGS &&... args)
Primal Dual Algorithm.
Definition: primal_dual.h:24
PrimalDual< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
Definition: primal_dual.h:156
Proximal for the L2 norm.
Definition: proximal.h:157
void id(Eigen::DenseBase< T0 > &out, typename real_type< typename T0::Scalar >::type gamma, Eigen::DenseBase< T1 > const &x)
Proximal of a function that is always zero, the identity.
Definition: proximal.h:135
Eigen::CwiseUnaryOp< const details::ProjectPositiveQuadrant< typename T::Scalar >, const T > positive_quadrant(Eigen::DenseBase< T > const &input)
Expression to create projection onto positive quadrant.
Definition: maths.h:60
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:12
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:38
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > Matrix
A matrix of a given type.
Definition: types.h:29
TEST_CASE("Primal Dual Imaging", "[primaldual]")
Definition: primal_dual.cc:20
constexpr auto N
Definition: primal_dual.cc:18