PURIFY
Next-generation radio interferometric imaging
stochastic_algorithm.cc
Go to the documentation of this file.
1 #include "purify/config.h"
2 #include "purify/types.h"
3 #include <array>
4 #include <random>
5 #include <benchmark/benchmark.h>
6 #include "benchmarks/utilities.h"
8 #include "purify/directories.h"
10 #include "purify/mpi_utilities.h"
11 #include "purify/operators.h"
12 #include "purify/utilities.h"
13 #include "purify/uvw_utilities.h"
15 #include <sopt/imaging_padmm.h>
16 #include <sopt/mpi/communicator.h>
17 #include <sopt/mpi/session.h>
18 #include <sopt/power_method.h>
19 #include <sopt/relative_variation.h>
20 #include <sopt/utilities.h>
21 #include <sopt/wavelets.h>
22 #include <sopt/wavelets/sara.h>
23 
24 #ifdef PURIFY_H5
25 #include "purify/h5reader.h"
26 #endif
27 
28 using namespace purify;
29 
30 class StochasticAlgoFixture : public ::benchmark::Fixture {
31  public:
32  void SetUp(const ::benchmark::State &state) {
33  m_imsizex = state.range(0);
34  m_imsizey = state.range(0);
35 
36  m_sigma = 0.016820222945913496 * std::sqrt(2);
37  m_beta = m_sigma * m_sigma;
38  m_gamma = 0.0001;
39 
40  m_N = state.range(1);
41 
42  m_input_data_path = data_filename("expected/fb/input_data.h5");
43 
44  m_world = sopt::mpi::Communicator::World();
45  }
46 
47  void TearDown(const ::benchmark::State &state) {}
48 
49  sopt::mpi::Communicator m_world;
50 
51  std::string m_input_data_path;
52 
53  t_uint m_imsizey;
54  t_uint m_imsizex;
55 
56  t_real m_sigma;
57  t_real m_beta;
58  t_real m_gamma;
59 
60  size_t m_N;
61 
62  std::vector<std::tuple<std::string, t_uint>> const m_sara{
63  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
64  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
65  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
66 };
67 
68 BENCHMARK_DEFINE_F(StochasticAlgoFixture, ForwardBackward)(benchmark::State &state) {
69  // This functor would be defined in Purify
70  std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
71  [this]() {
72  H5::H5Handler h5file(m_input_data_path, m_world);
73  utilities::vis_params uv_data = H5::stochread_visibility(h5file, m_N, false);
75  auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
77  m_imsizey, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
78 
79  auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
80  *phi, 1000, 1e-5,
81  m_world.broadcast(Vector<t_complex>::Ones(m_imsizex * m_imsizey).eval()));
82 
83  const t_real op_norm = std::get<0>(power_method_stuff);
84  phi->set_norm(op_norm);
85 
86  return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.vis, phi);
87  };
88 
89  // wavelets
90  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
91  factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);
92 
93  // algorithm
94  sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
95  fb.itermax(state.range(2))
96  .step_size(m_beta * sqrt(2))
97  .sigma(m_sigma * sqrt(2))
98  .regulariser_strength(m_gamma)
99  .relative_variation(1e-3)
100  .residual_tolerance(0)
101  .tight_frame(true)
102  .obj_comm(m_world);
103 
104  auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
105  gp->l1_proximal_tolerance(1e-4)
106  .l1_proximal_nu(1)
107  .l1_proximal_itermax(50)
108  .l1_proximal_positivity_constraint(true)
109  .l1_proximal_real_constraint(true)
110  .Psi(*wavelets);
111  fb.g_function(gp);
112 
113  PURIFY_INFO("Start iteration loop");
114 
115  while (state.KeepRunning()) {
116  auto start = std::chrono::high_resolution_clock::now();
117  fb();
118  auto end = std::chrono::high_resolution_clock::now();
119  state.SetIterationTime(b_utilities::duration(start, end, m_world));
120  }
121 }
122 
123 BENCHMARK_DEFINE_F(StochasticAlgoFixture, ForwardBackwardApproxNorm)(benchmark::State &state) {
124  // This functor would be defined in Purify
125  std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
126  [this]() {
127  H5::H5Handler h5file(m_input_data_path, m_world);
128  utilities::vis_params uv_data = H5::stochread_visibility(h5file, m_N, false);
130  auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
132  m_imsizey, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
133 
134  // declaration of static variables to avoid recalculating the normalisation
135  static auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
136  *phi, 1000, 1e-5,
137  m_world.broadcast(Vector<t_complex>::Ones(m_imsizex * m_imsizey).eval()));
138 
139  static const t_real op_norm = std::get<0>(power_method_stuff);
140 
141  // set the normalisation of the new phi
142  phi->set_norm(op_norm);
143 
144  return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.vis, phi);
145  };
146 
147  // wavelets
148  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
149  factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);
150 
151  // algorithm
152  sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
153  fb.itermax(state.range(2))
154  .step_size(m_beta * sqrt(2))
155  .sigma(m_sigma * sqrt(2))
156  .regulariser_strength(m_gamma)
157  .relative_variation(1e-3)
158  .residual_tolerance(0)
159  .tight_frame(true)
160  .obj_comm(m_world);
161 
162  auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
163  gp->l1_proximal_tolerance(1e-4)
164  .l1_proximal_nu(1)
165  .l1_proximal_itermax(50)
166  .l1_proximal_positivity_constraint(true)
167  .l1_proximal_real_constraint(true)
168  .Psi(*wavelets);
169  fb.g_function(gp);
170 
171  PURIFY_INFO("Start iteration loop");
172 
173  while (state.KeepRunning()) {
174  auto start = std::chrono::high_resolution_clock::now();
175  fb();
176  auto end = std::chrono::high_resolution_clock::now();
177  state.SetIterationTime(b_utilities::duration(start, end, m_world));
178  }
179 }
180 
181 BENCHMARK_REGISTER_F(StochasticAlgoFixture, ForwardBackward)
182  ->Args({128, 10000, 10})
183  ->UseManualTime()
184  ->MinTime(60.0)
185  ->MinWarmUpTime(5.0)
186  ->Repetitions(3) //->ReportAggregatesOnly(true)
187  ->Unit(benchmark::kMillisecond);
188 
189 BENCHMARK_REGISTER_F(StochasticAlgoFixture, ForwardBackwardApproxNorm)
190  ->Args({128, 10000, 10})
191  ->UseManualTime()
192  ->MinTime(60.0)
193  ->MinWarmUpTime(5.0)
194  ->Repetitions(3) //->ReportAggregatesOnly(true)
195  ->Unit(benchmark::kMillisecond);
void SetUp(const ::benchmark::State &state)
void TearDown(const ::benchmark::State &state)
sopt::mpi::Communicator m_world
Purify interface class to handle HDF5 input files.
Definition: h5reader.h:48
#define PURIFY_INFO(...)
Definition: logging.h:195
double duration(std::chrono::high_resolution_clock::time_point start, std::chrono::high_resolution_clock::time_point end)
Definition: utilities.cc:26
utilities::vis_params stochread_visibility(H5Handler &file, const size_t N, const bool w_term)
Stochastically reads dataset slices from the supplied HDF5-file handler, constructs a vis_params obje...
Definition: h5reader.h:206
const std::map< std::string, kernel > kernel_from_string
Definition: kernels.h:16
std::function< bool()> random_updater(const sopt::mpi::Communicator &comm, const t_int total, const t_int update_size, const std::shared_ptr< bool > update_pointer, const std::string &update_name)
std::string data_filename(std::string const &filename)
Holds data and such.
BENCHMARK_DEFINE_F(StochasticAlgoFixture, ForwardBackward)(benchmark
Vector< t_complex > vis
Definition: uvw_utilities.h:22