SOPT
Sparse OPTimisation
mpi_proximals.cc
Go to the documentation of this file.
1 #include <catch2/catch_all.hpp>
2 #include <numeric>
3 #include <random>
4 #include <utility>
5 
6 #include "sopt/proximal.h"
7 #include "sopt/types.h"
8 
9 TEST_CASE("Parallel Euclidian norm", "[proximal]") {
10  using namespace sopt;
11  auto const world = mpi::Communicator::World();
12  proximal::EuclidianNorm const eucl(world);
13 
14  Vector<t_real> out(5);
15  Vector<t_real> x(5);
16  x << 1, 2, 3, 4, 5;
17  eucl(out, std::sqrt(world.all_sum_all(x.squaredNorm())) * 1.001, x);
18  CHECK(out.isApprox(Vector<t_real>::Zero(x.size())));
19 
20  out = eucl(0.1, x);
21  CHECK(out.isApprox(x * (1e0 - 0.1 / std::sqrt(world.all_sum_all(x.squaredNorm())))));
22 }
23 
24 TEST_CASE("Parallel L2 ball", "[proximal]") {
25  using namespace sopt;
26  auto const world = mpi::Communicator::World();
27  proximal::L2Ball<t_real> ball(0.5, world);
28  Vector<t_real> out;
29  Vector<t_real> x(5);
30  x << 1, 2, 3, 4, 5;
31  x *= world.rank();
32 
33  out = ball(0, x);
34  CHECK(x.isApprox(out / 0.5 * std::sqrt(world.all_sum_all(x.squaredNorm()))));
35  ball.epsilon(std::sqrt(world.all_sum_all(x.squaredNorm())) * 1.001);
36  out = ball(0, x);
37  CHECK(x.isApprox(out));
38 }
39 
40 TEST_CASE("Parallel WeightedL2Ball", "[proximal]") {
41  using namespace sopt;
42  auto const world = mpi::Communicator::World();
43  Vector<t_real> const weights = 0.01 * Vector<t_real>::Random(5).array() + 1e0;
44  Vector<t_real> x(5);
45  x << 1, 2, 3, 4, 5;
46  x *= world.rank();
47  proximal::WeightedL2Ball<t_real> wball(0.5, weights, world);
48  proximal::L2Ball<t_real> const ball(0.5, world);
49 
50  Vector<t_real> const expected =
51  ball((x.array() * weights.array()).matrix()).array() / weights.array();
52  Vector<t_real> const actual = wball(x);
53  CHECK(actual.isApprox(expected));
54 
55  wball.epsilon(std::sqrt(world.all_sum_all((x.array() * weights.array()).matrix().squaredNorm())) *
56  1.001);
57  CHECK(x.isApprox(wball(x)));
58 }
Proximal of euclidian norm.
Definition: proximal.h:18
Proximal for indicator function of L2 ball.
Definition: proximal.h:182
Real epsilon() const
Size of the ball.
Definition: proximal.h:222
Real epsilon() const
Size of the ball.
Definition: proximal.h:312
TEST_CASE("Parallel Euclidian norm", "[proximal]")
Definition: mpi_proximals.cc:9
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24