SOPT
Sparse OPTimisation
Classes | Public Types | Public Member Functions | List of all members
sopt::algorithm::ImagingForwardBackward< SCALAR > Class Template Reference

#include <imaging_forward_backward.h>

Classes

struct  Diagnostic
 Values indicating how the algorithm ran. More...
 
struct  DiagnosticAndResult
 Holds result vector as well. More...
 

Public Types

using value_type = typename FB::value_type
 
using Scalar = typename FB::Scalar
 
using Real = typename FB::Real
 
using t_Vector = typename FB::t_Vector
 
using t_LinearTransform = typename FB::t_LinearTransform
 
using t_Proximal = typename FB::t_Proximal
 
using t_Gradient = typename FB::t_Gradient
 
using t_l2Gradient = typename std::function< void(t_Vector &, const t_Vector &)>
 
using t_IsConverged = typename FB::t_IsConverged
 
using t_randomUpdater = typename FB::t_randomUpdater
 

Public Member Functions

 ImagingForwardBackward (t_Vector const &target)
 
 ImagingForwardBackward (t_randomUpdater &updater)
 
virtual ~ImagingForwardBackward ()
 
 SOPT_MACRO (tight_frame, bool)
 Whether Ψ is a tight-frame or not. More...
 
 SOPT_MACRO (residual_tolerance, Real)
 Convergence of the relative variation of the objective functions. More...
 
 SOPT_MACRO (relative_variation, Real)
 Convergence of the relative variation of the objective functions. More...
 
 SOPT_MACRO (residual_convergence, t_IsConverged)
 Convergence of the residuals. More...
 
 SOPT_MACRO (objective_convergence, t_IsConverged)
 Convergence of the residuals. More...
 
 SOPT_MACRO (itermax, t_uint)
 Maximum number of iterations. More...
 
 SOPT_MACRO (regulariser_strength, Real)
 γ parameter. More...
 
 SOPT_MACRO (step_size, Real)
 γ parameter. More...
 
 SOPT_MACRO (sigma, Real)
 γ parameter. More...
 
 SOPT_MACRO (fista, bool)
 flag to for FISTA Forward-Backward algorithm. True by default but should be false when using a learned g_function. More...
 
 SOPT_MACRO (is_converged, t_IsConverged)
 A function verifying convergence. More...
 
t_LinearTransform const & Phi () const
 Measurement operator. More...
 
ImagingForwardBackward< SCALAR > & Phi (t_LinearTransform const &(Phi))
 
std::shared_ptr< NonDifferentiableFunc< SCALAR > > g_function ()
 
ImagingForwardBackward< SCALAR > & g_function (std::shared_ptr< NonDifferentiableFunc< SCALAR >> g_function)
 
std::shared_ptr< DifferentiableFunc< SCALAR > > f_function ()
 
ImagingForwardBackward< SCALAR > & f_function (std::shared_ptr< DifferentiableFunc< SCALAR >> f_function)
 
t_randomUpdaterrandom_updater ()
 
ImagingForwardBackward< SCALAR > & random_updater (t_randomUpdater &new_updater)
 
t_LinearTransform const & Psi () const
 
t_Vector const & target () const
 Vector of target measurements. More...
 
Real objmin () const
 Minimum of objective_function. More...
 
ImagingForwardBackward< Scalar > & target (t_Vector const &target)
 Sets the vector of target measurements. More...
 
Diagnostic operator() (t_Vector &out) const
 Calls Forward Backward. More...
 
Diagnostic operator() (t_Vector &out, std::tuple< t_Vector, t_Vector > const &guess)
 Calls Forward Backward. More...
 
Diagnostic operator() (t_Vector &out, std::tuple< t_Vector const &, t_Vector const & > const &guess)
 Calls Forward Backward. More...
 
DiagnosticAndResult operator() (std::tuple< t_Vector, t_Vector > const &guess)
 Calls Forward Backward. More...
 
DiagnosticAndResult operator() (std::tuple< t_Vector const &, t_Vector const & > const &guess)
 Calls Forward Backward. More...
 
