SOPT
Sparse OPTimisation
Classes | Public Types | Public Member Functions | List of all members
sopt::algorithm::SDMM< SCALAR > Class Template Reference

Simultaneous-direction method of the multipliers. More...

#include <sdmm.h>

Classes

struct  Diagnostic
 Values indicating how the algorithm ran. More...
 
struct  DiagnosticAndResult
 

Public Types

using value_type = SCALAR
 Scalar type. More...
 
using Scalar = value_type
 Scalar type. More...
 
using Real = typename real_type< Scalar >::type
 Real type. More...
 
using t_Vector = Vector< SCALAR >
 Type of then underlying vectors. More...
 
using t_LinearTransform = LinearTransform< t_Vector >
 Type of the A and A^H operations. More...
 
using t_Proximal = ProximalFunction< SCALAR >
 Type of the proximal functions. More...
 
using t_IsConverged = ConvergenceFunction< SCALAR >
 Type of the convergence function. More...
 

Public Member Functions

 SDMM ()
 
virtual ~SDMM ()
 
 SOPT_MACRO (itermax, t_uint)
 Maximum number of iterations. More...
 
 SOPT_MACRO (gamma, Real)
 Gamma. More...
 
 SOPT_MACRO (conjugate_gradient, ConjugateGradient)
 Conjugate gradient. More...
 
 SOPT_MACRO (is_converged, t_IsConverged)
 A function verifying convergence. More...
 
SDMM< SCALAR > & conjugate_gradient (t_uint itermax, t_real tolerance)
 Helps setup conjugate gradient. More...
 
template<typename PROXIMAL , typename T >
SDMM< SCALAR > & append (PROXIMAL proximal, T args)
 Appends a proximal and linear transform. More...
 
template<typename PROXIMAL >
SDMM< SCALAR > & append (PROXIMAL proximal)
 Appends a proximal with identity as the linear transform. More...
 
template<typename PROXIMAL , typename L , typename LADJOINT >
SDMM< SCALAR > & append (PROXIMAL proximal, L l, LADJOINT ladjoint)
 Appends a proximal with the linear transform as pair of functions. More...
 
template<typename PROXIMAL , typename L , typename LADJOINT >
SDMM< SCALAR > & append (PROXIMAL proximal, L l, LADJOINT ladjoint, std::array< t_int, 3 > sizes)
 Appends a proximal with the linear transform as pair of functions. More...
 
template<typename PROXIMAL , typename L , typename LADJOINT >
SDMM< SCALAR > & append (PROXIMAL proximal, L l, std::array< t_int, 3 > dsizes, LADJOINT ladjoint, std::array< t_int, 3 > isizes)
 Appends a proximal with the linear transform as pair of functions. More...
 
Diagnostic operator() (t_Vector &out, t_Vector const &input) const
 Implements SDMM. More...
 
DiagnosticAndResult operator() (t_Vector const &input) const
 
DiagnosticAndResult operator() (DiagnosticAndResult const &warmstart) const
 Makes it simple to chain different calls to SDMM. More...
 
std::vector< t_LinearTransform > const & transforms () const
 Linear transforms associated with each objective function. More...
 
std::vector< t_LinearTransform > & transforms ()
 Linear transforms associated with each objective function. More...
 
t_LinearTransform const & transforms (t_uint i) const
 Linear transform associated with a given objective function. More...
 
t_LinearTransformtransforms (t_uint i)
 Linear transform associated with a given objective function. More...
 
std::vector< t_Proximal > const & proximals () const
 Proximal of each objective function. More...
 
std::vector< t_Proximal > & proximals ()
 Linear transforms associated with each objective function. More...
 
t_Proximal const & proximals (t_uint i) const
 Proximal associated with a given objective function. More...
 
t_Proximalproximals (t_uint i)
 Proximal associated with a given objective function. More...
 
template<typename T0 >
proximal::ProximalExpression< t_Proximal const &, T0 > proximals (t_uint i, Eigen::MatrixBase< T0 > const &x) const
 Lazy call to specific proximal function. More...
 
