SOPT
Sparse OPTimisation
Typedefs | Functions
onnx_inpainting.cc File Reference
#include <algorithm>
#include <exception>
#include <functional>
#include <iostream>
#include <random>
#include <vector>
#include <ctime>
#include <catch2/catch_all.hpp>
#include "sopt/imaging_forward_backward.h"
#include "sopt/real_indicator.h"
#include "sopt/logging.h"
#include "sopt/maths.h"
#include "sopt/relative_variation.h"
#include "sopt/sampling.h"
#include "sopt/types.h"
#include "sopt/utilities.h"
#include "sopt/ort_session.h"
#include "sopt/onnx_differentiable_func.h"
#include "tools_for_tests/directories.h"
#include "tools_for_tests/tiffwrappers.h"
+ Include dependency graph for onnx_inpainting.cc:

Go to the source code of this file.

Typedefs

using Scalar = double
 
using Vector = sopt::Vector< Scalar >
 
using Matrix = sopt::Matrix< Scalar >
 
using Image = sopt::Image< Scalar >
 
using LinearTransform = sopt::LinearTransform< Vector >
 

Functions

 TEST_CASE ("Inpainting")
 

Typedef Documentation

◆ Image

Definition at line 32 of file onnx_inpainting.cc.

◆ LinearTransform

Definition at line 33 of file onnx_inpainting.cc.

◆ Matrix

Definition at line 31 of file onnx_inpainting.cc.

◆ Scalar

using Scalar = double

Definition at line 29 of file onnx_inpainting.cc.

◆ Vector

Definition at line 30 of file onnx_inpainting.cc.

Function Documentation

◆ TEST_CASE()

TEST_CASE ( "Inpainting"  )

Definition at line 35 of file onnx_inpainting.cc.

35  {
36 
37  // black magic?
38  double lambda = 5e4;
39  double mu = 20;
40 
41  extern std::unique_ptr<std::mt19937_64> mersenne;
42  std::string const input = "cameraman256";
43 
44 
45  Image const image = sopt::tools::read_standard_tiff(input);
46 
47  sopt::t_uint nmeasure = std::floor(0.5 * image.size());
48  LinearTransform const sampling =
49  sopt::linear_transform<Scalar>(sopt::Sampling(image.size(), nmeasure, *mersenne));
50 
51  Vector const y0 = sampling * Vector::Map(image.data(), image.size());
52  auto constexpr snr = 30.0;
53  auto const sigma = y0.stableNorm() / std::sqrt(y0.size()) * std::pow(10.0, -(snr / 20.0));
54  auto const epsilon = std::sqrt(nmeasure + 2 * std::sqrt(y0.size())) * sigma;
55 
56  // set the model function and gradient
57  std::string const prior_path = std::string(sopt::tools::models_directory() + "/example_cost_dynamic_CRR_sigma_5_t_5.onnx");
58  std::string const prior_gradient_path = std::string(sopt::tools::models_directory() + "/example_grad_dynamic_CRR_sigma_5_t_5.onnx");
59  std::shared_ptr<sopt::ONNXDifferentiableFunc<Scalar>> diff_function = std::make_shared<sopt::ONNXDifferentiableFunc<Scalar>>(prior_path, prior_gradient_path, sigma, mu, lambda, sampling);
60 
61  std::normal_distribution<> gaussian_dist(0, sigma);
62  Vector y(y0.size());
63  for (sopt::t_int i = 0; i < y0.size(); i++) y(i) = y0(i) + gaussian_dist(*mersenne);
64 
65  Eigen::VectorXd dirty_image = sampling.adjoint() * y;
66  Eigen::VectorXd init_res = ((sampling * dirty_image) - y);
67  auto init_res_norm = init_res.array().abs().sum();
68  SOPT_HIGH_LOG("Initial residual norm: {}", init_res_norm);
69 
70  sopt::t_real constexpr regulariser_strength = 18;
71  sopt::t_real const beta = sigma * sigma * 0.5;
72 
73  // Arbitrary (absolute) tolerance level to produce a reasonable image which converges
75  "Convergence function");
76  std::function<bool(const Vector &, const Vector &)> convergence = [&scalvar](const Vector &x, const Vector &residual)
77  {
78  return scalvar(x.array());
79  };
80 
82  fb.itermax(500)
83  .step_size(beta) // stepsize
84  .sigma(sigma) // sigma
85  .regulariser_strength(regulariser_strength) // regularisation paramater
86  .relative_variation(1e-3)
87  .residual_tolerance(0)
88  .tight_frame(true)
89  .Phi(sampling)
90  .is_converged(convergence);
91 
92  //fb.set_f_gradient(f_gradient);
93  fb.f_function(diff_function);
94 
95  // Create a shared pointer to the real indicator (non differentiable) function
96  auto non_diff_func = std::make_shared<sopt::algorithm::RealIndicator<Scalar>>();
97 
98  // Inject it into the ImagingForwardBackward object
99  fb.g_function(non_diff_func);
100 
101  auto const diagnostic = fb();
102 
103  // CHECK(diagnostic.good);
104  // CHECK(diagnostic.niters < 500);
105 
106  // compare input image to cleaned output image
107  // calculate mean squared error sum_i ( ( x_true(i) - x_est(i) ) **2 )
108  // check this is less than the number of pixels * 0.01
109  //sopt::utilities::write_tiff(Matrix::Map(diagnostic.x.data(), image.rows(), image.cols()),
110  // "onnx_reconstruction.tiff");
111  //sopt::utilities::write_tiff(Matrix::Map(dirty_image.data(), image.rows(), image.cols()),
112  // "dirty.tiff");
113  Eigen::Map<const Eigen::VectorXd> flat_image(image.data(), image.size());
114  auto mse = (flat_image - diagnostic.x).array().square().sum() / image.size();
115  CAPTURE(mse);
116  SOPT_HIGH_LOG("MSE: {}", mse);
117  CHECK(mse < 0.01);
118 }
Joins together direct and indirect operators.
LinearTransform< VECTOR > adjoint() const
Indirect transform.
An operator that samples a set of measurements.
Definition: sampling.h:17
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
Image read_standard_tiff(std::string const &name)
Reads tiff image from sopt data directory if it exists.
Definition: tiffwrappers.cc:9
std::string models_directory()
Machine-learning models.
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:38
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:17
sopt::Vector< Scalar > Vector
Definition: inpainting.cc:28
sopt::Image< Scalar > Image
Definition: inpainting.cc:30

References sopt::LinearTransform< VECTOR >::adjoint(), sopt::epsilon(), mersenne(), sopt::tools::models_directory(), sopt::tools::read_standard_tiff(), sopt::sigma(), and SOPT_HIGH_LOG.