DiagnosticAndResult operator() ()
 Calls Forward Backward. More...
 
DiagnosticAndResult operator() (DiagnosticAndResult const &warmstart)
 Makes it simple to chain different calls to FB. More...
 
ImagingForwardBackward &::type Phi (ARGS &&... args)
 
ImagingForwardBackward< Scalar > & residual_convergence (Real const &tolerance)
 Helper function to set-up default residual convergence function. More...
 
ImagingForwardBackward< Scalar > & objective_convergence (Real const &tolerance)
 Helper function to set-up default residual convergence function. More...
 
ImagingForwardBackward< Scalar > & is_converged (std::function< bool(t_Vector const &x)> const &func)
 Convergence function that takes only the output as argument. More...
 

Detailed Description

template<typename SCALAR>
class sopt::algorithm::ImagingForwardBackward< SCALAR >

Definition at line 31 of file imaging_forward_backward.h.

Member Typedef Documentation

◆ Real

template<typename SCALAR >
using sopt::algorithm::ImagingForwardBackward< SCALAR >::Real = typename FB::Real

Definition at line 38 of file imaging_forward_backward.h.

◆ Scalar

template<typename SCALAR >
using sopt::algorithm::ImagingForwardBackward< SCALAR >::Scalar = typename FB::Scalar

Definition at line 37 of file imaging_forward_backward.h.

◆ t_Gradient

template<typename SCALAR >
using sopt::algorithm::ImagingForwardBackward< SCALAR >::t_Gradient = typename FB::t_Gradient

Definition at line 42 of file imaging_forward_backward.h.

◆ t_IsConverged

template<typename SCALAR >
using sopt::algorithm::ImagingForwardBackward< SCALAR >::t_IsConverged = typename FB::t_IsConverged

Definition at line 44 of file imaging_forward_backward.h.

◆ t_l2Gradient

template<typename SCALAR >
using sopt::algorithm::ImagingForwardBackward< SCALAR >::t_l2Gradient = typename std::function<void(t_Vector &, const t_Vector &)>

Definition at line 43 of file imaging_forward_backward.h.

◆ t_LinearTransform

template<typename SCALAR >
using sopt::algorithm::ImagingForwardBackward< SCALAR >::t_LinearTransform = typename FB::t_LinearTransform

Definition at line 40 of file imaging_forward_backward.h.

◆ t_Proximal

template<typename SCALAR >
using sopt::algorithm::ImagingForwardBackward< SCALAR >::t_Proximal = typename FB::t_Proximal

Definition at line 41 of file imaging_forward_backward.h.

◆ t_randomUpdater

template<typename SCALAR >
using sopt::algorithm::ImagingForwardBackward< SCALAR >::t_randomUpdater = typename FB::t_randomUpdater

Definition at line 45 of file imaging_forward_backward.h.

◆ t_Vector

template<typename SCALAR >
using sopt::algorithm::ImagingForwardBackward< SCALAR >::t_Vector = typename FB::t_Vector

Definition at line 39 of file imaging_forward_backward.h.

◆ value_type

template<typename SCALAR >
using sopt::algorithm::ImagingForwardBackward< SCALAR >::value_type = typename FB::value_type

Definition at line 36 of file imaging_forward_backward.h.

Constructor & Destructor Documentation

◆ ImagingForwardBackward() [1/2]

template<typename SCALAR >
sopt::algorithm::ImagingForwardBackward< SCALAR >::ImagingForwardBackward ( t_Vector const &  target)
inline

Sets up imaging wrapper for ForwardBackward. Sets g_function_ to null to avoid having a dependency on the implementation of g_function. The correct implementation should be injected by the code that instantiates this class.

Parameters
[in]targetVector of target measurements

Definition at line 69 of file imaging_forward_backward.h.

