1 #ifndef UPDATE_FACTORY_H 
    2 #define UPDATE_FACTORY_H 
    3 #include "purify/config.h" 
    8 #include <sopt/mpi/communicator.h> 
   19 template <
class T, 
class Algo>
 
   20 void add_updater(std::weak_ptr<Algo> 
const algo_weak, 
const t_real step_size_scale,
 
   21                  const t_real update_tol, 
const t_uint update_iters,
 
   24                  const t_uint imsizex, 
const t_uint sara_size, 
const bool using_mpi,
 
   25                  const t_real beam_units = 1) {
 
   27     throw std::runtime_error(
"Step size update tolerance must be greater than zero.");
 
   28   if (step_size_scale < 0)
 
   29     throw std::runtime_error(
"Step size update scale must be greater than zero.");
 
   32       std::make_shared<CDisplay>(cimg::make_display(Vector<t_real>::Zero(1024 * 512), 1024, 512));
 
   34   const std::shared_ptr<pfitsio::header_params> update_header_sol =
 
   35       std::make_shared<pfitsio::header_params>(update_solution_header);
 
   36   const std::shared_ptr<pfitsio::header_params> update_header_res =
 
   37       std::make_shared<pfitsio::header_params>(update_residual_header);
 
   40     const auto comm = sopt::mpi::Communicator::World();
 
   41     const std::shared_ptr<t_int> iter = std::make_shared<t_int>(0);
 
   42     const auto updater = [update_tol, update_iters, imsizex, imsizey, algo_weak, iter,
 
   43                           step_size_scale, update_header_sol, update_header_res, sara_size, comm,
 
   44                           beam_units](
const Vector<T> &x, 
const Vector<T> &res) -> 
bool {
 
   45       auto algo = algo_weak.lock();
 
   46       if (comm.is_root()) 
PURIFY_MEDIUM_LOG(
"Step size γ {}", algo->regulariser_strength());
 
   47       if (algo->regulariser_strength() > 0) {
 
   48         Vector<t_complex> 
const alpha = algo->Psi().adjoint() * x;
 
   49         const t_real new_gamma =
 
   50             comm.all_reduce((sara_size > 0) ? alpha.real().cwiseAbs().maxCoeff() : 0., MPI_MAX) *
 
   54         algo->regulariser_strength(
 
   55             ((std::abs(algo->regulariser_strength() - new_gamma) > update_tol) and
 
   58                 : algo->regulariser_strength());
 
   60       Vector<t_complex> 
const residual = algo->Phi().adjoint() * (res / beam_units).eval();
 
   62                         residual.norm() / std::sqrt(residual.size()));
 
   64         update_header_sol->niters = *iter;
 
   65         update_header_res->niters = *iter;
 
   66         pfitsio::write2d(Image<T>::Map(x.data(), imsizey, imsizex).real(), *update_header_sol,
 
   69                          *update_header_res, 
true);
 
   74     auto algo = algo_weak.lock();
 
   75     algo->is_converged(updater);
 
   78     throw std::runtime_error(
 
   79         "Trying to use algorithm step size update with MPI, but you did not compile with MPI.");
 
   82     const std::shared_ptr<t_int> iter = std::make_shared<t_int>(0);
 
   83     const auto updater = [update_tol, update_iters, imsizex, imsizey, algo_weak, iter,
 
   84                           step_size_scale, update_header_sol, update_header_res, beam_units
 
   89     ](
const Vector<T> &x, 
const Vector<T> &res) -> 
bool {
 
   90       auto algo = algo_weak.lock();
 
   91       if (algo->regulariser_strength() > 0) {
 
   93         Vector<T> 
const alpha = algo->Psi().adjoint() * x;
 
   94         const t_real new_gamma = alpha.real().cwiseAbs().maxCoeff() * step_size_scale;
 
   97         algo->regulariser_strength(((std::abs((algo->regulariser_strength() - new_gamma) /
 
   98                                               algo->regulariser_strength()) > update_tol) and
 
  101                                        : algo->regulariser_strength());
 
  103       Vector<t_complex> 
const residual = algo->Phi().adjoint() * (res / beam_units).eval();
 
  105                         residual.norm() / std::sqrt(residual.size()));
 
  107       const auto img1 = cimg::make_image(x.real().eval(), imsizey, imsizex)
 
  109                             .get_resize(512, 512);
 
  110       const auto img2 = cimg::make_image(residual.real().eval(), imsizey, imsizex)
 
  112                             .get_resize(512, 512);
 
  114           CImageList<t_real>(img1.get_equalize(256, 0.05, 1.), img2.get_equalize(256, 0.1, 1.));
 
  115       canvas->display(results);
 
  116       canvas->resize(
true);
 
  118       update_header_sol->niters = *iter;
 
  119       update_header_res->niters = *iter;
 
  120       pfitsio::write2d(Image<T>::Map(x.data(), imsizey, imsizex).real(), *update_header_sol, 
true);
 
  121       pfitsio::write2d(Image<T>::Map(residual.data(), imsizey, imsizex).real(), *update_header_res,
 
  126     auto algo = algo_weak.lock();
 
  127     algo->is_converged(updater);
 
#define PURIFY_MEDIUM_LOG(...)
Medium priority message.
 
void add_updater(std::weak_ptr< Algo > const algo_weak, const t_real step_size_scale, const t_real update_tol, const t_uint update_iters, const pfitsio::header_params &update_solution_header, const pfitsio::header_params &update_residual_header, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const bool using_mpi, const t_real beam_units=1)
 
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.