1 #ifndef SOPT_L1_PROXIMAL_H
2 #define SOPT_L1_PROXIMAL_H
4 #include "sopt/config.h"
25 template <
typename SCALAR>
39 L1TightFrame(mpi::Communicator
const &direct_comm = mpi::Communicator(),
40 mpi::Communicator
const &adjoint_comm = mpi::Communicator())
43 direct_space_comm_(direct_comm),
44 adjoint_space_comm_(adjoint_comm),
51 #define SOPT_MACRO(NAME, TYPE) \
52 TYPE const &NAME() const { return NAME##_; } \
53 L1TightFrame<Scalar> &NAME(TYPE const &(NAME)) { \
68 SOPT_MACRO(direct_space_comm, mpi::Communicator);
70 SOPT_MACRO(adjoint_space_comm, mpi::Communicator);
78 if ((w.array() < 0e0).any())
SOPT_THROW(
"Weights cannot be negative");
79 if (w.stableNorm() < 1e-12)
SOPT_THROW(
"Weights cannot be null");
86 if (value <= 0e0)
SOPT_THROW(
"Weight cannot be negative or null");
92 template <
typename... ARGS>
93 typename std::enable_if<
sizeof...(ARGS) >= 1,
L1TightFrame &>::type
Psi(ARGS &&... args) {
99 template <
typename T0,
typename T1>
102 operator()(Eigen::MatrixBase<T0> &out,
Real gamma, Eigen::MatrixBase<T1>
const &x)
const;
105 template <
typename T0>
107 Real const &gamma, Eigen::MatrixBase<T0>
const &x)
const {
108 return {*
this, gamma, x};
112 template <
typename T0,
typename T1>
116 objective(Eigen::MatrixBase<T0>
const &x, Eigen::MatrixBase<T1>
const &z,
117 Real const &gamma)
const;
124 template <
typename SCALAR>
125 template <
typename T0,
typename T1>
129 Eigen::MatrixBase<T1>
const &x)
const {
131 if (weights().size() == 1)
133 Psi() * (
soft_threshhold(psit_x, nu() * gamma * weights()(0)) - psit_x)) /
138 Psi() * (
soft_threshhold(psit_x, nu() * gamma * weights()) - psit_x)) /
141 SOPT_LOW_LOG(
"Prox L1: objective = {}", objective(x, out, gamma));
144 template <
typename SCALAR>
145 template <
typename T0,
typename T1>
150 Real const &gamma)
const {
152 auto const adj = gamma *
sopt::mpi::l1_norm(
static_cast<T1
>(Psi().adjoint() * z), weights(),
153 adjoint_space_comm());
154 auto const dir = direct_space_comm().all_sum_all(0.5 * (x - z).squaredNorm());
157 return 0.5 * (x - z).squaredNorm() +
158 gamma *
sopt::l1_norm(
static_cast<T1
>(Psi().adjoint() * z), weights());
168 template <
typename SCALAR>
214 template <
typename T0>
218 apply_constraints(out, x);
219 return Diagnostic(0, 0, 0.5 * (out - x).squaredNorm(),
true);
229 template <
typename T0>
237 L1(mpi::Communicator
const &direct_comm = mpi::Communicator(),
238 mpi::Communicator
const &adjoint_comm = mpi::Communicator())
242 positivity_constraint_(false),
243 real_constraint_(false),
244 fista_mixing_(true) {}
250 positivity_constraint_(false),
251 real_constraint_(false),
252 fista_mixing_(true) {}
255 #define SOPT_MACRO(NAME, TYPE) \
256 TYPE const &NAME() const { return NAME##_; } \
257 L1<Scalar> &NAME(TYPE const &(NAME)) { \
282 template <
typename T>
304 template <
typename... ARGS>
305 typename std::enable_if<
sizeof...(ARGS) >= 1,
L1<Scalar> &>::type
Psi(ARGS &&... args) {
312 template <
typename... T>
320 template <
typename T1>
321 Vector<SCALAR> apply_soft_threshhold(
Real gamma, Eigen::MatrixBase<T1>
const &x)
const;
323 template <
typename T0,
typename T1>
324 void apply_constraints(Eigen::MatrixBase<T0> &out, Eigen::MatrixBase<T1>
const &x)
const;
327 template <
typename T0,
typename MIXING>
329 MIXING mixing)
const;
333 template <
typename SCALAR>
334 template <
typename T0,
typename MIXING>
337 MIXING mixing)
const {
342 Breaker breaker(objective(out, x, gamma), tolerance(),
false);
343 SOPT_LOW_LOG(
" - [ProxL1] iter {}, prox_fval = {}", niters, breaker.current());
345 Vector<Scalar> u_l1 = 1e0 / nu() * (res - apply_soft_threshhold(gamma, res));
346 apply_constraints(out, x - Psi() * u_l1);
349 for (++niters; niters < itermax() or itermax() == 0; ++niters) {
350 auto const do_break = breaker(objective(x, out, gamma));
351 SOPT_LOW_LOG(
" - [ProxL1] iter {}, prox_fval = {}, rel_fval = {}", niters, breaker.current(),
352 breaker.relative_variation());
356 mixing(u_l1, 1e0 / nu() * (res - apply_soft_threshhold(gamma, res)), niters);
357 apply_constraints(out, x - Psi() * u_l1);
360 if (breaker.two_cycle())
SOPT_WARN(
"Two-cycle detected when computing L1");
362 if (breaker.converged()) {
363 SOPT_LOW_LOG(
"Proximal L1 operator converged at {} in {} iterations", breaker.current(),
366 SOPT_ERROR(
"Proximal L1 operator did not converge after {} iterations", niters);
367 return {niters, breaker.relative_variation(), breaker.current(), breaker.converged()};
370 template <
typename SCALAR>
371 template <
typename T1>
372 Vector<SCALAR> L1<SCALAR>::apply_soft_threshhold(Real gamma, Eigen::MatrixBase<T1>
const &x)
const {
373 if (weights().size() == 1)
379 template <
typename SCALAR>
380 template <
typename T0,
typename T1>
381 void L1<SCALAR>::apply_constraints(Eigen::MatrixBase<T0> &out,
382 Eigen::MatrixBase<T1>
const &x)
const {
383 if (positivity_constraint())
385 else if (real_constraint())
386 out = x.real().template cast<SCALAR>();
391 template <
typename SCALAR>
396 template <
typename T1>
403 if (iter <= 1) t = next(1);
404 auto const prior_t = t;
406 auto const alpha = (prior_t - 1) / t;
407 previous = (1e0 + alpha) * unmixed.derived() - alpha * previous;
409 static Real next(
Real t) {
return 0.5 + 0.5 * std::sqrt(1e0 + 4e0 * t * t); }
415 template <
typename SCALAR>
418 template <
typename T1>
424 template <
typename SCALAR>
434 : tolerance_(tolerance),
437 do_two_cycle(do_two_cycle) {}
441 objectives = {{
objective, objectives[0], objectives[1], objectives[2]}};
442 return converged() or two_cycle();
453 return do_two_cycle and iter > 3 and std::abs(objectives[0] - objectives[2]) < tolerance() and
454 std::abs(objectives[1] - objectives[3]) < tolerance();
461 if (std::abs(current() * 1000) < tolerance())
return std::abs(previous() * 1000) < tolerance();
462 return relative_variation() < tolerance();
475 std::array<Real, 4> objectives;
Computes inner-most element type.
L1 proximal, including linear transform.
L1TightFrame< Scalar > & weights(Eigen::MatrixBase< T > const &w)
Weights of the l1 norm.
Vector< Real > const & weights() const
Weights of the l1 norm.
SOPT_MACRO(nu, Real)
Bound on the squared norm of the operator Ψ
std::enable_if< is_complex< Scalar >::value==is_complex< typename T0::Scalar >::value and is_complex< Scalar >::value==is_complex< typename T1::Scalar >::value >::type operator()(Eigen::MatrixBase< T0 > &out, Real gamma, Eigen::MatrixBase< T1 > const &x) const
Computes proximal for given γ
typename real_type< Scalar >::type Real
Underlying real scalar type.
ProximalExpression< L1TightFrame< Scalar > const &, T0 > operator()(Real const &gamma, Eigen::MatrixBase< T0 > const &x) const
Lazy version.
L1TightFrame< Scalar > & weights(Real const &value)
Set weights to a single value.
std::enable_if< is_complex< Scalar >::value==is_complex< typename T0::Scalar >::value and is_complex< Scalar >::value==is_complex< typename T1::Scalar >::value, Real >::type objective(Eigen::MatrixBase< T0 > const &x, Eigen::MatrixBase< T1 > const &z, Real const &gamma) const
SOPT_MACRO(Psi, LinearTransform< Vector< Scalar >>)
Linear transform applied to input prior to L1 norm.
L1TightFrame &::type Psi(ARGS &&... args)
SCALAR Scalar
Underlying scalar type.
Real previous() const
Current objective.
Real tolerance() const
Tolerance criteria.
bool converged() const
True if relative variation smaller than tolerance.
bool two_cycle() const
Whether we have a cycle of period two.
bool operator()(Real objective)
True if we should break out of loop.
L1< SCALAR >::Breaker & tolerance(Real tol) const
Tolerance criteria.
typename real_type< SCALAR >::type Real
Breaker(Real objective, Real tolerance=1e-8, bool do_two_cycle=true)
Real current() const
Current objective.
Real relative_variation() const
Variation in the objective function.
void operator()(Vector< SCALAR > &previous, Eigen::MatrixBase< T1 > const &unmixed, t_uint iter)
typename real_type< SCALAR >::type Real
void operator()(Vector< SCALAR > &previous, Eigen::MatrixBase< T1 > const &unmixed, t_uint)
L1 proximal, including linear transform.
auto tight_frame(T &&... args) const -> decltype(this->L1TightFrame< Scalar >::operator()(std::forward< T >(args)...))
Special case if Ψ ia a tight frame.
DiagnosticAndResult operator()(Real const &gamma, Eigen::MatrixBase< T0 > const &x) const
Lazy version.
SOPT_MACRO(tolerance, Real)
Tolerance criteria.
SOPT_MACRO(fista_mixing, bool)
Whether to do fista mixing or not.
L1< Scalar > & nu(Real const &nu)
Sets the bound on the squared norm of the operator Ψ
SOPT_MACRO(real_constraint, bool)
Whether the output should be constrained to be real.
L1< Scalar > &::type Psi(ARGS &&... args)
Diagnostic operator()(Eigen::MatrixBase< T0 > &out, Real gamma, Vector< Scalar > const &x) const
Computes proximal for given γ
L1< Scalar > & weights(Eigen::MatrixBase< T > const &w)
Set weights to an array of values.
typename L1TightFrame< SCALAR >::Real Real
Underlying real scalar type.
Real nu() const
Bounds on the squared norm of the operator Ψ
Vector< Real > const & weights() const
Weights of the l1 norm.
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations before bailing out.
SOPT_MACRO(positivity_constraint, bool)
Whether to apply positivity constraints.
LinearTransform< Vector< Scalar > > const & Psi() const
Linear transform applied to input prior to L1 norm.
L1< Scalar > & weights(Real const &w)
Set weights to a single value.
Expression referencing a lazy proximal function call.
#define SOPT_LOW_LOG(...)
Low priority message.
#define SOPT_ERROR(...)
\macro Something is definitely wrong, algorithm exits
#define SOPT_WARN(...)
\macro Something might be going wrong
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
Holds some standard proximals.
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
Eigen::CwiseUnaryOp< const details::ProjectPositiveQuadrant< typename T::Scalar >, const T > positive_quadrant(Eigen::DenseBase< T > const &input)
Expression to create projection onto positive quadrant.
std::enable_if< std::is_arithmetic< SCALAR >::value or is_complex< SCALAR >::value, SCALAR >::type soft_threshhold(SCALAR const &x, typename real_type< SCALAR >::type const &threshhold)
abs(x) < threshhold ? 0: x - sgn(x) * threshhold
size_t t_uint
Root of the type hierarchy for unsigned integers.
LinearTransform< Vector< SCALAR > > linear_transform_identity()
Helper function to create a linear transform that's just the identity.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
real_type< typename T0::Scalar >::type l1_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L1 norm.
True if underlying type is complex.
Vector< SCALAR > proximal
The proximal value.
t_uint niters
Number of iterations.
Real objective
Value of the objective function.
Diagnostic(t_uint niters=0, Real relative_variation=0, Real objective=0, bool good=false)
Real relative_variation
Relative variation of the objective function.
bool good
Wether convergence was achieved.
sopt::Vector< Scalar > Vector