2 #include "purify/config.h"
9 void regroup(
vis_params &uv_params, std::vector<t_int>
const &groups_,
const t_int max_groups) {
10 std::vector<t_int> image_index(uv_params.
size(), 0);
11 regroup(uv_params, image_index, groups_, max_groups);
15 std::vector<t_int>
const &groups_,
const t_int max_groups) {
16 std::vector<t_int> groups = groups_;
18 std::map<t_int, t_int> sizes;
19 for (
auto g = 0; g < max_groups; g++) sizes[g] = 0;
20 for (
auto const &group : groups) ++sizes[group];
22 std::map<t_int, t_int> indices, ends;
24 for (
auto const &item : sizes) {
25 indices[item.first] = i;
29 const auto minmax = std::minmax_element(ends.begin(), ends.end());
30 if (std::get<1>(*std::get<0>(minmax)) < 0)
31 throw std::runtime_error(
"segment end " + std::to_string(std::get<1>(*std::get<0>(minmax))) +
32 " less than 0. Not a valid end.");
33 if (std::get<1>(*std::get<1>(minmax)) > uv_params.
size())
34 throw std::runtime_error(
"segment end " + std::to_string(std::get<1>(*std::get<1>(minmax))) +
35 " larger than data vector " +
36 std::to_string(
static_cast<t_int
>(uv_params.
size())) +
39 auto const expected = [&ends](t_int i) {
41 for (
auto const end : ends) {
42 if (i < end.second)
return j;
49 while (i < uv_params.
u.size()) {
50 auto const expected_proc = expected(i);
51 if (groups[i] == expected_proc) {
55 auto &swapper = indices[groups[i]];
56 if (groups[swapper] == expected(swapper)) {
60 if (swapper >= uv_params.
u.size())
61 throw std::runtime_error(
"regroup (groups, " + std::to_string(groups[swapper]) +
", " +
62 std::to_string(groups[i]) +
") index out of bounds " +
63 std::to_string(i) +
" " + std::to_string(swapper) +
64 " >= " + std::to_string(uv_params.
u.size()));
65 std::swap(groups[i], groups[swapper]);
66 std::swap(uv_params.
u(i), uv_params.
u(swapper));
67 std::swap(uv_params.
v(i), uv_params.
v(swapper));
68 std::swap(uv_params.
w(i), uv_params.
w(swapper));
69 std::swap(uv_params.
vis(i), uv_params.
vis(swapper));
71 std::swap(image_index[i], image_index[swapper]);
78 sopt::mpi::Communicator
const &comm) {
79 if (comm.size() == 1)
return params;
82 std::vector<t_int> sizes(comm.size());
83 std::fill(sizes.begin(), sizes.end(), 0);
84 for (
auto const &group : groups) {
85 if (group >
static_cast<t_int
>(comm.size()))
86 throw std::out_of_range(
"groups should go from 0 to comm.size()");
95 vis_params const ¶ms,
const std::vector<t_int> &image_index,
96 std::vector<t_int>
const &groups, sopt::mpi::Communicator
const &comm) {
97 if (comm.size() == 1)
return std::make_tuple(params, image_index);
99 std::vector<t_int> index_copy(image_index.size());
100 std::copy(image_index.begin(), image_index.end(), index_copy.begin());
103 std::vector<t_int> sizes(comm.size());
104 std::fill(sizes.begin(), sizes.end(), 0);
105 for (
const t_int &group : groups) {
106 if (group >=
static_cast<t_int
>(comm.size())) {
107 throw std::out_of_range(
"groups should go from 0 to comm.size()-1");
113 comm.all_to_allv(index_copy, sizes));
117 sopt::mpi::Communicator
const &comm) {
123 sopt::mpi::Communicator
const &comm) {
124 if (comm.size() == 1)
return params;
126 result.
u = comm.all_to_allv(params.
u, sizes);
127 result.
v = comm.all_to_allv(params.
v, sizes);
128 result.
w = comm.all_to_allv(params.
w, sizes);
129 result.
vis = comm.all_to_allv(params.
vis, sizes);
132 result.
ra = comm.broadcast(params.
ra);
133 result.
dec = comm.broadcast(params.
dec);
138 sopt::mpi::Communicator
const &comm) {
139 if (comm.size() == 1)
return params;
142 comm.scatter_one(sizes);
144 result.
u = comm.scatterv(params.
u, sizes);
145 result.
v = comm.scatterv(params.
v, sizes);
146 result.
w = comm.scatterv(params.
w, sizes);
147 result.
vis = comm.scatterv(params.
vis, sizes);
150 result.
ra = comm.broadcast(params.
ra);
151 result.
dec = comm.broadcast(params.
dec);
158 throw std::runtime_error(
"The root node should call the *other* scatter_visibilities function");
160 auto const local_size = comm.scatter_one<t_int>();
162 result.
u = comm.scatterv<decltype(result.u)::Scalar>(local_size);
163 result.v = comm.scatterv<decltype(result.v)::Scalar>(local_size);
164 result.w = comm.scatterv<decltype(result.w)::Scalar>(local_size);
165 result.vis = comm.scatterv<decltype(result.vis)::Scalar>(local_size);
166 result.weights = comm.scatterv<decltype(result.weights)::Scalar>(local_size);
168 comm.broadcast<std::remove_const<decltype(static_cast<int>(result.units))>::type>());
169 result.ra = comm.broadcast<std::remove_const<decltype(result.ra)>::type>();
170 result.dec = comm.broadcast<std::remove_const<decltype(result.dec)>::type>();
171 result.average_frequency =
172 comm.broadcast<std::remove_const<decltype(result.average_frequency)>::type>();
177 sopt::mpi::Communicator
const &comm) {
178 if (comm.is_root() and comm.size() > 1) {
181 }
else if (comm.size() > 1)
188 const t_real &cell_y) {
191 const t_real max_u = comm.all_reduce<t_real>(uv_vis.
u.array().cwiseAbs().maxCoeff(), MPI_MAX);
192 const t_real max_v = comm.all_reduce<t_real>(uv_vis.
v.array().cwiseAbs().maxCoeff(), MPI_MAX);
196 sopt::mpi::Communicator
const &comm,
const t_int iters,
197 const std::function<t_real(t_real)> &cost,
198 const t_real k_means_rel_diff) {
199 const std::vector<t_int> image_index = std::get<0>(
203 std::tuple<utilities::vis_params, std::vector<t_int>, std::vector<t_real>>
205 const t_int min_support,
const t_int max_support,
206 sopt::mpi::Communicator
const &comm,
const t_int iters,
207 const t_real fill_relaxation,
const std::function<t_real(t_real)> &cost,
208 const t_real k_means_rel_diff) {
211 const std::vector<t_real> &w_stacks = std::get<1>(kmeans);
213 params.
w, std::get<0>(kmeans), w_stacks, du, min_support, max_support, fill_relaxation, comm);
215 std::vector<t_int> image_index;
216 std::tie(outdata, image_index) =
218 return std::tuple<utilities::vis_params, std::vector<t_int>, std::vector<t_real>>(
219 outdata, image_index, w_stacks);
std::tuple< std::vector< t_int >, std::vector< t_real > > kmeans_algo(const Vector< t_real > &w, const t_int number_of_nodes, const t_int iters, const std::function< t_real(t_real)> &cost, const t_real rel_diff)
patition w terms using k-means
std::vector< t_int > distribute_measurements(Vector< t_real > const &u, Vector< t_real > const &v, Vector< t_real > const &w, t_int const number_of_nodes, distribute::plan const distribution_plan, t_int const &grid_size)
Distribute visiblities into groups.
std::tuple< vis_params, std::vector< t_int > > regroup_and_all_to_all(vis_params const ¶ms, const std::vector< t_int > &image_index, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
utilities::vis_params distribute_params(utilities::vis_params const ¶ms, sopt::mpi::Communicator const &comm)
utilities::vis_params w_stacking(utilities::vis_params const ¶ms, sopt::mpi::Communicator const &comm, const t_int iters, const std::function< t_real(t_real)> &cost, const t_real k_means_rel_diff)
void regroup(vis_params &uv_params, std::vector< t_int > const &groups_, const t_int max_groups)
vis_params all_to_all_visibilities(vis_params const ¶ms, std::vector< t_int > const &sizes, sopt::mpi::Communicator const &comm)
std::tuple< utilities::vis_params, std::vector< t_int >, std::vector< t_real > > w_stacking_with_all_to_all(utilities::vis_params const ¶ms, const t_real du, const t_int min_support, const t_int max_support, sopt::mpi::Communicator const &comm, const t_int iters, const t_real fill_relaxation, const std::function< t_real(t_real)> &cost, const t_real k_means_rel_diff)
vis_params scatter_visibilities(vis_params const ¶ms, std::vector< t_int > const &sizes, sopt::mpi::Communicator const &comm)
vis_params regroup_and_scatter(vis_params const ¶ms, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
utilities::vis_params set_cell_size(const sopt::mpi::Communicator &comm, utilities::vis_params const &uv_vis, const t_real &cell_x, const t_real &cell_y)
t_int w_support(const t_real w, const t_real du, const t_int min, const t_int max)
estimate support size of w given u resolution du
t_uint size() const
return number of measurements
Vector< t_complex > weights