4 #include <catch2/catch_all.hpp> 
    7 #include "purify/directories.h" 
   13 #include <sopt/logging.h> 
   14 #include <sopt/mpi/communicator.h> 
   15 #include <sopt/mpi/utilities.h> 
   16 #include <sopt/power_method.h> 
   17 #include <sopt/wavelets.h> 
   33                                          sopt::mpi::Communicator 
const &comm) {
 
   45   auto const world = sopt::mpi::Communicator::World();
 
   47   const std::string &
test_dir = 
"expected/padmm/";
 
   52   if (world.is_root()) {
 
   53     CAPTURE(uv_data.vis.head(5));
 
   55   REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
 
   57   t_uint 
const imsizey = 128;
 
   58   t_uint 
const imsizex = 128;
 
   60   auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
 
   63   Vector<t_complex> 
const init = Vector<t_complex>::Ones(imsizex * imsizey).eval();
 
   64   auto const power_method_stuff =
 
   65       sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
 
   66   const t_real op_norm = std::get<0>(power_method_stuff);
 
   67   measurements_transform->set_norm(op_norm);
 
   69   std::vector<std::tuple<std::string, t_uint>> 
const sara{
 
   70       std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
 
   71       std::make_tuple(
"DB3", 3u),   std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
 
   72       std::make_tuple(
"DB6", 3u),   std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
 
   73   auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
 
   76       world.broadcast(0.016820222945913496) * std::sqrt(2);  
 
   78     auto const padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
 
   80         imsizey, imsizex, sara.size(), 300, 
true, 
true, 
false, 1e-2, 1e-3, 50, 1);
 
   82     auto const diagnostic = (*padmm)();
 
   83     CHECK(diagnostic.niters == 10);
 
   91     const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
 
   92     CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
 
   93     CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
 
   94     CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
 
   95     CHECK(image.isApprox(solution, 1e-4));
 
   97     const Vector<t_complex> residuals = measurements_transform->adjoint() *
 
   98                                         (uv_data.vis - ((*measurements_transform) * diagnostic.x));
 
   99     const Image<t_complex> residual_image =
 
  100         Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
 
  101     CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
 
  102     CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
 
  103     CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
 
  106     auto const padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
 
  108         sigma, imsizey, imsizex, sara.size(), 500, 
true, 
true, 
false, 1e-2, 1e-3, 50, 1);
 
  110     auto const diagnostic = (*padmm)();
 
  112                                                           world.broadcast(sigma));
 
  113     CHECK(sopt::mpi::l2_norm(diagnostic.residual, 
padmm->l2ball_proximal_weights(), world) <
 
  117     if (world.size() > 2 or world.size() == 0) 
return;
 
  119     const std::string &expected_solution_path = (world.size() == 2)
 
  122     const std::string &expected_residual_path = (world.size() == 2)
 
  125     if (world.size() == 1) 
CHECK(diagnostic.niters == 10);
 
  126     if (world.size() == 2) 
CHECK(diagnostic.niters == 11);
 
  128     const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
 
  131     CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
 
  132     CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
 
  133     CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
 
  134     CHECK(image.isApprox(solution, 1e-4));
 
  136     const Vector<t_complex> residuals = measurements_transform->adjoint() *
 
  137                                         (uv_data.vis - ((*measurements_transform) * diagnostic.x));
 
  138     const Image<t_complex> residual_image =
 
  139         Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
 
  142     CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
 
  143     CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
 
  144     CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
 
  148 TEST_CASE(
"Serial vs. Serial with MPI Primal Dual", 
"[!shouldfail]") {
 
  149   auto const world = sopt::mpi::Communicator::World();
 
  151   const std::string &
test_dir = 
"expected/primal_dual/";
 
  156   if (world.is_root()) {
 
  157     CAPTURE(uv_data.vis.head(5));
 
  159   REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
 
  161   t_uint 
const imsizey = 128;
 
  162   t_uint 
const imsizex = 128;
 
  164   auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
 
  167   auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
 
  168       *measurements_transform, 1000, 1e-5,
 
  169       world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
 
  170   const t_real op_norm = std::get<0>(power_method_stuff);
 
  171   measurements_transform->set_norm(op_norm);
 
  173   std::vector<std::tuple<std::string, t_uint>> 
const sara{
 
  174       std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
 
  175       std::make_tuple(
"DB3", 3u),   std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
 
  176       std::make_tuple(
"DB6", 3u),   std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
 
  177   auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
 
  180       world.broadcast(0.016820222945913496) * std::sqrt(2);  
 
  182     auto const primaldual =
 
  183         factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
 
  185             sigma, imsizey, imsizex, sara.size(), 500, 
true, 
true, 1e-2, 1);
 
  187     auto const diagnostic = (*primaldual)();
 
  188     CHECK(diagnostic.niters == 16);
 
  196     const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
 
  197     CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
 
  198     CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
 
  199     CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
 
  200     CHECK(image.isApprox(solution, 1e-4));
 
  202     const Vector<t_complex> residuals = measurements_transform->adjoint() *
 
  203                                         (uv_data.vis - ((*measurements_transform) * diagnostic.x));
 
  204     const Image<t_complex> residual_image =
 
  205         Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
 
  206     CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
 
  207     CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
 
  208     CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
 
  211     auto const primaldual =
 
  212         factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
 
  214             sigma, imsizey, imsizex, sara.size(), 500, 
true, 
true, 1e-2, 1);
 
  216     auto const diagnostic = (*primaldual)();
 
  218                                                           world.broadcast(sigma));
 
  219     CHECK(sopt::mpi::l2_norm(diagnostic.residual, primaldual->l2ball_proximal_weights(), world) <
 
  223     if (world.size() > 2 or world.size() == 0) 
return;
 
  225     const std::string &expected_solution_path = (world.size() == 2)
 
  228     const std::string &expected_residual_path = (world.size() == 2)
 
  231     if (world.size() == 1) 
CHECK(diagnostic.niters == 16);
 
  232     if (world.size() == 2) 
CHECK(diagnostic.niters == 18);
 
  236     const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
 
  238     CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
 
  239     CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
 
  240     CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
 
  241     CHECK(image.isApprox(solution, 1e-4));
 
  243     const Vector<t_complex> residuals = measurements_transform->adjoint() *
 
  244                                         (uv_data.vis - ((*measurements_transform) * diagnostic.x));
 
  245     const Image<t_complex> residual_image =
 
  246         Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
 
  248     CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
 
  249     CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
 
  250     CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
 
  252   SECTION(
"random update") {
 
  253     auto const measurements_transform_serial =
 
  254         factory::measurement_operator_factory<Vector<t_complex>>(
 
  257     auto const power_method_stuff = sopt::algorithm::all_sum_all_power_method<Vector<t_complex>>(
 
  258         world, *measurements_transform, 1000, 1e-5,
 
  259         world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
 
  260     const t_real op_norm = std::get<0>(power_method_stuff);
 
  261     measurements_transform->set_norm(op_norm);
 
  263     auto sara_dist = sopt::wavelets::distribute_sara(sara, world);
 
  264     auto const wavelets_serial = factory::wavelet_operator_factory<Vector<t_complex>>(
 
  267     auto const primaldual =
 
  268         factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
 
  270             wavelets_serial, uv_data, sigma, imsizey, imsizex, sara_dist.size(), 500, 
true, 
true,
 
  273     auto const diagnostic = (*primaldual)();
 
  275                                                           world.broadcast(sigma));
 
  276     CHECK(sopt::mpi::l2_norm(diagnostic.residual, primaldual->l2ball_proximal_weights(), world) <
 
  278     if (world.size() > 1) 
return;
 
  281     if (world.size() == 0)
 
  283     else if (world.size() == 2 or world.size() == 1) {
 
  285       const std::string &expected_solution_path =
 
  288       const std::string &expected_residual_path =
 
  291       if (world.size() == 1) 
CHECK(diagnostic.niters == 16);
 
  292       if (world.size() == 2) 
CHECK(diagnostic.niters < 100);
 
  297       const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
 
  299       CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
 
  300       CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
 
  302           Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
 
  303       CHECK(image.isApprox(solution, 1e-3));
 
  305       const Vector<t_complex> residuals =
 
  306           measurements_transform->adjoint() *
 
  307           (uv_data.vis - ((*measurements_transform) * diagnostic.x));
 
  308       const Image<t_complex> residual_image =
 
  309           Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
 
  311       CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
 
  312       CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
 
  313       CHECK(residual_image.real().isApprox(residual.real(), 1e-3));
 
  319 TEST_CASE(
"Serial vs. Serial with MPI Forward Backward") {
 
  320   auto const world = sopt::mpi::Communicator::World();
 
  322   const std::string &
test_dir = 
"expected/fb/";
 
  328   if (world.is_root()) {
 
  329     CAPTURE(uv_data.vis.head(5));
 
  331   REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
 
  333   t_uint 
const imsizey = 128;
 
  334   t_uint 
const imsizex = 128;
 
  336   auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
 
  339   auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
 
  340       *measurements_transform, 1000, 1e-5,
 
  341       world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
 
  342   const t_real op_norm = std::get<0>(power_method_stuff);
 
  343   measurements_transform->set_norm(op_norm);
 
  345   std::vector<std::tuple<std::string, t_uint>> 
const sara{
 
  346       std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
 
  347       std::make_tuple(
"DB3", 3u),   std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
 
  348       std::make_tuple(
"DB6", 3u),   std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
 
  349   auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
 
  352       world.broadcast(0.016820222945913496) * std::sqrt(2);  
 
  353   t_real 
const beta = sigma * sigma;
 
  354   t_real 
const gamma = 0.0001;
 
  355   auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
 
  357       beta, gamma, imsizey, imsizex, sara.size(), 1000, 
true, 
true, 
false, 1e-2, 1e-3, 50);
 
  359   auto const diagnostic = (*fb)();
 
  360   const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
 
  361   if (world.is_root()) {
 
  372   double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
 
  373   SOPT_HIGH_LOG(
"Average intensity = {}", average_intensity);
 
  374   double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
 
  378   SOPT_HIGH_LOG(
"MSE = {}", mse);
 
  379   CHECK(mse <= average_intensity * 1e-3);
 
  384   auto const world = sopt::mpi::Communicator::World();
 
  385   const size_t N = 13107;
 
  387   const std::string &
test_dir = 
"expected/fb/";
 
  394   if (world.is_root()) {
 
  395     CAPTURE(uv_data.vis.head(5));
 
  399   t_uint 
const imsizey = 128;
 
  400   t_uint 
const imsizex = 128;
 
  402   auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
 
  405   auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
 
  406       *measurements_transform, 1000, 1e-5,
 
  407       world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
 
  408   const t_real op_norm = std::get<0>(power_method_stuff);
 
  409   measurements_transform->set_norm(op_norm);
 
  411   std::vector<std::tuple<std::string, t_uint>> 
const sara{
 
  412       std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
 
  413       std::make_tuple(
"DB3", 3u),   std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
 
  414       std::make_tuple(
"DB6", 3u),   std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
 
  415   auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
 
  418       world.broadcast(0.016820222945913496) * std::sqrt(2);  
 
  419   t_real 
const beta = sigma * sigma;
 
  420   t_real 
const gamma = 0.0001;
 
  421   auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
 
  423       beta, gamma, imsizey, imsizex, sara.size(), 1000, 
true, 
true, 
false, 1e-2, 1e-3, 50);
 
  425   auto const diagnostic = (*fb)();
 
  426   const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
 
  438   double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
 
  439   SOPT_HIGH_LOG(
"Average intensity = {}", average_intensity);
 
  440   double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
 
  444   SOPT_HIGH_LOG(
"MSE = {}", mse);
 
  445   CHECK(mse <= average_intensity * 1e-3);
 
  449   const std::string &
test_dir = 
"expected/fb/";
 
  455   auto const comm = sopt::mpi::Communicator::World();
 
  456   const size_t N = 2000;
 
  458   using t_complexVec = Vector<t_complex>;
 
  461   std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> 
random_updater =
 
  462       [&h5file, &N, &comm]() {
 
  466         auto phi = factory::measurement_operator_factory<t_complexVec>(
 
  470         auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
 
  471             *phi, 1000, 1e-5, comm.broadcast(Vector<t_complex>::Ones(128 * 128).eval()));
 
  472         const t_real op_norm = std::get<0>(power_method_stuff);
 
  473         phi->set_norm(op_norm);
 
  475         return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.
vis, phi);
 
  480   t_uint 
const imsizey = 128;
 
  481   t_uint 
const imsizex = 128;
 
  484   std::vector<std::tuple<std::string, t_uint>> 
const sara{
 
  485       std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
 
  486       std::make_tuple(
"DB3", 3u),   std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
 
  487       std::make_tuple(
"DB6", 3u),   std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
 
  488   auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
 
  492   t_real 
const sigma = 0.016820222945913496 * std::sqrt(2);  
 
  493   t_real 
const beta = sigma * sigma;
 
  494   t_real 
const gamma = 0.0001;
 
  496   sopt::algorithm::ImagingForwardBackward<t_complex> fb(
random_updater);
 
  498       .step_size(beta * sqrt(2))
 
  499       .sigma(sigma * sqrt(2))
 
  500       .regulariser_strength(gamma)
 
  501       .relative_variation(1e-3)
 
  502       .residual_tolerance(0)
 
  506   auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(
false);
 
  507   gp->l1_proximal_tolerance(1e-4)
 
  509       .l1_proximal_itermax(50)
 
  510       .l1_proximal_positivity_constraint(
true)
 
  511       .l1_proximal_real_constraint(
true)
 
  515   auto const diagnostic = fb();
 
  516   const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
 
  522   auto soln_flat = Vector<t_complex>::Map(solution.data(), solution.size());
 
  523   double average_intensity = soln_flat.real().sum() / soln_flat.size();
 
  524   SOPT_HIGH_LOG(
"Average intensity = {}", average_intensity);
 
  525   double mse = (soln_flat - diagnostic.x).real().squaredNorm() / solution.size();
 
  526   SOPT_HIGH_LOG(
"MSE = {}", mse);
 
  527   CHECK(mse <= average_intensity * 1e-3);
 
#define CHECK(CONDITION, ERROR)
 
Purify interface class to handle HDF5 input files.
 
TEST_CASE("Serial vs. Serial with MPI PADMM")
 
utilities::vis_params dirty_visibilities(const std::vector< std::string > &names)
 
const std::string test_dir
 
utilities::vis_params stochread_visibility(H5Handler &file, const size_t N, const bool w_term)
Stochastically reads dataset slices from the supplied HDF5-file handler, constructs a vis_params obje...
 
std::vector< t_int > distribute_measurements(Vector< t_real > const &u, Vector< t_real > const &v, Vector< t_real > const &w, t_int const number_of_nodes, distribute::plan const distribution_plan, t_int const &grid_size)
Distribute visiblities into groups.
 
const std::map< std::string, kernel > kernel_from_string
 
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
 
Image< t_complex > read2d(const std::string &fits_name)
Read image from fits file.
 
std::function< bool()> random_updater(const sopt::mpi::Communicator &comm, const t_int total, const t_int update_size, const std::shared_ptr< bool > update_pointer, const std::string &update_name)
 
vis_params scatter_visibilities(vis_params const ¶ms, std::vector< t_int > const &sizes, sopt::mpi::Communicator const &comm)
 
vis_params regroup_and_scatter(vis_params const ¶ms, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
 
utilities::vis_params read_visibility(const std::vector< std::string > &names, const bool w_term)
Read visibility files from name of vector.
 
t_real calculate_l2_radius(const t_uint y_size, const t_real &sigma, const t_real &n_sigma, const std::string distirbution)
A function that calculates the l2 ball radius for sopt.
 
std::string data_filename(std::string const &filename)
Holds data and such.
 
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)