t_uint size () const
 Number of terms. More...
 
template<typename T0 , typename... T>
auto conjugate_gradient (T0 &&t0, T &&... args) const -> decltype(this->conjugate_gradient()(std::forward< T0 >(t0), std::forward< T >(args)...))
 Forwards to internal conjugage gradient object. More...
 
bool is_converged (t_Vector const &x) const
 Forwards to convergence function parameter. More...
 

Detailed Description

template<typename SCALAR>
class sopt::algorithm::SDMM< SCALAR >

Simultaneous-direction method of the multipliers.

The algorithm is detailed in (doi) 10.1093/mnras/stu202.

Definition at line 23 of file sdmm.h.

Member Typedef Documentation

◆ Real

template<typename SCALAR >
using sopt::algorithm::SDMM< SCALAR >::Real = typename real_type<Scalar>::type

Real type.

Definition at line 43 of file sdmm.h.

◆ Scalar

template<typename SCALAR >
using sopt::algorithm::SDMM< SCALAR >::Scalar = value_type

Scalar type.

Definition at line 41 of file sdmm.h.

◆ t_IsConverged

template<typename SCALAR >
using sopt::algorithm::SDMM< SCALAR >::t_IsConverged = ConvergenceFunction<SCALAR>

Type of the convergence function.

Definition at line 51 of file sdmm.h.

◆ t_LinearTransform

template<typename SCALAR >
using sopt::algorithm::SDMM< SCALAR >::t_LinearTransform = LinearTransform<t_Vector>

Type of the A and A^H operations.

Definition at line 47 of file sdmm.h.

◆ t_Proximal

template<typename SCALAR >
using sopt::algorithm::SDMM< SCALAR >::t_Proximal = ProximalFunction<SCALAR>

Type of the proximal functions.

Definition at line 49 of file sdmm.h.

◆ t_Vector

template<typename SCALAR >
using sopt::algorithm::SDMM< SCALAR >::t_Vector = Vector<SCALAR>

Type of then underlying vectors.

Definition at line 45 of file sdmm.h.

◆ value_type

template<typename SCALAR >
using sopt::algorithm::SDMM< SCALAR >::value_type = SCALAR

Scalar type.

Definition at line 39 of file sdmm.h.

Constructor & Destructor Documentation

◆ SDMM()

template<typename SCALAR >
sopt::algorithm::SDMM< SCALAR >::SDMM ( )
inline

Definition at line 53 of file sdmm.h.

54  : itermax_(std::numeric_limits<t_uint>::max()),
55  gamma_(1e-8),
56  conjugate_gradient_(std::numeric_limits<t_uint>::max(), 1e-6),
57  is_converged_([](t_Vector const &) { return false; }) {}
sopt::Vector< Scalar > t_Vector

◆ ~SDMM()

template<typename SCALAR >
virtual sopt::algorithm::SDMM< SCALAR >::~SDMM ( )
inlinevirtual

Definition at line 58 of file sdmm.h.

58 {}

Member Function Documentation

◆ append() [1/5]

template<typename SCALAR >
template<typename PROXIMAL >
SDMM<SCALAR>& sopt::algorithm::SDMM< SCALAR >::append ( PROXIMAL  proximal)
inline

Appends a proximal with identity as the linear transform.

Definition at line 97 of file sdmm.h.

97  {
98  return append(proximal, linear_transform_identity<Scalar>());
99  }
SDMM< SCALAR > & append(PROXIMAL proximal, T args)
Appends a proximal and linear transform.
Definition: sdmm.h:90

References sopt::algorithm::SDMM< SCALAR >::append().

◆ append() [2/5]

template<typename SCALAR >
template<typename PROXIMAL , typename L , typename LADJOINT >
SDMM<SCALAR>& sopt::algorithm::SDMM< SCALAR >::append ( PROXIMAL  proximal,
l,
LADJOINT  ladjoint 
)
inline

Appends a proximal with the linear transform as pair of functions.

