4 #include <catch2/catch_all.hpp>
7 #include "purify/directories.h"
13 #include <sopt/logging.h>
14 #include <sopt/mpi/communicator.h>
15 #include <sopt/mpi/utilities.h>
16 #include <sopt/power_method.h>
17 #include <sopt/wavelets.h>
33 sopt::mpi::Communicator
const &comm) {
45 auto const world = sopt::mpi::Communicator::World();
47 const std::string &
test_dir =
"expected/padmm/";
52 if (world.is_root()) {
53 CAPTURE(uv_data.vis.head(5));
55 REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
57 t_uint
const imsizey = 128;
58 t_uint
const imsizex = 128;
60 auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
63 Vector<t_complex>
const init = Vector<t_complex>::Ones(imsizex * imsizey).eval();
64 auto const power_method_stuff =
65 sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
66 const t_real op_norm = std::get<0>(power_method_stuff);
67 measurements_transform->set_norm(op_norm);
69 std::vector<std::tuple<std::string, t_uint>>
const sara{
70 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
71 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
72 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
73 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
76 world.broadcast(0.016820222945913496) * std::sqrt(2);
78 auto const padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
80 imsizey, imsizex, sara.size(), 300,
true,
true,
false, 1e-2, 1e-3, 50, 1);
82 auto const diagnostic = (*padmm)();
83 CHECK(diagnostic.niters == 10);
91 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
92 CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
93 CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
94 CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
95 CHECK(image.isApprox(solution, 1e-4));
97 const Vector<t_complex> residuals = measurements_transform->adjoint() *
98 (uv_data.vis - ((*measurements_transform) * diagnostic.x));
99 const Image<t_complex> residual_image =
100 Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
101 CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
102 CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
103 CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
106 auto const padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
108 sigma, imsizey, imsizex, sara.size(), 500,
true,
true,
false, 1e-2, 1e-3, 50, 1);
110 auto const diagnostic = (*padmm)();
112 world.broadcast(sigma));
113 CHECK(sopt::mpi::l2_norm(diagnostic.residual,
padmm->l2ball_proximal_weights(), world) <
117 if (world.size() > 2 or world.size() == 0)
return;
119 const std::string &expected_solution_path = (world.size() == 2)
122 const std::string &expected_residual_path = (world.size() == 2)
125 if (world.size() == 1)
CHECK(diagnostic.niters == 10);
126 if (world.size() == 2)
CHECK(diagnostic.niters == 11);
128 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
131 CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
132 CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
133 CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
134 CHECK(image.isApprox(solution, 1e-4));
136 const Vector<t_complex> residuals = measurements_transform->adjoint() *
137 (uv_data.vis - ((*measurements_transform) * diagnostic.x));
138 const Image<t_complex> residual_image =
139 Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
142 CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
143 CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
144 CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
148 TEST_CASE(
"Serial vs. Serial with MPI Primal Dual",
"[!shouldfail]") {
149 auto const world = sopt::mpi::Communicator::World();
151 const std::string &
test_dir =
"expected/primal_dual/";
156 if (world.is_root()) {
157 CAPTURE(uv_data.vis.head(5));
159 REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
161 t_uint
const imsizey = 128;
162 t_uint
const imsizex = 128;
164 auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
167 auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
168 *measurements_transform, 1000, 1e-5,
169 world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
170 const t_real op_norm = std::get<0>(power_method_stuff);
171 measurements_transform->set_norm(op_norm);
173 std::vector<std::tuple<std::string, t_uint>>
const sara{
174 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
175 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
176 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
177 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
180 world.broadcast(0.016820222945913496) * std::sqrt(2);
182 auto const primaldual =
183 factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
185 sigma, imsizey, imsizex, sara.size(), 500,
true,
true, 1e-2, 1);
187 auto const diagnostic = (*primaldual)();
188 CHECK(diagnostic.niters == 16);
196 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
197 CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
198 CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
199 CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
200 CHECK(image.isApprox(solution, 1e-4));
202 const Vector<t_complex> residuals = measurements_transform->adjoint() *
203 (uv_data.vis - ((*measurements_transform) * diagnostic.x));
204 const Image<t_complex> residual_image =
205 Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
206 CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
207 CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
208 CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
211 auto const primaldual =
212 factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
214 sigma, imsizey, imsizex, sara.size(), 500,
true,
true, 1e-2, 1);
216 auto const diagnostic = (*primaldual)();
218 world.broadcast(sigma));
219 CHECK(sopt::mpi::l2_norm(diagnostic.residual, primaldual->l2ball_proximal_weights(), world) <
223 if (world.size() > 2 or world.size() == 0)
return;
225 const std::string &expected_solution_path = (world.size() == 2)
228 const std::string &expected_residual_path = (world.size() == 2)
231 if (world.size() == 1)
CHECK(diagnostic.niters == 16);
232 if (world.size() == 2)
CHECK(diagnostic.niters == 18);
236 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
238 CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
239 CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
240 CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
241 CHECK(image.isApprox(solution, 1e-4));
243 const Vector<t_complex> residuals = measurements_transform->adjoint() *
244 (uv_data.vis - ((*measurements_transform) * diagnostic.x));
245 const Image<t_complex> residual_image =
246 Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
248 CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
249 CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
250 CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
252 SECTION(
"random update") {
253 auto const measurements_transform_serial =
254 factory::measurement_operator_factory<Vector<t_complex>>(
257 auto const power_method_stuff = sopt::algorithm::all_sum_all_power_method<Vector<t_complex>>(
258 world, *measurements_transform, 1000, 1e-5,
259 world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
260 const t_real op_norm = std::get<0>(power_method_stuff);
261 measurements_transform->set_norm(op_norm);
263 auto sara_dist = sopt::wavelets::distribute_sara(sara, world);
264 auto const wavelets_serial = factory::wavelet_operator_factory<Vector<t_complex>>(
267 auto const primaldual =
268 factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
270 wavelets_serial, uv_data, sigma, imsizey, imsizex, sara_dist.size(), 500,
true,
true,
273 auto const diagnostic = (*primaldual)();
275 world.broadcast(sigma));
276 CHECK(sopt::mpi::l2_norm(diagnostic.residual, primaldual->l2ball_proximal_weights(), world) <
278 if (world.size() > 1)
return;
281 if (world.size() == 0)
283 else if (world.size() == 2 or world.size() == 1) {
285 const std::string &expected_solution_path =
288 const std::string &expected_residual_path =
291 if (world.size() == 1)
CHECK(diagnostic.niters == 16);
292 if (world.size() == 2)
CHECK(diagnostic.niters < 100);
297 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
299 CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
300 CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
302 Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
303 CHECK(image.isApprox(solution, 1e-3));
305 const Vector<t_complex> residuals =
306 measurements_transform->adjoint() *
307 (uv_data.vis - ((*measurements_transform) * diagnostic.x));
308 const Image<t_complex> residual_image =
309 Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
311 CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
312 CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
313 CHECK(residual_image.real().isApprox(residual.real(), 1e-3));
319 TEST_CASE(
"Serial vs. Serial with MPI Forward Backward") {
320 auto const world = sopt::mpi::Communicator::World();
322 const std::string &
test_dir =
"expected/fb/";
328 if (world.is_root()) {
329 CAPTURE(uv_data.vis.head(5));
331 REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
333 t_uint
const imsizey = 128;
334 t_uint
const imsizex = 128;
336 auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
339 auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
340 *measurements_transform, 1000, 1e-5,
341 world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
342 const t_real op_norm = std::get<0>(power_method_stuff);
343 measurements_transform->set_norm(op_norm);
345 std::vector<std::tuple<std::string, t_uint>>
const sara{
346 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
347 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
348 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
349 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
352 world.broadcast(0.016820222945913496) * std::sqrt(2);
353 t_real
const beta = sigma * sigma;
354 t_real
const gamma = 0.0001;
355 auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
357 beta, gamma, imsizey, imsizex, sara.size(), 1000,
true,
true,
false, 1e-2, 1e-3, 50);
359 auto const diagnostic = (*fb)();
360 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
361 if (world.is_root()) {
372 double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
373 SOPT_HIGH_LOG(
"Average intensity = {}", average_intensity);
374 double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
378 SOPT_HIGH_LOG(
"MSE = {}", mse);
379 CHECK(mse <= average_intensity * 1e-3);
384 auto const world = sopt::mpi::Communicator::World();
385 const size_t N = 13107;
387 const std::string &
test_dir =
"expected/fb/";
394 if (world.is_root()) {
395 CAPTURE(uv_data.vis.head(5));
399 t_uint
const imsizey = 128;
400 t_uint
const imsizex = 128;
402 auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
405 auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
406 *measurements_transform, 1000, 1e-5,
407 world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
408 const t_real op_norm = std::get<0>(power_method_stuff);
409 measurements_transform->set_norm(op_norm);
411 std::vector<std::tuple<std::string, t_uint>>
const sara{
412 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
413 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
414 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
415 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
418 world.broadcast(0.016820222945913496) * std::sqrt(2);
419 t_real
const beta = sigma * sigma;
420 t_real
const gamma = 0.0001;
421 auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
423 beta, gamma, imsizey, imsizex, sara.size(), 1000,
true,
true,
false, 1e-2, 1e-3, 50);
425 auto const diagnostic = (*fb)();
426 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
438 double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
439 SOPT_HIGH_LOG(
"Average intensity = {}", average_intensity);
440 double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
444 SOPT_HIGH_LOG(
"MSE = {}", mse);
445 CHECK(mse <= average_intensity * 1e-3);
449 const std::string &
test_dir =
"expected/fb/";
455 auto const comm = sopt::mpi::Communicator::World();
456 const size_t N = 2000;
458 using t_complexVec = Vector<t_complex>;
461 std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()>
random_updater =
462 [&h5file, &N, &comm]() {
466 auto phi = factory::measurement_operator_factory<t_complexVec>(
470 auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
471 *phi, 1000, 1e-5, comm.broadcast(Vector<t_complex>::Ones(128 * 128).eval()));
472 const t_real op_norm = std::get<0>(power_method_stuff);
473 phi->set_norm(op_norm);
475 return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.
vis, phi);
480 t_uint
const imsizey = 128;
481 t_uint
const imsizex = 128;
484 std::vector<std::tuple<std::string, t_uint>>
const sara{
485 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
486 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
487 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
488 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
492 t_real
const sigma = 0.016820222945913496 * std::sqrt(2);
493 t_real
const beta = sigma * sigma;
494 t_real
const gamma = 0.0001;
496 sopt::algorithm::ImagingForwardBackward<t_complex> fb(
random_updater);
498 .step_size(beta * sqrt(2))
499 .sigma(sigma * sqrt(2))
500 .regulariser_strength(gamma)
501 .relative_variation(1e-3)
502 .residual_tolerance(0)
506 auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(
false);
507 gp->l1_proximal_tolerance(1e-4)
509 .l1_proximal_itermax(50)
510 .l1_proximal_positivity_constraint(
true)
511 .l1_proximal_real_constraint(
true)
515 auto const diagnostic = fb();
516 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
522 auto soln_flat = Vector<t_complex>::Map(solution.data(), solution.size());
523 double average_intensity = soln_flat.real().sum() / soln_flat.size();
524 SOPT_HIGH_LOG(
"Average intensity = {}", average_intensity);
525 double mse = (soln_flat - diagnostic.x).real().squaredNorm() / solution.size();
526 SOPT_HIGH_LOG(
"MSE = {}", mse);
527 CHECK(mse <= average_intensity * 1e-3);
#define CHECK(CONDITION, ERROR)
Purify interface class to handle HDF5 input files.
TEST_CASE("Serial vs. Serial with MPI PADMM")
utilities::vis_params dirty_visibilities(const std::vector< std::string > &names)
const std::string test_dir
utilities::vis_params stochread_visibility(H5Handler &file, const size_t N, const bool w_term)
Stochastically reads dataset slices from the supplied HDF5-file handler, constructs a vis_params obje...
std::vector< t_int > distribute_measurements(Vector< t_real > const &u, Vector< t_real > const &v, Vector< t_real > const &w, t_int const number_of_nodes, distribute::plan const distribution_plan, t_int const &grid_size)
Distribute visiblities into groups.
const std::map< std::string, kernel > kernel_from_string
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
Image< t_complex > read2d(const std::string &fits_name)
Read image from fits file.
std::function< bool()> random_updater(const sopt::mpi::Communicator &comm, const t_int total, const t_int update_size, const std::shared_ptr< bool > update_pointer, const std::string &update_name)
vis_params scatter_visibilities(vis_params const ¶ms, std::vector< t_int > const &sizes, sopt::mpi::Communicator const &comm)
vis_params regroup_and_scatter(vis_params const ¶ms, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
utilities::vis_params read_visibility(const std::vector< std::string > &names, const bool w_term)
Read visibility files from name of vector.
t_real calculate_l2_radius(const t_uint y_size, const t_real &sigma, const t_real &n_sigma, const std::string distirbution)
A function that calculates the l2 ball radius for sopt.
std::string data_filename(std::string const &filename)
Holds data and such.
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)