SOPT
Sparse OPTimisation
padmm.h
Go to the documentation of this file.
1 #ifndef SOPT_PROXIMAL_ADMM_H
2 #define SOPT_PROXIMAL_ADMM_H
3 
4 #include "sopt/config.h"
5 #include <functional>
6 #include <limits>
7 #include <tuple> // for std::tuple<>
8 #include <utility> // for std::move<>
9 #include "sopt/exception.h"
10 #include "sopt/linear_transform.h"
11 #include "sopt/logging.h"
12 #include "sopt/types.h"
13 
14 namespace sopt::algorithm {
15 
18 template <typename SCALAR>
19 class ProximalADMM {
20  public:
22  using value_type = SCALAR;
24  using Scalar = value_type;
26  using Real = typename real_type<Scalar>::type;
32  using t_IsConverged = std::function<bool (const t_Vector &, const t_Vector &)>;
35 
37  struct Diagnostic {
41  bool good;
44 
45  Diagnostic(t_uint niters = 0u, bool good = false)
46  : niters(niters), good(good), residual(t_Vector::Zero(0)) {}
48  : niters(niters), good(good), residual(std::move(residual)) {}
49  };
51  struct DiagnosticAndResult : public Diagnostic {
54  };
55 
59  template <typename DERIVED>
61  Eigen::MatrixBase<DERIVED> const &target)
62  : itermax_(std::numeric_limits<t_uint>::max()),
63  regulariser_strength_(1e-8),
64  lagrange_update_scale_(0.9),
65  is_converged_(),
67  f_proximal_(f_proximal),
68  g_proximal_(g_proximal),
69  target_(target) {}
70  virtual ~ProximalADMM() {}
71 
72 // Macro helps define properties that can be initialized as in
73 // auto sdmm = ProximalADMM<float>().prop0(value).prop1(value);
74 #define SOPT_MACRO(NAME, TYPE) \
75  TYPE const &NAME() const { return NAME##_; } \
76  ProximalADMM<SCALAR> &NAME(TYPE const &(NAME)) { \
77  NAME##_ = NAME; \
78  return *this; \
79  } \
80  \
81  protected: \
82  TYPE NAME##_; \
83  \
84  public:
85 
87  SOPT_MACRO(itermax, t_uint);
89  SOPT_MACRO(regulariser_strength, Real);
91  SOPT_MACRO(lagrange_update_scale, Real);
101 #undef SOPT_MACRO
103  void f_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const {
104  f_proximal()(out, regulariser_strength, x);
105  }
107  void g_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const {
108  g_proximal()(out, regulariser_strength, x);
109  }
110 
112  ProximalADMM<Scalar> &is_converged(std::function<bool(t_Vector const &x)> const &func) {
113  return is_converged([func](t_Vector const &x, t_Vector const &) { return func(x); });
114  }
115 
117  t_Vector const &target() const { return target_; }
119  template <typename DERIVED>
120  ProximalADMM<Scalar> &target(Eigen::MatrixBase<DERIVED> const &target) {
121  target_ = target;
122  return *this;
123  }
124 
126  bool is_converged(t_Vector const &x, t_Vector const &residual) const {
127  return static_cast<bool>(is_converged()) and is_converged()(x, residual);
128  }
129 
132  Diagnostic operator()(t_Vector &out) const { return operator()(out, initial_guess()); }
136  Diagnostic operator()(t_Vector &out, std::tuple<t_Vector, t_Vector> const &guess) const {
137  return operator()(out, std::get<0>(guess), std::get<1>(guess));
138  }
143  std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
144  return operator()(out, std::get<0>(guess), std::get<1>(guess));
145  }
148  DiagnosticAndResult operator()(std::tuple<t_Vector, t_Vector> const &guess) const {
149  return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
150  }
154  std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
155  DiagnosticAndResult result;
156  static_cast<Diagnostic &>(result) = operator()(result.x, guess);
157  return result;
158  }
162  DiagnosticAndResult result;
163  static_cast<Diagnostic &>(result) = operator()(result.x, initial_guess());
164  return result;
165  }
168  DiagnosticAndResult result = warmstart;
169  static_cast<Diagnostic &>(result) = operator()(result.x, warmstart.x, warmstart.residual);
170  return result;
171  }
173  template <typename... ARGS>
174  typename std::enable_if<sizeof...(ARGS) >= 1, ProximalADMM &>::type Phi(ARGS &&... args) {
175  Phi_ = linear_transform(std::forward<ARGS>(args)...);
176  return *this;
177  }
178 
183  std::tuple<t_Vector, t_Vector> initial_guess() const {
185  }
186 
193  static std::tuple<t_Vector, t_Vector> initial_guess(t_Vector const &target,
194  t_LinearTransform const &phi) {
195  std::tuple<t_Vector, t_Vector> guess;
196  std::get<0>(guess) = static_cast<t_Vector>(phi.adjoint() * target) / phi.sq_norm();
197  std::get<1>(guess) = phi * std::get<0>(guess) - target;
198  return guess;
199  }
200 
201  protected:
202  void iteration_step(t_Vector &out, t_Vector &residual, t_Vector &lambda, t_Vector &z) const;
203 
205  void sanity_check(t_Vector const &x_guess, t_Vector const &res_guess) const {
206  if ((Phi().adjoint() * target()).size() != x_guess.size())
207  SOPT_THROW("target, adjoint measurement operator and input vector have inconsistent sizes");
208  if (target().size() != res_guess.size())
209  SOPT_THROW("target and residual vector have inconsistent sizes");
210  if ((Phi() * x_guess).size() != target().size())
211  SOPT_THROW("target, measurement operator and input vector have inconsistent sizes");
212  if (not static_cast<bool>(is_converged()))
213  SOPT_WARN("No convergence function was provided: algorithm will run for {} steps", itermax());
214  }
215 
220  Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res) const;
221 
223  t_Vector target_;
224 };
225 
226 template <typename SCALAR>
227 void ProximalADMM<SCALAR>::iteration_step(t_Vector &out, t_Vector &residual, t_Vector &lambda,
228  t_Vector &z) const {
229  g_proximal(z, regulariser_strength(), -lambda - residual);
230  f_proximal(out, regulariser_strength() / Phi().sq_norm(),
231  out - static_cast<t_Vector>(Phi().adjoint() * (residual + lambda + z)) / Phi().sq_norm());
232  residual = static_cast<t_Vector>(Phi() * out - target());
233  lambda += lagrange_update_scale() * (residual + z);
234 }
235 
236 template <typename SCALAR>
237 typename ProximalADMM<SCALAR>::Diagnostic ProximalADMM<SCALAR>::operator()(
238  t_Vector &out, t_Vector const &x_guess, t_Vector const &res_guess) const {
239  SOPT_HIGH_LOG("Performing Proximal ADMM");
240  sanity_check(x_guess, res_guess);
241 
242  t_Vector lambda = t_Vector::Zero(target().size());
243  t_Vector z = t_Vector::Zero(target().size());
244  t_Vector residual = res_guess;
245  out = x_guess;
246 
247  t_uint niters(0);
248  bool converged = false;
249  for (; (not converged) && (niters < itermax()); ++niters) {
250  SOPT_LOW_LOG(" - [PADMM] Iteration {}/{}", niters, itermax());
251  iteration_step(out, residual, lambda, z);
252  SOPT_LOW_LOG(" - [PADMM] Sum of residuals: {}", residual.array().abs().sum());
253  converged = is_converged(out, residual);
254  }
255 
256  if (converged) {
257  SOPT_MEDIUM_LOG(" - [PADMM] converged in {} of {} iterations", niters, itermax());
258  } else if (static_cast<bool>(is_converged())) {
259  // not meaningful if not convergence function
260  SOPT_ERROR(" - [PADMM] did not converge within {} iterations", itermax());
261  }
262  return {niters, converged, std::move(residual)};
263 }
264 } // namespace sopt::algorithm
265 #endif
sopt::Vector< Scalar > t_Vector
sopt::t_real sq_norm() const
LinearTransform< VECTOR > adjoint() const
Indirect transform.
Proximal Alternate Direction method of mutltipliers.
Definition: padmm.h:19
ProximalADMM(t_Proximal const &f_proximal, t_Proximal const &g_proximal, Eigen::MatrixBase< DERIVED > const &target)
Definition: padmm.h:60
void g_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const
Simplifies calling the proximal of f.
Definition: padmm.h:107
SOPT_MACRO(regulariser_strength, Real)
γ parameter.
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SCALAR value_type
Scalar type.
Definition: padmm.h:22
t_Vector const & target() const
Vector of target measurements.
Definition: padmm.h:117
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Proximal ADMM.
Definition: padmm.h:142
ProximalADMM< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
Definition: padmm.h:112
Vector< Scalar > t_Vector
Type of then underlying vectors.
Definition: padmm.h:28
Diagnostic operator()(t_Vector &out) const
Calls Proximal ADMM.
Definition: padmm.h:132
ProximalADMM< Scalar > & target(Eigen::MatrixBase< DERIVED > const &target)
Sets the vector of target measurements.
Definition: padmm.h:120
value_type Scalar
Scalar type.
Definition: padmm.h:24
ProximalFunction< Scalar > t_Proximal
Type of the convergence function.
Definition: padmm.h:34
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
SOPT_MACRO(g_proximal, t_Proximal)
Second proximal.
SOPT_MACRO(lagrange_update_scale, Real)
Lagrange update scale β
std::tuple< t_Vector, t_Vector > initial_guess() const
Computes initial guess for x and the residual using the targets.
Definition: padmm.h:183
ProximalADMM &::type Phi(ARGS &&... args)
Definition: padmm.h:174
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const
Makes it simple to chain different calls to PADMM.
Definition: padmm.h:167
SOPT_MACRO(f_proximal, t_Proximal)
First proximal.
DiagnosticAndResult operator()(std::tuple< t_Vector, t_Vector > const &guess) const
Calls Proximal ADMM.
Definition: padmm.h:148
bool is_converged(t_Vector const &x, t_Vector const &residual) const
Facilitates call to user-provided convergence function.
Definition: padmm.h:126
DiagnosticAndResult operator()(std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Proximal ADMM.
Definition: padmm.h:153
std::function< bool(const t_Vector &, const t_Vector &)> t_IsConverged
Type of the convergence function.
Definition: padmm.h:32
static std::tuple< t_Vector, t_Vector > initial_guess(t_Vector const &target, t_LinearTransform const &phi)
Computes initial guess for x and the residual using the targets.
Definition: padmm.h:193
void f_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const
Simplifies calling the proximal of f.
Definition: padmm.h:103
SOPT_MACRO(Phi, t_LinearTransform)
Measurement operator.
typename real_type< Scalar >::type Real
Real type.
Definition: padmm.h:26
DiagnosticAndResult operator()() const
Calls Proximal ADMM.
Definition: padmm.h:161
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector, t_Vector > const &guess) const
Calls Proximal ADMM.
Definition: padmm.h:136
Computes inner-most element type.
Definition: real_type.h:42
#define SOPT_THROW(MSG)
Definition: exception.h:46
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
#define SOPT_ERROR(...)
\macro Something is definitely wrong, algorithm exits
Definition: logging.h:211
#define SOPT_WARN(...)
\macro Something might be going wrong
Definition: logging.h:213
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:225
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:12
LinearTransform< Vector< SCALAR > > linear_transform_identity()
Helper function to create a linear transform that's just the identity.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
std::function< void(Vector< SCALAR > &output, typename real_type< SCALAR >::type const weight, Vector< SCALAR > const &input)> ProximalFunction
Typical function signature for calls to proximal.
Definition: types.h:48
Holds result vector as well.
Definition: padmm.h:51
Values indicating how the algorithm ran.
Definition: padmm.h:37
bool good
Wether convergence was achieved.
Definition: padmm.h:41
Diagnostic(t_uint niters=0u, bool good=false)
Definition: padmm.h:45
t_uint niters
Number of iterations.
Definition: padmm.h:39
Diagnostic(t_uint niters, bool good, t_Vector &&residual)
Definition: padmm.h:47
t_Vector residual
the residual from the last iteration
Definition: padmm.h:43