PURIFY
Next-generation radio interferometric imaging
Functions
mpi_measurement_factory.cc File Reference
#include "catch2/catch_all.hpp"
#include "purify/distribute.h"
#include "purify/logging.h"
#include "purify/measurement_operator_factory.h"
#include "purify/mpi_utilities.h"
#include "purify/operators.h"
#include "purify/utilities.h"
#include <sopt/mpi/communicator.h>
#include <sopt/power_method.h>
+ Include dependency graph for mpi_measurement_factory.cc:

Go to the source code of this file.

Functions

 TEST_CASE ("Serial vs Distributed Operator")
 
 TEST_CASE ("GPU Serial vs Distributed Operator")
 

Function Documentation

◆ TEST_CASE() [1/2]

TEST_CASE ( "GPU Serial vs Distributed Operator"  )

Definition at line 166 of file mpi_measurement_factory.cc.

166  {
168  auto const world = sopt::mpi::Communicator::World();
169 
170  auto const N = 100;
171  auto uv_serial = utilities::random_sample_density(N, 0, constant::pi / 3);
172  uv_serial.u = world.broadcast(uv_serial.u);
173  uv_serial.v = world.broadcast(uv_serial.v);
174  uv_serial.w = world.broadcast(uv_serial.w);
175  uv_serial.units = utilities::vis_units::radians;
176  uv_serial.vis = world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(uv_serial.u.size()));
177  uv_serial.weights =
178  world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(uv_serial.u.size()));
179 
180  utilities::vis_params uv_mpi;
181  if (world.is_root()) {
182  auto const order =
183  distribute::distribute_measurements(uv_serial, world, distribute::plan::radial);
184  uv_mpi = utilities::regroup_and_scatter(uv_serial, order, world);
185  } else
186  uv_mpi = utilities::scatter_visibilities(world);
187 
188  auto const over_sample = 2;
189  auto const J = 4;
190  auto const kernel = kernels::kernel::kb;
191  auto const width = 128;
192  auto const height = 128;
193  const Vector<t_complex> power_init =
194  world.broadcast(Vector<t_complex>::Random(height * width).eval());
195  const auto op_serial = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
197  uv_serial.u, uv_serial.v, uv_serial.w, uv_serial.weights, height, width, over_sample),
198  100, 1e-4, power_init));
199  CAPTURE(world.size());
200 
201  for (auto method : {factory::distributed_measurement_operator::gpu_mpi_distribute_image,
202  factory::distributed_measurement_operator::gpu_mpi_distribute_grid}) {
203 #ifndef PURIFY_ARRAYFIRE
204  REQUIRE_THROWS(factory::measurement_operator_factory<Vector<t_complex>>(
205  method, uv_mpi.u, uv_mpi.v, uv_mpi.w, uv_mpi.weights, height, width, over_sample));
206 #else
207  const auto op = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
208  factory::measurement_operator_factory<Vector<t_complex>>(
209  method, uv_mpi.u, uv_mpi.v, uv_mpi.w, uv_mpi.weights, height, width, over_sample),
210  100, 1e-4, power_init));
211 
212  if (uv_serial.u.size() == uv_mpi.u.size()) {
213  REQUIRE(uv_serial.u.isApprox(uv_mpi.u));
214  CHECK(uv_serial.v.isApprox(uv_mpi.v));
215  CHECK(uv_serial.weights.isApprox(uv_mpi.weights));
216  }
217  SECTION("Degridding") {
218  Vector<t_complex> const image =
219  world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(width * height));
220 
221  auto uv_degrid = uv_serial;
222  if (world.is_root()) {
223  uv_degrid.vis = *op_serial * image;
224  auto const order =
225  distribute::distribute_measurements(uv_degrid, world, distribute::plan::radial);
226  uv_degrid = utilities::regroup_and_scatter(uv_degrid, order, world);
227  } else
228  uv_degrid = utilities::scatter_visibilities(world);
229  Vector<t_complex> const degridded = *op * image;
230  REQUIRE(degridded.size() == uv_degrid.vis.size());
231  REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
232  }
233  SECTION("Gridding") {
234  Vector<t_complex> const gridded = op->adjoint() * uv_mpi.vis;
235  Vector<t_complex> const gridded_serial = op_serial->adjoint() * uv_serial.vis;
236  REQUIRE(gridded.size() == gridded_serial.size());
237  REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
238  }
239 #endif
240  }
241 }
#define CHECK(CONDITION, ERROR)
Definition: casa.cc:6
const t_real pi
mathematical constant
Definition: types.h:70
std::vector< t_int > distribute_measurements(Vector< t_real > const &u, Vector< t_real > const &v, Vector< t_real > const &w, t_int const number_of_nodes, distribute::plan const distribution_plan, t_int const &grid_size)
Distribute visiblities into groups.
Definition: distribute.cc:6
std::shared_ptr< sopt::LinearTransform< T > > measurement_operator_factory(const distributed_measurement_operator distribute, ARGS &&...args)
distributed measurement operator factory
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
Definition: logging.h:137
std::shared_ptr< sopt::LinearTransform< T > > init_degrid_operator_2d(const Vector< t_real > &u, const Vector< t_real > &v, const Vector< t_real > &w, const Vector< t_complex > &weights, const t_uint &imsizey, const t_uint &imsizex, const t_real &oversample_ratio=2, const kernels::kernel kernel=kernels::kernel::kb, const t_uint Ju=4, const t_uint Jv=4, const bool w_stacking=false, const t_real &cellx=1, const t_real &celly=1)
Returns linear transform that is the standard degridding operator.
Definition: operators.h:608
vis_params scatter_visibilities(vis_params const &params, std::vector< t_int > const &sizes, sopt::mpi::Communicator const &comm)
vis_params regroup_and_scatter(vis_params const &params, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
utilities::vis_params random_sample_density(const t_int vis_num, const t_real mean, const t_real standard_deviation, const t_real rms_w)
Generates a random visibility coverage.

References CHECK, purify::distribute::distribute_measurements(), purify::factory::gpu_mpi_distribute_grid, purify::factory::gpu_mpi_distribute_image, purify::measurementoperator::init_degrid_operator_2d(), purify::kernels::kb, purify::factory::measurement_operator_factory(), purify::constant::pi, purify::distribute::radial, purify::utilities::radians, purify::utilities::random_sample_density(), purify::utilities::regroup_and_scatter(), purify::utilities::scatter_visibilities(), and purify::logging::set_level().

◆ TEST_CASE() [2/2]

TEST_CASE ( "Serial vs Distributed Operator"  )

Definition at line 12 of file mpi_measurement_factory.cc.

12  {
14  auto const world = sopt::mpi::Communicator::World();
15 
16  auto const N = 100;
17  auto uv_serial = utilities::random_sample_density(N, 0, constant::pi / 3, 100);
18  uv_serial.u = world.broadcast(uv_serial.u);
19  uv_serial.v = world.broadcast(uv_serial.v);
20  uv_serial.w = world.broadcast(uv_serial.w);
21  uv_serial.units = utilities::vis_units::radians;
22  uv_serial.vis = world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(uv_serial.u.size()));
23  uv_serial.weights =
24  world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(uv_serial.u.size()));
25 
26  utilities::vis_params uv_mpi;
27  if (world.is_root()) {
28  auto const order =
29  distribute::distribute_measurements(uv_serial, world, distribute::plan::radial);
30  uv_mpi = utilities::regroup_and_scatter(uv_serial, order, world);
31  } else
32  uv_mpi = utilities::scatter_visibilities(world);
33 
34  auto const over_sample = 2;
35  auto const J = 4;
36  auto const kernel = kernels::kernel::kb;
37  auto const width = 128;
38  auto const height = 128;
39  const Vector<t_complex> power_init =
40  world.broadcast(Vector<t_complex>::Random(height * width).eval());
41  const auto op_serial = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
43  uv_serial.u, uv_serial.v, uv_serial.w, uv_serial.weights, height, width, over_sample),
44  100, 1e-4, power_init));
45  CAPTURE(world.size());
46 
47  for (auto method : {factory::distributed_measurement_operator::mpi_distribute_image,
48  factory::distributed_measurement_operator::mpi_distribute_grid}) {
49  const auto op = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
50  factory::measurement_operator_factory<Vector<t_complex>>(
51  method, uv_mpi.u, uv_mpi.v, uv_mpi.w, uv_mpi.weights, height, width, over_sample),
52  100, 1e-4, power_init));
53  if (uv_serial.u.size() == uv_mpi.u.size()) {
54  REQUIRE(uv_serial.u.isApprox(uv_mpi.u));
55  CHECK(uv_serial.v.isApprox(uv_mpi.v));
56  CHECK(uv_serial.weights.isApprox(uv_mpi.weights));
57  }
58  SECTION("Degridding") {
59  Vector<t_complex> const image =
60  world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(width * height));
61 
62  auto uv_degrid = uv_serial;
63  if (world.is_root()) {
64  uv_degrid.vis = *op_serial * image;
65  auto const order =
66  distribute::distribute_measurements(uv_degrid, world, distribute::plan::radial);
67  uv_degrid = utilities::regroup_and_scatter(uv_degrid, order, world);
68  } else
69  uv_degrid = utilities::scatter_visibilities(world);
70  Vector<t_complex> const degridded = *op * image;
71  REQUIRE(degridded.size() == uv_degrid.vis.size());
72  REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
73  }
74  SECTION("Gridding") {
75  Vector<t_complex> const gridded = op->adjoint() * uv_mpi.vis;
76  Vector<t_complex> const gridded_serial = op_serial->adjoint() * uv_serial.vis;
77  REQUIRE(gridded.size() == gridded_serial.size());
78  REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
79  }
80  }
81  SECTION("All to All") {
82  t_real const cell_size = 1;
83  const auto kmeans = distribute::kmeans_algo(uv_mpi.w, world.size(), 100, world);
84  const std::vector<t_int> image_index = std::get<0>(kmeans);
85  const std::vector<t_real> w_stacks = std::get<1>(kmeans);
86 
87  const auto uv_stacks = utilities::regroup_and_all_to_all(uv_mpi, image_index, world);
88  // standard operator
89  const auto op_wproj = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
91  world, uv_stacks, height, width, cell_size, cell_size, over_sample, kernel, J, J, true),
92  100, 1e-4, power_init));
93  // all to all operator
94  const auto op_wproj_all = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
96  factory::distributed_measurement_operator::mpi_distribute_all_to_all, image_index,
97  w_stacks, uv_mpi, height, width, cell_size, cell_size, over_sample, kernel, J, J, true),
98  100, 1e-4, power_init));
99  if (world.size() == 1) {
100  REQUIRE(uv_serial.u.isApprox(uv_mpi.u));
101  CHECK(uv_serial.v.isApprox(uv_mpi.v));
102  CHECK(uv_serial.weights.isApprox(uv_mpi.weights));
103  }
104  SECTION("Degridding") {
105  Vector<t_complex> const image =
106  world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(width * height));
107 
108  const Vector<t_complex> degridded = *op_wproj * image;
109  auto uv_degrid = uv_mpi;
110  uv_degrid.vis = *op_wproj_all * image;
111  uv_degrid = utilities::regroup_and_all_to_all(uv_degrid, image_index, world);
112  REQUIRE(degridded.size() == uv_degrid.vis.size());
113  REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
114  }
115  SECTION("Gridding") {
116  Vector<t_complex> const gridded = op_wproj_all->adjoint() * uv_mpi.vis;
117  Vector<t_complex> const gridded_serial = op_wproj->adjoint() * uv_stacks.vis;
118  REQUIRE(gridded.size() == gridded_serial.size());
119  REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
120  }
121  }
122  SECTION("All to All wproj") {
123  t_real const cell_size = 1;
124  const auto kmeans = distribute::kmeans_algo(uv_mpi.w, world.size(), 100, world);
125  const std::vector<t_int> image_index = std::get<0>(kmeans);
126  const std::vector<t_real> w_stacks = std::get<1>(kmeans);
127 
128  const auto uv_stacks = utilities::regroup_and_all_to_all(uv_mpi, image_index, world);
129  // standard operator
130  const auto op_wproj = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
132  world, uv_stacks, height, width, cell_size, cell_size, over_sample, kernel, J, 10, true,
133  1e-8, 1e-8, dde_type::wkernel_radial),
134  100, 1e-4, power_init));
135  // all to all operator
136  const auto op_wproj_all = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
138  factory::distributed_measurement_operator::mpi_distribute_all_to_all, image_index,
139  w_stacks, uv_mpi, height, width, cell_size, cell_size, over_sample, kernel, J, 100,
140  true, 1e-8, 1e-8, dde_type::wkernel_radial),
141  100, 1e-4, power_init));
142  if (world.size() == 1) {
143  REQUIRE(uv_serial.u.isApprox(uv_mpi.u));
144  CHECK(uv_serial.v.isApprox(uv_mpi.v));
145  CHECK(uv_serial.weights.isApprox(uv_mpi.weights));
146  }
147  SECTION("Degridding") {
148  Vector<t_complex> const image =
149  world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(width * height));
150 
151  const Vector<t_complex> degridded = *op_wproj * image;
152  auto uv_degrid = uv_mpi;
153  uv_degrid.vis = *op_wproj_all * image;
154  uv_degrid = utilities::regroup_and_all_to_all(uv_degrid, image_index, world);
155  REQUIRE(degridded.size() == uv_degrid.vis.size());
156  REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
157  }
158  SECTION("Gridding") {
159  Vector<t_complex> const gridded = op_wproj_all->adjoint() * uv_mpi.vis;
160  Vector<t_complex> const gridded_serial = op_wproj->adjoint() * uv_stacks.vis;
161  REQUIRE(gridded.size() == gridded_serial.size());
162  REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
163  }
164  }
165 }
std::tuple< std::vector< t_int >, std::vector< t_real > > kmeans_algo(const Vector< t_real > &w, const t_int number_of_nodes, const t_int iters, const std::function< t_real(t_real)> &cost, const t_real rel_diff)
patition w terms using k-means
Definition: distribute.cc:103
std::shared_ptr< sopt::LinearTransform< T > > all_to_all_measurement_operator_factory(const distributed_measurement_operator distribute, const std::vector< t_int > &image_stacks, const std::vector< t_real > &w_stacks, ARGS &&...args)
distributed measurement operator factory
std::tuple< vis_params, std::vector< t_int > > regroup_and_all_to_all(vis_params const &params, const std::vector< t_int > &image_index, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)

References purify::factory::all_to_all_measurement_operator_factory(), CHECK, purify::distribute::distribute_measurements(), purify::measurementoperator::init_degrid_operator_2d(), purify::kernels::kb, purify::distribute::kmeans_algo(), purify::factory::measurement_operator_factory(), purify::factory::mpi_distribute_all_to_all, purify::factory::mpi_distribute_grid, purify::factory::mpi_distribute_image, purify::constant::pi, purify::distribute::radial, purify::utilities::radians, purify::utilities::random_sample_density(), purify::utilities::regroup_and_all_to_all(), purify::utilities::regroup_and_scatter(), purify::utilities::scatter_visibilities(), purify::logging::set_level(), and purify::wkernel_radial.