1 #ifndef SOPT_REWEIGHTED_H
2 #define SOPT_REWEIGHTED_H
13 template <
typename ALGORITHM>
17 template <
typename ALGORITHM>
18 Reweighted<ALGORITHM>
reweighted(ALGORITHM
const &algo,
33 template <
typename ALGORITHM>
65 typename Algorithm::DiagnosticAndResult
algo;
72 setweights_(setweights),
74 itermax_(std::numeric_limits<
t_uint>::max()),
77 update_delta_([](
Real delta) {
return 1e-1 * delta; }) {}
90 algo_ = std::move(algo);
98 setweights_ = setweights;
132 is_converged_ = convergence;
139 template <
typename INPUT>
140 typename std::enable_if<not(std::is_same<INPUT, typename Algorithm::DiagnosticAndResult>::value or
141 std::is_same<INPUT, ReweightedResult>::value),
142 ReweightedResult>::type
148 ReweightedResult
operator()(
typename Algorithm::DiagnosticAndResult
const &warm)
const;
150 ReweightedResult
operator()(ReweightedResult
const &warm)
const;
177 template <
typename ALGORITHM>
178 template <
typename INPUT>
179 typename std::enable_if<
180 not(std::is_same<INPUT, typename ALGORITHM::DiagnosticAndResult>::value or
181 std::is_same<INPUT,
typename Reweighted<ALGORITHM>::ReweightedResult>::value),
182 typename Reweighted<ALGORITHM>::ReweightedResult>::type
185 set_weights(algo, WeightVector::Ones(1));
186 return operator()(algo(input));
189 template <
typename ALGORITHM>
192 set_weights(algo, WeightVector::Ones(1));
193 return operator()(algo());
196 template <
typename ALGORITHM>
198 typename Algorithm::DiagnosticAndResult
const &warm)
const {
201 result.
weights = WeightVector::Ones(1);
202 return operator()(result);
205 template <
typename ALGORITHM>
218 result.
weights = delta / (delta + reweightee(result.
algo.x).array().abs());
219 set_weights(algo, result.
weights);
221 if (is_converged(result.
algo.x)) {
226 delta = std::max(min_delta(), update_delta(delta));
229 if (not is_converged())
231 else if (not result.
good)
232 SOPT_ERROR(
"Reweighting scheme did *not* converge in {} iterations", itermax());
237 template <
typename ALGORITHM>
241 return {algo, set_weights, reweightee};
244 template <
typename SCALAR>
245 class ImagingProximalADMM;
246 template <
typename ALGORITHM>
247 class PositiveQuadrant;
248 template <
typename T>
249 Eigen::CwiseUnaryOp<const details::ProjectPositiveQuadrant<typename T::Scalar>,
const T>
252 template <
typename SCALAR>
256 using Algorithm =
typename std::remove_const<decltype(posq)>::type;
258 auto const reweightee =
259 [](Algorithm
const &posq,
typename RW::XVector
const &x) ->
typename RW::XVector {
260 return posq.
algorithm().Psi().adjoint() * x;
262 auto const set_weights = [](Algorithm &posq,
typename RW::WeightVector
const &weights) ->
void {
263 posq.algorithm().l1_proximal_weights(weights);
265 return {posq, set_weights, reweightee};
268 template <
typename SCALAR>
270 template <
typename ALGORITHM>
271 class PositiveQuadrant;
272 template <
typename T>
273 Eigen::CwiseUnaryOp<const details::ProjectPositiveQuadrant<typename T::Scalar>,
const T>
276 template <
typename SCALAR>
279 using Algorithm =
typename std::remove_const<decltype(posq)>::type;
281 auto const reweightee =
282 [](Algorithm
const &posq,
typename RW::XVector
const &x) ->
typename RW::XVector {
283 return posq.
algorithm().Psi().adjoint() * x;
285 auto const set_weights = [](Algorithm &posq,
typename RW::WeightVector
const &weights) ->
void {
286 posq.algorithm().l1_proximal_weights(weights);
288 return {posq, set_weights, reweightee};
sopt::Vector< Scalar > t_Vector
L0-approximation algorithm, through reweighting.
Reweighted & min_delta(Real min_delta)
void set_weights(Algorithm &algo, WeightVector const &weights) const
Sets the weights on the underlying algorithm.
t_uint itermax() const
Maximum number of reweighted iterations.
ALGORITHM Algorithm
Inner-loop algorithm.
Reweighted & itermax(t_uint i)
XVector reweightee(XVector const &x) const
Forwards to the reweightee function.
t_Reweightee const & reweightee() const
Function that needs to be reweighted.
Algorithm & algorithm()
Underlying "inner-loop" algorithm.
bool is_converged(XVector const &x) const
t_SetWeights const & set_weights() const
Function to reset the weights in the algorithm.
t_IsConverged const & is_converged() const
Checks convergence of the reweighting scheme.
std::function< Real(Real)> t_DeltaUpdate
Function to update delta at each turn.
Reweighted(Algorithm const &algo, t_SetWeights const &setweights, t_Reweightee const &reweightee)
std::function< void(Algorithm &, const WeightVector &)> t_SetWeights
Type of the function to set weights.
Reweighted & is_converged(t_IsConverged const &convergence)
Reweighted< Algorithm > & set_weights(t_SetWeights const &setweights) const
Function to reset the weights in the algorithm.
ReweightedResult operator()() const
Performs reweighting.
Vector< Real > WeightVector
Weight vector type.
t_DeltaUpdate const & update_delta() const
Updates delta.
std::function< XVector(const Algorithm &, const XVector &)> t_Reweightee
Type of the function that is subject to reweighting.
Reweighted< Algorithm > update_delta(t_DeltaUpdate const &ud) const
Updates delta.
Algorithm const & algorithm() const
Underlying "inner-loop" algorithm.
Real min_delta() const
Lower limit for delta.
typename real_type< Scalar >::type Real
Real type.
Real update_delta(Real delta) const
Updates delta.
Reweighted< Algorithm > & reweightee(t_Reweightee const &rw)
typename Algorithm::t_Vector XVector
Type of then underlying vectors.
Reweighted< Algorithm > & algorithm(Algorithm const &algo)
Sets the underlying "inner-loop" algorithm.
Reweighted< Algorithm > & algorithm(Algorithm &&algo)
Sets the underlying "inner-loop" algorithm.
typename Algorithm::Scalar Scalar
Scalar type.
ConvergenceFunction< Scalar > t_IsConverged
Type of the convergence function.
Computes inner-most element type.
#define SOPT_LOW_LOG(...)
Low priority message.
#define SOPT_HIGH_LOG(...)
High priority message.
#define SOPT_ERROR(...)
\macro Something is definitely wrong, algorithm exits
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
PositiveQuadrant< ALGORITHM > positive_quadrant(ALGORITHM const &algo)
Extended algorithm where the solution is projected on the positive quadrant.
Reweighted< ALGORITHM > reweighted(ALGORITHM const &algo, typename Reweighted< ALGORITHM >::t_SetWeights const &set_weights, typename Reweighted< ALGORITHM >::t_Reweightee const &reweightee)
Factory function to create an l0-approximation by reweighting an l1 norm.
real_type< typename T::Scalar >::type standard_deviation(Eigen::ArrayBase< T > const &x)
Computes the standard deviation of a vector.
size_t t_uint
Root of the type hierarchy for unsigned integers.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
std::function< bool(Vector< SCALAR > const &)> ConvergenceFunction
Typical function signature for convergence.
output from running reweighting scheme
t_uint niters
Number of iterations (outer loop)
ReweightedResult()
Default construction.
bool good
Wether convergence was achieved.
Algorithm::DiagnosticAndResult algo
Result from last inner loop.
WeightVector weights
Weights at last iteration.