SOPT
Sparse OPTimisation
ort_model.cc
Go to the documentation of this file.
1 #include <catch2/catch_all.hpp>
2 #include <Eigen/Core>
3 #include "sopt/ort_session.h"
4 #include "sopt/utilities.h"
5 #include "tools_for_tests/directories.h"
6 
7 using namespace sopt;
8 
9 TEST_CASE("Load an example ORT model", "[ONNXrt]") {
10 
11  const std::string path(sopt::tools::models_directory());
12  ORTsession model(path+"/example_grad_dynamic_CRR_sigma_5_t_5.onnx");
13 
14  CHECK(true);
15 }
16 
17 TEST_CASE("Check metadata of an example ORT model", "[ONNXrt]") {
18 
19  const std::string path(sopt::tools::models_directory());
20  ORTsession model(path+"/example_grad_dynamic_CRR_sigma_5_t_5.onnx");
21 
22  const double L = model.retrieve<double>("L_CRR");
23  const double L_ref = 0.769605;
24 
25  CHECK(abs(L_ref - L) < 1e-6*L_ref);
26 }
27 
28 TEST_CASE("Check forward folding of an example ORT model using std::vectors", "[ONNXrt]") {
29 
30  const std::string path(sopt::tools::models_directory());
31  ORTsession model(path+"/example_grad_dynamic_CRR_sigma_5_t_5.onnx");
32 
33  const size_t nROWS = 256, nCOLS = 256;
34  const size_t input_size = 1 * nROWS * nCOLS;
35  // This network metadata contains an example
36  // input vector and reference output vector
37  std::vector<float> input_values, expected_values;
38  input_values.reserve(input_size);
39  expected_values.reserve(input_size);
40  input_values = utilities::split<float>(model.retrieve<std::string>("x_init"), ",");
41  expected_values = utilities::split<float>(model.retrieve<std::string>("x_exp"), ",");
42 
43  // forward fold
44  std::vector<float> output_values = model.compute(input_values, {1,nROWS,nCOLS});
45 
46  // calculate mean squared error sum_i ( ( x_true(i) - x_est(i) ) **2 )
47  // check this is less than the number of pixels * 0.01
48  double mse = 0.0;
49  for (size_t i = 0; i < output_values.size(); ++i) {
50  double diff = output_values[i] - expected_values[i];
51  mse += diff*diff;
52  }
53  mse /= output_values.size();
54 
55  CAPTURE(mse);
56  CHECK(mse < 0.01);
57 }
58 
59 TEST_CASE("Check forward folding of an example ORT model using sopt::Vectors", "[ONNXrt]") {
60 
61  const std::string path(sopt::tools::models_directory());
62  ORTsession model(path+"/example_grad_dynamic_CRR_sigma_5_t_5.onnx");
63 
64  const size_t nROWS = 256, nCOLS = 256;
65  const size_t input_size = 1 * nROWS * nCOLS;
66  // This network metadata contains an example
67  // input vector and reference output vector
68  std::vector<float> input_values, expected_values;
69  input_values.reserve(input_size);
70  expected_values.reserve(input_size);
71  input_values = utilities::split<float>(model.retrieve<std::string>("x_init"), ",");
72  expected_values = utilities::split<float>(model.retrieve<std::string>("x_exp"), ",");
73 
74  // convert flat vectors to sopt::Vector
75  Vector<double> inputT(input_values.size());
76  Vector<double> refT(expected_values.size());
77  for (size_t i = 0; i < input_values.size(); ++i) {
78  inputT[i] = input_values[i];
79  refT[i] = expected_values[i];
80  }
81 
82  // forward fold using sopt::Vector
83  Vector<double> outputT = model.compute(inputT,{1,nROWS,nCOLS});
84 
85  // compare output tensor to reference tensor
86  // calculate mean squared error sum_i ( ( x_true(i) - x_est(i) ) **2 )
87  // check this is less than the number of pixels * 0.01
88  auto mse = (outputT - refT).squaredNorm() / outputT.size();
89  CAPTURE(mse);
90  CHECK(mse < 0.01);
91 }
Sopt interface class to hold a ONNXrt session.
Definition: ort_session.h:18
const T retrieve(const std::string &key) const
Definition: ort_session.h:151
std::vector< float > compute(std::vector< float > &inputs, const std::vector< int64_t > &inDims) const
Definition: ort_session.h:50
Vector< T > diff(const Vector< T > &x)
Numerical derivative of 1d vector.
std::string models_directory()
Machine-learning models.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
TEST_CASE("Load an example ORT model", "[ONNXrt]")
Definition: ort_model.cc:9