PURIFY
Next-generation radio interferometric imaging
padmm_real_data.cc
Go to the documentation of this file.
1 #include <array>
2 #include <memory>
3 #include <random>
4 #include <boost/math/special_functions/erf.hpp>
5 #include "purify/directories.h"
6 #include "purify/logging.h"
7 #include "purify/operators.h"
8 #include <sopt/credible_region.h>
9 #include <sopt/imaging_padmm.h>
10 #include <sopt/relative_variation.h>
11 #include <sopt/utilities.h>
12 #include <sopt/wavelets.h>
13 #include <sopt/wavelets/sara.h>
14 #ifdef PURIFY_GPU
15 #include "purify/operators_gpu.h"
16 #endif
17 #include "purify/types.h"
18 #include "purify/cimg.h"
19 #include "purify/pfitsio.h"
20 #include "purify/utilities.h"
21 #include "purify/uvfits.h"
22 #include "purify/wproj_utilities.h"
23 using namespace purify;
24 
25 void padmm(const std::string &name, const t_uint &imsizex, const t_uint &imsizey,
26  const std::string &kernel, const t_int J, const utilities::vis_params &uv_data,
27  const t_real sigma, const std::tuple<bool, t_real> &w_term) {
28  std::string const outfile_fits = output_filename(name + "_solution.fits");
29  std::string const residual_fits = output_filename(name + "_residual.fits");
30  std::string const dirty_image_fits = output_filename(name + "_dirty.fits");
31  std::string const psf_image_fits = output_filename(name + "_psf.fits");
32 
33  t_real const over_sample = 2;
34  std::shared_ptr<sopt::LinearTransform<Vector<t_complex>> const> measurements_transform =
35  measurementoperator::init_degrid_operator_2d<Vector<t_complex>>(
36  uv_data, imsizey, imsizex, std::get<1>(w_term), std::get<1>(w_term), over_sample,
37  kernels::kernel_from_string.at(kernel), J, J, std::get<0>(w_term));
38  t_uint const M = uv_data.size();
39  t_uint const N = imsizex * imsizey;
40  sopt::wavelets::SARA const sara{
41  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
42  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
43  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
44 
45  auto const Psi = sopt::linear_transform<t_complex>(sara, imsizey, imsizex);
46  const Vector<> dimage = (measurements_transform->adjoint() * uv_data.vis).real();
47  Matrix<t_complex> point = Matrix<t_complex>::Zero(imsizey, imsizex);
48  point(int(imsizey / 2), int(imsizex / 2)) = 1.;
49  const Vector<> psf =
50  (measurements_transform->adjoint() *
51  (*measurements_transform * Vector<t_complex>::Map(point.data(), point.size())).eval())
52  .real();
53  Vector<t_complex> initial_estimate = Vector<t_complex>::Zero(dimage.size());
54  pfitsio::write2d(Image<t_real>::Map(dimage.data(), imsizey, imsizex), dirty_image_fits);
55  pfitsio::write2d(Image<t_real>::Map(psf.data(), imsizey, imsizex), psf_image_fits);
56  auto const epsilon = 3 * std::sqrt(2 * uv_data.size()) * sigma;
57  auto const regulariser_strength =
58  (measurements_transform->adjoint() * uv_data.vis).real().maxCoeff() * 1e-3;
59  PURIFY_HIGH_LOG("Using epsilon of {}", epsilon);
60 #ifdef PURIFY_CImg
61  auto const canvas = std::make_shared<CDisplay>(
62  cimg::make_display(Vector<t_real>::Zero(2 * imsizex * imsizey), 2 * imsizex, imsizey));
63  canvas->resize(true);
64  auto const show_image = [&, measurements_transform](const Vector<t_complex> &x) -> bool {
65  if (!canvas->is_closed()) {
66  const Vector<t_complex> res =
67  (measurements_transform->adjoint() * (uv_data.vis - (*measurements_transform * x)));
68  const auto img1 = cimg::make_image(x.real().eval(), imsizey, imsizex)
69  .get_normalize(0, 1)
70  .get_resize(512, 512);
71  const auto img2 = cimg::make_image(res.real().eval(), imsizey, imsizex)
72  .get_normalize(0, 1)
73  .get_resize(512, 512);
74  const auto results = CImageList<t_real>(img1, img2);
75  canvas->display(results);
76  canvas->resize(true);
77  }
78  return true;
79  };
80 #endif
81  auto padmm = std::make_shared<sopt::algorithm::ImagingProximalADMM<t_complex>>(uv_data.vis);
82  padmm->itermax(500)
83  .regulariser_strength(regulariser_strength)
84  .relative_variation(1e-3)
85  .l2ball_proximal_epsilon(epsilon)
86  .tight_frame(false)
87  .l1_proximal_tolerance(1e-2)
88  .l1_proximal_nu(1.)
89  .l1_proximal_itermax(50)
90  .l1_proximal_positivity_constraint(true)
91  .l1_proximal_real_constraint(true)
92  .residual_convergence(epsilon)
93  .lagrange_update_scale(0.9)
94  .Psi(Psi)
95  .Phi(*measurements_transform);
96 
97  auto convergence_function = [](const Vector<t_complex> &x) { return true; };
98  const std::shared_ptr<t_uint> iter = std::make_shared<t_uint>(0);
99 
100  std::weak_ptr<decltype(padmm)::element_type> const padmm_weak(padmm);
101  const auto algo_update = [uv_data, imsizex, imsizey, padmm_weak,
102  iter](const Vector<t_complex> &x) -> bool {
103  auto padmm = padmm_weak.lock();
104  PURIFY_MEDIUM_LOG("Step size γ {}", padmm->regulariser_strength());
105  *iter = *iter + 1;
106  Vector<t_complex> const alpha = padmm->Psi().adjoint() * x;
107  // updating parameter
108  const t_real new_regulariser_strength = alpha.real().cwiseAbs().maxCoeff() * 1e-3;
109  PURIFY_MEDIUM_LOG("Step size γ update {}", new_regulariser_strength);
110  padmm->regulariser_strength(
111  ((std::abs(padmm->regulariser_strength() - new_regulariser_strength) > 0.2) and *iter < 200)
112  ? new_regulariser_strength
113  : padmm->regulariser_strength());
114 
115  Vector<t_complex> const residual = padmm->Phi().adjoint() * (uv_data.vis - padmm->Phi() * x);
116 
117  pfitsio::write2d(x, imsizey, imsizex, "solution_update.fits");
118  pfitsio::write2d(residual, imsizey, imsizex, "residual_update.fits");
119  return true;
120  };
121  auto lambda = [=](Vector<t_complex> const &x) {
122  return convergence_function(x)
123 #ifdef PURIFY_CImg
124  and show_image(x)
125 #endif
126  and algo_update(x);
127  };
128  padmm->is_converged(lambda);
129  auto const diagnostic = (*padmm)();
130  Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
131  pfitsio::write2d(image.real(), outfile_fits);
132  Vector<t_complex> residuals = measurements_transform->adjoint() *
133  (uv_data.vis - ((*measurements_transform) * diagnostic.x));
134  Image<t_complex> residual_image = Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
135  pfitsio::write2d(residual_image.real(), residual_fits);
136 #ifdef PURIFY_CImg
137  const auto results = CImageList<t_real>(
138  cimg::make_image(diagnostic.x.real().eval(), imsizey, imsizex).get_resize(512, 512),
139  cimg::make_image(residuals.real().eval(), imsizey, imsizex).get_resize(512, 512));
140  canvas->display(results);
141  cimg::make_image(residuals.real().eval(), imsizey, imsizex)
142  .histogram(256)
143  .display_graph("Residual Histogram", 2);
144  while (!canvas->is_closed()) canvas->wait();
145 #endif
146 }
147 
148 int main(int, char **) {
149  sopt::logging::set_level("debug");
151  const std::string &name = "real_data";
152  const bool w_term = false;
153  const t_real cellsize = 20;
154  const t_uint imsizex = 1024;
155  const t_uint imsizey = 1024;
156  const std::string kernel = "kb";
157  const std::vector<std::string> inputfiles = {vla_filename("../mwa/uvdump_01.uvfits"),
158  vla_filename("../mwa/uvdump_02.uvfits")};
159 
160  auto uv_data = pfitsio::read_uvfits(inputfiles);
161  t_real const sigma = uv_data.weights.norm() / std::sqrt(uv_data.weights.size()) * 0.05;
162  uv_data.vis =
163  uv_data.vis.array() * uv_data.weights.array() / uv_data.weights.array().cwiseAbs().maxCoeff();
164  padmm(name, imsizex, imsizey, kernel, 4, uv_data, sigma, std::make_tuple(w_term, cellsize));
165  return 0;
166 }
#define PURIFY_HIGH_LOG(...)
High priority message.
Definition: logging.h:203
#define PURIFY_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:205
const std::map< std::string, kernel > kernel_from_string
Definition: kernels.h:16
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
Definition: logging.h:137
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
Definition: pfitsio.cc:30
utilities::vis_params read_uvfits(const std::vector< std::string > &names, const bool flag, const stokes pol)
Read uvfits files from name of vector.
Definition: uvfits.cc:12
std::string output_filename(std::string const &filename)
Test output file.
std::string vla_filename(std::string const &filename)
Specific vla data.
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)
int main(int, char **)
Vector< t_complex > vis
Definition: uvw_utilities.h:22
t_uint size() const
return number of measurements
Definition: uvw_utilities.h:54