SOPT
Sparse OPTimisation
ort_session.h
Go to the documentation of this file.
1 #ifndef SOPT_ORT_SESSION_H
2 #define SOPT_ORT_SESSION_H
3 
4 #include "onnxruntime_cxx_api.h"
5 #include "sopt/logging.h"
6 #include "sopt/utilities.h"
7 #include "sopt/types.h"
8 
9 #include <memory>
10 #include <sstream>
11 #include <stdexcept>
12 #include <string>
13 #include <vector>
14 
15 namespace sopt {
16 
18 class ORTsession {
19 
20  public:
21 
22  ORTsession() = delete;
23 
25  ORTsession(const std::string& filename, const std::string& runname = "soptONNXrt") {
26 
27  // Set-up ONNXrt session
28  _env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, runname.c_str());
29 
30  // Load the model
31  Ort::SessionOptions sessionopts;
32 
33  // Allow the number of threads used by the ONNX runtime to be set by an
34  // environment variable. If unset it will use all available threads by default
35  char* env_num_threads = std::getenv("ORT_NUM_THREADS");
36  if(env_num_threads) {
37  const int num_threads = std::stoi(env_num_threads);
38  sessionopts.SetIntraOpNumThreads(num_threads);
39  SOPT_INFO("ONNXRT using {} IntraOpThreads", num_threads);
40  }
41 
42  _session = std::make_unique<Ort::Session>(*_env, filename.c_str(), sessionopts);
43 
44  // Store model hyperparameters (input/output shape etc.)
45  _retrieveNetworkInfo();
46  }
47 
50  std::vector<float> compute(std::vector<float>& inputs, const std::vector<int64_t>& inDims) const {
51 
52  if (inputs.empty()) {
53  throw std::length_error("Input vector is empty!");
54  }
55  if (inDims.size() != _inShape) {
56  throw std::length_error("Input tensor has incorrect shape! Expected "+std::to_string(_inShape)+" dimensions.");
57  }
58 
59  // reshape flat input vector as tensor and run the model
60  auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
61  auto input_tensor = Ort::Value::CreateTensor<float>(memory_info, inputs.data(), inputs.size(),
62  inDims.data(), inDims.size());
63  auto output_tensors = _session->Run(Ort::RunOptions{nullptr}, _inNames.data(),
64  &input_tensor, _inNames.size(),
65  _outNames.data(), _outNames.size());
66 
67  // retrieve the ouput tensor and return flattened version
68  auto outputInfo = output_tensors[0].GetTensorTypeAndShapeInfo();
69  // Fix negative shape values - this appears to be an artefact of batch size issues
70  int64_t outLen = 1;
71  for (auto& dim : outputInfo.GetShape()) {
72  if (dim < 0) dim = abs(dim);
73  outLen *= dim;
74  }
75  if (outLen == 0) {
76  throw std::length_error("Invalid network structure: Output node with 0-length tensor encountered!");
77  }
78  float* floatarr = output_tensors.front().GetTensorMutableData<float>();
79  std::vector<float> outputs;
80  outputs.assign(floatarr, floatarr + outLen);
81  return outputs;
82  }
83 
85  template<typename T = t_real>
86  Vector<T> compute(const Vector<T>& input, const std::vector<int64_t>& inDims) const {
87  // ONNXrt requires floats as input
88  std::vector<float> flat_input(input.size());
89  for (size_t i = 0; i < input.size(); ++i) {
90  if constexpr(std::is_same<T, t_complex>::value) {
91  flat_input[i] = input[i].real();
92  } else {
93  flat_input[i] = input[i];
94  }
95  }
96  std::vector<float> flat_output = compute(flat_input, inDims);
97  Vector<T> rtn(flat_output.size());
98  for (size_t i = 0; i < flat_output.size(); ++i) {
99  if constexpr(std::is_same<T, t_complex>::value)
100  {
101  rtn[i] = t_complex(flat_output[i], 0);
102  }
103  else
104  {
105  rtn[i] = flat_output[i];
106  }
107  }
108  return rtn;
109  }
110 
112  template<typename T = t_real>
113  Image<T> compute(const Image<T>& input, std::vector<int64_t> inDims = {}) const {
114 
115  // require an output node of the form {1, nRows, nCols}
116  // in order to be able to map this onto a 2D tensor
117  if (inDims.size() && inDims.size() != _inShape) {
118  throw std::length_error("Input tensor has incorrect shape! Expected "+std::to_string(_inShape)+" dimensions.");
119  }
120 
121  // ONNXrt requires floats as input
122  const int nrows = input.rows();
123  const int ncols = input.cols();
124  std::vector<float> flat_input(nrows*ncols);
125  for (int i = 0; i < nrows; ++i) {
126  for (int j = 0; j < ncols; ++j) {
127  flat_input[j*ncols+i] = input(i,j);
128  }
129  }
130  if (inDims.empty()) {
131  while (inDims.size() < _inShape-2) inDims.push_back(1);
132  inDims.push_back(nrows);
133  inDims.push_back(ncols);
134  }
135  std::vector<float> flat_output = compute(flat_input, inDims);
136 
137  std::vector<T> tResults(flat_output.begin(), flat_output.end());
138  Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>> rtn(tResults.data(), nrows, ncols);
139  return rtn;
140  }
141 
143  const bool hasKey(const std::string& key) const {
144  Ort::AllocatorWithDefaultOptions allocator;
145  return (bool)_metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
146  }
147 
150  template <typename T>
151  const T retrieve(const std::string& key) const {
152  Ort::AllocatorWithDefaultOptions allocator;
153  Ort::AllocatedStringPtr res = _metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
154  if (!res) {
155  throw std::runtime_error("Key '"+key+"' not found in network metadata!");
156  }
157  if constexpr (std::is_same<T, std::string>::value) {
158  return res.get();
159  }
160  else {
161  return utilities::lexical_cast<T>(res.get());
162  }
163  }
164 
165 
168  template <typename T>
169  const T retrieve(const std::string& key, const T& defaultreturn) const {
170  try {
171  return retrieve<T>(key);
172  } catch (std::exception& e) {
173  return defaultreturn;
174  }
175  }
176 
177 
178  private:
179 
180  void _retrieveNetworkInfo() {
181 
182  Ort::AllocatorWithDefaultOptions allocator;
183 
184  // Retrieve network metadata
185  _metadata = std::make_unique<Ort::ModelMetadata>(_session->GetModelMetadata());
186 
187  // find out how many input nodes the model expects
188  const size_t num_input_nodes = _session->GetInputCount();
189  if (num_input_nodes == 0) {
190  throw std::length_error("Invalid network structure! Expected at least one input node.");
191  }
192  _inShape = _session->GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape().size();
193  _inNames.reserve(num_input_nodes);
194  _inNamesPtr.reserve(num_input_nodes);
195  SOPT_DEBUG("ORT input nodes = {}", num_input_nodes);
196  for (size_t i = 0; i < num_input_nodes; ++i) {
197  // query input node names
198  auto input_name = _session->GetInputNameAllocated(i, allocator);
199  _inNames.push_back(input_name.get());
200  _inNamesPtr.push_back(std::move(input_name));
201  SOPT_DEBUG("ORT input node {} is called {}", i, _inNames[_inNames.size()-1]);
202  }
203 
204  // find out how many output nodes the model provides
205  const size_t num_output_nodes = _session->GetOutputCount();
206  if (num_output_nodes == 0) {
207  throw std::length_error("Invalid network structure! Expected at least one output node.");
208  }
209  _outNames.reserve(num_output_nodes);
210  _outNamesPtr.reserve(num_output_nodes);
211  SOPT_DEBUG("ORT output nodes = {}", num_output_nodes);
212  for (size_t i = 0; i < num_output_nodes; ++i) {
213  // query input node names
214  auto output_name = _session->GetOutputNameAllocated(i, allocator);
215  _outNames.push_back(output_name.get());
216  _outNamesPtr.push_back(std::move(output_name));
217  SOPT_DEBUG("ORT output node {} is called {}", i, _outNames[_outNames.size()-1]);
218 
219  }
220  }
221 
222  private:
223 
225  std::unique_ptr<Ort::Env> _env;
226 
228  std::unique_ptr<Ort::Session> _session;
229 
231  std::unique_ptr<Ort::ModelMetadata> _metadata;
232 
234  std::vector<Ort::AllocatedStringPtr> _inNamesPtr, _outNamesPtr;
235 
237  std::vector<const char*> _inNames, _outNames;
238 
240  size_t _inShape;
241 
242 };
243 
244 } // end of namespace sopt
245 
246 #endif
Sopt interface class to hold a ONNXrt session.
Definition: ort_session.h:18
ORTsession(const std::string &filename, const std::string &runname="soptONNXrt")
Constructor.
Definition: ort_session.h:25
ORTsession()=delete
const T retrieve(const std::string &key) const
Definition: ort_session.h:151
std::vector< float > compute(std::vector< float > &inputs, const std::vector< int64_t > &inDims) const
Definition: ort_session.h:50
Vector< T > compute(const Vector< T > &input, const std::vector< int64_t > &inDims) const
Variant of compute() using input/output Eigen arrays.
Definition: ort_session.h:86
const T retrieve(const std::string &key, const T &defaultreturn) const
Definition: ort_session.h:169
const bool hasKey(const std::string &key) const
Method to check if key exists in network metadata.
Definition: ort_session.h:143
Image< T > compute(const Image< T > &input, std::vector< int64_t > inDims={}) const
Variant of compute() using input/output Image.
Definition: ort_session.h:113
#define SOPT_INFO(...)
\macro Verbose informational message about normal condition
Definition: logging.h:215
#define SOPT_DEBUG(...)
\macro Output some debugging
Definition: logging.h:217
Eigen::Array< T, Eigen::Dynamic, Eigen::Dynamic > Image
A 2-dimensional list of elements of given type.
Definition: types.h:39
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
std::complex< t_real > t_complex
Root of the type hierarchy for (real) complex numbers.
Definition: types.h:19