70  : g_function_(nullptr),
71  f_function_(nullptr),
72  random_updater_(nullptr),
73  tight_frame_(false),
74  residual_tolerance_(0.),
75  relative_variation_(1e-4),
76  residual_convergence_(nullptr),
77  objective_convergence_(nullptr),
78  itermax_(std::numeric_limits<t_uint>::max()),
79  regulariser_strength_(1e-8),
80  step_size_(1),
81  sigma_(1),
82  fista_(true),
83  is_converged_()
84  {
85  std::shared_ptr<t_LinearTransform> Id = std::make_shared<t_LinearTransform>(linear_transform_identity<Scalar>());
86  problem_state = std::make_shared<IterationState<t_Vector>>(target, Id);
87  }
t_Vector const & target() const
Vector of target measurements.

References sopt::algorithm::ImagingForwardBackward< SCALAR >::target().

◆ ImagingForwardBackward() [2/2]

template<typename SCALAR >
sopt::algorithm::ImagingForwardBackward< SCALAR >::ImagingForwardBackward ( t_randomUpdater updater)
inline

Definition at line 89 of file imaging_forward_backward.h.

90  : g_function_(nullptr),
91  f_function_(nullptr),
92  random_updater_(updater),
93  tight_frame_(false),
94  residual_tolerance_(0.),
95  relative_variation_(1e-4),
96  residual_convergence_(nullptr),
97  objective_convergence_(nullptr),
98  itermax_(std::numeric_limits<t_uint>::max()),
99  regulariser_strength_(1e-8),
100  step_size_(1),
101  sigma_(1),
102  fista_(true),
103  is_converged_()
104  {
105  if(random_updater_)
106  {
107  // target and Phi are not known ahead of time for random data sets so need to be initialised
108  problem_state = random_updater_();
109  }
110  else
111  {
112  throw std::runtime_error("Attempted to construct ImagingForwardBackward class with a null random updater. To run without random updates supply a target vector instead.");
113  }
114  }

◆ ~ImagingForwardBackward()

template<typename SCALAR >
virtual sopt::algorithm::ImagingForwardBackward< SCALAR >::~ImagingForwardBackward ( )
inlinevirtual

Definition at line 116 of file imaging_forward_backward.h.

116 {}

Member Function Documentation

◆ f_function() [1/2]

template<typename SCALAR >
std::shared_ptr<DifferentiableFunc<SCALAR> > sopt::algorithm::ImagingForwardBackward< SCALAR >::f_function ( )
inline

Definition at line 185 of file imaging_forward_backward.h.

185 { return f_function_; }

Referenced by sopt::algorithm::ImagingForwardBackward< SCALAR >::f_function().

◆ f_function() [2/2]

template<typename SCALAR >
ImagingForwardBackward<SCALAR>& sopt::algorithm::ImagingForwardBackward< SCALAR >::f_function ( std::shared_ptr< DifferentiableFunc< SCALAR >>  f_function)
inline

Definition at line 186 of file imaging_forward_backward.h.

186  {
187  f_function_ = std::move(f_function);
188  return *this;
189  }
std::shared_ptr< DifferentiableFunc< SCALAR > > f_function()

References sopt::algorithm::ImagingForwardBackward< SCALAR >::f_function().

◆ g_function() [1/2]

template<typename SCALAR >
std::shared_ptr<NonDifferentiableFunc<SCALAR> > sopt::algorithm::ImagingForwardBackward< SCALAR >::g_function ( )
inline

Definition at line 176 of file imaging_forward_backward.h.

176 { return g_function_; }

Referenced by sopt::algorithm::ImagingForwardBackward< SCALAR >::g_function().

◆ g_function() [2/2]

template<typename SCALAR >
ImagingForwardBackward<SCALAR>& sopt::algorithm::ImagingForwardBackward< SCALAR >::g_function ( std::shared_ptr< NonDifferentiableFunc< SCALAR >>  g_function)
inline

Definition at line 177 of file imaging_forward_backward.h.

177  {
178  g_function_ = std::move(g_function);
179  return *this;
180  }
std::shared_ptr< NonDifferentiableFunc< SCALAR > > g_function()

References sopt::algorithm::ImagingForwardBackward< SCALAR >::g_function().

◆ is_converged()

template<typename SCALAR >
ImagingForwardBackward<Scalar>& sopt::algorithm::ImagingForwardBackward< SCALAR >::is_converged ( std::function< bool(t_Vector const &x)> const &  func)
inline

