SOPT
Sparse OPTimisation
l1_proximal.h
Go to the documentation of this file.
1 #ifndef SOPT_L1_PROXIMAL_H
2 #define SOPT_L1_PROXIMAL_H
3 
4 #include "sopt/config.h"
5 #include <array>
6 #include <type_traits>
7 #include <utility> // for std::forward<>
8 #include <Eigen/Core>
10 #include "sopt/maths.h"
12 #ifdef SOPT_MPI
13 #include "sopt/mpi/communicator.h"
14 #include "sopt/mpi/utilities.h"
15 #endif
16 
17 namespace sopt::proximal {
18 
25 template <typename SCALAR>
26 class L1TightFrame {
27  public:
29  using Scalar = SCALAR;
31  using Real = typename real_type<Scalar>::type;
32 
33 #ifdef SOPT_MPI
39  L1TightFrame(mpi::Communicator const &direct_comm = mpi::Communicator(),
40  mpi::Communicator const &adjoint_comm = mpi::Communicator())
42  nu_(1e0),
43  direct_space_comm_(direct_comm),
44  adjoint_space_comm_(adjoint_comm),
45  weights_(Vector<Real>::Ones(1)) {}
46 #else
48  : Psi_(linear_transform_identity<Scalar>()), nu_(1e0), weights_(Vector<Real>::Ones(1)) {}
49 #endif
50 
51 #define SOPT_MACRO(NAME, TYPE) \
52  TYPE const &NAME() const { return NAME##_; } \
53  L1TightFrame<Scalar> &NAME(TYPE const &(NAME)) { \
54  NAME##_ = NAME; \
55  return *this; \
56  } \
57  \
58  protected: \
59  TYPE NAME##_; \
60  \
61  public:
66 #ifdef SOPT_MPI
68  SOPT_MACRO(direct_space_comm, mpi::Communicator);
70  SOPT_MACRO(adjoint_space_comm, mpi::Communicator);
71 #endif
72 #undef SOPT_MACRO
74  Vector<Real> const &weights() const { return weights_; }
76  template <typename T>
77  L1TightFrame<Scalar> &weights(Eigen::MatrixBase<T> const &w) {
78  if ((w.array() < 0e0).any()) SOPT_THROW("Weights cannot be negative");
79  if (w.stableNorm() < 1e-12) SOPT_THROW("Weights cannot be null");
80  weights_ = w;
81  return *this;
82  }
83 
86  if (value <= 0e0) SOPT_THROW("Weight cannot be negative or null");
87  weights_ = Vector<Real>::Ones(1) * value;
88  return *this;
89  }
90 
92  template <typename... ARGS>
93  typename std::enable_if<sizeof...(ARGS) >= 1, L1TightFrame &>::type Psi(ARGS &&... args) {
94  Psi_ = linear_transform(std::forward<ARGS>(args)...);
95  return *this;
96  }
97 
99  template <typename T0, typename T1>
100  typename std::enable_if<is_complex<Scalar>::value == is_complex<typename T0::Scalar>::value and
102  operator()(Eigen::MatrixBase<T0> &out, Real gamma, Eigen::MatrixBase<T1> const &x) const;
103 
105  template <typename T0>
107  Real const &gamma, Eigen::MatrixBase<T0> const &x) const {
108  return {*this, gamma, x};
109  }
110 
112  template <typename T0, typename T1>
113  typename std::enable_if<is_complex<Scalar>::value == is_complex<typename T0::Scalar>::value and
115  Real>::type
116  objective(Eigen::MatrixBase<T0> const &x, Eigen::MatrixBase<T1> const &z,
117  Real const &gamma) const;
118 
119  protected:
121  Vector<Real> weights_;
122 };
123 
124 template <typename SCALAR>
125 template <typename T0, typename T1>
126 typename std::enable_if<is_complex<SCALAR>::value == is_complex<typename T0::Scalar>::value and
128 L1TightFrame<SCALAR>::operator()(Eigen::MatrixBase<T0> &out, Real gamma,
129  Eigen::MatrixBase<T1> const &x) const {
130  Vector<Scalar> const psit_x = Psi().adjoint() * x;
131  if (weights().size() == 1)
132  out = static_cast<Vector<Scalar>>(
133  Psi() * (soft_threshhold(psit_x, nu() * gamma * weights()(0)) - psit_x)) /
134  nu() +
135  x;
136  else
137  out = static_cast<Vector<Scalar>>(
138  Psi() * (soft_threshhold(psit_x, nu() * gamma * weights()) - psit_x)) /
139  nu() +
140  x;
141  SOPT_LOW_LOG("Prox L1: objective = {}", objective(x, out, gamma));
142 }
143 
144 template <typename SCALAR>
145 template <typename T0, typename T1>
146 typename std::enable_if<is_complex<SCALAR>::value == is_complex<typename T0::Scalar>::value and
148  typename real_type<SCALAR>::type>::type
149 L1TightFrame<SCALAR>::objective(Eigen::MatrixBase<T0> const &x, Eigen::MatrixBase<T1> const &z,
150  Real const &gamma) const {
151 #ifdef SOPT_MPI
152  auto const adj = gamma * sopt::mpi::l1_norm(static_cast<T1>(Psi().adjoint() * z), weights(),
153  adjoint_space_comm());
154  auto const dir = direct_space_comm().all_sum_all(0.5 * (x - z).squaredNorm());
155  return adj + dir;
156 #else
157  return 0.5 * (x - z).squaredNorm() +
158  gamma * sopt::l1_norm(static_cast<T1>(Psi().adjoint() * z), weights());
159 #endif
160 }
161 
168 template <typename SCALAR>
169 class L1 : protected L1TightFrame<SCALAR> {
170  public:
172  class FistaMixing;
174  class NoMixing;
176  class Breaker;
177 
179 #ifdef SOPT_MPI
182 #endif
183 
188 
190  struct Diagnostic {
198  bool good;
200  bool good = false)
201  : niters(niters),
204  good(good) {}
205  };
206 
208  struct DiagnosticAndResult : public Diagnostic {
211  };
212 
214  template <typename T0>
215  Diagnostic operator()(Eigen::MatrixBase<T0> &out, Real gamma, Vector<Scalar> const &x) const {
216  // Note that we *must* call eval on x, in case it is an expression involving out
217  if (gamma <= 0) {
218  apply_constraints(out, x);
219  return Diagnostic(0, 0, 0.5 * (out - x).squaredNorm(), true);
220  }
221 
222  if (fista_mixing())
223  return operator()(out, gamma, x, FistaMixing());
224  else
225  return operator()(out, gamma, x, NoMixing());
226  }
227 
229  template <typename T0>
230  DiagnosticAndResult operator()(Real const &gamma, Eigen::MatrixBase<T0> const &x) const {
231  DiagnosticAndResult result;
232  static_cast<Diagnostic &>(result) = operator()(result.proximal, gamma, x);
233  return result;
234  }
235 
236 #ifdef SOPT_MPI
237  L1(mpi::Communicator const &direct_comm = mpi::Communicator(),
238  mpi::Communicator const &adjoint_comm = mpi::Communicator())
239  : L1TightFrame<SCALAR>(direct_comm, adjoint_comm),
240  itermax_(0),
241  tolerance_(1e-8),
242  positivity_constraint_(false),
243  real_constraint_(false),
244  fista_mixing_(true) {}
245 #else
246  L1()
247  : L1TightFrame<SCALAR>(),
248  itermax_(0),
249  tolerance_(1e-8),
250  positivity_constraint_(false),
251  real_constraint_(false),
252  fista_mixing_(true) {}
253 #endif
254 
255 #define SOPT_MACRO(NAME, TYPE) \
256  TYPE const &NAME() const { return NAME##_; } \
257  L1<Scalar> &NAME(TYPE const &(NAME)) { \
258  NAME##_ = NAME; \
259  return *this; \
260  } \
261  \
262  protected: \
263  TYPE NAME##_; \
264  \
265  public:
268  SOPT_MACRO(itermax, t_uint);
270  SOPT_MACRO(tolerance, Real);
272  SOPT_MACRO(positivity_constraint, bool);
274  SOPT_MACRO(real_constraint, bool);
276  SOPT_MACRO(fista_mixing, bool);
277 #undef SOPT_MACRO
278 
282  template <typename T>
283  L1<Scalar> &weights(Eigen::MatrixBase<T> const &w) {
285  return *this;
286  }
288  L1<Scalar> &weights(Real const &w) {
290  return this;
291  }
292 
294  Real nu() const { return L1TightFrame<Scalar>::nu(); }
296  L1<Scalar> &nu(Real const &nu) {
298  return *this;
299  }
300 
304  template <typename... ARGS>
305  typename std::enable_if<sizeof...(ARGS) >= 1, L1<Scalar> &>::type Psi(ARGS &&... args) {
306  L1TightFrame<Scalar>::Psi(std::forward<ARGS>(args)...);
307  return *this;
308  }
309 
312  template <typename... T>
313  auto tight_frame(T &&... args) const
314  -> decltype(this->L1TightFrame<Scalar>::operator()(std::forward<T>(args)...)) {
315  return this->L1TightFrame<Scalar>::operator()(std::forward<T>(args)...);
316  }
317 
318  protected:
320  template <typename T1>
321  Vector<SCALAR> apply_soft_threshhold(Real gamma, Eigen::MatrixBase<T1> const &x) const;
323  template <typename T0, typename T1>
324  void apply_constraints(Eigen::MatrixBase<T0> &out, Eigen::MatrixBase<T1> const &x) const;
325 
327  template <typename T0, typename MIXING>
328  Diagnostic operator()(Eigen::MatrixBase<T0> &out, Real gamma, Vector<Scalar> const &x,
329  MIXING mixing) const;
330 };
331 
333 template <typename SCALAR>
334 template <typename T0, typename MIXING>
335 typename L1<SCALAR>::Diagnostic L1<SCALAR>::operator()(Eigen::MatrixBase<T0> &out, Real gamma,
336  Vector<Scalar> const &x,
337  MIXING mixing) const {
338  SOPT_MEDIUM_LOG("Starting Proximal L1 operator:");
339  t_uint niters = 0;
340  out = x;
341 
342  Breaker breaker(objective(out, x, gamma), tolerance(), false); // not fista_mixing());
343  SOPT_LOW_LOG(" - [ProxL1] iter {}, prox_fval = {}", niters, breaker.current());
344  Vector<Scalar> const res = Psi().adjoint() * out;
345  Vector<Scalar> u_l1 = 1e0 / nu() * (res - apply_soft_threshhold(gamma, res));
346  apply_constraints(out, x - Psi() * u_l1);
347 
348  // Move on to other iterations
349  for (++niters; niters < itermax() or itermax() == 0; ++niters) {
350  auto const do_break = breaker(objective(x, out, gamma));
351  SOPT_LOW_LOG(" - [ProxL1] iter {}, prox_fval = {}, rel_fval = {}", niters, breaker.current(),
352  breaker.relative_variation());
353  if (do_break) break;
354 
355  Vector<Scalar> const res = u_l1 * nu() + Psi().adjoint() * out;
356  mixing(u_l1, 1e0 / nu() * (res - apply_soft_threshhold(gamma, res)), niters);
357  apply_constraints(out, x - Psi() * u_l1);
358  }
359 
360  if (breaker.two_cycle()) SOPT_WARN("Two-cycle detected when computing L1");
361 
362  if (breaker.converged()) {
363  SOPT_LOW_LOG("Proximal L1 operator converged at {} in {} iterations", breaker.current(),
364  niters);
365  } else
366  SOPT_ERROR("Proximal L1 operator did not converge after {} iterations", niters);
367  return {niters, breaker.relative_variation(), breaker.current(), breaker.converged()};
368 }
369 
370 template <typename SCALAR>
371 template <typename T1>
372 Vector<SCALAR> L1<SCALAR>::apply_soft_threshhold(Real gamma, Eigen::MatrixBase<T1> const &x) const {
373  if (weights().size() == 1)
374  return soft_threshhold(x, gamma * weights()(0));
375  else
376  return soft_threshhold(x, gamma * weights());
377 }
378 
379 template <typename SCALAR>
380 template <typename T0, typename T1>
381 void L1<SCALAR>::apply_constraints(Eigen::MatrixBase<T0> &out,
382  Eigen::MatrixBase<T1> const &x) const {
383  if (positivity_constraint())
384  out = sopt::positive_quadrant(x);
385  else if (real_constraint())
386  out = x.real().template cast<SCALAR>();
387  else
388  out = x;
389 }
390 
391 template <typename SCALAR>
392 class L1<SCALAR>::FistaMixing {
393  public:
394  using Real = typename real_type<SCALAR>::type;
395  FistaMixing() : t(1){};
396  template <typename T1>
397  void operator()(Vector<SCALAR> &previous, Eigen::MatrixBase<T1> const &unmixed, t_uint iter) {
398  // reset
399  if (iter == 0) {
400  previous = unmixed;
401  return;
402  }
403  if (iter <= 1) t = next(1);
404  auto const prior_t = t;
405  t = next(t);
406  auto const alpha = (prior_t - 1) / t;
407  previous = (1e0 + alpha) * unmixed.derived() - alpha * previous;
408  }
409  static Real next(Real t) { return 0.5 + 0.5 * std::sqrt(1e0 + 4e0 * t * t); }
410 
411  private:
412  Real t;
413 };
414 
415 template <typename SCALAR>
416 class L1<SCALAR>::NoMixing {
417  public:
418  template <typename T1>
419  void operator()(Vector<SCALAR> &previous, Eigen::MatrixBase<T1> const &unmixed, t_uint) {
420  previous = unmixed;
421  }
422 };
423 
424 template <typename SCALAR>
425 class L1<SCALAR>::Breaker {
426  public:
427  using Real = typename real_type<SCALAR>::type;
433  Breaker(Real objective, Real tolerance = 1e-8, bool do_two_cycle = true)
434  : tolerance_(tolerance),
435  iter(0),
436  objectives({{objective, 0, 0, 0}}),
437  do_two_cycle(do_two_cycle) {}
440  ++iter;
441  objectives = {{objective, objectives[0], objectives[1], objectives[2]}};
442  return converged() or two_cycle();
443  }
445  Real current() const { return objectives[0]; }
447  Real previous() const { return objectives[1]; }
449  Real relative_variation() const { return std::abs((current() - previous()) / current()); }
452  bool two_cycle() const {
453  return do_two_cycle and iter > 3 and std::abs(objectives[0] - objectives[2]) < tolerance() and
454  std::abs(objectives[1] - objectives[3]) < tolerance();
455  }
456 
458  bool converged() const {
459  // If current ~ 0, then defaults to absolute convergence
460  // This is mainly to avoid a division by zero
461  if (std::abs(current() * 1000) < tolerance()) return std::abs(previous() * 1000) < tolerance();
462  return relative_variation() < tolerance();
463  }
465  Real tolerance() const { return tolerance_; }
468  tolerance_ = tol;
469  return *this;
470  }
471 
472  protected:
473  Real tolerance_;
474  t_uint iter;
475  std::array<Real, 4> objectives;
476  bool do_two_cycle;
477 };
478 } // namespace sopt::proximal
479 
480 #endif
constexpr Scalar tol
Joins together direct and indirect operators.
Computes inner-most element type.
Definition: real_type.h:42
L1 proximal, including linear transform.
Definition: l1_proximal.h:26
L1TightFrame< Scalar > & weights(Eigen::MatrixBase< T > const &w)
Weights of the l1 norm.
Definition: l1_proximal.h:77
Vector< Real > const & weights() const
Weights of the l1 norm.
Definition: l1_proximal.h:74
SOPT_MACRO(nu, Real)
Bound on the squared norm of the operator Ψ
std::enable_if< is_complex< Scalar >::value==is_complex< typename T0::Scalar >::value and is_complex< Scalar >::value==is_complex< typename T1::Scalar >::value >::type operator()(Eigen::MatrixBase< T0 > &out, Real gamma, Eigen::MatrixBase< T1 > const &x) const
Computes proximal for given γ
Definition: l1_proximal.h:128
typename real_type< Scalar >::type Real
Underlying real scalar type.
Definition: l1_proximal.h:31
ProximalExpression< L1TightFrame< Scalar > const &, T0 > operator()(Real const &gamma, Eigen::MatrixBase< T0 > const &x) const
Lazy version.
Definition: l1_proximal.h:106
L1TightFrame< Scalar > & weights(Real const &value)
Set weights to a single value.
Definition: l1_proximal.h:85
std::enable_if< is_complex< Scalar >::value==is_complex< typename T0::Scalar >::value and is_complex< Scalar >::value==is_complex< typename T1::Scalar >::value, Real >::type objective(Eigen::MatrixBase< T0 > const &x, Eigen::MatrixBase< T1 > const &z, Real const &gamma) const
Definition: l1_proximal.h:149
SOPT_MACRO(Psi, LinearTransform< Vector< Scalar >>)
Linear transform applied to input prior to L1 norm.
L1TightFrame &::type Psi(ARGS &&... args)
Definition: l1_proximal.h:93
SCALAR Scalar
Underlying scalar type.
Definition: l1_proximal.h:29
Real previous() const
Current objective.
Definition: l1_proximal.h:447
Real tolerance() const
Tolerance criteria.
Definition: l1_proximal.h:465
bool converged() const
True if relative variation smaller than tolerance.
Definition: l1_proximal.h:458
bool two_cycle() const
Whether we have a cycle of period two.
Definition: l1_proximal.h:452
bool operator()(Real objective)
True if we should break out of loop.
Definition: l1_proximal.h:439
L1< SCALAR >::Breaker & tolerance(Real tol) const
Tolerance criteria.
Definition: l1_proximal.h:467
typename real_type< SCALAR >::type Real
Definition: l1_proximal.h:427
Breaker(Real objective, Real tolerance=1e-8, bool do_two_cycle=true)
Definition: l1_proximal.h:433
Real current() const
Current objective.
Definition: l1_proximal.h:445
Real relative_variation() const
Variation in the objective function.
Definition: l1_proximal.h:449
void operator()(Vector< SCALAR > &previous, Eigen::MatrixBase< T1 > const &unmixed, t_uint iter)
Definition: l1_proximal.h:397
static Real next(Real t)
Definition: l1_proximal.h:409
typename real_type< SCALAR >::type Real
Definition: l1_proximal.h:394
void operator()(Vector< SCALAR > &previous, Eigen::MatrixBase< T1 > const &unmixed, t_uint)
Definition: l1_proximal.h:419
L1 proximal, including linear transform.
Definition: l1_proximal.h:169
auto tight_frame(T &&... args) const -> decltype(this->L1TightFrame< Scalar >::operator()(std::forward< T >(args)...))
Special case if Ψ ia a tight frame.
Definition: l1_proximal.h:313
DiagnosticAndResult operator()(Real const &gamma, Eigen::MatrixBase< T0 > const &x) const
Lazy version.
Definition: l1_proximal.h:230
SOPT_MACRO(tolerance, Real)
Tolerance criteria.
SOPT_MACRO(fista_mixing, bool)
Whether to do fista mixing or not.
L1< Scalar > & nu(Real const &nu)
Sets the bound on the squared norm of the operator Ψ
Definition: l1_proximal.h:296
SOPT_MACRO(real_constraint, bool)
Whether the output should be constrained to be real.
L1< Scalar > &::type Psi(ARGS &&... args)
Definition: l1_proximal.h:305
Diagnostic operator()(Eigen::MatrixBase< T0 > &out, Real gamma, Vector< Scalar > const &x) const
Computes proximal for given γ
Definition: l1_proximal.h:215
L1< Scalar > & weights(Eigen::MatrixBase< T > const &w)
Set weights to an array of values.
Definition: l1_proximal.h:283
typename L1TightFrame< SCALAR >::Real Real
Underlying real scalar type.
Definition: l1_proximal.h:187
Real nu() const
Bounds on the squared norm of the operator Ψ
Definition: l1_proximal.h:294
Vector< Real > const & weights() const
Weights of the l1 norm.
Definition: l1_proximal.h:280
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations before bailing out.
SOPT_MACRO(positivity_constraint, bool)
Whether to apply positivity constraints.
LinearTransform< Vector< Scalar > > const & Psi() const
Linear transform applied to input prior to L1 norm.
Definition: l1_proximal.h:302
L1< Scalar > & weights(Real const &w)
Set weights to a single value.
Definition: l1_proximal.h:288
Expression referencing a lazy proximal function call.
#define SOPT_THROW(MSG)
Definition: exception.h:46
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
#define SOPT_ERROR(...)
\macro Something is definitely wrong, algorithm exits
Definition: logging.h:211
#define SOPT_WARN(...)
\macro Something might be going wrong
Definition: logging.h:213
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:225
Holds some standard proximals.
Definition: l1_proximal.h:17
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
Eigen::CwiseUnaryOp< const details::ProjectPositiveQuadrant< typename T::Scalar >, const T > positive_quadrant(Eigen::DenseBase< T > const &input)
Expression to create projection onto positive quadrant.
Definition: maths.h:60
std::enable_if< std::is_arithmetic< SCALAR >::value or is_complex< SCALAR >::value, SCALAR >::type soft_threshhold(SCALAR const &x, typename real_type< SCALAR >::type const &threshhold)
abs(x) < threshhold ? 0: x - sgn(x) * threshhold
Definition: maths.h:29
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
LinearTransform< Vector< SCALAR > > linear_transform_identity()
Helper function to create a linear transform that's just the identity.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
real_type< typename T0::Scalar >::type l1_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L1 norm.
Definition: maths.h:116
True if underlying type is complex.
Definition: real_type.h:59
Vector< SCALAR > proximal
The proximal value.
Definition: l1_proximal.h:210
How did calling L1 go?
Definition: l1_proximal.h:190
t_uint niters
Number of iterations.
Definition: l1_proximal.h:192
Real objective
Value of the objective function.
Definition: l1_proximal.h:196
Diagnostic(t_uint niters=0, Real relative_variation=0, Real objective=0, bool good=false)
Definition: l1_proximal.h:199
Real relative_variation
Relative variation of the objective function.
Definition: l1_proximal.h:194
bool good
Wether convergence was achieved.
Definition: l1_proximal.h:198
sopt::Vector< Scalar > Vector
Definition: inpainting.cc:28