Definition at line 102 of file sdmm.h.

102  {
103  return append(proximal, linear_transform<t_Vector>(l, ladjoint));
104  }

References sopt::algorithm::SDMM< SCALAR >::append().

◆ append() [3/5]

template<typename SCALAR >
template<typename PROXIMAL , typename L , typename LADJOINT >
SDMM<SCALAR>& sopt::algorithm::SDMM< SCALAR >::append ( PROXIMAL  proximal,
l,
LADJOINT  ladjoint,
std::array< t_int, 3 >  sizes 
)
inline

Appends a proximal with the linear transform as pair of functions.

Definition at line 107 of file sdmm.h.

107  {
108  return append(proximal, linear_transform<t_Vector>(l, ladjoint, sizes));
109  }

References sopt::algorithm::SDMM< SCALAR >::append().

◆ append() [4/5]

template<typename SCALAR >
template<typename PROXIMAL , typename L , typename LADJOINT >
SDMM<SCALAR>& sopt::algorithm::SDMM< SCALAR >::append ( PROXIMAL  proximal,
l,
std::array< t_int, 3 >  dsizes,
LADJOINT  ladjoint,
std::array< t_int, 3 >  isizes 
)
inline

Appends a proximal with the linear transform as pair of functions.

Definition at line 112 of file sdmm.h.

113  {
114  return append(proximal, linear_transform<t_Vector>(l, dsizes, ladjoint, isizes));
115  }

References sopt::algorithm::SDMM< SCALAR >::append().

◆ append() [5/5]

template<typename SCALAR >
template<typename PROXIMAL , typename T >
SDMM<SCALAR>& sopt::algorithm::SDMM< SCALAR >::append ( PROXIMAL  proximal,
args 
)
inline

Appends a proximal and linear transform.

Definition at line 90 of file sdmm.h.

90  {
91  proximals().emplace_back(proximal);
92  transforms().emplace_back(linear_transform(args));
93  return *this;
94  }
std::vector< t_LinearTransform > const & transforms() const
Linear transforms associated with each objective function.
Definition: sdmm.h:135
std::vector< t_Proximal > const & proximals() const
Proximal of each objective function.
Definition: sdmm.h:144
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})

References sopt::linear_transform(), sopt::algorithm::SDMM< SCALAR >::proximals(), and sopt::algorithm::SDMM< SCALAR >::transforms().

Referenced by sopt::algorithm::SDMM< SCALAR >::append(), and TEST_CASE().

◆ conjugate_gradient() [1/2]

template<typename SCALAR >
template<typename T0 , typename... T>
auto sopt::algorithm::SDMM< SCALAR >::conjugate_gradient ( T0 &&  t0,
T &&...  args 
) const -> decltype(this->conjugate_gradient()(std::forward<T0>(t0), std::forward<T>(args)...))
inline

Forwards to internal conjugage gradient object.

Removes the need for ugly extra brackets.

Definition at line 166 of file sdmm.h.

167  {
168  return conjugate_gradient()(std::forward<T0>(t0), std::forward<T>(args)...);
169  }
SDMM< SCALAR > & conjugate_gradient(t_uint itermax, t_real tolerance)
Helps setup conjugate gradient.
Definition: sdmm.h:83

References sopt::algorithm::SDMM< SCALAR >::conjugate_gradient().

◆ conjugate_gradient() [2/2]

template<typename SCALAR >
SDMM<SCALAR>& sopt::algorithm::SDMM< SCALAR >::conjugate_gradient ( t_uint  itermax,
t_real  tolerance 
)
inline

Helps setup conjugate gradient.

Definition at line 83 of file sdmm.h.

83  {
84  conjugate_gradient_.itermax(itermax);
85  conjugate_gradient_.tolerance(tolerance);
86  return *this;
87  }

Referenced by sopt::algorithm::SDMM< SCALAR >::conjugate_gradient(), main(), and TEST_CASE().

◆ is_converged()

template<typename SCALAR >
bool sopt::algorithm::SDMM< SCALAR >::is_converged ( t_Vector const &  x) const
inline

