SOPT
Sparse OPTimisation
inpainting_credible_interval.cc
Go to the documentation of this file.
1 #include <algorithm>
2 #include <exception>
3 #include <functional>
4 #include <iostream>
5 #include <random>
6 #include <vector>
7 #include <ctime>
8 
9 #include "sopt/credible_region.h"
12 #include "sopt/logging.h"
13 #include "sopt/maths.h"
15 #include "sopt/sampling.h"
16 #include "sopt/types.h"
17 #include "sopt/utilities.h"
18 #include "sopt/wavelets.h"
19 // This header is not part of the installed sopt interface
20 // It is only present in tests
21 #include "tools_for_tests/directories.h"
23 
24 // \min_{x} ||\Psi^Tx||_1 \quad \mbox{s.t.} \quad ||y - Ax||_2 < \epsilon and x \geq 0
25 int main(int argc, char const **argv) {
26  // Some type aliases for simplicity
27  using Scalar = sopt::t_real;
28  // Column vector - linear algebra - A * x is a matrix-vector multiplication
29  // type expected by Forward Backward
31  // Matrix - linear algebra - A * x is a matrix-vector multiplication
32  // type expected by Forward Backward
34  // Image - 2D array - A * x is a coefficient-wise multiplication
35  // Type expected by wavelets and image write/read functions
36  using Image = sopt::Image<Scalar>;
37 
38  std::string const input = argc >= 2 ? argv[1] : "cameraman256";
39  std::string const output = argc == 3 ? argv[2] : "none";
40  if (argc > 3) {
41  std::cout << "Usage:\n"
42  "$ "
43  << argv[0]
44  << " [input [output]]\n\n"
45  "- input: path to the image to clean (or name of standard SOPT image)\n"
46  "- output: filename pattern for output image\n";
47  exit(0);
48  }
49  // Set up random numbers for C and C++
50  auto const seed = std::time(nullptr);
51  std::srand(static_cast<unsigned int>(seed));
52  std::mt19937 mersenne(std::time(nullptr));
53 
54  // See set_level function for levels.
55  sopt::logging::set_level("debug");
56  SOPT_HIGH_LOG("Read input file {}", input);
57  const Image image = sopt::tools::read_standard_tiff(input) /
58  sopt::tools::read_standard_tiff(input).cwiseAbs().maxCoeff();
59  SOPT_HIGH_LOG("Image size: {} x {} = {}", image.cols(), image.rows(), image.size());
60 
61  SOPT_HIGH_LOG("Initializing sensing operator");
62  sopt::t_uint const nmeasure = std::floor(0.33 * image.size());
63  sopt::LinearTransform<Vector> const sampling =
64  sopt::linear_transform<Scalar>(sopt::Sampling(image.size(), nmeasure, mersenne));
65  SOPT_HIGH_LOG("Initializing wavelets");
66  auto const wavelet = sopt::wavelets::factory("DB8", 4);
67 
68  // sopt::wavelets::SARA const wavelet{std::make_tuple("db1", 4u), std::make_tuple("db2", 4u),
69  // std::make_tuple("db3", 4u), std::make_tuple("db4", 4u)};
70 
71  auto const psi = sopt::linear_transform<Scalar>(wavelet, image.rows(), image.cols());
72  SOPT_LOW_LOG("Wavelet coefficients: {}", (psi.adjoint() * image).size());
73 
74  SOPT_HIGH_LOG("Computing Forward Backward parameters");
75  Vector const y0 = sampling * Vector::Map(image.data(), image.size());
76  auto constexpr snr = 30.0;
77  auto const sigma = y0.stableNorm() / std::sqrt(y0.size()) * std::pow(10.0, -(snr / 20.0));
78  auto const epsilon = std::sqrt(nmeasure + 2 * std::sqrt(nmeasure)) * sigma;
79 
80  SOPT_HIGH_LOG("Create dirty vector");
81  std::normal_distribution<> gaussian_dist(0, sigma);
82  Vector y(y0.size());
83  for (sopt::t_int i = 0; i < y0.size(); i++) y(i) = y0(i) + gaussian_dist(mersenne);
84  // Write dirty imagte to file
85  if (output != "none") {
86  Vector const dirty = sampling.adjoint() * y;
87  sopt::utilities::write_tiff(Matrix::Map(dirty.data(), image.rows(), image.cols()),
88  "dirty_" + output + ".tiff");
89  }
90 
91  sopt::t_real constexpr regulariser_strength = 18;
92  sopt::t_real const beta = sigma * sigma;
93  SOPT_HIGH_LOG("Creating Foward Backward Functor");
95  .itermax(500)
96  .step_size(beta)
97  .sigma(sigma)
98  .regulariser_strength(regulariser_strength)
99  .relative_variation(5e-4)
100  .residual_tolerance(0)
101  .tight_frame(true)
102  .Phi(sampling);
103 
104  // Create a shared pointer to an instance of the L1GProximal class
105  // and set its properties
106  auto gp = std::make_shared<sopt::algorithm::L1GProximal<Scalar>>(false);
107  gp->l1_proximal_tolerance(1e-4)
108  .l1_proximal_nu(1)
109  .l1_proximal_itermax(50)
110  .l1_proximal_positivity_constraint(true)
111  .l1_proximal_real_constraint(true)
112  .Psi(psi);
113 
114  // Once the properties are set, inject it into the ImagingForwardBackward object
115  fb.g_function(gp);
116 
117  SOPT_HIGH_LOG("Starting Forward Backward");
118  // Alternatively, forward-backward can be called with a tuple (x, residual) as argument
119  // Here, we default to (Φ^Ty/ν, ΦΦ^Ty/ν - y)
120  auto const diagnostic = fb();
121  SOPT_HIGH_LOG("Forward backward returned {}", diagnostic.good);
122 
123  if (output != "none")
124  sopt::utilities::write_tiff(Matrix::Map(diagnostic.x.data(), image.rows(), image.cols()),
125  output + ".tiff");
126  // diagnostic should tell us the function converged
127  // it also contains diagnostic.niters - the number of iterations, and cg_diagnostic - the
128  // diagnostic from the last call to the conjugate gradient.
129  if (not diagnostic.good) throw std::runtime_error("Did not converge!");
130 
131  SOPT_HIGH_LOG("SOPT-Forward Backward converged in {} iterations", diagnostic.niters);
132 
133  constexpr sopt::t_real alpha = 0.99;
134  const sopt::t_uint grid_pixel_size = image.rows() / 16;
135  SOPT_HIGH_LOG("Finding credible interval");
136  const std::function<Scalar(Vector)> objective_function = [regulariser_strength, sigma, &y, &sampling,
137  &psi](const Vector &x) {
138  return sopt::l1_norm(psi.adjoint() * x) * regulariser_strength +
139  0.5 * std::pow(sopt::l2_norm(sampling * x - y), 2) / (sigma * sigma);
140  };
141 
142  sopt::Image<sopt::t_real> lower_error;
143  sopt::Image<sopt::t_real> upper_error;
144  sopt::Image<sopt::t_real> mean_solution;
145  std::tie(lower_error, mean_solution, upper_error) =
146  sopt::credible_region::credible_interval<sopt::Vector<sopt::t_real>, sopt::t_real>(
147  diagnostic.x, image.rows(), image.cols(), grid_pixel_size, objective_function, alpha);
148  if (output != "none") {
150  Matrix::Map(upper_error.data(), upper_error.rows(), upper_error.cols()),
151  output + "_upper_error.tiff");
153  Matrix::Map(mean_solution.data(), mean_solution.rows(), mean_solution.cols()),
154  output + "_mean_solution.tiff");
156  Matrix::Map(lower_error.data(), lower_error.rows(), lower_error.cols()),
157  output + "_lower_error.tiff");
158  }
159  return 0;
160 }
sopt::t_real Scalar
Joins together direct and indirect operators.
LinearTransform< VECTOR > adjoint() const
Indirect transform.
An operator that samples a set of measurements.
Definition: sampling.h:17
t_LinearTransform const & Phi() const
Measurement operator.
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))
int main(int argc, char const **argv)
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
Definition: logging.h:154
Image read_standard_tiff(std::string const &name)
Reads tiff image from sopt data directory if it exists.
Definition: tiffwrappers.cc:9
void write_tiff(Image<> const &image, std::string const &filename)
Writes a tiff greyscale file.
Definition: utilities.cc:68
Wavelet factory(const std::string &name, t_uint nlevels)
Creates a wavelet transform object.
Definition: wavelets.cc:8
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
Vector< T > dirty(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image, RANDOM &mersenne)
Definition: inpainting.h:25
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Eigen::Array< T, Eigen::Dynamic, Eigen::Dynamic > Image
A 2-dimensional list of elements of given type.
Definition: types.h:39
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:38
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:17
real_type< typename T0::Scalar >::type l1_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L1 norm.
Definition: maths.h:116
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > Matrix
A matrix of a given type.
Definition: types.h:29
real_type< typename T0::Scalar >::type l2_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L2 norm.
Definition: maths.h:140
sopt::Vector< Scalar > Vector
Definition: inpainting.cc:28
sopt::Matrix< Scalar > Matrix
Definition: inpainting.cc:29
sopt::Image< Scalar > Image
Definition: inpainting.cc:30