SOPT
Sparse OPTimisation
Functions
serial_vs_parallel_padmm.cc File Reference
#include <catch2/catch_all.hpp>
#include <numeric>
#include <random>
#include <utility>
#include "sopt/imaging_padmm.h"
#include "sopt/logging.h"
#include "sopt/maths.h"
#include "sopt/mpi/communicator.h"
#include "sopt/mpi/utilities.h"
#include "sopt/relative_variation.h"
#include "sopt/sampling.h"
#include "sopt/types.h"
#include "sopt/utilities.h"
#include "sopt/wavelets.h"
#include "tools_for_tests/directories.h"
#include "tools_for_tests/tiffwrappers.h"
+ Include dependency graph for serial_vs_parallel_padmm.cc:

Go to the source code of this file.

Functions

 TEST_CASE ("Parallel vs serial inpainting")
 

Function Documentation

◆ TEST_CASE()

TEST_CASE ( "Parallel vs serial inpainting"  )

Definition at line 23 of file serial_vs_parallel_padmm.cc.

23  {
24  extern std::unique_ptr<std::mt19937_64> mersenne;
25  using namespace sopt;
26  auto const world = mpi::Communicator::World();
27  // split into serial and parallel
28  auto const split_comm = world.split(static_cast<t_int>(world.is_root()));
29  if (world.size() < 2) return;
30 
31  // Some type aliases for simplicity
32  using Scalar = double;
34  using Image = sopt::Image<Scalar>;
35 
36  std::string const input = "cameraman256";
37  // Read input file
38  Image const image = world.is_root()
39  ? world.broadcast(sopt::tools::read_standard_tiff(input))
40  : world.broadcast<Image>();
41 
42  // Initializing sensing operator
43  // The operator is obtained by world root proc and split across the procs in split_comm
44  sopt::t_uint const nmeasure = 0.33 * image.size();
45  auto indices = world.is_root()
46  ? world.broadcast(sopt::Sampling(image.size(), nmeasure, *mersenne).indices())
47  : world.broadcast<std::vector<t_uint>>();
48  if (split_comm.size() > 1) {
49  auto const copy = indices;
50  auto const N = indices.size() / split_comm.size() +
51  (split_comm.rank() < indices.size() % split_comm.size() ? 1 : 0);
52  auto const start = split_comm.rank() * (indices.size() / split_comm.size()) +
53  std::min(indices.size() % split_comm.size(), split_comm.rank());
54  indices.resize(N);
55  std::copy(copy.begin() + start, copy.begin() + N + start, indices.begin());
56  }
57  auto const sampling = sopt::linear_transform<Scalar>(sopt::Sampling(image.size(), indices));
58 
59  // Initializing wavelets
60  auto const wavelet = sopt::wavelets::factory("DB4", 4);
61  auto const psi = sopt::linear_transform<Scalar>(wavelet, image.rows(), image.cols());
62 
63  // Computing proximal-ADMM parameters
64  Vector const y0 = sampling * Vector::Map(image.data(), image.size());
65  CHECK(y0.size() == indices.size());
66  auto constexpr snr = 30.0;
67  auto const sigma = y0.stableNorm() / std::sqrt(y0.size()) * std::pow(10.0, -(snr / 20.0));
68  auto const epsilon = world.broadcast(std::sqrt(nmeasure + 2 * std::sqrt(y0.size())) * sigma);
69 
70  // Create dirty vector
71  std::normal_distribution<> gaussian_dist(0, sigma);
72  Vector y = world.is_root() ? y0 : world.broadcast<Vector>();
73  if (world.is_root()) {
74  for (sopt::t_int i = 0; i < y0.size(); i++) y(i) += gaussian_dist(*mersenne);
75  world.broadcast(y);
76  }
77  if (split_comm.size() > 1) {
78  auto const N =
79  y.size() / split_comm.size() + (split_comm.rank() < y.size() % split_comm.size() ? 1 : 0);
80  auto const start = split_comm.rank() * (y.size() / split_comm.size()) +
81  std::min(y.size() % split_comm.size(), split_comm.rank());
82  y = y.segment(start, N).eval();
83  }
84  CHECK(y.size() == indices.size());
85 
86  // Creating proximal-ADMM Functor
87  auto padmm =
89  .itermax(4)
90  .regulariser_strength(1e-1)
91  .relative_variation(5e-4)
93  .tight_frame(false)
94  .l1_proximal_tolerance(1e-2)
95  .l1_proximal_nu(1)
96  .l1_proximal_itermax(50)
97  .l1_proximal_positivity_constraint(true)
98  .l1_proximal_real_constraint(true)
99  .lagrange_update_scale(0.9)
100  .Psi(psi);
101  LinearTransform<Vector> const parallel_sampling(
102  [&sampling](Vector &out, Vector const &input) { out = sampling * input; }, sampling.sizes(),
103  [&sampling, split_comm](Vector &out, Vector const &input) {
104  out = sampling.adjoint() * input;
105  split_comm.all_sum_all(out);
106  },
107  sampling.adjoint().sizes());
108  padmm.Phi(parallel_sampling);
109  padmm.residual_convergence([&padmm, split_comm, world](Vector const &,
110  Vector const &residual) mutable -> bool {
111  auto const residual_norm =
112  sopt::mpi::l2_norm(residual, padmm.l2ball_proximal_weights(), split_comm);
113  SOPT_LOW_LOG(" - [PADMM] Residuals: {} <? {}", residual_norm, padmm.residual_tolerance());
114  CHECK(residual_norm == Approx(world.broadcast(residual_norm, world.root_id())));
115  return residual_norm < padmm.residual_tolerance();
116  });
117  sopt::ScalarRelativeVariation<Scalar> conv(padmm.relative_variation(), padmm.relative_variation(),
118  "Objective function");
119  padmm.objective_convergence(
120  [&padmm, split_comm, conv, world](Vector const &x, Vector const &) mutable -> bool {
121  auto const result = conv(
122  sopt::mpi::l1_norm(padmm.Psi().adjoint() * x, padmm.l1_proximal_weights(), split_comm));
123  CHECK(result == (world.broadcast<int>(result, world.root_id()) != 0));
124  return result;
125  });
126  padmm.is_converged([world](Vector const &image, Vector const &) -> bool {
127  auto const from_root = world.broadcast(image);
128  CHECK(from_root.isApprox(image, 1e-12));
129  return true;
130  });
131 
132  auto const diagnostic = padmm();
133  CHECK(diagnostic.good == (world.broadcast<int>(diagnostic.good, world.root_id()) != 0));
134  CHECK(diagnostic.x.isApprox(world.broadcast(diagnostic.x)));
135 }
constexpr auto N
Definition: wavelets.cc:57
sopt::t_real Scalar
Joins together direct and indirect operators.
An operator that samples a set of measurements.
Definition: sampling.h:17
std::vector< t_uint > const & indices() const
Indices of sampled points.
Definition: sampling.h:51
proximal::WeightedL2Ball< Scalar > & l2ball_proximal()
Proximal of the L2 ball.
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
Image read_standard_tiff(std::string const &name)
Reads tiff image from sopt data directory if it exists.
Definition: tiffwrappers.cc:9
Wavelet factory(const std::string &name, t_uint nlevels)
Creates a wavelet transform object.
Definition: wavelets.cc:8
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Eigen::Array< T, Eigen::Dynamic, Eigen::Dynamic > Image
A 2-dimensional list of elements of given type.
Definition: types.h:39
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:38
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:17
real_type< typename T0::Scalar >::type l1_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L1 norm.
Definition: maths.h:116
real_type< typename T0::Scalar >::type l2_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L2 norm.
Definition: maths.h:140

References sopt::epsilon(), sopt::wavelets::factory(), sopt::Sampling::indices(), sopt::l1_norm(), sopt::l2_norm(), sopt::algorithm::ImagingProximalADMM< SCALAR >::l2ball_proximal(), mersenne(), N, sopt::tools::read_standard_tiff(), sopt::sigma(), and SOPT_LOW_LOG.