SOPT
Sparse OPTimisation
tf_non_diff_function.h
Go to the documentation of this file.
1 #ifndef SOPT_TF_NON_DIFF_FUNCTION_H
2 #define SOPT_TF_NON_DIFF_FUNCTION_H
3 // TODO: Clean up unnecessary includes
4 #include "sopt/config.h"
5 #include "sopt/exception.h"
8 #include "sopt/logging.h"
9 #include "sopt/ort_session.h"
10 #include "sopt/proximal.h"
12 #include "sopt/types.h"
14 
15 #include <numeric>
16 #include <string>
17 #include <tuple>
18 #include <utility>
19 
20 namespace sopt::algorithm {
21 
22 // Implementation of non differentiable function g(x) with a TensorFlow model.
23 // The function represents an l1 norm.
24 // The "proximal" operator is implemented using the neural net model (denoiser).
25 template <typename SCALAR>
26 class TFGProximal : public NonDifferentiableFunc<SCALAR> {
27 
28 public:
29  using FB = ForwardBackward<SCALAR>;
30  using Real = typename FB::Real;
31  using Scalar = typename FB::Scalar;
32  using t_Vector = typename FB::t_Vector;
33  using t_Proximal = typename FB::t_Proximal;
35 
36  // The constructor constructs a cppflow model object from a saved model saved
37  // to the file filename
38  TFGProximal(const std::string& path)
39  : model_(path),
42 
43  // Print log message with the correct norms
44  void log_message() const override {
45  SOPT_HIGH_LOG("Performing Forward Backward TensorFlow model");
46  }
47 
48  // Return the L1 norm of x with unit weights
49  Real function(t_Vector const &x) const override {
50  auto weights = Vector<Real>::Ones(x.size());
51  return sopt::l1_norm(static_cast<t_Vector>(x), weights);
52  }
53 
54  // Return g_proximal as a lambda function. Used in operator() in base class.
55  t_Proximal proximal_operator() const override {
56  return [this](t_Vector &out, Real gamma, t_Vector const &x) {
57  this -> call_model(out, x);
58  };
59  }
60 
62  // Psi is not implemented in this class, return an identity transform.
63  t_LinearTransform const &Psi() const override {
64  return Psi_;
65  }
66 
67 protected:
68 
69  t_LinearTransform Psi_;
70  sopt::ORTsession model_;
71 
72  void call_model(t_Vector &image_out, t_Vector const &image_in) const {
73 
74  // Set dimensions
75  int const image_size = image_in.size();
76  int nrows = sqrt(image_size), ncols = sqrt(image_size);
77 
78  // Scale to [0,1] in reals
79  Vector<float> real_image(image_size);
80  for (size_t i = 0; i < image_size; i++) {
81  if constexpr (std::is_same<SCALAR, t_complex>::value) {
82  real_image[i] = image_in[i].real();
83  } else {
84  real_image[i] = image_in[i];
85  }
86  }
87  auto min = *(std::min_element(real_image.begin(), real_image.end()));
88  auto max = *(std::max_element(real_image.begin(), real_image.end()));
89  for(size_t i = 0; i < image_size; i++)
90  {
91  real_image[i] = (real_image[i] - min)/max;
92  }
93 
94  // Call model
95  Vector<float> computed_image = model_.compute(real_image, {1,nrows,ncols,1});
96 
97  // rescale back
98  for(size_t i = 0; i < image_size; i++)
99  {
100  if constexpr (std::is_same<SCALAR, t_complex>::value) {
101  image_out[i] = t_complex(computed_image[i] * max + min, 0);
102  } else {
103  image_out[i] = computed_image[i] * max + min;
104  }
105  }
106  }
107 
108 };
109 } // namespace sopt::algorithm
110 #endif
sopt::Vector< Scalar > t_Vector
sopt::t_real Scalar
typename FB::t_Proximal t_Proximal
typename FB::t_Vector t_Vector
typename FB::t_LinearTransform t_LinearTransform
sopt::algorithm::ForwardBackward< SCALAR > FB
Sopt interface class to hold a ONNXrt session.
Definition: ort_session.h:18
std::vector< float > compute(std::vector< float > &inputs, const std::vector< int64_t > &inDims) const
Definition: ort_session.h:50
t_Proximal proximal_operator() const override
TFGProximal(const std::string &path)
t_LinearTransform const & Psi() const override
Analysis operator Ψ
void log_message() const override
sopt::LinearTransform< t_Vector > t_LinearTransform
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
LinearTransform< Vector< SCALAR > > linear_transform_identity()
Helper function to create a linear transform that's just the identity.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
std::complex< t_real > t_complex
Root of the type hierarchy for (real) complex numbers.
Definition: types.h:19
real_type< typename T0::Scalar >::type l1_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L1 norm.
Definition: maths.h:116
sopt::Vector< Scalar > Vector
Definition: inpainting.cc:28