Convergence function that takes only the output as argument.

Definition at line 287 of file imaging_forward_backward.h.

287  {
288  return is_converged([func](t_Vector const &x, t_Vector const &) { return func(x); });
289  }
sopt::Vector< Scalar > t_Vector
ImagingForwardBackward< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.

Referenced by TEST_CASE().

◆ objective_convergence()

template<typename SCALAR >
ImagingForwardBackward<Scalar>& sopt::algorithm::ImagingForwardBackward< SCALAR >::objective_convergence ( Real const &  tolerance)
inline

Helper function to set-up default residual convergence function.

Definition at line 283 of file imaging_forward_backward.h.

283  {
284  return objective_convergence(nullptr).relative_variation(tolerance);
285  }
ImagingForwardBackward< Scalar > & objective_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.

◆ objmin()

template<typename SCALAR >
Real sopt::algorithm::ImagingForwardBackward< SCALAR >::objmin ( ) const
inline

Minimum of objective_function.

Definition at line 216 of file imaging_forward_backward.h.

216 { return objmin_; }

◆ operator()() [1/7]

template<typename SCALAR >
DiagnosticAndResult sopt::algorithm::ImagingForwardBackward< SCALAR >::operator() ( )
inline

Calls Forward Backward.

Parameters
[in]guessinitial guess

Definition at line 257 of file imaging_forward_backward.h.

257  {
258  DiagnosticAndResult result;
259  static_cast<Diagnostic &>(result) = operator()(
260  result.x, ForwardBackward<SCALAR>::initial_guess(target(), Phi()));
261  return result;
262  }
t_LinearTransform const & Phi() const
Measurement operator.

References sopt::algorithm::ImagingForwardBackward< SCALAR >::Phi(), sopt::algorithm::ImagingForwardBackward< SCALAR >::target(), and sopt::algorithm::ImagingForwardBackward< SCALAR >::DiagnosticAndResult::x.

Referenced by sopt::algorithm::ImagingForwardBackward< SCALAR >::operator()().

◆ operator()() [2/7]

template<typename SCALAR >
DiagnosticAndResult sopt::algorithm::ImagingForwardBackward< SCALAR >::operator() ( DiagnosticAndResult const &  warmstart)
inline

Makes it simple to chain different calls to FB.

Definition at line 264 of file imaging_forward_backward.h.

264  {
265  DiagnosticAndResult result = warmstart;
266  static_cast<Diagnostic &>(result) = operator()(result.x, warmstart.x, warmstart.residual);
267  return result;
268  }

References sopt::algorithm::ImagingForwardBackward< SCALAR >::DiagnosticAndResult::x.

◆ operator()() [3/7]

template<typename SCALAR >
DiagnosticAndResult sopt::algorithm::ImagingForwardBackward< SCALAR >::operator() ( std::tuple< t_Vector const &, t_Vector const & > const &  guess)
inline

Calls Forward Backward.

Parameters
[in]guessinitial guess

Definition at line 249 of file imaging_forward_backward.h.

250  {
251  DiagnosticAndResult result;
252  static_cast<Diagnostic &>(result) = operator()(result.x, guess);
253  return result;
254  }

References sopt::algorithm::ImagingForwardBackward< SCALAR >::DiagnosticAndResult::x.

◆ operator()() [4/7]

template<typename SCALAR >
DiagnosticAndResult sopt::algorithm::ImagingForwardBackward< SCALAR >::operator() ( std::tuple< t_Vector, t_Vector > const &  guess)
inline

Calls Forward Backward.

Parameters
[in]guessinitial guess

Definition at line 244 of file imaging_forward_backward.h.

244  {
245  return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
246  }
DiagnosticAndResult operator()()
Calls Forward Backward.

References sopt::algorithm::ImagingForwardBackward< SCALAR >::operator()().

◆ operator()() [5/7]

template<typename SCALAR >
Diagnostic sopt::algorithm::ImagingForwardBackward< SCALAR >::operator() ( t_Vector out) const
inline

