SOPT
Sparse OPTimisation
communicator.cc
Go to the documentation of this file.
1 #include <iostream>
2 #include <numeric>
3 #include "catch2/catch_all.hpp"
4 
5 #include "sopt/config.h"
7 
8 using namespace sopt;
9 
10 #ifdef SOPT_MPI
11 TEST_CASE("Creates an mpi communicator") {
12  int rank;
13  int size;
14  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
15  MPI_Comm_size(MPI_COMM_WORLD, &size);
16 
17  auto const world = mpi::Communicator::World();
18 
19  SECTION("General stuff") {
20  REQUIRE(*world == MPI_COMM_WORLD);
21  REQUIRE(static_cast<t_int>(world.rank()) == rank);
22  REQUIRE(static_cast<t_int>(world.size()) == size);
23 
24  mpi::Communicator const shallow = world;
25  CHECK(*shallow == *world);
26  }
27 
28  SECTION("Duplicate") {
29  mpi::Communicator const dup = world.duplicate();
30  CHECK(*dup != *world);
31  }
32 
33  SECTION("Scatter") {
34  if (world.rank() == world.root_id()) {
35  std::vector<t_int> scattered(world.size());
36  std::iota(scattered.begin(), scattered.end(), 2);
37  auto const result = world.scatter_one(scattered);
38  CHECK(result == world.rank() + 2);
39  } else {
40  auto const result = world.scatter_one<t_int>();
41  CHECK(result == world.rank() + 2);
42  }
43  }
44 
45  SECTION("ScatterV") {
46  std::vector<t_int> sizes(world.size());
47  std::vector<t_int> displs(world.size());
48  for (t_uint i(0); i < world.rank(); ++i) sizes[i] = world.rank() * 2 + i;
49  for (t_uint i(1); i < world.rank(); ++i) displs[i] = displs[i - 1] + sizes[i - 1];
50  Vector<t_int> const sendee =
51  Vector<t_int>::Random(std::accumulate(sizes.begin(), sizes.end(), 0));
52  auto const result = world.rank() == world.root_id()
53  ? world.scatterv(sendee, sizes)
54  : world.scatterv<t_int>(sizes[world.rank()]);
55  CHECK(result.isApprox(sendee.segment(displs[world.rank()], sizes[world.rank()])));
56  }
57 
58  SECTION("Gather a single item") {
59  if (world.rank() == world.root_id()) {
60  std::vector<t_int> scattered(world.size());
61  std::iota(scattered.begin(), scattered.end(), 2);
62  auto const result = world.scatter_one(scattered);
63  REQUIRE(result == world.rank() + 2);
64  auto const gathered = world.gather(result);
65  for (decltype(gathered)::size_type i = 0; i < gathered.size(); i++)
66  CHECK(gathered[i] == scattered[i]);
67  } else {
68  auto const result = world.scatter_one<t_int>();
69  REQUIRE(result == world.rank() + 2);
70  auto const gather = world.gather(result);
71  CHECK(gather.empty());
72  }
73  }
74 
75  SECTION("Gather an eigen vector") {
76  auto const size = [](t_int n) { return n * 2 + 10; };
77  auto const totsize = [](t_int n) { return std::max<t_int>(0, n * (9 + n)); };
78  Vector<t_int> const sendee = Vector<t_int>::Constant(size(world.rank()), world.rank());
79  std::vector<t_int> sizes(world.size());
80  int n(0);
81  std::generate(sizes.begin(), sizes.end(), [&n, &size]() { return size(n++); });
82 
83  auto const result = world.is_root() ? world.gather(sendee, sizes) : world.gather(sendee);
84  if (world.rank() == world.root_id()) {
85  CHECK(result.size() == totsize(world.size()));
86  for (decltype(world.size()) i(0); i < world.size(); ++i)
87  CHECK(result.segment(totsize(i), size(i)) == Vector<t_int>::Constant(size(i), i));
88  } else
89  CHECK(result.size() == 0);
90  }
91 
92  SECTION("Gather an std::set") {
93  std::set<t_int> const input{static_cast<t_int>(world.size()), static_cast<t_int>(world.rank())};
94  auto const result = world.gather(input, world.gather<t_int>(input.size()));
95  if (world.is_root()) {
96  CHECK(result.size() == world.size() + 1);
97  for (decltype(world.size()) i(0); i <= world.size(); ++i) CHECK(result.count(i) == 1);
98  } else
99  CHECK(result.empty());
100  }
101 
102  SECTION("Gather an std::vector") {
103  std::vector<t_int> const input{static_cast<t_int>(world.size()),
104  static_cast<t_int>(world.rank())};
105  auto const result = world.gather(input, world.gather<t_int>(input.size()));
106  if (world.is_root()) {
107  CHECK(result.size() == world.size() * 2);
108  for (decltype(world.size()) i(0); i < world.size(); ++i) {
109  CHECK(result[2 * i] == world.size());
110  CHECK(result[2 * i + 1] == i);
111  }
112  } else
113  CHECK(result.empty());
114  }
115 
116  SECTION("All sum all over image") {
117  Image<t_int> image(2, 2);
118  image.fill(world.rank());
119  world.all_sum_all(image);
120  CHECK((2 * image == world.size() * (world.size() - 1)).all());
121  }
122 
123  SECTION("Broadcast") {
124  SECTION("integer") {
125  auto const result = world.broadcast(world.root_id() == world.rank() ? 5 : 2, world.root_id());
126  CHECK(result == 5);
127  }
128 
129  SECTION("Eigen vector") {
130  Vector<t_int> y0(3);
131  y0 << 3, 2, 1;
132  auto const y =
133  world.rank() == world.root_id() ? world.broadcast(y0) : world.broadcast<Vector<t_int>>();
134  CHECK(y == y0);
135 
136  std::vector<t_int> v0 = {3, 2, 1};
137  auto const v = world.rank() == world.root_id() ? world.broadcast(v0)
138  : world.broadcast<std::vector<t_int>>();
139  CHECK(std::equal(v.begin(), v.end(), v0.begin()));
140  }
141 
142  SECTION("Eigen image - and check for correct size initialization") {
143  Image<t_int> image0(2, 2);
144  image0 << 3, 2, 1, 0;
145  auto const image = world.rank() == world.root_id() ? world.broadcast(image0)
146  : world.broadcast<Image<t_int>>();
147  CHECK(image.matrix() == image0.matrix());
148 
149  Image<t_int> const image1 = world.is_root() ? image0 : Image<t_int>();
150  CHECK(world.broadcast(image1).matrix() == image0.matrix());
151  }
152 
153  SECTION("std::string") {
154  const auto *const expected = "Hello World!";
155  std::string const input = world.is_root() ? expected : "";
156  CHECK(world.broadcast(input) == expected);
157  }
158  SECTION("all_to_allv") {
159  //
160  std::vector<t_int> sizes(world.size(), world.rank());
161  const Vector<t_int> sendee =
162  Vector<t_int>::Constant(std::accumulate(sizes.begin(), sizes.end(), 0), world.rank());
163  const Vector<t_int> output = world.all_to_allv(sendee, sizes);
164  t_int sum = 0;
165  for (t_int i = 0; i < world.size() - 1; i++) {
166  const Vector<t_int> expected = Vector<t_int>::Constant(i + 1, i + 1);
167  CAPTURE(sum);
168  CAPTURE(i);
169  CAPTURE(output.segment(sum, i + 1));
170  CAPTURE(expected);
171  REQUIRE(output.segment(sum, i + 1).isApprox(expected));
172  sum += i + 1;
173  }
174  }
175  }
176 }
177 #endif
constexpr auto n
Definition: wavelets.cc:56
TEST_CASE("Bisection x^3")
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Eigen::Array< T, Eigen::Dynamic, Eigen::Dynamic > Image
A 2-dimensional list of elements of given type.
Definition: types.h:39
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24