PURIFY
Next-generation radio interferometric imaging
padmm_reweighted_simulation.cc
Go to the documentation of this file.
1 #include "purify/config.h"
2 #include "purify/types.h"
3 #include <array>
4 #include <ctime>
5 #include <memory>
6 #include <random>
7 #include <boost/math/special_functions/erf.hpp>
8 #include "purify/directories.h"
9 #include "purify/logging.h"
10 #include "purify/operators.h"
11 #include "purify/pfitsio.h"
12 #include "purify/utilities.h"
13 #include <sopt/imaging_padmm.h>
14 #include <sopt/positive_quadrant.h>
15 #include <sopt/power_method.h>
16 #include <sopt/relative_variation.h>
17 #include <sopt/reweighted.h>
18 #include <sopt/utilities.h>
19 #include <sopt/wavelets.h>
20 #include <sopt/wavelets/sara.h>
21 
22 int main(int nargs, char const **args) {
23  if (nargs != 8) {
24  std::cerr << " Wrong number of arguments! " << '\n';
25  return 1;
26  }
27 
28  using namespace purify;
29  sopt::logging::set_level("debug");
30 
31  std::string const kernel = args[1];
32  t_real const over_sample = std::stod(static_cast<std::string>(args[2]));
33  t_int const J = static_cast<t_int>(std::stod(static_cast<std::string>(args[3])));
34  t_real const m_over_n = std::stod(static_cast<std::string>(args[4]));
35  std::string const test_number = static_cast<std::string>(args[5]);
36  t_real const ISNR = std::stod(static_cast<std::string>(args[6]));
37  std::string const name = static_cast<std::string>(args[7]);
38 
39  std::string const fitsfile = image_filename(name + ".fits");
40 
41  std::string const dirty_image_fits =
42  output_filename(name + "_dirty_" + kernel + "_" + test_number + ".fits");
43  std::string const results =
44  output_filename(name + "_results_" + kernel + "_" + test_number + ".txt");
45 
46  auto sky_model = pfitsio::read2d(fitsfile);
47  auto sky_model_max = sky_model.array().abs().maxCoeff();
48  sky_model = sky_model / sky_model_max;
49  t_int const number_of_vis = std::floor(m_over_n * sky_model.size());
50  t_real const sigma_m = constant::pi / 3;
51  auto uv_data = utilities::random_sample_density(number_of_vis, 0, sigma_m);
52  uv_data.units = utilities::vis_units::radians;
53  PURIFY_MEDIUM_LOG("Number of measurements: {}", uv_data.u.size());
54  auto simulate_measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
56  uv_data.u, uv_data.v, uv_data.w, uv_data.weights, sky_model.cols(), sky_model.rows(),
57  over_sample, kernels::kernel_from_string.at("kb"), J, J),
58  100, 1e-4, Vector<t_complex>::Random(sky_model.size())));
59  uv_data.vis = simulate_measurements * sky_model;
60 
61  // putting measurement operator in a form that sopt can use
62  auto measurements_transform = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
64  uv_data.u, uv_data.v, uv_data.w, uv_data.weights, sky_model.cols(), sky_model.rows(),
65  over_sample, kernels::kernel_from_string.at(kernel), J, J),
66  100, 1e-4, Vector<t_complex>::Random(sky_model.size())));
67 
68  sopt::wavelets::SARA const sara{
69  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
70  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
71  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
72 
73  auto const Psi = sopt::linear_transform<t_complex>(sara, sky_model.rows(), sky_model.cols());
74 
75  // working out value of sigma given SNR of 30
76  t_real sigma = utilities::SNR_to_standard_deviation(uv_data.vis, ISNR);
77  // adding noise to visibilities
78  uv_data.vis = utilities::add_noise(uv_data.vis, 0., sigma);
79 
80  Vector<> dimage = (measurements_transform.adjoint() * uv_data.vis).real();
81  t_real const max_val = dimage.array().abs().maxCoeff();
82  dimage = dimage / max_val;
83  Vector<t_complex> initial_estimate = Vector<t_complex>::Zero(dimage.size());
84 
85  auto const epsilon = utilities::calculate_l2_radius(uv_data.vis.size(), sigma);
86  auto const purify_regulariser_strength =
87  (Psi.adjoint() * (measurements_transform.adjoint() * uv_data.vis).eval()).real().maxCoeff() *
88  1e-3;
89 
90  PURIFY_HIGH_LOG("Starting sopt!");
91  PURIFY_MEDIUM_LOG("Epsilon {}", epsilon);
92  PURIFY_MEDIUM_LOG("Regulariser_Strength {}", purify_regulariser_strength);
93  auto const padmm = sopt::algorithm::ImagingProximalADMM<t_complex>(uv_data.vis)
94  .regulariser_strength(purify_regulariser_strength)
95  .relative_variation(1e-3)
96  .l2ball_proximal_epsilon(epsilon * 1.001)
97  .tight_frame(false)
98  .l1_proximal_tolerance(1e-2)
99  .l1_proximal_nu(1)
100  .l1_proximal_itermax(50)
101  .l1_proximal_positivity_constraint(true)
102  .l1_proximal_real_constraint(true)
103  .residual_convergence(epsilon * 1.001)
104  .lagrange_update_scale(0.9)
105  .Psi(Psi)
106  .Phi(measurements_transform);
107  // Timing reconstruction
108  auto const posq = sopt::algorithm::positive_quadrant(padmm);
109  auto const min_delta = sigma * std::sqrt(uv_data.vis.size()) / std::sqrt(9 * sky_model.size());
110  // Sets weight after each padmm iteration.
111  // In practice, this means replacing the proximal of the l1 objective function.
112  auto const reweighted = sopt::algorithm::reweighted(padmm).min_delta(min_delta).is_converged(
113  sopt::RelativeVariation<std::complex<t_real>>(1e-3));
114  std::clock_t c_start = std::clock();
115  auto const diagnostic = reweighted();
116  std::clock_t c_end = std::clock();
117 
118  Image<t_complex> image =
119  Image<t_complex>::Map(diagnostic.algo.x.data(), sky_model.rows(), sky_model.cols());
120 
121  Vector<t_complex> original = Vector<t_complex>::Map(sky_model.data(), sky_model.size(), 1);
122  Image<t_complex> res = sky_model - image;
123  Vector<t_complex> residual = Vector<t_complex>::Map(res.data(), image.size(), 1);
124 
125  auto snr = 20. * std::log10(original.norm() / residual.norm()); // SNR of reconstruction
126  auto total_time = (c_end - c_start) / CLOCKS_PER_SEC; // total time for solver to run in seconds
127  // writing snr and time to a file
128  std::ofstream out(results);
129  out.precision(13);
130  out << snr << " " << total_time;
131  out.close();
132 
133  return 0;
134 }
#define PURIFY_HIGH_LOG(...)
High priority message.
Definition: logging.h:203
#define PURIFY_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:205
const t_real pi
mathematical constant
Definition: types.h:70
const std::map< std::string, kernel > kernel_from_string
Definition: kernels.h:16
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
Definition: logging.h:137
std::shared_ptr< sopt::LinearTransform< T > > init_degrid_operator_2d(const Vector< t_real > &u, const Vector< t_real > &v, const Vector< t_real > &w, const Vector< t_complex > &weights, const t_uint &imsizey, const t_uint &imsizex, const t_real &oversample_ratio=2, const kernels::kernel kernel=kernels::kernel::kb, const t_uint Ju=4, const t_uint Jv=4, const bool w_stacking=false, const t_real &cellx=1, const t_real &celly=1)
Returns linear transform that is the standard degridding operator.
Definition: operators.h:608
Image< t_complex > read2d(const std::string &fits_name)
Read image from fits file.
Definition: pfitsio.cc:109
t_real SNR_to_standard_deviation(const Vector< t_complex > &y0, const t_real &SNR)
Converts SNR to RMS noise.
Definition: utilities.cc:101
Vector< t_complex > add_noise(const Vector< t_complex > &y0, const t_complex &mean, const t_real &standard_deviation)
Add guassian noise to vector.
Definition: utilities.cc:113
utilities::vis_params random_sample_density(const t_int vis_num, const t_real mean, const t_real standard_deviation, const t_real rms_w)
Generates a random visibility coverage.
t_real calculate_l2_radius(const t_uint y_size, const t_real &sigma, const t_real &n_sigma, const std::string distirbution)
A function that calculates the l2 ball radius for sopt.
Definition: utilities.cc:75
std::string output_filename(std::string const &filename)
Test output file.
std::string image_filename(std::string const &filename)
Image filename.
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)
int main(int nargs, char const **args)