PURIFY
Next-generation radio interferometric imaging
Classes | Functions
stochastic_algorithm.cc File Reference
#include "purify/config.h"
#include "purify/types.h"
#include <array>
#include <random>
#include <benchmark/benchmark.h>
#include "benchmarks/utilities.h"
#include "purify/algorithm_factory.h"
#include "purify/directories.h"
#include "purify/measurement_operator_factory.h"
#include "purify/mpi_utilities.h"
#include "purify/operators.h"
#include "purify/utilities.h"
#include "purify/uvw_utilities.h"
#include "purify/wavelet_operator_factory.h"
#include <sopt/imaging_padmm.h>
#include <sopt/mpi/communicator.h>
#include <sopt/mpi/session.h>
#include <sopt/power_method.h>
#include <sopt/relative_variation.h>
#include <sopt/utilities.h>
#include <sopt/wavelets.h>
#include <sopt/wavelets/sara.h>
+ Include dependency graph for stochastic_algorithm.cc:

Go to the source code of this file.

Classes

class  StochasticAlgoFixture
 

Functions

 BENCHMARK_DEFINE_F (StochasticAlgoFixture, ForwardBackward)(benchmark
 
 BENCHMARK_DEFINE_F (StochasticAlgoFixture, ForwardBackwardApproxNorm)(benchmark
 
 Args ({128, 10000, 10}) -> UseManualTime() ->MinTime(60.0) ->MinWarmUpTime(5.0) ->Repetitions(3) ->Unit(benchmark::kMillisecond)
 

Function Documentation

◆ Args()

Args ( {128, 10000, 10}  ) -> UseManualTime() ->MinTime(60.0) ->MinWarmUpTime(5.0) ->Repetitions(3) ->Unit(benchmark::kMillisecond)

◆ BENCHMARK_DEFINE_F() [1/2]

BENCHMARK_DEFINE_F ( StochasticAlgoFixture  ,
ForwardBackward   
)

Definition at line 68 of file stochastic_algorithm.cc.

68  {
69  // This functor would be defined in Purify
70  std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
71  [this]() {
72  H5::H5Handler h5file(m_input_data_path, m_world);
73  utilities::vis_params uv_data = H5::stochread_visibility(h5file, m_N, false);
74  uv_data.units = utilities::vis_units::radians;
75  auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
76  factory::distributed_measurement_operator::mpi_distribute_image, uv_data, m_imsizex,
77  m_imsizey, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
78 
79  auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
80  *phi, 1000, 1e-5,
81  m_world.broadcast(Vector<t_complex>::Ones(m_imsizex * m_imsizey).eval()));
82 
83  const t_real op_norm = std::get<0>(power_method_stuff);
84  phi->set_norm(op_norm);
85 
86  return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.vis, phi);
87  };
88 
89  // wavelets
90  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
91  factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);
92 
93  // algorithm
94  sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
95  fb.itermax(state.range(2))
96  .step_size(m_beta * sqrt(2))
97  .sigma(m_sigma * sqrt(2))
98  .regulariser_strength(m_gamma)
99  .relative_variation(1e-3)
100  .residual_tolerance(0)
101  .tight_frame(true)
102  .obj_comm(m_world);
103 
104  auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
105  gp->l1_proximal_tolerance(1e-4)
106  .l1_proximal_nu(1)
107  .l1_proximal_itermax(50)
108  .l1_proximal_positivity_constraint(true)
109  .l1_proximal_real_constraint(true)
110  .Psi(*wavelets);
111  fb.g_function(gp);
112 
113  PURIFY_INFO("Start iteration loop");
114 
115  while (state.KeepRunning()) {
116  auto start = std::chrono::high_resolution_clock::now();
117  fb();
118  auto end = std::chrono::high_resolution_clock::now();
119  state.SetIterationTime(b_utilities::duration(start, end, m_world));
120  }
121 }
Purify interface class to handle HDF5 input files.
Definition: h5reader.h:48
#define PURIFY_INFO(...)
Definition: logging.h:195
double duration(std::chrono::high_resolution_clock::time_point start, std::chrono::high_resolution_clock::time_point end)
Definition: utilities.cc:26
utilities::vis_params stochread_visibility(H5Handler &file, const size_t N, const bool w_term)
Stochastically reads dataset slices from the supplied HDF5-file handler, constructs a vis_params obje...
Definition: h5reader.h:206
const std::map< std::string, kernel > kernel_from_string
Definition: kernels.h:16
std::function< bool()> random_updater(const sopt::mpi::Communicator &comm, const t_int total, const t_int update_size, const std::shared_ptr< bool > update_pointer, const std::string &update_name)
Vector< t_complex > vis
Definition: uvw_utilities.h:22

References b_utilities::duration(), purify::kernels::kernel_from_string, purify::factory::mpi_distribute_image, PURIFY_INFO, purify::utilities::radians, purify::random_updater::random_updater(), purify::factory::serial, purify::H5::stochread_visibility(), purify::utilities::vis_params::units, and purify::utilities::vis_params::vis.

◆ BENCHMARK_DEFINE_F() [2/2]

BENCHMARK_DEFINE_F ( StochasticAlgoFixture  ,
ForwardBackwardApproxNorm   
)

Definition at line 123 of file stochastic_algorithm.cc.

123  {
124  // This functor would be defined in Purify
125  std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
126  [this]() {
127  H5::H5Handler h5file(m_input_data_path, m_world);
128  utilities::vis_params uv_data = H5::stochread_visibility(h5file, m_N, false);
129  uv_data.units = utilities::vis_units::radians;
130  auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
131  factory::distributed_measurement_operator::mpi_distribute_image, uv_data, m_imsizex,
132  m_imsizey, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
133 
134  // declaration of static variables to avoid recalculating the normalisation
135  static auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
136  *phi, 1000, 1e-5,
137  m_world.broadcast(Vector<t_complex>::Ones(m_imsizex * m_imsizey).eval()));
138 
139  static const t_real op_norm = std::get<0>(power_method_stuff);
140 
141  // set the normalisation of the new phi
142  phi->set_norm(op_norm);
143 
144  return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.vis, phi);
145  };
146 
147  // wavelets
148  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
149  factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);
150 
151  // algorithm
152  sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
153  fb.itermax(state.range(2))
154  .step_size(m_beta * sqrt(2))
155  .sigma(m_sigma * sqrt(2))
156  .regulariser_strength(m_gamma)
157  .relative_variation(1e-3)
158  .residual_tolerance(0)
159  .tight_frame(true)
160  .obj_comm(m_world);
161 
162  auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
163  gp->l1_proximal_tolerance(1e-4)
164  .l1_proximal_nu(1)
165  .l1_proximal_itermax(50)
166  .l1_proximal_positivity_constraint(true)
167  .l1_proximal_real_constraint(true)
168  .Psi(*wavelets);
169  fb.g_function(gp);
170 
171  PURIFY_INFO("Start iteration loop");
172 
173  while (state.KeepRunning()) {
174  auto start = std::chrono::high_resolution_clock::now();
175  fb();
176  auto end = std::chrono::high_resolution_clock::now();
177  state.SetIterationTime(b_utilities::duration(start, end, m_world));
178  }
179 }

References b_utilities::duration(), purify::kernels::kernel_from_string, purify::factory::mpi_distribute_image, PURIFY_INFO, purify::utilities::radians, purify::random_updater::random_updater(), purify::factory::serial, purify::H5::stochread_visibility(), purify::utilities::vis_params::units, and purify::utilities::vis_params::vis.