1 #include "purify/config.h" 
    7 #include <boost/math/special_functions/erf.hpp> 
    8 #include "purify/directories.h" 
   13 #include <sopt/imaging_padmm.h> 
   14 #include <sopt/power_method.h> 
   15 #include <sopt/relative_variation.h> 
   16 #include <sopt/utilities.h> 
   17 #include <sopt/wavelets.h> 
   18 #include <sopt/wavelets/sara.h> 
   20 int main(
int nargs, 
char const **args) {
 
   22     std::cerr << 
" Wrong number of arguments! " << 
'\n';
 
   30   std::string 
const test_type = args[1];
 
   31   const std::string 
kernel = args[2];
 
   32   t_real 
const over_sample = std::stod(
static_cast<std::string
>(args[3]));
 
   33   t_int 
const J = 
static_cast<t_int
>(std::stod(
static_cast<std::string
>(args[4])));
 
   34   t_real 
const m_over_n = std::stod(
static_cast<std::string
>(args[5]));
 
   35   std::string 
const test_number = 
static_cast<std::string
>(args[6]);
 
   36   t_real 
const ISNR = std::stod(
static_cast<std::string
>(args[7]));
 
   37   std::string 
const name = 
static_cast<std::string
>(args[8]);
 
   41   std::string 
const dirty_image_fits =
 
   43   std::string 
const results =
 
   47   auto sky_model_max = sky_model.array().abs().maxCoeff();
 
   48   sky_model = sky_model / sky_model_max;
 
   49   t_int 
const number_of_vis = std::floor(m_over_n * sky_model.size());
 
   56   auto measurements_sky = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
 
   58           uv_data.u, uv_data.v, uv_data.w, uv_data.weights, sky_model.cols(), sky_model.rows(),
 
   60       100, 1e-4, Vector<t_complex>::Random(sky_model.size())));
 
   61   uv_data.vis = measurements_sky * Vector<t_complex>::Map(sky_model.data(), sky_model.size());
 
   62   auto measurements_transform = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
 
   64           uv_data.u, uv_data.v, uv_data.w, uv_data.weights, sky_model.cols(), sky_model.rows(),
 
   66       100, 1e-4, Vector<t_complex>::Random(sky_model.size())));
 
   68   std::vector<std::tuple<std::string, t_uint>> wavelets;
 
   70   if (test_type == 
"clean") wavelets.push_back(std::make_tuple(
"Dirac", 3u));
 
   71   if (test_type == 
"ms_clean") wavelets.push_back(std::make_tuple(
"DB4", 3u));
 
   72   if (test_type == 
"padmm") {
 
   73     wavelets.push_back(std::make_tuple(
"Dirac", 3u));
 
   74     wavelets.push_back(std::make_tuple(
"DB1", 3u));
 
   75     wavelets.push_back(std::make_tuple(
"DB2", 3u));
 
   76     wavelets.push_back(std::make_tuple(
"DB3", 3u));
 
   77     wavelets.push_back(std::make_tuple(
"DB4", 3u));
 
   78     wavelets.push_back(std::make_tuple(
"DB5", 3u));
 
   79     wavelets.push_back(std::make_tuple(
"DB6", 3u));
 
   80     wavelets.push_back(std::make_tuple(
"DB7", 3u));
 
   81     wavelets.push_back(std::make_tuple(
"DB8", 3u));
 
   83   sopt::wavelets::SARA 
const sara(wavelets.begin(), wavelets.end());
 
   84   auto const Psi = sopt::linear_transform<t_complex>(sara, sky_model.rows(), sky_model.cols());
 
   91   Vector<> dimage = (measurements_transform.adjoint() * uv_data.vis).real();
 
   92   t_real 
const max_val = dimage.array().abs().maxCoeff();
 
   93   dimage = dimage / max_val;
 
   94   Vector<t_complex> initial_estimate = Vector<t_complex>::Zero(dimage.size());
 
   97   auto const purify_regulariser_strength =
 
   98       (Psi.adjoint() * (measurements_transform.adjoint() * uv_data.vis).eval()).real().maxCoeff() *
 
  101   auto convergence_function = [&iters](
const Vector<t_complex> &x) {
 
  108   auto const padmm = sopt::algorithm::ImagingProximalADMM<t_complex>(uv_data.vis)
 
  109                          .regulariser_strength(purify_regulariser_strength)
 
  110                          .relative_variation(1e-3)
 
  111                          .l2ball_proximal_epsilon(epsilon * 1.001)
 
  113                          .l1_proximal_tolerance(1e-2)
 
  115                          .l1_proximal_itermax(50)
 
  116                          .l1_proximal_positivity_constraint(
true)
 
  117                          .l1_proximal_real_constraint(
true)
 
  118                          .residual_convergence(epsilon * 1.001)
 
  119                          .lagrange_update_scale(0.9)
 
  122                          .is_converged(convergence_function)
 
  123                          .Phi(measurements_transform);
 
  126   std::clock_t c_start = std::clock();
 
  127   auto const diagnostic = 
padmm();
 
  128   std::clock_t c_end = std::clock();
 
  132   if (diagnostic.good) {
 
  135   const t_uint maxiters = iters;
 
  137   Image<t_complex> image =
 
  138       Image<t_complex>::Map(diagnostic.x.data(), sky_model.rows(), sky_model.cols());
 
  140   Vector<t_complex> original = Vector<t_complex>::Map(sky_model.data(), sky_model.size(), 1);
 
  141   Image<t_complex> res = sky_model - image;
 
  142   Vector<t_complex> residual = Vector<t_complex>::Map(res.data(), image.size(), 1);
 
  144   auto snr = 20. * std::log10(original.norm() / residual.norm());  
 
  145   auto total_time = (c_end - c_start) / CLOCKS_PER_SEC;  
 
  147   std::ofstream out(results);
 
  149   out << snr << 
" " << total_time << 
" " << converged << 
" " << maxiters;
 
#define PURIFY_HIGH_LOG(...)
High priority message.
 
#define PURIFY_MEDIUM_LOG(...)
Medium priority message.
 
const t_real pi
mathematical constant
 
const std::map< std::string, kernel > kernel_from_string
 
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
 
std::shared_ptr< sopt::LinearTransform< T > > init_degrid_operator_2d(const Vector< t_real > &u, const Vector< t_real > &v, const Vector< t_real > &w, const Vector< t_complex > &weights, const t_uint &imsizey, const t_uint &imsizex, const t_real &oversample_ratio=2, const kernels::kernel kernel=kernels::kernel::kb, const t_uint Ju=4, const t_uint Jv=4, const bool w_stacking=false, const t_real &cellx=1, const t_real &celly=1)
Returns linear transform that is the standard degridding operator.
 
Image< t_complex > read2d(const std::string &fits_name)
Read image from fits file.
 
t_real SNR_to_standard_deviation(const Vector< t_complex > &y0, const t_real &SNR)
Converts SNR to RMS noise.
 
Vector< t_complex > add_noise(const Vector< t_complex > &y0, const t_complex &mean, const t_real &standard_deviation)
Add guassian noise to vector.
 
utilities::vis_params random_sample_density(const t_int vis_num, const t_real mean, const t_real standard_deviation, const t_real rms_w)
Generates a random visibility coverage.
 
t_real calculate_l2_radius(const t_uint y_size, const t_real &sigma, const t_real &n_sigma, const std::string distirbution)
A function that calculates the l2 ball radius for sopt.
 
std::string output_filename(std::string const &filename)
Test output file.
 
std::string image_filename(std::string const &filename)
Image filename.
 
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)
 
int main(int nargs, char const **args)