5 #include <boost/filesystem.hpp> 
    6 #include <boost/math/special_functions/erf.hpp> 
    7 #include "purify/directories.h" 
   16 #include <sopt/imaging_padmm.h> 
   17 #include <sopt/mpi/communicator.h> 
   18 #include <sopt/mpi/session.h> 
   19 #include <sopt/power_method.h> 
   20 #include <sopt/relative_variation.h> 
   21 #include <sopt/utilities.h> 
   22 #include <sopt/wavelets.h> 
   23 #include <sopt/wavelets/sara.h> 
   29 #ifndef PURIFY_PADMM_ALGORITHM 
   30 #define PURIFY_PADMM_ALGORITHM 2 
   40                                          sopt::mpi::Communicator 
const &comm) {
 
   51 std::shared_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>> 
padmm_factory(
 
   52     std::shared_ptr<sopt::LinearTransform<Vector<t_complex>> 
const> 
const &measurements,
 
   54     const sopt::mpi::Communicator &comm, 
const t_uint &imsizex, 
const t_uint &imsizey) {
 
   55   auto const Psi = sopt::linear_transform<t_complex>(sara, imsizey, imsizex, comm);
 
   57 #if PURIFY_PADMM_ALGORITHM == 2 
   58   auto const epsilon = 3 * std::sqrt(comm.all_sum_all(std::pow(sigma, 2))) *
 
   59                        std::sqrt(2 * comm.all_sum_all(uv_data.
size()));
 
   60 #elif PURIFY_PADMM_ALGORITHM == 3 || PURIFY_PADMM_ALGORITHM == 1 
   61   auto const epsilon = 3 * std::sqrt(2 * uv_data.
size()) * sigma;
 
   63   const t_real regulariser_strength =
 
   65                            std::make_shared<sopt::LinearTransform<Vector<t_complex>> 
const>(Psi),
 
   73   auto padmm = std::make_shared<sopt::algorithm::ImagingProximalADMM<t_complex>>(uv_data.
vis);
 
   75       .regulariser_strength(comm.all_reduce<t_real>(regulariser_strength, MPI_MAX))
 
   76       .relative_variation(1e-3)
 
   77       .l2ball_proximal_epsilon(epsilon)
 
   78 #if PURIFY_PADMM_ALGORITHM == 2 
   80       .l2ball_proximal_communicator(comm)
 
   83       .l1_proximal_adjoint_space_comm(comm)
 
   85       .l1_proximal_tolerance(1e-2)
 
   87       .l1_proximal_itermax(50)
 
   88       .l1_proximal_positivity_constraint(
true)
 
   89       .l1_proximal_real_constraint(
true)
 
   90       .residual_tolerance(epsilon)
 
   91       .lagrange_update_scale(0.9)
 
   94   sopt::ScalarRelativeVariation<t_complex> conv(
padmm->relative_variation(),
 
   95                                                 padmm->relative_variation(), 
"Objective function");
 
   96   std::weak_ptr<decltype(
padmm)::element_type> 
const padmm_weak(
padmm);
 
   97   padmm->residual_convergence([padmm_weak, conv, comm](
 
   98                                   Vector<t_complex> 
const &x,
 
   99                                   Vector<t_complex> 
const &residual) 
mutable -> 
bool {
 
  100     auto const padmm = padmm_weak.lock();
 
  101 #if PURIFY_PADMM_ALGORITHM == 2 
  102     auto const residual_norm = sopt::mpi::l2_norm(residual, 
padmm->l2ball_proximal_weights(), comm);
 
  103     auto const result = residual_norm < 
padmm->residual_tolerance();
 
  104 #elif PURIFY_PADMM_ALGORITHM == 3 || PURIFY_PADMM_ALGORITHM == 1 
  105     auto const residual_norm = sopt::l2_norm(residual, 
padmm->l2ball_proximal_weights());
 
  107         comm.all_reduce<int8_t>(residual_norm < 
padmm->residual_tolerance(), MPI_LAND);
 
  109     SOPT_LOW_LOG(
"    - [PADMM] Residuals: {} <? {}", residual_norm, 
padmm->residual_tolerance());
 
  113   padmm->objective_convergence([padmm_weak, conv, comm](Vector<t_complex> 
const &x,
 
  114                                                         Vector<t_complex> 
const &) 
mutable -> 
bool {
 
  115     auto const padmm = padmm_weak.lock();
 
  116 #if PURIFY_PADMM_ALGORITHM == 2 
  117     return conv(sopt::mpi::l1_norm(
padmm->Psi().adjoint() * x, 
padmm->l1_proximal_weights(), comm));
 
  118 #elif PURIFY_PADMM_ALGORITHM == 3 || PURIFY_PADMM_ALGORITHM == 1 
  119     return comm.all_reduce<uint8_t>(
 
  120         conv(sopt::l1_norm(
padmm->Psi().adjoint() * x, 
padmm->l1_proximal_weights())), MPI_LAND);
 
  124   auto convergence_function = [](
const Vector<t_complex> &x) { 
return true; };
 
  125   const std::shared_ptr<t_uint> iter = std::make_shared<t_uint>(0);
 
  126   const auto algo_update = [uv_data, imsizex, imsizey, padmm_weak, iter,
 
  127                             comm](
const Vector<t_complex> &x) -> 
bool {
 
  128     auto padmm = padmm_weak.lock();
 
  131     Vector<t_complex> 
const alpha = 
padmm->Psi().adjoint() * x;
 
  132     const t_real new_regulariser_strength =
 
  133         comm.all_reduce(alpha.real().cwiseAbs().maxCoeff(), MPI_MAX) * 1e-3;
 
  134     if (comm.is_root()) 
PURIFY_MEDIUM_LOG(
"Step size γ update {}", new_regulariser_strength);
 
  135     padmm->regulariser_strength(
 
  136         ((std::abs(
padmm->regulariser_strength() - new_regulariser_strength) > 0.2) and *iter < 200)
 
  137             ? new_regulariser_strength
 
  138             : 
padmm->regulariser_strength());
 
  141     Vector<t_complex> 
const residual = 
padmm->Phi().adjoint() * (uv_data.
vis - 
padmm->Phi() * x);
 
  143     if (comm.is_root()) {
 
  149   auto lambda = [convergence_function, algo_update](Vector<t_complex> 
const &x) {
 
  150     return convergence_function(x) and algo_update(x);
 
  152   padmm->is_converged(lambda);
 
  156 int main(
int nargs, 
char const **args) {
 
  159   auto const session = sopt::mpi::init(nargs, args);
 
  160   auto const world = sopt::mpi::Communicator::World();
 
  162   const std::string name = 
"realdata";
 
  163   const std::string filename_base = 
vla_filename(
"../mwa/uvdump_");
 
  164   const std::vector<std::string> filenames = {filename_base +
 
  167   std::string kernel_name = 
"kb";
 
  168   const bool w_term = 
false;
 
  170   const t_real cellsize = 20;  
 
  171   const t_uint imsizex = 1024;
 
  172   const t_uint imsizey = 1024;
 
  178       data.
weights.norm() / std::sqrt(world.all_sum_all(data.
weights.size())) * 0.5;
 
  180              world.all_reduce(data.
weights.array().cwiseAbs().maxCoeff(), MPI_MAX);
 
  181 #if PURIFY_PADMM_ALGORITHM == 2 || PURIFY_PADMM_ALGORITHM == 3 
  183   auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
 
  185           world, data, imsizey, imsizex, cellsize, cellsize, 2, 
kernel, 4, 4, w_term),
 
  186       100, 1e-4, world.broadcast(Vector<t_complex>::Random(imsizex * imsizey).eval())));
 
  189   auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
 
  191                                                         cellsize, 2, 
kernel, 4, 4, w_term),
 
  192       100, 1e-4, world.broadcast(Vector<t_complex>::Random(imsizex * imsizey).eval())));
 
  195 #elif PURIFY_PADMM_ALGORITHM == 1 
  197   auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
 
  198       measurementoperator::init_degrid_operator_2d_mpi<Vector<t_complex>>(
 
  199           world, data, imsizey, imsizex, cellsize, cellsize, 2, 
kernel, 4, 4, w_term),
 
  200       100, 1e-4, world.broadcast(Vector<t_complex>::Random(imsizex * imsizey).eval())));
 
  204   auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
 
  205       gpu::measurementoperator::init_degrid_operator_2d_mpi(world, data, imsizey, imsizex, cellsize,
 
  206                                                             cellsize, 2, 
kernel, 4, 4, w_term),
 
  207       100, 1e-4, world.broadcast(Vector<t_complex>::Random(imsizex * imsizey).eval())));
 
  211   auto const sara = sopt::wavelets::distribute_sara(
 
  212       sopt::wavelets::SARA{
 
  213           std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
 
  214           std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
 
  215           std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)},
 
  218   Vector<t_real> 
const dirty_image = (measurements->adjoint() * (data.
vis)).real();
 
  220   if (world.is_root()) {
 
  223 #if PURIFY_PADMM_ALGORITHM == 3 
  224     auto const pb_path = path / kernel_name / 
"local_epsilon_replicated_grids";
 
  225 #elif PURIFY_PADMM_ALGORITHM == 2 
  226     auto const pb_path = path / kernel_name / 
"global_epsilon_replicated_grids";
 
  227 #elif PURIFY_PADMM_ALGORITHM == 1 
  228     auto const pb_path = path / kernel_name / 
"local_epsilon_distributed_grids";
 
  230 #error Unknown or unimplemented algorithm 
  234     pfitsio::write2d(dirty_image, imsizey, imsizex, (pb_path / 
"dirty.fits").native());
 
  238   auto const padmm = 
padmm_factory(measurements, sigma, sara, data, world, imsizey, imsizex);
 
  240   auto const diagnostic = (*padmm)();
 
  243   assert(world.broadcast(diagnostic.x).isApprox(diagnostic.x));
 
  245   Vector<t_real> 
const residual_image = (measurements->adjoint() * diagnostic.residual).real();
 
  246   if (world.is_root()) {
 
  249 #if PURIFY_PADMM_ALGORITHM == 3 
  250     auto const pb_path = path / kernel_name / 
"local_epsilon_replicated_grids";
 
  251 #elif PURIFY_PADMM_ALGORITHM == 2 
  252     auto const pb_path = path / kernel_name / 
"global_epsilon_replicated_grids";
 
  253 #elif PURIFY_PADMM_ALGORITHM == 1 
  254     auto const pb_path = path / kernel_name / 
"local_epsilon_distributed_grids";
 
  256 #error Unknown or unimplemented algorithm 
  260     pfitsio::write2d(dirty_image, imsizey, imsizex, (pb_path / 
"dirty.fits").native());
 
  261     pfitsio::write2d(diagnostic.x.real(), imsizey, imsizex, (pb_path / 
"solution.fits").native());
 
  262     pfitsio::write2d(residual_image, imsizey, imsizex, (pb_path / 
"residual.fits").native());
 
#define PURIFY_MEDIUM_LOG(...)
Medium priority message.
 
std::vector< t_int > distribute_measurements(Vector< t_real > const &u, Vector< t_real > const &v, Vector< t_real > const &w, t_int const number_of_nodes, distribute::plan const distribution_plan, t_int const &grid_size)
Distribute visiblities into groups.
 
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
 
std::shared_ptr< sopt::LinearTransform< T > > init_degrid_operator_2d(const Vector< t_real > &u, const Vector< t_real > &v, const Vector< t_real > &w, const Vector< t_complex > &weights, const t_uint &imsizey, const t_uint &imsizex, const t_real &oversample_ratio=2, const kernels::kernel kernel=kernels::kernel::kb, const t_uint Ju=4, const t_uint Jv=4, const bool w_stacking=false, const t_real &cellx=1, const t_real &celly=1)
Returns linear transform that is the standard degridding operator.
 
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
 
t_real step_size(T const &vis, const std::shared_ptr< sopt::LinearTransform< T > const > &measurements, const std::shared_ptr< sopt::LinearTransform< T > const > &wavelets, const t_uint sara_size)
Calculate step size using MPI (does not include factor of 1e-3)
 
vis_params scatter_visibilities(vis_params const ¶ms, std::vector< t_int > const &sizes, sopt::mpi::Communicator const &comm)
 
vis_params regroup_and_scatter(vis_params const ¶ms, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
 
utilities::vis_params read_visibility(const std::vector< std::string > &names, const bool w_term)
Read visibility files from name of vector.
 
void mkdir_recursive(const std::string &path)
recursively create directories when they do not exist
 
std::string output_filename(std::string const &filename)
Test output file.
 
std::string vla_filename(std::string const &filename)
Specific vla data.
 
std::shared_ptr< sopt::algorithm::ImagingProximalADMM< t_complex > > padmm_factory(std::shared_ptr< sopt::LinearTransform< Vector< t_complex >> const > const &measurements, t_real const sigma, const sopt::wavelets::SARA &sara, const utilities::vis_params &uv_data, const sopt::mpi::Communicator &comm, const t_uint &imsizex, const t_uint &imsizey)
 
int main(int nargs, char const **args)
 
utilities::vis_params dirty_visibilities(const std::vector< std::string > &names)
 
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)
 
t_uint size() const
return number of measurements
 
Vector< t_complex > weights