SOPT
Sparse OPTimisation
joint_map.h
Go to the documentation of this file.
1 #ifndef SOPT_JOINT_MAP_H
2 #define SOPT_JOINT_MAP_H
3 
4 #include "sopt/config.h"
5 #include <functional>
6 #include <limits>
7 #include <memory> // for std::shared_ptr<>
8 #include <utility> // for std::forward<>
9 #include "sopt/exception.h"
10 #include "sopt/forward_backward.h"
12 #include "sopt/linear_transform.h"
13 #include "sopt/logging.h"
14 #include "sopt/types.h"
15 
16 namespace sopt::algorithm {
17 
18 template <typename ALGORITHM>
19 class JointMAP {
20  using t_Vector = typename ALGORITHM::t_Vector;
21  using t_Reg_Term = typename std::function<t_real (const t_Vector &)>;
22  using ResultType = typename ALGORITHM::DiagnosticAndResult;
24  using t_IsConverged = std::function<bool (const t_Vector &, const t_Vector &, const t_real)>;
25 
26  public:
28  struct DiagnosticAndResultReg : public ResultType {
31  bool reg_good;
34  };
35 
36  JointMAP(const std::shared_ptr<ALGORITHM> &algo_ptr, const t_Reg_Term &reg_term,
37  const t_uint number_of_wavelet_coeffs)
38  : algo_ptr_(algo_ptr),
39  reg_term_(reg_term),
40  alpha_(1),
41  beta_(1),
42  k_(1),
43  number_of_wavelet_coeffs_(number_of_wavelet_coeffs),
44  is_converged_([](t_Vector const &, t_Vector const &, t_real const) { return true; }),
45  relative_variation_(1e-3),
46  objective_variation_(1e-3),
47  itermax_(std::numeric_limits<t_uint>::max()){}
48 
49 #define SOPT_MACRO(NAME, TYPE) \
50  TYPE const &NAME() const { return NAME##_; } \
51  JointMAP<ALGORITHM> &NAME(TYPE const &(NAME)) { \
52  NAME##_ = NAME; \
53  return *this; \
54  } \
55  \
56  protected: \
57  TYPE NAME##_; \
58  \
59  public:
60 
62  SOPT_MACRO(itermax, t_uint);
64  SOPT_MACRO(alpha, t_real);
70  SOPT_MACRO(number_of_wavelet_coeffs, t_uint);
72  SOPT_MACRO(algo_ptr, std::shared_ptr<ALGORITHM>);
74  SOPT_MACRO(reg_term, t_Reg_Term);
76  SOPT_MACRO(relative_variation, t_real);
78  SOPT_MACRO(objective_variation, t_real);
81  SOPT_MACRO(is_converged, t_IsConverged);
82 #undef SOPT_MACRO
83  protected:
85  void sanity_check(t_real const &regulariser_strength, t_real const beta, t_real const alpha) const {
86  if (regulariser_strength < 0) SOPT_THROW("Starting regularisation parameter not positive.");
87  if (alpha < 0) SOPT_THROW("Alpha parameter not positive.");
88  if (beta <= 0) SOPT_THROW("Beta not positive.");
89  }
90 
91  public:
94  template <typename... ARGS>
95  DiagnosticAndResultReg operator()(ARGS &&... args) const {
96  SOPT_HIGH_LOG("Performing Joint MAP estimation");
97 
98  ScalarRelativeVariation<t_real> scalvar(relative_variation(), relative_variation(),
99  "Regularisation Parameter");
100  ScalarRelativeVariation<t_real> objvar(objective_variation(), objective_variation(),
101  "Joint Objective Function");
102  sanity_check(this->algo_ptr_->regulariser_strength(), beta(), alpha());
103  t_uint niters(0);
104  bool converged = false;
105  using ResultType = typename ALGORITHM::DiagnosticAndResult;
106  ResultType result = (*(this->algo_ptr_))(std::forward<ARGS>(args)...);
107  t_real regulariser_strength = 0;
108  niters++;
109  t_uint algo_iters(result.niters);
110  for (; (not converged) && (niters < itermax()); ++niters) {
111  SOPT_LOW_LOG(" - [JMAP] Iteration {}/{}", niters, itermax());
112  regulariser_strength = (static_cast<t_real>(number_of_wavelet_coeffs()) / k() + alpha()) /
113  (this->reg_term()(result.x) + beta());
114  SOPT_LOW_LOG(" - [JMAP] Regularisation Parameter Value {}", regulariser_strength);
115  algo_ptr_->regulariser_strength(regulariser_strength);
116  result = (*algo_ptr_)(result);
117  converged = result.good and scalvar(regulariser_strength) and objvar(algo_ptr_->objmin()) and
118  this->is_converged()(result.x, result.residual, regulariser_strength);
119  algo_iters += result.niters;
120  }
121 
122  if (converged) {
123  SOPT_MEDIUM_LOG(" - [JMAP] converged in {} of {} iterations", niters, itermax());
124  } else {
125  // not meaningful if not convergence function
126  SOPT_ERROR(" - [JMAP] did not converge within {} iterations", itermax());
127  }
128  SOPT_MEDIUM_LOG(" - Total Algorithm iterations {} ", algo_iters);
129  DiagnosticAndResultReg diagnostic;
130  static_cast<ResultType &>(diagnostic) = result;
131  diagnostic.niters = algo_iters;
132  diagnostic.reg_good = converged;
133  diagnostic.reg_niters = niters;
134  diagnostic.reg_term = regulariser_strength;
135  return diagnostic;
136  }
137 };
138 
139 } // namespace sopt::algorithm
140 
141 #endif
sopt::Vector< Scalar > t_Vector
JointMAP(const std::shared_ptr< ALGORITHM > &algo_ptr, const t_Reg_Term &reg_term, const t_uint number_of_wavelet_coeffs)
Definition: joint_map.h:36
SOPT_MACRO(objective_variation, t_real)
relative variation of objective_function
SOPT_MACRO(alpha, t_real)
Alpha parameter.
SOPT_MACRO(beta, t_real)
Beta parameter.
SOPT_MACRO(relative_variation, t_real)
relative variation of reg parameter
SOPT_MACRO(reg_term, t_Reg_Term)
Regularsation Term.
SOPT_MACRO(algo_ptr, std::shared_ptr< ALGORITHM >)
Shared ptr with algorithm.
DiagnosticAndResultReg operator()(ARGS &&... args) const
Calls Joint MAP estimation.
Definition: joint_map.h:95
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
SOPT_MACRO(number_of_wavelet_coeffs, t_uint)
number of wavelet coeffs
#define SOPT_THROW(MSG)
Definition: exception.h:46
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
#define SOPT_ERROR(...)
\macro Something is definitely wrong, algorithm exits
Definition: logging.h:211
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:225
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Holds results and reg parameter.
Definition: joint_map.h:28
bool reg_good
Wether convergence was achieved.
Definition: joint_map.h:31
t_uint reg_niters
Number of iterations.
Definition: joint_map.h:33