1 #ifndef CONVERGENCE_FACTORY_H
2 #define CONVERGENCE_FACTORY_H
4 #include <sopt/imaging_padmm.h>
6 #include <sopt/mpi/communicator.h>
8 #include <sopt/relative_variation.h>
16 template <
class T,
class Algo>
17 std::function<bool(Vector<T>
const &, Vector<T>
const &)> l2_convergence_factory(
19 const auto comm = sopt::mpi::Communicator::World();
22 return [algo, algo_weak, comm](Vector<T>
const &, Vector<T>
const &residual) {
23 auto const algo = algo_weak.lock();
24 auto const residual_norm = sopt::l2_norm(residual, algo->l2ball_proximal_weights());
25 SOPT_LOW_LOG(
" - [Algorithm] Residuals: {} <? {}", residual_norm,
26 algo->residual_tolerance());
27 return static_cast<bool>(
28 comm.all_reduce<int8_t>(residual_norm < algo->residual_tolerance(), MPI_LAND));
33 return [algo, algo_weak, comm](Vector<T>
const &, Vector<T>
const &residual) {
34 auto const algo = algo_weak.lock();
35 auto const residual_norm =
36 sopt::mpi::l2_norm(residual, algo->l2ball_proximal_weights(), comm);
37 SOPT_LOW_LOG(
" - [Algorithm] Residuals: {} <? {}", residual_norm,
38 algo->residual_tolerance());
39 return static_cast<bool>(residual_norm < algo->residual_tolerance());
44 throw std::runtime_error(
"Unknown type of distributed MPI convergence algorithm.");
48 template <
class T,
class Algo>
49 std::function<bool(Vector<T>
const &, Vector<T>
const &)> l1_convergence_factory(
51 auto const comm = sopt::mpi::Communicator::World();
52 auto const algo_temp = algo_weak.lock();
53 const std::shared_ptr<sopt::ScalarRelativeVariation<T>> conv =
54 std::make_shared<sopt::ScalarRelativeVariation<T>>(algo_temp->relative_variation(), 0,
55 "Objective function");
58 return [algo_weak, comm, conv](Vector<T>
const &x, Vector<T>
const &) {
59 auto const algo = algo_weak.lock();
60 return comm.all_reduce<uint8_t>((*conv)(sopt::l1_norm(algo->Psi().adjoint() * x)), MPI_LAND);
65 return [algo_weak, comm, conv](Vector<T>
const &x, Vector<T>
const &) {
66 auto const algo = algo_weak.lock();
67 return (*conv)(sopt::mpi::l1_norm(algo->Psi().adjoint() * x, comm));
72 throw std::runtime_error(
"Unknown type of distributed MPI convergence algorithm.");