PURIFY
Next-generation radio interferometric imaging
padmm_mpi_random_coverage.cc
Go to the documentation of this file.
1 #include "purify/types.h"
2 #include <array>
3 #include <memory>
4 #include <random>
5 #include <boost/filesystem.hpp>
6 #include <boost/math/special_functions/erf.hpp>
7 #include "purify/directories.h"
8 #include "purify/distribute.h"
9 #include "purify/logging.h"
10 #include "purify/mpi_utilities.h"
11 #include "purify/operators.h"
12 #include "purify/pfitsio.h"
14 #include "purify/utilities.h"
15 #include <sopt/imaging_padmm.h>
16 #include <sopt/mpi/communicator.h>
17 #include <sopt/mpi/session.h>
18 #include <sopt/power_method.h>
19 #include <sopt/relative_variation.h>
20 #include <sopt/utilities.h>
21 #include <sopt/wavelets.h>
22 #include <sopt/wavelets/sara.h>
23 
24 #ifdef PURIFY_GPU
25 #include "purify/operators_gpu.h"
26 #endif
27 
28 #ifndef PURIFY_PADMM_ALGORITHM
29 #define PURIFY_PADMM_ALGORITHM 2
30 #endif
31 
32 using namespace purify;
33 
34 std::tuple<utilities::vis_params, t_real> dirty_visibilities(
35  Image<t_complex> const &ground_truth_image, t_uint number_of_vis, t_real snr,
36  const std::tuple<bool, t_real> &w_term) {
37  auto uv_data =
38  utilities::random_sample_density(number_of_vis, 0, constant::pi / 3, std::get<0>(w_term));
39  uv_data.units = utilities::vis_units::radians;
40  PURIFY_HIGH_LOG("Number of measurements / number of pixels: {}",
41  uv_data.u.size() / ground_truth_image.size());
42  // creating operator to generate measurements
43  auto const sky_measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
45  uv_data, ground_truth_image.rows(), ground_truth_image.cols(), std::get<1>(w_term),
46  std::get<1>(w_term), 2, kernels::kernel::kb, 8, 8, std::get<0>(w_term)),
47  100, 1e-4, Vector<t_complex>::Random(ground_truth_image.size())));
48  // Generates measurements from image
49  uv_data.vis = (*sky_measurements) *
50  Image<t_complex>::Map(ground_truth_image.data(), ground_truth_image.size(), 1);
51 
52  // working out value of signal given SNR of 30
53  auto const sigma = utilities::SNR_to_standard_deviation(uv_data.vis, snr);
54  // adding noise to visibilities
55  uv_data.vis = utilities::add_noise(uv_data.vis, 0., sigma);
56  return std::make_tuple(uv_data, sigma);
57 }
58 
59 std::tuple<utilities::vis_params, t_real> dirty_visibilities(
60  Image<t_complex> const &ground_truth_image, t_uint number_of_vis, t_real snr,
61  const std::tuple<bool, t_real> &w_term, sopt::mpi::Communicator const &comm) {
62  if (comm.size() == 1) return dirty_visibilities(ground_truth_image, number_of_vis, snr, w_term);
63  if (comm.is_root()) {
64  auto result = dirty_visibilities(ground_truth_image, number_of_vis, snr, w_term);
65  comm.broadcast(std::get<1>(result));
66  auto const order =
68  std::get<0>(result) = utilities::regroup_and_scatter(std::get<0>(result), order, comm);
69  return result;
70  }
71  auto const sigma = comm.broadcast<t_real>();
72  return std::make_tuple(utilities::scatter_visibilities(comm), sigma);
73 }
74 
75 std::shared_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>> padmm_factory(
76  std::shared_ptr<sopt::LinearTransform<Vector<t_complex>> const> const &measurements,
77  const sopt::wavelets::SARA &sara, const Image<t_complex> &ground_truth_image,
78  const utilities::vis_params &uv_data, const t_real sigma, const sopt::mpi::Communicator &comm) {
79  auto const Psi = sopt::linear_transform<t_complex>(sara, ground_truth_image.rows(),
80  ground_truth_image.cols(), comm);
81 
82 #if PURIFY_PADMM_ALGORITHM == 2
83  auto const epsilon = std::sqrt(
84  comm.all_sum_all(std::pow(utilities::calculate_l2_radius(uv_data.vis.size(), sigma), 2)));
85 #elif PURIFY_PADMM_ALGORITHM == 3 || PURIFY_PADMM_ALGORITHM == 1
86  auto const epsilon = utilities::calculate_l2_radius(uv_data.vis.size(), sigma);
87 #endif
88  PURIFY_LOW_LOG("SARA Size = {}, Rank = {}", sara.size(), comm.rank());
89  const t_real regulariser_strength =
90  utilities::step_size(uv_data.vis, measurements,
91  std::make_shared<sopt::LinearTransform<Vector<t_complex>> const>(Psi),
92  sara.size()) *
93  1e-3;
94  PURIFY_LOW_LOG("Epsilon {}, Rank = {}", epsilon, comm.rank());
95  PURIFY_LOW_LOG("Regulariser_Strength {}, SARA Size = {}, Rank = {}", regulariser_strength,
96  sara.size(), comm.rank());
97 
98  // shared pointer because the convergence function need access to some data that we would rather
99  // not reproduce. E.g. padmm definition is self-referential.
100  auto padmm = std::make_shared<sopt::algorithm::ImagingProximalADMM<t_complex>>(uv_data.vis);
101  padmm->itermax(50)
102  .regulariser_strength(comm.all_reduce(regulariser_strength, MPI_MAX))
103  .relative_variation(1e-3)
104  .l2ball_proximal_epsilon(epsilon)
105 #if PURIFY_PADMM_ALGORITHM == 2
106  // communicator ensuring l2 norm in l2ball proximal is global
107  .l2ball_proximal_communicator(comm)
108 #endif
109  // communicator ensuring l1 norm in l1 proximal is global
110  .l1_proximal_adjoint_space_comm(comm)
111  .tight_frame(false)
112  .l1_proximal_tolerance(1e-2)
113  .l1_proximal_nu(1)
114  .l1_proximal_itermax(50)
115  .l1_proximal_positivity_constraint(true)
116  .l1_proximal_real_constraint(true)
117  .residual_tolerance(epsilon)
118  .lagrange_update_scale(0.9)
119  .Psi(Psi)
120  .Phi(*measurements);
121  sopt::ScalarRelativeVariation<t_complex> conv(padmm->relative_variation(),
122  padmm->relative_variation(), "Objective function");
123  std::weak_ptr<decltype(padmm)::element_type> const padmm_weak(padmm);
124  padmm->residual_convergence([padmm_weak, conv, comm](
125  Vector<t_complex> const &x,
126  Vector<t_complex> const &residual) mutable -> bool {
127  auto const padmm = padmm_weak.lock();
128 #if PURIFY_PADMM_ALGORITHM == 2
129  auto const residual_norm = sopt::mpi::l2_norm(residual, padmm->l2ball_proximal_weights(), comm);
130  auto const result = residual_norm < padmm->residual_tolerance();
131 #elif PURIFY_PADMM_ALGORITHM == 3 || PURIFY_PADMM_ALGORITHM == 1
132  auto const residual_norm = sopt::l2_norm(residual, padmm->l2ball_proximal_weights());
133  auto const result =
134  comm.all_reduce<int8_t>(residual_norm < padmm->residual_tolerance(), MPI_LAND);
135 #endif
136  SOPT_LOW_LOG(" - [PADMM] Residuals: {} <? {}", residual_norm, padmm->residual_tolerance());
137  return result;
138  });
139 
140  padmm->objective_convergence([padmm_weak, conv, comm](Vector<t_complex> const &x,
141  Vector<t_complex> const &) mutable -> bool {
142  auto const padmm = padmm_weak.lock();
143 #if PURIFY_PADMM_ALGORITHM == 2
144  return conv(sopt::mpi::l1_norm(padmm->Psi().adjoint() * x, padmm->l1_proximal_weights(), comm));
145 #elif PURIFY_PADMM_ALGORITHM == 3 || PURIFY_PADMM_ALGORITHM == 1
146  return comm.all_reduce<uint8_t>(
147  conv(sopt::l1_norm(padmm->Psi().adjoint() * x, padmm->l1_proximal_weights())), MPI_LAND);
148 #endif
149  });
150 
151  return padmm;
152 }
153 
154 int main(int nargs, char const **args) {
155  sopt::logging::set_level("debug");
157  auto const session = sopt::mpi::init(nargs, args);
158  auto const world = sopt::mpi::Communicator::World();
159 
160  const t_real FoV = 1; // deg
161  const t_real max_w = 100; // lambda
162  const std::string name = "M31";
163  const t_real snr = 30;
164  auto const kernel = "kb";
165  const bool w_term = true;
166  // string of fits file of image to reconstruct
167  auto ground_truth_image = pfitsio::read2d(image_filename(name + ".fits"));
168  ground_truth_image /= ground_truth_image.array().abs().maxCoeff();
169 
170  const t_real cellsize = FoV / ground_truth_image.cols() * 60. * 60.;
171  // determine amount of visibilities to simulate
172  t_int const number_of_pixels = ground_truth_image.size();
173  t_int const number_of_vis = std::floor(number_of_pixels * 0.1);
174 
175  // Generating random uv(w) coverage
176  auto const data = dirty_visibilities(ground_truth_image, number_of_vis, snr,
177  std::make_tuple(w_term, cellsize), world);
178 #if PURIFY_PADMM_ALGORITHM == 2 || PURIFY_PADMM_ALGORITHM == 3
179 #ifndef PURIFY_GPU
180  auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
182  world, std::get<0>(data), ground_truth_image.rows(), ground_truth_image.cols(), cellsize,
183  cellsize, 2, kernels::kernel_from_string.at(kernel), 8, 8, w_term),
184  100, 1e-4,
185  world.broadcast(
186  Vector<t_complex>::Random(ground_truth_image.rows() * ground_truth_image.cols())
187  .eval())));
188 
189 #else
190  af::setDevice(0);
191  auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
193  world, std::get<0>(data), ground_truth_image.rows(), ground_truth_image.cols(), cellsize,
194  cellsize, 2, kernels::kernel_from_string.at(kernel), 8, 8, w_term),
195  100, 1e-4,
196  world.broadcast(
197  Vector<t_complex>::Random(ground_truth_image.rows() * ground_truth_image.cols())
198  .eval())));
199 
200 #endif
201 #elif PURIFY_PADMM_ALGORITHM == 1
202 #ifndef PURIFY_GPU
203  auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
204  measurementoperator::init_degrid_operator_2d_mpi<Vector<t_complex>>(
205  world, std::get<0>(data), ground_truth_image.rows(), ground_truth_image.cols(), cellsize,
206  cellsize, 2, kernels::kernel_from_string.at(kernel), 8, 8, w_term),
207  100, 1e-4,
208  world.broadcast(
209  Vector<t_complex>::Random(ground_truth_image.rows() * ground_truth_image.cols())
210  .eval())));
211 
212 #else
213  af::setDevice(0);
214  auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
215  gpu::measurementoperator::init_degrid_operator_2d_mpi(
216  world, std::get<0>(data), ground_truth_image.rows(), ground_truth_image.cols(), cellsize,
217  cellsize, 2, kernels::kernel_from_string.at(kernel), 8, 8, w_term),
218  100, 1e-4,
219  world.broadcast(
220  Vector<t_complex>::Random(ground_truth_image.rows() * ground_truth_image.cols())
221  .eval())));
222 
223 #endif
224 #endif
225  auto const sara = sopt::wavelets::distribute_sara(
226  sopt::wavelets::SARA{
227  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
228  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
229  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)},
230  world);
231 
232  // Create the padmm solver
233  auto const padmm = padmm_factory(measurements, sara, ground_truth_image, std::get<0>(data),
234  std::get<1>(data), world);
235  // calls padmm
236  auto const diagnostic = (*padmm)();
237 
238  // makes sure we set things up correctly
239  assert(diagnostic.x.size() == ground_truth_image.size());
240  assert(world.broadcast(diagnostic.x).isApprox(diagnostic.x));
241 
242  // then writes stuff to files
243  auto const residual_image = (measurements->adjoint() * diagnostic.residual).real();
244  auto const dirty_image = (measurements->adjoint() * std::get<0>(data).vis).real();
245  if (world.is_root()) {
246  boost::filesystem::path const path(output_filename(name));
247 #if PURIFY_PADMM_ALGORITHM == 3
248  auto const pb_path = path / kernel / "local_epsilon_replicated_grids";
249 #elif PURIFY_PADMM_ALGORITHM == 2
250  auto const pb_path = path / kernel / "global_epsilon_replicated_grids";
251 #elif PURIFY_PADMM_ALGORITHM == 1
252  auto const pb_path = path / kernel / "local_epsilon_distributed_grids";
253 #else
254 #error Unknown or unimplemented algorithm
255 #endif
256  mkdir_recursive(pb_path);
257 
258  pfitsio::write2d(ground_truth_image.real(), (path / "input.fits").native());
259  pfitsio::write2d(dirty_image, ground_truth_image.rows(), ground_truth_image.cols(),
260  (pb_path / "dirty.fits").native());
261  pfitsio::write2d(diagnostic.x.real(), ground_truth_image.rows(), ground_truth_image.cols(),
262  (pb_path / "solution.fits").native());
263  pfitsio::write2d(residual_image, ground_truth_image.rows(), ground_truth_image.cols(),
264  (pb_path / "residual.fits").native());
265  }
266  return 0;
267 }
#define PURIFY_LOW_LOG(...)
Low priority message.
Definition: logging.h:207
#define PURIFY_HIGH_LOG(...)
High priority message.
Definition: logging.h:203
const t_real pi
mathematical constant
Definition: types.h:70
std::vector< t_int > distribute_measurements(Vector< t_real > const &u, Vector< t_real > const &v, Vector< t_real > const &w, t_int const number_of_nodes, distribute::plan const distribution_plan, t_int const &grid_size)
Distribute visiblities into groups.
Definition: distribute.cc:6
const std::map< std::string, kernel > kernel_from_string
Definition: kernels.h:16
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
Definition: logging.h:137
std::shared_ptr< sopt::LinearTransform< T > > init_degrid_operator_2d(const Vector< t_real > &u, const Vector< t_real > &v, const Vector< t_real > &w, const Vector< t_complex > &weights, const t_uint &imsizey, const t_uint &imsizex, const t_real &oversample_ratio=2, const kernels::kernel kernel=kernels::kernel::kb, const t_uint Ju=4, const t_uint Jv=4, const bool w_stacking=false, const t_real &cellx=1, const t_real &celly=1)
Returns linear transform that is the standard degridding operator.
Definition: operators.h:608
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
Definition: pfitsio.cc:30
Image< t_complex > read2d(const std::string &fits_name)
Read image from fits file.
Definition: pfitsio.cc:109
t_real step_size(T const &vis, const std::shared_ptr< sopt::LinearTransform< T > const > &measurements, const std::shared_ptr< sopt::LinearTransform< T > const > &wavelets, const t_uint sara_size)
Calculate step size using MPI (does not include factor of 1e-3)
Definition: mpi_utilities.h:79
t_real SNR_to_standard_deviation(const Vector< t_complex > &y0, const t_real &SNR)
Converts SNR to RMS noise.
Definition: utilities.cc:101
Vector< t_complex > add_noise(const Vector< t_complex > &y0, const t_complex &mean, const t_real &standard_deviation)
Add guassian noise to vector.
Definition: utilities.cc:113
vis_params scatter_visibilities(vis_params const &params, std::vector< t_int > const &sizes, sopt::mpi::Communicator const &comm)
vis_params regroup_and_scatter(vis_params const &params, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
utilities::vis_params random_sample_density(const t_int vis_num, const t_real mean, const t_real standard_deviation, const t_real rms_w)
Generates a random visibility coverage.
t_real calculate_l2_radius(const t_uint y_size, const t_real &sigma, const t_real &n_sigma, const std::string distirbution)
A function that calculates the l2 ball radius for sopt.
Definition: utilities.cc:75
void mkdir_recursive(const std::string &path)
recursively create directories when they do not exist
std::string output_filename(std::string const &filename)
Test output file.
std::string image_filename(std::string const &filename)
Image filename.
std::shared_ptr< sopt::algorithm::ImagingProximalADMM< t_complex > > padmm_factory(std::shared_ptr< sopt::LinearTransform< Vector< t_complex >> const > const &measurements, const sopt::wavelets::SARA &sara, const Image< t_complex > &ground_truth_image, const utilities::vis_params &uv_data, const t_real sigma, const sopt::mpi::Communicator &comm)
std::tuple< utilities::vis_params, t_real > dirty_visibilities(Image< t_complex > const &ground_truth_image, t_uint number_of_vis, t_real snr, const std::tuple< bool, t_real > &w_term)
int main(int nargs, char const **args)
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)
Vector< t_complex > vis
Definition: uvw_utilities.h:22