PURIFY
Next-generation radio interferometric imaging
update_factory.h
Go to the documentation of this file.
1 #ifndef UPDATE_FACTORY_H
2 #define UPDATE_FACTORY_H
3 #include "purify/config.h"
4 
5 #include "purify/logging.h"
6 
7 #ifdef PURIFY_MPI
8 #include <sopt/mpi/communicator.h>
9 #endif
10 
12 #include "purify/pfitsio.h"
13 
14 #include "purify/cimg.h"
15 
16 namespace purify {
17 namespace factory {
18 
19 template <class T, class Algo>
20 void add_updater(std::weak_ptr<Algo> const algo_weak, const t_real step_size_scale,
21  const t_real update_tol, const t_uint update_iters,
22  const pfitsio::header_params &update_solution_header,
23  const pfitsio::header_params &update_residual_header, const t_uint imsizey,
24  const t_uint imsizex, const t_uint sara_size, const bool using_mpi,
25  const t_real beam_units = 1) {
26  if (update_tol < 0)
27  throw std::runtime_error("Step size update tolerance must be greater than zero.");
28  if (step_size_scale < 0)
29  throw std::runtime_error("Step size update scale must be greater than zero.");
30 #ifdef PURIFY_CImg
31  auto const canvas =
32  std::make_shared<CDisplay>(cimg::make_display(Vector<t_real>::Zero(1024 * 512), 1024, 512));
33 #endif
34  const std::shared_ptr<pfitsio::header_params> update_header_sol =
35  std::make_shared<pfitsio::header_params>(update_solution_header);
36  const std::shared_ptr<pfitsio::header_params> update_header_res =
37  std::make_shared<pfitsio::header_params>(update_residual_header);
38  if (using_mpi) {
39 #ifdef PURIFY_MPI
40  const auto comm = sopt::mpi::Communicator::World();
41  const std::shared_ptr<t_int> iter = std::make_shared<t_int>(0);
42  const auto updater = [update_tol, update_iters, imsizex, imsizey, algo_weak, iter,
43  step_size_scale, update_header_sol, update_header_res, sara_size, comm,
44  beam_units](const Vector<T> &x, const Vector<T> &res) -> bool {
45  auto algo = algo_weak.lock();
46  if (comm.is_root()) PURIFY_MEDIUM_LOG("Step size γ {}", algo->regulariser_strength());
47  if (algo->regulariser_strength() > 0) {
48  Vector<t_complex> const alpha = algo->Psi().adjoint() * x;
49  const t_real new_gamma =
50  comm.all_reduce((sara_size > 0) ? alpha.real().cwiseAbs().maxCoeff() : 0., MPI_MAX) *
51  step_size_scale;
52  if (comm.is_root()) PURIFY_MEDIUM_LOG("Step size γ update {}", new_gamma);
53  // updating parameter
54  algo->regulariser_strength(
55  ((std::abs(algo->regulariser_strength() - new_gamma) > update_tol) and
56  *iter < update_iters)
57  ? new_gamma
58  : algo->regulariser_strength());
59  }
60  Vector<t_complex> const residual = algo->Phi().adjoint() * (res / beam_units).eval();
61  PURIFY_MEDIUM_LOG("RMS of residual map in Jy/beam {}",
62  residual.norm() / std::sqrt(residual.size()));
63  if (comm.is_root()) {
64  update_header_sol->niters = *iter;
65  update_header_res->niters = *iter;
66  pfitsio::write2d(Image<T>::Map(x.data(), imsizey, imsizex).real(), *update_header_sol,
67  true);
68  pfitsio::write2d(Image<T>::Map(residual.data(), imsizey, imsizex).real(),
69  *update_header_res, true);
70  }
71  *iter = *iter + 1;
72  return true;
73  };
74  auto algo = algo_weak.lock();
75  algo->is_converged(updater);
76 
77 #else
78  throw std::runtime_error(
79  "Trying to use algorithm step size update with MPI, but you did not compile with MPI.");
80 #endif
81  } else {
82  const std::shared_ptr<t_int> iter = std::make_shared<t_int>(0);
83  const auto updater = [update_tol, update_iters, imsizex, imsizey, algo_weak, iter,
84  step_size_scale, update_header_sol, update_header_res, beam_units
85 #ifdef PURIFY_CImg
86  ,
87  canvas
88 #endif
89  ](const Vector<T> &x, const Vector<T> &res) -> bool {
90  auto algo = algo_weak.lock();
91  if (algo->regulariser_strength() > 0) {
92  PURIFY_MEDIUM_LOG("Step size γ {}", algo->regulariser_strength());
93  Vector<T> const alpha = algo->Psi().adjoint() * x;
94  const t_real new_gamma = alpha.real().cwiseAbs().maxCoeff() * step_size_scale;
95  PURIFY_MEDIUM_LOG("Step size γ update {}", new_gamma);
96  // updating parameter
97  algo->regulariser_strength(((std::abs((algo->regulariser_strength() - new_gamma) /
98  algo->regulariser_strength()) > update_tol) and
99  *iter < update_iters)
100  ? new_gamma
101  : algo->regulariser_strength());
102  }
103  Vector<t_complex> const residual = algo->Phi().adjoint() * (res / beam_units).eval();
104  PURIFY_MEDIUM_LOG("RMS of residual map in Jy/beam {}",
105  residual.norm() / std::sqrt(residual.size()));
106 #ifdef PURIFY_CImg
107  const auto img1 = cimg::make_image(x.real().eval(), imsizey, imsizex)
108  .get_normalize(0, 1)
109  .get_resize(512, 512);
110  const auto img2 = cimg::make_image(residual.real().eval(), imsizey, imsizex)
111  .get_normalize(0, 1)
112  .get_resize(512, 512);
113  const auto results =
114  CImageList<t_real>(img1.get_equalize(256, 0.05, 1.), img2.get_equalize(256, 0.1, 1.));
115  canvas->display(results);
116  canvas->resize(true);
117 #endif
118  update_header_sol->niters = *iter;
119  update_header_res->niters = *iter;
120  pfitsio::write2d(Image<T>::Map(x.data(), imsizey, imsizex).real(), *update_header_sol, true);
121  pfitsio::write2d(Image<T>::Map(residual.data(), imsizey, imsizex).real(), *update_header_res,
122  true);
123  *iter = *iter + 1;
124  return true;
125  };
126  auto algo = algo_weak.lock();
127  algo->is_converged(updater);
128  }
129 }
130 } // namespace factory
131 } // namespace purify
132 #endif
#define PURIFY_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:205
void add_updater(std::weak_ptr< Algo > const algo_weak, const t_real step_size_scale, const t_real update_tol, const t_uint update_iters, const pfitsio::header_params &update_solution_header, const pfitsio::header_params &update_residual_header, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const bool using_mpi, const t_real beam_units=1)
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
Definition: pfitsio.cc:30