SOPT
Sparse OPTimisation
imaging_primal_dual.h
Go to the documentation of this file.
1 #ifndef SOPT_L1_PRIMAL_DUAL_H
2 #define SOPT_L1_PRIMAL_DUAL_H
3 
4 #include "sopt/config.h"
5 #include <limits> // for std::numeric_limits<>
6 #include <numeric>
7 #include <tuple>
8 #include <utility>
9 #include "sopt/exception.h"
10 #include "sopt/l1_proximal.h"
11 #include "sopt/linear_transform.h"
12 #include "sopt/logging.h"
13 #include "sopt/primal_dual.h"
14 #include "sopt/proximal.h"
16 #include "sopt/types.h"
17 
18 namespace sopt::algorithm {
19 template <typename SCALAR>
22  using PD = PrimalDual<SCALAR>;
23 
24  public:
25  using value_type = typename PD::value_type;
26  using Scalar = typename PD::Scalar;
27  using Real = typename PD::Real;
28  using t_Vector = typename PD::t_Vector;
30  template <typename T>
31  using t_Proximal = std::function<void(t_Vector &, const T &, const t_Vector &)>;
32  using t_IsConverged = typename PD::t_IsConverged;
33  using t_Constraint = typename PD::t_Constraint;
35 
37  struct Diagnostic : public PD::Diagnostic {
38  Diagnostic(t_uint niters = 0u, bool good = false) : PD::Diagnostic(niters, good) {}
40  : PD::Diagnostic(niters, good, std::move(residual)) {}
41  };
43  struct DiagnosticAndResult : public Diagnostic {
46  };
47 
51  template <typename DERIVED>
52  ImagingPrimalDual(Eigen::MatrixBase<DERIVED> const &target)
53  : l1_proximal_([](t_Vector &out, const Real &regulariser_strength, const t_Vector &x) {
54  proximal::l1_norm<t_Vector, t_Vector>(out, regulariser_strength, x);
55  }),
56  l1_proximal_weighted_([](t_Vector &out, const Vector<Real> &regulariser_strength, const t_Vector &x) {
57  proximal::l1_norm<t_Vector, t_Vector, Vector<Real>>(out, regulariser_strength, x);
58  }),
59  l1_proximal_weights_(Vector<Real>::Ones(1)),
60  l2ball_proximal_(1e0),
61  residual_tolerance_(1e-4),
62  relative_variation_(1e-4),
63  residual_convergence_(nullptr),
64  objective_convergence_(nullptr),
65  itermax_(std::numeric_limits<t_uint>::max()),
66  sigma_(1),
67  tau_(0.5),
68  regulariser_strength_(0.5),
69  update_scale_(1),
70  precondition_stepsize_(0.5),
71  precondition_weights_(t_Vector::Ones(target.size())),
72  precondition_iters_(0),
73  xi_(1),
74  rho_(1),
75  is_converged_(),
76  Phi_(linear_transform_identity<Scalar>()),
77  Psi_(linear_transform_identity<Scalar>()),
78  random_measurement_updater_([]() { return true; }),
79  random_wavelet_updater_([]() { return true; }),
80  target_(target) {}
81  virtual ~ImagingPrimalDual() {}
82 
83 // Macro helps define properties that can be initialized as in
84 // auto padmm = ImagingPrimalDual<float>().prop0(value).prop1(value);
85 #define SOPT_MACRO(NAME, TYPE) \
86  TYPE const &NAME() const { return NAME##_; } \
87  ImagingPrimalDual<SCALAR> &NAME(TYPE const &(NAME)) { \
88  NAME##_ = NAME; \
89  return *this; \
90  } \
91  \
92  protected: \
93  TYPE NAME##_; \
94  \
95  public:
99  SOPT_MACRO(l1_proximal_weighted, t_Proximal<Vector<Real>>);
101  SOPT_MACRO(l1_proximal_weights, Vector<Real>);
106  SOPT_MACRO(residual_tolerance, Real);
109  SOPT_MACRO(relative_variation, Real);
117  SOPT_MACRO(itermax, t_uint);
119  SOPT_MACRO(regulariser_strength, Real);
121  SOPT_MACRO(update_scale, Real);
123  SOPT_MACRO(positivity_constraint, bool);
125  SOPT_MACRO(real_constraint, bool);
135  SOPT_MACRO(precondition_stepsize, Real);
137  SOPT_MACRO(precondition_weights, t_Vector);
139  SOPT_MACRO(precondition_iters, t_uint);
147  SOPT_MACRO(random_measurement_updater, t_Random_Updater);
149  SOPT_MACRO(random_wavelet_updater, t_Random_Updater);
150 #ifdef SOPT_MPI
152  SOPT_MACRO(v_all_sum_all_comm, mpi::Communicator);
154  SOPT_MACRO(u_all_sum_all_comm, mpi::Communicator);
155 #endif
156 
157 #undef SOPT_MACRO
159  t_Vector const &target() const { return target_; }
161  template <typename DERIVED>
162  ImagingPrimalDual<Scalar> &target(Eigen::MatrixBase<DERIVED> const &target) {
163  target_ = target;
164  return *this;
165  }
166 
170  return operator()(out, PD::initial_guess(target(), Phi()));
171  }
175  Diagnostic operator()(t_Vector &out, std::tuple<t_Vector, t_Vector> const &guess) const {
176  return operator()(out, std::get<0>(guess), std::get<1>(guess));
177  }
182  std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
183  return operator()(out, std::get<0>(guess), std::get<1>(guess));
184  }
187  DiagnosticAndResult operator()(std::tuple<t_Vector, t_Vector> const &guess) const {
188  return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
189  }
193  std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
194  DiagnosticAndResult result;
195  static_cast<Diagnostic &>(result) = operator()(result.x, guess);
196  return result;
197  }
201  DiagnosticAndResult result;
202  static_cast<Diagnostic &>(result) = operator()(result.x,
204  return result;
205  }
208  DiagnosticAndResult result = warmstart;
209  static_cast<Diagnostic &>(result) = operator()(result.x, warmstart.x, warmstart.residual);
210  return result;
211  }
212 
214  template <typename... ARGS>
215  typename std::enable_if<sizeof...(ARGS) >= 1, ImagingPrimalDual &>::type Phi(ARGS &&... args) {
216  Phi_ = linear_transform(std::forward<ARGS>(args)...);
217  return *this;
218  }
219 
222  proximal::WeightedL2Ball<Scalar> &l2ball_proximal() { return l2ball_proximal_; }
223 
225  template <typename... ARGS>
226  typename std::enable_if<sizeof...(ARGS) >= 1, ImagingPrimalDual &>::type Psi(ARGS &&... args) {
227  Psi_ = linear_transform(std::forward<ARGS>(args)...);
228  return *this;
229  }
230 
231 // Forwards get/setters to L1 and L2Ball proximals
232 // In practice, we end up with a bunch of functions that make it simpler to set or get values
233 // associated with the two proximal operators.
234 // E.g.: `paddm.l1_proximal_itermax(100).l2ball_epsilon(1e-2).l1_proximal_tolerance(1e-4)`.
235 // ~~~
236 #define SOPT_MACRO(VAR, NAME, PROXIMAL) \
237  \
238  decltype(std::declval<proximal::PROXIMAL<Scalar> const>().VAR()) NAME##_proximal_##VAR() const { \
239  return NAME##_proximal().VAR(); \
240  } \
241  \
242  ImagingPrimalDual<Scalar> &NAME##_proximal_##VAR( \
243  decltype(std::declval<proximal::PROXIMAL<Scalar> const>().VAR()) (VAR)) { \
244  NAME##_proximal().VAR(VAR); \
245  return *this; \
246  }
247  SOPT_MACRO(epsilon, l2ball, WeightedL2Ball);
248  SOPT_MACRO(weights, l2ball, WeightedL2Ball);
249 #ifdef SOPT_MPI
250  SOPT_MACRO(communicator, l2ball, WeightedL2Ball);
251 #endif
252 #undef SOPT_MACRO
253 
256  return residual_convergence(nullptr).residual_tolerance(tolerance);
257  }
260  return objective_convergence(nullptr).relative_variation(tolerance);
261  }
263  ImagingPrimalDual<Scalar> &is_converged(std::function<bool(t_Vector const &x)> const &func) {
264  return is_converged([func](t_Vector const &x, t_Vector const &) { return func(x); });
265  }
266 
267  protected:
269  t_Vector target_;
270 
275  Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res) const;
276 
278  bool residual_convergence(t_Vector const &x, t_Vector const &residual) const;
279 
282  t_Vector const &residual) const;
283 
285  bool is_converged(ScalarRelativeVariation<Scalar> &scalvar, t_Vector const &x,
286  t_Vector const &residual) const;
288  bool check_l1_weight_proximal(const t_Proximal<Real> &no_weights,
289  const t_Proximal<Vector<Real>> &with_weights) const {
290  const Vector<SCALAR> x = Vector<SCALAR>::Ones(this->l1_proximal_weights().size());
291  Vector<SCALAR> output = Vector<SCALAR>::Zero(this->l1_proximal_weights().size());
292  Vector<SCALAR> outputw = Vector<SCALAR>::Zero(this->l1_proximal_weights().size());
293  no_weights(output, 1, x);
294  with_weights(outputw, Vector<Real>::Ones(this->l1_proximal_weights().size()), x);
295  return output.isApprox(outputw);
296  }
297 };
298 
299 template <typename SCALAR>
300 typename ImagingPrimalDual<SCALAR>::Diagnostic ImagingPrimalDual<SCALAR>::operator()(
301  t_Vector &out, t_Vector const &guess, t_Vector const &res) const {
302  SOPT_HIGH_LOG("Performing Primal Dual with L1 and L2 operators");
303  // The f proximal is an L1 proximal that stores some diagnostic result
304  if (not check_l1_weight_proximal(l1_proximal(), l1_proximal_weighted()))
305  SOPT_THROW(
306  "l1 proximal and weighted l1 proximal appear to be different functions. Please make sure "
307  "both are the same function.");
308  auto const f_proximal = [this](t_Vector &out, Real regulariser_strength, t_Vector const &x) {
309  if (this->l1_proximal_weights().size() > 1)
310  this->l1_proximal_weighted()(out, this->l1_proximal_weights() * regulariser_strength, x);
311  else
312  this->l1_proximal()(out, this->l1_proximal_weights()(0) * regulariser_strength, x);
313  };
314  auto const g_proximal = [this](t_Vector &out, Real regulariser_strength, t_Vector const &x) {
315  this->l2ball_proximal()(out, regulariser_strength, x);
316  // applying preconditioning
317  for (t_int i = 0; i < this->precondition_iters(); i++)
318  this->l2ball_proximal()(
319  out, regulariser_strength,
320  out - this->precondition_stepsize() *
321  (out.array() * this->precondition_weights().array() - x.array()).matrix());
322 
323  if (this->precondition_iters() > 0) out = out.array() * this->precondition_weights().array();
324  };
325  ScalarRelativeVariation<Scalar> scalvar(relative_variation(), relative_variation(),
326  "Objective function");
327  auto const convergence = [this, scalvar](t_Vector const &x, t_Vector const &residual) mutable {
328  return this->is_converged(scalvar, x, residual);
329  };
330  const bool positive = positivity_constraint();
331  const bool real = real_constraint();
332  t_Constraint constraint = [real, positive](t_Vector &out, const t_Vector &x) {
333  if (real) out = x.real();
334  if (positive) out = sopt::positive_quadrant(x);
335  if (not real and not positive) out = x;
336  };
337  auto const pd = PD(f_proximal, g_proximal, target())
338  .itermax(itermax())
339  .constraint(constraint)
340  .sigma(sigma())
341  .tau(tau())
342  .regulariser_strength(regulariser_strength())
343  .update_scale(update_scale())
344  .xi(xi())
345  .rho(rho())
346  .regulariser_strength(regulariser_strength())
347  .Phi(Phi())
348  .Psi(Psi())
349  .random_measurement_updater(random_measurement_updater())
350  .random_wavelet_updater(random_wavelet_updater())
351 #ifdef SOPT_MPI
352  .v_all_sum_all_comm(v_all_sum_all_comm())
353  .u_all_sum_all_comm(u_all_sum_all_comm())
354 #endif
355  .is_converged(convergence);
356  Diagnostic result;
357  static_cast<typename PD::Diagnostic &>(result) = pd(out, std::tie(guess, res));
358  return result;
359 }
360 
361 template <typename SCALAR>
363  t_Vector const &residual) const {
364  if (static_cast<bool>(residual_convergence())) return residual_convergence()(x, residual);
365  if (residual_tolerance() <= 0e0) return true;
366  auto const residual_norm = sopt::l2_norm(residual, l2ball_proximal_weights());
367  SOPT_LOW_LOG(" - [Primal Dual] Residuals: {} <? {}", residual_norm, residual_tolerance());
368  return residual_norm < residual_tolerance();
369 }
370 
371 template <typename SCALAR>
372 bool ImagingPrimalDual<SCALAR>::objective_convergence(ScalarRelativeVariation<Scalar> &scalvar,
373  t_Vector const &x,
374  t_Vector const &residual) const {
375  if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
376  if (scalvar.relative_tolerance() <= 0e0) return true;
377  auto const current =
378  sopt::l1_norm(static_cast<t_Vector>(Psi().adjoint() * x), l1_proximal_weights());
379  return scalvar(current);
380 }
381 
382 template <typename SCALAR>
383 bool ImagingPrimalDual<SCALAR>::is_converged(ScalarRelativeVariation<Scalar> &scalvar,
384  t_Vector const &x, t_Vector const &residual) const {
385  auto const user = static_cast<bool>(is_converged()) == false or is_converged()(x, residual);
386  auto const res = residual_convergence(x, residual);
387  auto const obj = objective_convergence(scalvar, x, residual);
388  // beware of short-circuiting!
389  // better evaluate each convergence function everytime, especially with mpi
390  return user and res and obj;
391 }
392 } // namespace sopt::algorithm
393 #endif
sopt::Vector< Scalar > t_Vector
SOPT_MACRO(Psi, t_LinearTransform)
Wavelet operator.
std::function< void(t_Vector &, const T &, const t_Vector &)> t_Proximal
SOPT_MACRO(residual_tolerance, Real)
Convergence of the relative variation of the objective functions.
ImagingPrimalDual &::type Psi(ARGS &&... args)
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector, t_Vector > const &guess) const
Calls Primal Dual.
DiagnosticAndResult operator()(std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Primal Dual.
SOPT_MACRO(l1_proximal, t_Proximal< Real >)
The l1 prox functioning as f.
SOPT_MACRO(real_constraint, bool)
Apply real constraint.
SOPT_MACRO(Phi, t_LinearTransform)
Measurement operator.
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SOPT_MACRO(residual_convergence, t_IsConverged)
Convergence of the residuals.
SOPT_MACRO(l2ball_proximal, proximal::WeightedL2Ball< Scalar >)
The weighted L2 proximal functioning as g.
SOPT_MACRO(random_measurement_updater, t_Random_Updater)
lambda that determines if to update measurements
SOPT_MACRO(epsilon, l2ball, WeightedL2Ball)
DiagnosticAndResult operator()(std::tuple< t_Vector, t_Vector > const &guess) const
Calls Primal Dual.
SOPT_MACRO(weights, l2ball, WeightedL2Ball)
SOPT_MACRO(update_scale, Real)
update parameter
SOPT_MACRO(l1_proximal_weights, Vector< Real >)
The l1 prox weights functioning.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Primal Dual.
ImagingPrimalDual< Scalar > & objective_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
typename PD::t_LinearTransform t_LinearTransform
SOPT_MACRO(relative_variation, Real)
Convergence of the relative variation of the objective functions.
typename PD::t_Constraint t_Constraint
DiagnosticAndResult operator()() const
Calls Primal Dual.
SOPT_MACRO(rho, Real)
rho parameter
t_Vector const & target() const
Vector of target measurements.
SOPT_MACRO(precondition_stepsize, Real)
precondtion step size parameter
SOPT_MACRO(positivity_constraint, bool)
Apply positivity constraint.
SOPT_MACRO(sigma, Real)
sigma parameter
ImagingPrimalDual< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
SOPT_MACRO(l1_proximal_weighted, t_Proximal< Vector< Real >>)
The l1 prox with weights functioning as f.
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
ImagingPrimalDual(Eigen::MatrixBase< DERIVED > const &target)
SOPT_MACRO(precondition_iters, t_uint)
precondition iterations parameter
ImagingPrimalDual< Scalar > & residual_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
proximal::WeightedL2Ball< Scalar > & l2ball_proximal()
Proximal of the L2 ball.
ImagingPrimalDual< Scalar > & target(Eigen::MatrixBase< DERIVED > const &target)
Sets the vector of target measurements.
Diagnostic operator()(t_Vector &out) const
Calls Primal Dual.
SOPT_MACRO(precondition_weights, t_Vector)
precondition weights parameter
SOPT_MACRO(xi, Real)
xi parameter
SOPT_MACRO(objective_convergence, t_IsConverged)
Convergence of the residuals.
SOPT_MACRO(tau, Real)
tau parameter
ImagingPrimalDual &::type Phi(ARGS &&... args)
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const
Makes it simple to chain different calls to PD.
typename PD::t_Random_Updater t_Random_Updater
SOPT_MACRO(random_wavelet_updater, t_Random_Updater)
lambda that determines if to update wavelets
SOPT_MACRO(regulariser_strength, Real)
regulariser_strength parameter
typename PD::t_IsConverged t_IsConverged
Primal Dual Algorithm.
Definition: primal_dual.h:24
SCALAR value_type
Scalar type.
Definition: primal_dual.h:27
std::function< bool()> t_Random_Updater
Type of random update function.
Definition: primal_dual.h:41
Vector< Scalar > t_Vector
Type of then underlying vectors.
Definition: primal_dual.h:33
LinearTransform< t_Vector > t_LinearTransform
Type of the Ψ and Ψ^H operations, as well as Φ and Φ^H.
Definition: primal_dual.h:35
value_type Scalar
Scalar type.
Definition: primal_dual.h:29
std::function< void(t_Vector &, const t_Vector &)> t_Constraint
Type of the constraint function.
Definition: primal_dual.h:39
std::function< bool(const t_Vector &, const t_Vector &)> t_IsConverged
Type of the convergence function.
Definition: primal_dual.h:37
std::tuple< t_Vector, t_Vector > initial_guess() const
Computes initial guess for x and the residual using the targets.
Definition: primal_dual.h:233
typename real_type< Scalar >::type Real
Real type.
Definition: primal_dual.h:31
#define SOPT_THROW(MSG)
Definition: exception.h:46
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
Eigen::CwiseUnaryOp< const details::ProjectPositiveQuadrant< typename T::Scalar >, const T > positive_quadrant(Eigen::DenseBase< T > const &input)
Expression to create projection onto positive quadrant.
Definition: maths.h:60
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:12
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:38
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:17
real_type< typename T0::Scalar >::type l1_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L1 norm.
Definition: maths.h:116
real_type< typename T0::Scalar >::type l2_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L2 norm.
Definition: maths.h:140
Values indicating how the algorithm ran.
Diagnostic(t_uint niters, bool good, t_Vector &&residual)
Diagnostic(t_uint niters=0u, bool good=false)
Values indicating how the algorithm ran.
Definition: primal_dual.h:46
t_Vector residual
the residual from the last iteration
Definition: primal_dual.h:52
t_uint niters
Number of iterations.
Definition: primal_dual.h:48
bool good
Wether convergence was achieved.
Definition: primal_dual.h:50