SOPT
Sparse OPTimisation
communicator.h
Go to the documentation of this file.
1 #ifndef SOPT_MPI_COMMUNICATOR_H
2 #define SOPT_MPI_COMMUNICATOR_H
3 
4 #include "sopt/config.h"
5 #include "sopt/mpi/session.h"
6 
7 #ifdef SOPT_MPI
8 
9 #include <algorithm> // for std::copy
10 #include <iostream>
11 #include <memory>
12 #include <mpi.h>
13 #include <set>
14 #include <string>
15 #include <type_traits>
16 #include <vector>
18 #include "sopt/types.h"
19 
20 #include <cxxabi.h>
21 #include <typeinfo>
22 
23 namespace sopt::mpi {
24 
32 class Communicator {
34  struct Impl {
36  MPI_Comm comm;
38  t_uint size;
40  t_uint rank;
41  };
42 
43  public:
45  Communicator() : impl(), session() {}
46 
47  static Communicator World() { return Communicator(MPI_COMM_WORLD); }
48  static Communicator Self() { return Communicator(MPI_COMM_SELF); }
49 
50  virtual ~Communicator(){}
51 
53  decltype(Impl::size) size() const { return impl ? impl->size : 1; }
55  decltype(Impl::rank) rank() const { return impl ? impl->rank : 0; }
57  decltype(Impl::comm) operator*() const {
58  if (not impl) throw std::runtime_error("Communicator was not set");
59  return impl->comm;
60  }
62  static constexpr t_uint root_id() { return 0; }
64  bool is_root() const { return rank() == root_id(); }
67  Communicator duplicate() const;
69  Communicator clone() const { return duplicate(); }
71  void abort(const std::string &reason) const;
72 
74  template <typename T>
75  typename std::enable_if<is_registered_type<T>::value, T>::type all_reduce(T const &value,
76  MPI_Op operation) const;
77 
79  template <typename T>
80  typename std::enable_if<is_registered_type<T>::value>::type all_reduce(Matrix<T> &image,
81  MPI_Op operation) const {
82  if (size() == 1) return;
83  assert(impl and image.size() and image.data());
84  MPI_Allreduce(MPI_IN_PLACE, image.data(), image.size(), registered_type(T(0)), operation,
85  **this);
86  }
87  template <typename T>
88  typename std::enable_if<is_registered_type<T>::value>::type all_reduce(Image<T> &image,
89  MPI_Op operation) const {
90  if (size() == 1) return;
91  assert(impl and image.size() and image.data());
92  MPI_Allreduce(MPI_IN_PLACE, image.data(), image.size(), registered_type(T(0)), operation,
93  **this);
94  }
95  template <typename T>
96  typename std::enable_if<is_registered_type<T>::value>::type all_reduce(Vector<T> &image,
97  MPI_Op operation) const {
98  if (size() == 1) return;
99  assert(impl and image.size() and image.data());
100  MPI_Allreduce(MPI_IN_PLACE, image.data(), image.size(), registered_type(T(0)), operation,
101  **this);
102  }
103 
105  template <typename T>
106  typename std::enable_if<is_registered_type<T>::value, T>::type all_sum_all(T const &value) const {
107  return all_reduce(value, MPI_SUM);
108  }
109  template <typename T>
110  typename std::enable_if<is_registered_type<typename T::Scalar>::value>::type all_sum_all(
111  T &image) const {
112  all_reduce(image, MPI_SUM);
113  }
114  template <typename T>
115  typename std::enable_if<is_registered_type<typename T::Scalar>::value, T>::type all_sum_all(
116  T const &image) const {
117  T result(image);
118  all_reduce(result, MPI_SUM);
119  return result;
120  }
121 
123  template <typename T>
124  typename std::enable_if<is_registered_type<T>::value, T>::type broadcast(
125  T const &value, t_uint const root = root_id()) const;
127  template <typename T>
128  typename std::enable_if<is_registered_type<T>::value, T>::type broadcast(
129  t_uint const root = root_id()) const;
130  template <typename T>
131  typename std::enable_if<is_registered_type<typename T::Scalar>::value, T>::type broadcast(
132  T const &vec, t_uint const root = root_id()) const;
133  template <typename T>
134  typename std::enable_if<is_registered_type<typename T::Scalar>::value, T>::type broadcast(
135  t_uint const root = root_id()) const;
136  template <typename T>
137  typename std::enable_if<is_registered_type<typename T::value_type>::value and
138  not std::is_base_of<Eigen::EigenBase<T>, T>::value,
139  T>::type
140  broadcast(T const &vec, t_uint const root = root_id()) const;
141  template <typename T>
142  typename std::enable_if<is_registered_type<typename T::value_type>::value and
143  not std::is_base_of<Eigen::EigenBase<T>, T>::value,
144  T>::type
145  broadcast(t_uint const root = root_id()) const;
146  std::string broadcast(std::string const &input, t_uint const root = root_id()) const;
147 
149  template <typename T>
150  typename std::enable_if<is_registered_type<T>::value, T>::type scatter_one(
151  std::vector<T> const &values, t_uint const root = root_id()) const;
153  template <typename T>
154  typename std::enable_if<is_registered_type<T>::value, T>::type scatter_one(
155  t_uint const root = root_id()) const;
156 
158  template <typename T>
159  typename std::enable_if<is_registered_type<T>::value, Vector<T>>::type scatterv(
160  Vector<T> const &vec, std::vector<t_int> const &sizes, t_uint const root = root_id()) const;
161  template <typename T>
162  typename std::enable_if<is_registered_type<T>::value, Vector<T>>::type scatterv(
163  t_int local_size, t_uint const root = root_id()) const;
164 
165  // Gather one object per proc
166  template <typename T>
167  typename std::enable_if<is_registered_type<T>::value, std::vector<T>>::type gather(
168  T const value, t_uint const root = root_id()) const;
169 
171  template <typename T>
172  typename std::enable_if<is_registered_type<T>::value, Vector<T>>::type gather(
173  Vector<T> const &vec, std::vector<t_int> const &sizes, t_uint const root = root_id()) const {
174  return gather_<Vector<T>, T>(vec, sizes, root);
175  }
176  template <typename T>
177  typename std::enable_if<is_registered_type<T>::value, Vector<T>>::type gather(
178  Vector<T> const &vec, t_uint const root = root_id()) const {
179  return gather_<Vector<T>, T>(vec, root);
180  }
181 
182  template <typename T>
183  typename std::enable_if<is_registered_type<T>::value, std::set<T>>::type gather(
184  std::set<T> const &set, std::vector<t_int> const &sizes, t_uint const root = root_id()) const;
185  template <typename T>
186  typename std::enable_if<is_registered_type<T>::value, std::set<T>>::type gather(
187  std::set<T> const &vec, t_uint const root = root_id()) const;
188  template <typename T>
189  typename std::enable_if<is_registered_type<T>::value, std::vector<T>>::type gather(
190  std::vector<T> const &vec, std::vector<t_int> const &sizes,
191  t_uint const root = root_id()) const {
192  return gather_<std::vector<T>, T>(vec, sizes, root);
193  }
194  template <typename T>
195  typename std::enable_if<is_registered_type<T>::value, std::vector<T>>::type gather(
196  std::vector<T> const &vec, t_uint const root = root_id()) const {
197  return gather_<std::vector<T>, T>(vec, root);
198  }
200  template <typename T>
201  typename std::enable_if<is_registered_type<T>::value, Vector<T>>::type all_to_allv(
202  const Vector<T> &vec, std::vector<t_int> const &send_sizes) const;
203  template <typename T>
204  typename std::enable_if<is_registered_type<T>::value, Vector<T>>::type all_to_allv(
205  const Vector<T> &vec, std::vector<t_int> const &send_sizes,
206  std::vector<t_int> const &rec_sizes) const;
207  template <typename T>
208  typename std::enable_if<is_registered_type<T>::value, std::vector<T>>::type all_to_allv(
209  const std::vector<T> &vec, std::vector<t_int> const &send_sizes,
210  std::vector<t_int> const &rec_sizes) const;
211  template <typename T>
212  typename std::enable_if<is_registered_type<T>::value, std::vector<T>>::type all_to_allv(
213  const std::vector<T> &vec, std::vector<t_int> const &send_sizes) const;
214 
216  Communicator split(t_int color) const { return split(color, rank()); }
218  Communicator split(t_int color, t_uint rank) const {
219  MPI_Comm comm;
220  MPI_Comm_split(**this, color, static_cast<t_int>(rank), &comm);
221  return comm;
222  }
223 
225  void barrier() const {
226  if (not impl) return;
227  if (MPI_Barrier(**this) != MPI_SUCCESS)
228  throw std::runtime_error("Encountered error in mpi barrier");
229  }
230 
231  private:
233  std::shared_ptr<Impl const> impl;
236  std::shared_ptr<mpi::details::initializer> session;
237 
239  static void delete_comm(Impl *impl);
240 
246  Communicator(MPI_Comm const &comm);
247 
249  template <typename CONTAINER, typename T>
250  CONTAINER gather_(CONTAINER const &vec, std::vector<t_int> const &sizes,
251  t_uint const root = root_id()) const;
252  template <typename CONTAINER, typename T>
253  CONTAINER gather_(CONTAINER const &vec, t_uint const root = root_id()) const;
254 };
255 
256 template <typename T>
257 typename std::enable_if<is_registered_type<T>::value, T>::type Communicator::all_reduce(
258  T const &value, MPI_Op operation) const {
259  if (size() == 1) return value;
260  assert(impl);
261  T result;
262  MPI_Allreduce(const_cast<void *>(reinterpret_cast<const void *>(&value)), &result, 1,
263  registered_type(value), operation, **this);
264  return result;
265 }
266 
267 template <typename T>
268 typename std::enable_if<is_registered_type<T>::value, T>::type Communicator::scatter_one(
269  std::vector<T> const &values, t_uint const root) const {
270  assert(root < size());
271  if (values.size() != size()) throw std::runtime_error("Expected a single object per process");
272  if (size() == 1) return values.at(0);
273  T result;
274  MPI_Scatter(const_cast<void *>(reinterpret_cast<const void *>(values.data())), 1,
275  registered_type(result), &result, 1, registered_type(result), root, **this);
276  return result;
277 }
279 template <typename T>
280 typename std::enable_if<is_registered_type<T>::value, T>::type Communicator::scatter_one(
281  t_uint const root) const {
282  T result;
283  MPI_Scatter(nullptr, 1, registered_type(result), &result, 1, registered_type(result), root,
284  **this);
285  return result;
286 }
287 
288 template <typename T>
289 typename std::enable_if<is_registered_type<T>::value, Vector<T>>::type Communicator::scatterv(
290  Vector<T> const &vec, std::vector<t_int> const &sizes, t_uint const root) const {
291  if (size() == 1) {
292  if (sizes.size() == 1 and vec.size() != sizes.front())
293  throw std::runtime_error("Input vector size and sizes are inconsistent on root");
294  return vec;
295  }
296  if (rank() != root) return scatterv<T>(sizes.at(rank()), root);
297  std::vector<int> sizes_;
298  std::vector<int> displs;
299  int i = 0;
300  for (auto const size : sizes) {
301  sizes_.push_back(static_cast<int>(size));
302  displs.push_back(i);
303  i += size;
304  }
305  if (vec.size() != i) throw std::runtime_error("Input vector size and sizes are inconsistent");
306 
307  Vector<T> result(sizes[rank()]);
308  if (not impl)
309  result = vec.head(sizes[rank()]);
310  else
311  MPI_Scatterv(const_cast<void *>(reinterpret_cast<const void *>(vec.data())), sizes_.data(),
312  displs.data(), registered_type(T(0)), result.data(), sizes_[rank()],
313  registered_type(T(0)), root, **this);
314  return result;
315 }
316 
317 template <typename T>
318 typename std::enable_if<is_registered_type<T>::value, Vector<T>>::type Communicator::scatterv(
319  t_int local_size, t_uint const root) const {
320  if (rank() == root) throw std::runtime_error("Root should call the *other* scatterv");
321  std::vector<int> sizes(size());
322  sizes[rank()] = local_size;
323  Vector<T> result(sizes[rank()]);
324  MPI_Scatterv(nullptr, sizes.data(), nullptr, registered_type(T(0)), result.data(), local_size,
325  registered_type(T(0)), root, **this);
326  return result;
327 }
328 
329 template <typename T>
330 typename std::enable_if<is_registered_type<T>::value, std::vector<T>>::type
331 Communicator::all_to_allv(const std::vector<T> &vec, std::vector<t_int> const &send_sizes) const {
332  if (size() == 1) {
333  if (send_sizes.size() == 1 and vec.size() != send_sizes.front())
334  throw std::runtime_error("Input vector size and sizes are inconsistent on root");
335  return vec;
336  }
337  std::vector<t_int> rec_sizes(send_sizes.size(), 0);
338  for (t_int i = 0; i < size(); i++) {
339  if (i == rank())
340  rec_sizes = gather<t_int>(send_sizes[i], i);
341  else
342  gather<t_int>(send_sizes[i], i);
343  }
344 
345  return all_to_allv<T>(vec, send_sizes, rec_sizes);
346 }
347 template <typename T>
348 typename std::enable_if<is_registered_type<T>::value, std::vector<T>>::type
349 Communicator::all_to_allv(const std::vector<T> &vec, std::vector<t_int> const &send_sizes,
350  std::vector<t_int> const &rec_sizes) const {
351  if (size() == 1) {
352  if (send_sizes.size() == 1 and vec.size() != send_sizes.front())
353  throw std::runtime_error("Input vector size and sizes are inconsistent on root");
354  return vec;
355  }
356 
357  int i = 0;
358  std::vector<int> ssizes_;
359  std::vector<int> sdispls;
360  for (auto const size : send_sizes) {
361  ssizes_.push_back(static_cast<int>(size));
362  sdispls.push_back(i);
363  i += size;
364  }
365  int total = 0;
366  std::vector<int> rsizes_;
367  std::vector<int> rdispls;
368  for (auto const size : rec_sizes) {
369  rsizes_.push_back(static_cast<int>(size));
370  rdispls.push_back(total);
371  total += size;
372  }
373  std::vector<T> output = std::vector<T>(total, 0);
374  MPI_Alltoallv(const_cast<void *>(reinterpret_cast<const void *>(vec.data())), ssizes_.data(),
375  sdispls.data(), registered_type(T(0)), output.data(), rsizes_.data(),
376  rdispls.data(), registered_type(T(0)), **this);
377  return output;
378 }
379 
380 template <typename T>
381 typename std::enable_if<is_registered_type<T>::value, Vector<T>>::type Communicator::all_to_allv(
382  const Vector<T> &vec, std::vector<t_int> const &send_sizes) const {
383  if (size() == 1) {
384  if (send_sizes.size() == 1 and vec.size() != send_sizes.front())
385  throw std::runtime_error("Input vector size and sizes are inconsistent on root");
386  return vec;
387  }
388  std::vector<t_int> rec_sizes(send_sizes.size(), 0);
389  for (t_int i = 0; i < size(); i++) {
390  if (i == rank())
391  rec_sizes = gather<t_int>(send_sizes[i], i);
392  else
393  gather<t_int>(send_sizes[i], i);
394  }
395 
396  return all_to_allv<T>(vec, send_sizes, rec_sizes);
397 }
398 
399 template <typename T>
400 typename std::enable_if<is_registered_type<T>::value, Vector<T>>::type Communicator::all_to_allv(
401  const Vector<T> &vec, std::vector<t_int> const &send_sizes,
402  std::vector<t_int> const &rec_sizes) const {
403  if (size() == 1) {
404  if (send_sizes.size() == 1 and vec.size() != send_sizes.front())
405  throw std::runtime_error("Input vector size and sizes are inconsistent on root");
406  return vec;
407  }
408 
409  int i = 0;
410  std::vector<int> ssizes_;
411  std::vector<int> sdispls;
412  for (auto const size : send_sizes) {
413  ssizes_.push_back(static_cast<int>(size));
414  sdispls.push_back(i);
415  i += size;
416  }
417  int total = 0;
418  std::vector<int> rsizes_;
419  std::vector<int> rdispls;
420  for (auto const size : rec_sizes) {
421  rsizes_.push_back(static_cast<int>(size));
422  rdispls.push_back(total);
423  total += size;
424  }
425  Vector<T> output = Vector<T>::Zero(total);
426  MPI_Alltoallv(const_cast<void *>(reinterpret_cast<const void *>(vec.data())), ssizes_.data(),
427  sdispls.data(), registered_type(T(0)), output.data(), rsizes_.data(),
428  rdispls.data(), registered_type(T(0)), **this);
429  return output;
430 }
431 
432 template <typename T>
433 typename std::enable_if<is_registered_type<T>::value, std::vector<T>>::type Communicator::gather(
434  T const value, t_uint const root) const {
435  assert(root < size());
436  if (size() == 1) return {value};
437  std::vector<T> result;
438  if (rank() == root) {
439  result.resize(size());
440  MPI_Gather(const_cast<void *>(reinterpret_cast<const void *>(&value)), 1,
441  registered_type(value), result.data(), 1, registered_type(value), root, **this);
442  } else
443  MPI_Gather(const_cast<void *>(reinterpret_cast<const void *>(&value)), 1,
444  registered_type(value), nullptr, 1, registered_type(value), root, **this);
445  return result;
446 }
447 
448 template <typename CONTAINER, typename T>
449 CONTAINER Communicator::gather_(CONTAINER const &vec, std::vector<t_int> const &sizes,
450  t_uint const root) const {
451  assert(root < size());
452  if (sizes.size() != size() and rank() == root)
453  throw std::runtime_error("Sizes and communicator size do not match on root");
454  else if (rank() != root and !sizes.empty() and sizes.size() != size())
455  throw std::runtime_error(
456  "Outside root, sizes should be either empty or match the number of procs");
457  else if (sizes.size() == size() and sizes[rank()] != static_cast<t_int>(vec.size()))
458  throw std::runtime_error("Sizes and input vector size do not match");
459 
460  if (size() == 1) return vec;
461 
462  if (rank() != root) return gather_<CONTAINER, T>(vec, root);
463 
464  std::vector<int> sizes_;
465  std::vector<int> displs;
466  int result_size = 0;
467  for (auto const size : sizes) {
468  sizes_.push_back(static_cast<int>(size));
469  displs.push_back(result_size);
470  result_size += size;
471  }
472  CONTAINER result(result_size);
473 
474  MPI_Gatherv(const_cast<void *>(reinterpret_cast<const void *>(vec.data())), sizes_[rank()],
475  mpi::Type<T>::value, result.data(), sizes_.data(), displs.data(), mpi::Type<T>::value,
476  root, **this);
477  return result;
478 }
479 
480 template <typename CONTAINER, typename T>
481 CONTAINER Communicator::gather_(CONTAINER const &vec, t_uint const root) const {
482  assert(root < size());
483  if (rank() == root) throw std::runtime_error("Root should call the *other* gather");
484 
485  MPI_Gatherv(const_cast<void *>(reinterpret_cast<const void *>(vec.data())), vec.size(),
486  mpi::Type<T>::value, nullptr, nullptr, nullptr, mpi::Type<T>::value, root, **this);
487  return CONTAINER();
488 }
489 
490 template <typename T>
491 typename std::enable_if<is_registered_type<T>::value, std::set<T>>::type Communicator::gather(
492  std::set<T> const &set, std::vector<t_int> const &sizes, t_uint const root) const {
493  assert(root < size());
494  if (rank() != root)
495  return gather(set, root);
496  else if (size() == 1)
497  return set;
498 
499  assert(sizes.size() == size());
500  assert(sizes[root] == set.size());
501  Vector<T> buffer(set.size());
502  std::copy(set.begin(), set.end(), buffer.data());
503  buffer = gather(buffer, sizes);
504  return std::set<T>(buffer.data(), buffer.data() + buffer.size());
505 }
506 
507 template <typename T>
508 typename std::enable_if<is_registered_type<T>::value, std::set<T>>::type Communicator::gather(
509  std::set<T> const &set, t_uint const root) const {
510  assert(root < size());
511  if (rank() == root) throw std::runtime_error("Root should call the *other* gather");
512 
513  Vector<T> buffer(set.size());
514  std::copy(set.begin(), set.end(), buffer.data());
515  gather(buffer);
516  return std::set<T>();
517 }
518 
519 template <typename T>
520 typename std::enable_if<is_registered_type<T>::value, T>::type Communicator::broadcast(
521  T const &value, t_uint const root) const {
522  assert(root < size());
523  if (size() == 1) return value;
524  if (not impl) return value;
525  auto result = value;
526  MPI_Bcast(&result, 1, registered_type(result), root, **this);
527  return result;
528 }
530 template <typename T>
531 typename std::enable_if<is_registered_type<T>::value, T>::type Communicator::broadcast(
532  t_uint const root) const {
533  assert(root < size());
534  if (root == rank())
535  throw std::runtime_error("Root process should call the *other* broadcasting function");
536  T result;
537  MPI_Bcast(&result, 1, registered_type(result), root, **this);
538  return result;
539 }
540 
541 template <typename T>
542 typename std::enable_if<is_registered_type<typename T::Scalar>::value, T>::type
543 Communicator::broadcast(T const &vec, t_uint const root) const {
544  if (size() == 1) return vec;
545  if (not impl) return vec;
546  if (rank() != root) return broadcast<T>(root);
547  assert(root < size());
548  auto const Nx = broadcast(vec.rows(), root);
549  auto const Ny = broadcast(vec.cols(), root);
550  MPI_Bcast(const_cast<typename T::Scalar *>(vec.data()), Nx * Ny, Type<typename T::Scalar>::value,
551  root, **this);
552  return vec;
553 }
554 template <typename T>
555 typename std::enable_if<is_registered_type<typename T::Scalar>::value, T>::type
556 Communicator::broadcast(t_uint const root) const {
557  assert(root < size());
558  if (root == rank())
559  throw std::runtime_error("Root process should call the *other* broadcasting function");
560  auto const Nx = broadcast(decltype(std::declval<T>().rows())(0), root);
561  auto const Ny = broadcast(decltype(std::declval<T>().cols())(0), root);
562  T result(Nx, Ny);
563  MPI_Bcast(result.data(), result.size(), Type<typename T::Scalar>::value, root, **this);
564  return result;
565 }
566 template <typename T>
567 typename std::enable_if<is_registered_type<typename T::value_type>::value and
568  not std::is_base_of<Eigen::EigenBase<T>, T>::value,
569  T>::type
570 Communicator::broadcast(T const &vec, t_uint const root) const {
571  assert(root < size());
572  if (size() == 1) return vec;
573  if (not impl) return vec;
574  if (rank() != root) return broadcast<T>(root);
575  auto const N = broadcast(vec.size(), root);
576  MPI_Bcast(const_cast<typename T::value_type *>(vec.data()), N,
577  Type<typename T::value_type>::value, root, **this);
578  return vec;
579 }
580 template <typename T>
581 typename std::enable_if<is_registered_type<typename T::value_type>::value and
582  not std::is_base_of<Eigen::EigenBase<T>, T>::value,
583  T>::type
584 Communicator::broadcast(t_uint const root) const {
585  assert(root < size());
586  if (root == rank())
587  throw std::runtime_error("Root process should call the *other* broadcasting function");
588  auto const N = broadcast(decltype(std::declval<T>().size())(0), root);
589  T result(N);
590  MPI_Bcast(result.data(), result.size(), Type<typename T::value_type>::value, root, **this);
591  return result;
592 }
593 } // namespace sopt::mpi
594 #endif /* ifdef SOPT_MPI */
595 #endif /* ifndef SOPT_MPI_COMMUNICATOR */
constexpr auto N
Definition: wavelets.cc:57
sopt::t_real Scalar
t_uint rows
t_uint cols
std::vector< T > split(std::string s, const std::string &sep)
Split a string on a specified delimiter with optional cast to another type.
Definition: utilities.h:38
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
sopt::Vector< Scalar > Vector
Definition: inpainting.cc:28
sopt::Matrix< Scalar > Matrix
Definition: inpainting.cc:29
sopt::Image< Scalar > Image
Definition: inpainting.cc:30