Calls Forward Backward.

Parameters
[out]outOutput vector x

Definition at line 225 of file imaging_forward_backward.h.

225  {
226  return operator()(out, ForwardBackward<SCALAR>::initial_guess(target(), Phi()));
227 
228  }

References sopt::algorithm::ImagingForwardBackward< SCALAR >::operator()(), sopt::algorithm::ImagingForwardBackward< SCALAR >::Phi(), and sopt::algorithm::ImagingForwardBackward< SCALAR >::target().

◆ operator()() [6/7]

template<typename SCALAR >
Diagnostic sopt::algorithm::ImagingForwardBackward< SCALAR >::operator() ( t_Vector out,
std::tuple< t_Vector const &, t_Vector const & > const &  guess 
)
inline

Calls Forward Backward.

Parameters
[out]outOutput vector x
[in]guessinitial guess

Definition at line 238 of file imaging_forward_backward.h.

239  {
240  return operator()(out, std::get<0>(guess), std::get<1>(guess));
241  }

References sopt::algorithm::ImagingForwardBackward< SCALAR >::operator()().

◆ operator()() [7/7]

template<typename SCALAR >
Diagnostic sopt::algorithm::ImagingForwardBackward< SCALAR >::operator() ( t_Vector out,
std::tuple< t_Vector, t_Vector > const &  guess 
)
inline

Calls Forward Backward.

Parameters
[out]outOutput vector x
[in]guessinitial guess

Definition at line 232 of file imaging_forward_backward.h.

232  {
233  return operator()(out, std::get<0>(guess), std::get<1>(guess));
234  }

References sopt::algorithm::ImagingForwardBackward< SCALAR >::operator()().

◆ Phi() [1/3]

template<typename SCALAR >
t_LinearTransform const& sopt::algorithm::ImagingForwardBackward< SCALAR >::Phi ( ) const
inline

Measurement operator.

Definition at line 159 of file imaging_forward_backward.h.

159 { return problem_state->Phi(); }

Referenced by main(), sopt::algorithm::ImagingForwardBackward< SCALAR >::operator()(), and sopt::algorithm::ImagingForwardBackward< SCALAR >::Phi().

◆ Phi() [2/3]

template<typename SCALAR >
ImagingForwardBackward& ::type sopt::algorithm::ImagingForwardBackward< SCALAR >::Phi ( ARGS &&...  args)
inline

Definition at line 272 of file imaging_forward_backward.h.

273  {
274  problem_state->Phi(linear_transform(std::forward<ARGS>(args)...));
275  return *this;
276  }
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})

References sopt::linear_transform().

◆ Phi() [3/3]

template<typename SCALAR >
ImagingForwardBackward<SCALAR>& sopt::algorithm::ImagingForwardBackward< SCALAR >::Phi ( t_LinearTransform const &  Phi)
inline

Definition at line 160 of file imaging_forward_backward.h.

160  {
161  problem_state->Phi(Phi);
162  return *this;
163  }

References sopt::algorithm::ImagingForwardBackward< SCALAR >::Phi().

◆ Psi()

template<typename SCALAR >
t_LinearTransform const& sopt::algorithm::ImagingForwardBackward< SCALAR >::Psi ( ) const
inline

Definition at line 198 of file imaging_forward_backward.h.

199  {
200  return (g_function_) ? g_function_->Psi() : linear_transform_identity<Scalar>();
201  }

◆ random_updater() [1/2]

template<typename SCALAR >
t_randomUpdater& sopt::algorithm::ImagingForwardBackward< SCALAR >::random_updater ( )
inline

Definition at line 192 of file imaging_forward_backward.h.

192 { return random_updater_; }

◆ random_updater() [2/2]

template<typename SCALAR >
ImagingForwardBackward<SCALAR>& sopt::algorithm::ImagingForwardBackward< SCALAR >::random_updater ( t_randomUpdater new_updater)
inline

Definition at line 193 of file imaging_forward_backward.h.

193  {
194  random_updater_ = new_updater; // may change this to a move if we don't need to keep it
195  return *this;
196  }

