SOPT
Sparse OPTimisation
primal_dual.h
Go to the documentation of this file.
1 #ifndef SOPT_PRIMAL_DUAL_H
2 #define SOPT_PRIMAL_DUAL_H
3 
4 #include "sopt/config.h"
5 #include <functional>
6 #include <limits>
7 #include <tuple> // for std::tuple<>
8 #include <utility> // for std::move<>
9 #include "sopt/exception.h"
10 #include "sopt/linear_transform.h"
11 #include "sopt/logging.h"
12 #include "sopt/types.h"
13 
14 #ifdef SOPT_MPI
15 #include "sopt/mpi/communicator.h"
16 #include "sopt/mpi/utilities.h"
17 #endif
18 
19 namespace sopt::algorithm {
20 
23 template <typename SCALAR>
24 class PrimalDual {
25  public:
27  using value_type = SCALAR;
29  using Scalar = value_type;
31  using Real = typename real_type<Scalar>::type;
37  using t_IsConverged = std::function<bool (const t_Vector &, const t_Vector &)>;
39  using t_Constraint = std::function<void (t_Vector &, const t_Vector &)>;
41  using t_Random_Updater = std::function<bool ()>;
44 
46  struct Diagnostic {
50  bool good;
53 
54  Diagnostic(t_uint niters = 0u, bool good = false)
55  : niters(niters), good(good), residual(t_Vector::Zero(0)) {}
57  : niters(niters), good(good), residual(std::move(residual)) {}
58  };
60  struct DiagnosticAndResult : public Diagnostic {
63  };
64 
68  template <typename DERIVED>
70  Eigen::MatrixBase<DERIVED> const &target)
71  : itermax_(std::numeric_limits<t_uint>::max()),
72  sigma_(1),
73  tau_(0.5),
74  regulariser_strength_(0.5),
75  update_scale_(1),
76  xi_(1),
77  rho_(1),
78  is_converged_(),
79  constraint_([](t_Vector &out, t_Vector const &x) { out = x; }),
80  Phi_(linear_transform_identity<Scalar>()),
81  Psi_(linear_transform_identity<Scalar>()),
82  f_proximal_(f_proximal),
83  g_proximal_(g_proximal),
84  random_measurement_updater_([]() { return true; }),
85  random_wavelet_updater_([]() { return true; }),
86 #ifdef SOPT_MPI
87  v_all_sum_all_comm_(mpi::Communicator()),
88  u_all_sum_all_comm_(mpi::Communicator()),
89 #endif
90  target_(target) {
91  }
92  virtual ~PrimalDual() {}
93 
94 // Macro helps define properties that can be initialized as in
95 // auto sdmm = PrimalDual<float>().prop0(value).prop1(value);
96 #define SOPT_MACRO(NAME, TYPE) \
97  TYPE const &NAME() const { return NAME##_; } \
98  PrimalDual<SCALAR> &NAME(TYPE const &(NAME)) { \
99  NAME##_ = NAME; \
100  return *this; \
101  } \
102  \
103  protected: \
104  TYPE NAME##_; \
105  \
106  public:
107 
109  SOPT_MACRO(itermax, t_uint);
111  SOPT_MACRO(update_scale, Real);
113  SOPT_MACRO(regulariser_strength, Real);
126  SOPT_MACRO(constraint, t_Constraint);
136  SOPT_MACRO(random_measurement_updater, t_Random_Updater);
138  SOPT_MACRO(random_wavelet_updater, t_Random_Updater);
139 #ifdef SOPT_MPI
141  SOPT_MACRO(v_all_sum_all_comm, mpi::Communicator);
143  SOPT_MACRO(u_all_sum_all_comm, mpi::Communicator);
144 #endif
145 #undef SOPT_MACRO
147  void f_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const {
148  f_proximal()(out, regulariser_strength, x);
149  }
151  void g_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const {
152  g_proximal()(out, regulariser_strength, x);
153  }
154 
156  PrimalDual<Scalar> &is_converged(std::function<bool(t_Vector const &x)> const &func) {
157  return is_converged([func](t_Vector const &x, t_Vector const &) { return func(x); });
158  }
159 
161  t_Vector const &target() const { return target_; }
163  template <typename DERIVED>
164  PrimalDual<Scalar> &target(Eigen::MatrixBase<DERIVED> const &target) {
165  target_ = target;
166  return *this;
167  }
168 
170  bool is_converged(t_Vector const &x, t_Vector const &residual) const {
171  return static_cast<bool>(is_converged()) and is_converged()(x, residual);
172  }
173 
176  Diagnostic operator()(t_Vector &out) const { return operator()(out, initial_guess()); }
180  Diagnostic operator()(t_Vector &out, std::tuple<t_Vector, t_Vector> const &guess) const {
181  return operator()(out, std::get<0>(guess), std::get<1>(guess));
182  }
187  std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
188  return operator()(out, std::get<0>(guess), std::get<1>(guess));
189  }
192  DiagnosticAndResult operator()(std::tuple<t_Vector, t_Vector> const &guess) const {
193  return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
194  }
198  std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
199  DiagnosticAndResult result;
200  static_cast<Diagnostic &>(result) = operator()(result.x, guess);
201  return result;
202  }
206  DiagnosticAndResult result;
207  static_cast<Diagnostic &>(result) = operator()(result.x, initial_guess());
208  return result;
209  }
212  DiagnosticAndResult result = warmstart;
213  static_cast<Diagnostic &>(result) = operator()(result.x, warmstart.x, warmstart.residual);
214  return result;
215  }
217  template <typename... ARGS>
218  typename std::enable_if<sizeof...(ARGS) >= 1, PrimalDual &>::type Phi(ARGS &&... args) {
219  Phi_ = linear_transform(std::forward<ARGS>(args)...);
220  return *this;
221  }
223  template <typename... ARGS>
224  typename std::enable_if<sizeof...(ARGS) >= 1, PrimalDual &>::type Psi(ARGS &&... args) {
225  Psi_ = linear_transform(std::forward<ARGS>(args)...);
226  return *this;
227  }
228 
233  std::tuple<t_Vector, t_Vector> initial_guess() const {
235  }
236 
243  static std::tuple<t_Vector, t_Vector> initial_guess(t_Vector const &target,
244  t_LinearTransform const &phi) {
245  std::tuple<t_Vector, t_Vector> guess;
246  std::get<0>(guess) = static_cast<t_Vector>(phi.adjoint() * target) / phi.sq_norm();
247  std::get<1>(guess) = target;
248  return guess;
249  }
250 
251  protected:
252  void iteration_step(t_Vector &out, t_Vector &out_hold, t_Vector &u, t_Vector &u_hold, t_Vector &v,
253  t_Vector &v_hold, t_Vector &residual, t_Vector &q, t_Vector &r,
254  bool &random_measurement_update, bool &random_wavelet_update,
255  t_Vector &u_update, t_Vector &v_update) const;
256 
258  void sanity_check(t_Vector const &x_guess, t_Vector const &res_guess) const {
259  if ((Phi().adjoint() * target()).size() != x_guess.size())
260  SOPT_THROW("target, adjoint measurement operator and input vector have inconsistent sizes");
261  if (target().size() != res_guess.size())
262  SOPT_THROW("target and residual vector have inconsistent sizes");
263  if ((Phi() * x_guess).size() != target().size())
264  SOPT_THROW("target, measurement operator and input vector have inconsistent sizes");
265  if (not static_cast<bool>(is_converged()))
266  SOPT_WARN("No convergence function was provided: algorithm will run for {} steps", itermax());
267  }
268 
273  Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res) const;
274 
276  t_Vector target_;
277 };
278 
279 template <typename SCALAR>
280 void PrimalDual<SCALAR>::iteration_step(t_Vector &out, t_Vector &out_hold, t_Vector &u,
281  t_Vector &u_hold, t_Vector &v, t_Vector &v_hold,
282  t_Vector &residual, t_Vector &q, t_Vector &r,
283  bool &random_measurement_update,
284  bool &random_wavelet_update, t_Vector &u_update,
285  t_Vector &v_update) const {
286  // dual calculations for measurements
287  if (random_measurement_update) {
288  g_proximal(v_hold, rho(), v + residual);
289  v_hold = v + residual - v_hold;
290  v = v + update_scale() * (v_hold - v);
291  v_update = static_cast<t_Vector>(Phi().adjoint() * v);
292  }
293  // dual calculations for wavelet
294  if (random_wavelet_update) {
295  q = static_cast<t_Vector>(Psi().adjoint() * out_hold) * sigma();
296  f_proximal(u_hold, regulariser_strength(), (u + q));
297  u_hold = u + q - u_hold;
298  u = u + update_scale() * (u_hold - u);
299  u_update = static_cast<t_Vector>(Psi() * u);
300  }
301  // primal calculations
302  r = out;
303 #ifdef SOPT_MPI
304  if (v_all_sum_all_comm().size() > 0 and u_all_sum_all_comm().size() > 0)
305  constraint()(
306  out_hold,
307  r - tau() * (u_all_sum_all_comm().all_sum_all(static_cast<const t_Vector>(u_update)) +
308  v_all_sum_all_comm().all_sum_all(static_cast<const t_Vector>(v_update))));
309  else
310 #endif
311  constraint()(out_hold, r - tau() * (u_update + v_update));
312  out = r + update_scale() * (out_hold - r);
313  out_hold = 2 * out_hold - r;
314  random_measurement_update = random_measurement_updater_();
315  random_wavelet_update = random_wavelet_updater_();
316  // update residual
317  if (random_measurement_update)
318  residual = static_cast<t_Vector>(Phi() * out_hold) * xi() - target();
319 }
320 
321 template <typename SCALAR>
322 typename PrimalDual<SCALAR>::Diagnostic PrimalDual<SCALAR>::operator()(
323  t_Vector &out, t_Vector const &x_guess, t_Vector const &res_guess) const {
324  SOPT_HIGH_LOG("Performing Primal Dual");
325  sanity_check(x_guess, res_guess);
326  bool random_measurement_update = random_measurement_updater_();
327  bool random_wavelet_update = random_wavelet_updater_();
328  t_Vector residual = res_guess;
329  out = x_guess;
330  t_Vector out_hold = x_guess;
331  t_Vector r = x_guess;
332  t_Vector v = residual;
333  t_Vector v_hold = residual;
334  t_Vector v_update = x_guess;
335  t_Vector u = Psi().adjoint() * out;
336  t_Vector u_hold = u;
337  t_Vector u_update = out;
338  t_Vector q = u;
339 
340  t_uint niters(0);
341  bool converged = false;
342  for (; (not converged) && (niters < itermax()); ++niters) {
343  SOPT_LOW_LOG(" - [Primal Dual] Iteration {}/{}", niters, itermax());
344  iteration_step(out, out_hold, u, u_hold, v, v_hold, residual, q, r, random_measurement_update,
345  random_wavelet_update, u_update, v_update);
346  SOPT_LOW_LOG(" - [Primal Dual] Sum of residuals: {}",
347  static_cast<t_Vector>(residual).array().abs().sum());
348  converged = is_converged(out, residual);
349  }
350 
351  if (converged) {
352  SOPT_MEDIUM_LOG(" - [Primal Dual] converged in {} of {} iterations", niters, itermax());
353  } else if (static_cast<bool>(is_converged())) {
354  // not meaningful if not convergence function
355  SOPT_ERROR(" - [Primal Dual] did not converge within {} iterations", itermax());
356  }
357  return {niters, converged, std::move(residual)};
358 }
359 } // namespace sopt::algorithm
360 #endif
sopt::Vector< Scalar > t_Vector
sopt::t_real sq_norm() const
LinearTransform< VECTOR > adjoint() const
Indirect transform.
Primal Dual Algorithm.
Definition: primal_dual.h:24
ProximalFunction< Scalar > t_Proximal
Type of the convergence function.
Definition: primal_dual.h:43
PrimalDual< Scalar > & target(Eigen::MatrixBase< DERIVED > const &target)
Sets the vector of target measurements.
Definition: primal_dual.h:164
SCALAR value_type
Scalar type.
Definition: primal_dual.h:27
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector, t_Vector > const &guess) const
Calls Primal Dual.
Definition: primal_dual.h:180
SOPT_MACRO(f_proximal, t_Proximal)
First proximal.
t_Vector const & target() const
Vector of target measurements.
Definition: primal_dual.h:161
DiagnosticAndResult operator()() const
Calls Primal Dual.
Definition: primal_dual.h:205
SOPT_MACRO(random_wavelet_updater, t_Random_Updater)
lambda that determines if to update wavelets
std::function< bool()> t_Random_Updater
Type of random update function.
Definition: primal_dual.h:41
Vector< Scalar > t_Vector
Type of then underlying vectors.
Definition: primal_dual.h:33
Diagnostic operator()(t_Vector &out) const
Calls Primal Dual.
Definition: primal_dual.h:176
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Primal Dual.
Definition: primal_dual.h:186
SOPT_MACRO(xi, Real)
xi parameter
bool is_converged(t_Vector const &x, t_Vector const &residual) const
Facilitates call to user-provided convergence function.
Definition: primal_dual.h:170
PrimalDual &::type Psi(ARGS &&... args)
Definition: primal_dual.h:224
SOPT_MACRO(Phi, t_LinearTransform)
Measurement operator.
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
value_type Scalar
Scalar type.
Definition: primal_dual.h:29
SOPT_MACRO(rho, Real)
rho parameter
std::function< void(t_Vector &, const t_Vector &)> t_Constraint
Type of the constraint function.
Definition: primal_dual.h:39
SOPT_MACRO(constraint, t_Constraint)
A function applying a simple constraint.
PrimalDual &::type Phi(ARGS &&... args)
Definition: primal_dual.h:218
std::function< bool(const t_Vector &, const t_Vector &)> t_IsConverged
Type of the convergence function.
Definition: primal_dual.h:37
SOPT_MACRO(regulariser_strength, Real)
γ parameter.
SOPT_MACRO(g_proximal, t_Proximal)
Second proximal.
void f_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const
Simplifies calling the proximal of f.
Definition: primal_dual.h:147
SOPT_MACRO(tau, Real)
tau parameter
PrimalDual< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
Definition: primal_dual.h:156
void g_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const
Simplifies calling the proximal of f.
Definition: primal_dual.h:151
DiagnosticAndResult operator()(std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Primal Dual.
Definition: primal_dual.h:197
PrimalDual(t_Proximal const &f_proximal, t_Proximal const &g_proximal, Eigen::MatrixBase< DERIVED > const &target)
Definition: primal_dual.h:69
SOPT_MACRO(Psi, t_LinearTransform)
Wavelet operator.
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SOPT_MACRO(random_measurement_updater, t_Random_Updater)
lambda that determines if to update measurements
DiagnosticAndResult operator()(std::tuple< t_Vector, t_Vector > const &guess) const
Calls Primal Dual.
Definition: primal_dual.h:192
std::tuple< t_Vector, t_Vector > initial_guess() const
Computes initial guess for x and the residual using the targets.
Definition: primal_dual.h:233
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const
Makes it simple to chain different calls to PD.
Definition: primal_dual.h:211
typename real_type< Scalar >::type Real
Real type.
Definition: primal_dual.h:31
static std::tuple< t_Vector, t_Vector > initial_guess(t_Vector const &target, t_LinearTransform const &phi)
Computes initial guess for x and the residual using the targets.
Definition: primal_dual.h:243
SOPT_MACRO(sigma, Real)
sigma parameter
SOPT_MACRO(update_scale, Real)
Update parameter.
Computes inner-most element type.
Definition: real_type.h:42
#define SOPT_MPI
Whether or not to include mpi.
Definition: config.in.h:17
#define SOPT_THROW(MSG)
Definition: exception.h:46
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
#define SOPT_ERROR(...)
\macro Something is definitely wrong, algorithm exits
Definition: logging.h:211
#define SOPT_WARN(...)
\macro Something might be going wrong
Definition: logging.h:213
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:225
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:12
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
std::function< void(Vector< SCALAR > &output, typename real_type< SCALAR >::type const weight, Vector< SCALAR > const &input)> ProximalFunction
Typical function signature for calls to proximal.
Definition: types.h:48
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:17
Values indicating how the algorithm ran.
Definition: primal_dual.h:46
Diagnostic(t_uint niters, bool good, t_Vector &&residual)
Definition: primal_dual.h:56
t_Vector residual
the residual from the last iteration
Definition: primal_dual.h:52
t_uint niters
Number of iterations.
Definition: primal_dual.h:48
Diagnostic(t_uint niters=0u, bool good=false)
Definition: primal_dual.h:54
bool good
Wether convergence was achieved.
Definition: primal_dual.h:50