2 #include <benchmark/benchmark.h>
5 #include <sopt/mpi/communicator.h>
6 #include <sopt/mpi/session.h>
54 void SetUp(const ::benchmark::State& state) {
55 m_imsizex = state.range(0);
56 m_imsizey = state.range(0);
57 b_utilities::update_comm(m_world);
58 sopt::wavelets::SARA m_sara{
59 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
60 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
61 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
63 sopt::wavelets::SARA
const saraDistr = sopt::wavelets::distribute_sara(m_sara, m_world);
66 n_wave_coeff = saraDistr.size() * m_imsizey * m_imsizex;
67 m_Psi = sopt::linear_transform<t_complex>(saraDistr, m_imsizey, m_imsizex, m_world);
70 void TearDown(const ::benchmark::State& state) {}
76 sopt::LinearTransform<Vector<t_complex>> m_Psi = sopt::linear_transform_identity<t_complex>();
87 Vector<t_complex> image = Vector<t_complex>::Random(m_imsizey * m_imsizex);
88 Vector<t_complex>
const wavelet_coeff = Vector<t_complex>::Ones(n_wave_coeff);
90 while (m_world.broadcast<
int>(state.KeepRunning())) {
91 auto start = std::chrono::high_resolution_clock::now();
92 image = m_Psi * wavelet_coeff;
93 auto end = std::chrono::high_resolution_clock::now();
100 Vector<t_complex>
const image = Vector<t_complex>::Ones(m_imsizey * m_imsizex);
101 Vector<t_complex> wavelet_coeff = Vector<t_complex>::Zero(n_wave_coeff);
103 while (m_world.broadcast<
int>(state.KeepRunning())) {
104 auto start = std::chrono::high_resolution_clock::now();
105 wavelet_coeff = m_Psi.adjoint() * image;
106 auto end = std::chrono::high_resolution_clock::now();
114 ->Range(128, 128 << 3)
117 ->ReportAggregatesOnly(
true)
118 ->Unit(benchmark::kMillisecond);
123 ->Range(128, 128 << 3)
126 ->ReportAggregatesOnly(
true)
127 ->Unit(benchmark::kMillisecond);
void SetUp(const ::benchmark::State &state)
void TearDown(const ::benchmark::State &state)
sopt::mpi::Communicator m_world
WaveletOperatorMPIFixture()
double duration(std::chrono::high_resolution_clock::time_point start, std::chrono::high_resolution_clock::time_point end)
BENCHMARK_DEFINE_F(WaveletOperatorMPIFixture, Forward)(benchmark