PURIFY
Next-generation radio interferometric imaging
algorithms.cc
Go to the documentation of this file.
1 #include "purify/config.h"
2 #include "purify/types.h"
3 #include <array>
4 #include <benchmark/benchmark.h>
5 #include "benchmarks/utilities.h"
7 #include "purify/directories.h"
9 #include "purify/operators.h"
10 #include "purify/utilities.h"
12 #include <sopt/imaging_padmm.h>
13 #include <sopt/relative_variation.h>
14 #include <sopt/utilities.h>
15 #include <sopt/wavelets.h>
16 #include <sopt/wavelets/sara.h>
17 
18 using namespace purify;
19 
20 class AlgoFixture : public ::benchmark::Fixture {
21  public:
22  void SetUp(const ::benchmark::State &state) {
23  // Reading image from file and update related quantities
24  bool newImage = b_utilities::updateImage(state.range(0), m_image, m_imsizex, m_imsizey);
25 
26  // Generating random uv(w) coverage
27  bool newMeasurements =
28  b_utilities::updateMeasurements(state.range(1), m_uv_data, m_epsilon, newImage, m_image);
29 
30  bool newKernel = m_kernel != state.range(2);
31 
32  m_kernel = state.range(2);
33  // creating the measurement operator
34  const t_real FoV = 1; // deg
35  const t_real cellsize = FoV / m_imsizex * 60. * 60.;
36  const bool w_term = false;
37  m_measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
38  factory::distributed_measurement_operator::serial, m_uv_data, m_imsizey, m_imsizex,
39  cellsize, cellsize, 2, kernels::kernel::kb, m_kernel, m_kernel, w_term);
40 
41  t_real const m_sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
42  }
43 
44  void TearDown(const ::benchmark::State &state) {}
45 
46  t_real m_epsilon;
47  t_uint m_counter;
48  t_real m_sigma;
49  std::vector<std::tuple<std::string, t_uint>> const m_sara{
50  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
51  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
52  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
53 
54  Image<t_complex> m_image;
55  t_uint m_imsizex;
56  t_uint m_imsizey;
57 
59 
60  t_uint m_kernel;
61  std::shared_ptr<sopt::LinearTransform<Vector<t_complex>> const> m_measurements_transform;
62  std::shared_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>> m_padmm;
63  std::shared_ptr<sopt::algorithm::ImagingForwardBackward<t_complex>> m_fb;
64 };
65 
66 BENCHMARK_DEFINE_F(AlgoFixture, Padmm)(benchmark::State &state) {
67  // Benchmark the application of the algorithm
68  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
69  factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);
70 
71  m_padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
72  factory::algo_distribution::serial, m_measurements_transform, wavelets, m_uv_data, m_sigma,
73  m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true, false, 1e-3, 1e-2, 50);
74 
75  while (state.KeepRunning()) {
76  auto start = std::chrono::high_resolution_clock::now();
77  (*m_padmm)();
78  auto end = std::chrono::high_resolution_clock::now();
79  state.SetIterationTime(b_utilities::duration(start, end));
80  }
81 }
82 
83 BENCHMARK_DEFINE_F(AlgoFixture, ForwardBackward)(benchmark::State &state) {
84  // Benchmark the application of the algorithm
85  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
86  factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);
87 
88  t_real const beta = m_sigma * m_sigma;
89  t_real const gamma = 0.0001;
90 
91  m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
92  factory::algo_distribution::serial, m_measurements_transform, wavelets, m_uv_data, m_sigma,
93  beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true, false, 1e-3,
94  1e-2, 50);
95 
96  while (state.KeepRunning()) {
97  auto start = std::chrono::high_resolution_clock::now();
98  (*m_fb)();
99  auto end = std::chrono::high_resolution_clock::now();
100  state.SetIterationTime(b_utilities::duration(start, end));
101  }
102 }
103 
104 #ifdef PURIFY_ONNXRT
105 BENCHMARK_DEFINE_F(AlgoFixture, ForwardBackwardOnnx)(benchmark::State &state) {
106  // Benchmark the application of the algorithm
107  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
108  factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);
109 
110  t_real const beta = m_sigma * m_sigma;
111  t_real const gamma = 0.0001;
112  std::string tf_model_path = purify::models_directory() + "/snr_15_model_dynamic.onnx";
113 
114  m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
115  factory::algo_distribution::serial, m_measurements_transform, wavelets, m_uv_data, m_sigma,
116  beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true, false, 1e-3,
117  1e-2, 50, tf_model_path, nondiff_func_type::Denoiser);
118 
119  while (state.KeepRunning()) {
120  auto start = std::chrono::high_resolution_clock::now();
121  (*m_fb)();
122  auto end = std::chrono::high_resolution_clock::now();
123  state.SetIterationTime(b_utilities::duration(start, end));
124  }
125 }
126 
127 BENCHMARK_REGISTER_F(AlgoFixture, ForwardBackwardOnnx)
128  //->Apply(b_utilities::Arguments)
129  ->Args({128, 10000, 4, 10})
130  ->UseManualTime()
131  ->MinTime(10.0)
132  ->MinWarmUpTime(5.0)
133  ->Repetitions(3) //->ReportAggregatesOnly(true)
134  ->Unit(benchmark::kMillisecond);
135 #endif
136 
137 BENCHMARK_REGISTER_F(AlgoFixture, Padmm)
138  //->Apply(b_utilities::Arguments)
139  ->Args({128, 10000, 4, 10})
140  ->UseManualTime()
141  ->MinTime(10.0)
142  ->MinWarmUpTime(5.0)
143  ->Repetitions(3) //->ReportAggregatesOnly(true)
144  ->Unit(benchmark::kMillisecond);
145 
146 BENCHMARK_REGISTER_F(AlgoFixture, ForwardBackward)
147  //->Apply(b_utilities::Arguments)
148  ->Args({128, 10000, 4, 10})
149  ->UseManualTime()
150  ->MinTime(10.0)
151  ->MinWarmUpTime(5.0)
152  ->Repetitions(3) //->ReportAggregatesOnly(true)
153  ->Unit(benchmark::kMillisecond);
154 
BENCHMARK_MAIN()
BENCHMARK_DEFINE_F(AlgoFixture, Padmm)(benchmark
Definition: algorithms.cc:66
t_uint m_kernel
Definition: algorithms.cc:60
t_uint m_imsizex
Definition: algorithms.cc:55
void SetUp(const ::benchmark::State &state)
Definition: algorithms.cc:22
t_uint m_counter
Definition: algorithms.cc:47
utilities::vis_params m_uv_data
Definition: algorithms.cc:58
Image< t_complex > m_image
Definition: algorithms.cc:54
std::shared_ptr< sopt::algorithm::ImagingProximalADMM< t_complex > > m_padmm
Definition: algorithms.cc:62
std::shared_ptr< sopt::LinearTransform< Vector< t_complex > > const > m_measurements_transform
Definition: algorithms.cc:61
t_real m_sigma
Definition: algorithms.cc:48
t_real m_epsilon
Definition: algorithms.cc:46
void TearDown(const ::benchmark::State &state)
Definition: algorithms.cc:44
std::shared_ptr< sopt::algorithm::ImagingForwardBackward< t_complex > > m_fb
Definition: algorithms.cc:63
t_uint m_imsizey
Definition: algorithms.cc:56
bool updateMeasurements(t_uint newSize, utilities::vis_params &data)
Definition: utilities.cc:54
bool updateImage(t_uint newSize, Image< t_complex > &image, t_uint &sizex, t_uint &sizey)
Definition: utilities.cc:32
double duration(std::chrono::high_resolution_clock::time_point start, std::chrono::high_resolution_clock::time_point end)
Definition: utilities.cc:26
std::string models_directory()
Holds TF models.