SOPT
Sparse OPTimisation
l1_non_diff_function.h
Go to the documentation of this file.
1 #ifndef SOPT_L1_NON_DIFF_FUNCTION_H
2 #define SOPT_L1_NON_DIFF_FUNCTION_H
3 // TODO: Clean up unnecessary includes
4 #include "sopt/config.h"
5 #include <numeric>
6 #include <tuple>
7 #include <utility>
8 #include "sopt/exception.h"
10 #include "sopt/linear_transform.h"
11 #include "sopt/logging.h"
12 #include "sopt/proximal.h"
14 #include "sopt/types.h"
16 #include "sopt/l1_proximal.h"
17 
18 #ifdef SOPT_MPI
19  #include "sopt/mpi/communicator.h"
20 #endif
21 
22 namespace sopt::algorithm {
23 
24 // Implementation of non differentiable function g(x) = l1 norm.
25 // Proximal operator is implemented using the private l1_proximal object.
26 template <typename SCALAR>
27 class L1GProximal : public NonDifferentiableFunc<SCALAR> {
28 
29 public:
30  using FB = ForwardBackward<SCALAR>;
31  using Real = typename FB::Real;
32  using Scalar = typename FB::Scalar;
33  using t_Vector = typename FB::t_Vector;
34  using t_Proximal = typename FB::t_Proximal;
36 
37  // In the constructor we need to construct the private l1_proximal_
38  // object that contains the real implementation details. The tight_frame
39  // parameter is required for internal logic in l1_proximal
40  L1GProximal(bool tight_frame = false)
41  : tight_frame_ (tight_frame),
42  l1_proximal_() {}
44 
45 // Implements the interface in GProximal
46 
47  // Print log message with the correct norms
48  void log_message() const override {
49  SOPT_HIGH_LOG("Performing Forward Backward with L1 and L2 norms");
50  }
51 
52  // Return the norm associated with this implementation
53  Real function(t_Vector const &x) const override {
54  auto &weights = l1_proximal_weights();
55  auto input = static_cast<t_Vector>(Psi().adjoint() * x);
56  return sopt::l1_norm(input, weights);
57  }
58 
59  // Return g_proximal as a lambda function. Used in operator() in base class.
60  t_Proximal proximal_operator() const override {
61  return [this](t_Vector &out, Real gamma, t_Vector const &x) {
62  this -> l1_proximal(out, gamma, x);
63  };
64  }
65 
68  t_LinearTransform const &Psi() const override {
69  return l1_proximal().Psi();
70  }
71 
72 
73 // All the public properties below are specific to the l1 proximal
74 // and therefore not part of the interface
75 
78  proximal::L1<Scalar> &l1_proximal() { return l1_proximal_; }
79  proximal::L1<Scalar> const &l1_proximal() const { return l1_proximal_; }
81  l1_proximal_ = arg;
82  return *this;
83  }
84 
85 // This macro creates get/setters that point to l1_proximal
86 // In practice, we end up with a bunch of functions that make it simpler to set or get values
87 // associated with the two proximal operators.
88 // E.g.: `paddm.l1_proximal_itermax(100).l1_proximal_tolerance(1e-4)`.
89 // ~~~
90 #define SOPT_MACRO(VAR, TYPE) \
91  \
92  TYPE const &l1_proximal_##VAR() const { return l1_proximal().VAR(); } \
93  \
94  L1GProximal<SCALAR> &l1_proximal_##VAR(TYPE const ARG) { \
95  l1_proximal().VAR(ARG); \
96  return *this; \
97  }
98 
99  SOPT_MACRO(itermax, t_uint);
100  SOPT_MACRO(tolerance, Real);
101  SOPT_MACRO(positivity_constraint, bool);
102  SOPT_MACRO(real_constraint, bool);
105 #ifdef SOPT_MPI
106  SOPT_MACRO(adjoint_space_comm, mpi::Communicator);
107  SOPT_MACRO(direct_space_comm, mpi::Communicator);
108 #endif
109 #undef SOPT_MACRO
110 
112  template <typename... ARGS>
113  typename std::enable_if<sizeof...(ARGS) >= 1, L1GProximal<SCALAR> &>::type Psi(
114  ARGS &&... args) {
115  l1_proximal().Psi(std::forward<ARGS>(args)...);
116  return *this;
117  }
118 
119 protected:
120 
121  bool tight_frame_;
122  proximal::L1<Scalar> l1_proximal_;
123 
124  // Helper functions for calling l1_proximal
126  template <typename T0, typename T1>
127  typename proximal::L1<Scalar>::Diagnostic l1_proximal(Eigen::MatrixBase<T0> &out, Real gamma,
128  Eigen::MatrixBase<T1> const &x) const {
129  return l1_proximal_real_constraint()
130  ? call_l1_proximal(out, gamma, x.real().template cast<typename T1::Scalar>())
131  : call_l1_proximal(out, gamma, x);
132  }
133 
135  template <typename T0, typename T1>
136  typename proximal::L1<Scalar>::Diagnostic call_l1_proximal(Eigen::MatrixBase<T0> &out, Real gamma,
137  Eigen::MatrixBase<T1> const &x) const {
138  if (tight_frame_) {
139  l1_proximal().tight_frame(out, gamma, x);
140  return {0, 0, l1_proximal().objective(x, out, gamma), true};
141  }
142  return l1_proximal()(out, gamma, x);
143  }
144 
145 };
146 } // namespace sopt::algorithm
147 #endif
sopt::Vector< Scalar > t_Vector
sopt::t_real Scalar
L1GProximal< SCALAR > &::type Psi(ARGS &&... args)
L1GProximal(bool tight_frame=false)
typename FB::t_Proximal t_Proximal
L1GProximal< SCALAR > & l1_proximal(proximal::L1< Scalar > const &arg)
SOPT_MACRO(positivity_constraint, bool)
typename FB::t_LinearTransform t_LinearTransform
void log_message() const override
proximal::L1< Scalar > const & l1_proximal() const
typename FB::t_Vector t_Vector
proximal::L1< Scalar > & l1_proximal()
L1 proximal used during calculation.
SOPT_MACRO(tolerance, Real)
ForwardBackward< SCALAR > FB
SOPT_MACRO(weights, Vector< t_real >)
t_Proximal proximal_operator() const override
SOPT_MACRO(real_constraint, bool)
t_LinearTransform const & Psi() const override
Analysis operator Ψ
SOPT_MACRO(itermax, t_uint)
auto tight_frame(T &&... args) const -> decltype(this->L1TightFrame< Scalar >::operator()(std::forward< T >(args)...))
Special case if Ψ ia a tight frame.
Definition: l1_proximal.h:313
LinearTransform< Vector< Scalar > > const & Psi() const
Linear transform applied to input prior to L1 norm.
Definition: l1_proximal.h:302
sopt::LinearTransform< t_Vector > t_LinearTransform
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
real_type< typename T0::Scalar >::type l1_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L1 norm.
Definition: maths.h:116
How did calling L1 go?
Definition: l1_proximal.h:190