PURIFY
Next-generation radio interferometric imaging
Functions
main.cc File Reference
#include "purify/config.h"
#include "purify/types.h"
#include <array>
#include <cstddef>
#include <ctime>
#include <memory>
#include <random>
#include "purify/algorithm_factory.h"
#include "purify/cimg.h"
#include "purify/logging.h"
#include "purify/measurement_operator_factory.h"
#include "purify/pfitsio.h"
#include "purify/read_measurements.h"
#include "purify/setup_utils.h"
#include "purify/update_factory.h"
#include "purify/wavelet_operator_factory.h"
#include "purify/wide_field_utilities.h"
#include "purify/yaml-parser.h"
#include <sopt/imaging_padmm.h>
#include <sopt/positive_quadrant.h>
#include <sopt/power_method.h>
#include <sopt/relative_variation.h>
#include <sopt/reweighted.h>
+ Include dependency graph for main.cc:

Go to the source code of this file.

Functions

int main (int argc, const char **argv)
 

Function Documentation

◆ main()

int main ( int  argc,
const char **  argv 
)

Definition at line 32 of file main.cc.

32  {
33  std::srand(static_cast<t_uint>(std::time(0)));
34  std::mt19937 mersnne(std::time(0));
35 
36  // Read config file path from command line
37  if (argc == 1) {
38  PURIFY_HIGH_LOG("Specify the config file full path. Aborting.");
39  return 1;
40  }
41 
42  std::string file_path = argv[1];
43  YamlParser params = YamlParser(file_path);
44  if (params.version() != purify::version())
45  throw std::runtime_error(
46  "Using purify version " + purify::version() +
47  " but the configuration file expects version " + params.version() +
48  ". Please updated the config version manually to be compatable with the new version.");
49 
50 #ifdef PURIFY_MPI
51  auto const session = sopt::mpi::init(argc, argv);
52 #endif
53 
54  const auto [mop_algo, wop_algo, using_mpi] = selectOperators(params);
55 
56  sopt::logging::set_level(params.logging());
57  purify::logging::set_level(params.logging());
58 
59  // Read or generate input data
60  auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] =
61  getInputData(params, mop_algo, wop_algo, using_mpi);
62 
63  // create measurement operator
64  auto measurements_transform =
65  createMeasurementOperator(params, mop_algo, wop_algo, using_mpi, image_index, w_stacks,
66  uv_data, measurement_op_eigen_vector);
67 
68  // create wavelet operator
69  const waveletInfo wavelets = createWaveletOperator(params, wop_algo);
70 
71  PURIFY_LOW_LOG("Value of operator norm is {}", measurements_transform->norm());
72  t_real const flux_scale = 1.;
73  uv_data.vis = uv_data.vis.array() * uv_data.weights.array() / flux_scale;
74 
75  // Save some things before applying the algorithm
76  // the config yaml file - this also generates the output directory and the timestamp
78 
79  // Creating header for saving output images during iterations
80  const auto [update_header_sol, update_header_res, def_header] = genHeaders(params, uv_data);
81 
82  // the eigenvector
83  saveMeasurementEigenVector(params, measurement_op_eigen_vector);
84 
85  // the psf
86  t_real beam_units = 1.0;
87  if (params.mpiAlgorithm() != factory::algo_distribution::serial) {
88 #ifdef PURIFY_MPI
89  auto const world = sopt::mpi::Communicator::World();
90  beam_units = world.all_sum_all(uv_data.size()) / flux_scale / flux_scale;
91 #else
92  throw std::runtime_error("Compile with MPI if you want to use MPI algorithm");
93 #endif
94  } else {
95  beam_units = uv_data.size() / flux_scale / flux_scale;
96  }
97 
98  savePSF(params, def_header, measurements_transform, uv_data, flux_scale, sigma, beam_units);
99 
100  // the dirty image
101  saveDirtyImage(params, def_header, measurements_transform, uv_data, beam_units);
102 
103  // Create algorithm
104  std::shared_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>> padmm;
105  std::shared_ptr<sopt::algorithm::ImagingForwardBackward<t_complex>> fb;
106  std::shared_ptr<sopt::algorithm::ImagingPrimalDual<t_complex>> primaldual;
107  if (params.algorithm() == "padmm")
108  padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
109  params.mpiAlgorithm(), measurements_transform, wavelets.transform, uv_data,
110  sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(),
111  wavelets.sara_size, params.iterations(), params.realValueConstraint(),
112  params.positiveValueConstraint(),
113  (params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and
114  (not params.positiveValueConstraint()),
115  params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50,
116  params.epsilonConvergenceScaling());
117  if (params.algorithm() == "fb") {
118  std::shared_ptr<DifferentiableFunc<t_complex>> f;
119  if (params.diffFuncType() == diff_func_type::L2Norm_with_CRR) {
120 #ifdef PURIFY_ONNXRT
121  f = std::make_shared<sopt::ONNXDifferentiableFunc<t_complex>>(
122  params.CRR_function_model_path(), params.CRR_gradient_model_path(), sigma,
123  params.CRR_mu(), params.CRR_lambda(), *measurements_transform);
124 #else
125  throw std::runtime_error("CRR approach cannot be used with ONNXRT off");
126 #endif
127  }
128 
129  fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
130  params.mpiAlgorithm(), measurements_transform, wavelets.transform, uv_data,
131  sigma * params.epsilonScaling() / flux_scale,
132  params.stepsize() * std::pow(sigma * params.epsilonScaling() / flux_scale, 2),
133  params.regularisation_parameter(), params.height(), params.width(), wavelets.sara_size,
134  params.iterations(), params.realValueConstraint(), params.positiveValueConstraint(),
135  (params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and
136  (not params.positiveValueConstraint()),
137  params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50,
138  params.model_path(), params.nondiffFuncType(), f);
139  }
140  if (params.algorithm() == "primaldual")
141  primaldual = factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
142  params.mpiAlgorithm(), measurements_transform, wavelets.transform, uv_data,
143  sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(),
144  wavelets.sara_size, params.iterations(), params.realValueConstraint(),
145  params.positiveValueConstraint(), params.relVarianceConvergence(),
146  params.epsilonConvergenceScaling());
147  // Add primal dual preconditioning
148  if (params.algorithm() == "primaldual" and params.precondition_iters() > 0) {
150  "Using visibility sampling density to precondtion primal dual with {} "
151  "subiterations",
152  params.precondition_iters());
153  primaldual->precondition_iters(params.precondition_iters());
154 #ifdef PURIFY_MPI
155  if (using_mpi) {
156  const auto world = sopt::mpi::Communicator::World();
157  primaldual->precondition_weights(widefield::sample_density_weights(
158  uv_data.u, uv_data.v, params.cellsizex(), params.cellsizey(), params.width(),
159  params.height(), params.oversampling(), 0.5, world));
160  } else
161 #endif
162  primaldual->precondition_weights(widefield::sample_density_weights(
163  uv_data.u, uv_data.v, params.cellsizex(), params.cellsizey(), params.width(),
164  params.height(), params.oversampling(), 0.5));
165  }
166 
167  if (params.algorithm() == "padmm") {
168  const std::weak_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>> algo_weak(padmm);
169  // Adding step size update to algorithm
170  factory::add_updater<t_complex, sopt::algorithm::ImagingProximalADMM<t_complex>>(
171  algo_weak, 1e-3, params.update_tolerance(), params.update_iters(), update_header_sol,
172  update_header_res, params.height(), params.width(), wavelets.sara_size, using_mpi,
173  beam_units);
174  }
175  if (params.algorithm() == "primaldual") {
176  const std::weak_ptr<sopt::algorithm::ImagingPrimalDual<t_complex>> algo_weak(primaldual);
177  // Adding step size update to algorithm
178  factory::add_updater<t_complex, sopt::algorithm::ImagingPrimalDual<t_complex>>(
179  algo_weak, 1e-3, params.update_tolerance(), params.update_iters(), update_header_sol,
180  update_header_res, params.height(), params.width(), wavelets.sara_size, using_mpi,
181  beam_units);
182  }
183  if (params.algorithm() == "fb") {
184  const std::weak_ptr<sopt::algorithm::ImagingForwardBackward<t_complex>> algo_weak(fb);
185  // Adding step size update to algorithm
186  factory::add_updater<t_complex, sopt::algorithm::ImagingForwardBackward<t_complex>>(
187  algo_weak, 0, params.update_tolerance(), 0, update_header_sol, update_header_res,
188  params.height(), params.width(), wavelets.sara_size, using_mpi, beam_units);
189  }
190 
191  PURIFY_HIGH_LOG("Starting sopt!");
192  Image<t_real> image;
193  Image<t_real> residual_image;
194  pfitsio::header_params purified_header = def_header;
195  purified_header.fits_name = params.output_path() + "/purified.fits";
196  const Vector<t_complex> estimate_image =
197  (params.warm_start() != "")
198  ? Vector<t_complex>::Map(pfitsio::read2d(params.warm_start()).data(),
199  params.height() * params.width())
200  .eval()
201  : Vector<t_complex>::Zero(params.height() * params.width()).eval();
202  const Vector<t_complex> estimate_res =
203  (*measurements_transform * estimate_image).eval() - uv_data.vis;
204  if (params.algorithm() == "padmm") {
205  // Apply algorithm
206  auto const diagnostic = (*padmm)(std::make_tuple(estimate_image.eval(), estimate_res.eval()));
207 
208  // Save the rest of the output
209  image = Image<t_complex>::Map(diagnostic.x.data(), params.height(), params.width()).real();
210  const Vector<t_complex> residuals =
211  measurements_transform->adjoint() * (diagnostic.residual / beam_units);
212  residual_image =
213  Image<t_complex>::Map(residuals.data(), params.height(), params.width()).real();
214  purified_header.hasconverged = diagnostic.good;
215  purified_header.niters = diagnostic.niters;
216  }
217  if (params.algorithm() == "fb") {
218  // Apply algorithm
219  auto const diagnostic = (*fb)(std::make_tuple(estimate_image.eval(), estimate_res.eval()));
220 
221  // Save the rest of the output
222  // the clean image
223  image = Image<t_complex>::Map(diagnostic.x.data(), params.height(), params.width()).real();
224  const Vector<t_complex> residuals =
225  measurements_transform->adjoint() * (diagnostic.residual / beam_units);
226  residual_image =
227  Image<t_complex>::Map(residuals.data(), params.height(), params.width()).real();
228  purified_header.hasconverged = diagnostic.good;
229  purified_header.niters = diagnostic.niters;
230  }
231  if (params.algorithm() == "primaldual") {
232  // Apply algorithm
233  auto const diagnostic =
234  (*primaldual)(std::make_tuple(estimate_image.eval(), estimate_res.eval()));
235 
236  // Save the rest of the output
237  image = Image<t_complex>::Map(diagnostic.x.data(), params.height(), params.width()).real();
238  const Vector<t_complex> residuals =
239  measurements_transform->adjoint() * (diagnostic.residual / beam_units);
240  residual_image =
241  Image<t_complex>::Map(residuals.data(), params.height(), params.width()).real();
242  purified_header.hasconverged = diagnostic.good;
243  purified_header.niters = diagnostic.niters;
244  }
245  if (params.mpiAlgorithm() != factory::algo_distribution::serial) {
246 #ifdef PURIFY_MPI
247  auto const world = sopt::mpi::Communicator::World();
248  if (world.is_root())
249 #else
250  throw std::runtime_error("Compile with MPI if you want to use MPI algorithm");
251 #endif
252  pfitsio::write2d(image, purified_header, true);
253  } else {
254  pfitsio::write2d(image, purified_header, true);
255  }
256  // the residuals
257  pfitsio::header_params residuals_header = purified_header;
258  residuals_header.fits_name = params.output_path() + "/residuals.fits";
259  residuals_header.pix_units = "Jy/Beam";
260  if (params.mpiAlgorithm() != factory::algo_distribution::serial) {
261 #ifdef PURIFY_MPI
262  auto const world = sopt::mpi::Communicator::World();
263  if (world.is_root())
264 #else
265  throw std::runtime_error("Compile with MPI if you want to use MPI algorithm");
266 #endif
267  pfitsio::write2d(residual_image, residuals_header, true);
268  } else {
269  pfitsio::write2d(residual_image, residuals_header, true);
270  }
271 
272  return 0;
273 }
std::string output_path() const
Definition: yaml-parser.h:152
#define PURIFY_LOW_LOG(...)
Low priority message.
Definition: logging.h:207
#define PURIFY_HIGH_LOG(...)
High priority message.
Definition: logging.h:203
std::enable_if< std::is_same< Algorithm, sopt::algorithm::ImagingPrimalDual< t_complex > >::value, std::shared_ptr< Algorithm > >::type primaldual_factory(const algo_distribution dist, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &measurements, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &wavelets, const utilities::vis_params &uv_data, const t_real sigma, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const t_uint max_iterations=500, const bool real_constraint=true, const bool positive_constraint=true, const t_real relative_variation=1e-3, const t_real residual_tolerance_scaling=1)
return shared pointer to primal dual object
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
Definition: logging.h:137
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
Definition: pfitsio.cc:30
Image< t_complex > read2d(const std::string &fits_name)
Read image from fits file.
Definition: pfitsio.cc:109
Vector< t_complex > sample_density_weights(const Vector< t_real > &u, const Vector< t_real > &v, const t_real cellx, const t_real celly, const t_uint imsizex, const t_uint imsizey, const t_real oversample_ratio, const t_real scale)
create sample density weights for a given field of view, uniform weighting
std::string version()
Returns library version.
Definition: config.in.h:40
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)
void saveMeasurementEigenVector(const YamlParser &params, const Vector< t_complex > &measurement_op_eigen_vector)
Definition: setup_utils.cc:380
void saveDirtyImage(const YamlParser &params, const pfitsio::header_params &def_header, const std::shared_ptr< sopt::LinearTransform< Vector< t_complex >>> &measurements_transform, const utilities::vis_params &uv_data, const t_real beam_units)
Definition: setup_utils.cc:441
void savePSF(const YamlParser &params, const pfitsio::header_params &def_header, const std::shared_ptr< sopt::LinearTransform< Vector< t_complex >>> &measurements_transform, const utilities::vis_params &uv_data, const t_real flux_scale, const t_real sigma, const t_real beam_units)
Definition: setup_utils.cc:403
std::shared_ptr< sopt::LinearTransform< Vector< t_complex > > > createMeasurementOperator(const YamlParser &params, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi, const std::vector< t_int > &image_index, const std::vector< t_real > &w_stacks, const utilities::vis_params &uv_data, Vector< t_complex > &measurement_op_eigen_vector)
Definition: setup_utils.cc:250
waveletInfo createWaveletOperator(YamlParser &params, const factory::distributed_wavelet_operator &wop_algo)
Definition: setup_utils.cc:20
OperatorsInfo selectOperators(YamlParser &params)
Definition: setup_utils.cc:38
void initOutDirectoryWithConfig(YamlParser &params)
Definition: setup_utils.cc:350
Headers genHeaders(const YamlParser &params, const utilities::vis_params &uv_data)
Definition: setup_utils.cc:364
inputData getInputData(const YamlParser &params, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi)
Definition: setup_utils.cc:66
std::shared_ptr< const sopt::LinearTransform< Eigen::VectorXcd > > transform
Definition: setup_utils.h:17
t_uint sara_size
Definition: setup_utils.h:18

References createMeasurementOperator(), createWaveletOperator(), purify::pfitsio::header_params::fits_name, genHeaders(), getInputData(), purify::pfitsio::header_params::hasconverged, initOutDirectoryWithConfig(), purify::L2Norm_with_CRR, purify::pfitsio::header_params::niters, purify::YamlParser::output_path(), padmm(), purify::pfitsio::header_params::pix_units, purify::factory::primaldual_factory(), PURIFY_HIGH_LOG, PURIFY_LOW_LOG, purify::pfitsio::read2d(), purify::widefield::sample_density_weights(), waveletInfo::sara_size, saveDirtyImage(), saveMeasurementEigenVector(), savePSF(), selectOperators(), purify::factory::serial, purify::logging::set_level(), waveletInfo::transform, purify::version(), and purify::pfitsio::write2d().