1 #include "catch2/catch_all.hpp" 
    8 #include <sopt/mpi/communicator.h> 
    9 #include <sopt/power_method.h> 
   14   auto const world = sopt::mpi::Communicator::World();
 
   18   uv_serial.u = world.broadcast(uv_serial.u);
 
   19   uv_serial.v = world.broadcast(uv_serial.v);
 
   20   uv_serial.w = world.broadcast(uv_serial.w);
 
   22   uv_serial.vis = world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(uv_serial.u.size()));
 
   24       world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(uv_serial.u.size()));
 
   27   if (world.is_root()) {
 
   34   auto const over_sample = 2;
 
   37   auto const width = 128;
 
   38   auto const height = 128;
 
   39   const Vector<t_complex> power_init =
 
   40       world.broadcast(Vector<t_complex>::Random(height * width).eval());
 
   41   const auto op_serial = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
 
   43           uv_serial.u, uv_serial.v, uv_serial.w, uv_serial.weights, height, width, over_sample),
 
   44       100, 1e-4, power_init));
 
   45   CAPTURE(world.size());
 
   49     const auto op = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
 
   51             method, uv_mpi.u, uv_mpi.v, uv_mpi.w, uv_mpi.weights, height, width, over_sample),
 
   52         100, 1e-4, power_init));
 
   53     if (uv_serial.u.size() == uv_mpi.u.size()) {
 
   54       REQUIRE(uv_serial.u.isApprox(uv_mpi.u));
 
   55       CHECK(uv_serial.v.isApprox(uv_mpi.v));
 
   56       CHECK(uv_serial.weights.isApprox(uv_mpi.weights));
 
   58     SECTION(
"Degridding") {
 
   59       Vector<t_complex> 
const image =
 
   60           world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(width * height));
 
   62       auto uv_degrid = uv_serial;
 
   63       if (world.is_root()) {
 
   64         uv_degrid.vis = *op_serial * image;
 
   70       Vector<t_complex> 
const degridded = *op * image;
 
   71       REQUIRE(degridded.size() == uv_degrid.vis.size());
 
   72       REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
 
   75       Vector<t_complex> 
const gridded = op->adjoint() * uv_mpi.vis;
 
   76       Vector<t_complex> 
const gridded_serial = op_serial->adjoint() * uv_serial.vis;
 
   77       REQUIRE(gridded.size() == gridded_serial.size());
 
   78       REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
 
   81   SECTION(
"All to All") {
 
   82     t_real 
const cell_size = 1;
 
   84     const std::vector<t_int> image_index = std::get<0>(kmeans);
 
   85     const std::vector<t_real> w_stacks = std::get<1>(kmeans);
 
   89     const auto op_wproj = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
 
   91             world, uv_stacks, height, width, cell_size, cell_size, over_sample, 
kernel, J, J, 
true),
 
   92         100, 1e-4, power_init));
 
   94     const auto op_wproj_all = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
 
   97             w_stacks, uv_mpi, height, width, cell_size, cell_size, over_sample, 
