SOPT
Sparse OPTimisation
Typedefs | Functions
stochastic_update.cc File Reference
#include <algorithm>
#include <exception>
#include <functional>
#include <iostream>
#include <random>
#include <vector>
#include <ctime>
#include <catch2/catch_all.hpp>
#include "sopt/imaging_forward_backward.h"
#include "sopt/l1_non_diff_function.h"
#include "sopt/logging.h"
#include "sopt/maths.h"
#include "sopt/relative_variation.h"
#include "sopt/sampling.h"
#include "sopt/types.h"
#include "sopt/utilities.h"
#include "sopt/wavelets.h"
#include "sopt/gradient_utils.h"
#include "tools_for_tests/directories.h"
#include "tools_for_tests/tiffwrappers.h"
+ Include dependency graph for stochastic_update.cc:

Go to the source code of this file.

Typedefs

using Scalar = double
 
using Vector = sopt::Vector< Scalar >
 
using Matrix = sopt::Matrix< Scalar >
 
using Image = sopt::Image< Scalar >
 

Functions

 TEST_CASE ("Inpainting")
 

Typedef Documentation

◆ Image

Definition at line 31 of file stochastic_update.cc.

◆ Matrix

Definition at line 30 of file stochastic_update.cc.

◆ Scalar

using Scalar = double

Definition at line 28 of file stochastic_update.cc.

◆ Vector

Definition at line 29 of file stochastic_update.cc.

Function Documentation

◆ TEST_CASE()

TEST_CASE ( "Inpainting"  )

Definition at line 33 of file stochastic_update.cc.

33  {
34  extern std::unique_ptr<std::mt19937_64> mersenne;
35  std::string const input = "cameraman256";
36 
37  Image const image = sopt::tools::read_standard_tiff(input);
38 
39  auto const wavelet = sopt::wavelets::factory("DB8", 4);
40 
41  auto const psi = sopt::linear_transform<Scalar>(wavelet, image.rows(), image.cols());
42  size_t nmeasure = static_cast<size_t>(image.size() * 0.5);
43 
44  double constexpr snr = 30.0;
45  std::shared_ptr<sopt::LinearTransform<Vector>> Phi =
46  std::make_shared<sopt::LinearTransform<Vector>>(
47  sopt::linear_transform<Scalar>(sopt::Sampling(image.size(), nmeasure, *mersenne)));
48  Vector y = (*Phi) * Vector::Map(image.data(), image.size());
49 
50  auto sigma = y.stableNorm() / std::sqrt(y.size()) * std::pow(10.0, -(snr / 20.0));
51  sopt::t_real constexpr regulariser_strength = 18;
52  sopt::t_real const beta = sigma*sigma*0.5;
53 
54  // Define a stochostic target/operator updater!
55  std::unique_ptr<std::mt19937_64> *m = &mersenne;
56  std::function<std::shared_ptr<sopt::IterationState<Vector>>()> random_updater = [&image, m, sigma, nmeasure](){
57  double constexpr snr = 30.0;
58  std::shared_ptr<sopt::LinearTransform<Vector>> Phi =
59  std::make_shared<sopt::LinearTransform<Vector>>(sopt::linear_transform<Scalar>(sopt::Sampling(image.size(), nmeasure, **m)));
60  Vector y = (*Phi) * Vector::Map(image.data(), image.size());
61 
62  std::normal_distribution<> gaussian_dist(0, sigma);
63  for (sopt::t_int i = 0; i < y.size(); i++) y(i) = y(i) + gaussian_dist(*mersenne);
64 
65  return std::make_shared<sopt::IterationState<Vector>>(y, Phi);
66  };
67 
68  auto fb = sopt::algorithm::ImagingForwardBackward<Scalar>(random_updater);
69  fb.itermax(1000)
70  .step_size(beta) // stepsize
71  .sigma(sigma) // sigma
72  .regulariser_strength(regulariser_strength) // regularisation paramater
73  .relative_variation(1e-3)
74  .residual_tolerance(0)
75  .tight_frame(true);
76 
77  // Create a shared pointer to an instance of the L1GProximal class
78  // and set its properties
79  auto gp = std::make_shared<sopt::algorithm::L1GProximal<Scalar>>(false);
80  gp->l1_proximal_tolerance(1e-4)
81  .l1_proximal_nu(1)
82  .l1_proximal_itermax(50)
83  .l1_proximal_positivity_constraint(true)
84  .l1_proximal_real_constraint(true)
85  .Psi(psi);
86 
87  // Once the properties are set, inject it into the ImagingForwardBackward object
88  fb.g_function(gp);
89 
90  auto const diagnostic = fb();
91 
92  CHECK(diagnostic.good);
93  CHECK(diagnostic.niters < 500);
94 
95  // compare input image to cleaned output image
96  // calculate mean squared error sum_i ( ( x_true(i) - x_est(i) ) **2 )
97  // check this is less than the number of pixels * 0.01
98 
99  Eigen::Map<const Eigen::VectorXd> flat_image(image.data(), image.size());
100  auto mse = (flat_image - diagnostic.x).array().square().sum() / image.size();
101  CAPTURE(mse);
102  CHECK(mse < 0.01);
103 }
An operator that samples a set of measurements.
Definition: sampling.h:17
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))
Image read_standard_tiff(std::string const &name)
Reads tiff image from sopt data directory if it exists.
Definition: tiffwrappers.cc:9
Wavelet factory(const std::string &name, t_uint nlevels)
Creates a wavelet transform object.
Definition: wavelets.cc:8
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:17
sopt::Vector< Scalar > Vector
Definition: inpainting.cc:28
sopt::Image< Scalar > Image
Definition: inpainting.cc:30

References sopt::wavelets::factory(), mersenne(), sopt::tools::read_standard_tiff(), and sopt::sigma().