SOPT
Sparse OPTimisation
Functions
ort_model.cc File Reference
#include <catch2/catch_all.hpp>
#include <Eigen/Core>
#include "sopt/ort_session.h"
#include "sopt/utilities.h"
#include "tools_for_tests/directories.h"
+ Include dependency graph for ort_model.cc:

Go to the source code of this file.

Functions

 TEST_CASE ("Load an example ORT model", "[ONNXrt]")
 
 TEST_CASE ("Check metadata of an example ORT model", "[ONNXrt]")
 
 TEST_CASE ("Check forward folding of an example ORT model using std::vectors", "[ONNXrt]")
 
 TEST_CASE ("Check forward folding of an example ORT model using sopt::Vectors", "[ONNXrt]")
 

Function Documentation

◆ TEST_CASE() [1/4]

TEST_CASE ( "Check forward folding of an example ORT model using sopt::Vectors"  ,
""  [ONNXrt] 
)

Definition at line 59 of file ort_model.cc.

59  {
60 
61  const std::string path(sopt::tools::models_directory());
62  ORTsession model(path+"/example_grad_dynamic_CRR_sigma_5_t_5.onnx");
63 
64  const size_t nROWS = 256, nCOLS = 256;
65  const size_t input_size = 1 * nROWS * nCOLS;
66  // This network metadata contains an example
67  // input vector and reference output vector
68  std::vector<float> input_values, expected_values;
69  input_values.reserve(input_size);
70  expected_values.reserve(input_size);
71  input_values = utilities::split<float>(model.retrieve<std::string>("x_init"), ",");
72  expected_values = utilities::split<float>(model.retrieve<std::string>("x_exp"), ",");
73 
74  // convert flat vectors to sopt::Vector
75  Vector<double> inputT(input_values.size());
76  Vector<double> refT(expected_values.size());
77  for (size_t i = 0; i < input_values.size(); ++i) {
78  inputT[i] = input_values[i];
79  refT[i] = expected_values[i];
80  }
81 
82  // forward fold using sopt::Vector
83  Vector<double> outputT = model.compute(inputT,{1,nROWS,nCOLS});
84 
85  // compare output tensor to reference tensor
86  // calculate mean squared error sum_i ( ( x_true(i) - x_est(i) ) **2 )
87  // check this is less than the number of pixels * 0.01
88  auto mse = (outputT - refT).squaredNorm() / outputT.size();
89  CAPTURE(mse);
90  CHECK(mse < 0.01);
91 }
Sopt interface class to hold a ONNXrt session.
Definition: ort_session.h:18
std::string models_directory()
Machine-learning models.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24

References sopt::ORTsession::compute(), sopt::tools::models_directory(), and sopt::ORTsession::retrieve().

◆ TEST_CASE() [2/4]

TEST_CASE ( "Check forward folding of an example ORT model using std::vectors"  ,
""  [ONNXrt] 
)

Definition at line 28 of file ort_model.cc.

28  {
29 
30  const std::string path(sopt::tools::models_directory());
31  ORTsession model(path+"/example_grad_dynamic_CRR_sigma_5_t_5.onnx");
32 
33  const size_t nROWS = 256, nCOLS = 256;
34  const size_t input_size = 1 * nROWS * nCOLS;
35  // This network metadata contains an example
36  // input vector and reference output vector
37  std::vector<float> input_values, expected_values;
38  input_values.reserve(input_size);
39  expected_values.reserve(input_size);
40  input_values = utilities::split<float>(model.retrieve<std::string>("x_init"), ",");
41  expected_values = utilities::split<float>(model.retrieve<std::string>("x_exp"), ",");
42 
43  // forward fold
44  std::vector<float> output_values = model.compute(input_values, {1,nROWS,nCOLS});
45 
46  // calculate mean squared error sum_i ( ( x_true(i) - x_est(i) ) **2 )
47  // check this is less than the number of pixels * 0.01
48  double mse = 0.0;
49  for (size_t i = 0; i < output_values.size(); ++i) {
50  double diff = output_values[i] - expected_values[i];
51  mse += diff*diff;
52  }
53  mse /= output_values.size();
54 
55  CAPTURE(mse);
56  CHECK(mse < 0.01);
57 }
Vector< T > diff(const Vector< T > &x)
Numerical derivative of 1d vector.

References sopt::ORTsession::compute(), sopt::gradient_operator::diff(), sopt::tools::models_directory(), and sopt::ORTsession::retrieve().

◆ TEST_CASE() [3/4]

TEST_CASE ( "Check metadata of an example ORT model"  ,
""  [ONNXrt] 
)

Definition at line 17 of file ort_model.cc.

17  {
18 
19  const std::string path(sopt::tools::models_directory());
20  ORTsession model(path+"/example_grad_dynamic_CRR_sigma_5_t_5.onnx");
21 
22  const double L = model.retrieve<double>("L_CRR");
23  const double L_ref = 0.769605;
24 
25  CHECK(abs(L_ref - L) < 1e-6*L_ref);
26 }

References sopt::tools::models_directory(), and sopt::ORTsession::retrieve().

◆ TEST_CASE() [4/4]

TEST_CASE ( "Load an example ORT model"  ,
""  [ONNXrt] 
)

Definition at line 9 of file ort_model.cc.

9  {
10 
11  const std::string path(sopt::tools::models_directory());
12  ORTsession model(path+"/example_grad_dynamic_CRR_sigma_5_t_5.onnx");
13 
14  CHECK(true);
15 }

References sopt::tools::models_directory().