PURIFY
Next-generation radio interferometric imaging
mpi_algo_factory.cc
Go to the documentation of this file.
1 #include <numeric>
2 #include <random>
3 #include <utility>
4 #include <catch2/catch_all.hpp>
5 
6 #include "purify/types.h"
7 #include "purify/directories.h"
8 #include "purify/distribute.h"
9 #include "purify/logging.h"
10 #include "purify/mpi_utilities.h"
11 #include "purify/pfitsio.h"
12 #include "purify/utilities.h"
13 #include <sopt/logging.h>
14 #include <sopt/mpi/communicator.h>
15 #include <sopt/mpi/utilities.h>
16 #include <sopt/power_method.h>
17 #include <sopt/wavelets.h>
18 
19 #ifdef PURIFY_H5
20 #include "purify/h5reader.h"
21 #endif
22 
26 
27 using namespace purify;
28 utilities::vis_params dirty_visibilities(const std::vector<std::string> &names) {
29  return utilities::read_visibility(names, false);
30 }
31 
32 utilities::vis_params dirty_visibilities(const std::vector<std::string> &names,
33  sopt::mpi::Communicator const &comm) {
34  if (comm.size() == 1) return dirty_visibilities(names);
35  if (comm.is_root()) {
36  auto result = dirty_visibilities(names);
37  auto const order = distribute::distribute_measurements(result, comm, distribute::plan::none);
38  return utilities::regroup_and_scatter(result, order, comm);
39  }
40  auto result = utilities::scatter_visibilities(comm);
41  return result;
42 }
43 
44 TEST_CASE("Serial vs. Serial with MPI PADMM") {
45  auto const world = sopt::mpi::Communicator::World();
46 
47  const std::string &test_dir = "expected/padmm/";
48  const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
49 
50  auto uv_data = dirty_visibilities({input_data_path}, world);
52  if (world.is_root()) {
53  CAPTURE(uv_data.vis.head(5));
54  }
55  REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
56 
57  t_uint const imsizey = 128;
58  t_uint const imsizex = 128;
59 
60  auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
62  1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
63  Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey).eval();
64  auto const power_method_stuff =
65  sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
66  const t_real op_norm = std::get<0>(power_method_stuff);
67  measurements_transform->set_norm(op_norm);
68 
69  std::vector<std::tuple<std::string, t_uint>> const sara{
70  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
71  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
72  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
73  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
74  factory::distributed_wavelet_operator::mpi_sara, sara, imsizey, imsizex);
75  t_real const sigma =
76  world.broadcast(0.016820222945913496) * std::sqrt(2); // see test_parameters file
77  SECTION("global") {
78  auto const padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
79  factory::algo_distribution::mpi_serial, measurements_transform, wavelets, uv_data, sigma,
80  imsizey, imsizex, sara.size(), 300, true, true, false, 1e-2, 1e-3, 50, 1);
81 
82  auto const diagnostic = (*padmm)();
83  CHECK(diagnostic.niters == 10);
84 
85  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
86  const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
87 
88  const auto solution = pfitsio::read2d(expected_solution_path);
89  const auto residual = pfitsio::read2d(expected_residual_path);
90 
91  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
92  CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
93  CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
94  CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
95  CHECK(image.isApprox(solution, 1e-4));
96 
97  const Vector<t_complex> residuals = measurements_transform->adjoint() *
98  (uv_data.vis - ((*measurements_transform) * diagnostic.x));
99  const Image<t_complex> residual_image =
100  Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
101  CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
102  CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
103  CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
104  }
105  SECTION("local") {
106  auto const padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
107  factory::algo_distribution::mpi_distributed, measurements_transform, wavelets, uv_data,
108  sigma, imsizey, imsizex, sara.size(), 500, true, true, false, 1e-2, 1e-3, 50, 1);
109 
110  auto const diagnostic = (*padmm)();
111  t_real const epsilon = utilities::calculate_l2_radius(world.all_sum_all(uv_data.vis.size()),
112  world.broadcast(sigma));
113  CHECK(sopt::mpi::l2_norm(diagnostic.residual, padmm->l2ball_proximal_weights(), world) <
114  epsilon);
115  // the algorithm depends on nodes, so other than a basic bounds check,
116  // it is hard to know exact precision (might depend on probability theory...)
117  if (world.size() > 2 or world.size() == 0) return;
118  // testing the case where there are two nodes exactly.
119  const std::string &expected_solution_path = (world.size() == 2)
120  ? data_filename(test_dir + "mpi_solution.fits")
121  : data_filename(test_dir + "solution.fits");
122  const std::string &expected_residual_path = (world.size() == 2)
123  ? data_filename(test_dir + "mpi_residual.fits")
124  : data_filename(test_dir + "residual.fits");
125  if (world.size() == 1) CHECK(diagnostic.niters == 10);
126  if (world.size() == 2) CHECK(diagnostic.niters == 11);
127 
128  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
129  // if (world.is_root()) pfitsio::write2d(image.real(), expected_solution_path);
130  const auto solution = pfitsio::read2d(expected_solution_path);
131  CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
132  CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
133  CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
134  CHECK(image.isApprox(solution, 1e-4));
135 
136  const Vector<t_complex> residuals = measurements_transform->adjoint() *
137  (uv_data.vis - ((*measurements_transform) * diagnostic.x));
138  const Image<t_complex> residual_image =
139  Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
140  // if (world.is_root()) pfitsio::write2d(residual_image.real(), expected_residual_path);
141  const auto residual = pfitsio::read2d(expected_residual_path);
142  CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
143  CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
144  CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
145  }
146 }
147 
148 TEST_CASE("Serial vs. Serial with MPI Primal Dual", "[!shouldfail]") {
149  auto const world = sopt::mpi::Communicator::World();
150 
151  const std::string &test_dir = "expected/primal_dual/";
152  const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
153 
154  auto uv_data = dirty_visibilities({input_data_path}, world);
156  if (world.is_root()) {
157  CAPTURE(uv_data.vis.head(5));
158  }
159  REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
160 
161  t_uint const imsizey = 128;
162  t_uint const imsizex = 128;
163 
164  auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
166  1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
167  auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
168  *measurements_transform, 1000, 1e-5,
169  world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
170  const t_real op_norm = std::get<0>(power_method_stuff);
171  measurements_transform->set_norm(op_norm);
172 
173  std::vector<std::tuple<std::string, t_uint>> const sara{
174  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
175  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
176  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
177  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
178  factory::distributed_wavelet_operator::mpi_sara, sara, imsizey, imsizex);
179  t_real const sigma =
180  world.broadcast(0.016820222945913496) * std::sqrt(2); // see test_parameters file
181  SECTION("global") {
182  auto const primaldual =
183  factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
184  factory::algo_distribution::mpi_serial, measurements_transform, wavelets, uv_data,
185  sigma, imsizey, imsizex, sara.size(), 500, true, true, 1e-2, 1);
186 
187  auto const diagnostic = (*primaldual)();
188  CHECK(diagnostic.niters == 16);
189 
190  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
191  const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
192 
193  const auto solution = pfitsio::read2d(expected_solution_path);
194  const auto residual = pfitsio::read2d(expected_residual_path);
195 
196  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
197  CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
198  CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
199  CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
200  CHECK(image.isApprox(solution, 1e-4));
201 
202  const Vector<t_complex> residuals = measurements_transform->adjoint() *
203  (uv_data.vis - ((*measurements_transform) * diagnostic.x));
204  const Image<t_complex> residual_image =
205  Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
206  CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
207  CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
208  CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
209  }
210  SECTION("local") {
211  auto const primaldual =
212  factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
213  factory::algo_distribution::mpi_distributed, measurements_transform, wavelets, uv_data,
214  sigma, imsizey, imsizex, sara.size(), 500, true, true, 1e-2, 1);
215 
216  auto const diagnostic = (*primaldual)();
217  t_real const epsilon = utilities::calculate_l2_radius(world.all_sum_all(uv_data.vis.size()),
218  world.broadcast(sigma));
219  CHECK(sopt::mpi::l2_norm(diagnostic.residual, primaldual->l2ball_proximal_weights(), world) <
220  epsilon);
221  // the algorithm depends on nodes, so other than a basic bounds check,
222  // it is hard to know exact precision (might depend on probability theory...)
223  if (world.size() > 2 or world.size() == 0) return;
224  // testing the case where there are two nodes exactly.
225  const std::string &expected_solution_path = (world.size() == 2)
226  ? data_filename(test_dir + "mpi_solution.fits")
227  : data_filename(test_dir + "solution.fits");
228  const std::string &expected_residual_path = (world.size() == 2)
229  ? data_filename(test_dir + "mpi_residual.fits")
230  : data_filename(test_dir + "residual.fits");
231  if (world.size() == 1) CHECK(diagnostic.niters == 16);
232  if (world.size() == 2) CHECK(diagnostic.niters == 18);
233  const auto solution = pfitsio::read2d(expected_solution_path);
234  const auto residual = pfitsio::read2d(expected_residual_path);
235 
236  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
237  // if (world.is_root()) pfitsio::write2d(image.real(), expected_solution_path);
238  CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
239  CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
240  CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
241  CHECK(image.isApprox(solution, 1e-4));
242 
243  const Vector<t_complex> residuals = measurements_transform->adjoint() *
244  (uv_data.vis - ((*measurements_transform) * diagnostic.x));
245  const Image<t_complex> residual_image =
246  Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
247  // if (world.is_root()) pfitsio::write2d(residual_image.real(), expected_residual_path);
248  CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
249  CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
250  CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
251  }
252  SECTION("random update") {
253  auto const measurements_transform_serial =
254  factory::measurement_operator_factory<Vector<t_complex>>(
255  factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2,
256  kernels::kernel_from_string.at("kb"), 4, 4);
257  auto const power_method_stuff = sopt::algorithm::all_sum_all_power_method<Vector<t_complex>>(
258  world, *measurements_transform, 1000, 1e-5,
259  world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
260  const t_real op_norm = std::get<0>(power_method_stuff);
261  measurements_transform->set_norm(op_norm);
262 
263  auto sara_dist = sopt::wavelets::distribute_sara(sara, world);
264  auto const wavelets_serial = factory::wavelet_operator_factory<Vector<t_complex>>(
265  factory::distributed_wavelet_operator::serial, sara_dist, imsizey, imsizex);
266 
267  auto const primaldual =
268  factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
269  factory::algo_distribution::mpi_random_updates, measurements_transform_serial,
270  wavelets_serial, uv_data, sigma, imsizey, imsizex, sara_dist.size(), 500, true, true,
271  1e-2, 1);
272 
273  auto const diagnostic = (*primaldual)();
274  t_real const epsilon = utilities::calculate_l2_radius(world.all_sum_all(uv_data.vis.size()),
275  world.broadcast(sigma));
276  CHECK(sopt::mpi::l2_norm(diagnostic.residual, primaldual->l2ball_proximal_weights(), world) <
277  epsilon);
278  if (world.size() > 1) return;
279  // the algorithm depends on nodes, so other than a basic bounds check,
280  // it is hard to know exact precision (might depend on probability theory...)
281  if (world.size() == 0)
282  return;
283  else if (world.size() == 2 or world.size() == 1) {
284  // testing the case where there are two nodes exactly.
285  const std::string &expected_solution_path =
286  (world.size() == 2) ? data_filename(test_dir + "mpi_random_solution.fits")
287  : data_filename(test_dir + "solution.fits");
288  const std::string &expected_residual_path =
289  (world.size() == 2) ? data_filename(test_dir + "mpi_random_residual.fits")
290  : data_filename(test_dir + "residual.fits");
291  if (world.size() == 1) CHECK(diagnostic.niters == 16);
292  if (world.size() == 2) CHECK(diagnostic.niters < 100);
293 
294  const auto solution = pfitsio::read2d(expected_solution_path);
295  const auto residual = pfitsio::read2d(expected_residual_path);
296 
297  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
298  // if (world.is_root()) pfitsio::write2d(image.real(), expected_solution_path);
299  CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
300  CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
301  CAPTURE(
302  Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
303  CHECK(image.isApprox(solution, 1e-3));
304 
305  const Vector<t_complex> residuals =
306  measurements_transform->adjoint() *
307  (uv_data.vis - ((*measurements_transform) * diagnostic.x));
308  const Image<t_complex> residual_image =
309  Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
310  // if (world.is_root()) pfitsio::write2d(residual_image.real(), expected_residual_path);
311  CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
312  CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
313  CHECK(residual_image.real().isApprox(residual.real(), 1e-3));
314  } else
315  return;
316  }
317 }
318 
319 TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
320  auto const world = sopt::mpi::Communicator::World();
321 
322  const std::string &test_dir = "expected/fb/";
323  const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
324  const std::string &result_path = data_filename(test_dir + "mpi_fb_result.fits");
325 
326  auto uv_data = dirty_visibilities({input_data_path}, world);
328  if (world.is_root()) {
329  CAPTURE(uv_data.vis.head(5));
330  }
331  REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
332 
333  t_uint const imsizey = 128;
334  t_uint const imsizex = 128;
335 
336  auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
338  1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
339  auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
340  *measurements_transform, 1000, 1e-5,
341  world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
342  const t_real op_norm = std::get<0>(power_method_stuff);
343  measurements_transform->set_norm(op_norm);
344 
345  std::vector<std::tuple<std::string, t_uint>> const sara{
346  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
347  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
348  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
349  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
350  factory::distributed_wavelet_operator::mpi_sara, sara, imsizey, imsizex);
351  t_real const sigma =
352  world.broadcast(0.016820222945913496) * std::sqrt(2); // see test_parameters file
353  t_real const beta = sigma * sigma;
354  t_real const gamma = 0.0001;
355  auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
356  factory::algo_distribution::mpi_serial, measurements_transform, wavelets, uv_data, sigma,
357  beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50);
358 
359  auto const diagnostic = (*fb)();
360  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
361  if (world.is_root()) {
362  pfitsio::write2d(image.real(), result_path);
363  // pfitsio::write2d(residual_image.real(), expected_residual_path);
364  }
365 
366  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
367  const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
368 
369  const auto solution = pfitsio::read2d(expected_solution_path);
370  const auto residual = pfitsio::read2d(expected_residual_path);
371 
372  double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
373  SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
374  double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
375  .real()
376  .squaredNorm() /
377  solution.size();
378  SOPT_HIGH_LOG("MSE = {}", mse);
379  CHECK(mse <= average_intensity * 1e-3);
380 }
381 
382 #ifdef PURIFY_H5
383 TEST_CASE("MPI_fb_factory_hdf5") {
384  auto const world = sopt::mpi::Communicator::World();
385  const size_t N = 13107;
386 
387  const std::string &test_dir = "expected/fb/";
388  const std::string &input_data_path = data_filename(test_dir + "input_data.h5");
389  const std::string &result_path = data_filename(test_dir + "mpi_fb_result_hdf5.fits");
390  H5::H5Handler h5file(input_data_path, world);
391 
392  auto uv_data = H5::stochread_visibility(h5file, 6000, false);
393  uv_data.units = utilities::vis_units::radians;
394  if (world.is_root()) {
395  CAPTURE(uv_data.vis.head(5));
396  }
397  // REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
398 
399  t_uint const imsizey = 128;
400  t_uint const imsizex = 128;
401 
402  auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
404  1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
405  auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
406  *measurements_transform, 1000, 1e-5,
407  world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
408  const t_real op_norm = std::get<0>(power_method_stuff);
409  measurements_transform->set_norm(op_norm);
410 
411  std::vector<std::tuple<std::string, t_uint>> const sara{
412  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
413  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
414  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
415  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
416  factory::distributed_wavelet_operator::mpi_sara, sara, imsizey, imsizex);
417  t_real const sigma =
418  world.broadcast(0.016820222945913496) * std::sqrt(2); // see test_parameters file
419  t_real const beta = sigma * sigma;
420  t_real const gamma = 0.0001;
421  auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
422  factory::algo_distribution::mpi_serial, measurements_transform, wavelets, uv_data, sigma,
423  beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50);
424 
425  auto const diagnostic = (*fb)();
426  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
427  // if (world.is_root())
428  //{
429  // pfitsio::write2d(image.real(), result_path);
430  //}
431 
432  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
433  const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
434 
435  const auto solution = pfitsio::read2d(expected_solution_path);
436  const auto residual = pfitsio::read2d(expected_residual_path);
437 
438  double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
439  SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
440  double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
441  .real()
442  .squaredNorm() /
443  solution.size();
444  SOPT_HIGH_LOG("MSE = {}", mse);
445  CHECK(mse <= average_intensity * 1e-3);
446 }
447 
448 TEST_CASE("fb_factory_stochastic") {
449  const std::string &test_dir = "expected/fb/";
450  const std::string &input_data_path = data_filename(test_dir + "input_data.h5");
451  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
452  const std::string &result_path = data_filename(test_dir + "fb_stochastic_result_mpi.fits");
453 
454  // HDF5
455  auto const comm = sopt::mpi::Communicator::World();
456  const size_t N = 2000;
457  H5::H5Handler h5file(input_data_path, comm); // length 13107
458  using t_complexVec = Vector<t_complex>;
459 
460  // This functor would be defined in Purify
461  std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
462  [&h5file, &N, &comm]() {
463  utilities::vis_params uv_data =
464  H5::stochread_visibility(h5file, N, false); // no w-term in this data-set
466  auto phi = factory::measurement_operator_factory<t_complexVec>(
468  1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
469 
470  auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
471  *phi, 1000, 1e-5, comm.broadcast(Vector<t_complex>::Ones(128 * 128).eval()));
472  const t_real op_norm = std::get<0>(power_method_stuff);
473  phi->set_norm(op_norm);
474 
475  return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.vis, phi);
476  };
477 
478  const auto solution = pfitsio::read2d(expected_solution_path);
479 
480  t_uint const imsizey = 128;
481  t_uint const imsizex = 128;
482 
483  // wavelets
484  std::vector<std::tuple<std::string, t_uint>> const sara{
485  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
486  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
487  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
488  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
489  factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
490 
491  // algorithm
492  t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
493  t_real const beta = sigma * sigma;
494  t_real const gamma = 0.0001;
495 
496  sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
497  fb.itermax(1000)
498  .step_size(beta * sqrt(2))
499  .sigma(sigma * sqrt(2))
500  .regulariser_strength(gamma)
501  .relative_variation(1e-3)
502  .residual_tolerance(0)
503  .tight_frame(true)
504  .obj_comm(comm);
505 
506  auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
507  gp->l1_proximal_tolerance(1e-4)
508  .l1_proximal_nu(1)
509  .l1_proximal_itermax(50)
510  .l1_proximal_positivity_constraint(true)
511  .l1_proximal_real_constraint(true)
512  .Psi(*wavelets);
513  fb.g_function(gp);
514 
515  auto const diagnostic = fb();
516  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
517  // if (comm.is_root())
518  //{
519  // //pfitsio::write2d(image.real(), result_path);
520  //}
521 
522  auto soln_flat = Vector<t_complex>::Map(solution.data(), solution.size());
523  double average_intensity = soln_flat.real().sum() / soln_flat.size();
524  SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
525  double mse = (soln_flat - diagnostic.x).real().squaredNorm() / solution.size();
526  SOPT_HIGH_LOG("MSE = {}", mse);
527  CHECK(mse <= average_intensity * 1e-3);
528 }
529 #endif
#define CHECK(CONDITION, ERROR)
Definition: casa.cc:6
Purify interface class to handle HDF5 input files.
Definition: h5reader.h:48
TEST_CASE("Serial vs. Serial with MPI PADMM")
utilities::vis_params dirty_visibilities(const std::vector< std::string > &names)
const std::string test_dir
Definition: operators.cc:16
utilities::vis_params stochread_visibility(H5Handler &file, const size_t N, const bool w_term)
Stochastically reads dataset slices from the supplied HDF5-file handler, constructs a vis_params obje...
Definition: h5reader.h:206
std::vector< t_int > distribute_measurements(Vector< t_real > const &u, Vector< t_real > const &v, Vector< t_real > const &w, t_int const number_of_nodes, distribute::plan const distribution_plan, t_int const &grid_size)
Distribute visiblities into groups.
Definition: distribute.cc:6
const std::map< std::string, kernel > kernel_from_string
Definition: kernels.h:16
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
Definition: pfitsio.cc:30
Image< t_complex > read2d(const std::string &fits_name)
Read image from fits file.
Definition: pfitsio.cc:109
std::function< bool()> random_updater(const sopt::mpi::Communicator &comm, const t_int total, const t_int update_size, const std::shared_ptr< bool > update_pointer, const std::string &update_name)
vis_params scatter_visibilities(vis_params const &params, std::vector< t_int > const &sizes, sopt::mpi::Communicator const &comm)
vis_params regroup_and_scatter(vis_params const &params, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
utilities::vis_params read_visibility(const std::vector< std::string > &names, const bool w_term)
Read visibility files from name of vector.
t_real calculate_l2_radius(const t_uint y_size, const t_real &sigma, const t_real &n_sigma, const std::string distirbution)
A function that calculates the l2 ball radius for sopt.
Definition: utilities.cc:75
std::string data_filename(std::string const &filename)
Holds data and such.
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)
Vector< t_complex > vis
Definition: uvw_utilities.h:22