SOPT
Sparse OPTimisation
Public Member Functions | List of all members
sopt::ORTsession Class Reference

Sopt interface class to hold a ONNXrt session. More...

#include <ort_session.h>

Public Member Functions

 ORTsession ()=delete
 
 ORTsession (const std::string &filename, const std::string &runname="soptONNXrt")
 Constructor. More...
 
std::vector< float > compute (std::vector< float > &inputs, const std::vector< int64_t > &inDims) const
 
template<typename T = t_real>
Vector< T > compute (const Vector< T > &input, const std::vector< int64_t > &inDims) const
 Variant of compute() using input/output Eigen arrays. More...
 
template<typename T = t_real>
Image< T > compute (const Image< T > &input, std::vector< int64_t > inDims={}) const
 Variant of compute() using input/output Image. More...
 
const bool hasKey (const std::string &key) const
 Method to check if key exists in network metadata. More...
 
template<typename T >
const T retrieve (const std::string &key) const
 
template<typename T >
const T retrieve (const std::string &key, const T &defaultreturn) const
 

Detailed Description

Sopt interface class to hold a ONNXrt session.

Definition at line 18 of file ort_session.h.

Constructor & Destructor Documentation

◆ ORTsession() [1/2]

sopt::ORTsession::ORTsession ( )
delete

◆ ORTsession() [2/2]

sopt::ORTsession::ORTsession ( const std::string &  filename,
const std::string &  runname = "soptONNXrt" 
)
inline

Constructor.

Definition at line 25 of file ort_session.h.

25  {
26 
27  // Set-up ONNXrt session
28  _env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, runname.c_str());
29 
30  // Load the model
31  Ort::SessionOptions sessionopts;
32 
33  // Allow the number of threads used by the ONNX runtime to be set by an
34  // environment variable. If unset it will use all available threads by default
35  char* env_num_threads = std::getenv("ORT_NUM_THREADS");
36  if(env_num_threads) {
37  const int num_threads = std::stoi(env_num_threads);
38  sessionopts.SetIntraOpNumThreads(num_threads);
39  SOPT_INFO("ONNXRT using {} IntraOpThreads", num_threads);
40  }
41 
42  _session = std::make_unique<Ort::Session>(*_env, filename.c_str(), sessionopts);
43 
44  // Store model hyperparameters (input/output shape etc.)
45  _retrieveNetworkInfo();
46  }
#define SOPT_INFO(...)
\macro Verbose informational message about normal condition
Definition: logging.h:215

References SOPT_INFO.

Member Function Documentation

◆ compute() [1/3]

template<typename T = t_real>
Image<T> sopt::ORTsession::compute ( const Image< T > &  input,
std::vector< int64_t >  inDims = {} 
) const
inline

Variant of compute() using input/output Image.

Definition at line 113 of file ort_session.h.

113  {}) const {
114 
115  // require an output node of the form {1, nRows, nCols}
116  // in order to be able to map this onto a 2D tensor
117  if (inDims.size() && inDims.size() != _inShape) {
118  throw std::length_error("Input tensor has incorrect shape! Expected "+std::to_string(_inShape)+" dimensions.");
119  }
120 
121  // ONNXrt requires floats as input
122  const int nrows = input.rows();
123  const int ncols = input.cols();
124  std::vector<float> flat_input(nrows*ncols);
125  for (int i = 0; i < nrows; ++i) {
126  for (int j = 0; j < ncols; ++j) {
127  flat_input[j*ncols+i] = input(i,j);
128  }
129  }
130  if (inDims.empty()) {
131  while (inDims.size() < _inShape-2) inDims.push_back(1);
132  inDims.push_back(nrows);
133  inDims.push_back(ncols);
134  }
135  std::vector<float> flat_output = compute(flat_input, inDims);
136 
137  std::vector<T> tResults(flat_output.begin(), flat_output.end());
138  Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>> rtn(tResults.data(), nrows, ncols);
139  return rtn;
140  }
std::vector< float > compute(std::vector< float > &inputs, const std::vector< int64_t > &inDims) const
Definition: ort_session.h:50

◆ compute() [2/3]

template<typename T = t_real>
Vector<T> sopt::ORTsession::compute ( const Vector< T > &  input,
const std::vector< int64_t > &  inDims 
) const
inline

Variant of compute() using input/output Eigen arrays.

Definition at line 86 of file ort_session.h.

