SOPT
Sparse OPTimisation
reweighted.h
Go to the documentation of this file.
1 #ifndef SOPT_REWEIGHTED_H
2 #define SOPT_REWEIGHTED_H
3 
5 #include "sopt/types.h"
6 
7 #include <algorithm> // for std::max<>
8 #include <limits> // for std::numeric_limits<>
9 #include <utility> // for std::move<>
10 
11 namespace sopt::algorithm {
12 
13 template <typename ALGORITHM>
14 class Reweighted;
15 
17 template <typename ALGORITHM>
18 Reweighted<ALGORITHM> reweighted(ALGORITHM const &algo,
19  typename Reweighted<ALGORITHM>::t_SetWeights const &set_weights,
20  typename Reweighted<ALGORITHM>::t_Reweightee const &reweightee);
21 
33 template <typename ALGORITHM>
34 class Reweighted {
35  public:
37  using Algorithm = ALGORITHM;
39  using Scalar = typename Algorithm::Scalar;
41  using Real = typename real_type<Scalar>::type;
45  using XVector = typename Algorithm::t_Vector;
50  using t_Reweightee = std::function<XVector (const Algorithm &, const XVector &)>;
52  using t_SetWeights = std::function<void (Algorithm &, const WeightVector &)>;
54  using t_DeltaUpdate = std::function<Real (Real)>;
55 
61  bool good;
65  typename Algorithm::DiagnosticAndResult algo;
67  ReweightedResult() : niters(0), good(false), weights(WeightVector::Ones(1)), algo() {}
68  };
69 
70  Reweighted(Algorithm const &algo, t_SetWeights const &setweights, t_Reweightee const &reweightee)
71  : algo_(algo),
72  setweights_(setweights),
73  reweightee_(reweightee),
74  itermax_(std::numeric_limits<t_uint>::max()),
75  min_delta_(0e0),
76  is_converged_(),
77  update_delta_([](Real delta) { return 1e-1 * delta; }) {}
78 
80  Algorithm &algorithm() { return algo_; }
82  Algorithm const &algorithm() const { return algo_; }
85  algo_ = algo;
86  return *this;
87  }
90  algo_ = std::move(algo);
91  return *this;
92  }
93 
95  t_SetWeights const &set_weights() const { return setweights_; }
97  Reweighted<Algorithm> &set_weights(t_SetWeights const &setweights) const {
98  setweights_ = setweights;
99  return *this;
100  }
102  void set_weights(Algorithm &algo, WeightVector const &weights) const {
103  return set_weights()(algo, weights);
104  }
105 
109  reweightee_ = rw;
110  return *this;
111  }
113  t_Reweightee const &reweightee() const { return reweightee_; }
115  XVector reweightee(XVector const &x) const { return reweightee()(algorithm(), x); }
116 
118  t_uint itermax() const { return itermax_; }
120  itermax_ = i;
121  return *this;
122  }
124  Real min_delta() const { return min_delta_; }
126  min_delta_ = min_delta;
127  return *this;
128  }
130  t_IsConverged const &is_converged() const { return is_converged_; }
131  Reweighted &is_converged(t_IsConverged const &convergence) {
132  is_converged_ = convergence;
133  return *this;
134  }
135  bool is_converged(XVector const &x) const { return is_converged() ? is_converged()(x) : false; }
136 
139  template <typename INPUT>
140  typename std::enable_if<not(std::is_same<INPUT, typename Algorithm::DiagnosticAndResult>::value or
141  std::is_same<INPUT, ReweightedResult>::value),
142  ReweightedResult>::type
143  operator()(INPUT const &input) const;
146  ReweightedResult operator()() const;
148  ReweightedResult operator()(typename Algorithm::DiagnosticAndResult const &warm) const;
150  ReweightedResult operator()(ReweightedResult const &warm) const;
151 
153  Real update_delta(Real delta) const { return update_delta()(delta); }
155  t_DeltaUpdate const &update_delta() const { return update_delta_; }
157  Reweighted<Algorithm> update_delta(t_DeltaUpdate const &ud) const { return update_delta_ = ud; }
158 
159  protected:
161  Algorithm algo_;
163  t_SetWeights setweights_;
166  t_Reweightee reweightee_;
168  t_uint itermax_;
170  Real min_delta_;
172  t_IsConverged is_converged_;
174  t_DeltaUpdate update_delta_;
175 };
176 
177 template <typename ALGORITHM>
178 template <typename INPUT>
179 typename std::enable_if<
180  not(std::is_same<INPUT, typename ALGORITHM::DiagnosticAndResult>::value or
181  std::is_same<INPUT, typename Reweighted<ALGORITHM>::ReweightedResult>::value),
182  typename Reweighted<ALGORITHM>::ReweightedResult>::type
183 Reweighted<ALGORITHM>::operator()(INPUT const &input) const {
184  Algorithm algo = algorithm();
185  set_weights(algo, WeightVector::Ones(1));
186  return operator()(algo(input));
187 }
188 
189 template <typename ALGORITHM>
191  Algorithm algo = algorithm();
192  set_weights(algo, WeightVector::Ones(1));
193  return operator()(algo());
194 }
195 
196 template <typename ALGORITHM>
198  typename Algorithm::DiagnosticAndResult const &warm) const {
199  ReweightedResult result;
200  result.algo = warm;
201  result.weights = WeightVector::Ones(1);
202  return operator()(result);
203 }
204 
205 template <typename ALGORITHM>
207  ReweightedResult const &warm) const {
208  SOPT_HIGH_LOG("Starting reweighted scheme");
209  // Copies inner algorithm, so that operator() can be constant
210  Algorithm algo(algorithm());
211  ReweightedResult result(warm);
212 
213  auto delta = std::max(standard_deviation(reweightee(warm.algo.x)), min_delta());
214  SOPT_LOW_LOG("- Initial delta: {}", delta);
215  for (result.niters = 0; result.niters < itermax(); ++result.niters) {
216  SOPT_LOW_LOG("Reweigting iteration {}/{} ", result.niters, itermax());
217  SOPT_LOW_LOG(" - delta: {}", delta);
218  result.weights = delta / (delta + reweightee(result.algo.x).array().abs());
219  set_weights(algo, result.weights);
220  result.algo = algo(result.algo);
221  if (is_converged(result.algo.x)) {
222  SOPT_MEDIUM_LOG("Reweighting scheme did converge in {} iterations", result.niters);
223  result.good = true;
224  break;
225  }
226  delta = std::max(min_delta(), update_delta(delta));
227  }
228  // result is always good if no convergence function is defined
229  if (not is_converged())
230  result.good = true;
231  else if (not result.good)
232  SOPT_ERROR("Reweighting scheme did *not* converge in {} iterations", itermax());
233  return result;
234 }
235 
237 template <typename ALGORITHM>
238 Reweighted<ALGORITHM> reweighted(ALGORITHM const &algo,
239  typename Reweighted<ALGORITHM>::t_SetWeights const &set_weights,
240  typename Reweighted<ALGORITHM>::t_Reweightee const &reweightee) {
241  return {algo, set_weights, reweightee};
242 }
243 
244 template <typename SCALAR>
245 class ImagingProximalADMM;
246 template <typename ALGORITHM>
247 class PositiveQuadrant;
248 template <typename T>
249 Eigen::CwiseUnaryOp<const details::ProjectPositiveQuadrant<typename T::Scalar>, const T>
250 positive_quadrant(Eigen::DenseBase<T> const &input);
251 
252 template <typename SCALAR>
254  ImagingProximalADMM<SCALAR> const &algo) {
255  auto const posq = positive_quadrant(algo);
256  using Algorithm = typename std::remove_const<decltype(posq)>::type;
257  using RW = Reweighted<Algorithm>;
258  auto const reweightee =
259  [](Algorithm const &posq, typename RW::XVector const &x) -> typename RW::XVector {
260  return posq.algorithm().Psi().adjoint() * x;
261  };
262  auto const set_weights = [](Algorithm &posq, typename RW::WeightVector const &weights) -> void {
263  posq.algorithm().l1_proximal_weights(weights);
264  };
265  return {posq, set_weights, reweightee};
266 }
267 
268 template <typename SCALAR>
269 class PrimalDual;
270 template <typename ALGORITHM>
271 class PositiveQuadrant;
272 template <typename T>
273 Eigen::CwiseUnaryOp<const details::ProjectPositiveQuadrant<typename T::Scalar>, const T>
274 positive_quadrant(Eigen::DenseBase<T> const &input);
275 
276 template <typename SCALAR>
278  auto const posq = positive_quadrant(algo);
279  using Algorithm = typename std::remove_const<decltype(posq)>::type;
280  using RW = Reweighted<Algorithm>;
281  auto const reweightee =
282  [](Algorithm const &posq, typename RW::XVector const &x) -> typename RW::XVector {
283  return posq.algorithm().Psi().adjoint() * x;
284  };
285  auto const set_weights = [](Algorithm &posq, typename RW::WeightVector const &weights) -> void {
286  posq.algorithm().l1_proximal_weights(weights);
287  };
288  return {posq, set_weights, reweightee};
289 }
290 
291 } // namespace sopt::algorithm
292 #endif
sopt::Vector< Scalar > t_Vector
sopt::t_real Scalar
Primal Dual Algorithm.
Definition: primal_dual.h:24
L0-approximation algorithm, through reweighting.
Definition: reweighted.h:34
Reweighted & min_delta(Real min_delta)
Definition: reweighted.h:125
void set_weights(Algorithm &algo, WeightVector const &weights) const
Sets the weights on the underlying algorithm.
Definition: reweighted.h:102
t_uint itermax() const
Maximum number of reweighted iterations.
Definition: reweighted.h:118
ALGORITHM Algorithm
Inner-loop algorithm.
Definition: reweighted.h:37
Reweighted & itermax(t_uint i)
Definition: reweighted.h:119
XVector reweightee(XVector const &x) const
Forwards to the reweightee function.
Definition: reweighted.h:115
t_Reweightee const & reweightee() const
Function that needs to be reweighted.
Definition: reweighted.h:113
Algorithm & algorithm()
Underlying "inner-loop" algorithm.
Definition: reweighted.h:80
bool is_converged(XVector const &x) const
Definition: reweighted.h:135
t_SetWeights const & set_weights() const
Function to reset the weights in the algorithm.
Definition: reweighted.h:95
t_IsConverged const & is_converged() const
Checks convergence of the reweighting scheme.
Definition: reweighted.h:130
std::function< Real(Real)> t_DeltaUpdate
Function to update delta at each turn.
Definition: reweighted.h:54
Reweighted(Algorithm const &algo, t_SetWeights const &setweights, t_Reweightee const &reweightee)
Definition: reweighted.h:70
std::function< void(Algorithm &, const WeightVector &)> t_SetWeights
Type of the function to set weights.
Definition: reweighted.h:52
Reweighted & is_converged(t_IsConverged const &convergence)
Definition: reweighted.h:131
Reweighted< Algorithm > & set_weights(t_SetWeights const &setweights) const
Function to reset the weights in the algorithm.
Definition: reweighted.h:97
ReweightedResult operator()() const
Performs reweighting.
Definition: reweighted.h:190
Vector< Real > WeightVector
Weight vector type.
Definition: reweighted.h:43
t_DeltaUpdate const & update_delta() const
Updates delta.
Definition: reweighted.h:155
std::function< XVector(const Algorithm &, const XVector &)> t_Reweightee
Type of the function that is subject to reweighting.
Definition: reweighted.h:50
Reweighted< Algorithm > update_delta(t_DeltaUpdate const &ud) const
Updates delta.
Definition: reweighted.h:157
Algorithm const & algorithm() const
Underlying "inner-loop" algorithm.
Definition: reweighted.h:82
Real min_delta() const
Lower limit for delta.
Definition: reweighted.h:124
typename real_type< Scalar >::type Real
Real type.
Definition: reweighted.h:41
Real update_delta(Real delta) const
Updates delta.
Definition: reweighted.h:153
Reweighted< Algorithm > & reweightee(t_Reweightee const &rw)
Definition: reweighted.h:108
typename Algorithm::t_Vector XVector
Type of then underlying vectors.
Definition: reweighted.h:45
Reweighted< Algorithm > & algorithm(Algorithm const &algo)
Sets the underlying "inner-loop" algorithm.
Definition: reweighted.h:84
Reweighted< Algorithm > & algorithm(Algorithm &&algo)
Sets the underlying "inner-loop" algorithm.
Definition: reweighted.h:89
typename Algorithm::Scalar Scalar
Scalar type.
Definition: reweighted.h:39
ConvergenceFunction< Scalar > t_IsConverged
Type of the convergence function.
Definition: reweighted.h:47
Computes inner-most element type.
Definition: real_type.h:42
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
#define SOPT_ERROR(...)
\macro Something is definitely wrong, algorithm exits
Definition: logging.h:211
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:225
PositiveQuadrant< ALGORITHM > positive_quadrant(ALGORITHM const &algo)
Extended algorithm where the solution is projected on the positive quadrant.
Reweighted< ALGORITHM > reweighted(ALGORITHM const &algo, typename Reweighted< ALGORITHM >::t_SetWeights const &set_weights, typename Reweighted< ALGORITHM >::t_Reweightee const &reweightee)
Factory function to create an l0-approximation by reweighting an l1 norm.
Definition: reweighted.h:238
real_type< typename T::Scalar >::type standard_deviation(Eigen::ArrayBase< T > const &x)
Computes the standard deviation of a vector.
Definition: maths.h:16
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
std::function< bool(Vector< SCALAR > const &)> ConvergenceFunction
Typical function signature for convergence.
Definition: types.h:52
output from running reweighting scheme
Definition: reweighted.h:57
t_uint niters
Number of iterations (outer loop)
Definition: reweighted.h:59
bool good
Wether convergence was achieved.
Definition: reweighted.h:61
Algorithm::DiagnosticAndResult algo
Result from last inner loop.
Definition: reweighted.h:65
WeightVector weights
Weights at last iteration.
Definition: reweighted.h:63