SOPT
Sparse OPTimisation
sdmm.h
Go to the documentation of this file.
1 #ifndef SOPT_SDMM_H
2 #define SOPT_SDMM_H
3 
4 #include "sopt/config.h"
5 #include <limits>
6 #include <numeric>
7 #include <utility> // for std::forward<>
8 #include <vector>
10 #include "sopt/exception.h"
11 #include "sopt/linear_transform.h"
12 #include "sopt/logging.h"
13 #include "sopt/proximal.h"
15 #include "sopt/types.h"
16 #include "sopt/wrapper.h"
17 
18 namespace sopt::algorithm {
19 
22 template <typename SCALAR>
23 class SDMM {
24  public:
26  struct Diagnostic {
30  bool good;
33  };
34  struct DiagnosticAndResult : public Diagnostic {
37  };
39  using value_type = SCALAR;
41  using Scalar = value_type;
43  using Real = typename real_type<Scalar>::type;
52 
53  SDMM()
54  : itermax_(std::numeric_limits<t_uint>::max()),
55  gamma_(1e-8),
56  conjugate_gradient_(std::numeric_limits<t_uint>::max(), 1e-6),
57  is_converged_([](t_Vector const &) { return false; }) {}
58  virtual ~SDMM() {}
59 
60 // Macro helps define properties that can be initialized as in
61 // auto sdmm = SDMM<float>().prop0(value).prop1(value);
62 #define SOPT_MACRO(NAME, TYPE) \
63  TYPE const &NAME() const { return NAME##_; } \
64  SDMM<SCALAR> &NAME(TYPE const &(NAME)) { \
65  NAME##_ = NAME; \
66  return *this; \
67  } \
68  \
69  protected: \
70  TYPE NAME##_; \
71  \
72  public:
74  SOPT_MACRO(itermax, t_uint);
76  SOPT_MACRO(gamma, Real);
81 #undef SOPT_MACRO
84  conjugate_gradient_.itermax(itermax);
85  conjugate_gradient_.tolerance(tolerance);
86  return *this;
87  }
89  template <typename PROXIMAL, typename T>
90  SDMM<SCALAR> &append(PROXIMAL proximal, T args) {
91  proximals().emplace_back(proximal);
92  transforms().emplace_back(linear_transform(args));
93  return *this;
94  }
96  template <typename PROXIMAL>
97  SDMM<SCALAR> &append(PROXIMAL proximal) {
98  return append(proximal, linear_transform_identity<Scalar>());
99  }
101  template <typename PROXIMAL, typename L, typename LADJOINT>
102  SDMM<SCALAR> &append(PROXIMAL proximal, L l, LADJOINT ladjoint) {
103  return append(proximal, linear_transform<t_Vector>(l, ladjoint));
104  }
106  template <typename PROXIMAL, typename L, typename LADJOINT>
107  SDMM<SCALAR> &append(PROXIMAL proximal, L l, LADJOINT ladjoint, std::array<t_int, 3> sizes) {
108  return append(proximal, linear_transform<t_Vector>(l, ladjoint, sizes));
109  }
111  template <typename PROXIMAL, typename L, typename LADJOINT>
112  SDMM<SCALAR> &append(PROXIMAL proximal, L l, std::array<t_int, 3> dsizes, LADJOINT ladjoint,
113  std::array<t_int, 3> isizes) {
114  return append(proximal, linear_transform<t_Vector>(l, dsizes, ladjoint, isizes));
115  }
116 
121  Diagnostic operator()(t_Vector &out, t_Vector const &input) const;
123  DiagnosticAndResult result;
124  static_cast<Diagnostic &>(result) = operator()(result.x, input);
125  return result;
126  }
129  DiagnosticAndResult result;
130  static_cast<Diagnostic &>(result) = operator()(result.x, warmstart.x);
131  return result;
132  }
133 
135  std::vector<t_LinearTransform> const &transforms() const { return transforms_; }
137  std::vector<t_LinearTransform> &transforms() { return transforms_; }
139  t_LinearTransform const &transforms(t_uint i) const { return transforms_[i]; }
141  t_LinearTransform &transforms(t_uint i) { return transforms_[i]; }
142 
144  std::vector<t_Proximal> const &proximals() const { return proximals_; }
146  std::vector<t_Proximal> &proximals() { return proximals_; }
148  t_Proximal const &proximals(t_uint i) const { return proximals_[i]; }
150  t_Proximal &proximals(t_uint i) { return proximals_[i]; }
152  template <typename T0>
154  t_uint i, Eigen::MatrixBase<T0> const &x) const {
155  return {proximals()[i], gamma(), x};
156  }
157 
159  t_uint size() const { return proximals().size(); }
160 
161  // We must declare the first argument explicitly so that the function never
162  // match the getter with the same name.
165  template <typename T0, typename... T>
166  auto conjugate_gradient(T0 &&t0, T &&... args) const
167  -> decltype(this->conjugate_gradient()(std::forward<T0>(t0), std::forward<T>(args)...)) {
168  return conjugate_gradient()(std::forward<T0>(t0), std::forward<T>(args)...);
169  }
170 
172  bool is_converged(t_Vector const &x) const { return is_converged()(x); }
173 
174  protected:
176  std::vector<t_LinearTransform> transforms_;
178  std::vector<t_Proximal> proximals_;
179 
181  using t_Vectors = std::vector<t_Vector>;
183  virtual ConjugateGradient::Diagnostic solve_for_xn(t_Vector &out, t_Vectors const &y,
184  t_Vectors const &z) const;
186  virtual void update_directions(t_Vectors &y, t_Vectors &z, t_Vector const &x) const;
187 
189  virtual void initialization(t_Vectors &y, t_Vectors &z, t_Vector const &x) const;
190 
192  virtual void sanity_check(t_Vector const &input) const;
193 };
194 
195 template <typename SCALAR>
197  t_Vector const &input) const {
198  sanity_check(input);
199  bool convergence = false;
200  t_uint niters(0);
201  // Figures out where itermax or convergence reached
202  auto const has_finished = [&convergence, &niters, this](t_Vector const &out) {
203  convergence = is_converged(out);
204  return niters >= itermax() or convergence;
205  };
206 
207  SOPT_HIGH_LOG("Performing SDMM ");
208  out = input;
209  t_Vectors y(transforms().size());
210  t_Vectors z(transforms().size());
211 
212  // Initial step replaces iteration update with initialization
213  initialization(y, z, input);
214  auto cg_diagnostic = solve_for_xn(out, y, z);
215 
216  while (not has_finished(out)) {
217  SOPT_LOW_LOG("Iteration {}/{}. ", niters, itermax());
218  // computes y and z from out and transforms
219  update_directions(y, z, out);
220  SOPT_LOW_LOG(" - sum z_ij = {}",
221  std::accumulate(z.begin(), z.end(), Scalar(0e0),
222  [](Scalar const &a, t_Vector const &z) { return a + z.sum(); }));
223  // computes x = L^-1 y
224  cg_diagnostic = solve_for_xn(out, y, z);
225  SOPT_LOW_LOG(" - CG Residual = {} in {}/{} iterations", cg_diagnostic.residual,
226  cg_diagnostic.niters, conjugate_gradient().itermax());
227 
228  ++niters;
229  }
230  return {niters, convergence, cg_diagnostic};
231 }
232 
233 template <typename SCALAR>
235  t_Vectors const &z) const {
236  assert(z.size() == transforms().size());
237  assert(y.size() == transforms().size());
238  SOPT_TRACE("Solving for x_n");
239 
240  // Initialize b of A x = b = sum_i L_i^H(z_i - y_i)
241  t_Vector b = out.Zero(out.size());
242  for (t_uint i(0); i < transforms().size(); ++i) b += transforms(i).adjoint() * (y[i] - z[i]);
243  if (b.stableNorm() < 1e-12) {
244  out.fill(0e0);
245  return {0, 0, true};
246  }
247 
248  // Then create operator A
249  auto A = [this](t_Vector &out, t_Vector const &input) {
250  out = out.Zero(input.size());
251  for (auto const &transform : this->transforms())
252  out += transform.adjoint() * static_cast<t_Vector>(transform * input);
253  };
254 
255  // Call conjugate gradient
256  auto const diagnostic = this->conjugate_gradient(out, A, b);
257  if (not diagnostic.good) {
258  SOPT_ERROR("CG error - iterations: {}/{} - residuals {}\n", diagnostic.niters,
259  conjugate_gradient().itermax(), diagnostic.residual);
260  SOPT_THROW("Conjugate gradient failed to converge");
261  }
262 
263  return diagnostic;
264 }
265 
266 template <typename SCALAR>
267 void SDMM<SCALAR>::update_directions(t_Vectors &y, t_Vectors &z, t_Vector const &x) const {
268  SOPT_TRACE("Updating directions");
269  for (t_uint i(0); i < transforms().size(); ++i) {
270  z[i] += transforms(i) * x;
271  y[i] = proximals(i, z[i]);
272  z[i] -= y[i];
273  }
274 }
275 
276 template <typename SCALAR>
277 void SDMM<SCALAR>::initialization(t_Vectors &y, t_Vectors &z, t_Vector const &x) const {
278  SOPT_TRACE("Initializing SDMM");
279  for (t_uint i(0); i < transforms().size(); i++) {
280  y[i] = transforms(i) * x;
281  z[i].resize(y[i].size());
282  z[i].fill(0);
283  assert(z[i].size() == y[i].size());
284  SOPT_TRACE(" - transform {}: {}", i, y[i].transpose());
285  }
286 }
287 
288 template <typename SCALAR>
289 void SDMM<SCALAR>::sanity_check(t_Vector const &x) const {
290  bool doexit = false;
291  if (proximals().size() != transforms().size()) {
292  SOPT_ERROR("Internal error: number of proximals and transforms do not match");
293  doexit = true;
294  }
295  if (x.size() == 0) SOPT_WARN("Input vector has zero size");
296  if (size() == 0) SOPT_WARN("No operators - SDMM is empty");
297  for (t_uint i(0); i < size(); ++i) {
298  auto const xdual = t_Vector::Zero((transforms(i) * x).size());
299  auto const r = (transforms(i).adjoint() * xdual).size();
300  if (r != x.size()) {
301  SOPT_ERROR("Output size of transform {} and input do not match: {} vs {}", i, r, x.size());
302  doexit = true;
303  }
304  }
305  if (doexit) SOPT_THROW("Input to SDMM is inconsistent");
306 }
307 } // namespace sopt::algorithm
308 #endif
sopt::Vector< Scalar > t_Vector
constexpr Scalar b
constexpr Scalar a
Solves $Ax = b$ for $x$, given $A$ and $b$.
Simultaneous-direction method of the multipliers.
Definition: sdmm.h:23
SDMM< SCALAR > & append(PROXIMAL proximal, L l, std::array< t_int, 3 > dsizes, LADJOINT ladjoint, std::array< t_int, 3 > isizes)
Appends a proximal with the linear transform as pair of functions.
Definition: sdmm.h:112
auto conjugate_gradient(T0 &&t0, T &&... args) const -> decltype(this->conjugate_gradient()(std::forward< T0 >(t0), std::forward< T >(args)...))
Forwards to internal conjugage gradient object.
Definition: sdmm.h:166
std::vector< t_LinearTransform > const & transforms() const
Linear transforms associated with each objective function.
Definition: sdmm.h:135
bool is_converged(t_Vector const &x) const
Forwards to convergence function parameter.
Definition: sdmm.h:172
virtual ~SDMM()
Definition: sdmm.h:58
SDMM< SCALAR > & append(PROXIMAL proximal, L l, LADJOINT ladjoint)
Appends a proximal with the linear transform as pair of functions.
Definition: sdmm.h:102
proximal::ProximalExpression< t_Proximal const &, T0 > proximals(t_uint i, Eigen::MatrixBase< T0 > const &x) const
Lazy call to specific proximal function.
Definition: sdmm.h:153
SDMM< SCALAR > & append(PROXIMAL proximal, L l, LADJOINT ladjoint, std::array< t_int, 3 > sizes)
Appends a proximal with the linear transform as pair of functions.
Definition: sdmm.h:107
SDMM< SCALAR > & append(PROXIMAL proximal)
Appends a proximal with identity as the linear transform.
Definition: sdmm.h:97
std::vector< t_Proximal > & proximals()
Linear transforms associated with each objective function.
Definition: sdmm.h:146
Diagnostic operator()(t_Vector &out, t_Vector const &input) const
Implements SDMM.
Definition: sdmm.h:196
t_Proximal & proximals(t_uint i)
Proximal associated with a given objective function.
Definition: sdmm.h:150
Vector< SCALAR > t_Vector
Type of then underlying vectors.
Definition: sdmm.h:45
value_type Scalar
Scalar type.
Definition: sdmm.h:41
SDMM< SCALAR > & conjugate_gradient(t_uint itermax, t_real tolerance)
Helps setup conjugate gradient.
Definition: sdmm.h:83
t_LinearTransform & transforms(t_uint i)
Linear transform associated with a given objective function.
Definition: sdmm.h:141
typename real_type< Scalar >::type Real
Real type.
Definition: sdmm.h:43
t_Proximal const & proximals(t_uint i) const
Proximal associated with a given objective function.
Definition: sdmm.h:148
std::vector< t_Proximal > const & proximals() const
Proximal of each objective function.
Definition: sdmm.h:144
t_LinearTransform const & transforms(t_uint i) const
Linear transform associated with a given objective function.
Definition: sdmm.h:139
SOPT_MACRO(gamma, Real)
Gamma.
ProximalFunction< SCALAR > t_Proximal
Type of the proximal functions.
Definition: sdmm.h:49
SDMM< SCALAR > & append(PROXIMAL proximal, T args)
Appends a proximal and linear transform.
Definition: sdmm.h:90
SCALAR value_type
Scalar type.
Definition: sdmm.h:39
std::vector< t_LinearTransform > & transforms()
Linear transforms associated with each objective function.
Definition: sdmm.h:137
ConvergenceFunction< SCALAR > t_IsConverged
Type of the convergence function.
Definition: sdmm.h:51
DiagnosticAndResult operator()(t_Vector const &input) const
Definition: sdmm.h:122
t_uint size() const
Number of terms.
Definition: sdmm.h:159
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const
Makes it simple to chain different calls to SDMM.
Definition: sdmm.h:128
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SOPT_MACRO(conjugate_gradient, ConjugateGradient)
Conjugate gradient.
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
Computes inner-most element type.
Definition: real_type.h:42
Expression referencing a lazy proximal function call.
#define SOPT_THROW(MSG)
Definition: exception.h:46
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
#define SOPT_ERROR(...)
\macro Something is definitely wrong, algorithm exits
Definition: logging.h:211
#define SOPT_TRACE(...)
Definition: logging.h:220
#define SOPT_WARN(...)
\macro Something might be going wrong
Definition: logging.h:213
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
std::function< void(Vector< SCALAR > &output, typename real_type< SCALAR >::type const weight, Vector< SCALAR > const &input)> ProximalFunction
Typical function signature for calls to proximal.
Definition: types.h:48
std::function< bool(Vector< SCALAR > const &)> ConvergenceFunction
Typical function signature for convergence.
Definition: types.h:52
Values indicating how the algorithm ran.
Vector< SCALAR > x
Vector which minimizes the sum of functions.
Definition: sdmm.h:36
Values indicating how the algorithm ran.
Definition: sdmm.h:26
t_uint niters
Number of iterations.
Definition: sdmm.h:28
ConjugateGradient::Diagnostic cg_diagnostic
Conjugate gradient result.
Definition: sdmm.h:32
bool good
Wether convergence was achieved.
Definition: sdmm.h:30