SOPT
Sparse OPTimisation
imaging_forward_backward.h
Go to the documentation of this file.
1 #ifndef SOPT_IMAGING_FORWARD_BACKWARD_H
2 #define SOPT_IMAGING_FORWARD_BACKWARD_H
3 
4 #include "sopt/config.h"
5 #include <limits> // for std::numeric_limits<>
6 #include <memory> // for std::shared_ptr<>
7 #include <numeric>
8 #include <tuple>
9 #include <utility>
10 #include "sopt/exception.h"
11 #include "sopt/forward_backward.h"
12 #include "sopt/linear_transform.h"
13 #include "sopt/logging.h"
14 #include "sopt/proximal.h"
16 #include "sopt/types.h"
19 
20 #include <functional>
21 #include "sopt/gradient_utils.h"
22 #include <stdexcept>
23 
24 #ifdef SOPT_MPI
25 #include "sopt/mpi/communicator.h"
26 #include "sopt/mpi/utilities.h"
27 #endif
28 
29 namespace sopt::algorithm {
30 template <typename SCALAR>
33  using FB = ForwardBackward<SCALAR>;
34 
35  public:
36  using value_type = typename FB::value_type;
37  using Scalar = typename FB::Scalar;
38  using Real = typename FB::Real;
39  using t_Vector = typename FB::t_Vector;
41  using t_Proximal = typename FB::t_Proximal;
42  using t_Gradient = typename FB::t_Gradient;
43  using t_l2Gradient = typename std::function<void(t_Vector &, const t_Vector &)>;
44  using t_IsConverged = typename FB::t_IsConverged;
45  using t_randomUpdater = typename FB::t_randomUpdater;
46 
48  struct Diagnostic : public FB::Diagnostic {
49  Diagnostic(t_uint niters = 0u, bool good = false) : FB::Diagnostic(niters, good) {}
50  Diagnostic(t_uint niters, bool good, t_Vector &&residual)
51  : FB::Diagnostic(niters, good, std::move(residual)) {}
52  };
53 
55  struct DiagnosticAndResult : public Diagnostic {
58  };
59 
63  // Note: Using setter injection instead of constructior injection to follow the
64  // style in the rest of the class, although constructor might be more appropriate
65  // In this problem we assume an objective function \f$f(x, y, \Phi) + g(x)\f$ where
66  // \f$f\f$ is differentiable with a supplied gradient and \f$g\f$ is non-differentiable with a supplied proximal operator.
67  // Throughout this class we will use `f` and `g` in variables to refer to these two parts of the objective function.
70  : g_function_(nullptr),
71  f_function_(nullptr),
72  random_updater_(nullptr),
73  tight_frame_(false),
74  residual_tolerance_(0.),
75  relative_variation_(1e-4),
76  residual_convergence_(nullptr),
77  objective_convergence_(nullptr),
78  itermax_(std::numeric_limits<t_uint>::max()),
79  regulariser_strength_(1e-8),
80  step_size_(1),
81  sigma_(1),
82  fista_(true),
83  is_converged_()
84  {
85  std::shared_ptr<t_LinearTransform> Id = std::make_shared<t_LinearTransform>(linear_transform_identity<Scalar>());
86  problem_state = std::make_shared<IterationState<t_Vector>>(target, Id);
87  }
88 
90  : g_function_(nullptr),
91  f_function_(nullptr),
92  random_updater_(updater),
93  tight_frame_(false),
94  residual_tolerance_(0.),
95  relative_variation_(1e-4),
96  residual_convergence_(nullptr),
97  objective_convergence_(nullptr),
98  itermax_(std::numeric_limits<t_uint>::max()),
99  regulariser_strength_(1e-8),
100  step_size_(1),
101  sigma_(1),
102  fista_(true),
103  is_converged_()
104  {
105  if(random_updater_)
106  {
107  // target and Phi are not known ahead of time for random data sets so need to be initialised
108  problem_state = random_updater_();
109  }
110  else
111  {
112  throw std::runtime_error("Attempted to construct ImagingForwardBackward class with a null random updater. To run without random updates supply a target vector instead.");
113  }
114  }
115 
117 // Macro helps define properties that can be initialized as in
118 // auto padmm = ImagingForwardBackward<float>().prop0(value).prop1(value);
119 #define SOPT_MACRO(NAME, TYPE) \
120  TYPE const &NAME() const { return NAME##_; } \
121  ImagingForwardBackward<SCALAR> &NAME(TYPE const &(NAME)) { \
122  NAME##_ = NAME; \
123  return *this; \
124  } \
125  \
126  protected: \
127  TYPE NAME##_; \
128  \
129  public:
130 
132  SOPT_MACRO(tight_frame, bool);
135  SOPT_MACRO(residual_tolerance, Real);
138  SOPT_MACRO(relative_variation, Real);
146  SOPT_MACRO(itermax, t_uint);
148  SOPT_MACRO(regulariser_strength, Real);
150  SOPT_MACRO(step_size, Real);
154  SOPT_MACRO(fista, bool);
157 
159  t_LinearTransform const &Phi() const { return problem_state->Phi(); }
161  problem_state->Phi(Phi);
162  return *this;
163  }
164 
165 #ifdef SOPT_MPI
167  SOPT_MACRO(obj_comm, mpi::Communicator);
168  SOPT_MACRO(adjoint_space_comm, mpi::Communicator);
169 #endif
170 
171 #undef SOPT_MACRO
172 
173  // Getter and setter for the g_function object
174  // The getter of g_function can not return a const because it will be used
175  // to call setters of its internal properties
176  std::shared_ptr<NonDifferentiableFunc<SCALAR>> g_function() { return g_function_; }
178  g_function_ = std::move(g_function);
179  return *this;
180  }
181 
182  // Getter and setter for the f_function object
183  // The getter of f_function can not return a const because it will be used
184  // to call setters of its internal properties
185  std::shared_ptr<DifferentiableFunc<SCALAR>> f_function() { return f_function_; }
187  f_function_ = std::move(f_function);
188  return *this;
189  }
190 
191  // Getter and setter for the random updater object
192  t_randomUpdater &random_updater() { return random_updater_; }
194  random_updater_ = new_updater; // may change this to a move if we don't need to keep it
195  return *this;
196  }
197 
198  t_LinearTransform const &Psi() const
199  {
200  return (g_function_) ? g_function_->Psi() : linear_transform_identity<Scalar>();
201  }
202 
203  // Default f_gradient is gradient of l2-norm
204  // This gradient ignores x and is based only on residual. (x is required for other forms of gradient)
205  //t_Gradient f_gradient;
206 
207  //void set_f_gradient(const t_Gradient &fgrad)
208  //{
209  // f_gradient = fgrad;
210  //}
211 
213  t_Vector const &target() const { return problem_state->target(); }
214 
216  Real objmin() const { return objmin_; }
219  problem_state->target(target);
220  return *this;
221  }
222 
226  return operator()(out, ForwardBackward<SCALAR>::initial_guess(target(), Phi()));
227 
228  }
232  Diagnostic operator()(t_Vector &out, std::tuple<t_Vector, t_Vector> const &guess) {
233  return operator()(out, std::get<0>(guess), std::get<1>(guess));
234  }
239  std::tuple<t_Vector const &, t_Vector const &> const &guess) {
240  return operator()(out, std::get<0>(guess), std::get<1>(guess));
241  }
244  DiagnosticAndResult operator()(std::tuple<t_Vector, t_Vector> const &guess) {
245  return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
246  }
250  std::tuple<t_Vector const &, t_Vector const &> const &guess) {
251  DiagnosticAndResult result;
252  static_cast<Diagnostic &>(result) = operator()(result.x, guess);
253  return result;
254  }
258  DiagnosticAndResult result;
259  static_cast<Diagnostic &>(result) = operator()(
260  result.x, ForwardBackward<SCALAR>::initial_guess(target(), Phi()));
261  return result;
262  }
265  DiagnosticAndResult result = warmstart;
266  static_cast<Diagnostic &>(result) = operator()(result.x, warmstart.x, warmstart.residual);
267  return result;
268  }
269 
271  template <typename... ARGS>
272  typename std::enable_if<sizeof...(ARGS) >= 1, ImagingForwardBackward &>::type Phi(
273  ARGS &&... args) {
274  problem_state->Phi(linear_transform(std::forward<ARGS>(args)...));
275  return *this;
276  }
277 
280  return residual_convergence(nullptr).residual_tolerance(tolerance);
281  }
284  return objective_convergence(nullptr).relative_variation(tolerance);
285  }
287  ImagingForwardBackward<Scalar> &is_converged(std::function<bool(t_Vector const &x)> const &func) {
288  return is_converged([func](t_Vector const &x, t_Vector const &) { return func(x); });
289  }
290 
291  protected:
292 
293  // Store a pointer of the abstract base classes DifferentiableFunc & NonDifferentiableFunction type for f and g
294  // These should point to an instance of a derived class (e.g. L1GProximal) once set up
295  std::shared_ptr<NonDifferentiableFunc<SCALAR>> g_function_;
296  std::shared_ptr<DifferentiableFunc<SCALAR>> f_function_;
297  t_randomUpdater random_updater_;
299  std::shared_ptr<IterationState<t_Vector>> problem_state;
300 
302  mutable Real objmin_;
303 
308  Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res) ;
309 
311  bool residual_convergence(t_Vector const &x, t_Vector const &residual) const;
312 
315  t_Vector const &residual) const;
316 #ifdef SOPT_MPI
318  bool objective_convergence(mpi::Communicator const &obj_comm,
319  ScalarRelativeVariation<Scalar> &scalvar, t_Vector const &x,
320  t_Vector const &residual) const;
321 #endif
322 
324  bool is_converged(ScalarRelativeVariation<Scalar> &scalvar, t_Vector const &x,
325  t_Vector const &residual) const;
326 };
327 
328 template <typename SCALAR>
329 typename ImagingForwardBackward<SCALAR>::Diagnostic ImagingForwardBackward<SCALAR>::operator()(
330  t_Vector &out, t_Vector const &guess, t_Vector const &res) {
331  if(!g_function_)
332  {
333  throw std::runtime_error("Non-differentiable function `g` has not been set. You must set it with `g_function()` before calling the algorithm.");
334  }
335  g_function_->log_message();
336  Diagnostic result;
337  auto const g_proximal = g_function_->proximal_operator();
338  t_Gradient f_gradient;
339  Real gradient_step_size;
340  if(f_function_)
341  {
342  f_gradient = f_function_->gradient();
343  gradient_step_size = f_function_->get_step_size();
344  }
345  if(!f_gradient)
346  {
347  SOPT_MEDIUM_LOG("Gradient function has not been set; using default (gaussian likelihood) gradient. (To set a custom gradient set_gradient() must be called before the algorithm is run.)");
348  f_gradient = [this](t_Vector &output, t_Vector const &x, t_Vector const &residual, t_LinearTransform const &Phi) {
349  output = Phi.adjoint() * (residual / (this->sigma() * this->sigma()));
350  };
351  gradient_step_size = sigma()*sigma();
352  }
353  ScalarRelativeVariation<Scalar> scalvar(relative_variation(), relative_variation(),
354  "Objective function");
355  auto const convergence = [this, &scalvar](t_Vector const &x, t_Vector const &residual) mutable {
356  const bool result = this->is_converged(scalvar, x, residual);
357  this->objmin_ = std::real(scalvar.previous());
358  return result;
359  };
360  auto fb = ForwardBackward<SCALAR>(f_gradient, g_proximal, target())
361  .itermax(itermax())
362  .step_size(gradient_step_size)
363  .regulariser_strength(regulariser_strength())
364  .fista(fista())
365  .Phi(Phi())
366  .is_converged(convergence)
367  .random_updater(random_updater_)
368  .set_problem_state(problem_state);
369  static_cast<typename ForwardBackward<SCALAR>::Diagnostic &>(result) =
370  fb(out, std::tie(guess, res));
371  return result;
372 }
373 
374 template <typename SCALAR>
376  t_Vector const &residual) const {
377  if (static_cast<bool>(residual_convergence())) return residual_convergence()(x, residual);
378  if (residual_tolerance() <= 0e0) return true;
379  auto const residual_norm = sopt::l2_norm(residual);
380  SOPT_LOW_LOG(" - [FB] Residuals: {} <? {}", residual_norm, residual_tolerance());
381  return residual_norm < residual_tolerance();
382 }
383 
384 template <typename SCALAR>
385 bool ImagingForwardBackward<SCALAR>::objective_convergence(ScalarRelativeVariation<Scalar> &scalvar,
386  t_Vector const &x,
387  t_Vector const &residual) const {
388  if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
389  if (scalvar.relative_tolerance() <= 0e0) return true;
390  auto const current = ((regulariser_strength() > 0) ? g_function_->function(x)
391  * regulariser_strength() : 0) + \
392  ((f_function_) ? f_function_->function(x, target(), Phi()) : std::pow(sopt::l2_norm(residual), 2) / (2 * sigma() * sigma()));
393  return scalvar(current);
394 }
395 
396 #ifdef SOPT_MPI
397 template <typename SCALAR>
398 bool ImagingForwardBackward<SCALAR>::objective_convergence(mpi::Communicator const &obj_comm,
399  ScalarRelativeVariation<Scalar> &scalvar,
400  t_Vector const &x,
401  t_Vector const &residual) const {
402  if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
403  if (scalvar.relative_tolerance() <= 0e0) return true;
404  auto const current = obj_comm.all_sum_all<t_real>(
405  ((regulariser_strength() > 0) ? g_function_->function(x)
406  * regulariser_strength() : 0) + std::pow(sopt::l2_norm(residual), 2) / (2 * sigma_ * sigma_));
407  return scalvar(current);
408 }
409 #endif
410 
411 template <typename SCALAR>
412 bool ImagingForwardBackward<SCALAR>::is_converged(ScalarRelativeVariation<Scalar> &scalvar,
413  t_Vector const &x,
414  t_Vector const &residual) const {
415  auto const user = static_cast<bool>(is_converged()) == false or is_converged()(x, residual);
416  auto const res = residual_convergence(x, residual);
417 #ifdef SOPT_MPI
418  auto const obj = objective_convergence(obj_comm(), scalvar, x, residual);
419 #else
420  auto const obj = objective_convergence(scalvar, x, residual);
421 #endif
422  // beware of short-circuiting!
423  // better evaluate each convergence function everytime, especially with mpi
424  return user and res and obj;
425 }
426 } // namespace sopt::algorithm
427 #endif
sopt::Vector< Scalar > t_Vector
sopt::t_real Scalar
Real objmin() const
Minimum of objective_function.
typename FB::t_LinearTransform t_LinearTransform
Diagnostic operator()(t_Vector &out) const
Calls Forward Backward.
ImagingForwardBackward< SCALAR > & random_updater(t_randomUpdater &new_updater)
std::shared_ptr< NonDifferentiableFunc< SCALAR > > g_function()
std::shared_ptr< DifferentiableFunc< SCALAR > > f_function()
SOPT_MACRO(sigma, Real)
γ parameter.
ImagingForwardBackward< SCALAR > & Phi(t_LinearTransform const &(Phi))
t_LinearTransform const & Phi() const
Measurement operator.
DiagnosticAndResult operator()(std::tuple< t_Vector, t_Vector > const &guess)
Calls Forward Backward.
ImagingForwardBackward< Scalar > & target(t_Vector const &target)
Sets the vector of target measurements.
ImagingForwardBackward< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
DiagnosticAndResult operator()(std::tuple< t_Vector const &, t_Vector const & > const &guess)
Calls Forward Backward.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector const &, t_Vector const & > const &guess)
Calls Forward Backward.
ImagingForwardBackward< SCALAR > & g_function(std::shared_ptr< NonDifferentiableFunc< SCALAR >> g_function)
SOPT_MACRO(relative_variation, Real)
Convergence of the relative variation of the objective functions.
ImagingForwardBackward< Scalar > & residual_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
SOPT_MACRO(step_size, Real)
γ parameter.
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart)
Makes it simple to chain different calls to FB.
ImagingForwardBackward< Scalar > & objective_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
SOPT_MACRO(regulariser_strength, Real)
γ parameter.
ImagingForwardBackward &::type Phi(ARGS &&... args)
typename std::function< void(t_Vector &, const t_Vector &)> t_l2Gradient
ImagingForwardBackward< SCALAR > & f_function(std::shared_ptr< DifferentiableFunc< SCALAR >> f_function)
DiagnosticAndResult operator()()
Calls Forward Backward.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector, t_Vector > const &guess)
Calls Forward Backward.
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SOPT_MACRO(objective_convergence, t_IsConverged)
Convergence of the residuals.
SOPT_MACRO(fista, bool)
flag to for FISTA Forward-Backward algorithm. True by default but should be false when using a learne...
SOPT_MACRO(residual_convergence, t_IsConverged)
Convergence of the residuals.
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
t_Vector const & target() const
Vector of target measurements.
SOPT_MACRO(tight_frame, bool)
Whether Ψ is a tight-frame or not.
SOPT_MACRO(residual_tolerance, Real)
Convergence of the relative variation of the objective functions.
sopt::LinearTransform< t_Vector > t_LinearTransform
sopt::t_real t_real
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:225
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:12
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:17
real_type< typename T0::Scalar >::type l2_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L2 norm.
Definition: maths.h:140
Diagnostic(t_uint niters, bool good, t_Vector &&residual)