1 #include <catch2/catch_all.hpp>
5 #include "tools_for_tests/directories.h"
9 TEST_CASE(
"Load an example ORT model",
"[ONNXrt]") {
12 ORTsession model(path+
"/example_grad_dynamic_CRR_sigma_5_t_5.onnx");
17 TEST_CASE(
"Check metadata of an example ORT model",
"[ONNXrt]") {
20 ORTsession model(path+
"/example_grad_dynamic_CRR_sigma_5_t_5.onnx");
22 const double L = model.
retrieve<
double>(
"L_CRR");
23 const double L_ref = 0.769605;
25 CHECK(abs(L_ref - L) < 1e-6*L_ref);
28 TEST_CASE(
"Check forward folding of an example ORT model using std::vectors",
"[ONNXrt]") {
31 ORTsession model(path+
"/example_grad_dynamic_CRR_sigma_5_t_5.onnx");
33 const size_t nROWS = 256, nCOLS = 256;
34 const size_t input_size = 1 * nROWS * nCOLS;
37 std::vector<float> input_values, expected_values;
38 input_values.reserve(input_size);
39 expected_values.reserve(input_size);
40 input_values = utilities::split<float>(model.
retrieve<std::string>(
"x_init"),
",");
41 expected_values = utilities::split<float>(model.
retrieve<std::string>(
"x_exp"),
",");
44 std::vector<float> output_values = model.
compute(input_values, {1,nROWS,nCOLS});
49 for (
size_t i = 0; i < output_values.size(); ++i) {
50 double diff = output_values[i] - expected_values[i];
53 mse /= output_values.size();
59 TEST_CASE(
"Check forward folding of an example ORT model using sopt::Vectors",
"[ONNXrt]") {
62 ORTsession model(path+
"/example_grad_dynamic_CRR_sigma_5_t_5.onnx");
64 const size_t nROWS = 256, nCOLS = 256;
65 const size_t input_size = 1 * nROWS * nCOLS;
68 std::vector<float> input_values, expected_values;
69 input_values.reserve(input_size);
70 expected_values.reserve(input_size);
71 input_values = utilities::split<float>(model.
retrieve<std::string>(
"x_init"),
",");
72 expected_values = utilities::split<float>(model.
retrieve<std::string>(
"x_exp"),
",");
77 for (
size_t i = 0; i < input_values.size(); ++i) {
78 inputT[i] = input_values[i];
79 refT[i] = expected_values[i];
88 auto mse = (outputT - refT).squaredNorm() / outputT.size();
Sopt interface class to hold a ONNXrt session.
const T retrieve(const std::string &key) const
std::vector< float > compute(std::vector< float > &inputs, const std::vector< int64_t > &inDims) const
Vector< T > diff(const Vector< T > &x)
Numerical derivative of 1d vector.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
TEST_CASE("Load an example ORT model", "[ONNXrt]")