1 #ifndef PURIFY_DISTRIBUTE_OPERATOR_H
2 #define PURIFY_DISTRIBUTE_OPERATOR_H
4 #include "purify/config.h"
9 #include "sopt/mpi/communicator.h"
15 class DistributeSparseVector {
16 DistributeSparseVector(
const IndexMapping<t_int> &_mapping,
const std::vector<t_int> &_sizes,
17 const t_int _local_size,
const sopt::mpi::Communicator &_comm)
18 : mapping(_mapping), sizes(_sizes), local_size(_local_size), comm(_comm) {}
19 DistributeSparseVector(
const std::vector<t_int> &local_indices, std::vector<t_int>
const &_sizes,
20 t_int global_size,
const sopt::mpi::Communicator &_comm)
21 : DistributeSparseVector(
22 IndexMapping<t_int>(_comm.gather(local_indices, _sizes), global_size), _sizes,
23 static_cast<t_int>(local_indices.size()), _comm) {}
30 DistributeSparseVector(
const std::vector<t_int> &local_indices, t_int global_size,
31 const sopt::mpi::Communicator &_comm)
32 : DistributeSparseVector(local_indices,
33 _comm.gather(static_cast<t_int>(local_indices.size())), global_size,
35 DistributeSparseVector(
const std::set<t_int> &local_indices, t_int global_size,
36 const sopt::mpi::Communicator &_comm)
37 : DistributeSparseVector(std::vector<t_int>(local_indices.begin(), local_indices.end()),
38 global_size, _comm) {}
40 DistributeSparseVector(Eigen::SparseMatrixBase<T0>
const &sparse,
41 const sopt::mpi::Communicator &_comm)
42 : DistributeSparseVector(
non_empty_outers<T0, t_int>(sparse), sparse.cols(), _comm) {}
44 template <
class T0,
class T1>
45 void scatter(Eigen::MatrixBase<T0>
const &input, Eigen::MatrixBase<T1>
const &output)
const {
46 assert(input.cols() == 1);
47 if (not comm.is_root())
return scatter(output);
48 Vector<typename T0::Scalar> buffer;
49 mapping(input, buffer);
50 assert(buffer.size() == std::accumulate(sizes.begin(), sizes.end(), 0));
51 output.const_cast_derived() = comm.scatterv(buffer, sizes);
55 void scatter(Eigen::MatrixBase<T1>
const &output)
const {
56 if (comm.is_root())
throw std::runtime_error(
"This function should not be called by root");
57 output.const_cast_derived() = comm.scatterv<
typename T1::Scalar>(local_size);
60 template <
class T0,
class T1>
61 void gather(Eigen::MatrixBase<T0>
const &input, Eigen::MatrixBase<T1>
const &output)
const {
62 assert(input.cols() == 1);
63 if (not comm.is_root())
return gather(input);
64 auto const buffer = comm.gather<
typename T0::Scalar>(input, sizes);
65 mapping.adjoint(buffer, output.derived());
69 void gather(Eigen::MatrixBase<T1>
const &input)
const {
70 if (comm.is_root())
throw std::runtime_error(
"This function should not be called by root");
71 comm.gather<
typename T1::Scalar>(input);
75 IndexMapping<t_int> mapping;
76 std::vector<t_int> sizes;
78 sopt::mpi::Communicator comm;
std::set< STORAGE_INDEX_TYPE > non_empty_outers(Eigen::SparseMatrixBase< T0 > const &matrix)
Indices of non empty outer indices.