3 #include "catch2/catch_all.hpp"
5 #include "sopt/config.h"
11 TEST_CASE(
"Creates an mpi communicator") {
14 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
15 MPI_Comm_size(MPI_COMM_WORLD, &size);
17 auto const world = mpi::Communicator::World();
19 SECTION(
"General stuff") {
20 REQUIRE(*world == MPI_COMM_WORLD);
21 REQUIRE(
static_cast<t_int>(world.rank()) == rank);
22 REQUIRE(
static_cast<t_int>(world.size()) == size);
24 mpi::Communicator
const shallow = world;
25 CHECK(*shallow == *world);
28 SECTION(
"Duplicate") {
29 mpi::Communicator
const dup = world.duplicate();
30 CHECK(*dup != *world);
34 if (world.rank() == world.root_id()) {
35 std::vector<t_int> scattered(world.size());
36 std::iota(scattered.begin(), scattered.end(), 2);
37 auto const result = world.scatter_one(scattered);
38 CHECK(result == world.rank() + 2);
40 auto const result = world.scatter_one<
t_int>();
41 CHECK(result == world.rank() + 2);
46 std::vector<t_int> sizes(world.size());
47 std::vector<t_int> displs(world.size());
48 for (
t_uint i(0); i < world.rank(); ++i) sizes[i] = world.rank() * 2 + i;
49 for (
t_uint i(1); i < world.rank(); ++i) displs[i] = displs[i - 1] + sizes[i - 1];
52 auto const result = world.rank() == world.root_id()
53 ? world.scatterv(sendee, sizes)
54 : world.scatterv<
t_int>(sizes[world.rank()]);
55 CHECK(result.isApprox(sendee.segment(displs[world.rank()], sizes[world.rank()])));
58 SECTION(
"Gather a single item") {
59 if (world.rank() == world.root_id()) {
60 std::vector<t_int> scattered(world.size());
61 std::iota(scattered.begin(), scattered.end(), 2);
62 auto const result = world.scatter_one(scattered);
63 REQUIRE(result == world.rank() + 2);
64 auto const gathered = world.gather(result);
65 for (decltype(gathered)::size_type i = 0; i < gathered.size(); i++)
66 CHECK(gathered[i] == scattered[i]);
68 auto const result = world.scatter_one<
t_int>();
69 REQUIRE(result == world.rank() + 2);
70 auto const gather = world.gather(result);
71 CHECK(gather.empty());
75 SECTION(
"Gather an eigen vector") {
76 auto const size = [](
t_int n) {
return n * 2 + 10; };
77 auto const totsize = [](
t_int n) {
return std::max<t_int>(0,
n * (9 +
n)); };
79 std::vector<t_int> sizes(world.size());
81 std::generate(sizes.begin(), sizes.end(), [&
n, &size]() { return size(n++); });
83 auto const result = world.is_root() ? world.gather(sendee, sizes) : world.gather(sendee);
84 if (world.rank() == world.root_id()) {
85 CHECK(result.size() == totsize(world.size()));
86 for (decltype(world.size()) i(0); i < world.size(); ++i)
89 CHECK(result.size() == 0);
92 SECTION(
"Gather an std::set") {
93 std::set<t_int>
const input{
static_cast<t_int>(world.size()),
static_cast<t_int>(world.rank())};
94 auto const result = world.gather(input, world.gather<
t_int>(input.size()));
95 if (world.is_root()) {
96 CHECK(result.size() == world.size() + 1);
97 for (decltype(world.size()) i(0); i <= world.size(); ++i) CHECK(result.count(i) == 1);
99 CHECK(result.empty());
102 SECTION(
"Gather an std::vector") {
103 std::vector<t_int>
const input{
static_cast<t_int>(world.size()),
104 static_cast<t_int>(world.rank())};
105 auto const result = world.gather(input, world.gather<
t_int>(input.size()));
106 if (world.is_root()) {
107 CHECK(result.size() == world.size() * 2);
108 for (decltype(world.size()) i(0); i < world.size(); ++i) {
109 CHECK(result[2 * i] == world.size());
110 CHECK(result[2 * i + 1] == i);
113 CHECK(result.empty());
116 SECTION(
"All sum all over image") {
118 image.fill(world.rank());
119 world.all_sum_all(image);
120 CHECK((2 * image == world.size() * (world.size() - 1)).all());
123 SECTION(
"Broadcast") {
125 auto const result = world.broadcast(world.root_id() == world.rank() ? 5 : 2, world.root_id());
129 SECTION(
"Eigen vector") {
133 world.rank() == world.root_id() ? world.broadcast(y0) : world.broadcast<
Vector<t_int>>();
136 std::vector<t_int> v0 = {3, 2, 1};
137 auto const v = world.rank() == world.root_id() ? world.broadcast(v0)
138 : world.broadcast<std::vector<t_int>>();
139 CHECK(std::equal(v.begin(), v.end(), v0.begin()));
142 SECTION(
"Eigen image - and check for correct size initialization") {
144 image0 << 3, 2, 1, 0;
145 auto const image = world.rank() == world.root_id() ? world.broadcast(image0)
147 CHECK(image.matrix() == image0.matrix());
150 CHECK(world.broadcast(image1).matrix() == image0.matrix());
153 SECTION(
"std::string") {
154 const auto *
const expected =
"Hello World!";
155 std::string
const input = world.is_root() ? expected :
"";
156 CHECK(world.broadcast(input) == expected);
158 SECTION(
"all_to_allv") {
160 std::vector<t_int> sizes(world.size(), world.rank());
163 const Vector<t_int> output = world.all_to_allv(sendee, sizes);
165 for (
t_int i = 0; i < world.size() - 1; i++) {
169 CAPTURE(output.segment(sum, i + 1));
171 REQUIRE(output.segment(sum, i + 1).isApprox(expected));
TEST_CASE("Bisection x^3")
int t_int
Root of the type hierarchy for signed integers.
size_t t_uint
Root of the type hierarchy for unsigned integers.
Eigen::Array< T, Eigen::Dynamic, Eigen::Dynamic > Image
A 2-dimensional list of elements of given type.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.