1 #ifndef ONNX_DIFFERENTIABLE_FUNC
2 #define ONNX_DIFFERENTIABLE_FUNC
11 template<
typename SCALAR>
21 const std::string& gradient_model_path,
25 const LinearTransform& Phi,
26 const std::vector<int64_t> dimensions = {}): LT(Phi),
sigma(sigma), mu(mu), lambda(lambda),
27 function_model(function_model_path),
28 gradient_model(gradient_model_path)
31 if(dimensions.empty()) infer_square_dimensions =
true;
34 L_CRR = gradient_model.
retrieve<
double>(
"L_CRR");
35 this->step_size = 0.98 / (1/(
sigma*
sigma) + mu * lambda * L_CRR);
39 catch(
const std::exception &e)
42 "Failed to find a Lipschitz constant for the current model. Please ensure that the "
43 "Lipschitz constant is included in the gradient model meta-data with the key "
44 "\"L_CRR\". Setting step size to 1 by default.");
45 SOPT_HIGH_LOG(
"Exception message retrieving L_CRR: {}", e.what());
52 SOPT_HIGH_LOG(
"Using ONNX model differentiable function f(x)");
55 void gradient(Vector &output,
const Vector &image,
const Vector &residual,
56 const LinearTransform &Phi)
override
60 output = Phi.adjoint() * (residual / (
sigma *
sigma));
61 Vector scaled_image = image * mu;
63 Vector ANN_gradient = utilities::floatToImage<SCALAR>(gradient_model.
compute(float_image, dimensions));
64 output += (ANN_gradient * lambda);
69 set_dimensions({1,
static_cast<int64_t
>(sqrt(image_size)),
static_cast<int64_t
>(sqrt(image_size))});
70 if(dimensions[1] * dimensions[2] != image_size)
72 throw std::runtime_error(
"Image dimensions are not provided and image size is not compatible with a square image.");
74 infer_square_dimensions =
false;
82 Real
function(Vector
const &image, Vector
const &y, LinearTransform
const &Phi)
override
85 Real Likelihood = 0.5 * ((Phi*image) - y).squaredNorm() / (
sigma *
sigma);
86 Vector scaled_image = image * mu;
88 Real Prior = (lambda / mu) * (function_model.
compute(float_image, dimensions)[0]);
89 return Likelihood + Prior;
98 std::vector<int64_t> dimensions;
100 bool infer_square_dimensions =
false;
typename FB::t_LinearTransform t_LinearTransform
typename FB::t_Gradient t_Gradient
typename FB::t_Vector t_Vector
ONNXDifferentiableFunc(const std::string &function_model_path, const std::string &gradient_model_path, const Real sigma, const Real mu, const Real lambda, const LinearTransform &Phi, const std::vector< int64_t > dimensions={})
void gradient(Vector &output, const Vector &image, const Vector &residual, const LinearTransform &Phi) override
void log_message() const override
void infer_dimensions(const size_t image_size)
void set_dimensions(const std::vector< int64_t > &dims)
Sopt interface class to hold a ONNXrt session.
const T retrieve(const std::string &key) const
std::vector< float > compute(std::vector< float > &inputs, const std::vector< int64_t > &inDims) const
#define SOPT_HIGH_LOG(...)
High priority message.
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
std::vector< float > imageToFloat(sopt::Vector< T > const &image)
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)