kernel, J, J, 
true),
 
   98         100, 1e-4, power_init));
 
   99     if (world.size() == 1) {
 
  100       REQUIRE(uv_serial.u.isApprox(uv_mpi.u));
 
  101       CHECK(uv_serial.v.isApprox(uv_mpi.v));
 
  102       CHECK(uv_serial.weights.isApprox(uv_mpi.weights));
 
  104     SECTION(
"Degridding") {
 
  105       Vector<t_complex> 
const image =
 
  106           world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(width * height));
 
  108       const Vector<t_complex> degridded = *op_wproj * image;
 
  109       auto uv_degrid = uv_mpi;
 
  110       uv_degrid.vis = *op_wproj_all * image;
 
  112       REQUIRE(degridded.size() == uv_degrid.vis.size());
 
  113       REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
 
  115     SECTION(
"Gridding") {
 
  116       Vector<t_complex> 
const gridded = op_wproj_all->adjoint() * uv_mpi.vis;
 
  117       Vector<t_complex> 
const gridded_serial = op_wproj->adjoint() * uv_stacks.vis;
 
  118       REQUIRE(gridded.size() == gridded_serial.size());
 
  119       REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
 
  122   SECTION(
"All to All wproj") {
 
  123     t_real 
const cell_size = 1;
 
  125     const std::vector<t_int> image_index = std::get<0>(kmeans);
 
  126     const std::vector<t_real> w_stacks = std::get<1>(kmeans);
 
  130     const auto op_wproj = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
 
  132             world, uv_stacks, height, width, cell_size, cell_size, over_sample, 
kernel, J, 10, 
true,
 
  134         100, 1e-4, power_init));
 
  136     const auto op_wproj_all = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
 
  139             w_stacks, uv_mpi, height, width, cell_size, cell_size, over_sample, 
kernel, J, 100,
 
  141         100, 1e-4, power_init));
 
  142     if (world.size() == 1) {
 
  143       REQUIRE(uv_serial.u.isApprox(uv_mpi.u));
 
  144       CHECK(uv_serial.v.isApprox(uv_mpi.v));
 
  145       CHECK(uv_serial.weights.isApprox(uv_mpi.weights));
 
  147     SECTION(
"Degridding") {
 
  148       Vector<t_complex> 
const image =
 
  149           world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(width * height));
 
  151       const Vector<t_complex> degridded = *op_wproj * image;
 
  152       auto uv_degrid = uv_mpi;
 
  153       uv_degrid.vis = *op_wproj_all * image;
 
  155       REQUIRE(degridded.size() == uv_degrid.vis.size());
 
  156       REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
 
  158     SECTION(
"Gridding") {
 
  159       Vector<t_complex> 
const gridded = op_wproj_all->adjoint() * uv_mpi.vis;
 
  160       Vector<t_complex> 
const gridded_serial = op_wproj->adjoint() * uv_stacks.vis;
 
  161       REQUIRE(gridded.size() == gridded_serial.size());
 
  162       REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
 
  168   auto const world = sopt::mpi::Communicator::World();
 
  172   uv_serial.u = world.broadcast(uv_serial.u);
 
  173   uv_serial.v = world.broadcast(uv_serial.v);
 
  174   uv_serial.w = world.broadcast(uv_serial.w);
 
  176   uv_serial.vis = world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(uv_serial.u.size()));
 
  178       world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(uv_serial.u.size()));
 
  181   if (world.is_root()) {
 
  188   auto const over_sample = 2;
 
  191   auto const width = 128;
 
  192   auto const height = 128;
 
  193   const Vector<t_complex> power_init =
 
  194       world.broadcast(Vector<t_complex>::Random(height * width).eval());
 
  195   const auto op_serial = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
 
  197           uv_serial.u, uv_serial.v, uv_serial.w, uv_serial.weights, height, width, over_sample),
 
  198       100, 1e-4, power_init));
 
  199   CAPTURE(world.size());
 
  203 #ifndef PURIFY_ARRAYFIRE 
  205         method, uv_mpi.u, uv_mpi.v, uv_mpi.w, uv_mpi.weights, height, width, over_sample));
 
  207     const auto op = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
 
  209             method, uv_mpi.u, uv_mpi.v, uv_mpi.w, uv_mpi.weights, height, width, over_sample),
 
  210         100, 1e-4, power_init));
 
  212     if (uv_serial.u.size() == uv_mpi.u.size()) {
 
  213       REQUIRE(uv_serial.u.isApprox(uv_mpi.u));
 
  214       CHECK(uv_serial.v.isApprox(uv_mpi.v));
 
  215       CHECK(uv_serial.weights.isApprox(uv_mpi.weights));
 
  217     SECTION(
"Degridding") {
 
  218       Vector<t_complex> 
const image =
 
  219           world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(width * height));
 
  221       auto uv_degrid = uv_serial;
 
  222       if (world.is_root()) {
 
  223         uv_degrid.vis = *op_serial * image;
 
  229       Vector<t_complex> 
const degridded = *op * image;
 
  230       REQUIRE(degridded.size() == uv_degrid.vis.size());
 
  231       REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
 
  233     SECTION(
"Gridding") {
 
  234       Vector<t_complex> 
const gridded = op->adjoint() * uv_mpi.vis;
 
  235       Vector<t_complex> 
const gridded_serial = op_serial->adjoint() * uv_serial.vis;
 
  236       REQUIRE(gridded.size() == gridded_serial.size());
 
  237       REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
 
#define CHECK(CONDITION, ERROR)
 
TEST_CASE("Serial vs Distributed Operator")
 
const t_real pi
mathematical constant
 
std::tuple< std::vector< t_int >, std::vector< t_real > > kmeans_algo(const Vector< t_real > &w, const t_int number_of_nodes, const t_int iters, const std::function< t_real(t_real)> &cost, const t_real rel_diff)
patition w terms using k-means
 
std::vector< t_int > distribute_measurements(Vector< t_real > const &u, Vector< t_real > const &v, Vector< t_real > const &w, t_int const number_of_nodes, distribute::plan const distribution_plan, t_int const &grid_size)
Distribute visiblities into groups.
 
@ gpu_mpi_distribute_image
 
@ gpu_mpi_distribute_grid
 
@ mpi_distribute_all_to_all
 
std::shared_ptr< sopt::LinearTransform< T > > measurement_operator_factory(const distributed_measurement_operator distribute, ARGS &&...args)
distributed measurement operator factory
 
std::shared_ptr< sopt::LinearTransform< T > > all_to_all_measurement_operator_factory(const distributed_measurement_operator distribute, const std::vector< t_int > &image_stacks, const std::vector< t_real > &w_stacks, ARGS &&...args)
distributed measurement operator factory
 
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
 
std::shared_ptr< sopt::LinearTransform< T > > init_degrid_operator_2d(const Vector< t_real > &u, const Vector< t_real > &v, const Vector< t_real > &w, const Vector< t_complex > &weights, const t_uint &imsizey, const t_uint &imsizex, const t_real &oversample_ratio=2, const kernels::kernel kernel=kernels::kernel::kb, const t_uint Ju=4, const t_uint Jv=4, const bool w_stacking=false, const t_real &cellx=1, const t_real &celly=1)
Returns linear transform that is the standard degridding operator.
 
std::tuple< vis_params, std::vector< t_int > > regroup_and_all_to_all(vis_params const ¶ms, const std::vector< t_int > &image_index, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
 
vis_params scatter_visibilities(vis_params const ¶ms, std::vector< t_int > const &sizes, sopt::mpi::Communicator const &comm)
 
vis_params regroup_and_scatter(vis_params const ¶ms, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
 
utilities::vis_params random_sample_density(const t_int vis_num, const t_real mean, const t_real standard_deviation, const t_real rms_w)
Generates a random visibility coverage.