1 #ifndef SOPT_TF_NON_DIFF_FUNCTION_H
2 #define SOPT_TF_NON_DIFF_FUNCTION_H
4 #include "sopt/config.h"
25 template <
typename SCALAR>
29 using FB = ForwardBackward<SCALAR>;
30 using Real =
typename FB::Real;
45 SOPT_HIGH_LOG(
"Performing Forward Backward TensorFlow model");
57 this -> call_model(out, x);
75 int const image_size = image_in.size();
76 int nrows = sqrt(image_size), ncols = sqrt(image_size);
80 for (
size_t i = 0; i < image_size; i++) {
81 if constexpr (std::is_same<SCALAR, t_complex>::value) {
82 real_image[i] = image_in[i].real();
84 real_image[i] = image_in[i];
87 auto min = *(std::min_element(real_image.begin(), real_image.end()));
88 auto max = *(std::max_element(real_image.begin(), real_image.end()));
89 for(
size_t i = 0; i < image_size; i++)
91 real_image[i] = (real_image[i] - min)/max;
98 for(
size_t i = 0; i < image_size; i++)
100 if constexpr (std::is_same<SCALAR, t_complex>::value) {
101 image_out[i] =
t_complex(computed_image[i] * max + min, 0);
103 image_out[i] = computed_image[i] * max + min;
sopt::Vector< Scalar > t_Vector
typename FB::t_Proximal t_Proximal
typename FB::t_Vector t_Vector
typename FB::t_LinearTransform t_LinearTransform
sopt::algorithm::ForwardBackward< SCALAR > FB
Sopt interface class to hold a ONNXrt session.
std::vector< float > compute(std::vector< float > &inputs, const std::vector< int64_t > &inDims) const
typename FB::Scalar Scalar
t_Proximal proximal_operator() const override
TFGProximal(const std::string &path)
t_LinearTransform const & Psi() const override
Analysis operator Ψ
void log_message() const override
sopt::LinearTransform< t_Vector > t_LinearTransform
#define SOPT_HIGH_LOG(...)
High priority message.
LinearTransform< Vector< SCALAR > > linear_transform_identity()
Helper function to create a linear transform that's just the identity.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
std::complex< t_real > t_complex
Root of the type hierarchy for (real) complex numbers.
real_type< typename T0::Scalar >::type l1_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L1 norm.
sopt::Vector< Scalar > Vector