24 extern std::unique_ptr<std::mt19937_64>
mersenne;
26 auto const world = mpi::Communicator::World();
28 auto const split_comm = world.split(
static_cast<t_int>(world.is_root()));
29 if (world.size() < 2)
return;
36 std::string
const input =
"cameraman256";
38 Image const image = world.is_root()
40 : world.broadcast<
Image>();
45 auto indices = world.is_root()
47 : world.broadcast<std::vector<t_uint>>();
48 if (split_comm.size() > 1) {
49 auto const copy = indices;
50 auto const N = indices.size() / split_comm.size() +
51 (split_comm.rank() < indices.size() % split_comm.size() ? 1 : 0);
52 auto const start = split_comm.rank() * (indices.size() / split_comm.size()) +
53 std::min(indices.size() % split_comm.size(), split_comm.rank());
55 std::copy(copy.begin() + start, copy.begin() +
N + start, indices.begin());
57 auto const sampling = sopt::linear_transform<Scalar>(
sopt::Sampling(image.size(), indices));
61 auto const psi = sopt::linear_transform<Scalar>(wavelet, image.rows(), image.cols());
64 Vector const y0 = sampling * Vector::Map(image.data(), image.size());
65 CHECK(y0.size() == indices.size());
66 auto constexpr snr = 30.0;
67 auto const sigma = y0.stableNorm() / std::sqrt(y0.size()) * std::pow(10.0, -(snr / 20.0));
68 auto const epsilon = world.broadcast(std::sqrt(nmeasure + 2 * std::sqrt(y0.size())) *
sigma);
71 std::normal_distribution<> gaussian_dist(0,
sigma);
72 Vector y = world.is_root() ? y0 : world.broadcast<
Vector>();
73 if (world.is_root()) {
77 if (split_comm.size() > 1) {
79 y.size() / split_comm.size() + (split_comm.rank() < y.size() % split_comm.size() ? 1 : 0);
80 auto const start = split_comm.rank() * (y.size() / split_comm.size()) +
81 std::min(y.size() % split_comm.size(), split_comm.rank());
82 y = y.segment(start,
N).eval();
84 CHECK(y.size() == indices.size());
90 .regulariser_strength(1e-1)
91 .relative_variation(5e-4)
94 .l1_proximal_tolerance(1e-2)
96 .l1_proximal_itermax(50)
97 .l1_proximal_positivity_constraint(
true)
98 .l1_proximal_real_constraint(
true)
99 .lagrange_update_scale(0.9)
102 [&sampling](
Vector &out,
Vector const &input) { out = sampling * input; }, sampling.sizes(),
103 [&sampling, split_comm](
Vector &out,
Vector const &input) {
104 out = sampling.adjoint() * input;
105 split_comm.all_sum_all(out);
107 sampling.adjoint().sizes());
108 padmm.Phi(parallel_sampling);
109 padmm.residual_convergence([&padmm, split_comm, world](
Vector const &,
110 Vector const &residual)
mutable ->
bool {
111 auto const residual_norm =
113 SOPT_LOW_LOG(
" - [PADMM] Residuals: {} <? {}", residual_norm, padmm.residual_tolerance());
114 CHECK(residual_norm == Approx(world.broadcast(residual_norm, world.root_id())));
115 return residual_norm < padmm.residual_tolerance();
118 "Objective function");
119 padmm.objective_convergence(
120 [&padmm, split_comm, conv, world](
Vector const &x,
Vector const &)
mutable ->
bool {
121 auto const result = conv(
122 sopt::mpi::l1_norm(padmm.Psi().adjoint() * x, padmm.l1_proximal_weights(), split_comm));
123 CHECK(result == (world.broadcast<
int>(result, world.root_id()) != 0));
126 padmm.is_converged([world](
Vector const &image,
Vector const &) ->
bool {
127 auto const from_root = world.broadcast(image);
128 CHECK(from_root.isApprox(image, 1e-12));
132 auto const diagnostic = padmm();
133 CHECK(diagnostic.good == (world.broadcast<
int>(diagnostic.good, world.root_id()) != 0));
134 CHECK(diagnostic.x.isApprox(world.broadcast(diagnostic.x)));
An operator that samples a set of measurements.
std::vector< t_uint > const & indices() const
Indices of sampled points.
proximal::WeightedL2Ball< Scalar > & l2ball_proximal()
Proximal of the L2 ball.
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))
#define SOPT_LOW_LOG(...)
Low priority message.
Wavelet factory(const std::string &name, t_uint nlevels)
Creates a wavelet transform object.
int t_int
Root of the type hierarchy for signed integers.
size_t t_uint
Root of the type hierarchy for unsigned integers.
Eigen::Array< T, Eigen::Dynamic, Eigen::Dynamic > Image
A 2-dimensional list of elements of given type.
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
real_type< typename T0::Scalar >::type l1_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L1 norm.
real_type< typename T0::Scalar >::type l2_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L2 norm.