1 #ifndef PURIFY_DISTRIBUTE_ALLTOALL_OPERATOR_H
2 #define PURIFY_DISTRIBUTE_ALLTOALL_OPERATOR_H
4 #include "purify/config.h"
9 #include "sopt/mpi/communicator.h"
15 template <
class STORAGE_INDEX_TYPE>
16 std::vector<t_int> all_to_all_recv_sizes(
const std::vector<STORAGE_INDEX_TYPE> &local_indices,
17 const t_int nodes,
const STORAGE_INDEX_TYPE N) {
18 std::vector<t_int> recv_sizes;
19 STORAGE_INDEX_TYPE group = 0;
21 if (
static_cast<std::int64_t
>(N) *
static_cast<std::int64_t
>(nodes) >
22 std::numeric_limits<STORAGE_INDEX_TYPE>::max())
23 throw std::runtime_error(
24 "Total number of pixels across FFT grids is less than 0. Please use index mapper with 64 "
26 "data types, i.e. long long int.");
28 for (
const STORAGE_INDEX_TYPE &index : local_indices) {
29 const STORAGE_INDEX_TYPE index_group = index / N;
30 if (index_group < group)
31 throw std::runtime_error(
"local indices are out of order for columns of gridding matrix, " +
32 std::to_string(index_group) +
" < " + std::to_string(group) +
33 " for index " + std::to_string(index));
34 if (group != index_group) {
35 recv_sizes.push_back(count);
37 while (group < index_group) {
38 recv_sizes.push_back(0);
45 recv_sizes.push_back(count);
47 while (group < nodes) {
48 recv_sizes.push_back(0);
51 assert(group == nodes);
52 assert(local_indices.size() == std::accumulate(recv_sizes.begin(), recv_sizes.end(), 0));
59 const sopt::mpi::Communicator &comm);
61 template <
class STORAGE_INDEX_TYPE = t_
int>
62 class AllToAllSparseVector {
63 AllToAllSparseVector(
const IndexMapping<STORAGE_INDEX_TYPE> &_mapping,
64 const std::vector<t_int> &_send_sizes,
const std::vector<t_int> &_recv_sizes,
65 const sopt::mpi::Communicator &_comm)
66 : mapping(_mapping), send_sizes(_send_sizes), recv_sizes(_recv_sizes), comm(_comm) {}
67 AllToAllSparseVector(
const IndexMapping<STORAGE_INDEX_TYPE> &_mapping,
68 const std::vector<t_int> &_recv_sizes,
const sopt::mpi::Communicator &_comm)
70 recv_sizes(_recv_sizes),
73 AllToAllSparseVector(
const std::vector<STORAGE_INDEX_TYPE> &local_indices,
74 const std::vector<t_int> &_recv_sizes, STORAGE_INDEX_TYPE ft_grid_size,
75 const STORAGE_INDEX_TYPE start,
const sopt::mpi::Communicator &_comm)
76 : AllToAllSparseVector(IndexMapping<STORAGE_INDEX_TYPE>(
77 _comm.all_to_allv<STORAGE_INDEX_TYPE>(local_indices, _recv_sizes),
79 _recv_sizes, _comm) {}
87 AllToAllSparseVector(
const std::vector<STORAGE_INDEX_TYPE> &local_indices,
88 STORAGE_INDEX_TYPE ft_grid_size, STORAGE_INDEX_TYPE start,
89 const sopt::mpi::Communicator &_comm)
90 : AllToAllSparseVector(
92 all_to_all_recv_sizes<STORAGE_INDEX_TYPE>(local_indices, _comm.size(), ft_grid_size),
93 ft_grid_size, start, _comm) {}
94 AllToAllSparseVector(
const std::set<STORAGE_INDEX_TYPE> &local_indices,
95 STORAGE_INDEX_TYPE ft_grid_size, STORAGE_INDEX_TYPE start,
96 const sopt::mpi::Communicator &_comm)
97 : AllToAllSparseVector(
98 std::vector<STORAGE_INDEX_TYPE>(local_indices.begin(), local_indices.end()),
99 ft_grid_size, start, _comm) {}
101 AllToAllSparseVector(Eigen::SparseMatrixBase<T0>
const &sparse,
102 const STORAGE_INDEX_TYPE ft_grid_size,
const STORAGE_INDEX_TYPE start,
103 const sopt::mpi::Communicator &_comm)
104 : AllToAllSparseVector(
non_empty_outers<T0, STORAGE_INDEX_TYPE>(sparse), ft_grid_size, start,
107 template <
class T0,
class T1>
108 void recv_grid(Eigen::MatrixBase<T0>
const &input, Eigen::MatrixBase<T1>
const &output)
const {
109 assert(input.cols() == 1);
110 Vector<typename T0::Scalar> buffer;
111 mapping(input, buffer);
112 assert(buffer.size() == std::accumulate(send_sizes.begin(), send_sizes.end(), 0));
113 output.const_cast_derived() =
114 comm.all_to_allv<
typename T0::Scalar>(buffer, send_sizes, recv_sizes);
117 template <
class T0,
class T1>
118 void send_grid(Eigen::MatrixBase<T0>
const &input, Eigen::MatrixBase<T1>
const &output)
const {
119 assert(input.cols() == 1);
120 auto const buffer = comm.all_to_allv<
typename T0::Scalar>(input, recv_sizes, send_sizes);
121 mapping.adjoint(buffer, output.derived());
125 IndexMapping<STORAGE_INDEX_TYPE> mapping;
126 std::vector<t_int> send_sizes;
127 std::vector<t_int> recv_sizes;
128 sopt::mpi::Communicator comm;
std::vector< t_int > all_to_all_send_sizes(const std::vector< t_int > &recv_sizes, const sopt::mpi::Communicator &comm)
std::set< STORAGE_INDEX_TYPE > non_empty_outers(Eigen::SparseMatrixBase< T0 > const &matrix)
Indices of non empty outer indices.