SOPT
Sparse OPTimisation
reweighted.cc
Go to the documentation of this file.
1 #include <algorithm>
2 #include <exception>
3 #include <functional>
4 #include <iostream>
5 #include <random>
6 #include <vector>
7 #include <ctime>
8 
9 #include "sopt/logging.h"
10 #include "sopt/maths.h"
11 #include "sopt/positive_quadrant.h"
13 #include "sopt/reweighted.h"
14 #include "sopt/sampling.h"
15 #include "sopt/sdmm.h"
16 #include "sopt/types.h"
17 #include "sopt/utilities.h"
18 #include "sopt/wavelets.h"
19 // This header is not part of the installed sopt interface
20 // It is only present in tests
21 #include "tools_for_tests/directories.h"
23 
24 // \min_{x} ||W_j\Psi^Tx||_1 \quad \mbox{s.t.} \quad ||y - Ax||_2 < \epsilon and x \geq 0
25 // with W_j = ||\Psi^Tx_{j-1}||_1
26 // By iterating this algorithm, we can approximate L0 from L1.
27 int main(int argc, char const **argv) {
28  // Some type aliases for simplicity
29  using Scalar = double;
30  // Column vector - linear algebra - A * x is a matrix-vector multiplication
31  // type expected by SDMM
33  // Matrix - linear algebra - A * x is a matrix-vector multiplication
34  // type expected by SDMM
36  // Image - 2D array - A * x is a coefficient-wise multiplication
37  // Type expected by wavelets and image write/read functions
38  using Image = sopt::Image<Scalar>;
39 
40  std::string const input = argc >= 2 ? argv[1] : "cameraman256";
41  std::string const output = argc == 3 ? argv[2] : "none";
42  if (argc > 3) {
43  std::cout << "Usage:\n"
44  "$ "
45  << argv[0]
46  << " [input [output]]\n\n"
47  "- input: path to the image to clean (or name of standard SOPT image)\n"
48  "- output: filename pattern for output image\n";
49  exit(0);
50  }
51  // Set up random numbers for C and C++
52  auto const seed = std::time(nullptr);
53  std::srand(static_cast<unsigned int>(seed));
54  std::mt19937 mersenne(std::time(nullptr));
55 
56  SOPT_HIGH_LOG("Read input file {}", input);
57  Image const image = sopt::tools::read_standard_tiff(input);
58 
59  SOPT_HIGH_LOG("Initializing sensing operator");
60  sopt::t_uint const nmeasure = 0.33 * image.size();
61  auto const sampling =
62  sopt::linear_transform<Scalar>(sopt::Sampling(image.size(), nmeasure, mersenne));
63 
64  SOPT_HIGH_LOG("Initializing wavelets");
65  auto const wavelet = sopt::wavelets::factory("DB4", 4);
66  auto const psi = sopt::linear_transform<Scalar>(wavelet, image.rows(), image.cols());
67 
68  SOPT_HIGH_LOG("Computing sdmm parameters");
69  Vector const y0 = sampling * Vector::Map(image.data(), image.size());
70  auto constexpr snr = 30.0;
71  auto const sigma = y0.stableNorm() / std::sqrt(y0.size()) * std::pow(10.0, -(snr / 20.0));
72  auto const epsilon = std::sqrt(nmeasure + 2 * std::sqrt(y0.size())) * sigma;
73 
74  SOPT_HIGH_LOG("Create dirty vector");
75  std::normal_distribution<> gaussian_dist(0, sigma);
76  Vector y(y0.size());
77  for (sopt::t_int i = 0; i < y0.size(); i++) y(i) = y0(i) + gaussian_dist(mersenne);
78  // Write dirty imagte to file
79  if (output != "none") {
80  Vector const dirty = sampling.adjoint() * y;
81  sopt::utilities::write_tiff(Matrix::Map(dirty.data(), image.rows(), image.cols()),
82  "dirty_" + output + ".tiff");
83  }
84 
85  SOPT_HIGH_LOG("Initializing convergence function");
86  auto relvar = sopt::RelativeVariation<Scalar>(5e-2);
87  auto convergence = [&y, &sampling, &psi, &relvar](sopt::Vector<Scalar> const &x) -> bool {
88  SOPT_MEDIUM_LOG("||x - y||_2: {}", (y - sampling * x).stableNorm());
89  SOPT_MEDIUM_LOG("||Psi^Tx||_1: {}", sopt::l1_norm(psi.adjoint() * x));
90  SOPT_MEDIUM_LOG("||abs(x) - x||_2: {}", (x.array().abs().matrix() - x).stableNorm());
91  return relvar(x);
92  };
93 
94  SOPT_HIGH_LOG("Creating SDMM Functor");
95  auto const sdmm =
97  .itermax(3000)
98  .gamma(0.1)
99  .conjugate_gradient(200, 1e-8)
100  .is_converged(convergence)
101  // Any number of (proximal g_i, L_i) pairs can be added
102  // ||Psi^dagger x||_1
103  .append(sopt::proximal::l1_norm<Scalar>, psi.adjoint(), psi)
104  // ||y - A x|| < epsilon
106  // x in positive quadrant
107  .append(sopt::proximal::positive_quadrant<Scalar>);
108 
109  SOPT_HIGH_LOG("Creating the reweighted algorithm");
110  // positive_quadrant projects the result of SDMM on the positive quadrant.
111  // This follows the reweighted algorithm in the original C implementation.
112  auto const posq = positive_quadrant(sdmm);
113  using t_PosQuadSDMM = std::remove_const<decltype(posq)>::type;
114  auto const min_delta = sigma * std::sqrt(y.size()) / std::sqrt(8 * image.size());
115  // Sets weight after each sdmm iteration.
116  // In practice, this means replacing the proximal of the l1 objective function.
117  auto set_weights = [](t_PosQuadSDMM &sdmm, Vector const &weights) {
118  sdmm.algorithm().proximals(0) = [weights](Vector &out, Scalar gamma, Vector const &x) {
119  out = sopt::soft_threshhold(x, gamma * weights);
120  };
121  };
122  auto call_PsiT = [&psi](t_PosQuadSDMM const &, Vector const &x) -> Vector {
123  return psi.adjoint() * x;
124  };
125  auto const reweighted = sopt::algorithm::reweighted(posq, set_weights, call_PsiT)
126  .itermax(5)
127  .min_delta(min_delta)
128  .is_converged(sopt::RelativeVariation<Scalar>(1e-3));
129 
130  SOPT_HIGH_LOG("Computing warm-start SDMM");
131  auto warm_start = sdmm(Vector::Zero(image.size()));
132  warm_start.x = sopt::positive_quadrant(warm_start.x);
133  SOPT_HIGH_LOG("SDMM returned {}", warm_start.good);
134 
135  SOPT_HIGH_LOG("Computing warm-start SDMM");
136  auto const result = reweighted(warm_start);
137 
138  // result should tell us the function converged
139  // it also contains result.niters - the number of iterations, and cg_diagnostic - the
140  // result from the last call to the conjugate gradient.
141  if (not result.good) throw std::runtime_error("Did not converge!");
142 
143  SOPT_HIGH_LOG("SOPT-SDMM converged in {} iterations", result.niters);
144  if (output != "none")
145  sopt::utilities::write_tiff(Matrix::Map(result.algo.x.data(), image.rows(), image.cols()),
146  output + ".tiff");
147 
148  return 0;
149 }
sopt::t_real Scalar
An operator that samples a set of measurements.
Definition: sampling.h:17
bool is_converged(t_Vector const &x) const
Forwards to convergence function parameter.
Definition: sdmm.h:172
SDMM< SCALAR > & conjugate_gradient(t_uint itermax, t_real tolerance)
Helps setup conjugate gradient.
Definition: sdmm.h:83
Proximal for indicator function of L2 ball.
Definition: proximal.h:182
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))
int main(int argc, char const **argv)
Definition: reweighted.cc:26
#define SOPT_HIGH_LOG(...)
High priority message.
Definition: logging.h:223
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:225
Reweighted< ALGORITHM > reweighted(ALGORITHM const &algo, typename Reweighted< ALGORITHM >::t_SetWeights const &set_weights, typename Reweighted< ALGORITHM >::t_Reweightee const &reweightee)
Factory function to create an l0-approximation by reweighting an l1 norm.
Definition: reweighted.h:238
Translation< FUNCTION, VECTOR > translate(FUNCTION const &func, VECTOR const &translation)
Translates given proximal by given vector.
Definition: proximal.h:362
Image read_standard_tiff(std::string const &name)
Reads tiff image from sopt data directory if it exists.
Definition: tiffwrappers.cc:9
void write_tiff(Image<> const &image, std::string const &filename)
Writes a tiff greyscale file.
Definition: utilities.cc:68
Wavelet factory(const std::string &name, t_uint nlevels)
Creates a wavelet transform object.
Definition: wavelets.cc:8
Eigen::CwiseUnaryOp< const details::ProjectPositiveQuadrant< typename T::Scalar >, const T > positive_quadrant(Eigen::DenseBase< T > const &input)
Expression to create projection onto positive quadrant.
Definition: maths.h:60
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
Vector< T > dirty(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image, RANDOM &mersenne)
Definition: inpainting.h:25
std::enable_if< std::is_arithmetic< SCALAR >::value or is_complex< SCALAR >::value, SCALAR >::type soft_threshhold(SCALAR const &x, typename real_type< SCALAR >::type const &threshhold)
abs(x) < threshhold ? 0: x - sgn(x) * threshhold
Definition: maths.h:29
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Eigen::Array< T, Eigen::Dynamic, Eigen::Dynamic > Image
A 2-dimensional list of elements of given type.
Definition: types.h:39
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:38
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:17
real_type< typename T0::Scalar >::type l1_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L1 norm.
Definition: maths.h:116
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > Matrix
A matrix of a given type.
Definition: types.h:29
sopt::Vector< Scalar > Vector
Definition: inpainting.cc:28
sopt::Matrix< Scalar > Matrix
Definition: inpainting.cc:29
sopt::Image< Scalar > Image
Definition: inpainting.cc:30