SOPT
Sparse OPTimisation
forward_backward.h
Go to the documentation of this file.
1 #ifndef SOPT_FORWARD_BACKWARD_H
2 #define SOPT_FORWARD_BACKWARD_H
3 
4 #include "sopt/config.h"
5 #include <functional>
6 #include <limits>
7 #include <tuple> // for tuple<>
8 #include <utility> // for std::move<>
9 #include "sopt/exception.h"
10 #include "sopt/linear_transform.h"
11 #include "sopt/logging.h"
12 #include "sopt/types.h"
13 
14 #include "sopt/gradient_utils.h"
15 
16 namespace sopt::algorithm {
17 
26 template <typename SCALAR>
27 class ForwardBackward {
28  public:
30  using value_type = SCALAR;
32  using Scalar = value_type;
34  using Real = typename real_type<Scalar>::type;
36  using t_Vector = Vector<Scalar>;
40  using t_IsConverged = std::function<bool(const t_Vector &, const t_Vector &)>;
42  using t_Proximal = ProximalFunction<Scalar>;
44  // The first argument is the output vector, the others are inputs
45  using t_Gradient = std::function<void(t_Vector &gradient, const t_Vector &image, const t_Vector &residual, const t_LinearTransform& Phi)>;
46  using t_randomUpdater = std::function<std::shared_ptr<IterationState<t_Vector>>()>;
47 
49  struct Diagnostic {
51  t_uint niters;
53  bool good;
55  t_Vector residual;
56 
57  Diagnostic(t_uint niters = 0u, bool good = false)
58  : niters(niters), good(good), residual(t_Vector::Zero(0)) {}
59  Diagnostic(t_uint niters, bool good, t_Vector &&residual)
60  : niters(niters), good(good), residual(std::move(residual)) {}
61  };
63  struct DiagnosticAndResult : public Diagnostic {
65  t_Vector x;
66  };
67 
71  ForwardBackward(t_Gradient const &f_gradient, t_Proximal const &g_proximal,
72  t_Vector const &target)
73  : itermax_(std::numeric_limits<t_uint>::max()),
74  regulariser_strength_(1e-8),
75  step_size_(1),
76  is_converged_(),
77  fista_(true),
78  f_gradient_(f_gradient),
79  g_proximal_(g_proximal)
80  {
81  std::shared_ptr<t_LinearTransform> Id = std::make_shared<t_LinearTransform>(linear_transform_identity<Scalar>());
82  problem_state = std::make_shared<IterationState<t_Vector>>(target, Id);
83  }
84  virtual ~ForwardBackward() {}
85 
86 // Macro helps define properties that can be initialized as in
87 // auto sdmm = ForwardBackward<float>().prop0(value).prop1(value);
88 #define SOPT_MACRO(NAME, TYPE) \
89  TYPE const &NAME() const { return NAME##_; } \
90  ForwardBackward<SCALAR> &NAME(TYPE const &(NAME)) { \
91  NAME##_ = NAME; \
92  return *this; \
93  } \
94  \
95  protected: \
96  TYPE NAME##_; \
97  \
98  public:
99 
101  SOPT_MACRO(itermax, t_uint);
103  SOPT_MACRO(regulariser_strength, Real);
105  SOPT_MACRO(step_size, Real);
107  SOPT_MACRO(fista, bool);
110  SOPT_MACRO(is_converged, t_IsConverged);
112  SOPT_MACRO(f_gradient, t_Gradient);
114  SOPT_MACRO(g_proximal, t_Proximal);
115 
117  t_LinearTransform const &Phi() const { return problem_state->Phi(); }
118  ForwardBackward<SCALAR> &Phi(t_LinearTransform const &new_phi) {
119  problem_state->Phi(new_phi);
120  return *this;
121  }
122 
123  ForwardBackward<SCALAR> &random_updater(t_randomUpdater &rU)
124  {
125  random_updater_ = rU;
126  return *this;
127  }
128 
129  ForwardBackward<SCALAR> &set_problem_state(std::shared_ptr<IterationState<t_Vector>> pS)
130  {
131  problem_state = pS;
132  return *this;
133  }
134 
135 #undef SOPT_MACRO
137  void f_gradient(t_Vector &out, t_Vector const &x, t_Vector const &res, t_LinearTransform const &Phi) const { f_gradient()(out, x, res, Phi); }
139  void g_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const {
140  g_proximal()(out, regulariser_strength, x);
141  }
142 
144  ForwardBackward<Scalar> &is_converged(std::function<bool(t_Vector const &x)> const &func) {
145  return is_converged([func](t_Vector const &x, t_Vector const &) { return func(x); });
146  }
147 
149  t_Vector const &target() const { return problem_state->target(); }
151  ForwardBackward<Scalar> &target(t_Vector const &target) {
152  problem_state->target(target);
153  return *this;
154  }
155 
157  bool is_converged(t_Vector const &x, t_Vector const &residual) const {
158  return static_cast<bool>(is_converged()) and is_converged()(x, residual);
159  }
160 
163  Diagnostic operator()(t_Vector &out) { return operator()(out, initial_guess()); }
167  Diagnostic operator()(t_Vector &out, std::tuple<t_Vector, t_Vector> const &guess) {
168  return operator()(out, std::get<0>(guess), std::get<1>(guess));
169  }
173  Diagnostic operator()(t_Vector &out,
174  std::tuple<t_Vector const &, t_Vector const &> const &guess) {
175  return operator()(out, std::get<0>(guess), std::get<1>(guess));
176  }
179  DiagnosticAndResult operator()(std::tuple<t_Vector, t_Vector> const &guess) {
180  return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
181  }
184  DiagnosticAndResult operator()(
185  std::tuple<t_Vector const &, t_Vector const &> const &guess) {
186  DiagnosticAndResult result;
187  static_cast<Diagnostic &>(result) = operator()(result.x, guess);
188  return result;
189  }
192  DiagnosticAndResult operator()() {
193  DiagnosticAndResult result;
194  static_cast<Diagnostic &>(result) = operator()(result.x, initial_guess());
195  return result;
196  }
198  DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) {
199  DiagnosticAndResult result = warmstart;
200  static_cast<Diagnostic &>(result) = operator()(result.x, warmstart.x, warmstart.residual);
201  return result;
202  }
204  template <typename... ARGS>
205  typename std::enable_if<sizeof...(ARGS) >= 1, ForwardBackward &>::type Phi(ARGS &&... args) {
206  problem_state->Phi(linear_transform(std::forward<ARGS>(args)...));
207  return *this;
208  }
209 
214  std::tuple<t_Vector, t_Vector> initial_guess() const {
215  return ForwardBackward<SCALAR>::initial_guess(target(), Phi());
216  }
217 
224  static std::tuple<t_Vector, t_Vector> initial_guess(t_Vector const &target,
225  t_LinearTransform const &phi) {
226  std::tuple<t_Vector, t_Vector> guess;
227  std::get<0>(guess) = static_cast<t_Vector>(phi.adjoint() * target) / phi.sq_norm();
228  std::get<1>(guess) = phi * std::get<0>(guess) - target;
229  return guess;
230  }
231 
232  protected:
233  void iteration_step(t_Vector &out, t_Vector &residual, t_Vector &p, t_Vector &z,
234  const t_real lambda);
235 
237  void sanity_check(t_Vector const &x_guess, t_Vector const &res_guess) const {
238  if ((Phi().adjoint() * target()).size() != x_guess.size())
239  SOPT_THROW("target, adjoint measurement operator and input vector have inconsistent sizes");
240  if (target().size() != res_guess.size())
241  SOPT_THROW("target and residual vector have inconsistent sizes");
242  if ((Phi() * x_guess).size() != target().size())
243  SOPT_THROW("target, measurement operator and input vector have inconsistent sizes");
244  if (not static_cast<bool>(is_converged()))
245  SOPT_WARN("No convergence function was provided: algorithm will run for {} steps", itermax());
246  }
247 
252  Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res);
253 
255  std::shared_ptr<IterationState<t_Vector>> problem_state;
256  t_randomUpdater random_updater_;
257 };
258 
273 template <typename SCALAR>
274 void ForwardBackward<SCALAR>::iteration_step(t_Vector &image, t_Vector &residual, t_Vector &auxilliary_image,
275  t_Vector &gradient_current, const t_real FISTA_step) {
276  t_Vector prev_image = image;
277  f_gradient(gradient_current, auxilliary_image, residual, Phi()); // assigns gradient_current (non normalised)
278  t_Vector auxilliary_with_step = auxilliary_image - step_size() / Phi().sq_norm() * gradient_current; // step to new image using gradient
279  const Real weight = regulariser_strength() * step_size();
280  g_proximal(image, weight, auxilliary_with_step); // apply proximal operator to new image
281  auxilliary_image = image + FISTA_step * (image - prev_image); // update auxilliary vector with FISTA acceleration step
282 
283  // set up next iteration
284  if(random_updater_)
285  {
286  problem_state = random_updater_();
287  }
288  residual = (Phi() * auxilliary_image) - target(); // updates the residual for the NEXT iteration (new image).
289 }
290 
291 template <typename SCALAR>
292 typename ForwardBackward<SCALAR>::Diagnostic ForwardBackward<SCALAR>::operator()(
293  t_Vector &out, t_Vector const &x_guess, t_Vector const &res_guess) {
294  SOPT_HIGH_LOG("Performing Forward Backward Splitting");
295  if (fista()) {
296  SOPT_HIGH_LOG("Using FISTA algorithm");
297  } else {
298  SOPT_HIGH_LOG("Using standard FB algorithm");
299  }
300  sanity_check(x_guess, res_guess);
301 
302  const size_t image_size = x_guess.size();
303 
304  t_Vector auxilliary_image = x_guess;
305  t_Vector residual = res_guess;
306  t_Vector gradient_current = t_Vector::Zero(image_size);
307  out = x_guess;
308 
309  t_uint niters(0);
310  bool converged = false;
311  Real theta = 1.0;
312  Real theta_new = 1.0;
313  Real FISTA_step = 0.0;
314  for (; (not converged) && (niters < itermax()); ++niters) {
315  SOPT_MEDIUM_LOG(" - [FB] Iteration {}/{}", niters, itermax());
316  if (fista()) {
317  theta_new = (1 + std::sqrt(1 + 4 * theta * theta)) / 2.;
318  FISTA_step = (theta - 1) / (theta_new);
319  theta = theta_new;
320  }
321  SOPT_LOW_LOG(" - Call iteration step");
322  iteration_step(out, residual, auxilliary_image, gradient_current, FISTA_step);
323  SOPT_LOW_LOG(" - [FB] Sum of residuals: {}", residual.array().abs().sum());
324  converged = is_converged(out, residual);
325  }
326 
327  if (converged) {
328  SOPT_MEDIUM_LOG(" - [FB] converged in {} of {} iterations", niters, itermax());
329  } else if (static_cast<bool>(is_converged())) {
330  // not meaningful if not convergence function
331  SOPT_ERROR(" - [FB] did not converge within {} iterations", itermax());
332  }
333  return {niters, converged, std::move(residual)};
334 }
335 } // namespace sopt::algorithm
336 #endif
sopt::Vector< Scalar > t_Vector
sopt::t_real Scalar
sopt::t_real sq_norm() const
LinearTransform< VECTOR > adjoint() const
Indirect transform.
Computes inner-most element type.
Definition: real_type.h:42
#define SOPT_THROW(MSG)
Definition: exception.h:46
#define SOPT_MACRO(NAME, TYPE)
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
#define SOPT_ERROR(...)
\macro Something is definitely wrong, algorithm exits
Definition: logging.h:211
#define SOPT_WARN(...)
\macro Something might be going wrong
Definition: logging.h:213
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:225
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:12
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
std::function< void(Vector< SCALAR > &output, typename real_type< SCALAR >::type const weight, Vector< SCALAR > const &input)> ProximalFunction
Typical function signature for calls to proximal.
Definition: types.h:48