PURIFY
Next-generation radio interferometric imaging
mpi_utilities.cc
Go to the documentation of this file.
1 #include "purify/mpi_utilities.h"
2 #include "purify/config.h"
3 #include <iostream>
4 #include <type_traits>
5 #include "purify/distribute.h"
6 
7 namespace purify::utilities {
8 
9 void regroup(vis_params &uv_params, std::vector<t_int> const &groups_, const t_int max_groups) {
10  std::vector<t_int> image_index(uv_params.size(), 0);
11  regroup(uv_params, image_index, groups_, max_groups);
12 }
13 
14 void regroup(vis_params &uv_params, std::vector<t_int> &image_index,
15  std::vector<t_int> const &groups_, const t_int max_groups) {
16  std::vector<t_int> groups = groups_;
17  // Figure out size of each group
18  std::map<t_int, t_int> sizes;
19  for (auto g = 0; g < max_groups; g++) sizes[g] = 0;
20  for (auto const &group : groups) ++sizes[group];
21 
22  std::map<t_int, t_int> indices, ends;
23  auto i = 0;
24  for (auto const &item : sizes) {
25  indices[item.first] = i;
26  i += item.second;
27  ends[item.first] = i;
28  }
29  const auto minmax = std::minmax_element(ends.begin(), ends.end());
30  if (std::get<1>(*std::get<0>(minmax)) < 0)
31  throw std::runtime_error("segment end " + std::to_string(std::get<1>(*std::get<0>(minmax))) +
32  " less than 0. Not a valid end.");
33  if (std::get<1>(*std::get<1>(minmax)) > uv_params.size())
34  throw std::runtime_error("segment end " + std::to_string(std::get<1>(*std::get<1>(minmax))) +
35  " larger than data vector " +
36  std::to_string(static_cast<t_int>(uv_params.size())) +
37  ". End not valid.");
38 
39  auto const expected = [&ends](t_int i) {
40  t_int j = 0;
41  for (auto const end : ends) {
42  if (i < end.second) return j;
43  ++j;
44  }
45  return j;
46  };
47 
48  i = 0;
49  while (i < uv_params.u.size()) {
50  auto const expected_proc = expected(i);
51  if (groups[i] == expected_proc) {
52  ++i;
53  continue;
54  }
55  auto &swapper = indices[groups[i]];
56  if (groups[swapper] == expected(swapper)) {
57  ++swapper;
58  continue;
59  }
60  if (swapper >= uv_params.u.size())
61  throw std::runtime_error("regroup (groups, " + std::to_string(groups[swapper]) + ", " +
62  std::to_string(groups[i]) + ") index out of bounds " +
63  std::to_string(i) + " " + std::to_string(swapper) +
64  " >= " + std::to_string(uv_params.u.size()));
65  std::swap(groups[i], groups[swapper]);
66  std::swap(uv_params.u(i), uv_params.u(swapper));
67  std::swap(uv_params.v(i), uv_params.v(swapper));
68  std::swap(uv_params.w(i), uv_params.w(swapper));
69  std::swap(uv_params.vis(i), uv_params.vis(swapper));
70  std::swap(uv_params.weights(i), uv_params.weights(swapper));
71  std::swap(image_index[i], image_index[swapper]);
72 
73  ++swapper;
74  }
75 }
76 
77 vis_params regroup_and_scatter(vis_params const &params, std::vector<t_int> const &groups,
78  sopt::mpi::Communicator const &comm) {
79  if (comm.size() == 1) return params;
80  if (comm.rank() != comm.root_id()) return scatter_visibilities(comm);
81 
82  std::vector<t_int> sizes(comm.size());
83  std::fill(sizes.begin(), sizes.end(), 0);
84  for (auto const &group : groups) {
85  if (group > static_cast<t_int>(comm.size()))
86  throw std::out_of_range("groups should go from 0 to comm.size()");
87  ++sizes[group];
88  }
89 
90  vis_params copy = params;
91  regroup(copy, groups, comm.size());
92  return scatter_visibilities(copy, sizes, comm);
93 }
94 std::tuple<vis_params, std::vector<t_int>> regroup_and_all_to_all(
95  vis_params const &params, const std::vector<t_int> &image_index,
96  std::vector<t_int> const &groups, sopt::mpi::Communicator const &comm) {
97  if (comm.size() == 1) return std::make_tuple(params, image_index);
98  vis_params copy = params;
99  std::vector<t_int> index_copy(image_index.size());
100  std::copy(image_index.begin(), image_index.end(), index_copy.begin());
101  regroup(copy, index_copy, groups, comm.size());
102 
103  std::vector<t_int> sizes(comm.size());
104  std::fill(sizes.begin(), sizes.end(), 0);
105  for (const t_int &group : groups) {
106  if (group >= static_cast<t_int>(comm.size())) {
107  throw std::out_of_range("groups should go from 0 to comm.size()-1");
108  }
109  ++sizes[group];
110  }
111 
112  return std::make_tuple(all_to_all_visibilities(copy, sizes, comm),
113  comm.all_to_allv(index_copy, sizes));
114 }
115 
116 vis_params regroup_and_all_to_all(vis_params const &params, std::vector<t_int> const &groups,
117  sopt::mpi::Communicator const &comm) {
118  return std::get<0>(
119  regroup_and_all_to_all(params, std::vector<t_int>(params.size(), 0), groups, comm));
120 }
121 
122 vis_params all_to_all_visibilities(vis_params const &params, std::vector<t_int> const &sizes,
123  sopt::mpi::Communicator const &comm) {
124  if (comm.size() == 1) return params;
125  vis_params result;
126  result.u = comm.all_to_allv(params.u, sizes);
127  result.v = comm.all_to_allv(params.v, sizes);
128  result.w = comm.all_to_allv(params.w, sizes);
129  result.vis = comm.all_to_allv(params.vis, sizes);
130  result.weights = comm.all_to_allv(params.weights, sizes);
131  result.units = static_cast<utilities::vis_units>(comm.broadcast(static_cast<int>(params.units)));
132  result.ra = comm.broadcast(params.ra);
133  result.dec = comm.broadcast(params.dec);
134  result.average_frequency = comm.broadcast(params.average_frequency);
135  return result;
136 }
137 vis_params scatter_visibilities(vis_params const &params, std::vector<t_int> const &sizes,
138  sopt::mpi::Communicator const &comm) {
139  if (comm.size() == 1) return params;
140  if (not comm.is_root()) return scatter_visibilities(comm);
141 
142  comm.scatter_one(sizes);
143  vis_params result;
144  result.u = comm.scatterv(params.u, sizes);
145  result.v = comm.scatterv(params.v, sizes);
146  result.w = comm.scatterv(params.w, sizes);
147  result.vis = comm.scatterv(params.vis, sizes);
148  result.weights = comm.scatterv(params.weights, sizes);
149  result.units = static_cast<utilities::vis_units>(comm.broadcast(static_cast<int>(params.units)));
150  result.ra = comm.broadcast(params.ra);
151  result.dec = comm.broadcast(params.dec);
152  result.average_frequency = comm.broadcast(params.average_frequency);
153  return result;
154 }
155 
156 vis_params scatter_visibilities(sopt::mpi::Communicator const &comm) {
157  if (comm.is_root())
158  throw std::runtime_error("The root node should call the *other* scatter_visibilities function");
159 
160  auto const local_size = comm.scatter_one<t_int>();
161  vis_params result;
162  result.u = comm.scatterv<decltype(result.u)::Scalar>(local_size);
163  result.v = comm.scatterv<decltype(result.v)::Scalar>(local_size);
164  result.w = comm.scatterv<decltype(result.w)::Scalar>(local_size);
165  result.vis = comm.scatterv<decltype(result.vis)::Scalar>(local_size);
166  result.weights = comm.scatterv<decltype(result.weights)::Scalar>(local_size);
167  result.units = static_cast<utilities::vis_units>(
168  comm.broadcast<std::remove_const<decltype(static_cast<int>(result.units))>::type>());
169  result.ra = comm.broadcast<std::remove_const<decltype(result.ra)>::type>();
170  result.dec = comm.broadcast<std::remove_const<decltype(result.dec)>::type>();
171  result.average_frequency =
172  comm.broadcast<std::remove_const<decltype(result.average_frequency)>::type>();
173  return result;
174 }
175 
177  sopt::mpi::Communicator const &comm) {
178  if (comm.is_root() and comm.size() > 1) {
179  auto const order = distribute::distribute_measurements(params, comm, distribute::plan::radial);
180  return utilities::regroup_and_scatter(params, order, comm);
181  } else if (comm.size() > 1)
182  return utilities::scatter_visibilities(comm);
183  return params;
184 }
185 
186 utilities::vis_params set_cell_size(const sopt::mpi::Communicator &comm,
187  utilities::vis_params const &uv_vis, const t_real &cell_x,
188  const t_real &cell_y) {
189  if (comm.size() == 1) return utilities::set_cell_size(uv_vis, cell_x, cell_y);
190 
191  const t_real max_u = comm.all_reduce<t_real>(uv_vis.u.array().cwiseAbs().maxCoeff(), MPI_MAX);
192  const t_real max_v = comm.all_reduce<t_real>(uv_vis.v.array().cwiseAbs().maxCoeff(), MPI_MAX);
193  return utilities::set_cell_size(uv_vis, max_u, max_v, cell_x, cell_y);
194 }
196  sopt::mpi::Communicator const &comm, const t_int iters,
197  const std::function<t_real(t_real)> &cost,
198  const t_real k_means_rel_diff) {
199  const std::vector<t_int> image_index = std::get<0>(
200  distribute::kmeans_algo(params.w, comm.size(), iters, comm, cost, k_means_rel_diff));
201  return utilities::regroup_and_all_to_all(params, image_index, comm);
202 }
203 std::tuple<utilities::vis_params, std::vector<t_int>, std::vector<t_real>>
204 w_stacking_with_all_to_all(utilities::vis_params const &params, const t_real du,
205  const t_int min_support, const t_int max_support,
206  sopt::mpi::Communicator const &comm, const t_int iters,
207  const t_real fill_relaxation, const std::function<t_real(t_real)> &cost,
208  const t_real k_means_rel_diff) {
209  const auto kmeans =
210  distribute::kmeans_algo(params.w, comm.size(), iters, comm, cost, k_means_rel_diff);
211  const std::vector<t_real> &w_stacks = std::get<1>(kmeans);
212  const std::vector<t_int> groups = distribute::w_support(
213  params.w, std::get<0>(kmeans), w_stacks, du, min_support, max_support, fill_relaxation, comm);
214  utilities::vis_params outdata;
215  std::vector<t_int> image_index;
216  std::tie(outdata, image_index) =
217  utilities::regroup_and_all_to_all(params, std::get<0>(kmeans), groups, comm);
218  return std::tuple<utilities::vis_params, std::vector<t_int>, std::vector<t_real>>(
219  outdata, image_index, w_stacks);
220 }
221 } // namespace purify::utilities
std::tuple< std::vector< t_int >, std::vector< t_real > > kmeans_algo(const Vector< t_real > &w, const t_int number_of_nodes, const t_int iters, const std::function< t_real(t_real)> &cost, const t_real rel_diff)
patition w terms using k-means
Definition: distribute.cc:103
std::vector< t_int > distribute_measurements(Vector< t_real > const &u, Vector< t_real > const &v, Vector< t_real > const &w, t_int const number_of_nodes, distribute::plan const distribution_plan, t_int const &grid_size)
Distribute visiblities into groups.
Definition: distribute.cc:6
std::tuple< vis_params, std::vector< t_int > > regroup_and_all_to_all(vis_params const &params, const std::vector< t_int > &image_index, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
utilities::vis_params distribute_params(utilities::vis_params const &params, sopt::mpi::Communicator const &comm)
utilities::vis_params w_stacking(utilities::vis_params const &params, sopt::mpi::Communicator const &comm, const t_int iters, const std::function< t_real(t_real)> &cost, const t_real k_means_rel_diff)
void regroup(vis_params &uv_params, std::vector< t_int > const &groups_, const t_int max_groups)
Definition: mpi_utilities.cc:9
vis_params all_to_all_visibilities(vis_params const &params, std::vector< t_int > const &sizes, sopt::mpi::Communicator const &comm)
std::tuple< utilities::vis_params, std::vector< t_int >, std::vector< t_real > > w_stacking_with_all_to_all(utilities::vis_params const &params, const t_real du, const t_int min_support, const t_int max_support, sopt::mpi::Communicator const &comm, const t_int iters, const t_real fill_relaxation, const std::function< t_real(t_real)> &cost, const t_real k_means_rel_diff)
vis_params scatter_visibilities(vis_params const &params, std::vector< t_int > const &sizes, sopt::mpi::Communicator const &comm)
vis_params regroup_and_scatter(vis_params const &params, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
utilities::vis_params set_cell_size(const sopt::mpi::Communicator &comm, utilities::vis_params const &uv_vis, const t_real &cell_x, const t_real &cell_y)
t_int w_support(const t_real w, const t_real du, const t_int min, const t_int max)
estimate support size of w given u resolution du
Vector< t_complex > vis
Definition: uvw_utilities.h:22
t_uint size() const
return number of measurements
Definition: uvw_utilities.h:54
Vector< t_complex > weights
Definition: uvw_utilities.h:23