1 #include "catch2/catch_all.hpp"
3 #include "purify/config.h"
9 #include <sopt/mpi/communicator.h>
15 auto const world = sopt::mpi::Communicator::World();
16 t_int
const number_of_groups = world.size();
19 CAPTURE(world.rank());
20 CAPTURE(world.size());
23 uv_data.
u = world.broadcast<Vector<t_real>>(Vector<t_real>::Random(N));
24 uv_data.
v = world.broadcast<Vector<t_real>>(Vector<t_real>::Random(N));
25 uv_data.
w = world.broadcast<Vector<t_real>>(Vector<t_real>::Random(N));
26 uv_data.
vis = world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(N));
27 uv_data.
weights = world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(N));
30 const std::vector<t_int> serial_index = std::get<0>(kmeans_result_serial);
31 const std::vector<t_real> serial_means = std::get<1>(kmeans_result_serial);
32 REQUIRE(serial_means.size() == number_of_groups);
34 t_int
const start = world.rank() * (N - (N % world.size())) / world.size();
35 t_int
const length = (N - (N % world.size())) / world.size() +
36 ((world.rank() == (world.size() - 1)) ? (N % world.size()) : 0);
39 CAPTURE(N % world.size());
40 const Vector<t_real> w_segment = uv_data.
w.segment(start, length);
41 CAPTURE(w_segment.size());
43 const std::vector<t_int> mpi_index = std::get<0>(kmeans_result_mpi);
44 const std::vector<t_real> mpi_means = std::get<1>(kmeans_result_mpi);
45 REQUIRE(mpi_means.size() == number_of_groups);
46 REQUIRE(w_segment.size() == length);
48 for (t_int j = 0; j < length; j++) {
51 REQUIRE(serial_index.at(start + j) == mpi_index.at(j));
53 for (t_int g = 0; g < number_of_groups; g++)
54 REQUIRE(mpi_means.at(g) == Approx(serial_means.at(g)));
60 #define TEST_MACRO(param) \
62 Vector<t_real> a = uv_dist_all.param.cwiseAbs(); \
63 Vector<t_real> b = uv_dist_scatter.param.cwiseAbs(); \
64 std::sort(a.data(), a.data() + a.size()); \
65 std::sort(b.data(), b.data() + b.size()); \
68 REQUIRE(a.isApprox(b)); \
79 const t_uint M = 1000;
80 const t_int min_support = 4;
81 const t_int max_support = 100;
83 auto const comm = sopt::mpi::Communicator::World();
87 params.w, comm.size(), 100, comm, [](t_real x) { return x * x; }, 1e-5);
88 const std::vector<t_int> image_index = std::get<0>(kmeans);
89 const std::vector<t_real> w_stacks = std::get<1>(kmeans);
90 const std::vector<t_int> groups =
94 params, du, min_support, max_support, comm, 100, 0, [](t_real x) {
return x * x; }, 1e-5);
95 CHECK(std::get<0>(sorted).
u.isApprox(std::get<0>(data).u));
96 CHECK(std::get<0>(sorted).
v.isApprox(std::get<0>(data).v));
97 CHECK(std::get<0>(sorted).w.isApprox(std::get<0>(data).w));
98 CHECK(std::get<1>(sorted) == std::get<1>(data));
99 CHECK(w_stacks == std::get<2>(data));
#define CHECK(CONDITION, ERROR)
#define TEST_MACRO(param)
const std::vector< t_real > u
data for u coordinate
const std::vector< t_real > v
data for v coordinate
const t_real pi
mathematical constant
std::tuple< std::vector< t_int >, std::vector< t_real > > kmeans_algo(const Vector< t_real > &w, const t_int number_of_nodes, const t_int iters, const std::function< t_real(t_real)> &cost, const t_real rel_diff)
patition w terms using k-means
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
std::tuple< vis_params, std::vector< t_int > > regroup_and_all_to_all(vis_params const ¶ms, const std::vector< t_int > &image_index, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
std::tuple< utilities::vis_params, std::vector< t_int >, std::vector< t_real > > w_stacking_with_all_to_all(utilities::vis_params const ¶ms, const t_real du, const t_int min_support, const t_int max_support, sopt::mpi::Communicator const &comm, const t_int iters, const t_real fill_relaxation, const std::function< t_real(t_real)> &cost, const t_real k_means_rel_diff)
vis_params regroup_and_scatter(vis_params const ¶ms, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
utilities::vis_params random_sample_density(const t_int vis_num, const t_real mean, const t_real standard_deviation, const t_real rms_w)
Generates a random visibility coverage.
t_int w_support(const t_real w, const t_real du, const t_int min, const t_int max)
estimate support size of w given u resolution du
vis_params segment(const t_uint pos, const t_uint length) const
return subset of measurements
Vector< t_complex > weights