SOPT
Sparse OPTimisation
tf_model.cc
Go to the documentation of this file.
1 #include <iostream>
2 #include <vector>
3 #include <catch2/catch_all.hpp>
4 
5 #include "sopt/logging.h"
6 #include "sopt/types.h"
7 #include "sopt/utilities.h"
8 #include "sopt/ort_session.h"
9 
10 // This header is not part of the installed sopt interface
11 // It is only present in tests
12 #include "tools_for_tests/directories.h"
14 
15 
16 using Scalar = double;
19 
20 
21 TEST_CASE("Cppflow Model"){
22 
23  // read in image
24  const std::string input_image = "cameraman256";
25  const Image image = sopt::tools::read_standard_tiff(input_image);
26 
27  const int image_rows = image.rows();
28  const int image_cols = image.cols();
29 
30  // Read in model
31  const std::string path(sopt::tools::models_directory() + "/snr_15_model_dynamic.onnx");
32  sopt::ORTsession model(path);
33 
34  // Run model on image
35  const Image output_image = model.compute(image, {1,image_rows,image_cols,1});
36 
37  // compare input image to cleaned output image
38  // calculate mean squared error sum_i ( ( x_true(i) - x_est(i) ) **2 )
39  // check this is less than the number of pixels * 0.01
40 
41  auto mse = (image - output_image).square().sum() / image.size();
42  CAPTURE(mse);
43  CHECK(mse < 0.01);
44 
45 }
sopt::t_real Scalar
Sopt interface class to hold a ONNXrt session.
Definition: ort_session.h:18
std::vector< float > compute(std::vector< float > &inputs, const std::vector< int64_t > &inDims) const
Definition: ort_session.h:50
Image read_standard_tiff(std::string const &name)
Reads tiff image from sopt data directory if it exists.
Definition: tiffwrappers.cc:9
std::string models_directory()
Machine-learning models.
Eigen::Array< T, Eigen::Dynamic, Eigen::Dynamic > Image
A 2-dimensional list of elements of given type.
Definition: types.h:39
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
sopt::Vector< Scalar > Vector
Definition: inpainting.cc:28
sopt::Image< Scalar > Image
Definition: inpainting.cc:30
TEST_CASE("Cppflow Model")
Definition: tf_model.cc:21