SOPT
Sparse OPTimisation
tv_primal_dual.h
Go to the documentation of this file.
1 #ifndef SOPT_TV_PRIMAL_DUAL_H
2 #define SOPT_TV_PRIMAL_DUAL_H
3 
4 #include "sopt/config.h"
5 #include <limits> // for std::numeric_limits<>
6 #include <numeric>
7 #include <tuple>
8 #include <utility>
9 #include "sopt/exception.h"
10 #include "sopt/linear_transform.h"
11 #include "sopt/logging.h"
12 #include "sopt/primal_dual.h"
13 #include "sopt/proximal.h"
15 #include "sopt/types.h"
16 
17 namespace sopt::algorithm {
18 template <typename SCALAR>
19 class TVPrimalDual {
21  using PD = PrimalDual<SCALAR>;
22 
23  public:
24  using value_type = typename PD::value_type;
25  using Scalar = typename PD::Scalar;
26  using Real = typename PD::Real;
27  using t_Vector = typename PD::t_Vector;
29  template <typename T>
30  using t_Proximal = std::function<void(t_Vector &, const T &, const t_Vector &)>;
31  using t_IsConverged = typename PD::t_IsConverged;
32  using t_Constraint = typename PD::t_Constraint;
34 
36  struct Diagnostic : public PD::Diagnostic {
37  Diagnostic(t_uint niters = 0u, bool good = false) : PD::Diagnostic(niters, good) {}
39  : PD::Diagnostic(niters, good, std::move(residual)) {}
40  };
42  struct DiagnosticAndResult : public Diagnostic {
45  };
46 
50  template <typename DERIVED>
51  TVPrimalDual(Eigen::MatrixBase<DERIVED> const &target)
52  : tv_proximal_([](t_Vector &out, const Real &regulariser_strength, const t_Vector &x) {
53  proximal::tv_norm<t_Vector, t_Vector>(out, regulariser_strength, x);
54  }),
55  tv_proximal_weighted_([](t_Vector &out, const Vector<Real> &regulariser_strength, const t_Vector &x) {
56  proximal::tv_norm<t_Vector, t_Vector, Vector<Real>>(out, regulariser_strength, x);
57  }),
58  tv_proximal_weights_(Vector<Real>::Ones(1)),
59  l2ball_proximal_(1e0),
60  residual_tolerance_(1e-4),
61  relative_variation_(1e-4),
62  residual_convergence_(nullptr),
63  objective_convergence_(nullptr),
64  itermax_(std::numeric_limits<t_uint>::max()),
65  sigma_(1),
66  tau_(0.5),
67  regulariser_strength_(0.5),
68  update_scale_(1),
69  precondition_stepsize_(0.5),
70  precondition_weights_(t_Vector::Ones(target.size())),
71  precondition_iters_(0),
72  xi_(1),
73  rho_(1),
74  is_converged_(),
75  Phi_(linear_transform_identity<Scalar>()),
76  Psi_(linear_transform_identity<Scalar>()),
77  random_measurement_updater_([]() { return true; }),
78  random_wavelet_updater_([]() { return true; }),
79  target_(target) {}
80  virtual ~TVPrimalDual() {}
81 
82 // Macro helps define properties that can be initialized as in
83 // auto padmm = TVPrimalDual<float>().prop0(value).prop1(value);
84 #define SOPT_MACRO(NAME, TYPE) \
85  TYPE const &NAME() const { return NAME##_; } \
86  TVPrimalDual<SCALAR> &NAME(TYPE const &(NAME)) { \
87  NAME##_ = NAME; \
88  return *this; \
89  } \
90  \
91  protected: \
92  TYPE NAME##_; \
93  \
94  public:
98  SOPT_MACRO(tv_proximal_weighted, t_Proximal<Vector<Real>>);
100  SOPT_MACRO(tv_proximal_weights, Vector<Real>);
105  SOPT_MACRO(residual_tolerance, Real);
108  SOPT_MACRO(relative_variation, Real);
116  SOPT_MACRO(itermax, t_uint);
118  SOPT_MACRO(regulariser_strength, Real);
120  SOPT_MACRO(update_scale, Real);
122  SOPT_MACRO(positivity_constraint, bool);
124  SOPT_MACRO(real_constraint, bool);
134  SOPT_MACRO(precondition_stepsize, Real);
136  SOPT_MACRO(precondition_weights, t_Vector);
138  SOPT_MACRO(precondition_iters, t_uint);
146  SOPT_MACRO(random_measurement_updater, t_Random_Updater);
148  SOPT_MACRO(random_wavelet_updater, t_Random_Updater);
149 #ifdef SOPT_MPI
151  SOPT_MACRO(v_all_sum_all_comm, mpi::Communicator);
153  SOPT_MACRO(u_all_sum_all_comm, mpi::Communicator);
154 #endif
155 
156 #undef SOPT_MACRO
158  t_Vector const &target() const { return target_; }
160  template <typename DERIVED>
161  TVPrimalDual<Scalar> &target(Eigen::MatrixBase<DERIVED> const &target) {
162  target_ = target;
163  return *this;
164  }
165 
169  return operator()(out, PD::initial_guess(target(), Phi()));
170  }
174  Diagnostic operator()(t_Vector &out, std::tuple<t_Vector, t_Vector> const &guess) const {
175  return operator()(out, std::get<0>(guess), std::get<1>(guess));
176  }
181  std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
182  return operator()(out, std::get<0>(guess), std::get<1>(guess));
183  }
186  DiagnosticAndResult operator()(std::tuple<t_Vector, t_Vector> const &guess) const {
187  return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
188  }
192  std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
193  DiagnosticAndResult result;
194  static_cast<Diagnostic &>(result) = operator()(result.x, guess);
195  return result;
196  }
200  DiagnosticAndResult result;
201  static_cast<Diagnostic &>(result) = operator()(result.x,
203  return result;
204  }
207  DiagnosticAndResult result = warmstart;
208  static_cast<Diagnostic &>(result) = operator()(result.x, warmstart.x, warmstart.residual);
209  return result;
210  }
211 
213  template <typename... ARGS>
214  typename std::enable_if<sizeof...(ARGS) >= 1, TVPrimalDual &>::type Phi(ARGS &&... args) {
215  Phi_ = linear_transform(std::forward<ARGS>(args)...);
216  return *this;
217  }
218 
221  proximal::WeightedL2Ball<Scalar> &l2ball_proximal() { return l2ball_proximal_; }
222 
224  template <typename... ARGS>
225  typename std::enable_if<sizeof...(ARGS) >= 1, TVPrimalDual &>::type Psi(ARGS &&... args) {
226  Psi_ = linear_transform(std::forward<ARGS>(args)...);
227  return *this;
228  }
229 
230 // Forwards get/setters to L1 and L2Ball proximals
231 // In practice, we end up with a bunch of functions that make it simpler to set or get values
232 // associated with the two proximal operators.
233 // E.g.: `paddm.tv_proximal_itermax(100).l2ball_epsilon(1e-2).tv_proximal_tolerance(1e-4)`.
234 // ~~~
235 #define SOPT_MACRO(VAR, NAME, PROXIMAL) \
236  \
237  decltype(std::declval<proximal::PROXIMAL<Scalar> const>().VAR()) NAME##_proximal_##VAR() const { \
238  return NAME##_proximal().VAR(); \
239  } \
240  \
241  TVPrimalDual<Scalar> &NAME##_proximal_##VAR( \
242  decltype(std::declval<proximal::PROXIMAL<Scalar> const>().VAR()) (VAR)) { \
243  NAME##_proximal().VAR(VAR); \
244  return *this; \
245  }
246  SOPT_MACRO(epsilon, l2ball, WeightedL2Ball);
247  SOPT_MACRO(weights, l2ball, WeightedL2Ball);
248 #ifdef SOPT_MPI
249  SOPT_MACRO(communicator, l2ball, WeightedL2Ball);
250 #endif
251 #undef SOPT_MACRO
252 
255  return residual_convergence(nullptr).residual_tolerance(tolerance);
256  }
259  return objective_convergence(nullptr).relative_variation(tolerance);
260  }
262  TVPrimalDual<Scalar> &is_converged(std::function<bool(t_Vector const &x)> const &func) {
263  return is_converged([func](t_Vector const &x, t_Vector const &) { return func(x); });
264  }
265 
266  protected:
268  t_Vector target_;
269 
274  Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res) const;
275 
277  bool residual_convergence(t_Vector const &x, t_Vector const &residual) const;
278 
281  t_Vector const &residual) const;
282 
284  bool is_converged(ScalarRelativeVariation<Scalar> &scalvar, t_Vector const &x,
285  t_Vector const &residual) const;
287  bool check_tv_weight_proximal(const t_Proximal<Real> &no_weights,
288  const t_Proximal<Vector<Real>> &with_weights) const {
289  const Vector<SCALAR> x = Vector<SCALAR>::Ones(this->tv_proximal_weights().size());
290  Vector<SCALAR> output = Vector<SCALAR>::Zero(this->tv_proximal_weights().size());
291  Vector<SCALAR> outputw = Vector<SCALAR>::Zero(this->tv_proximal_weights().size());
292  no_weights(output, 1, x);
293  with_weights(outputw, Vector<Real>::Ones(this->tv_proximal_weights().size()), x);
294  return output.isApprox(outputw);
295  };
296 };
297 
298 template <typename SCALAR>
299 typename TVPrimalDual<SCALAR>::Diagnostic TVPrimalDual<SCALAR>::operator()(
300  t_Vector &out, t_Vector const &guess, t_Vector const &res) const {
301  SOPT_HIGH_LOG("Performing Primal Dual with TV and L2 operators");
302  // The f proximal is an L1 proximal that stores some diagnostic result
303  if (not check_tv_weight_proximal(tv_proximal(), tv_proximal_weighted()))
304  SOPT_THROW(
305  "tv proximal and weighted tv proximal appear to be different functions. Please make sure "
306  "both are the same function.");
307  auto const f_proximal = [this](t_Vector &out, Real regulariser_strength, t_Vector const &x) {
308  if (this->tv_proximal_weights().size() > 1)
309  this->tv_proximal_weighted()(out, this->tv_proximal_weights() * regulariser_strength, x);
310  else
311  this->tv_proximal()(out, this->tv_proximal_weights()(0) * regulariser_strength, x);
312  };
313  auto const g_proximal = [this](t_Vector &out, Real regulariser_strength, t_Vector const &x) {
314  this->l2ball_proximal()(out, regulariser_strength, x);
315  // applying preconditioning
316  for (t_int i = 0; i < this->precondition_iters(); i++)
317  this->l2ball_proximal()(
318  out, regulariser_strength,
319  out - this->precondition_stepsize() *
320  (out.array() * this->precondition_weights().array() - x.array()).matrix());
321 
322  if (this->precondition_iters() > 0) out = out.array() * this->precondition_weights().array();
323  };
324  ScalarRelativeVariation<Scalar> scalvar(relative_variation(), relative_variation(),
325  "Objective function");
326  auto const convergence = [this, scalvar](t_Vector const &x, t_Vector const &residual) mutable {
327  return this->is_converged(scalvar, x, residual);
328  };
329  const bool positive = positivity_constraint();
330  const bool real = real_constraint();
331  t_Constraint constraint = [real, positive](t_Vector &out, const t_Vector &x) {
332  if (real) out.real() = x.real();
333  if (positive) out = sopt::positive_quadrant(x);
334  if (not real and not positive) out = x;
335  };
336  auto const pd = PD(f_proximal, g_proximal, target())
337  .itermax(itermax())
338  .constraint(constraint)
339  .sigma(sigma())
340  .tau(tau())
341  .regulariser_strength(regulariser_strength())
342  .update_scale(update_scale())
343  .xi(xi())
344  .rho(rho())
345  .regulariser_strength(regulariser_strength())
346  .Phi(Phi())
347  .Psi(Psi())
348  .random_measurement_updater(random_measurement_updater())
349  .random_wavelet_updater(random_wavelet_updater())
350 #ifdef SOPT_MPI
351  .v_all_sum_all_comm(v_all_sum_all_comm())
352  .u_all_sum_all_comm(u_all_sum_all_comm())
353 #endif
354  .is_converged(convergence);
355  Diagnostic result;
356  static_cast<typename PD::Diagnostic &>(result) = pd(out, std::tie(guess, res));
357  return result;
358 }
359 
360 template <typename SCALAR>
361 bool TVPrimalDual<SCALAR>::residual_convergence(t_Vector const &x, t_Vector const &residual) const {
362  if (static_cast<bool>(residual_convergence())) return residual_convergence()(x, residual);
363  if (residual_tolerance() <= 0e0) return true;
364  auto const residual_norm = sopt::l2_norm(residual, l2ball_proximal_weights());
365  SOPT_LOW_LOG(" - [Primal Dual] Residuals: {} <? {}", residual_norm, residual_tolerance());
366  return residual_norm < residual_tolerance();
367 };
368 
369 template <typename SCALAR>
370 bool TVPrimalDual<SCALAR>::objective_convergence(ScalarRelativeVariation<Scalar> &scalvar,
371  t_Vector const &x,
372  t_Vector const &residual) const {
373  if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
374  if (scalvar.relative_tolerance() <= 0e0) return true;
375  auto const current =
376  sopt::tv_norm(static_cast<t_Vector>(Psi().adjoint() * x), tv_proximal_weights());
377  return scalvar(current);
378 };
379 
380 template <typename SCALAR>
381 bool TVPrimalDual<SCALAR>::is_converged(ScalarRelativeVariation<Scalar> &scalvar, t_Vector const &x,
382  t_Vector const &residual) const {
383  auto const user = static_cast<bool>(is_converged()) == false or is_converged()(x, residual);
384  auto const res = residual_convergence(x, residual);
385  auto const obj = objective_convergence(scalvar, x, residual);
386  // beware of short-circuiting!
387  // better evaluate each convergence function everytime, especially with mpi
388  return user and res and obj;
389 }
390 } // namespace sopt::algorithm
391 #endif
sopt::Vector< Scalar > t_Vector
Primal Dual Algorithm.
Definition: primal_dual.h:24
SCALAR value_type
Scalar type.
Definition: primal_dual.h:27
std::function< bool()> t_Random_Updater
Type of random update function.
Definition: primal_dual.h:41
Vector< Scalar > t_Vector
Type of then underlying vectors.
Definition: primal_dual.h:33
LinearTransform< t_Vector > t_LinearTransform
Type of the Ψ and Ψ^H operations, as well as Φ and Φ^H.
Definition: primal_dual.h:35
value_type Scalar
Scalar type.
Definition: primal_dual.h:29
std::function< void(t_Vector &, const t_Vector &)> t_Constraint
Type of the constraint function.
Definition: primal_dual.h:39
std::function< bool(const t_Vector &, const t_Vector &)> t_IsConverged
Type of the convergence function.
Definition: primal_dual.h:37
std::tuple< t_Vector, t_Vector > initial_guess() const
Computes initial guess for x and the residual using the targets.
Definition: primal_dual.h:233
typename real_type< Scalar >::type Real
Real type.
Definition: primal_dual.h:31
TVPrimalDual< Scalar > & target(Eigen::MatrixBase< DERIVED > const &target)
Sets the vector of target measurements.
SOPT_MACRO(l2ball_proximal, proximal::WeightedL2Ball< Scalar >)
The weighted L2 proximal functioning as g.
TVPrimalDual(Eigen::MatrixBase< DERIVED > const &target)
SOPT_MACRO(precondition_stepsize, Real)
precondtion step size parameter
SOPT_MACRO(epsilon, l2ball, WeightedL2Ball)
typename PD::t_Random_Updater t_Random_Updater
SOPT_MACRO(relative_variation, Real)
Convergence of the relative variation of the objective functions.
SOPT_MACRO(objective_convergence, t_IsConverged)
Convergence of the residuals.
Diagnostic operator()(t_Vector &out) const
Calls Primal Dual.
SOPT_MACRO(random_measurement_updater, t_Random_Updater)
lambda that determines if to update measurements
SOPT_MACRO(update_scale, Real)
update parameter
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SOPT_MACRO(xi, Real)
xi parameter
typename PD::Scalar Scalar
TVPrimalDual &::type Phi(ARGS &&... args)
SOPT_MACRO(Phi, t_LinearTransform)
Measurement operator.
SOPT_MACRO(sigma, Real)
sigma parameter
SOPT_MACRO(weights, l2ball, WeightedL2Ball)
SOPT_MACRO(precondition_weights, t_Vector)
precondition weights parameter
typename PD::value_type value_type
SOPT_MACRO(precondition_iters, t_uint)
precondition iterations parameter
t_Vector const & target() const
Vector of target measurements.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector, t_Vector > const &guess) const
Calls Primal Dual.
typename PD::t_IsConverged t_IsConverged
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
SOPT_MACRO(residual_tolerance, Real)
Convergence of the relative variation of the objective functions.
TVPrimalDual< Scalar > & residual_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
SOPT_MACRO(tv_proximal_weights, Vector< Real >)
The tv prox weights functioning.
SOPT_MACRO(real_constraint, bool)
Apply real constraint.
SOPT_MACRO(tau, Real)
tau parameter
DiagnosticAndResult operator()(std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Primal Dual.
SOPT_MACRO(rho, Real)
rho parameter
typename PD::t_Vector t_Vector
SOPT_MACRO(residual_convergence, t_IsConverged)
Convergence of the residuals.
DiagnosticAndResult operator()() const
Calls Primal Dual.
SOPT_MACRO(tv_proximal_weighted, t_Proximal< Vector< Real >>)
The tv prox with weights functioning as f.
typename PD::t_LinearTransform t_LinearTransform
SOPT_MACRO(positivity_constraint, bool)
Apply positivity constraint.
std::function< void(t_Vector &, const T &, const t_Vector &)> t_Proximal
TVPrimalDual< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Primal Dual.
SOPT_MACRO(Psi, t_LinearTransform)
Wavelet operator.
SOPT_MACRO(tv_proximal, t_Proximal< Real >)
The tv prox functioning as f.
typename PD::t_Constraint t_Constraint
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const
Makes it simple to chain different calls to PD.
SOPT_MACRO(regulariser_strength, Real)
regulariser_strength parameter
proximal::WeightedL2Ball< Scalar > & l2ball_proximal()
Proximal of the L2 ball.
DiagnosticAndResult operator()(std::tuple< t_Vector, t_Vector > const &guess) const
Calls Primal Dual.
TVPrimalDual &::type Psi(ARGS &&... args)
SOPT_MACRO(random_wavelet_updater, t_Random_Updater)
lambda that determines if to update wavelets
TVPrimalDual< Scalar > & objective_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
#define SOPT_THROW(MSG)
Definition: exception.h:46
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
real_type< typename T0::Scalar >::type tv_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted TV norm.
Definition: maths.h:168
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
Eigen::CwiseUnaryOp< const details::ProjectPositiveQuadrant< typename T::Scalar >, const T > positive_quadrant(Eigen::DenseBase< T > const &input)
Expression to create projection onto positive quadrant.
Definition: maths.h:60
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:12
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:38
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:17
real_type< typename T0::Scalar >::type l2_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L2 norm.
Definition: maths.h:140
Values indicating how the algorithm ran.
Definition: primal_dual.h:46
t_Vector residual
the residual from the last iteration
Definition: primal_dual.h:52
t_uint niters
Number of iterations.
Definition: primal_dual.h:48
bool good
Wether convergence was achieved.
Definition: primal_dual.h:50
Values indicating how the algorithm ran.
Diagnostic(t_uint niters, bool good, t_Vector &&residual)
Diagnostic(t_uint niters=0u, bool good=false)