1 #ifndef SOPT_JOINT_MAP_H
2 #define SOPT_JOINT_MAP_H
4 #include "sopt/config.h"
18 template <
typename ALGORITHM>
21 using t_Reg_Term =
typename std::function<
t_real (
const t_Vector &)>;
22 using ResultType =
typename ALGORITHM::DiagnosticAndResult;
24 using t_IsConverged = std::function<bool (
const t_Vector &,
const t_Vector &,
const t_real)>;
36 JointMAP(
const std::shared_ptr<ALGORITHM> &algo_ptr,
const t_Reg_Term ®_term,
37 const t_uint number_of_wavelet_coeffs)
38 : algo_ptr_(algo_ptr),
43 number_of_wavelet_coeffs_(number_of_wavelet_coeffs),
44 is_converged_([](t_Vector const &, t_Vector const &,
t_real const) {
return true; }),
45 relative_variation_(1e-3),
46 objective_variation_(1e-3),
47 itermax_(std::numeric_limits<t_uint>::max()){}
49 #define SOPT_MACRO(NAME, TYPE) \
50 TYPE const &NAME() const { return NAME##_; } \
51 JointMAP<ALGORITHM> &NAME(TYPE const &(NAME)) { \
85 void sanity_check(
t_real const ®ulariser_strength,
t_real const beta,
t_real const alpha)
const {
86 if (regulariser_strength < 0)
SOPT_THROW(
"Starting regularisation parameter not positive.");
87 if (alpha < 0)
SOPT_THROW(
"Alpha parameter not positive.");
88 if (beta <= 0)
SOPT_THROW(
"Beta not positive.");
94 template <
typename... ARGS>
99 "Regularisation Parameter");
101 "Joint Objective Function");
102 sanity_check(this->algo_ptr_->regulariser_strength(), beta(), alpha());
104 bool converged =
false;
105 using ResultType =
typename ALGORITHM::DiagnosticAndResult;
106 ResultType result = (*(this->algo_ptr_))(std::forward<ARGS>(args)...);
107 t_real regulariser_strength = 0;
109 t_uint algo_iters(result.niters);
110 for (; (not converged) && (niters < itermax()); ++niters) {
111 SOPT_LOW_LOG(
" - [JMAP] Iteration {}/{}", niters, itermax());
112 regulariser_strength = (
static_cast<t_real>(number_of_wavelet_coeffs()) / k() + alpha()) /
113 (this->reg_term()(result.x) + beta());
114 SOPT_LOW_LOG(
" - [JMAP] Regularisation Parameter Value {}", regulariser_strength);
115 algo_ptr_->regulariser_strength(regulariser_strength);
116 result = (*algo_ptr_)(result);
117 converged = result.good and scalvar(regulariser_strength) and objvar(algo_ptr_->objmin()) and
118 this->is_converged()(result.x, result.residual, regulariser_strength);
119 algo_iters += result.niters;
123 SOPT_MEDIUM_LOG(
" - [JMAP] converged in {} of {} iterations", niters, itermax());
126 SOPT_ERROR(
" - [JMAP] did not converge within {} iterations", itermax());
130 static_cast<ResultType &
>(diagnostic) = result;
131 diagnostic.niters = algo_iters;
134 diagnostic.
reg_term = regulariser_strength;
sopt::Vector< Scalar > t_Vector
JointMAP(const std::shared_ptr< ALGORITHM > &algo_ptr, const t_Reg_Term ®_term, const t_uint number_of_wavelet_coeffs)
SOPT_MACRO(objective_variation, t_real)
relative variation of objective_function
SOPT_MACRO(alpha, t_real)
Alpha parameter.
SOPT_MACRO(beta, t_real)
Beta parameter.
SOPT_MACRO(relative_variation, t_real)
relative variation of reg parameter
SOPT_MACRO(reg_term, t_Reg_Term)
Regularsation Term.
SOPT_MACRO(algo_ptr, std::shared_ptr< ALGORITHM >)
Shared ptr with algorithm.
DiagnosticAndResultReg operator()(ARGS &&... args) const
Calls Joint MAP estimation.
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
SOPT_MACRO(number_of_wavelet_coeffs, t_uint)
number of wavelet coeffs
#define SOPT_LOW_LOG(...)
Low priority message.
#define SOPT_HIGH_LOG(...)
High priority message.
#define SOPT_ERROR(...)
\macro Something is definitely wrong, algorithm exits
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
double t_real
Root of the type hierarchy for real numbers.
size_t t_uint
Root of the type hierarchy for unsigned integers.
Holds results and reg parameter.
bool reg_good
Wether convergence was achieved.
t_uint reg_niters
Number of iterations.