5 #include "catch2/catch_all.hpp"
6 #include <sopt/mpi/communicator.h>
10 auto const world = sopt::mpi::Communicator::World();
11 Vector<t_int>
const grid = world.broadcast<Vector<t_int>>(
12 Vector<t_int>::Random(std::max(world.size() * 2, world.size() + 2)));
13 std::vector<t_int>
const indices = {
static_cast<t_int
>(world.rank()),
14 static_cast<t_int
>(world.size() + 1)};
15 DistributeSparseVector distributor(indices, grid.size(), world);
19 distributor.scatter(grid, output);
20 REQUIRE(output.size() == 2);
21 CHECK(output(0) == grid(world.rank()));
22 CHECK(output(1) == grid(world.size() + 1));
26 Vector<t_int> local(2);
27 local << world.rank(), world.size() + world.rank();
29 if (world.is_root()) {
30 distributor.gather(local, output);
31 CHECK(output.size() == grid.size());
32 for (decltype(world.size()) i(0); i < world.size(); ++i)
CHECK(output(i) == i);
33 CHECK(output(world.size()) == 0);
34 CHECK(output(world.size() + 1) ==
35 world.size() * world.size() + (world.size() * (world.size() - 1)) / 2);
36 CHECK(grid.size() - 2 >= world.size());
37 CHECK((output.tail(grid.size() - world.size() - 2).array() == 0).all());
39 distributor.gather(local);
44 auto const world = sopt::mpi::Communicator::World();
45 Vector<t_int>
const grid = world.broadcast<Vector<t_int>>(
46 Vector<t_int>::Random(std::max(world.size() * 2, world.size() + 2)));
47 std::vector<t_int>
const indices = {
static_cast<t_int
>(world.rank()),
48 static_cast<t_int
>(world.size() + 1)};
49 AllToAllSparseVector<t_int> distributor(indices, grid.size(), grid.size() * world.rank(), world);
53 distributor.recv_grid(grid, output);
54 CAPTURE(world.rank());
55 REQUIRE(output.size() == 2);
56 CHECK(output(0) == grid(world.rank()));
57 CHECK(output(1) == grid(world.size() + 1));
61 Vector<t_int> local(2);
62 local << world.rank(), world.size() + world.rank();
64 distributor.send_grid(local, output);
65 CHECK(output.size() == grid.size());
66 if (world.is_root()) {
67 for (decltype(world.size()) i(0); i < world.size(); ++i)
CHECK(output(i) == i);
68 CHECK(output(world.size()) == 0);
69 CHECK(output(world.size() + 1) ==
70 world.size() * world.size() + (world.size() * (world.size() - 1)) / 2);
71 CHECK(grid.size() - 2 >= world.size());
72 CHECK((output.tail(grid.size() - world.size() - 2).array() == 0).all());
74 for (t_int i = 0; i < output.size(); i++)
CHECK(output(i) == 0);
80 auto const world = sopt::mpi::Communicator::World();
81 CAPTURE(world.rank());
82 Vector<t_int>
const grid = world.broadcast<Vector<t_int>>(
83 Vector<t_int>::Random(std::max(world.size() * 2, world.size() + 2)));
84 SECTION(
"Check throw when index is not ordered by node") {
85 std::vector<t_int>
const indices = {
86 static_cast<t_int
>(grid.size() * (world.rank() + 1) + world.size() + 1),
87 static_cast<t_int
>(world.rank())};
89 AllToAllSparseVector<t_int>(indices, grid.size(), grid.size() * world.rank(), world));
91 std::vector<t_int>
const indices = {
92 static_cast<t_int
>(world.rank()),
93 static_cast<t_int
>(grid.size() * world.rank() + world.size() + 1)};
94 AllToAllSparseVector<t_int> distributor(indices, grid.size(), grid.size() * world.rank(), world);
98 distributor.recv_grid(grid, output);
99 REQUIRE(output.size() == 2);
100 CHECK(output(0) == grid(world.rank()));
101 CHECK(output(1) == grid(world.size() + 1));
105 Vector<t_int> local = Vector<t_int>::Ones(2);
106 Vector<t_int> output;
107 distributor.send_grid(local, output);
108 CHECK(output.size() == grid.size());
109 CAPTURE(world.rank());
111 if (world.is_root()) {
112 for (decltype(world.size()) i(0); i < world.size(); ++i)
CHECK(output(i) == 1);
113 CHECK(output(world.size()) == 0);
114 CHECK(output(world.size() + 1) == 1);
115 CHECK(grid.size() - 2 >= world.size());
116 CHECK((output.tail(grid.size() - world.size() - 2).array() == 0).all());
118 for (t_int i = 0; i < output.size(); i++) {
119 if (i != world.size() + 1)
120 CHECK(output(i) == 0);
122 CHECK(output(i) == 1);
129 for (t_int nodes : {1, 2, 5, 10, 20, 50, 100, 1000}) {
130 for (t_int imsize : {128, 1024, 2048, 4096, 8192, 16384, 32768}) {
131 const std::int32_t N = imsize * imsize;
135 std::random_device rnd_device;
137 std::mt19937_64 mersenne_engine(rnd_device());
138 std::uniform_int_distribution<t_int> dist(0, nodes * N);
139 auto gen = [&dist, &mersenne_engine]() {
return dist(mersenne_engine); };
140 std::vector<t_int> local_indices(10);
141 std::generate(local_indices.begin(), local_indices.end(), gen);
142 std::sort(local_indices.begin(), local_indices.end(),
143 [](t_int a, t_int b) { return (a < b); });
144 CAPTURE(local_indices);
146 if (
static_cast<std::int64_t
>(N) *
static_cast<std::int64_t
>(nodes) >
147 std::numeric_limits<t_int>::max())
148 CHECK_THROWS(all_to_all_recv_sizes<t_int>(local_indices, nodes, N));
150 std::vector<t_int> recv = all_to_all_recv_sizes<t_int>(local_indices, nodes, N);
151 for (
const auto& a : recv)
CHECK(a >= 0);
153 CHECK(local_indices.size() == std::accumulate(recv.begin(), recv.end(), 0));
159 for (std::int64_t nodes : {1, 2, 5, 10, 20, 50, 100, 1000}) {
160 for (std::int64_t imsize : {128, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072}) {
161 const std::int64_t N = imsize * imsize;
164 std::random_device rnd_device;
166 std::mt19937_64 mersenne_engine(rnd_device());
167 std::uniform_int_distribution<std::int64_t> dist(
168 0,
static_cast<std::int64_t
>(nodes) *
static_cast<std::int64_t
>(N));
169 auto gen = [&dist, &mersenne_engine]() -> std::int64_t {
return dist(mersenne_engine); };
170 std::vector<std::int64_t> local_indices(10);
171 std::generate(local_indices.begin(), local_indices.end(), gen);
172 std::sort(local_indices.begin(), local_indices.end(),
173 [](std::int64_t a, std::int64_t b) { return (a < b); });
174 CAPTURE(local_indices);
176 std::vector<t_int> recv = all_to_all_recv_sizes<std::int64_t>(local_indices, nodes, N);
177 for (
const auto& a : recv)
CHECK(a >= 0);
178 CHECK(local_indices.size() == std::accumulate(recv.begin(), recv.end(), 0));
#define CHECK(CONDITION, ERROR)
#define CHECK_THROWS(STATEMENT, ERROR)
TEST_CASE("Distribe Sparse Vector")