1 #ifndef SOPT_MPI_COMMUNICATOR_H
2 #define SOPT_MPI_COMMUNICATOR_H
4 #include "sopt/config.h"
15 #include <type_traits>
45 Communicator() : impl(), session() {}
47 static Communicator World() {
return Communicator(MPI_COMM_WORLD); }
48 static Communicator Self() {
return Communicator(MPI_COMM_SELF); }
50 virtual ~Communicator(){}
53 decltype(Impl::size) size()
const {
return impl ? impl->size : 1; }
55 decltype(Impl::rank) rank()
const {
return impl ? impl->rank : 0; }
57 decltype(Impl::comm) operator*()
const {
58 if (not impl)
throw std::runtime_error(
"Communicator was not set");
62 static constexpr
t_uint root_id() {
return 0; }
64 bool is_root()
const {
return rank() == root_id(); }
67 Communicator duplicate()
const;
69 Communicator clone()
const {
return duplicate(); }
71 void abort(
const std::string &reason)
const;
75 typename std::enable_if<is_registered_type<T>::value, T>::type all_reduce(T
const &value,
76 MPI_Op operation)
const;
80 typename std::enable_if<is_registered_type<T>::value>::type all_reduce(
Matrix<T> &image,
81 MPI_Op operation)
const {
82 if (size() == 1)
return;
83 assert(impl and image.size() and image.data());
84 MPI_Allreduce(MPI_IN_PLACE, image.data(), image.size(), registered_type(T(0)), operation,
88 typename std::enable_if<is_registered_type<T>::value>::type all_reduce(
Image<T> &image,
89 MPI_Op operation)
const {
90 if (size() == 1)
return;
91 assert(impl and image.size() and image.data());
92 MPI_Allreduce(MPI_IN_PLACE, image.data(), image.size(), registered_type(T(0)), operation,
96 typename std::enable_if<is_registered_type<T>::value>::type all_reduce(
Vector<T> &image,
97 MPI_Op operation)
const {
98 if (size() == 1)
return;
99 assert(impl and image.size() and image.data());
100 MPI_Allreduce(MPI_IN_PLACE, image.data(), image.size(), registered_type(T(0)), operation,
105 template <
typename T>
106 typename std::enable_if<is_registered_type<T>::value, T>::type all_sum_all(T
const &value)
const {
107 return all_reduce(value, MPI_SUM);
109 template <
typename T>
110 typename std::enable_if<is_registered_type<typename T::Scalar>::value>::type all_sum_all(
112 all_reduce(image, MPI_SUM);
114 template <
typename T>
115 typename std::enable_if<is_registered_type<typename T::Scalar>::value, T>::type all_sum_all(
116 T
const &image)
const {
118 all_reduce(result, MPI_SUM);
123 template <
typename T>
124 typename std::enable_if<is_registered_type<T>::value, T>::type broadcast(
125 T
const &value,
t_uint const root = root_id())
const;
127 template <
typename T>
128 typename std::enable_if<is_registered_type<T>::value, T>::type broadcast(
129 t_uint const root = root_id())
const;
130 template <
typename T>
131 typename std::enable_if<is_registered_type<typename T::Scalar>::value, T>::type broadcast(
132 T
const &vec,
t_uint const root = root_id())
const;
133 template <
typename T>
134 typename std::enable_if<is_registered_type<typename T::Scalar>::value, T>::type broadcast(
135 t_uint const root = root_id())
const;
136 template <
typename T>
137 typename std::enable_if<is_registered_type<typename T::value_type>::value and
138 not std::is_base_of<Eigen::EigenBase<T>, T>::value,
140 broadcast(T
const &vec,
t_uint const root = root_id())
const;
141 template <
typename T>
142 typename std::enable_if<is_registered_type<typename T::value_type>::value and
143 not std::is_base_of<Eigen::EigenBase<T>, T>::value,
145 broadcast(
t_uint const root = root_id())
const;
146 std::string broadcast(std::string
const &input,
t_uint const root = root_id())
const;
149 template <
typename T>
150 typename std::enable_if<is_registered_type<T>::value, T>::type scatter_one(
151 std::vector<T>
const &values,
t_uint const root = root_id())
const;
153 template <
typename T>
154 typename std::enable_if<is_registered_type<T>::value, T>::type scatter_one(
155 t_uint const root = root_id())
const;
158 template <
typename T>
159 typename std::enable_if<is_registered_type<T>::value,
Vector<T>>::type scatterv(
160 Vector<T> const &vec, std::vector<t_int>
const &sizes,
t_uint const root = root_id())
const;
161 template <
typename T>
162 typename std::enable_if<is_registered_type<T>::value,
Vector<T>>::type scatterv(
163 t_int local_size,
t_uint const root = root_id())
const;
166 template <
typename T>
167 typename std::enable_if<is_registered_type<T>::value, std::vector<T>>::type gather(
168 T
const value,
t_uint const root = root_id())
const;
171 template <
typename T>
172 typename std::enable_if<is_registered_type<T>::value,
Vector<T>>::type gather(
173 Vector<T> const &vec, std::vector<t_int>
const &sizes,
t_uint const root = root_id())
const {
174 return gather_<Vector<T>, T>(vec, sizes, root);
176 template <
typename T>
177 typename std::enable_if<is_registered_type<T>::value,
Vector<T>>::type gather(
179 return gather_<Vector<T>, T>(vec, root);
182 template <
typename T>
183 typename std::enable_if<is_registered_type<T>::value, std::set<T>>::type gather(
184 std::set<T>
const &set, std::vector<t_int>
const &sizes,
t_uint const root = root_id())
const;
185 template <
typename T>
186 typename std::enable_if<is_registered_type<T>::value, std::set<T>>::type gather(
187 std::set<T>
const &vec,
t_uint const root = root_id())
const;
188 template <
typename T>
189 typename std::enable_if<is_registered_type<T>::value, std::vector<T>>::type gather(
190 std::vector<T>
const &vec, std::vector<t_int>
const &sizes,
191 t_uint const root = root_id())
const {
192 return gather_<std::vector<T>, T>(vec, sizes, root);
194 template <
typename T>
195 typename std::enable_if<is_registered_type<T>::value, std::vector<T>>::type gather(
196 std::vector<T>
const &vec,
t_uint const root = root_id())
const {
197 return gather_<std::vector<T>, T>(vec, root);
200 template <
typename T>
201 typename std::enable_if<is_registered_type<T>::value,
Vector<T>>::type all_to_allv(
202 const Vector<T> &vec, std::vector<t_int>
const &send_sizes)
const;
203 template <
typename T>
204 typename std::enable_if<is_registered_type<T>::value,
Vector<T>>::type all_to_allv(
205 const Vector<T> &vec, std::vector<t_int>
const &send_sizes,
206 std::vector<t_int>
const &rec_sizes)
const;
207 template <
typename T>
208 typename std::enable_if<is_registered_type<T>::value, std::vector<T>>::type all_to_allv(
209 const std::vector<T> &vec, std::vector<t_int>
const &send_sizes,
210 std::vector<t_int>
const &rec_sizes)
const;
211 template <
typename T>
212 typename std::enable_if<is_registered_type<T>::value, std::vector<T>>::type all_to_allv(
213 const std::vector<T> &vec, std::vector<t_int>
const &send_sizes)
const;
220 MPI_Comm_split(**
this, color,
static_cast<t_int>(rank), &comm);
225 void barrier()
const {
226 if (not impl)
return;
227 if (MPI_Barrier(**
this) != MPI_SUCCESS)
228 throw std::runtime_error(
"Encountered error in mpi barrier");
233 std::shared_ptr<Impl const> impl;
236 std::shared_ptr<mpi::details::initializer> session;
239 static void delete_comm(Impl *impl);
246 Communicator(MPI_Comm
const &comm);
249 template <
typename CONTAINER,
typename T>
250 CONTAINER gather_(CONTAINER
const &vec, std::vector<t_int>
const &sizes,
251 t_uint const root = root_id())
const;
252 template <
typename CONTAINER,
typename T>
253 CONTAINER gather_(CONTAINER
const &vec,
t_uint const root = root_id())
const;
256 template <
typename T>
257 typename std::enable_if<is_registered_type<T>::value, T>::type Communicator::all_reduce(
258 T
const &value, MPI_Op operation)
const {
259 if (size() == 1)
return value;
262 MPI_Allreduce(
const_cast<void *
>(
reinterpret_cast<const void *
>(&value)), &result, 1,
263 registered_type(value), operation, **
this);
267 template <
typename T>
268 typename std::enable_if<is_registered_type<T>::value, T>::type Communicator::scatter_one(
269 std::vector<T>
const &values,
t_uint const root)
const {
270 assert(root < size());
271 if (values.size() != size())
throw std::runtime_error(
"Expected a single object per process");
272 if (size() == 1)
return values.at(0);
274 MPI_Scatter(
const_cast<void *
>(
reinterpret_cast<const void *
>(values.data())), 1,
275 registered_type(result), &result, 1, registered_type(result), root, **
this);
279 template <
typename T>
280 typename std::enable_if<is_registered_type<T>::value, T>::type Communicator::scatter_one(
281 t_uint const root)
const {
283 MPI_Scatter(
nullptr, 1, registered_type(result), &result, 1, registered_type(result), root,
288 template <
typename T>
289 typename std::enable_if<is_registered_type<T>::value,
Vector<T>>::type Communicator::scatterv(
290 Vector<T> const &vec, std::vector<t_int>
const &sizes,
t_uint const root)
const {
292 if (sizes.size() == 1 and vec.size() != sizes.front())
293 throw std::runtime_error(
"Input vector size and sizes are inconsistent on root");
296 if (rank() != root)
return scatterv<T>(sizes.at(rank()), root);
297 std::vector<int> sizes_;
298 std::vector<int> displs;
300 for (
auto const size : sizes) {
301 sizes_.push_back(
static_cast<int>(size));
305 if (vec.size() != i)
throw std::runtime_error(
"Input vector size and sizes are inconsistent");
309 result = vec.head(sizes[rank()]);
311 MPI_Scatterv(
const_cast<void *
>(
reinterpret_cast<const void *
>(vec.data())), sizes_.data(),
312 displs.data(), registered_type(T(0)), result.data(), sizes_[rank()],
313 registered_type(T(0)), root, **
this);
317 template <
typename T>
318 typename std::enable_if<is_registered_type<T>::value,
Vector<T>>::type Communicator::scatterv(
320 if (rank() == root)
throw std::runtime_error(
"Root should call the *other* scatterv");
321 std::vector<int> sizes(size());
322 sizes[rank()] = local_size;
324 MPI_Scatterv(
nullptr, sizes.data(),
nullptr, registered_type(T(0)), result.data(), local_size,
325 registered_type(T(0)), root, **
this);
329 template <
typename T>
330 typename std::enable_if<is_registered_type<T>::value, std::vector<T>>::type
331 Communicator::all_to_allv(
const std::vector<T> &vec, std::vector<t_int>
const &send_sizes)
const {
333 if (send_sizes.size() == 1 and vec.size() != send_sizes.front())
334 throw std::runtime_error(
"Input vector size and sizes are inconsistent on root");
337 std::vector<t_int> rec_sizes(send_sizes.size(), 0);
338 for (
t_int i = 0; i < size(); i++) {
340 rec_sizes = gather<t_int>(send_sizes[i], i);
342 gather<t_int>(send_sizes[i], i);
345 return all_to_allv<T>(vec, send_sizes, rec_sizes);
347 template <
typename T>
348 typename std::enable_if<is_registered_type<T>::value, std::vector<T>>::type
349 Communicator::all_to_allv(
const std::vector<T> &vec, std::vector<t_int>
const &send_sizes,
350 std::vector<t_int>
const &rec_sizes)
const {
352 if (send_sizes.size() == 1 and vec.size() != send_sizes.front())
353 throw std::runtime_error(
"Input vector size and sizes are inconsistent on root");
358 std::vector<int> ssizes_;
359 std::vector<int> sdispls;
360 for (
auto const size : send_sizes) {
361 ssizes_.push_back(
static_cast<int>(size));
362 sdispls.push_back(i);
366 std::vector<int> rsizes_;
367 std::vector<int> rdispls;
368 for (
auto const size : rec_sizes) {
369 rsizes_.push_back(
static_cast<int>(size));
370 rdispls.push_back(total);
373 std::vector<T> output = std::vector<T>(total, 0);
374 MPI_Alltoallv(
const_cast<void *
>(
reinterpret_cast<const void *
>(vec.data())), ssizes_.data(),
375 sdispls.data(), registered_type(T(0)), output.data(), rsizes_.data(),
376 rdispls.data(), registered_type(T(0)), **
this);
380 template <
typename T>
381 typename std::enable_if<is_registered_type<T>::value,
Vector<T>>::type Communicator::all_to_allv(
382 const Vector<T> &vec, std::vector<t_int>
const &send_sizes)
const {
384 if (send_sizes.size() == 1 and vec.size() != send_sizes.front())
385 throw std::runtime_error(
"Input vector size and sizes are inconsistent on root");
388 std::vector<t_int> rec_sizes(send_sizes.size(), 0);
389 for (
t_int i = 0; i < size(); i++) {
391 rec_sizes = gather<t_int>(send_sizes[i], i);
393 gather<t_int>(send_sizes[i], i);
396 return all_to_allv<T>(vec, send_sizes, rec_sizes);
399 template <
typename T>
400 typename std::enable_if<is_registered_type<T>::value,
Vector<T>>::type Communicator::all_to_allv(
401 const Vector<T> &vec, std::vector<t_int>
const &send_sizes,
402 std::vector<t_int>
const &rec_sizes)
const {
404 if (send_sizes.size() == 1 and vec.size() != send_sizes.front())
405 throw std::runtime_error(
"Input vector size and sizes are inconsistent on root");
410 std::vector<int> ssizes_;
411 std::vector<int> sdispls;
412 for (
auto const size : send_sizes) {
413 ssizes_.push_back(
static_cast<int>(size));
414 sdispls.push_back(i);
418 std::vector<int> rsizes_;
419 std::vector<int> rdispls;
420 for (
auto const size : rec_sizes) {
421 rsizes_.push_back(
static_cast<int>(size));
422 rdispls.push_back(total);
426 MPI_Alltoallv(
const_cast<void *
>(
reinterpret_cast<const void *
>(vec.data())), ssizes_.data(),
427 sdispls.data(), registered_type(T(0)), output.data(), rsizes_.data(),
428 rdispls.data(), registered_type(T(0)), **
this);
432 template <
typename T>
433 typename std::enable_if<is_registered_type<T>::value, std::vector<T>>::type Communicator::gather(
434 T
const value,
t_uint const root)
const {
435 assert(root < size());
436 if (size() == 1)
return {value};
437 std::vector<T> result;
438 if (rank() == root) {
439 result.resize(size());
440 MPI_Gather(
const_cast<void *
>(
reinterpret_cast<const void *
>(&value)), 1,
441 registered_type(value), result.data(), 1, registered_type(value), root, **
this);
443 MPI_Gather(
const_cast<void *
>(
reinterpret_cast<const void *
>(&value)), 1,
444 registered_type(value),
nullptr, 1, registered_type(value), root, **
this);
448 template <
typename CONTAINER,
typename T>
449 CONTAINER Communicator::gather_(CONTAINER
const &vec, std::vector<t_int>
const &sizes,
450 t_uint const root)
const {
451 assert(root < size());
452 if (sizes.size() != size() and rank() == root)
453 throw std::runtime_error(
"Sizes and communicator size do not match on root");
454 else if (rank() != root and !sizes.empty() and sizes.size() != size())
455 throw std::runtime_error(
456 "Outside root, sizes should be either empty or match the number of procs");
457 else if (sizes.size() == size() and sizes[rank()] !=
static_cast<t_int>(vec.size()))
458 throw std::runtime_error(
"Sizes and input vector size do not match");
460 if (size() == 1)
return vec;
462 if (rank() != root)
return gather_<CONTAINER, T>(vec, root);
464 std::vector<int> sizes_;
465 std::vector<int> displs;
467 for (
auto const size : sizes) {
468 sizes_.push_back(
static_cast<int>(size));
469 displs.push_back(result_size);
472 CONTAINER result(result_size);
474 MPI_Gatherv(
const_cast<void *
>(
reinterpret_cast<const void *
>(vec.data())), sizes_[rank()],
475 mpi::Type<T>::value, result.data(), sizes_.data(), displs.data(), mpi::Type<T>::value,
480 template <
typename CONTAINER,
typename T>
481 CONTAINER Communicator::gather_(CONTAINER
const &vec,
t_uint const root)
const {
482 assert(root < size());
483 if (rank() == root)
throw std::runtime_error(
"Root should call the *other* gather");
485 MPI_Gatherv(
const_cast<void *
>(
reinterpret_cast<const void *
>(vec.data())), vec.size(),
486 mpi::Type<T>::value,
nullptr,
nullptr,
nullptr, mpi::Type<T>::value, root, **
this);
490 template <
typename T>
491 typename std::enable_if<is_registered_type<T>::value, std::set<T>>::type Communicator::gather(
492 std::set<T>
const &set, std::vector<t_int>
const &sizes,
t_uint const root)
const {
493 assert(root < size());
495 return gather(set, root);
496 else if (size() == 1)
499 assert(sizes.size() == size());
500 assert(sizes[root] == set.size());
502 std::copy(set.begin(), set.end(), buffer.data());
503 buffer = gather(buffer, sizes);
504 return std::set<T>(buffer.data(), buffer.data() + buffer.size());
507 template <
typename T>
508 typename std::enable_if<is_registered_type<T>::value, std::set<T>>::type Communicator::gather(
509 std::set<T>
const &set,
t_uint const root)
const {
510 assert(root < size());
511 if (rank() == root)
throw std::runtime_error(
"Root should call the *other* gather");
514 std::copy(set.begin(), set.end(), buffer.data());
516 return std::set<T>();
519 template <
typename T>
520 typename std::enable_if<is_registered_type<T>::value, T>::type Communicator::broadcast(
521 T
const &value,
t_uint const root)
const {
522 assert(root < size());
523 if (size() == 1)
return value;
524 if (not impl)
return value;
526 MPI_Bcast(&result, 1, registered_type(result), root, **
this);
530 template <
typename T>
531 typename std::enable_if<is_registered_type<T>::value, T>::type Communicator::broadcast(
532 t_uint const root)
const {
533 assert(root < size());
535 throw std::runtime_error(
"Root process should call the *other* broadcasting function");
537 MPI_Bcast(&result, 1, registered_type(result), root, **
this);
541 template <
typename T>
542 typename std::enable_if<is_registered_type<typename T::Scalar>::value, T>::type
543 Communicator::broadcast(T
const &vec,
t_uint const root)
const {
544 if (size() == 1)
return vec;
545 if (not impl)
return vec;
546 if (rank() != root)
return broadcast<T>(root);
547 assert(root < size());
548 auto const Nx = broadcast(vec.rows(), root);
549 auto const Ny = broadcast(vec.cols(), root);
550 MPI_Bcast(
const_cast<typename
T::Scalar *
>(vec.data()), Nx * Ny, Type<typename T::Scalar>::value,
554 template <
typename T>
555 typename std::enable_if<is_registered_type<typename T::Scalar>::value, T>::type
556 Communicator::broadcast(
t_uint const root)
const {
557 assert(root < size());
559 throw std::runtime_error(
"Root process should call the *other* broadcasting function");
560 auto const Nx = broadcast(decltype(std::declval<T>().
rows())(0), root);
561 auto const Ny = broadcast(decltype(std::declval<T>().
cols())(0), root);
563 MPI_Bcast(result.data(), result.size(), Type<typename T::Scalar>::value, root, **
this);
566 template <
typename T>
567 typename std::enable_if<is_registered_type<typename T::value_type>::value and
568 not std::is_base_of<Eigen::EigenBase<T>, T>::value,
570 Communicator::broadcast(T
const &vec,
t_uint const root)
const {
571 assert(root < size());
572 if (size() == 1)
return vec;
573 if (not impl)
return vec;
574 if (rank() != root)
return broadcast<T>(root);
575 auto const N = broadcast(vec.size(), root);
576 MPI_Bcast(
const_cast<typename T::value_type *
>(vec.data()),
N,
577 Type<typename T::value_type>::value, root, **
this);
580 template <
typename T>
581 typename std::enable_if<is_registered_type<typename T::value_type>::value and
582 not std::is_base_of<Eigen::EigenBase<T>, T>::value,
584 Communicator::broadcast(
t_uint const root)
const {
585 assert(root < size());
587 throw std::runtime_error(
"Root process should call the *other* broadcasting function");
588 auto const N = broadcast(decltype(std::declval<T>().size())(0), root);
590 MPI_Bcast(result.data(), result.size(), Type<typename T::value_type>::value, root, **
this);
std::vector< T > split(std::string s, const std::string &sep)
Split a string on a specified delimiter with optional cast to another type.
int t_int
Root of the type hierarchy for signed integers.
size_t t_uint
Root of the type hierarchy for unsigned integers.
sopt::Vector< Scalar > Vector
sopt::Matrix< Scalar > Matrix
sopt::Image< Scalar > Image