PURIFY
Next-generation radio interferometric imaging
Macros | Functions
kmeans.cc File Reference
#include "catch2/catch_all.hpp"
#include "purify/config.h"
#include "purify/types.h"
#include "purify/distribute.h"
#include "purify/mpi_utilities.h"
#include "purify/utilities.h"
#include <sopt/mpi/communicator.h>
+ Include dependency graph for kmeans.cc:

Go to the source code of this file.

Macros

#define TEST_MACRO(param)
 

Functions

 TEST_CASE ("k-means")
 
 TEST_CASE ("distribute w")
 

Macro Definition Documentation

◆ TEST_MACRO

#define TEST_MACRO (   param)
Value:
{ \
Vector<t_real> a = uv_dist_all.param.cwiseAbs(); \
Vector<t_real> b = uv_dist_scatter.param.cwiseAbs(); \
std::sort(a.data(), a.data() + a.size()); \
std::sort(b.data(), b.data() + b.size()); \
CAPTURE(a.head(5)); \
CAPTURE(b.head(5)); \
REQUIRE(a.isApprox(b)); \
}

Function Documentation

◆ TEST_CASE() [1/2]

TEST_CASE ( "distribute w"  )

Definition at line 77 of file kmeans.cc.

77  {
79  const t_uint M = 1000;
80  const t_int min_support = 4;
81  const t_int max_support = 100;
82  const t_real du = 1;
83  auto const comm = sopt::mpi::Communicator::World();
84 
85  const auto params = utilities::random_sample_density(M, 0, constant::pi / 3, 100);
86  const auto kmeans = distribute::kmeans_algo(
87  params.w, comm.size(), 100, comm, [](t_real x) { return x * x; }, 1e-5);
88  const std::vector<t_int> image_index = std::get<0>(kmeans);
89  const std::vector<t_real> w_stacks = std::get<1>(kmeans);
90  const std::vector<t_int> groups =
91  distribute::w_support(params.w, image_index, w_stacks, du, min_support, max_support, 0, comm);
92  auto sorted = utilities::regroup_and_all_to_all(params, image_index, groups, comm);
93  const auto data = utilities::w_stacking_with_all_to_all(
94  params, du, min_support, max_support, comm, 100, 0, [](t_real x) { return x * x; }, 1e-5);
95  CHECK(std::get<0>(sorted).u.isApprox(std::get<0>(data).u));
96  CHECK(std::get<0>(sorted).v.isApprox(std::get<0>(data).v));
97  CHECK(std::get<0>(sorted).w.isApprox(std::get<0>(data).w));
98  CHECK(std::get<1>(sorted) == std::get<1>(data));
99  CHECK(w_stacks == std::get<2>(data));
100 }
#define CHECK(CONDITION, ERROR)
Definition: casa.cc:6
const std::vector< t_real > u
data for u coordinate
Definition: operators.cc:18
const std::vector< t_real > v
data for v coordinate
Definition: operators.cc:20
const t_real pi
mathematical constant
Definition: types.h:70
std::tuple< std::vector< t_int >, std::vector< t_real > > kmeans_algo(const Vector< t_real > &w, const t_int number_of_nodes, const t_int iters, const std::function< t_real(t_real)> &cost, const t_real rel_diff)
patition w terms using k-means
Definition: distribute.cc:103
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
Definition: logging.h:137
std::tuple< vis_params, std::vector< t_int > > regroup_and_all_to_all(vis_params const &params, const std::vector< t_int > &image_index, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
std::tuple< utilities::vis_params, std::vector< t_int >, std::vector< t_real > > w_stacking_with_all_to_all(utilities::vis_params const &params, const t_real du, const t_int min_support, const t_int max_support, sopt::mpi::Communicator const &comm, const t_int iters, const t_real fill_relaxation, const std::function< t_real(t_real)> &cost, const t_real k_means_rel_diff)
utilities::vis_params random_sample_density(const t_int vis_num, const t_real mean, const t_real standard_deviation, const t_real rms_w)
Generates a random visibility coverage.
t_int w_support(const t_real w, const t_real du, const t_int min, const t_int max)
estimate support size of w given u resolution du

References CHECK, purify::distribute::kmeans_algo(), purify::constant::pi, purify::utilities::random_sample_density(), purify::utilities::regroup_and_all_to_all(), purify::logging::set_level(), operators_test::u, operators_test::v, purify::utilities::w_stacking_with_all_to_all(), and purify::widefield::w_support().

◆ TEST_CASE() [2/2]

TEST_CASE ( "k-means"  )

Definition at line 14 of file kmeans.cc.

14  {
15  auto const world = sopt::mpi::Communicator::World();
16  t_int const number_of_groups = world.size();
17  t_int const N = 1e5;
18 
19  CAPTURE(world.rank());
20  CAPTURE(world.size());
21  utilities::vis_params uv_data;
22 
23  uv_data.u = world.broadcast<Vector<t_real>>(Vector<t_real>::Random(N));
24  uv_data.v = world.broadcast<Vector<t_real>>(Vector<t_real>::Random(N));
25  uv_data.w = world.broadcast<Vector<t_real>>(Vector<t_real>::Random(N));
26  uv_data.vis = world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(N));
27  uv_data.weights = world.broadcast<Vector<t_complex>>(Vector<t_complex>::Random(N));
28 
29  const auto kmeans_result_serial = distribute::kmeans_algo(uv_data.w, number_of_groups, 100);
30  const std::vector<t_int> serial_index = std::get<0>(kmeans_result_serial);
31  const std::vector<t_real> serial_means = std::get<1>(kmeans_result_serial);
32  REQUIRE(serial_means.size() == number_of_groups);
33 
34  t_int const start = world.rank() * (N - (N % world.size())) / world.size();
35  t_int const length = (N - (N % world.size())) / world.size() +
36  ((world.rank() == (world.size() - 1)) ? (N % world.size()) : 0);
37  CAPTURE(start);
38  CAPTURE(length);
39  CAPTURE(N % world.size());
40  const Vector<t_real> w_segment = uv_data.w.segment(start, length);
41  CAPTURE(w_segment.size());
42  const auto kmeans_result_mpi = distribute::kmeans_algo(w_segment, number_of_groups, 100, world);
43  const std::vector<t_int> mpi_index = std::get<0>(kmeans_result_mpi);
44  const std::vector<t_real> mpi_means = std::get<1>(kmeans_result_mpi);
45  REQUIRE(mpi_means.size() == number_of_groups);
46  REQUIRE(w_segment.size() == length);
47  // Check that serial and mpi kmeans give the same results
48  for (t_int j = 0; j < length; j++) {
49  CAPTURE(j);
50  CAPTURE(start + j);
51  REQUIRE(serial_index.at(start + j) == mpi_index.at(j));
52  }
53  for (t_int g = 0; g < number_of_groups; g++)
54  REQUIRE(mpi_means.at(g) == Approx(serial_means.at(g)));
55  // Check that mean values are the same
56  auto uv_dist_all =
57  utilities::regroup_and_all_to_all(uv_data.segment(start, length), mpi_index, world);
58  auto uv_dist_scatter = utilities::regroup_and_scatter(uv_data, serial_index, world);
59  // check w values are distributed right between mpi and non mpi versions
60 #define TEST_MACRO(param) \
61  { \
62  Vector<t_real> a = uv_dist_all.param.cwiseAbs(); \
63  Vector<t_real> b = uv_dist_scatter.param.cwiseAbs(); \
64  std::sort(a.data(), a.data() + a.size()); \
65  std::sort(b.data(), b.data() + b.size()); \
66  CAPTURE(a.head(5)); \
67  CAPTURE(b.head(5)); \
68  REQUIRE(a.isApprox(b)); \
69  }
70  TEST_MACRO(w)
71  TEST_MACRO(v)
72  TEST_MACRO(u)
73  TEST_MACRO(vis)
74  TEST_MACRO(weights)
75 #undef TEST_MACRO
76 }
#define TEST_MACRO(param)
vis_params regroup_and_scatter(vis_params const &params, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
Vector< t_complex > vis
Definition: uvw_utilities.h:22
vis_params segment(const t_uint pos, const t_uint length) const
return subset of measurements
Definition: uvw_utilities.h:31
Vector< t_complex > weights
Definition: uvw_utilities.h:23

References purify::distribute::kmeans_algo(), purify::utilities::regroup_and_all_to_all(), purify::utilities::regroup_and_scatter(), purify::utilities::vis_params::segment(), TEST_MACRO, purify::utilities::vis_params::u, operators_test::u, purify::utilities::vis_params::v, operators_test::v, purify::utilities::vis_params::vis, purify::utilities::vis_params::w, and purify::utilities::vis_params::weights.