◆ residual_convergence()

template<typename SCALAR >
ImagingForwardBackward<Scalar>& sopt::algorithm::ImagingForwardBackward< SCALAR >::residual_convergence ( Real const &  tolerance)
inline

Helper function to set-up default residual convergence function.

Definition at line 279 of file imaging_forward_backward.h.

279  {
280  return residual_convergence(nullptr).residual_tolerance(tolerance);
281  }
ImagingForwardBackward< Scalar > & residual_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.

Referenced by TEST_CASE().

◆ SOPT_MACRO() [1/11]

template<typename SCALAR >
sopt::algorithm::ImagingForwardBackward< SCALAR >::SOPT_MACRO ( fista  ,
bool   
)

flag to for FISTA Forward-Backward algorithm. True by default but should be false when using a learned g_function.

◆ SOPT_MACRO() [2/11]

template<typename SCALAR >
sopt::algorithm::ImagingForwardBackward< SCALAR >::SOPT_MACRO ( is_converged  ,
t_IsConverged   
)

A function verifying convergence.

◆ SOPT_MACRO() [3/11]

template<typename SCALAR >
sopt::algorithm::ImagingForwardBackward< SCALAR >::SOPT_MACRO ( itermax  ,
t_uint   
)

Maximum number of iterations.

◆ SOPT_MACRO() [4/11]

template<typename SCALAR >
sopt::algorithm::ImagingForwardBackward< SCALAR >::SOPT_MACRO ( objective_convergence  ,
t_IsConverged   
)

Convergence of the residuals.

If negative, this convergence criteria is disabled.

◆ SOPT_MACRO() [5/11]

template<typename SCALAR >
sopt::algorithm::ImagingForwardBackward< SCALAR >::SOPT_MACRO ( regulariser_strength  ,
Real   
)

γ parameter.

◆ SOPT_MACRO() [6/11]

template<typename SCALAR >
sopt::algorithm::ImagingForwardBackward< SCALAR >::SOPT_MACRO ( relative_variation  ,
Real   
)

Convergence of the relative variation of the objective functions.

If negative, this convergence criteria is disabled.

◆ SOPT_MACRO() [7/11]

template<typename SCALAR >
sopt::algorithm::ImagingForwardBackward< SCALAR >::SOPT_MACRO ( residual_convergence  ,
t_IsConverged   
)

Convergence of the residuals.

If negative, this convergence criteria is disabled.

◆ SOPT_MACRO() [8/11]

template<typename SCALAR >
sopt::algorithm::ImagingForwardBackward< SCALAR >::SOPT_MACRO ( residual_tolerance  ,
Real   
)

Convergence of the relative variation of the objective functions.

If negative, this convergence criteria is disabled.

◆ SOPT_MACRO() [9/11]

template<typename SCALAR >
sopt::algorithm::ImagingForwardBackward< SCALAR >::SOPT_MACRO ( sigma  ,
Real   
)

γ parameter.

◆ SOPT_MACRO() [10/11]

template<typename SCALAR >
sopt::algorithm::ImagingForwardBackward< SCALAR >::SOPT_MACRO ( step_size  ,
Real   
)

γ parameter.

◆ SOPT_MACRO() [11/11]

template<typename SCALAR >
sopt::algorithm::ImagingForwardBackward< SCALAR >::SOPT_MACRO ( tight_frame  ,
bool   
)

Whether Ψ is a tight-frame or not.

◆ target() [1/2]

template<typename SCALAR >
t_Vector const& sopt::algorithm::ImagingForwardBackward< SCALAR >::target ( ) const
inline

◆ target() [2/2]

template<typename SCALAR >
ImagingForwardBackward<Scalar>& sopt::algorithm::ImagingForwardBackward< SCALAR >::target ( t_Vector const &  target)
inline

Sets the vector of target measurements.

Definition at line 218 of file imaging_forward_backward.h.

218  {
219  problem_state->target(target);
220  return *this;
221  }

References sopt::algorithm::ImagingForwardBackward< SCALAR >::target().


The documentation for this class was generated from the following file: