PURIFY
Next-generation radio interferometric imaging
padmm_mpi_real_data.cc
Go to the documentation of this file.
1 #include "purify/types.h"
2 #include <array>
3 #include <memory>
4 #include <random>
5 #include <boost/filesystem.hpp>
6 #include <boost/math/special_functions/erf.hpp>
7 #include "purify/directories.h"
8 #include "purify/distribute.h"
9 #include "purify/logging.h"
10 #include "purify/mpi_utilities.h"
11 #include "purify/operators.h"
12 #include "purify/pfitsio.h"
14 #include "purify/utilities.h"
15 #include "purify/uvfits.h"
16 #include <sopt/imaging_padmm.h>
17 #include <sopt/mpi/communicator.h>
18 #include <sopt/mpi/session.h>
19 #include <sopt/power_method.h>
20 #include <sopt/relative_variation.h>
21 #include <sopt/utilities.h>
22 #include <sopt/wavelets.h>
23 #include <sopt/wavelets/sara.h>
24 
25 #ifdef PURIFY_GPU
26 #include "purify/operators_gpu.h"
27 #endif
28 
29 #ifndef PURIFY_PADMM_ALGORITHM
30 #define PURIFY_PADMM_ALGORITHM 2
31 #endif
32 
33 using namespace purify;
34 
35 utilities::vis_params dirty_visibilities(const std::vector<std::string> &names) {
36  return utilities::read_visibility(names, true);
37 }
38 
39 utilities::vis_params dirty_visibilities(const std::vector<std::string> &names,
40  sopt::mpi::Communicator const &comm) {
41  if (comm.size() == 1) return dirty_visibilities(names);
42  if (comm.is_root()) {
43  auto result = dirty_visibilities(names);
44  auto const order = distribute::distribute_measurements(result, comm, distribute::plan::w_term);
45  return utilities::regroup_and_scatter(result, order, comm);
46  }
47  auto result = utilities::scatter_visibilities(comm);
48  return result;
49 }
50 
51 std::shared_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>> padmm_factory(
52  std::shared_ptr<sopt::LinearTransform<Vector<t_complex>> const> const &measurements,
53  t_real const sigma, const sopt::wavelets::SARA &sara, const utilities::vis_params &uv_data,
54  const sopt::mpi::Communicator &comm, const t_uint &imsizex, const t_uint &imsizey) {
55  auto const Psi = sopt::linear_transform<t_complex>(sara, imsizey, imsizex, comm);
56 
57 #if PURIFY_PADMM_ALGORITHM == 2
58  auto const epsilon = 3 * std::sqrt(comm.all_sum_all(std::pow(sigma, 2))) *
59  std::sqrt(2 * comm.all_sum_all(uv_data.size()));
60 #elif PURIFY_PADMM_ALGORITHM == 3 || PURIFY_PADMM_ALGORITHM == 1
61  auto const epsilon = 3 * std::sqrt(2 * uv_data.size()) * sigma;
62 #endif
63  const t_real regulariser_strength =
64  utilities::step_size(uv_data.vis, measurements,
65  std::make_shared<sopt::LinearTransform<Vector<t_complex>> const>(Psi),
66  sara.size()) *
67  1e-3;
68  PURIFY_MEDIUM_LOG("Epsilon {}", epsilon);
69  PURIFY_MEDIUM_LOG("Regulariser_Strength {}", regulariser_strength);
70 
71  // shared pointer because the convergence function need access to some data that we would rather
72  // not reproduce. E.g. padmm definition is self-referential.
73  auto padmm = std::make_shared<sopt::algorithm::ImagingProximalADMM<t_complex>>(uv_data.vis);
74  padmm->itermax(50)
75  .regulariser_strength(comm.all_reduce<t_real>(regulariser_strength, MPI_MAX))
76  .relative_variation(1e-3)
77  .l2ball_proximal_epsilon(epsilon)
78 #if PURIFY_PADMM_ALGORITHM == 2
79  // communicator ensuring l2 norm in l2ball proximal is global
80  .l2ball_proximal_communicator(comm)
81 #endif
82  // communicator ensuring l1 norm in l1 proximal is global
83  .l1_proximal_adjoint_space_comm(comm)
84  .tight_frame(false)
85  .l1_proximal_tolerance(1e-2)
86  .l1_proximal_nu(1)
87  .l1_proximal_itermax(50)
88  .l1_proximal_positivity_constraint(true)
89  .l1_proximal_real_constraint(true)
90  .residual_tolerance(epsilon)
91  .lagrange_update_scale(0.9)
92  .Psi(Psi)
93  .Phi(*measurements);
94  sopt::ScalarRelativeVariation<t_complex> conv(padmm->relative_variation(),
95  padmm->relative_variation(), "Objective function");
96  std::weak_ptr<decltype(padmm)::element_type> const padmm_weak(padmm);
97  padmm->residual_convergence([padmm_weak, conv, comm](
98  Vector<t_complex> const &x,
99  Vector<t_complex> const &residual) mutable -> bool {
100  auto const padmm = padmm_weak.lock();
101 #if PURIFY_PADMM_ALGORITHM == 2
102  auto const residual_norm = sopt::mpi::l2_norm(residual, padmm->l2ball_proximal_weights(), comm);
103  auto const result = residual_norm < padmm->residual_tolerance();
104 #elif PURIFY_PADMM_ALGORITHM == 3 || PURIFY_PADMM_ALGORITHM == 1
105  auto const residual_norm = sopt::l2_norm(residual, padmm->l2ball_proximal_weights());
106  auto const result =
107  comm.all_reduce<int8_t>(residual_norm < padmm->residual_tolerance(), MPI_LAND);
108 #endif
109  SOPT_LOW_LOG(" - [PADMM] Residuals: {} <? {}", residual_norm, padmm->residual_tolerance());
110  return result;
111  });
112 
113  padmm->objective_convergence([padmm_weak, conv, comm](Vector<t_complex> const &x,
114  Vector<t_complex> const &) mutable -> bool {
115  auto const padmm = padmm_weak.lock();
116 #if PURIFY_PADMM_ALGORITHM == 2
117  return conv(sopt::mpi::l1_norm(padmm->Psi().adjoint() * x, padmm->l1_proximal_weights(), comm));
118 #elif PURIFY_PADMM_ALGORITHM == 3 || PURIFY_PADMM_ALGORITHM == 1
119  return comm.all_reduce<uint8_t>(
120  conv(sopt::l1_norm(padmm->Psi().adjoint() * x, padmm->l1_proximal_weights())), MPI_LAND);
121 #endif
122  });
123 
124  auto convergence_function = [](const Vector<t_complex> &x) { return true; };
125  const std::shared_ptr<t_uint> iter = std::make_shared<t_uint>(0);
126  const auto algo_update = [uv_data, imsizex, imsizey, padmm_weak, iter,
127  comm](const Vector<t_complex> &x) -> bool {
128  auto padmm = padmm_weak.lock();
129  if (comm.is_root()) PURIFY_MEDIUM_LOG("Step size γ {}", padmm->regulariser_strength());
130  *iter = *iter + 1;
131  Vector<t_complex> const alpha = padmm->Psi().adjoint() * x;
132  const t_real new_regulariser_strength =
133  comm.all_reduce(alpha.real().cwiseAbs().maxCoeff(), MPI_MAX) * 1e-3;
134  if (comm.is_root()) PURIFY_MEDIUM_LOG("Step size γ update {}", new_regulariser_strength);
135  padmm->regulariser_strength(
136  ((std::abs(padmm->regulariser_strength() - new_regulariser_strength) > 0.2) and *iter < 200)
137  ? new_regulariser_strength
138  : padmm->regulariser_strength());
139  // updating parameter
140 
141  Vector<t_complex> const residual = padmm->Phi().adjoint() * (uv_data.vis - padmm->Phi() * x);
142 
143  if (comm.is_root()) {
144  pfitsio::write2d(x, imsizey, imsizex, "mpi_solution_update.fits");
145  pfitsio::write2d(residual, imsizey, imsizex, "mpi_residual_update.fits");
146  }
147  return true;
148  };
149  auto lambda = [convergence_function, algo_update](Vector<t_complex> const &x) {
150  return convergence_function(x) and algo_update(x);
151  };
152  padmm->is_converged(lambda);
153  return padmm;
154 }
155 
156 int main(int nargs, char const **args) {
157  sopt::logging::set_level("debug");
159  auto const session = sopt::mpi::init(nargs, args);
160  auto const world = sopt::mpi::Communicator::World();
161 
162  const std::string name = "realdata";
163  const std::string filename_base = vla_filename("../mwa/uvdump_");
164  const std::vector<std::string> filenames = {filename_base +
165  "01.vis"}; //, filename_base + "02.vis"};
166  auto const kernel = kernels::kernel::kb;
167  std::string kernel_name = "kb";
168  const bool w_term = false;
169 
170  const t_real cellsize = 20; // arcsec
171  const t_uint imsizex = 1024;
172  const t_uint imsizey = 1024;
173 
174  // Generating random uv(w) coverage
175  utilities::vis_params data = dirty_visibilities(filenames, world);
176 
177  t_real const sigma =
178  data.weights.norm() / std::sqrt(world.all_sum_all(data.weights.size())) * 0.5;
179  data.vis = (data.vis.array() * data.weights.array()) /
180  world.all_reduce(data.weights.array().cwiseAbs().maxCoeff(), MPI_MAX);
181 #if PURIFY_PADMM_ALGORITHM == 2 || PURIFY_PADMM_ALGORITHM == 3
182 #ifndef PURIFY_GPU
183  auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
185  world, data, imsizey, imsizex, cellsize, cellsize, 2, kernel, 4, 4, w_term),
186  100, 1e-4, world.broadcast(Vector<t_complex>::Random(imsizex * imsizey).eval())));
187 #else
188  af::setDevice(0);
189  auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
190  gpu::measurementoperator::init_degrid_operator_2d(world, data, imsizey, imsizex, cellsize,
191  cellsize, 2, kernel, 4, 4, w_term),
192  100, 1e-4, world.broadcast(Vector<t_complex>::Random(imsizex * imsizey).eval())));
193 
194 #endif
195 #elif PURIFY_PADMM_ALGORITHM == 1
196 #ifndef PURIFY_GPU
197  auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
198  measurementoperator::init_degrid_operator_2d_mpi<Vector<t_complex>>(
199  world, data, imsizey, imsizex, cellsize, cellsize, 2, kernel, 4, 4, w_term),
200  100, 1e-4, world.broadcast(Vector<t_complex>::Random(imsizex * imsizey).eval())));
201 
202 #else
203  af::setDevice(0);
204  auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
205  gpu::measurementoperator::init_degrid_operator_2d_mpi(world, data, imsizey, imsizex, cellsize,
206  cellsize, 2, kernel, 4, 4, w_term),
207  100, 1e-4, world.broadcast(Vector<t_complex>::Random(imsizex * imsizey).eval())));
208 
209 #endif
210 #endif
211  auto const sara = sopt::wavelets::distribute_sara(
212  sopt::wavelets::SARA{
213  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
214  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
215  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)},
216  world);
217 
218  Vector<t_real> const dirty_image = (measurements->adjoint() * (data.vis)).real();
219 
220  if (world.is_root()) {
221  // then writes stuff to files
222  boost::filesystem::path const path(output_filename(name));
223 #if PURIFY_PADMM_ALGORITHM == 3
224  auto const pb_path = path / kernel_name / "local_epsilon_replicated_grids";
225 #elif PURIFY_PADMM_ALGORITHM == 2
226  auto const pb_path = path / kernel_name / "global_epsilon_replicated_grids";
227 #elif PURIFY_PADMM_ALGORITHM == 1
228  auto const pb_path = path / kernel_name / "local_epsilon_distributed_grids";
229 #else
230 #error Unknown or unimplemented algorithm
231 #endif
232  mkdir_recursive(pb_path);
233 
234  pfitsio::write2d(dirty_image, imsizey, imsizex, (pb_path / "dirty.fits").native());
235  }
236 
237  // Create the padmm solver
238  auto const padmm = padmm_factory(measurements, sigma, sara, data, world, imsizey, imsizex);
239  // calls padmm
240  auto const diagnostic = (*padmm)();
241 
242  // makes sure we set things up correctly
243  assert(world.broadcast(diagnostic.x).isApprox(diagnostic.x));
244 
245  Vector<t_real> const residual_image = (measurements->adjoint() * diagnostic.residual).real();
246  if (world.is_root()) {
247  // then writes stuff to files
248  boost::filesystem::path const path(output_filename(name));
249 #if PURIFY_PADMM_ALGORITHM == 3
250  auto const pb_path = path / kernel_name / "local_epsilon_replicated_grids";
251 #elif PURIFY_PADMM_ALGORITHM == 2
252  auto const pb_path = path / kernel_name / "global_epsilon_replicated_grids";
253 #elif PURIFY_PADMM_ALGORITHM == 1
254  auto const pb_path = path / kernel_name / "local_epsilon_distributed_grids";
255 #else
256 #error Unknown or unimplemented algorithm
257 #endif
258  mkdir_recursive(pb_path);
259 
260  pfitsio::write2d(dirty_image, imsizey, imsizex, (pb_path / "dirty.fits").native());
261  pfitsio::write2d(diagnostic.x.real(), imsizey, imsizex, (pb_path / "solution.fits").native());
262  pfitsio::write2d(residual_image, imsizey, imsizex, (pb_path / "residual.fits").native());
263  }
264  return 0;
265 }
#define PURIFY_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:205
std::vector< t_int > distribute_measurements(Vector< t_real > const &u, Vector< t_real > const &v, Vector< t_real > const &w, t_int const number_of_nodes, distribute::plan const distribution_plan, t_int const &grid_size)
Distribute visiblities into groups.
Definition: distribute.cc:6
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
Definition: logging.h:137
std::shared_ptr< sopt::LinearTransform< T > > init_degrid_operator_2d(const Vector< t_real > &u, const Vector< t_real > &v, const Vector< t_real > &w, const Vector< t_complex > &weights, const t_uint &imsizey, const t_uint &imsizex, const t_real &oversample_ratio=2, const kernels::kernel kernel=kernels::kernel::kb, const t_uint Ju=4, const t_uint Jv=4, const bool w_stacking=false, const t_real &cellx=1, const t_real &celly=1)
Returns linear transform that is the standard degridding operator.
Definition: operators.h:608
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
Definition: pfitsio.cc:30
t_real step_size(T const &vis, const std::shared_ptr< sopt::LinearTransform< T > const > &measurements, const std::shared_ptr< sopt::LinearTransform< T > const > &wavelets, const t_uint sara_size)
Calculate step size using MPI (does not include factor of 1e-3)
Definition: mpi_utilities.h:79
vis_params scatter_visibilities(vis_params const &params, std::vector< t_int > const &sizes, sopt::mpi::Communicator const &comm)
vis_params regroup_and_scatter(vis_params const &params, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
utilities::vis_params read_visibility(const std::vector< std::string > &names, const bool w_term)
Read visibility files from name of vector.
void mkdir_recursive(const std::string &path)
recursively create directories when they do not exist
std::string output_filename(std::string const &filename)
Test output file.
std::string vla_filename(std::string const &filename)
Specific vla data.
std::shared_ptr< sopt::algorithm::ImagingProximalADMM< t_complex > > padmm_factory(std::shared_ptr< sopt::LinearTransform< Vector< t_complex >> const > const &measurements, t_real const sigma, const sopt::wavelets::SARA &sara, const utilities::vis_params &uv_data, const sopt::mpi::Communicator &comm, const t_uint &imsizex, const t_uint &imsizey)
int main(int nargs, char const **args)
utilities::vis_params dirty_visibilities(const std::vector< std::string > &names)
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)
Vector< t_complex > vis
Definition: uvw_utilities.h:22
t_uint size() const
return number of measurements
Definition: uvw_utilities.h:54
Vector< t_complex > weights
Definition: uvw_utilities.h:23