1 #include "catch2/catch_all.hpp"
8 #include <sopt/mpi/communicator.h>
9 #include <sopt/power_method.h>
14 auto const world = sopt::mpi::Communicator::World();
18 uv_serial.u = world.broadcast(uv_serial.u);
19 uv_serial.v = world.broadcast(uv_serial.v);
20 uv_serial.w = world.broadcast(uv_serial.w);
22 uv_serial.vis = world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(uv_serial.u.size()));
24 world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(uv_serial.u.size()));
27 if (world.is_root()) {
34 auto const over_sample = 2;
37 auto const width = 128;
38 auto const height = 128;
39 const Vector<t_complex> power_init =
40 world.broadcast(Vector<t_complex>::Random(height * width).eval());
41 const auto op_serial = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
43 uv_serial.u, uv_serial.v, uv_serial.w, uv_serial.weights, height, width, over_sample),
44 100, 1e-4, power_init));
45 CAPTURE(world.size());
49 const auto op = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
51 method, uv_mpi.u, uv_mpi.v, uv_mpi.w, uv_mpi.weights, height, width, over_sample),
52 100, 1e-4, power_init));
53 if (uv_serial.u.size() == uv_mpi.u.size()) {
54 REQUIRE(uv_serial.u.isApprox(uv_mpi.u));
55 CHECK(uv_serial.v.isApprox(uv_mpi.v));
56 CHECK(uv_serial.weights.isApprox(uv_mpi.weights));
58 SECTION(
"Degridding") {
59 Vector<t_complex>
const image =
60 world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(width * height));
62 auto uv_degrid = uv_serial;
63 if (world.is_root()) {
64 uv_degrid.vis = *op_serial * image;
70 Vector<t_complex>
const degridded = *op * image;
71 REQUIRE(degridded.size() == uv_degrid.vis.size());
72 REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
75 Vector<t_complex>
const gridded = op->adjoint() * uv_mpi.vis;
76 Vector<t_complex>
const gridded_serial = op_serial->adjoint() * uv_serial.vis;
77 REQUIRE(gridded.size() == gridded_serial.size());
78 REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
81 SECTION(
"All to All") {
82 t_real
const cell_size = 1;
84 const std::vector<t_int> image_index = std::get<0>(kmeans);
85 const std::vector<t_real> w_stacks = std::get<1>(kmeans);
89 const auto op_wproj = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
91 world, uv_stacks, height, width, cell_size, cell_size, over_sample,
kernel, J, J,
true),
92 100, 1e-4, power_init));
94 const auto op_wproj_all = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
97 w_stacks, uv_mpi, height, width, cell_size, cell_size, over_sample,
kernel, J, J,
true),
98 100, 1e-4, power_init));
99 if (world.size() == 1) {
100 REQUIRE(uv_serial.u.isApprox(uv_mpi.u));
101 CHECK(uv_serial.v.isApprox(uv_mpi.v));
102 CHECK(uv_serial.weights.isApprox(uv_mpi.weights));
104 SECTION(
"Degridding") {
105 Vector<t_complex>
const image =
106 world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(width * height));
108 const Vector<t_complex> degridded = *op_wproj * image;
109 auto uv_degrid = uv_mpi;
110 uv_degrid.vis = *op_wproj_all * image;
112 REQUIRE(degridded.size() == uv_degrid.vis.size());
113 REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
115 SECTION(
"Gridding") {
116 Vector<t_complex>
const gridded = op_wproj_all->adjoint() * uv_mpi.vis;
117 Vector<t_complex>
const gridded_serial = op_wproj->adjoint() * uv_stacks.vis;
118 REQUIRE(gridded.size() == gridded_serial.size());
119 REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
122 SECTION(
"All to All wproj") {
123 t_real
const cell_size = 1;
125 const std::vector<t_int> image_index = std::get<0>(kmeans);
126 const std::vector<t_real> w_stacks = std::get<1>(kmeans);
130 const auto op_wproj = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
132 world, uv_stacks, height, width, cell_size, cell_size, over_sample,
kernel, J, 10,
true,
134 100, 1e-4, power_init));
136 const auto op_wproj_all = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
139 w_stacks, uv_mpi, height, width, cell_size, cell_size, over_sample,
kernel, J, 100,
141 100, 1e-4, power_init));
142 if (world.size() == 1) {
143 REQUIRE(uv_serial.u.isApprox(uv_mpi.u));
144 CHECK(uv_serial.v.isApprox(uv_mpi.v));
145 CHECK(uv_serial.weights.isApprox(uv_mpi.weights));
147 SECTION(
"Degridding") {
148 Vector<t_complex>
const image =
149 world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(width * height));
151 const Vector<t_complex> degridded = *op_wproj * image;
152 auto uv_degrid = uv_mpi;
153 uv_degrid.vis = *op_wproj_all * image;
155 REQUIRE(degridded.size() == uv_degrid.vis.size());
156 REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
158 SECTION(
"Gridding") {
159 Vector<t_complex>
const gridded = op_wproj_all->adjoint() * uv_mpi.vis;
160 Vector<t_complex>
const gridded_serial = op_wproj->adjoint() * uv_stacks.vis;
161 REQUIRE(gridded.size() == gridded_serial.size());
162 REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
168 auto const world = sopt::mpi::Communicator::World();
172 uv_serial.u = world.broadcast(uv_serial.u);
173 uv_serial.v = world.broadcast(uv_serial.v);
174 uv_serial.w = world.broadcast(uv_serial.w);
176 uv_serial.vis = world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(uv_serial.u.size()));
178 world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(uv_serial.u.size()));
181 if (world.is_root()) {
188 auto const over_sample = 2;
191 auto const width = 128;
192 auto const height = 128;
193 const Vector<t_complex> power_init =
194 world.broadcast(Vector<t_complex>::Random(height * width).eval());
195 const auto op_serial = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
197 uv_serial.u, uv_serial.v, uv_serial.w, uv_serial.weights, height, width, over_sample),
198 100, 1e-4, power_init));
199 CAPTURE(world.size());
203 #ifndef PURIFY_ARRAYFIRE
205 method, uv_mpi.u, uv_mpi.v, uv_mpi.w, uv_mpi.weights, height, width, over_sample));
207 const auto op = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
209 method, uv_mpi.u, uv_mpi.v, uv_mpi.w, uv_mpi.weights, height, width, over_sample),
210 100, 1e-4, power_init));
212 if (uv_serial.u.size() == uv_mpi.u.size()) {
213 REQUIRE(uv_serial.u.isApprox(uv_mpi.u));
214 CHECK(uv_serial.v.isApprox(uv_mpi.v));
215 CHECK(uv_serial.weights.isApprox(uv_mpi.weights));
217 SECTION(
"Degridding") {
218 Vector<t_complex>
const image =
219 world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(width * height));
221 auto uv_degrid = uv_serial;
222 if (world.is_root()) {
223 uv_degrid.vis = *op_serial * image;
229 Vector<t_complex>
const degridded = *op * image;
230 REQUIRE(degridded.size() == uv_degrid.vis.size());
231 REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
233 SECTION(
"Gridding") {
234 Vector<t_complex>
const gridded = op->adjoint() * uv_mpi.vis;
235 Vector<t_complex>
const gridded_serial = op_serial->adjoint() * uv_serial.vis;
236 REQUIRE(gridded.size() == gridded_serial.size());
237 REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
#define CHECK(CONDITION, ERROR)
TEST_CASE("Serial vs Distributed Operator")
const t_real pi
mathematical constant
std::tuple< std::vector< t_int >, std::vector< t_real > > kmeans_algo(const Vector< t_real > &w, const t_int number_of_nodes, const t_int iters, const std::function< t_real(t_real)> &cost, const t_real rel_diff)
patition w terms using k-means
std::vector< t_int > distribute_measurements(Vector< t_real > const &u, Vector< t_real > const &v, Vector< t_real > const &w, t_int const number_of_nodes, distribute::plan const distribution_plan, t_int const &grid_size)
Distribute visiblities into groups.
@ gpu_mpi_distribute_image
@ gpu_mpi_distribute_grid
@ mpi_distribute_all_to_all
std::shared_ptr< sopt::LinearTransform< T > > measurement_operator_factory(const distributed_measurement_operator distribute, ARGS &&...args)
distributed measurement operator factory
std::shared_ptr< sopt::LinearTransform< T > > all_to_all_measurement_operator_factory(const distributed_measurement_operator distribute, const std::vector< t_int > &image_stacks, const std::vector< t_real > &w_stacks, ARGS &&...args)
distributed measurement operator factory
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
std::shared_ptr< sopt::LinearTransform< T > > init_degrid_operator_2d(const Vector< t_real > &u, const Vector< t_real > &v, const Vector< t_real > &w, const Vector< t_complex > &weights, const t_uint &imsizey, const t_uint &imsizex, const t_real &oversample_ratio=2, const kernels::kernel kernel=kernels::kernel::kb, const t_uint Ju=4, const t_uint Jv=4, const bool w_stacking=false, const t_real &cellx=1, const t_real &celly=1)
Returns linear transform that is the standard degridding operator.
std::tuple< vis_params, std::vector< t_int > > regroup_and_all_to_all(vis_params const ¶ms, const std::vector< t_int > &image_index, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
vis_params scatter_visibilities(vis_params const ¶ms, std::vector< t_int > const &sizes, sopt::mpi::Communicator const &comm)
vis_params regroup_and_scatter(vis_params const ¶ms, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
utilities::vis_params random_sample_density(const t_int vis_num, const t_real mean, const t_real standard_deviation, const t_real rms_w)
Generates a random visibility coverage.