Forwards to convergence function parameter.

Definition at line 172 of file sdmm.h.

172 { return is_converged()(x); }
bool is_converged(t_Vector const &x) const
Forwards to convergence function parameter.
Definition: sdmm.h:172

References sopt::algorithm::SDMM< SCALAR >::is_converged().

Referenced by sopt::algorithm::SDMM< SCALAR >::is_converged(), main(), and SCENARIO().

◆ operator()() [1/3]

template<typename SCALAR >
DiagnosticAndResult sopt::algorithm::SDMM< SCALAR >::operator() ( DiagnosticAndResult const &  warmstart) const
inline

Makes it simple to chain different calls to SDMM.

Definition at line 128 of file sdmm.h.

128  {
129  DiagnosticAndResult result;
130  static_cast<Diagnostic &>(result) = operator()(result.x, warmstart.x);
131  return result;
132  }

References sopt::algorithm::SDMM< SCALAR >::DiagnosticAndResult::x.

◆ operator()() [2/3]

template<typename SCALAR >
SDMM< SCALAR >::Diagnostic sopt::algorithm::SDMM< SCALAR >::operator() ( t_Vector out,
t_Vector const &  input 
) const

Implements SDMM.

Follows Combettes and Pesquet "Proximal Splitting Methods in Signal Processing", arXiv:0912.3522v4 [math.OC] (2010), equation 65. See therein for notation

Definition at line 196 of file sdmm.h.

197  {
198  sanity_check(input);
199  bool convergence = false;
200  t_uint niters(0);
201  // Figures out where itermax or convergence reached
202  auto const has_finished = [&convergence, &niters, this](t_Vector const &out) {
203  convergence = is_converged(out);
204  return niters >= itermax() or convergence;
205  };
206 
207  SOPT_HIGH_LOG("Performing SDMM ");
208  out = input;
209  t_Vectors y(transforms().size());
210  t_Vectors z(transforms().size());
211 
212  // Initial step replaces iteration update with initialization
213  initialization(y, z, input);
214  auto cg_diagnostic = solve_for_xn(out, y, z);
215 
216  while (not has_finished(out)) {
217  SOPT_LOW_LOG("Iteration {}/{}. ", niters, itermax());
218  // computes y and z from out and transforms
219  update_directions(y, z, out);
220  SOPT_LOW_LOG(" - sum z_ij = {}",
221  std::accumulate(z.begin(), z.end(), Scalar(0e0),
222  [](Scalar const &a, t_Vector const &z) { return a + z.sum(); }));
223  // computes x = L^-1 y
224  cg_diagnostic = solve_for_xn(out, y, z);
225  SOPT_LOW_LOG(" - CG Residual = {} in {}/{} iterations", cg_diagnostic.residual,
226  cg_diagnostic.niters, conjugate_gradient().itermax());
227 
228  ++niters;
229  }
230  return {niters, convergence, cg_diagnostic};
231 }
sopt::t_real Scalar
constexpr Scalar a
t_uint size() const
Number of terms.
Definition: sdmm.h:159
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15

References a, SOPT_HIGH_LOG, and SOPT_LOW_LOG.

◆ operator()() [3/3]

template<typename SCALAR >
DiagnosticAndResult sopt::algorithm::SDMM< SCALAR >::operator() ( t_Vector const &  input) const
inline

Definition at line 122 of file sdmm.h.

122  {
123  DiagnosticAndResult result;
124  static_cast<Diagnostic &>(result) = operator()(result.x, input);
125  return result;
126  }

References sopt::algorithm::SDMM< SCALAR >::DiagnosticAndResult::x.

◆ proximals() [1/5]

template<typename SCALAR >
std::vector<t_Proximal>& sopt::algorithm::SDMM< SCALAR >::proximals ( )
inline

Linear transforms associated with each objective function.

Definition at line 146 of file sdmm.h.

146 { return proximals_; }

◆ proximals() [2/5]

