5 #include "catch2/catch_all.hpp" 
    6 #include <sopt/mpi/communicator.h> 
   10   auto const world = sopt::mpi::Communicator::World();
 
   11   Vector<t_int> 
const grid = world.broadcast<Vector<t_int>>(
 
   12       Vector<t_int>::Random(std::max(world.size() * 2, world.size() + 2)));
 
   13   std::vector<t_int> 
const indices = {
static_cast<t_int
>(world.rank()),
 
   14                                       static_cast<t_int
>(world.size() + 1)};
 
   15   DistributeSparseVector distributor(indices, grid.size(), world);
 
   19     distributor.scatter(grid, output);
 
   20     REQUIRE(output.size() == 2);
 
   21     CHECK(output(0) == grid(world.rank()));
 
   22     CHECK(output(1) == grid(world.size() + 1));
 
   26     Vector<t_int> local(2);
 
   27     local << world.rank(), world.size() + world.rank();
 
   29     if (world.is_root()) {
 
   30       distributor.gather(local, output);
 
   31       CHECK(output.size() == grid.size());
 
   32       for (decltype(world.size()) i(0); i < world.size(); ++i) 
CHECK(output(i) == i);
 
   33       CHECK(output(world.size()) == 0);
 
   34       CHECK(output(world.size() + 1) ==
 
   35             world.size() * world.size() + (world.size() * (world.size() - 1)) / 2);
 
   36       CHECK(grid.size() - 2 >= world.size());
 
   37       CHECK((output.tail(grid.size() - world.size() - 2).array() == 0).all());
 
   39       distributor.gather(local);
 
   44   auto const world = sopt::mpi::Communicator::World();
 
   45   Vector<t_int> 
const grid = world.broadcast<Vector<t_int>>(
 
   46       Vector<t_int>::Random(std::max(world.size() * 2, world.size() + 2)));
 
   47   std::vector<t_int> 
const indices = {
static_cast<t_int
>(world.rank()),
 
   48                                       static_cast<t_int
>(world.size() + 1)};
 
   49   AllToAllSparseVector<t_int> distributor(indices, grid.size(), grid.size() * world.rank(), world);
 
   53     distributor.recv_grid(grid, output);
 
   54     CAPTURE(world.rank());
 
   55     REQUIRE(output.size() == 2);
 
   56     CHECK(output(0) == grid(world.rank()));
 
   57     CHECK(output(1) == grid(world.size() + 1));
 
   61     Vector<t_int> local(2);
 
   62     local << world.rank(), world.size() + world.rank();
 
   64     distributor.send_grid(local, output);
 
   65     CHECK(output.size() == grid.size());
 
   66     if (world.is_root()) {
 
   67       for (decltype(world.size()) i(0); i < world.size(); ++i) 
CHECK(output(i) == i);
 
   68       CHECK(output(world.size()) == 0);
 
   69       CHECK(output(world.size() + 1) ==
 
   70             world.size() * world.size() + (world.size() * (world.size() - 1)) / 2);
 
   71       CHECK(grid.size() - 2 >= world.size());
 
   72       CHECK((output.tail(grid.size() - world.size() - 2).array() == 0).all());
 
   74       for (t_int i = 0; i < output.size(); i++) 
CHECK(output(i) == 0);
 
   80   auto const world = sopt::mpi::Communicator::World();
 
   81   CAPTURE(world.rank());
 
   82   Vector<t_int> 
const grid = world.broadcast<Vector<t_int>>(
 
   83       Vector<t_int>::Random(std::max(world.size() * 2, world.size() + 2)));
 
   84   SECTION(
"Check throw when index is not ordered by node") {
 
   85     std::vector<t_int> 
const indices = {
 
   86         static_cast<t_int
>(grid.size() * (world.rank() + 1) + world.size() + 1),
 
   87         static_cast<t_int
>(world.rank())};
 
   89         AllToAllSparseVector<t_int>(indices, grid.size(), grid.size() * world.rank(), world));
 
   91   std::vector<t_int> 
const indices = {
 
   92       static_cast<t_int
>(world.rank()),
 
   93       static_cast<t_int
>(grid.size() * world.rank() + world.size() + 1)};
 
   94   AllToAllSparseVector<t_int> distributor(indices, grid.size(), grid.size() * world.rank(), world);
 
   98     distributor.recv_grid(grid, output);
 
   99     REQUIRE(output.size() == 2);
 
  100     CHECK(output(0) == grid(world.rank()));
 
  101     CHECK(output(1) == grid(world.size() + 1));
 
  105     Vector<t_int> local = Vector<t_int>::Ones(2);
 
  106     Vector<t_int> output;
 
  107     distributor.send_grid(local, output);
 
  108     CHECK(output.size() == grid.size());
 
  109     CAPTURE(world.rank());
 
  111     if (world.is_root()) {
 
  112       for (decltype(world.size()) i(0); i < world.size(); ++i) 
CHECK(output(i) == 1);
 
  113       CHECK(output(world.size()) == 0);
 
  114       CHECK(output(world.size() + 1) == 1);
 
  115       CHECK(grid.size() - 2 >= world.size());
 
  116       CHECK((output.tail(grid.size() - world.size() - 2).array() == 0).all());
 
  118       for (t_int i = 0; i < output.size(); i++) {
 
  119         if (i != world.size() + 1)
 
  120           CHECK(output(i) == 0);
 
  122           CHECK(output(i) == 1);
 
  129   for (t_int nodes : {1, 2, 5, 10, 20, 50, 100, 1000}) {
 
  130     for (t_int imsize : {128, 1024, 2048, 4096, 8192, 16384, 32768}) {
 
  131       const std::int32_t N = imsize * imsize;
 
  135       std::random_device rnd_device;
 
  137       std::mt19937_64 mersenne_engine(rnd_device());  
 
  138       std::uniform_int_distribution<t_int> dist(0, nodes * N);
 
  139       auto gen = [&dist, &mersenne_engine]() { 
return dist(mersenne_engine); };
 
  140       std::vector<t_int> local_indices(10);
 
  141       std::generate(local_indices.begin(), local_indices.end(), gen);
 
  142       std::sort(local_indices.begin(), local_indices.end(),
 
  143                 [](t_int a, t_int b) { return (a < b); });
 
  144       CAPTURE(local_indices);
 
  146       if (
static_cast<std::int64_t
>(N) * 
static_cast<std::int64_t
>(nodes) >
 
  147           std::numeric_limits<t_int>::max())
 
  148         CHECK_THROWS(all_to_all_recv_sizes<t_int>(local_indices, nodes, N));
 
  150         std::vector<t_int> recv = all_to_all_recv_sizes<t_int>(local_indices, nodes, N);
 
  151         for (
const auto& a : recv) 
CHECK(a >= 0);
 
  153         CHECK(local_indices.size() == std::accumulate(recv.begin(), recv.end(), 0));
 
  159   for (std::int64_t nodes : {1, 2, 5, 10, 20, 50, 100, 1000}) {
 
  160     for (std::int64_t imsize : {128, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072}) {
 
  161       const std::int64_t N = imsize * imsize;
 
  164       std::random_device rnd_device;
 
  166       std::mt19937_64 mersenne_engine(rnd_device());  
 
  167       std::uniform_int_distribution<std::int64_t> dist(
 
  168           0, 
static_cast<std::int64_t
>(nodes) * 
static_cast<std::int64_t
>(N));
 
  169       auto gen = [&dist, &mersenne_engine]() -> std::int64_t { 
return dist(mersenne_engine); };
 
  170       std::vector<std::int64_t> local_indices(10);
 
  171       std::generate(local_indices.begin(), local_indices.end(), gen);
 
  172       std::sort(local_indices.begin(), local_indices.end(),
 
  173                 [](std::int64_t a, std::int64_t b) { return (a < b); });
 
  174       CAPTURE(local_indices);
 
  176       std::vector<t_int> recv = all_to_all_recv_sizes<std::int64_t>(local_indices, nodes, N);
 
  177       for (
const auto& a : recv) 
CHECK(a >= 0);
 
  178       CHECK(local_indices.size() == std::accumulate(recv.begin(), recv.end(), 0));
 
#define CHECK(CONDITION, ERROR)
 
#define CHECK_THROWS(STATEMENT, ERROR)
 
TEST_CASE("Distribe Sparse Vector")