1 #ifndef UPDATE_FACTORY_H
2 #define UPDATE_FACTORY_H
3 #include "purify/config.h"
8 #include <sopt/mpi/communicator.h>
19 template <
class T,
class Algo>
20 void add_updater(std::weak_ptr<Algo>
const algo_weak,
const t_real step_size_scale,
21 const t_real update_tol,
const t_uint update_iters,
24 const t_uint imsizex,
const t_uint sara_size,
const bool using_mpi,
25 const t_real beam_units = 1) {
27 throw std::runtime_error(
"Step size update tolerance must be greater than zero.");
28 if (step_size_scale < 0)
29 throw std::runtime_error(
"Step size update scale must be greater than zero.");
32 std::make_shared<CDisplay>(cimg::make_display(Vector<t_real>::Zero(1024 * 512), 1024, 512));
34 const std::shared_ptr<pfitsio::header_params> update_header_sol =
35 std::make_shared<pfitsio::header_params>(update_solution_header);
36 const std::shared_ptr<pfitsio::header_params> update_header_res =
37 std::make_shared<pfitsio::header_params>(update_residual_header);
40 const auto comm = sopt::mpi::Communicator::World();
41 const std::shared_ptr<t_int> iter = std::make_shared<t_int>(0);
42 const auto updater = [update_tol, update_iters, imsizex, imsizey, algo_weak, iter,
43 step_size_scale, update_header_sol, update_header_res, sara_size, comm,
44 beam_units](
const Vector<T> &x,
const Vector<T> &res) ->
bool {
45 auto algo = algo_weak.lock();
46 if (comm.is_root())
PURIFY_MEDIUM_LOG(
"Step size γ {}", algo->regulariser_strength());
47 if (algo->regulariser_strength() > 0) {
48 Vector<t_complex>
const alpha = algo->Psi().adjoint() * x;
49 const t_real new_gamma =
50 comm.all_reduce((sara_size > 0) ? alpha.real().cwiseAbs().maxCoeff() : 0., MPI_MAX) *
54 algo->regulariser_strength(
55 ((std::abs(algo->regulariser_strength() - new_gamma) > update_tol) and
58 : algo->regulariser_strength());
60 Vector<t_complex>
const residual = algo->Phi().adjoint() * (res / beam_units).eval();
62 residual.norm() / std::sqrt(residual.size()));
64 update_header_sol->niters = *iter;
65 update_header_res->niters = *iter;
66 pfitsio::write2d(Image<T>::Map(x.data(), imsizey, imsizex).real(), *update_header_sol,
69 *update_header_res,
true);
74 auto algo = algo_weak.lock();
75 algo->is_converged(updater);
78 throw std::runtime_error(
79 "Trying to use algorithm step size update with MPI, but you did not compile with MPI.");
82 const std::shared_ptr<t_int> iter = std::make_shared<t_int>(0);
83 const auto updater = [update_tol, update_iters, imsizex, imsizey, algo_weak, iter,
84 step_size_scale, update_header_sol, update_header_res, beam_units
89 ](
const Vector<T> &x,
const Vector<T> &res) ->
bool {
90 auto algo = algo_weak.lock();
91 if (algo->regulariser_strength() > 0) {
93 Vector<T>
const alpha = algo->Psi().adjoint() * x;
94 const t_real new_gamma = alpha.real().cwiseAbs().maxCoeff() * step_size_scale;
97 algo->regulariser_strength(((std::abs((algo->regulariser_strength() - new_gamma) /
98 algo->regulariser_strength()) > update_tol) and
101 : algo->regulariser_strength());
103 Vector<t_complex>
const residual = algo->Phi().adjoint() * (res / beam_units).eval();
105 residual.norm() / std::sqrt(residual.size()));
107 const auto img1 = cimg::make_image(x.real().eval(), imsizey, imsizex)
109 .get_resize(512, 512);
110 const auto img2 = cimg::make_image(residual.real().eval(), imsizey, imsizex)
112 .get_resize(512, 512);
114 CImageList<t_real>(img1.get_equalize(256, 0.05, 1.), img2.get_equalize(256, 0.1, 1.));
115 canvas->display(results);
116 canvas->resize(
true);
118 update_header_sol->niters = *iter;
119 update_header_res->niters = *iter;
120 pfitsio::write2d(Image<T>::Map(x.data(), imsizey, imsizex).real(), *update_header_sol,
true);
121 pfitsio::write2d(Image<T>::Map(residual.data(), imsizey, imsizex).real(), *update_header_res,
126 auto algo = algo_weak.lock();
127 algo->is_converged(updater);
#define PURIFY_MEDIUM_LOG(...)
Medium priority message.
void add_updater(std::weak_ptr< Algo > const algo_weak, const t_real step_size_scale, const t_real update_tol, const t_uint update_iters, const pfitsio::header_params &update_solution_header, const pfitsio::header_params &update_residual_header, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const bool using_mpi, const t_real beam_units=1)
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.