SOPT
Sparse OPTimisation
onnx_differentiable_func.h
Go to the documentation of this file.
1 #ifndef ONNX_DIFFERENTIABLE_FUNC
2 #define ONNX_DIFFERENTIABLE_FUNC
3 
4 #include "sopt/ort_session.h"
6 #include <vector>
7 #include <array>
8 #include <exception>
9 namespace sopt
10 {
11 template<typename SCALAR>
13 {
15  using Real = typename DifferentiableFunc<SCALAR>::Real;
16  using Vector = typename DifferentiableFunc<SCALAR>::t_Vector;
17  using LinearTransform = typename DifferentiableFunc<SCALAR>::t_LinearTransform;
18 
19  public:
20  ONNXDifferentiableFunc(const std::string& function_model_path,
21  const std::string& gradient_model_path,
22  const Real sigma,
23  const Real mu,
24  const Real lambda,
25  const LinearTransform& Phi,
26  const std::vector<int64_t> dimensions = {}): LT(Phi), sigma(sigma), mu(mu), lambda(lambda),
27  function_model(function_model_path),
28  gradient_model(gradient_model_path)
29  {
30  Real L_CRR; // Lipschitz constant
31  if(dimensions.empty()) infer_square_dimensions = true;
32  try
33  {
34  L_CRR = gradient_model.retrieve<double>("L_CRR");
35  this->step_size = 0.98 / (1/(sigma*sigma) + mu * lambda * L_CRR);
36  SOPT_MEDIUM_LOG("Lipschitz Constant for CRR = {}", L_CRR);
37  SOPT_MEDIUM_LOG("Step size for CRR = {}", this->step_size);
38  }
39  catch(const std::exception &e)
40  {
42  "Failed to find a Lipschitz constant for the current model. Please ensure that the "
43  "Lipschitz constant is included in the gradient model meta-data with the key "
44  "\"L_CRR\". Setting step size to 1 by default.");
45  SOPT_HIGH_LOG("Exception message retrieving L_CRR: {}", e.what());
46  this->step_size = 1;
47  }
48  }
49 
50  void log_message() const override
51  {
52  SOPT_HIGH_LOG("Using ONNX model differentiable function f(x)");
53  }
54 
55  void gradient(Vector &output, const Vector &image, const Vector &residual,
56  const LinearTransform &Phi) override
57  {
58  if(infer_square_dimensions) infer_dimensions(image.size());
59 
60  output = Phi.adjoint() * (residual / (sigma * sigma)); // L2 norm
61  Vector scaled_image = image * mu;
62  std::vector<float> float_image = utilities::imageToFloat(scaled_image);
63  Vector ANN_gradient = utilities::floatToImage<SCALAR>(gradient_model.compute(float_image, dimensions)); // regulariser
64  output += (ANN_gradient * lambda);
65  }
66 
67  void infer_dimensions(const size_t image_size)
68  {
69  set_dimensions({1, static_cast<int64_t>(sqrt(image_size)), static_cast<int64_t>(sqrt(image_size))});
70  if(dimensions[1] * dimensions[2] != image_size)
71  {
72  throw std::runtime_error("Image dimensions are not provided and image size is not compatible with a square image.");
73  }
74  infer_square_dimensions = false;
75  }
76 
77  void set_dimensions(const std::vector<int64_t> &dims)
78  {
79  dimensions = dims;
80  }
81 
82  Real function(Vector const &image, Vector const &y, LinearTransform const &Phi) override
83  {
84  if(infer_square_dimensions) infer_dimensions(image.size());
85  Real Likelihood = 0.5 * ((Phi*image) - y).squaredNorm() / (sigma * sigma);
86  Vector scaled_image = image * mu;
87  std::vector<float> float_image = utilities::imageToFloat(scaled_image);
88  Real Prior = (lambda / mu) * (function_model.compute(float_image, dimensions)[0]);
89  return Likelihood + Prior;
90  }
91 
92  private:
93  ORTsession function_model;
94  ORTsession gradient_model;
95  Real sigma;
96  Real mu;
97  Real lambda;
98  std::vector<int64_t> dimensions;
99  const LinearTransform &LT;
100  bool infer_square_dimensions = false;
101 };
102 
103 } // namespace sopt
104 #endif
typename FB::t_LinearTransform t_LinearTransform
typename FB::t_Gradient t_Gradient
typename FB::Real Real
typename FB::t_Vector t_Vector
Joins together direct and indirect operators.
ONNXDifferentiableFunc(const std::string &function_model_path, const std::string &gradient_model_path, const Real sigma, const Real mu, const Real lambda, const LinearTransform &Phi, const std::vector< int64_t > dimensions={})
void gradient(Vector &output, const Vector &image, const Vector &residual, const LinearTransform &Phi) override
void infer_dimensions(const size_t image_size)
void set_dimensions(const std::vector< int64_t > &dims)
Sopt interface class to hold a ONNXrt session.
Definition: ort_session.h:18
const T retrieve(const std::string &key) const
Definition: ort_session.h:151
std::vector< float > compute(std::vector< float > &inputs, const std::vector< int64_t > &inDims) const
Definition: ort_session.h:50
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:225
std::vector< float > imageToFloat(sopt::Vector< T > const &image)
Definition: utilities.h:67
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:17