SOPT
Sparse OPTimisation
Typedefs | Functions
pd_inpainting.cc File Reference
#include <algorithm>
#include <exception>
#include <functional>
#include <iostream>
#include <random>
#include <vector>
#include <ctime>
#include <catch2/catch_all.hpp>
#include "sopt/imaging_primal_dual.h"
#include "sopt/l1_non_diff_function.h"
#include "sopt/logging.h"
#include "sopt/maths.h"
#include "sopt/relative_variation.h"
#include "sopt/sampling.h"
#include "sopt/types.h"
#include "sopt/utilities.h"
#include "sopt/wavelets.h"
#include "tools_for_tests/directories.h"
#include "tools_for_tests/tiffwrappers.h"
+ Include dependency graph for pd_inpainting.cc:

Go to the source code of this file.

Typedefs

using Scalar = double
 
using Vector = sopt::Vector< Scalar >
 
using Matrix = sopt::Matrix< Scalar >
 
using Image = sopt::Image< Scalar >
 

Functions

 TEST_CASE ("Primal Dual Inpainting")
 

Typedef Documentation

◆ Image

Definition at line 30 of file pd_inpainting.cc.

◆ Matrix

Definition at line 29 of file pd_inpainting.cc.

◆ Scalar

using Scalar = double

Definition at line 27 of file pd_inpainting.cc.

◆ Vector

Definition at line 28 of file pd_inpainting.cc.

Function Documentation

◆ TEST_CASE()

TEST_CASE ( "Primal Dual Inpainting"  )

Definition at line 32 of file pd_inpainting.cc.

32  {
33  extern std::unique_ptr<std::mt19937_64> mersenne;
34  std::string const input = "cameraman256";
35 
36  Image const image = sopt::tools::read_standard_tiff(input);
37 
38  sopt::t_uint const nmeasure = std::floor(0.5 * image.size());
39  sopt::LinearTransform<Vector> const sampling =
40  sopt::linear_transform<Scalar>(sopt::Sampling(image.size(), nmeasure, *mersenne));
41 
42  auto const wavelet = sopt::wavelets::factory("DB8", 4);
43 
44  auto const psi = sopt::linear_transform<Scalar>(wavelet, image.rows(), image.cols());
45 
46  Vector const y0 = sampling * Vector::Map(image.data(), image.size());
47  auto constexpr snr = 30.0;
48  auto const sigma = y0.stableNorm() / std::sqrt(y0.size()) * std::pow(10.0, -(snr / 20.0));
49  auto const epsilon = std::sqrt(nmeasure + 2 * std::sqrt(y0.size())) * sigma;
50 
51  std::normal_distribution<> gaussian_dist(0, sigma);
52  Vector y(y0.size());
53  for (sopt::t_int i = 0; i < y0.size(); i++) y(i) = y0(i) + gaussian_dist(*mersenne);
54 
55  Eigen::VectorXd dirty_image = sampling.adjoint() * y;
56 
57  sopt::t_real const regulariser_strength = (psi.adjoint() * (sampling.adjoint() * y)).real().maxCoeff() * 1e-2;
58 
60  .Phi(sampling)
61  .Psi(psi)
62  .itermax(500)
63  .tau(0.5)
64  .regulariser_strength(regulariser_strength)
65  .l2ball_proximal_epsilon(epsilon)
66  .relative_variation(5e-4)
68  .positivity_constraint(true);
69 
70  auto const diagnostic = pd();
71 
72  CHECK(diagnostic.good);
73  CHECK(diagnostic.niters < 500);
74 
75  // compare input image to cleaned output image
76  // calculate mean squared error sum_i ( ( x_true(i) - x_est(i) ) **2 )
77  // check this is less than the number of pixels * 0.01
78  sopt::utilities::write_tiff(Matrix::Map(diagnostic.x.data(), image.rows(), image.cols()),
79  "pd_reconstruction.tiff");
80  sopt::utilities::write_tiff(Matrix::Map(dirty_image.data(), image.rows(), image.cols()),
81  "pd_dirty.tiff");
82  Eigen::Map<const Eigen::VectorXd> flat_image(image.data(), image.size());
83  auto mse = (flat_image - diagnostic.x).array().square().sum() / image.size();
84  CAPTURE(mse);
85  SOPT_HIGH_LOG("MSE: {}", mse);
86  CHECK(mse < 0.01);
87 }
Joins together direct and indirect operators.
LinearTransform< VECTOR > adjoint() const
Indirect transform.
An operator that samples a set of measurements.
Definition: sampling.h:17
ImagingPrimalDual &::type Psi(ARGS &&... args)
ImagingPrimalDual< Scalar > & residual_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
ImagingPrimalDual &::type Phi(ARGS &&... args)
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
Image read_standard_tiff(std::string const &name)
Reads tiff image from sopt data directory if it exists.
Definition: tiffwrappers.cc:9
void write_tiff(Image<> const &image, std::string const &filename)
Writes a tiff greyscale file.
Definition: utilities.cc:68
Wavelet factory(const std::string &name, t_uint nlevels)
Creates a wavelet transform object.
Definition: wavelets.cc:8
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:38
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:17
sopt::Vector< Scalar > Vector
Definition: inpainting.cc:28
sopt::Image< Scalar > Image
Definition: inpainting.cc:30

References sopt::LinearTransform< VECTOR >::adjoint(), sopt::epsilon(), sopt::wavelets::factory(), mersenne(), sopt::algorithm::ImagingPrimalDual< SCALAR >::Phi(), sopt::algorithm::ImagingPrimalDual< SCALAR >::Psi(), sopt::tools::read_standard_tiff(), sopt::algorithm::ImagingPrimalDual< SCALAR >::residual_convergence(), sopt::sigma(), SOPT_HIGH_LOG, and sopt::utilities::write_tiff().