PURIFY
Next-generation radio interferometric imaging
wavelet_operator_mpi.cc
Go to the documentation of this file.
1 #include <chrono>
2 #include <benchmark/benchmark.h>
3 #include "benchmarks/utilities.h"
5 #include <sopt/mpi/communicator.h>
6 #include <sopt/mpi/session.h>
7 
8 using namespace purify;
9 
10 // -------------- Constructor benchmark -------------------------//
11 
12 // void wavelet_operator_constructor_mpi(benchmark::State &state) {
13 
14 // // Image size
15 // t_uint m_imsizex = state.range(0);
16 // t_uint m_imsizey = state.range(0);
17 // // MPI communicator
18 // sopt::mpi::Communicator m_world = sopt::mpi::Communicator::World();
19 
20 // // benchmark the creation of measurement operator
21 // while(state.KeepRunning()) {
22 
23 // auto start = std::chrono::high_resolution_clock::now();
24 
25 // const sopt::wavelets::SARA m_sara{std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u),
26 // std::make_tuple("DB2", 3u),
27 // std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
28 // std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
29 
30 // sopt::wavelets::SARA saraDistr = sopt::wavelets::distribute_sara(m_sara, m_world);
31 
32 // sopt::LinearTransform<Vector<t_complex>> Psi = sopt::linear_transform<t_complex>(saraDistr,
33 // m_imsizey, m_imsizex, m_world);
34 
35 // auto end = std::chrono::high_resolution_clock::now();
36 
37 // state.SetIterationTime(b_utilities::duration(start,end));
38 // }
39 // }
40 
41 // BENCHMARK(wavelet_operator_constructor_mpi)
42 // //->Apply(b_utilities::Arguments)
43 // //->Args({1024})
44 // ->RangeMultiplier(2)->Range(1024, 1024<<4)
45 // ->UseManualTime()
46 // ->Repetitions(1)->ReportAggregatesOnly(true)
47 // ->Unit(benchmark::kMillisecond);
48 
49 // ----------------- Application benchmarks -----------------------//
50 
51 class WaveletOperatorMPIFixture : public ::benchmark::Fixture {
52  public:
54  void SetUp(const ::benchmark::State& state) {
55  m_imsizex = state.range(0);
56  m_imsizey = state.range(0);
57  b_utilities::update_comm(m_world);
58  sopt::wavelets::SARA m_sara{
59  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
60  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
61  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
62 
63  sopt::wavelets::SARA const saraDistr = sopt::wavelets::distribute_sara(m_sara, m_world);
64 
65  // Get the number of wavelet coefs
66  n_wave_coeff = saraDistr.size() * m_imsizey * m_imsizex;
67  m_Psi = sopt::linear_transform<t_complex>(saraDistr, m_imsizey, m_imsizex, m_world);
68  }
69 
70  void TearDown(const ::benchmark::State& state) {}
71 
72  // A bunch of useful variables
73  t_uint m_counter;
74  // MPI communicator
75  sopt::mpi::Communicator m_world;
76  sopt::LinearTransform<Vector<t_complex>> m_Psi = sopt::linear_transform_identity<t_complex>();
77  t_uint m_imsizex;
78  t_uint m_imsizey;
79  // Get the number of wavelet coefs
80  t_uint n_wave_coeff;
81 };
82 
83 BENCHMARK_DEFINE_F(WaveletOperatorMPIFixture, Forward)(benchmark::State& state) {
84  // Benchmark the application of the operator
85 
86  // Apply Psi to a temporary vector
87  Vector<t_complex> image = Vector<t_complex>::Random(m_imsizey * m_imsizex);
88  Vector<t_complex> const wavelet_coeff = Vector<t_complex>::Ones(n_wave_coeff);
89 
90  while (m_world.broadcast<int>(state.KeepRunning())) {
91  auto start = std::chrono::high_resolution_clock::now();
92  image = m_Psi * wavelet_coeff;
93  auto end = std::chrono::high_resolution_clock::now();
94  state.SetIterationTime(b_utilities::duration(start, end, m_world));
95  }
96 }
97 
98 BENCHMARK_DEFINE_F(WaveletOperatorMPIFixture, Adjoint)(benchmark::State& state) {
99  // Apply Psi to a temporary vector
100  Vector<t_complex> const image = Vector<t_complex>::Ones(m_imsizey * m_imsizex);
101  Vector<t_complex> wavelet_coeff = Vector<t_complex>::Zero(n_wave_coeff);
102 
103  while (m_world.broadcast<int>(state.KeepRunning())) {
104  auto start = std::chrono::high_resolution_clock::now();
105  wavelet_coeff = m_Psi.adjoint() * image;
106  auto end = std::chrono::high_resolution_clock::now();
107  state.SetIterationTime(b_utilities::duration(start, end, m_world));
108  }
109 }
110 
111 BENCHMARK_REGISTER_F(WaveletOperatorMPIFixture, Forward)
112  // //->Apply(b_utilities::Arguments)
113  ->RangeMultiplier(2)
114  ->Range(128, 128 << 3)
115  ->UseManualTime()
116  ->Repetitions(5)
117  ->ReportAggregatesOnly(true)
118  ->Unit(benchmark::kMillisecond);
119 
120 BENCHMARK_REGISTER_F(WaveletOperatorMPIFixture, Adjoint)
121  // //->Apply(b_utilities::Arguments)
122  ->RangeMultiplier(2)
123  ->Range(128, 128 << 3)
124  ->UseManualTime()
125  ->Repetitions(5)
126  ->ReportAggregatesOnly(true)
127  ->Unit(benchmark::kMillisecond);
128 
129 // BENCHMARK_MAIN();
void SetUp(const ::benchmark::State &state)
void TearDown(const ::benchmark::State &state)
sopt::mpi::Communicator m_world
double duration(std::chrono::high_resolution_clock::time_point start, std::chrono::high_resolution_clock::time_point end)
Definition: utilities.cc:26
BENCHMARK_DEFINE_F(WaveletOperatorMPIFixture, Forward)(benchmark