SOPT
Sparse OPTimisation
imaging_padmm.h
Go to the documentation of this file.
1 #ifndef SOPT_L1_PROXIMAL_ADMM_H
2 #define SOPT_L1_PROXIMAL_ADMM_H
3 
4 #include "sopt/config.h"
5 #include <limits> // for std::numeric_limits<>
6 #include <numeric>
7 #include <tuple>
8 #include <utility>
9 #include "sopt/exception.h"
10 #include "sopt/l1_proximal.h"
11 #include "sopt/linear_transform.h"
12 #include "sopt/logging.h"
13 #include "sopt/padmm.h"
14 #include "sopt/proximal.h"
16 #include "sopt/types.h"
17 
18 namespace sopt::algorithm {
19 template <typename SCALAR>
23 
24  public:
25  using value_type = typename PADMM::value_type;
26  using Scalar = typename PADMM::Scalar;
27  using Real = typename PADMM::Real;
28  using t_Vector = typename PADMM::t_Vector;
30  using t_Proximal = typename PADMM::t_Proximal;
32 
34  struct Diagnostic : public PADMM::Diagnostic {
37  Diagnostic(t_uint niters = 0u, bool good = false,
38  typename proximal::L1<Scalar>::Diagnostic const &l1diag =
40  : PADMM::Diagnostic(niters, good), l1_diagnostic(l1diag) {}
43  : PADMM::Diagnostic(niters, good, std::move(residual)), l1_diagnostic(l1diag) {}
44  };
46  struct DiagnosticAndResult : public Diagnostic {
49  };
50 
54  template <typename DERIVED>
55  ImagingProximalADMM(Eigen::MatrixBase<DERIVED> const &target)
56  : l1_proximal_(),
57  l2ball_proximal_(1e0),
58  tight_frame_(false),
59  residual_tolerance_(1e-4),
60  relative_variation_(1e-4),
61  residual_convergence_(nullptr),
62  objective_convergence_(nullptr),
63  itermax_(std::numeric_limits<t_uint>::max()),
64  regulariser_strength_(1e-8),
65  lagrange_update_scale_(0.9),
66  is_converged_(),
68  target_(target) {}
69  virtual ~ImagingProximalADMM() {}
70 
71 // Macro helps define properties that can be initialized as in
72 // auto padmm = ImagingProximalADMM<float>().prop0(value).prop1(value);
73 #define SOPT_MACRO(NAME, TYPE) \
74  TYPE const &NAME() const { return NAME##_; } \
75  ImagingProximalADMM<SCALAR> &NAME(TYPE const &(NAME)) { \
76  NAME##_ = NAME; \
77  return *this; \
78  } \
79  \
80  protected: \
81  TYPE NAME##_; \
82  \
83  public:
84 
86  {
87  return &l1_proximal_;
88  }
89 
95  SOPT_MACRO(tight_frame, bool);
98  SOPT_MACRO(residual_tolerance, Real);
101  SOPT_MACRO(relative_variation, Real);
109  SOPT_MACRO(itermax, t_uint);
111  SOPT_MACRO(regulariser_strength, Real);
113  SOPT_MACRO(lagrange_update_scale, Real);
118 
119 #undef SOPT_MACRO
121  t_Vector const &target() const { return target_; }
123  template <typename DERIVED>
124  ImagingProximalADMM<Scalar> &target(Eigen::MatrixBase<DERIVED> const &target) {
125  target_ = target;
126  return *this;
127  }
128 
132  return operator()(out, PADMM::initial_guess(target(), Phi()));
133  }
137  Diagnostic operator()(t_Vector &out, std::tuple<t_Vector, t_Vector> const &guess) const {
138  return operator()(out, std::get<0>(guess), std::get<1>(guess));
139  }
144  std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
145  return operator()(out, std::get<0>(guess), std::get<1>(guess));
146  }
149  DiagnosticAndResult operator()(std::tuple<t_Vector, t_Vector> const &guess) const {
150  return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
151  }
155  std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
156  DiagnosticAndResult result;
157  static_cast<Diagnostic &>(result) = operator()(result.x, guess);
158  return result;
159  }
163  DiagnosticAndResult result;
164  static_cast<Diagnostic &>(result) = operator()(result.x,
166  return result;
167  }
170  DiagnosticAndResult result = warmstart;
171  static_cast<Diagnostic &>(result) = operator()(result.x, warmstart.x, warmstart.residual);
172  return result;
173  }
174 
176  template <typename... ARGS>
177  typename std::enable_if<sizeof...(ARGS) >= 1, ImagingProximalADMM &>::type Phi(ARGS &&... args) {
178  Phi_ = linear_transform(std::forward<ARGS>(args)...);
179  return *this;
180  }
181 
184  proximal::L1<Scalar> &l1_proximal() { return l1_proximal_; }
187  proximal::WeightedL2Ball<Scalar> &l2ball_proximal() { return l2ball_proximal_; }
188 
191  t_LinearTransform const &Psi() const { return l1_proximal().Psi(); }
193  template <typename... ARGS>
194  typename std::enable_if<sizeof...(ARGS) >= 1, ImagingProximalADMM<Scalar> &>::type Psi(
195  ARGS &&... args) {
196  l1_proximal().Psi(std::forward<ARGS>(args)...);
197  return *this;
198  }
199 
200 // Forwards get/setters to L1 and L2Ball proximals
201 // In practice, we end up with a bunch of functions that make it simpler to set or get values
202 // associated with the two proximal operators.
203 // E.g.: `paddm.l1_proximal_itermax(100).l2ball_epsilon(1e-2).l1_proximal_tolerance(1e-4)`.
204 // ~~~
205 #define SOPT_MACRO(VAR, NAME, PROXIMAL) \
206  \
207  decltype(std::declval<proximal::PROXIMAL<Scalar> const>().VAR()) NAME##_proximal_##VAR() const { \
208  return NAME##_proximal().VAR(); \
209  } \
210  \
211  ImagingProximalADMM<Scalar> &NAME##_proximal_##VAR( \
212  decltype(std::declval<proximal::PROXIMAL<Scalar> const>().VAR()) (VAR)) { \
213  NAME##_proximal().VAR(VAR); \
214  return *this; \
215  }
216  SOPT_MACRO(itermax, l1, L1);
217  SOPT_MACRO(tolerance, l1, L1);
218  SOPT_MACRO(positivity_constraint, l1, L1);
219  SOPT_MACRO(real_constraint, l1, L1);
220  SOPT_MACRO(fista_mixing, l1, L1);
221  SOPT_MACRO(nu, l1, L1);
222  SOPT_MACRO(weights, l1, L1);
223  SOPT_MACRO(epsilon, l2ball, WeightedL2Ball);
224  SOPT_MACRO(weights, l2ball, WeightedL2Ball);
225 #ifdef SOPT_MPI
226  SOPT_MACRO(communicator, l2ball, WeightedL2Ball);
227  SOPT_MACRO(direct_space_comm, l1, L1);
228  SOPT_MACRO(adjoint_space_comm, l1, L1);
229 #endif
230 #undef SOPT_MACRO
231 
234  return residual_convergence(nullptr).residual_tolerance(tolerance);
235  }
238  return objective_convergence(nullptr).relative_variation(tolerance);
239  }
241  ImagingProximalADMM<Scalar> &is_converged(std::function<bool(t_Vector const &x)> const &func) {
242  return is_converged([func](t_Vector const &x, t_Vector const &) { return func(x); });
243  }
244 
245  protected:
247  t_Vector target_;
248 
253  Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res) const;
254 
256  template <typename T0, typename T1>
257  typename proximal::L1<Scalar>::Diagnostic l1_proximal(Eigen::MatrixBase<T0> &out, Real regulariser_strength,
258  Eigen::MatrixBase<T1> const &x) const {
259  return l1_proximal_real_constraint()
260  ? call_l1_proximal(out, regulariser_strength, x.real().template cast<typename T1::Scalar>())
261  : call_l1_proximal(out, regulariser_strength, x);
262  }
263 
265  template <typename T0, typename T1>
266  typename proximal::L1<Scalar>::Diagnostic call_l1_proximal(Eigen::MatrixBase<T0> &out, Real regulariser_strength,
267  Eigen::MatrixBase<T1> const &x) const {
268  if (tight_frame()) {
269  l1_proximal().tight_frame(out, regulariser_strength, x);
270  return {0, 0, l1_proximal().objective(x, out, regulariser_strength), true};
271  }
272  return l1_proximal()(out, regulariser_strength, x);
273  }
274 
276  bool residual_convergence(t_Vector const &x, t_Vector const &residual) const;
277 
279  bool objective_convergence(ScalarRelativeVariation<Scalar> &scalvar, t_Vector const &x,
280  t_Vector const &residual) const;
281 
283  bool is_converged(ScalarRelativeVariation<Scalar> &scalvar, t_Vector const &x,
284  t_Vector const &residual) const;
285 };
286 
287 template <typename SCALAR>
288 typename ImagingProximalADMM<SCALAR>::Diagnostic ImagingProximalADMM<SCALAR>::operator()(
289  t_Vector &out, t_Vector const &guess, t_Vector const &res) const {
290  SOPT_HIGH_LOG("Performing Proximal ADMM with L1 and L2 operators");
291  // The f proximal is an L1 proximal that stores some diagnostic result
292  Diagnostic result;
293  auto const f_proximal = [this, &result](t_Vector &out, Real regulariser_strength, t_Vector const &x) {
294  result.l1_diagnostic = this->l1_proximal(out, regulariser_strength, x);
295  };
296  auto const g_proximal = [this](t_Vector &out, Real regulariser_strength, t_Vector const &x) {
297  this->l2ball_proximal()(out, regulariser_strength, x);
298  };
299  ScalarRelativeVariation<Scalar> scalvar(relative_variation(), relative_variation(),
300  "Objective function");
301  auto const convergence = [this, scalvar](t_Vector const &x, t_Vector const &residual) mutable {
302  return this->is_converged(scalvar, x, residual);
303  };
304  auto const padmm = PADMM(f_proximal, g_proximal, target())
305  .itermax(itermax())
306  .regulariser_strength(regulariser_strength())
307  .lagrange_update_scale(lagrange_update_scale())
308  .Phi(Phi())
309  .is_converged(convergence);
310  static_cast<typename PADMM::Diagnostic &>(result) = padmm(out, std::tie(guess, res));
311  return result;
312 }
313 
314 template <typename SCALAR>
316  t_Vector const &residual) const {
317  if (static_cast<bool>(residual_convergence())) return residual_convergence()(x, residual);
318  if (residual_tolerance() <= 0e0) return true;
319  auto const residual_norm = sopt::l2_norm(residual, l2ball_proximal_weights());
320  SOPT_LOW_LOG(" - [PADMM] Residuals: {} <? {}", residual_norm, residual_tolerance());
321  return residual_norm < residual_tolerance();
322 }
323 
324 template <typename SCALAR>
325 bool ImagingProximalADMM<SCALAR>::objective_convergence(ScalarRelativeVariation<Scalar> &scalvar,
326  t_Vector const &x,
327  t_Vector const &residual) const {
328  if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
329  if (scalvar.relative_tolerance() <= 0e0) return true;
330  auto const current =
331  sopt::l1_norm(static_cast<t_Vector>(Psi().adjoint() * x), l1_proximal_weights());
332  return scalvar(current);
333 }
334 
335 template <typename SCALAR>
336 bool ImagingProximalADMM<SCALAR>::is_converged(ScalarRelativeVariation<Scalar> &scalvar,
337  t_Vector const &x, t_Vector const &residual) const {
338  auto const user = static_cast<bool>(is_converged()) == false or is_converged()(x, residual);
339  auto const res = residual_convergence(x, residual);
340  auto const obj = objective_convergence(scalvar, x, residual);
341  // beware of short-circuiting!
342  // better evaluate each convergence function everytime, especially with mpi
343  return user and res and obj;
344 }
345 } // namespace sopt::algorithm
346 #endif
sopt::Vector< Scalar > t_Vector
ImagingProximalADMM(Eigen::MatrixBase< DERIVED > const &target)
Definition: imaging_padmm.h:55
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
SOPT_MACRO(weights, l2ball, WeightedL2Ball)
SOPT_MACRO(epsilon, l2ball, WeightedL2Ball)
SOPT_MACRO(residual_tolerance, Real)
Convergence of the relative variation of the objective functions.
proximal::L1< Scalar > & l1_proximal()
L1 proximal used during calculation.
typename PADMM::t_LinearTransform t_LinearTransform
Definition: imaging_padmm.h:29
SOPT_MACRO(fista_mixing, l1, L1)
ImagingProximalADMM< Scalar > &::type Psi(ARGS &&... args)
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const
Makes it simple to chain different calls to PADMM.
DiagnosticAndResult operator()(std::tuple< t_Vector, t_Vector > const &guess) const
Calls Proximal ADMM.
DiagnosticAndResult operator()(std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Proximal ADMM.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector, t_Vector > const &guess) const
Calls Proximal ADMM.
SOPT_MACRO(objective_convergence, t_IsConverged)
Convergence of the residuals.
ImagingProximalADMM< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
ImagingProximalADMM< Scalar > & objective_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
t_Vector const & target() const
Vector of target measurements.
typename PADMM::t_Vector t_Vector
Definition: imaging_padmm.h:28
typename PADMM::t_IsConverged t_IsConverged
Definition: imaging_padmm.h:31
SOPT_MACRO(regulariser_strength, Real)
γ parameter.
SOPT_MACRO(tight_frame, bool)
Whether Ψ is a tight-frame or not.
proximal::L1< Scalar > * g_proximal()
Definition: imaging_padmm.h:85
DiagnosticAndResult operator()() const
Calls Proximal ADMM.
ImagingProximalADMM< Scalar > & residual_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
ImagingProximalADMM< Scalar > & target(Eigen::MatrixBase< DERIVED > const &target)
Sets the vector of target measurements.
SOPT_MACRO(Phi, t_LinearTransform)
Measurement operator.
SOPT_MACRO(lagrange_update_scale, Real)
Lagrange update scale β
SOPT_MACRO(l1_proximal, proximal::L1< Scalar >)
Maximum number of iterations.
SOPT_MACRO(positivity_constraint, l1, L1)
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Proximal ADMM.
typename PADMM::value_type value_type
Definition: imaging_padmm.h:25
SOPT_MACRO(residual_convergence, t_IsConverged)
Convergence of the residuals.
SOPT_MACRO(relative_variation, Real)
Convergence of the relative variation of the objective functions.
Diagnostic operator()(t_Vector &out) const
Calls Proximal ADMM.
SOPT_MACRO(real_constraint, l1, L1)
ImagingProximalADMM &::type Phi(ARGS &&... args)
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
t_LinearTransform const & Psi() const
Analysis operator Ψ
typename PADMM::t_Proximal t_Proximal
Definition: imaging_padmm.h:30
proximal::WeightedL2Ball< Scalar > & l2ball_proximal()
Proximal of the L2 ball.
SOPT_MACRO(l2ball_proximal, proximal::WeightedL2Ball< Scalar >)
The weighted L2 proximal functioning as g.
Proximal Alternate Direction method of mutltipliers.
Definition: padmm.h:19
SCALAR value_type
Scalar type.
Definition: padmm.h:22
LinearTransform< t_Vector > t_LinearTransform
Type of the Ψ and Ψ^H operations, as well as Φ and Φ^H.
Definition: padmm.h:30
Vector< Scalar > t_Vector
Type of then underlying vectors.
Definition: padmm.h:28
value_type Scalar
Scalar type.
Definition: padmm.h:24
ProximalFunction< Scalar > t_Proximal
Type of the convergence function.
Definition: padmm.h:34
std::tuple< t_Vector, t_Vector > initial_guess() const
Computes initial guess for x and the residual using the targets.
Definition: padmm.h:183
std::function< bool(const t_Vector &, const t_Vector &)> t_IsConverged
Type of the convergence function.
Definition: padmm.h:32
typename real_type< Scalar >::type Real
Real type.
Definition: padmm.h:26
auto tight_frame(T &&... args) const -> decltype(this->L1TightFrame< Scalar >::operator()(std::forward< T >(args)...))
Special case if Ψ ia a tight frame.
Definition: l1_proximal.h:313
LinearTransform< Vector< Scalar > > const & Psi() const
Linear transform applied to input prior to L1 norm.
Definition: l1_proximal.h:302
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:12
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:38
LinearTransform< Vector< SCALAR > > linear_transform_identity()
Helper function to create a linear transform that's just the identity.
real_type< typename T0::Scalar >::type l1_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L1 norm.
Definition: maths.h:116
real_type< typename T0::Scalar >::type l2_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L2 norm.
Definition: maths.h:140
Values indicating how the algorithm ran.
Definition: imaging_padmm.h:34
proximal::L1< Scalar >::Diagnostic l1_diagnostic
Diagnostic from calling L1 proximal.
Definition: imaging_padmm.h:36
Diagnostic(t_uint niters, bool good, typename proximal::L1< Scalar >::Diagnostic const &l1diag, t_Vector &&residual)
Definition: imaging_padmm.h:41
Diagnostic(t_uint niters=0u, bool good=false, typename proximal::L1< Scalar >::Diagnostic const &l1diag=typename proximal::L1< Scalar >::Diagnostic())
Definition: imaging_padmm.h:37
Values indicating how the algorithm ran.
Definition: padmm.h:37
bool good
Wether convergence was achieved.
Definition: padmm.h:41
t_uint niters
Number of iterations.
Definition: padmm.h:39
t_Vector residual
the residual from the last iteration
Definition: padmm.h:43
How did calling L1 go?
Definition: l1_proximal.h:190