template<typename SCALAR >
std::vector<t_Proximal> const& sopt::algorithm::SDMM< SCALAR >::proximals ( ) const
inline

Proximal of each objective function.

Definition at line 144 of file sdmm.h.

144 { return proximals_; }

Referenced by sopt::algorithm::SDMM< SCALAR >::append(), sopt::algorithm::SDMM< SCALAR >::proximals(), and sopt::algorithm::SDMM< SCALAR >::size().

◆ proximals() [3/5]

template<typename SCALAR >
t_Proximal& sopt::algorithm::SDMM< SCALAR >::proximals ( t_uint  i)
inline

Proximal associated with a given objective function.

Definition at line 150 of file sdmm.h.

150 { return proximals_[i]; }

◆ proximals() [4/5]

template<typename SCALAR >
t_Proximal const& sopt::algorithm::SDMM< SCALAR >::proximals ( t_uint  i) const
inline

Proximal associated with a given objective function.

Definition at line 148 of file sdmm.h.

148 { return proximals_[i]; }

◆ proximals() [5/5]

template<typename SCALAR >
template<typename T0 >
proximal::ProximalExpression<t_Proximal const &, T0> sopt::algorithm::SDMM< SCALAR >::proximals ( t_uint  i,
Eigen::MatrixBase< T0 > const &  x 
) const
inline

Lazy call to specific proximal function.

Definition at line 153 of file sdmm.h.

154  {
155  return {proximals()[i], gamma(), x};
156  }

References sopt::algorithm::SDMM< SCALAR >::proximals().

◆ size()

template<typename SCALAR >
t_uint sopt::algorithm::SDMM< SCALAR >::size ( ) const
inline

Number of terms.

Definition at line 159 of file sdmm.h.

159 { return proximals().size(); }

References sopt::algorithm::SDMM< SCALAR >::proximals().

◆ SOPT_MACRO() [1/4]

template<typename SCALAR >
sopt::algorithm::SDMM< SCALAR >::SOPT_MACRO ( conjugate_gradient  ,
ConjugateGradient   
)

Conjugate gradient.

◆ SOPT_MACRO() [2/4]

template<typename SCALAR >
sopt::algorithm::SDMM< SCALAR >::SOPT_MACRO ( gamma  ,
Real   
)

Gamma.

◆ SOPT_MACRO() [3/4]

template<typename SCALAR >
sopt::algorithm::SDMM< SCALAR >::SOPT_MACRO ( is_converged  ,
t_IsConverged   
)

A function verifying convergence.

◆ SOPT_MACRO() [4/4]

template<typename SCALAR >
sopt::algorithm::SDMM< SCALAR >::SOPT_MACRO ( itermax  ,
t_uint   
)

Maximum number of iterations.

◆ transforms() [1/4]

template<typename SCALAR >
std::vector<t_LinearTransform>& sopt::algorithm::SDMM< SCALAR >::transforms ( )
inline

Linear transforms associated with each objective function.

Definition at line 137 of file sdmm.h.

137 { return transforms_; }

◆ transforms() [2/4]

template<typename SCALAR >
std::vector<t_LinearTransform> const& sopt::algorithm::SDMM< SCALAR >::transforms ( ) const
inline

Linear transforms associated with each objective function.

Definition at line 135 of file sdmm.h.

135 { return transforms_; }

Referenced by sopt::algorithm::SDMM< SCALAR >::append(), and TEST_CASE().

◆ transforms() [3/4]

template<typename SCALAR >
t_LinearTransform& sopt::algorithm::SDMM< SCALAR >::transforms ( t_uint  i)
inline

Linear transform associated with a given objective function.

Definition at line 141 of file sdmm.h.

141 { return transforms_[i]; }

◆ transforms() [4/4]

template<typename SCALAR >
t_LinearTransform const& sopt::algorithm::SDMM< SCALAR >::transforms ( t_uint  i) const
inline

Linear transform associated with a given objective function.

Definition at line 139 of file sdmm.h.

139 { return transforms_[i]; }

The documentation for this class was generated from the following file: