7 #include "purify/directories.h"
15 #include <sopt/imaging_padmm.h>
16 #include <sopt/mpi/communicator.h>
17 #include <sopt/mpi/session.h>
18 #include <sopt/relative_variation.h>
19 #include <sopt/utilities.h>
20 #include <sopt/wavelets.h>
21 #include <sopt/wavelets/sara.h>
27 void SetUp(const ::benchmark::State &state) {
33 newImage, m_image, m_world);
35 bool newKernel = m_kernel != state.range(2);
37 m_kernel = state.range(2);
40 const t_real cellsize = FoV / m_imsizex * 60. * 60.;
41 const bool w_term =
false;
42 if (state.range(4) == 1) {
43 PURIFY_INFO(
"Using distributed image MPI algorithm");
44 m_measurements_distribute_image = factory::measurement_operator_factory<Vector<t_complex>>(
50 if (state.range(4) == 2) {
51 PURIFY_INFO(
"Using distributed grid MPI algorithm");
52 m_measurements_distribute_grid = factory::measurement_operator_factory<Vector<t_complex>>(
57 m_sigma = 0.016820222945913496 * std::sqrt(2);
60 void TearDown(const ::benchmark::State &state) {}
64 std::vector<std::tuple<std::string, t_uint>>
const m_sara{
65 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
66 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
67 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
80 std::shared_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>>
m_padmm;
81 std::shared_ptr<sopt::algorithm::ImagingForwardBackward<t_complex>>
m_fb;
87 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
90 m_padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
92 m_uv_data, m_sigma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1,
true,
true,
93 false, 1e-3, 1e-2, 50);
96 while (state.KeepRunning()) {
97 auto start = std::chrono::high_resolution_clock::now();
98 auto result = (*m_padmm)();
99 auto end = std::chrono::high_resolution_clock::now();
100 std::cout <<
"Converged? " << result.good <<
" , niters = " << result.niters << std::endl;
108 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
111 m_padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
113 m_uv_data, m_sigma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1,
true,
true,
114 false, 1e-3, 1e-2, 50);
117 while (state.KeepRunning()) {
118 auto start = std::chrono::high_resolution_clock::now();
119 auto result = (*m_padmm)();
120 auto end = std::chrono::high_resolution_clock::now();
121 std::cout <<
"Converged? " << result.good <<
" , niters = " << result.niters << std::endl;
129 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
132 t_real
const beta = m_sigma * m_sigma;
133 t_real
const gamma = 0.0001;
135 m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
137 m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3),
true,
true,
false,
141 while (state.KeepRunning()) {
142 auto start = std::chrono::high_resolution_clock::now();
143 auto result = (*m_fb)();
144 auto end = std::chrono::high_resolution_clock::now();
145 std::cout <<
"Converged? " << result.good <<
" , niters = " << result.niters << std::endl;
153 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
156 t_real
const beta = m_sigma * m_sigma;
157 t_real
const gamma = 0.0001;
159 m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
161 m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3),
true,
true,
false,
165 while (state.KeepRunning()) {
166 auto start = std::chrono::high_resolution_clock::now();
167 auto result = (*m_fb)();
168 auto end = std::chrono::high_resolution_clock::now();
169 std::cout <<
"Converged? " << result.good <<
" , niters = " << result.niters << std::endl;
180 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
183 t_real
const beta = m_sigma * m_sigma;
184 t_real
const gamma = 0.0001;
188 m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
190 m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3),
true,
true,
false,
194 while (state.KeepRunning()) {
195 auto start = std::chrono::high_resolution_clock::now();
196 auto result = (*m_fb)();
197 auto end = std::chrono::high_resolution_clock::now();
198 std::cout <<
"Converged? " << result.good <<
" , niters = " << result.niters << std::endl;
208 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
211 t_real
const beta = m_sigma * m_sigma;
212 t_real
const gamma = 0.0001;
216 m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
218 m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3),
true,
true,
false,
222 while (state.KeepRunning()) {
223 auto start = std::chrono::high_resolution_clock::now();
224 auto result = (*m_fb)();
225 auto end = std::chrono::high_resolution_clock::now();
226 std::cout <<
"Converged? " << result.good <<
" , niters = " << result.niters << std::endl;
233 ->Args({128, 10000, 4, 10, 1})
234 ->
Args({1024,
static_cast<t_int
>(1e6), 4, 10, 1})
235 ->
Args({1024,
static_cast<t_int
>(1e7), 4, 10, 1})
236 ->
Args({2048,
static_cast<t_int
>(1e6), 4, 10, 1})
237 ->
Args({2048,
static_cast<t_int
>(1e7), 4, 10, 1})
238 ->
Args({4096,
static_cast<t_int
>(1e6), 4, 10, 1})
239 ->
Args({4096,
static_cast<t_int
>(1e7), 4, 10, 1})
242 ->MinWarmUpTime(10.0)
244 ->Unit(benchmark::kMillisecond);
248 ->Args({128, 10000, 4, 10, 1})
249 ->
Args({1024,
static_cast<t_int
>(1e6), 4, 10, 2})
250 ->
Args({1024,
static_cast<t_int
>(1e7), 4, 10, 2})
251 ->
Args({2048,
static_cast<t_int
>(1e6), 4, 10, 2})
252 ->
Args({2048,
static_cast<t_int
>(1e7), 4, 10, 2})
253 ->
Args({4096,
static_cast<t_int
>(1e6), 4, 10, 2})
254 ->
Args({4096,
static_cast<t_int
>(1e7), 4, 10, 2})
259 ->Unit(benchmark::kMillisecond);
265 ->Args({128, 10000, 4, 10, 1})
266 ->
Args({1024,
static_cast<t_int
>(1e6), 4, 10, 1})
267 ->
Args({1024,
static_cast<t_int
>(1e7), 4, 10, 1})
268 ->
Args({2048,
static_cast<t_int
>(1e6), 4, 10, 1})
269 ->
Args({2048,
static_cast<t_int
>(1e7), 4, 10, 1})
270 ->
Args({4096,
static_cast<t_int
>(1e6), 4, 10, 1})
271 ->
Args({4096,
static_cast<t_int
>(1e7), 4, 10, 1})
274 ->MinWarmUpTime(10.0)
276 ->Unit(benchmark::kMillisecond);
280 ->Args({128, 10000, 4, 10, 2})
281 ->
Args({1024,
static_cast<t_int
>(1e6), 4, 10, 2})
282 ->
Args({1024,
static_cast<t_int
>(1e7), 4, 10, 2})
283 ->
Args({2048,
static_cast<t_int
>(1e6), 4, 10, 2})
284 ->
Args({2048,
static_cast<t_int
>(1e7), 4, 10, 2})
285 ->
Args({4096,
static_cast<t_int
>(1e6), 4, 10, 2})
286 ->
Args({4096,
static_cast<t_int
>(1e7), 4, 10, 2})
289 ->MinWarmUpTime(10.0)
291 ->Unit(benchmark::kMillisecond);
295 ->Args({128, 10000, 4, 10, 1})
296 ->
Args({1024,
static_cast<t_int
>(1e6), 4, 10, 1})
297 ->
Args({1024,
static_cast<t_int
>(1e7), 4, 10, 1})
298 ->
Args({2048,
static_cast<t_int
>(1e6), 4, 10, 1})
299 ->
Args({2048,
static_cast<t_int
>(1e7), 4, 10, 1})
300 ->
Args({4096,
static_cast<t_int
>(1e6), 4, 10, 1})
301 ->
Args({4096,
static_cast<t_int
>(1e7), 4, 10, 1})
304 ->MinWarmUpTime(10.0)
306 ->Unit(benchmark::kMillisecond);
310 ->Args({128, 10000, 4, 10, 2})
311 ->
Args({1024,
static_cast<t_int
>(1e6), 4, 10, 2})
312 ->
Args({1024,
static_cast<t_int
>(1e7), 4, 10, 2})
313 ->
Args({2048,
static_cast<t_int
>(1e6), 4, 10, 2})
314 ->
Args({2048,
static_cast<t_int
>(1e7), 4, 10, 2})
315 ->
Args({4096,
static_cast<t_int
>(1e6), 4, 10, 2})
316 ->
Args({4096,
static_cast<t_int
>(1e7), 4, 10, 2})
319 ->MinWarmUpTime(10.0)
321 ->Unit(benchmark::kMillisecond);
Args({128, 10000, 4, 10, 1}) -> Args({1024, static_cast< t_int >(1e6), 4, 10, 1}) ->Args({1024, static_cast< t_int >(1e7), 4, 10, 1}) ->Args({2048, static_cast< t_int >(1e6), 4, 10, 1}) ->Args({2048, static_cast< t_int >(1e7), 4, 10, 1}) ->Args({4096, static_cast< t_int >(1e6), 4, 10, 1}) ->Args({4096, static_cast< t_int >(1e7), 4, 10, 1}) ->UseManualTime() ->MinTime(60.0) ->MinWarmUpTime(10.0) ->Repetitions(3) ->Unit(benchmark::kMillisecond)
BENCHMARK_DEFINE_F(AlgoFixtureMPI, PadmmDistributeImage)(benchmark
std::shared_ptr< sopt::LinearTransform< Vector< t_complex > > const > m_measurements_distribute_image
std::shared_ptr< sopt::algorithm::ImagingForwardBackward< t_complex > > m_fb
std::shared_ptr< sopt::algorithm::ImagingProximalADMM< t_complex > > m_padmm
sopt::mpi::Communicator m_world
Image< t_complex > m_image
void SetUp(const ::benchmark::State &state)
utilities::vis_params m_uv_data
void TearDown(const ::benchmark::State &state)
std::shared_ptr< sopt::LinearTransform< Vector< t_complex > > const > m_measurements_distribute_grid
bool updateMeasurements(t_uint newSize, utilities::vis_params &data)
bool updateImage(t_uint newSize, Image< t_complex > &image, t_uint &sizex, t_uint &sizey)
double duration(std::chrono::high_resolution_clock::time_point start, std::chrono::high_resolution_clock::time_point end)
std::string models_directory()
Holds TF models.