Sopt interface class to hold a ONNXrt session.
More...
#include <ort_session.h>
|
| ORTsession ()=delete |
|
| ORTsession (const std::string &filename, const std::string &runname="soptONNXrt") |
| Constructor. More...
|
|
std::vector< float > | compute (std::vector< float > &inputs, const std::vector< int64_t > &inDims) const |
|
template<typename T = t_real> |
Vector< T > | compute (const Vector< T > &input, const std::vector< int64_t > &inDims) const |
| Variant of compute() using input/output Eigen arrays. More...
|
|
template<typename T = t_real> |
Image< T > | compute (const Image< T > &input, std::vector< int64_t > inDims={}) const |
| Variant of compute() using input/output Image. More...
|
|
const bool | hasKey (const std::string &key) const |
| Method to check if key exists in network metadata. More...
|
|
template<typename T > |
const T | retrieve (const std::string &key) const |
|
template<typename T > |
const T | retrieve (const std::string &key, const T &defaultreturn) const |
|
Sopt interface class to hold a ONNXrt session.
Definition at line 18 of file ort_session.h.
◆ ORTsession() [1/2]
sopt::ORTsession::ORTsession |
( |
| ) |
|
|
delete |
◆ ORTsession() [2/2]
sopt::ORTsession::ORTsession |
( |
const std::string & |
filename, |
|
|
const std::string & |
runname = "soptONNXrt" |
|
) |
| |
|
inline |
Constructor.
Definition at line 25 of file ort_session.h.
28 _env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, runname.c_str());
31 Ort::SessionOptions sessionopts;
35 char* env_num_threads = std::getenv(
"ORT_NUM_THREADS");
37 const int num_threads = std::stoi(env_num_threads);
38 sessionopts.SetIntraOpNumThreads(num_threads);
39 SOPT_INFO(
"ONNXRT using {} IntraOpThreads", num_threads);
42 _session = std::make_unique<Ort::Session>(*_env, filename.c_str(), sessionopts);
45 _retrieveNetworkInfo();
#define SOPT_INFO(...)
\macro Verbose informational message about normal condition
References SOPT_INFO.
◆ compute() [1/3]
template<typename T = t_real>
Image<T> sopt::ORTsession::compute |
( |
const Image< T > & |
input, |
|
|
std::vector< int64_t > |
inDims = {} |
|
) |
| const |
|
inline |
Variant of compute() using input/output Image.
Definition at line 113 of file ort_session.h.
117 if (inDims.size() && inDims.size() != _inShape) {
118 throw std::length_error(
"Input tensor has incorrect shape! Expected "+std::to_string(_inShape)+
" dimensions.");
122 const int nrows = input.rows();
123 const int ncols = input.cols();
124 std::vector<float> flat_input(nrows*ncols);
125 for (
int i = 0; i < nrows; ++i) {
126 for (
int j = 0; j < ncols; ++j) {
127 flat_input[j*ncols+i] = input(i,j);
130 if (inDims.empty()) {
131 while (inDims.size() < _inShape-2) inDims.push_back(1);
132 inDims.push_back(nrows);
133 inDims.push_back(ncols);
135 std::vector<float> flat_output =
compute(flat_input, inDims);
137 std::vector<T> tResults(flat_output.begin(), flat_output.end());
138 Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>> rtn(tResults.data(), nrows, ncols);
std::vector< float > compute(std::vector< float > &inputs, const std::vector< int64_t > &inDims) const
◆ compute() [2/3]
template<typename T = t_real>
Vector<T> sopt::ORTsession::compute |
( |
const Vector< T > & |
input, |
|
|
const std::vector< int64_t > & |
inDims |
|
) |
| const |
|
inline |
Variant of compute() using input/output Eigen arrays.
Definition at line 86 of file ort_session.h.
88 std::vector<float> flat_input(input.size());
89 for (
size_t i = 0; i < input.size(); ++i) {
90 if constexpr(std::is_same<T, t_complex>::value) {
91 flat_input[i] = input[i].real();
93 flat_input[i] = input[i];
96 std::vector<float> flat_output =
compute(flat_input, inDims);
98 for (
size_t i = 0; i < flat_output.size(); ++i) {
99 if constexpr(std::is_same<T, t_complex>::value)
105 rtn[i] = flat_output[i];
std::complex< t_real > t_complex
Root of the type hierarchy for (real) complex numbers.
sopt::Vector< Scalar > Vector
References compute().
◆ compute() [3/3]
std::vector<float> sopt::ORTsession::compute |
( |
std::vector< float > & |
inputs, |
|
|
const std::vector< int64_t > & |
inDims |
|
) |
| const |
|
inline |
Evaluates the network in the forward direction using a flattened tensor as input
Definition at line 50 of file ort_session.h.
53 throw std::length_error(
"Input vector is empty!");
55 if (inDims.size() != _inShape) {
56 throw std::length_error(
"Input tensor has incorrect shape! Expected "+std::to_string(_inShape)+
" dimensions.");
60 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
61 auto input_tensor = Ort::Value::CreateTensor<float>(memory_info, inputs.data(), inputs.size(),
62 inDims.data(), inDims.size());
63 auto output_tensors = _session->Run(Ort::RunOptions{
nullptr}, _inNames.data(),
64 &input_tensor, _inNames.size(),
65 _outNames.data(), _outNames.size());
68 auto outputInfo = output_tensors[0].GetTensorTypeAndShapeInfo();
71 for (
auto& dim : outputInfo.GetShape()) {
72 if (dim < 0) dim = abs(dim);
76 throw std::length_error(
"Invalid network structure: Output node with 0-length tensor encountered!");
78 float* floatarr = output_tensors.front().GetTensorMutableData<
float>();
79 std::vector<float> outputs;
80 outputs.assign(floatarr, floatarr + outLen);
Referenced by compute(), sopt::ONNXDifferentiableFunc< SCALAR >::function(), sopt::ONNXDifferentiableFunc< SCALAR >::gradient(), and TEST_CASE().
◆ hasKey()
const bool sopt::ORTsession::hasKey |
( |
const std::string & |
key | ) |
const |
|
inline |
Method to check if key exists in network metadata.
Definition at line 143 of file ort_session.h.
144 Ort::AllocatorWithDefaultOptions allocator;
145 return (
bool)_metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
◆ retrieve() [1/2]
template<typename T >
const T sopt::ORTsession::retrieve |
( |
const std::string & |
key | ) |
const |
|
inline |
Method to retrieve value associated with key from network metadata and return value as type T
Definition at line 151 of file ort_session.h.
152 Ort::AllocatorWithDefaultOptions allocator;
153 Ort::AllocatedStringPtr res = _metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
155 throw std::runtime_error(
"Key '"+key+
"' not found in network metadata!");
157 if constexpr (std::is_same<T, std::string>::value) {
161 return utilities::lexical_cast<T>(res.get());
Referenced by TEST_CASE().
◆ retrieve() [2/2]
template<typename T >
const T sopt::ORTsession::retrieve |
( |
const std::string & |
key, |
|
|
const T & |
defaultreturn |
|
) |
| const |
|
inline |
Variation of retrieve method that falls back to defaultreturn if key cannot be found
Definition at line 169 of file ort_session.h.
171 return retrieve<T>(key);
172 }
catch (std::exception& e) {
173 return defaultreturn;
The documentation for this class was generated from the following file: