1 #ifndef PURIFY_DISTRIBUTE_OPERATOR_H 
    2 #define PURIFY_DISTRIBUTE_OPERATOR_H 
    4 #include "purify/config.h" 
    9 #include "sopt/mpi/communicator.h" 
   15 class DistributeSparseVector {
 
   16   DistributeSparseVector(
const IndexMapping<t_int> &_mapping, 
const std::vector<t_int> &_sizes,
 
   17                          const t_int _local_size, 
const sopt::mpi::Communicator &_comm)
 
   18       : mapping(_mapping), sizes(_sizes), local_size(_local_size), comm(_comm) {}
 
   19   DistributeSparseVector(
const std::vector<t_int> &local_indices, std::vector<t_int> 
const &_sizes,
 
   20                          t_int global_size, 
const sopt::mpi::Communicator &_comm)
 
   21       : DistributeSparseVector(
 
   22             IndexMapping<t_int>(_comm.gather(local_indices, _sizes), global_size), _sizes,
 
   23             static_cast<t_int>(local_indices.size()), _comm) {}
 
   30   DistributeSparseVector(
const std::vector<t_int> &local_indices, t_int global_size,
 
   31                          const sopt::mpi::Communicator &_comm)
 
   32       : DistributeSparseVector(local_indices,
 
   33                                _comm.gather(static_cast<t_int>(local_indices.size())), global_size,
 
   35   DistributeSparseVector(
const std::set<t_int> &local_indices, t_int global_size,
 
   36                          const sopt::mpi::Communicator &_comm)
 
   37       : DistributeSparseVector(std::vector<t_int>(local_indices.begin(), local_indices.end()),
 
   38                                global_size, _comm) {}
 
   40   DistributeSparseVector(Eigen::SparseMatrixBase<T0> 
const &sparse,
 
   41                          const sopt::mpi::Communicator &_comm)
 
   42       : DistributeSparseVector(
non_empty_outers<T0, t_int>(sparse), sparse.cols(), _comm) {}
 
   44   template <
class T0, 
class T1>
 
   45   void scatter(Eigen::MatrixBase<T0> 
const &input, Eigen::MatrixBase<T1> 
const &output)
 const {
 
   46     assert(input.cols() == 1);
 
   47     if (not comm.is_root()) 
return scatter(output);
 
   48     Vector<typename T0::Scalar> buffer;
 
   49     mapping(input, buffer);
 
   50     assert(buffer.size() == std::accumulate(sizes.begin(), sizes.end(), 0));
 
   51     output.const_cast_derived() = comm.scatterv(buffer, sizes);
 
   55   void scatter(Eigen::MatrixBase<T1> 
const &output)
 const {
 
   56     if (comm.is_root()) 
throw std::runtime_error(
"This function should not be called by root");
 
   57     output.const_cast_derived() = comm.scatterv<
typename T1::Scalar>(local_size);
 
   60   template <
class T0, 
class T1>
 
   61   void gather(Eigen::MatrixBase<T0> 
const &input, Eigen::MatrixBase<T1> 
const &output)
 const {
 
   62     assert(input.cols() == 1);
 
   63     if (not comm.is_root()) 
return gather(input);
 
   64     auto const buffer = comm.gather<
typename T0::Scalar>(input, sizes);
 
   65     mapping.adjoint(buffer, output.derived());
 
   69   void gather(Eigen::MatrixBase<T1> 
const &input)
 const {
 
   70     if (comm.is_root()) 
throw std::runtime_error(
"This function should not be called by root");
 
   71     comm.gather<
typename T1::Scalar>(input);
 
   75   IndexMapping<t_int> mapping;
 
   76   std::vector<t_int> sizes;
 
   78   sopt::mpi::Communicator comm;
 
std::set< STORAGE_INDEX_TYPE > non_empty_outers(Eigen::SparseMatrixBase< T0 > const &matrix)
Indices of non empty outer indices.