1 #ifndef SOPT_ORT_SESSION_H
2 #define SOPT_ORT_SESSION_H
4 #include "onnxruntime_cxx_api.h"
25 ORTsession(
const std::string& filename,
const std::string& runname =
"soptONNXrt") {
28 _env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, runname.c_str());
31 Ort::SessionOptions sessionopts;
35 char* env_num_threads = std::getenv(
"ORT_NUM_THREADS");
37 const int num_threads = std::stoi(env_num_threads);
38 sessionopts.SetIntraOpNumThreads(num_threads);
39 SOPT_INFO(
"ONNXRT using {} IntraOpThreads", num_threads);
42 _session = std::make_unique<Ort::Session>(*_env, filename.c_str(), sessionopts);
45 _retrieveNetworkInfo();
50 std::vector<float>
compute(std::vector<float>& inputs,
const std::vector<int64_t>& inDims)
const {
53 throw std::length_error(
"Input vector is empty!");
55 if (inDims.size() != _inShape) {
56 throw std::length_error(
"Input tensor has incorrect shape! Expected "+std::to_string(_inShape)+
" dimensions.");
60 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
61 auto input_tensor = Ort::Value::CreateTensor<float>(memory_info, inputs.data(), inputs.size(),
62 inDims.data(), inDims.size());
63 auto output_tensors = _session->Run(Ort::RunOptions{
nullptr}, _inNames.data(),
64 &input_tensor, _inNames.size(),
65 _outNames.data(), _outNames.size());
68 auto outputInfo = output_tensors[0].GetTensorTypeAndShapeInfo();
71 for (
auto& dim : outputInfo.GetShape()) {
72 if (dim < 0) dim = abs(dim);
76 throw std::length_error(
"Invalid network structure: Output node with 0-length tensor encountered!");
78 float* floatarr = output_tensors.front().GetTensorMutableData<
float>();
79 std::vector<float> outputs;
80 outputs.assign(floatarr, floatarr + outLen);
85 template<
typename T = t_real>
88 std::vector<float> flat_input(input.size());
89 for (
size_t i = 0; i < input.size(); ++i) {
90 if constexpr(std::is_same<T, t_complex>::value) {
91 flat_input[i] = input[i].real();
93 flat_input[i] = input[i];
96 std::vector<float> flat_output =
compute(flat_input, inDims);
98 for (
size_t i = 0; i < flat_output.size(); ++i) {
99 if constexpr(std::is_same<T, t_complex>::value)
105 rtn[i] = flat_output[i];
112 template<
typename T = t_real>
117 if (inDims.size() && inDims.size() != _inShape) {
118 throw std::length_error(
"Input tensor has incorrect shape! Expected "+std::to_string(_inShape)+
" dimensions.");
122 const int nrows = input.rows();
123 const int ncols = input.cols();
124 std::vector<float> flat_input(nrows*ncols);
125 for (
int i = 0; i < nrows; ++i) {
126 for (
int j = 0; j < ncols; ++j) {
127 flat_input[j*ncols+i] = input(i,j);
130 if (inDims.empty()) {
131 while (inDims.size() < _inShape-2) inDims.push_back(1);
132 inDims.push_back(nrows);
133 inDims.push_back(ncols);
135 std::vector<float> flat_output =
compute(flat_input, inDims);
137 std::vector<T> tResults(flat_output.begin(), flat_output.end());
138 Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>> rtn(tResults.data(), nrows, ncols);
143 const bool hasKey(
const std::string& key)
const {
144 Ort::AllocatorWithDefaultOptions allocator;
145 return (
bool)_metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
150 template <
typename T>
152 Ort::AllocatorWithDefaultOptions allocator;
153 Ort::AllocatedStringPtr res = _metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
155 throw std::runtime_error(
"Key '"+key+
"' not found in network metadata!");
157 if constexpr (std::is_same<T, std::string>::value) {
161 return utilities::lexical_cast<T>(res.get());
168 template <
typename T>
169 const T
retrieve(
const std::string& key,
const T& defaultreturn)
const {
171 return retrieve<T>(key);
172 }
catch (std::exception& e) {
173 return defaultreturn;
180 void _retrieveNetworkInfo() {
182 Ort::AllocatorWithDefaultOptions allocator;
185 _metadata = std::make_unique<Ort::ModelMetadata>(_session->GetModelMetadata());
188 const size_t num_input_nodes = _session->GetInputCount();
189 if (num_input_nodes == 0) {
190 throw std::length_error(
"Invalid network structure! Expected at least one input node.");
192 _inShape = _session->GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape().size();
193 _inNames.reserve(num_input_nodes);
194 _inNamesPtr.reserve(num_input_nodes);
195 SOPT_DEBUG(
"ORT input nodes = {}", num_input_nodes);
196 for (
size_t i = 0; i < num_input_nodes; ++i) {
198 auto input_name = _session->GetInputNameAllocated(i, allocator);
199 _inNames.push_back(input_name.get());
200 _inNamesPtr.push_back(std::move(input_name));
201 SOPT_DEBUG(
"ORT input node {} is called {}", i, _inNames[_inNames.size()-1]);
205 const size_t num_output_nodes = _session->GetOutputCount();
206 if (num_output_nodes == 0) {
207 throw std::length_error(
"Invalid network structure! Expected at least one output node.");
209 _outNames.reserve(num_output_nodes);
210 _outNamesPtr.reserve(num_output_nodes);
211 SOPT_DEBUG(
"ORT output nodes = {}", num_output_nodes);
212 for (
size_t i = 0; i < num_output_nodes; ++i) {
214 auto output_name = _session->GetOutputNameAllocated(i, allocator);
215 _outNames.push_back(output_name.get());
216 _outNamesPtr.push_back(std::move(output_name));
217 SOPT_DEBUG(
"ORT output node {} is called {}", i, _outNames[_outNames.size()-1]);
225 std::unique_ptr<Ort::Env> _env;
228 std::unique_ptr<Ort::Session> _session;
231 std::unique_ptr<Ort::ModelMetadata> _metadata;
234 std::vector<Ort::AllocatedStringPtr> _inNamesPtr, _outNamesPtr;
237 std::vector<const char*> _inNames, _outNames;
Sopt interface class to hold a ONNXrt session.
ORTsession(const std::string &filename, const std::string &runname="soptONNXrt")
Constructor.
const T retrieve(const std::string &key) const
std::vector< float > compute(std::vector< float > &inputs, const std::vector< int64_t > &inDims) const
Vector< T > compute(const Vector< T > &input, const std::vector< int64_t > &inDims) const
Variant of compute() using input/output Eigen arrays.
const T retrieve(const std::string &key, const T &defaultreturn) const
const bool hasKey(const std::string &key) const
Method to check if key exists in network metadata.
Image< T > compute(const Image< T > &input, std::vector< int64_t > inDims={}) const
Variant of compute() using input/output Image.
#define SOPT_INFO(...)
\macro Verbose informational message about normal condition
#define SOPT_DEBUG(...)
\macro Output some debugging
Eigen::Array< T, Eigen::Dynamic, Eigen::Dynamic > Image
A 2-dimensional list of elements of given type.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
std::complex< t_real > t_complex
Root of the type hierarchy for (real) complex numbers.