PURIFY
Next-generation radio interferometric imaging
DistributeSparseVector.h
Go to the documentation of this file.
1 #ifndef PURIFY_DISTRIBUTE_OPERATOR_H
2 #define PURIFY_DISTRIBUTE_OPERATOR_H
3 
4 #include "purify/config.h"
5 #ifdef PURIFY_MPI
6 #include "purify/types.h"
7 #include <numeric>
8 #include "purify/IndexMapping.h"
9 #include "sopt/mpi/communicator.h"
10 
11 namespace purify {
15 class DistributeSparseVector {
16  DistributeSparseVector(const IndexMapping<t_int> &_mapping, const std::vector<t_int> &_sizes,
17  const t_int _local_size, const sopt::mpi::Communicator &_comm)
18  : mapping(_mapping), sizes(_sizes), local_size(_local_size), comm(_comm) {}
19  DistributeSparseVector(const std::vector<t_int> &local_indices, std::vector<t_int> const &_sizes,
20  t_int global_size, const sopt::mpi::Communicator &_comm)
21  : DistributeSparseVector(
22  IndexMapping<t_int>(_comm.gather(local_indices, _sizes), global_size), _sizes,
23  static_cast<t_int>(local_indices.size()), _comm) {}
24 
25  public:
30  DistributeSparseVector(const std::vector<t_int> &local_indices, t_int global_size,
31  const sopt::mpi::Communicator &_comm)
32  : DistributeSparseVector(local_indices,
33  _comm.gather(static_cast<t_int>(local_indices.size())), global_size,
34  _comm) {}
35  DistributeSparseVector(const std::set<t_int> &local_indices, t_int global_size,
36  const sopt::mpi::Communicator &_comm)
37  : DistributeSparseVector(std::vector<t_int>(local_indices.begin(), local_indices.end()),
38  global_size, _comm) {}
39  template <class T0>
40  DistributeSparseVector(Eigen::SparseMatrixBase<T0> const &sparse,
41  const sopt::mpi::Communicator &_comm)
42  : DistributeSparseVector(non_empty_outers<T0, t_int>(sparse), sparse.cols(), _comm) {}
43 
44  template <class T0, class T1>
45  void scatter(Eigen::MatrixBase<T0> const &input, Eigen::MatrixBase<T1> const &output) const {
46  assert(input.cols() == 1);
47  if (not comm.is_root()) return scatter(output);
48  Vector<typename T0::Scalar> buffer;
49  mapping(input, buffer);
50  assert(buffer.size() == std::accumulate(sizes.begin(), sizes.end(), 0));
51  output.const_cast_derived() = comm.scatterv(buffer, sizes);
52  }
53 
54  template <class T1>
55  void scatter(Eigen::MatrixBase<T1> const &output) const {
56  if (comm.is_root()) throw std::runtime_error("This function should not be called by root");
57  output.const_cast_derived() = comm.scatterv<typename T1::Scalar>(local_size);
58  }
59 
60  template <class T0, class T1>
61  void gather(Eigen::MatrixBase<T0> const &input, Eigen::MatrixBase<T1> const &output) const {
62  assert(input.cols() == 1);
63  if (not comm.is_root()) return gather(input);
64  auto const buffer = comm.gather<typename T0::Scalar>(input, sizes);
65  mapping.adjoint(buffer, output.derived());
66  }
67 
68  template <class T1>
69  void gather(Eigen::MatrixBase<T1> const &input) const {
70  if (comm.is_root()) throw std::runtime_error("This function should not be called by root");
71  comm.gather<typename T1::Scalar>(input);
72  }
73 
74  private:
75  IndexMapping<t_int> mapping;
76  std::vector<t_int> sizes;
77  t_int local_size;
78  sopt::mpi::Communicator comm;
79 };
80 } // namespace purify
81 #endif
82 #endif
std::set< STORAGE_INDEX_TYPE > non_empty_outers(Eigen::SparseMatrixBase< T0 > const &matrix)
Indices of non empty outer indices.
Definition: IndexMapping.h:61