1 #include "purify/config.h"
4 #include <benchmark/benchmark.h>
7 #include "purify/directories.h"
12 #include <sopt/imaging_padmm.h>
13 #include <sopt/relative_variation.h>
14 #include <sopt/utilities.h>
15 #include <sopt/wavelets.h>
16 #include <sopt/wavelets/sara.h>
22 void SetUp(const ::benchmark::State &state) {
27 bool newMeasurements =
30 bool newKernel = m_kernel != state.range(2);
32 m_kernel = state.range(2);
35 const t_real cellsize = FoV / m_imsizex * 60. * 60.;
36 const bool w_term =
false;
37 m_measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
41 t_real
const m_sigma = 0.016820222945913496 * std::sqrt(2);
44 void TearDown(const ::benchmark::State &state) {}
49 std::vector<std::tuple<std::string, t_uint>>
const m_sara{
50 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
51 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
52 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
62 std::shared_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>>
m_padmm;
63 std::shared_ptr<sopt::algorithm::ImagingForwardBackward<t_complex>>
m_fb;
68 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
71 m_padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
73 m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1,
true,
true,
false, 1e-3, 1e-2, 50);
75 while (state.KeepRunning()) {
76 auto start = std::chrono::high_resolution_clock::now();
78 auto end = std::chrono::high_resolution_clock::now();
85 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
88 t_real
const beta = m_sigma * m_sigma;
89 t_real
const gamma = 0.0001;
91 m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
93 beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1,
true,
true,
false, 1e-3,
96 while (state.KeepRunning()) {
97 auto start = std::chrono::high_resolution_clock::now();
99 auto end = std::chrono::high_resolution_clock::now();
107 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
110 t_real
const beta = m_sigma * m_sigma;
111 t_real
const gamma = 0.0001;
114 m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
116 beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1,
true,
true,
false, 1e-3,
119 while (state.KeepRunning()) {
120 auto start = std::chrono::high_resolution_clock::now();
122 auto end = std::chrono::high_resolution_clock::now();
127 BENCHMARK_REGISTER_F(
AlgoFixture, ForwardBackwardOnnx)
129 ->Args({128, 10000, 4, 10})
134 ->Unit(benchmark::kMillisecond);
139 ->Args({128, 10000, 4, 10})
144 ->Unit(benchmark::kMillisecond);
148 ->Args({128, 10000, 4, 10})
153 ->Unit(benchmark::kMillisecond);
BENCHMARK_DEFINE_F(AlgoFixture, Padmm)(benchmark
void SetUp(const ::benchmark::State &state)
utilities::vis_params m_uv_data
Image< t_complex > m_image
std::shared_ptr< sopt::algorithm::ImagingProximalADMM< t_complex > > m_padmm
std::shared_ptr< sopt::LinearTransform< Vector< t_complex > > const > m_measurements_transform
void TearDown(const ::benchmark::State &state)
std::shared_ptr< sopt::algorithm::ImagingForwardBackward< t_complex > > m_fb
bool updateMeasurements(t_uint newSize, utilities::vis_params &data)
bool updateImage(t_uint newSize, Image< t_complex > &image, t_uint &sizex, t_uint &sizey)
double duration(std::chrono::high_resolution_clock::time_point start, std::chrono::high_resolution_clock::time_point end)
std::string models_directory()
Holds TF models.