SOPT
Sparse OPTimisation
differentiable_func.h
Go to the documentation of this file.
1 #ifndef DIFFERENTIABLE_FUNC_H
2 #define DIFFERENTIABLE_FUNC_H
3 
5 
6 // Abstract base class providing the interface for differentiable functions f(x)
7 // with a defined gradient.
8 template <typename SCALAR> class DifferentiableFunc
9 {
10 
11 public:
12  using FB = sopt::algorithm::ForwardBackward<SCALAR>;
13  using Real = typename FB::Real;
14  using t_Vector = typename FB::t_Vector;
15  using t_Gradient = typename FB::t_Gradient;
17 
18  // A function that prints a log message
19  virtual void log_message() const = 0;
20 
21  // Return a function representing the proximal operator for this function.
22  // Function must be of type t_Proximal, that is
23  // void proximal_operator(Vector, real, Vector)
24  virtual t_Gradient gradient()
25  {
26  return [this](t_Vector &output, const t_Vector &image, const t_Vector &residual,
27  const t_LinearTransform &Phi) -> void { this->gradient(output, image, residual, Phi); };
28  }
29 
30  // Calculate the gradient directly
31  virtual void gradient(t_Vector &output, const t_Vector &image, const t_Vector &residual,
32  const t_LinearTransform &Phi) = 0;
33 
34  // Calculate the function directly
35  virtual Real function(t_Vector const &image, t_Vector const &y, t_LinearTransform const &Phi) = 0;
36 
37  // Get appropriate gradient step-size for FISTA algorithms
39  {
40  return step_size;
41  }
42 
43  protected:
44 
45  Real step_size;
46 
47  // Transforms input image to a different basis.
48  // Return linear_transform_identity() if transform not necessary.
49  //virtual const t_LinearTransform &Phi() const = 0;
50 
51 };
52 
53 #endif
sopt::Vector< Scalar > t_Vector
typename FB::t_LinearTransform t_LinearTransform
typename FB::t_Gradient t_Gradient
typename FB::Real Real
virtual void log_message() const =0
Real get_step_size() const
sopt::algorithm::ForwardBackward< SCALAR > FB
virtual t_Gradient gradient()
typename FB::t_Vector t_Vector
virtual void gradient(t_Vector &output, const t_Vector &image, const t_Vector &residual, const t_LinearTransform &Phi)=0
sopt::LinearTransform< t_Vector > t_LinearTransform