SOPT
Sparse OPTimisation
Public Member Functions | List of all members
sopt::ONNXDifferentiableFunc< SCALAR > Class Template Reference

#include <onnx_differentiable_func.h>

+ Inheritance diagram for sopt::ONNXDifferentiableFunc< SCALAR >:
+ Collaboration diagram for sopt::ONNXDifferentiableFunc< SCALAR >:

Public Member Functions

 ONNXDifferentiableFunc (const std::string &function_model_path, const std::string &gradient_model_path, const Real sigma, const Real mu, const Real lambda, const LinearTransform &Phi, const std::vector< int64_t > dimensions={})
 
void log_message () const override
 
void gradient (Vector &output, const Vector &image, const Vector &residual, const LinearTransform &Phi) override
 
void infer_dimensions (const size_t image_size)
 
void set_dimensions (const std::vector< int64_t > &dims)
 
Real function (Vector const &image, Vector const &y, LinearTransform const &Phi) override
 
- Public Member Functions inherited from DifferentiableFunc< SCALAR >
virtual t_Gradient gradient ()
 
Real get_step_size () const
 

Additional Inherited Members

- Public Types inherited from DifferentiableFunc< SCALAR >
using FB = sopt::algorithm::ForwardBackward< SCALAR >
 
using Real = typename FB::Real
 
using t_Vector = typename FB::t_Vector
 
using t_Gradient = typename FB::t_Gradient
 
using t_LinearTransform = typename FB::t_LinearTransform
 

Detailed Description

template<typename SCALAR>
class sopt::ONNXDifferentiableFunc< SCALAR >

Definition at line 12 of file onnx_differentiable_func.h.

Constructor & Destructor Documentation

◆ ONNXDifferentiableFunc()

template<typename SCALAR >
sopt::ONNXDifferentiableFunc< SCALAR >::ONNXDifferentiableFunc ( const std::string &  function_model_path,
const std::string &  gradient_model_path,
const Real  sigma,
const Real  mu,
const Real  lambda,
const LinearTransform &  Phi,
const std::vector< int64_t >  dimensions = {} 
)
inline

Definition at line 20 of file onnx_differentiable_func.h.

26  {}): LT(Phi), sigma(sigma), mu(mu), lambda(lambda),
27  function_model(function_model_path),
28  gradient_model(gradient_model_path)
29  {
30  Real L_CRR; // Lipschitz constant
31  if(dimensions.empty()) infer_square_dimensions = true;
32  try
33  {
34  L_CRR = gradient_model.retrieve<double>("L_CRR");
35  this->step_size = 0.98 / (1/(sigma*sigma) + mu * lambda * L_CRR);
36  SOPT_MEDIUM_LOG("Lipschitz Constant for CRR = {}", L_CRR);
37  SOPT_MEDIUM_LOG("Step size for CRR = {}", this->step_size);
38  }
39  catch(const std::exception &e)
40  {
42  "Failed to find a Lipschitz constant for the current model. Please ensure that the "
43  "Lipschitz constant is included in the gradient model meta-data with the key "
44  "\"L_CRR\". Setting step size to 1 by default.");
45  SOPT_HIGH_LOG("Exception message retrieving L_CRR: {}", e.what());
46  this->step_size = 1;
47  }
48  }
const T retrieve(const std::string &key) const
Definition: ort_session.h:151
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:225

Member Function Documentation

◆ function()

template<typename SCALAR >
Real sopt::ONNXDifferentiableFunc< SCALAR >::function ( Vector const &  image,
Vector const &  y,
LinearTransform const &  Phi 
)
inlineoverridevirtual

Implements DifferentiableFunc< SCALAR >.

Definition at line 82 of file onnx_differentiable_func.h.

83  {
84  if(infer_square_dimensions) infer_dimensions(image.size());
85  Real Likelihood = 0.5 * ((Phi*image) - y).squaredNorm() / (sigma * sigma);
86  Vector scaled_image = image * mu;
87  std::vector<float> float_image = utilities::imageToFloat(scaled_image);
88  Real Prior = (lambda / mu) * (function_model.compute(float_image, dimensions)[0]);
89  return Likelihood + Prior;
90  }
void infer_dimensions(const size_t image_size)
std::vector< float > compute(std::vector< float > &inputs, const std::vector< int64_t > &inDims) const
Definition: ort_session.h:50
std::vector< float > imageToFloat(sopt::Vector< T > const &image)
Definition: utilities.h:67
sopt::Vector< Scalar > Vector
Definition: inpainting.cc:28

References sopt::ORTsession::compute(), sopt::utilities::imageToFloat(), sopt::ONNXDifferentiableFunc< SCALAR >::infer_dimensions(), and sopt::sigma().

◆ gradient()

template<typename SCALAR >
void sopt::ONNXDifferentiableFunc< SCALAR >::gradient ( Vector &  output,
const Vector &  image,
const Vector &  residual,
const LinearTransform &  Phi 
)
inlineoverridevirtual

Implements DifferentiableFunc< SCALAR >.

Definition at line 55 of file onnx_differentiable_func.h.

57  {
58  if(infer_square_dimensions) infer_dimensions(image.size());
59 
60  output = Phi.adjoint() * (residual / (sigma * sigma)); // L2 norm
61  Vector scaled_image = image * mu;
62  std::vector<float> float_image = utilities::imageToFloat(scaled_image);
63  Vector ANN_gradient = utilities::floatToImage<SCALAR>(gradient_model.compute(float_image, dimensions)); // regulariser
64  output += (ANN_gradient * lambda);
65  }
LinearTransform< VECTOR > adjoint() const
Indirect transform.

References sopt::ORTsession::compute(), sopt::utilities::imageToFloat(), sopt::ONNXDifferentiableFunc< SCALAR >::infer_dimensions(), and sopt::sigma().

◆ infer_dimensions()

template<typename SCALAR >
void sopt::ONNXDifferentiableFunc< SCALAR >::infer_dimensions ( const size_t  image_size)
inline

Definition at line 67 of file onnx_differentiable_func.h.

68  {
69  set_dimensions({1, static_cast<int64_t>(sqrt(image_size)), static_cast<int64_t>(sqrt(image_size))});
70  if(dimensions[1] * dimensions[2] != image_size)
71  {
72  throw std::runtime_error("Image dimensions are not provided and image size is not compatible with a square image.");
73  }
74  infer_square_dimensions = false;
75  }
void set_dimensions(const std::vector< int64_t > &dims)

References sopt::ONNXDifferentiableFunc< SCALAR >::set_dimensions().

Referenced by sopt::ONNXDifferentiableFunc< SCALAR >::function(), and sopt::ONNXDifferentiableFunc< SCALAR >::gradient().

◆ log_message()

template<typename SCALAR >
void sopt::ONNXDifferentiableFunc< SCALAR >::log_message ( ) const
inlineoverridevirtual

Implements DifferentiableFunc< SCALAR >.

Definition at line 50 of file onnx_differentiable_func.h.

51  {
52  SOPT_HIGH_LOG("Using ONNX model differentiable function f(x)");
53  }

References SOPT_HIGH_LOG.

◆ set_dimensions()

template<typename SCALAR >
void sopt::ONNXDifferentiableFunc< SCALAR >::set_dimensions ( const std::vector< int64_t > &  dims)
inline

Definition at line 77 of file onnx_differentiable_func.h.

78  {
79  dimensions = dims;
80  }

Referenced by sopt::ONNXDifferentiableFunc< SCALAR >::infer_dimensions().


The documentation for this class was generated from the following file: