14 auto const world = sopt::mpi::Communicator::World();
18 uv_serial.u = world.broadcast(uv_serial.u);
19 uv_serial.v = world.broadcast(uv_serial.v);
20 uv_serial.w = world.broadcast(uv_serial.w);
21 uv_serial.units = utilities::vis_units::radians;
22 uv_serial.vis = world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(uv_serial.u.size()));
24 world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(uv_serial.u.size()));
27 if (world.is_root()) {
34 auto const over_sample = 2;
36 auto const kernel = kernels::kernel::kb;
37 auto const width = 128;
38 auto const height = 128;
39 const Vector<t_complex> power_init =
40 world.broadcast(Vector<t_complex>::Random(height * width).eval());
41 const auto op_serial = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
43 uv_serial.u, uv_serial.v, uv_serial.w, uv_serial.weights, height, width, over_sample),
44 100, 1e-4, power_init));
45 CAPTURE(world.size());
47 for (
auto method : {factory::distributed_measurement_operator::mpi_distribute_image,
48 factory::distributed_measurement_operator::mpi_distribute_grid}) {
49 const auto op = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
51 method, uv_mpi.u, uv_mpi.v, uv_mpi.w, uv_mpi.weights, height, width, over_sample),
52 100, 1e-4, power_init));
53 if (uv_serial.u.size() == uv_mpi.u.size()) {
54 REQUIRE(uv_serial.u.isApprox(uv_mpi.u));
55 CHECK(uv_serial.v.isApprox(uv_mpi.v));
56 CHECK(uv_serial.weights.isApprox(uv_mpi.weights));
58 SECTION(
"Degridding") {
59 Vector<t_complex>
const image =
60 world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(width * height));
62 auto uv_degrid = uv_serial;
63 if (world.is_root()) {
64 uv_degrid.vis = *op_serial * image;
70 Vector<t_complex>
const degridded = *op * image;
71 REQUIRE(degridded.size() == uv_degrid.vis.size());
72 REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
75 Vector<t_complex>
const gridded = op->adjoint() * uv_mpi.vis;
76 Vector<t_complex>
const gridded_serial = op_serial->adjoint() * uv_serial.vis;
77 REQUIRE(gridded.size() == gridded_serial.size());
78 REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
81 SECTION(
"All to All") {
82 t_real
const cell_size = 1;
84 const std::vector<t_int> image_index = std::get<0>(kmeans);
85 const std::vector<t_real> w_stacks = std::get<1>(kmeans);
89 const auto op_wproj = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
91 world, uv_stacks, height, width, cell_size, cell_size, over_sample,
kernel, J, J,
true),
92 100, 1e-4, power_init));
94 const auto op_wproj_all = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
96 factory::distributed_measurement_operator::mpi_distribute_all_to_all, image_index,
97 w_stacks, uv_mpi, height, width, cell_size, cell_size, over_sample,
kernel, J, J,
true),
98 100, 1e-4, power_init));
99 if (world.size() == 1) {
100 REQUIRE(uv_serial.u.isApprox(uv_mpi.u));
101 CHECK(uv_serial.v.isApprox(uv_mpi.v));
102 CHECK(uv_serial.weights.isApprox(uv_mpi.weights));
104 SECTION(
"Degridding") {
105 Vector<t_complex>
const image =
106 world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(width * height));
108 const Vector<t_complex> degridded = *op_wproj * image;
109 auto uv_degrid = uv_mpi;
110 uv_degrid.vis = *op_wproj_all * image;
112 REQUIRE(degridded.size() == uv_degrid.vis.size());
113 REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
115 SECTION(
"Gridding") {
116 Vector<t_complex>
const gridded = op_wproj_all->adjoint() * uv_mpi.vis;
117 Vector<t_complex>
const gridded_serial = op_wproj->adjoint() * uv_stacks.vis;
118 REQUIRE(gridded.size() == gridded_serial.size());
119 REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
122 SECTION(
"All to All wproj") {
123 t_real
const cell_size = 1;
125 const std::vector<t_int> image_index = std::get<0>(kmeans);
126 const std::vector<t_real> w_stacks = std::get<1>(kmeans);
130 const auto op_wproj = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
132 world, uv_stacks, height, width, cell_size, cell_size, over_sample,
kernel, J, 10,
true,
133 1e-8, 1e-8, dde_type::wkernel_radial),
134 100, 1e-4, power_init));
136 const auto op_wproj_all = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
138 factory::distributed_measurement_operator::mpi_distribute_all_to_all, image_index,
139 w_stacks, uv_mpi, height, width, cell_size, cell_size, over_sample,
kernel, J, 100,
140 true, 1e-8, 1e-8, dde_type::wkernel_radial),
141 100, 1e-4, power_init));
142 if (world.size() == 1) {
143 REQUIRE(uv_serial.u.isApprox(uv_mpi.u));
144 CHECK(uv_serial.v.isApprox(uv_mpi.v));
145 CHECK(uv_serial.weights.isApprox(uv_mpi.weights));
147 SECTION(
"Degridding") {
148 Vector<t_complex>
const image =
149 world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(width * height));
151 const Vector<t_complex> degridded = *op_wproj * image;
152 auto uv_degrid = uv_mpi;
153 uv_degrid.vis = *op_wproj_all * image;
155 REQUIRE(degridded.size() == uv_degrid.vis.size());
156 REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
158 SECTION(
"Gridding") {
159 Vector<t_complex>
const gridded = op_wproj_all->adjoint() * uv_mpi.vis;
160 Vector<t_complex>
const gridded_serial = op_wproj->adjoint() * uv_stacks.vis;
161 REQUIRE(gridded.size() == gridded_serial.size());
162 REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
std::tuple< std::vector< t_int >, std::vector< t_real > > kmeans_algo(const Vector< t_real > &w, const t_int number_of_nodes, const t_int iters, const std::function< t_real(t_real)> &cost, const t_real rel_diff)
patition w terms using k-means
std::shared_ptr< sopt::LinearTransform< T > > all_to_all_measurement_operator_factory(const distributed_measurement_operator distribute, const std::vector< t_int > &image_stacks, const std::vector< t_real > &w_stacks, ARGS &&...args)
distributed measurement operator factory
std::tuple< vis_params, std::vector< t_int > > regroup_and_all_to_all(vis_params const ¶ms, const std::vector< t_int > &image_index, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)