SOPT
Sparse OPTimisation
l2_forward_backward.h
Go to the documentation of this file.
1 #ifndef SOPT_L2_FORWARD_BACKWARD_H
2 #define SOPT_L2_FORWARD_BACKWARD_H
3 
4 #include "sopt/config.h"
5 #include <limits> // for std::numeric_limits<>
6 #include <numeric>
7 #include <tuple>
8 #include <utility>
9 #include "sopt/exception.h"
10 #include "sopt/forward_backward.h"
11 #include "sopt/linear_transform.h"
12 #include "sopt/logging.h"
13 #include "sopt/proximal.h"
15 #include "sopt/types.h"
16 
17 #ifdef SOPT_MPI
18 #include "sopt/mpi/communicator.h"
19 #include "sopt/mpi/utilities.h"
20 #endif
21 
22 namespace sopt::algorithm {
23 template <typename SCALAR>
26  using FB = ForwardBackward<SCALAR>;
27 
28  public:
29  using value_type = typename FB::value_type;
30  using Scalar = typename FB::Scalar;
31  using Real = typename FB::Real;
32  using t_Vector = typename FB::t_Vector;
34  template <typename T>
35  using t_Proximal = std::function<void(t_Vector &, const T &, const t_Vector &)>;
36  using t_Gradient = typename FB::t_Gradient;
37  using t_IsConverged = typename FB::t_IsConverged;
38 
40  struct Diagnostic : public FB::Diagnostic {
41  Diagnostic(t_uint niters = 0u, bool good = false) : FB::Diagnostic(niters, good) {}
42  Diagnostic(t_uint niters, bool good, t_Vector &&residual)
43  : FB::Diagnostic(niters, good, std::move(residual)) {}
44  };
46  struct DiagnosticAndResult : public Diagnostic {
49  };
50 
54  template <typename DERIVED>
55  L2ForwardBackward(Eigen::MatrixBase<DERIVED> const &target)
56  : l2_proximal_([](t_Vector &output, const t_real &regulariser_strength, const t_Vector &x) -> void {
57  proximal::l2_norm(output, regulariser_strength, x);
58  }),
59  l2_proximal_weighted_(
60  [](t_Vector &output, const Vector<Real> &regulariser_strength, const t_Vector &x) -> void {
61  proximal::l2_norm(output, regulariser_strength, x);
62  }),
63  l2_proximal_weights_(Vector<Real>::Ones(1)),
64  l2_gradient_([](t_Vector &output, t_Vector const &image, const t_Vector &residual, const t_LinearTransform &Phi) -> void {
65  output = Phi.adjoint()*residual;
66  }), // gradient of 1/2 * r^2 = r;
67  tight_frame_(false),
68  residual_tolerance_(0.),
69  relative_variation_(1e-4),
70  residual_convergence_(nullptr),
71  objective_convergence_(nullptr),
72  itermax_(std::numeric_limits<t_uint>::max()),
73  regulariser_strength_(1e-8),
74  step_size_(1),
75  sigma_(1),
76  is_converged_(),
77  Phi_(linear_transform_identity<Scalar>()),
78  target_(target) {}
79  virtual ~L2ForwardBackward() {}
80 
81 // Macro helps define properties that can be initialized as in
82 // auto padmm = L2ForwardBackward<float>().prop0(value).prop1(value);
83 #define SOPT_MACRO(NAME, TYPE) \
84  TYPE const &NAME() const { return NAME##_; } \
85  L2ForwardBackward<SCALAR> &NAME(TYPE const &(NAME)) { \
86  NAME##_ = NAME; \
87  return *this; \
88  } \
89  \
90  protected: \
91  TYPE NAME##_; \
92  \
93  public:
94 
98  SOPT_MACRO(l2_proximal_weighted, t_Proximal<Vector<Real>>);
100  SOPT_MACRO(l2_proximal_weights, Vector<Real>);
102  SOPT_MACRO(l2_gradient, t_Gradient);
104  SOPT_MACRO(tight_frame, bool);
107  SOPT_MACRO(residual_tolerance, Real);
110  SOPT_MACRO(relative_variation, Real);
113  SOPT_MACRO(residual_convergence, t_IsConverged);
116  SOPT_MACRO(objective_convergence, t_IsConverged);
118  SOPT_MACRO(itermax, t_uint);
120  SOPT_MACRO(regulariser_strength, Real);
122  SOPT_MACRO(step_size, Real);
126  SOPT_MACRO(is_converged, t_IsConverged);
129 #ifdef SOPT_MPI
131  SOPT_MACRO(obj_comm, mpi::Communicator);
132 #endif
133 
134 #undef SOPT_MACRO
136  t_Vector const &target() const { return target_; }
138  Real objmin() const { return objmin_; }
140  template <typename DERIVED>
141  L2ForwardBackward<Scalar> &target(Eigen::MatrixBase<DERIVED> const &target) {
142  target_ = target;
143  return *this;
144  }
145 
149  return operator()(out, ForwardBackward<SCALAR>::initial_guess(target(), Phi()));
150  }
154  Diagnostic operator()(t_Vector &out, std::tuple<t_Vector, t_Vector> const &guess) const {
155  return operator()(out, std::get<0>(guess), std::get<1>(guess));
156  }
161  std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
162  return operator()(out, std::get<0>(guess), std::get<1>(guess));
163  }
166  DiagnosticAndResult operator()(std::tuple<t_Vector, t_Vector> const &guess) const {
167  return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
168  }
172  std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
173  DiagnosticAndResult result;
174  static_cast<Diagnostic &>(result) = operator()(result.x, guess);
175  return result;
176  }
180  DiagnosticAndResult result;
181  static_cast<Diagnostic &>(result) = operator()(
182  result.x, ForwardBackward<SCALAR>::initial_guess(target(), Phi()));
183  return result;
184  }
187  DiagnosticAndResult result = warmstart;
188  static_cast<Diagnostic &>(result) = operator()(result.x, warmstart.x, warmstart.residual);
189  return result;
190  }
191 
193  template <typename... ARGS>
194  typename std::enable_if<sizeof...(ARGS) >= 1, L2ForwardBackward &>::type Phi(
195  ARGS &&... args) {
196  Phi_ = linear_transform(std::forward<ARGS>(args)...);
197  return *this;
198  }
199 
202  t_Proximal<Real> &l2_proximal() { return l2_proximal_; }
203  t_Proximal<Vector<Real>> &l2_proximal_weighted() { return l2_proximal_weighted_; }
206  t_Gradient &l2_gradient() { return l2_gradient_; }
207 
210  return residual_convergence(nullptr).residual_tolerance(tolerance);
211  }
214  return objective_convergence(nullptr).relative_variation(tolerance);
215  }
217  L2ForwardBackward<Scalar> &is_converged(std::function<bool(t_Vector const &x)> const &func) {
218  return is_converged([func](t_Vector const &x, t_Vector const &) { return func(x); });
219  }
220 
221  protected:
223  t_Vector target_;
225  mutable Real objmin_;
226 
231  Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res) const;
232 
234  bool residual_convergence(t_Vector const &x, t_Vector const &residual) const;
235 
237  bool objective_convergence(ScalarRelativeVariation<Scalar> &scalvar, t_Vector const &x,
238  t_Vector const &residual) const;
239 #ifdef SOPT_MPI
241  bool objective_convergence(mpi::Communicator const &obj_comm,
242  ScalarRelativeVariation<Scalar> &scalvar, t_Vector const &x,
243  t_Vector const &residual) const;
244 #endif
245 
247  bool is_converged(ScalarRelativeVariation<Scalar> &scalvar, t_Vector const &x,
248  t_Vector const &residual) const;
249 };
250 
251 template <typename SCALAR>
252 typename L2ForwardBackward<SCALAR>::Diagnostic L2ForwardBackward<SCALAR>::operator()(
253  t_Vector &out, t_Vector const &guess, t_Vector const &res) const {
254  SOPT_HIGH_LOG("Performing Forward Backward with L2 and L2 norms");
255  // The f proximal is an L2 proximal
256  Diagnostic result;
257  auto const g_proximal = [this](t_Vector &out, Real regulariser_strength, t_Vector const &x) {
258  if (this->l2_proximal_weights().size() > 1)
259  this->l2_proximal_weighted()(out, this->l2_proximal_weights() * regulariser_strength, x);
260  else
261  this->l2_proximal()(out, this->l2_proximal_weights()(0) * regulariser_strength, x);
262  };
263  const Real sigma_factor = sigma() * sigma();
264  const t_Gradient f_gradient = [sigma_factor](t_Vector &out, t_Vector const &image, t_Vector const &res, t_LinearTransform const &Phi) {
265  t_Vector temp;
266  temp = res / sigma_factor;
267  out = Phi.adjoint() * temp;
268  };
269  ScalarRelativeVariation<Scalar> scalvar(relative_variation(), relative_variation(),
270  "Objective function");
271  auto const convergence = [this, &scalvar](t_Vector const &x, t_Vector const &residual) mutable {
272  const bool result = this->is_converged(scalvar, x, residual);
273  this->objmin_ = std::real(scalvar.previous());
274  return result;
275  };
276  auto fb = ForwardBackward<SCALAR>(f_gradient, g_proximal, target())
277  .itermax(itermax())
278  .step_size(step_size())
279  .regulariser_strength(regulariser_strength())
280  .Phi(Phi())
281  .is_converged(convergence);
282  static_cast<typename ForwardBackward<SCALAR>::Diagnostic &>(result) =
283  fb(out, std::tie(guess, res));
284  return result;
285 }
286 
287 template <typename SCALAR>
288 bool L2ForwardBackward<SCALAR>::residual_convergence(t_Vector const &x,
289  t_Vector const &residual) const {
290  if (static_cast<bool>(residual_convergence())) return residual_convergence()(x, residual);
291  if (residual_tolerance() <= 0e0) return true;
292  auto const residual_norm = sopt::l2_norm(residual);
293  SOPT_LOW_LOG(" - [FB] Residuals: {} <? {}", residual_norm, residual_tolerance());
294  return residual_norm < residual_tolerance();
295 }
296 
297 template <typename SCALAR>
298 bool L2ForwardBackward<SCALAR>::objective_convergence(ScalarRelativeVariation<Scalar> &scalvar,
299  t_Vector const &x,
300  t_Vector const &residual) const {
301  if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
302  if (scalvar.relative_tolerance() <= 0e0) return true;
303  auto const current = ((regulariser_strength() > 0) ? sopt::l2_norm(x, l2_proximal_weights()) * regulariser_strength() : 0) +
304  std::pow(sopt::l2_norm(residual), 2) / (2 * sigma() * sigma());
305  return scalvar(current);
306 }
307 
308 #ifdef SOPT_MPI
309 template <typename SCALAR>
310 bool L2ForwardBackward<SCALAR>::objective_convergence(mpi::Communicator const &obj_comm,
311  ScalarRelativeVariation<Scalar> &scalvar,
312  t_Vector const &x,
313  t_Vector const &residual) const {
314  if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
315  if (scalvar.relative_tolerance() <= 0e0) return true;
316  auto const current = obj_comm.all_sum_all<t_real>(
317  ((regulariser_strength() > 0) ? sopt::l2_norm(x, l2_proximal_weights()) * regulariser_strength() : 0) +
318  std::pow(sopt::l2_norm(residual), 2) / (2 * sigma() * sigma()));
319  return scalvar(current);
320 }
321 #endif
322 
323 template <typename SCALAR>
324 bool L2ForwardBackward<SCALAR>::is_converged(ScalarRelativeVariation<Scalar> &scalvar,
325  t_Vector const &x,
326  t_Vector const &residual) const {
327  auto const user = static_cast<bool>(is_converged()) == false or is_converged()(x, residual);
328  auto const res = residual_convergence(x, residual);
329 #ifdef SOPT_MPI
330  auto const obj = objective_convergence(obj_comm(), scalvar, x, residual);
331 #else
332  auto const obj = objective_convergence(scalvar, x, residual);
333 #endif
334  // beware of short-circuiting!
335  // better evaluate each convergence function everytime, especially with mpi
336  return user and res and obj;
337 }
338 } // namespace sopt::algorithm
339 #endif
sopt::Vector< Scalar > t_Vector
sopt::t_real Scalar
SOPT_MACRO(residual_tolerance, Real)
Convergence of the relative variation of the objective functions.
SOPT_MACRO(Phi, t_LinearTransform)
Measurement operator.
SOPT_MACRO(sigma, Real)
γ parameter.
DiagnosticAndResult operator()(std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Forward Backward.
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
L2ForwardBackward(Eigen::MatrixBase< DERIVED > const &target)
t_Proximal< Real > & l2_proximal()
L1 proximal used during calculation.
t_Gradient & l2_gradient()
Proximal of the L2 ball.
Real objmin() const
Minimun of objective_function.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector, t_Vector > const &guess) const
Calls Forward Backward.
DiagnosticAndResult operator()(std::tuple< t_Vector, t_Vector > const &guess) const
Calls Forward Backward.
SOPT_MACRO(tight_frame, bool)
Whether Ψ is a tight-frame or not.
typename FB::t_LinearTransform t_LinearTransform
L2ForwardBackward< Scalar > & objective_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const
Makes it simple to chain different calls to FB.
std::function< void(t_Vector &, const T &, const t_Vector &)> t_Proximal
DiagnosticAndResult operator()() const
Calls Forward Backward.
L2ForwardBackward &::type Phi(ARGS &&... args)
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SOPT_MACRO(step_size, Real)
γ parameter.
L2ForwardBackward< Scalar > & target(Eigen::MatrixBase< DERIVED > const &target)
Sets the vector of target measurements.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Forward Backward.
SOPT_MACRO(objective_convergence, t_IsConverged)
Convergence of the residuals.
L2ForwardBackward< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
SOPT_MACRO(l2_proximal, t_Proximal< Real >)
l2 proximal for regularizaiton
SOPT_MACRO(l2_gradient, t_Gradient)
Gradient of the l2 norm.
SOPT_MACRO(residual_convergence, t_IsConverged)
Convergence of the residuals.
L2ForwardBackward< Scalar > & residual_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
SOPT_MACRO(l2_proximal_weighted, t_Proximal< Vector< Real >>)
l2 proximal for regularizaiton with weights
Diagnostic operator()(t_Vector &out) const
Calls Forward Backward.
SOPT_MACRO(relative_variation, Real)
Convergence of the relative variation of the objective functions.
t_Proximal< Vector< Real > > & l2_proximal_weighted()
t_Vector const & target() const
Vector of target measurements.
typename FB::t_IsConverged t_IsConverged
SOPT_MACRO(regulariser_strength, Real)
γ parameter.
SOPT_MACRO(l2_proximal_weights, Vector< Real >)
l2 proximal weights
sopt::LinearTransform< t_Vector > t_LinearTransform
sopt::t_real t_real
#define SOPT_MACRO(NAME, TYPE)
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
void l2_norm(Eigen::DenseBase< T0 > &out, typename real_type< typename T0::Scalar >::type gamma, Eigen::DenseBase< T1 > const &x)
Proximal of the l2 norm (note this is different from the l2 ball indicator function)
Definition: proximal.h:84
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:12
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:17
real_type< typename T0::Scalar >::type l2_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L2 norm.
Definition: maths.h:140
Values indicating how the algorithm ran.
Diagnostic(t_uint niters=0u, bool good=false)
Diagnostic(t_uint niters, bool good, t_Vector &&residual)