86  {
87  // ONNXrt requires floats as input
88  std::vector<float> flat_input(input.size());
89  for (size_t i = 0; i < input.size(); ++i) {
90  if constexpr(std::is_same<T, t_complex>::value) {
91  flat_input[i] = input[i].real();
92  } else {
93  flat_input[i] = input[i];
94  }
95  }
96  std::vector<float> flat_output = compute(flat_input, inDims);
97  Vector<T> rtn(flat_output.size());
98  for (size_t i = 0; i < flat_output.size(); ++i) {
99  if constexpr(std::is_same<T, t_complex>::value)
100  {
101  rtn[i] = t_complex(flat_output[i], 0);
102  }
103  else
104  {
105  rtn[i] = flat_output[i];
106  }
107  }
108  return rtn;
109  }
std::complex< t_real > t_complex
Root of the type hierarchy for (real) complex numbers.
Definition: types.h:19
sopt::Vector< Scalar > Vector
Definition: inpainting.cc:28

References compute().

◆ compute() [3/3]

std::vector<float> sopt::ORTsession::compute ( std::vector< float > &  inputs,
const std::vector< int64_t > &  inDims 
) const
inline

Evaluates the network in the forward direction using a flattened tensor as input

Definition at line 50 of file ort_session.h.

50  {
51 
52  if (inputs.empty()) {
53  throw std::length_error("Input vector is empty!");
54  }
55  if (inDims.size() != _inShape) {
56  throw std::length_error("Input tensor has incorrect shape! Expected "+std::to_string(_inShape)+" dimensions.");
57  }
58 
59  // reshape flat input vector as tensor and run the model
60  auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
61  auto input_tensor = Ort::Value::CreateTensor<float>(memory_info, inputs.data(), inputs.size(),
62  inDims.data(), inDims.size());
63  auto output_tensors = _session->Run(Ort::RunOptions{nullptr}, _inNames.data(),
64  &input_tensor, _inNames.size(),
65  _outNames.data(), _outNames.size());
66 
67  // retrieve the ouput tensor and return flattened version
68  auto outputInfo = output_tensors[0].GetTensorTypeAndShapeInfo();
69  // Fix negative shape values - this appears to be an artefact of batch size issues
70  int64_t outLen = 1;
71  for (auto& dim : outputInfo.GetShape()) {
72  if (dim < 0) dim = abs(dim);
73  outLen *= dim;
74  }
75  if (outLen == 0) {
76  throw std::length_error("Invalid network structure: Output node with 0-length tensor encountered!");
77  }
78  float* floatarr = output_tensors.front().GetTensorMutableData<float>();
79  std::vector<float> outputs;
80  outputs.assign(floatarr, floatarr + outLen);
81  return outputs;
82  }

Referenced by compute(), sopt::ONNXDifferentiableFunc< SCALAR >::function(), sopt::ONNXDifferentiableFunc< SCALAR >::gradient(), and TEST_CASE().

◆ hasKey()

const bool sopt::ORTsession::hasKey ( const std::string &  key) const
inline

Method to check if key exists in network metadata.

Definition at line 143 of file ort_session.h.

143  {
144  Ort::AllocatorWithDefaultOptions allocator;
145  return (bool)_metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
146  }

◆ retrieve() [1/2]

template<typename T >
const T sopt::ORTsession::retrieve ( const std::string &  key) const
inline

Method to retrieve value associated with key from network metadata and return value as type T

Definition at line 151 of file ort_session.h.

151  {
152  Ort::AllocatorWithDefaultOptions allocator;
153  Ort::AllocatedStringPtr res = _metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
154  if (!res) {
155  throw std::runtime_error("Key '"+key+"' not found in network metadata!");
156  }
157  if constexpr (std::is_same<T, std::string>::value) {
158  return res.get();
159  }
160  else {
161  return utilities::lexical_cast<T>(res.get());
162  }
163  }

Referenced by TEST_CASE().

◆ retrieve() [2/2]

template<typename T >
const T sopt::ORTsession::retrieve ( const std::string &  key,
const T &  defaultreturn 
) const
inline

Variation of retrieve method that falls back to defaultreturn if key cannot be found

Definition at line 169 of file ort_session.h.

169  {
170  try {
171  return retrieve<T>(key);
172  } catch (std::exception& e) {
173  return defaultreturn;
174  }
175  }

The documentation for this class was generated from the following file: