149 auto const world = sopt::mpi::Communicator::World();
151 const std::string &
test_dir =
"expected/primal_dual/";
155 uv_data.
units = utilities::vis_units::radians;
156 if (world.is_root()) {
157 CAPTURE(uv_data.vis.head(5));
159 REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
161 t_uint
const imsizey = 128;
162 t_uint
const imsizex = 128;
164 auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
165 factory::distributed_measurement_operator::mpi_distribute_image, uv_data, imsizey, imsizex, 1,
167 auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
168 *measurements_transform, 1000, 1e-5,
169 world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
170 const t_real op_norm = std::get<0>(power_method_stuff);
171 measurements_transform->set_norm(op_norm);
173 std::vector<std::tuple<std::string, t_uint>>
const sara{
174 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
175 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
176 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
177 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
178 factory::distributed_wavelet_operator::mpi_sara, sara, imsizey, imsizex);
180 world.broadcast(0.016820222945913496) * std::sqrt(2);
182 auto const primaldual =
183 factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
184 factory::algo_distribution::mpi_serial, measurements_transform, wavelets, uv_data,
185 sigma, imsizey, imsizex, sara.size(), 500,
true,
true, 1e-2, 1);
187 auto const diagnostic = (*primaldual)();
188 CHECK(diagnostic.niters == 16);
196 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
197 CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
198 CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
199 CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
200 CHECK(image.isApprox(solution, 1e-4));
202 const Vector<t_complex> residuals = measurements_transform->adjoint() *
203 (uv_data.vis - ((*measurements_transform) * diagnostic.x));
204 const Image<t_complex> residual_image =
205 Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
206 CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
207 CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
208 CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
211 auto const primaldual =
212 factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
213 factory::algo_distribution::mpi_distributed, measurements_transform, wavelets, uv_data,
214 sigma, imsizey, imsizex, sara.size(), 500,
true,
true, 1e-2, 1);
216 auto const diagnostic = (*primaldual)();
218 world.broadcast(sigma));
219 CHECK(sopt::mpi::l2_norm(diagnostic.residual, primaldual->l2ball_proximal_weights(), world) <
223 if (world.size() > 2 or world.size() == 0)
return;
225 const std::string &expected_solution_path = (world.size() == 2)
228 const std::string &expected_residual_path = (world.size() == 2)
231 if (world.size() == 1)
CHECK(diagnostic.niters == 16);
232 if (world.size() == 2)
CHECK(diagnostic.niters == 18);
236 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
238 CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
239 CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
240 CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
241 CHECK(image.isApprox(solution, 1e-4));
243 const Vector<t_complex> residuals = measurements_transform->adjoint() *
244 (uv_data.vis - ((*measurements_transform) * diagnostic.x));
245 const Image<t_complex> residual_image =
246 Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
248 CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
249 CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
250 CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
252 SECTION(
"random update") {
253 auto const measurements_transform_serial =
254 factory::measurement_operator_factory<Vector<t_complex>>(
255 factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2,
257 auto const power_method_stuff = sopt::algorithm::all_sum_all_power_method<Vector<t_complex>>(
258 world, *measurements_transform, 1000, 1e-5,
259 world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
260 const t_real op_norm = std::get<0>(power_method_stuff);
261 measurements_transform->set_norm(op_norm);
263 auto sara_dist = sopt::wavelets::distribute_sara(sara, world);
264 auto const wavelets_serial = factory::wavelet_operator_factory<Vector<t_complex>>(
265 factory::distributed_wavelet_operator::serial, sara_dist, imsizey, imsizex);
267 auto const primaldual =
268 factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
269 factory::algo_distribution::mpi_random_updates, measurements_transform_serial,
270 wavelets_serial, uv_data, sigma, imsizey, imsizex, sara_dist.size(), 500,
true,
true,
273 auto const diagnostic = (*primaldual)();
275 world.broadcast(sigma));
276 CHECK(sopt::mpi::l2_norm(diagnostic.residual, primaldual->l2ball_proximal_weights(), world) <
278 if (world.size() > 1)
return;
281 if (world.size() == 0)
283 else if (world.size() == 2 or world.size() == 1) {
285 const std::string &expected_solution_path =
288 const std::string &expected_residual_path =
291 if (world.size() == 1)
CHECK(diagnostic.niters == 16);
292 if (world.size() == 2)
CHECK(diagnostic.niters < 100);
297 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
299 CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
300 CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
302 Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
303 CHECK(image.isApprox(solution, 1e-3));
305 const Vector<t_complex> residuals =
306 measurements_transform->adjoint() *
307 (uv_data.vis - ((*measurements_transform) * diagnostic.x));
308 const Image<t_complex> residual_image =
309 Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
311 CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
312 CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
313 CHECK(residual_image.real().isApprox(residual.real(), 1e-3));