PURIFY
Next-generation radio interferometric imaging
convergence_factory.h
Go to the documentation of this file.
1 #ifndef CONVERGENCE_FACTORY_H
2 #define CONVERGENCE_FACTORY_H
3 
4 #include <sopt/imaging_padmm.h>
5 #ifdef PURIFY_MPI
6 #include <sopt/mpi/communicator.h>
7 #endif
8 #include <sopt/relative_variation.h>
9 
10 namespace purify {
11 namespace factory {
13 
14 #ifdef PURIFY_MPI
15 
16 template <class T, class Algo>
17 std::function<bool(Vector<T> const &, Vector<T> const &)> l2_convergence_factory(
18  const ConvergenceType algo, std::weak_ptr<Algo> const algo_weak) {
19  const auto comm = sopt::mpi::Communicator::World();
20  switch (algo) {
22  return [algo, algo_weak, comm](Vector<T> const &, Vector<T> const &residual) {
23  auto const algo = algo_weak.lock();
24  auto const residual_norm = sopt::l2_norm(residual, algo->l2ball_proximal_weights());
25  SOPT_LOW_LOG(" - [Algorithm] Residuals: {} <? {}", residual_norm,
26  algo->residual_tolerance());
27  return static_cast<bool>(
28  comm.all_reduce<int8_t>(residual_norm < algo->residual_tolerance(), MPI_LAND));
29  };
30  break;
31  }
33  return [algo, algo_weak, comm](Vector<T> const &, Vector<T> const &residual) {
34  auto const algo = algo_weak.lock();
35  auto const residual_norm =
36  sopt::mpi::l2_norm(residual, algo->l2ball_proximal_weights(), comm);
37  SOPT_LOW_LOG(" - [Algorithm] Residuals: {} <? {}", residual_norm,
38  algo->residual_tolerance());
39  return static_cast<bool>(residual_norm < algo->residual_tolerance());
40  };
41  break;
42  }
43  default:
44  throw std::runtime_error("Unknown type of distributed MPI convergence algorithm.");
45  }
46 }
47 
48 template <class T, class Algo>
49 std::function<bool(Vector<T> const &, Vector<T> const &)> l1_convergence_factory(
50  const ConvergenceType algo, std::weak_ptr<Algo> const algo_weak) {
51  auto const comm = sopt::mpi::Communicator::World();
52  auto const algo_temp = algo_weak.lock();
53  const std::shared_ptr<sopt::ScalarRelativeVariation<T>> conv =
54  std::make_shared<sopt::ScalarRelativeVariation<T>>(algo_temp->relative_variation(), 0,
55  "Objective function");
56  switch (algo) {
58  return [algo_weak, comm, conv](Vector<T> const &x, Vector<T> const &) {
59  auto const algo = algo_weak.lock();
60  return comm.all_reduce<uint8_t>((*conv)(sopt::l1_norm(algo->Psi().adjoint() * x)), MPI_LAND);
61  };
62  break;
63  }
65  return [algo_weak, comm, conv](Vector<T> const &x, Vector<T> const &) {
66  auto const algo = algo_weak.lock();
67  return (*conv)(sopt::mpi::l1_norm(algo->Psi().adjoint() * x, comm));
68  };
69  break;
70  }
71  default:
72  throw std::runtime_error("Unknown type of distributed MPI convergence algorithm.");
73  }
74 }
75 #endif
76 } // namespace factory
77 } // namespace purify
78 #endif