33   std::srand(
static_cast<t_uint
>(std::time(0)));
 
   34   std::mt19937 mersnne(std::time(0));
 
   42   std::string file_path = argv[1];
 
   45     throw std::runtime_error(
 
   47         " but the configuration file expects version " + params.version() +
 
   48         ". Please updated the config version manually to be compatable with the new version.");
 
   51   auto const session = sopt::mpi::init(argc, argv);
 
   60   auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] =
 
   64   auto measurements_transform =
 
   66                                 uv_data, measurement_op_eigen_vector);
 
   71   PURIFY_LOW_LOG(
"Value of operator norm is {}", measurements_transform->norm());
 
   72   t_real 
const flux_scale = 1.;
 
   73   uv_data.vis = uv_data.vis.array() * uv_data.weights.array() / flux_scale;
 
   80   const auto [update_header_sol, update_header_res, def_header] = 
genHeaders(params, uv_data);
 
   86   t_real beam_units = 1.0;
 
   87   if (params.mpiAlgorithm() != factory::algo_distribution::serial) {
 
   89     auto const world = sopt::mpi::Communicator::World();
 
   90     beam_units = world.all_sum_all(uv_data.size()) / flux_scale / flux_scale;
 
   92     throw std::runtime_error(
"Compile with MPI if you want to use MPI algorithm");
 
   95     beam_units = uv_data.size() / flux_scale / flux_scale;
 
   98   savePSF(params, def_header, measurements_transform, uv_data, flux_scale, sigma, beam_units);
 
  101   saveDirtyImage(params, def_header, measurements_transform, uv_data, beam_units);
 
  104   std::shared_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>> 
padmm;
 
  105   std::shared_ptr<sopt::algorithm::ImagingForwardBackward<t_complex>> fb;
 
  106   std::shared_ptr<sopt::algorithm::ImagingPrimalDual<t_complex>> primaldual;
 
  107   if (params.algorithm() == 
"padmm")
 
  108     padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
 
  109         params.mpiAlgorithm(), measurements_transform, wavelets.
transform, uv_data,
 
  110         sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(),
 
  111         wavelets.
sara_size, params.iterations(), params.realValueConstraint(),
 
  112         params.positiveValueConstraint(),
 
  113         (params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and
 
  114             (not params.positiveValueConstraint()),
 
  115         params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50,
 
  116         params.epsilonConvergenceScaling());
 
  117   if (params.algorithm() == 
"fb") {
 
  118     std::shared_ptr<DifferentiableFunc<t_complex>> f;
 
  119     if (params.diffFuncType() == diff_func_type::L2Norm_with_CRR) {
 
  121       f = std::make_shared<sopt::ONNXDifferentiableFunc<t_complex>>(
 
  122           params.CRR_function_model_path(), params.CRR_gradient_model_path(), sigma,
 
  123           params.CRR_mu(), params.CRR_lambda(), *measurements_transform);
 
  125       throw std::runtime_error(
"CRR approach cannot be used with ONNXRT off");
 
  129     fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
 
  130         params.mpiAlgorithm(), measurements_transform, wavelets.
transform, uv_data,
 
  131         sigma * params.epsilonScaling() / flux_scale,
 
  132         params.stepsize() * std::pow(sigma * params.epsilonScaling() / flux_scale, 2),
 
  133         params.regularisation_parameter(), params.height(), params.width(), wavelets.
sara_size,
 
  134         params.iterations(), params.realValueConstraint(), params.positiveValueConstraint(),
 
  135         (params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and
 
  136             (not params.positiveValueConstraint()),
 
  137         params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50,
 
  138         params.model_path(), params.nondiffFuncType(), f);
 
  140   if (params.algorithm() == 
"primaldual")
 
  142         params.mpiAlgorithm(), measurements_transform, wavelets.
transform, uv_data,
 
  143         sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(),
 
  144         wavelets.
sara_size, params.iterations(), params.realValueConstraint(),
 
  145         params.positiveValueConstraint(), params.relVarianceConvergence(),
 
  146         params.epsilonConvergenceScaling());
 
  148   if (params.algorithm() == 
"primaldual" and params.precondition_iters() > 0) {
 
  150         "Using visibility sampling density to precondtion primal dual with {} " 
  152         params.precondition_iters());
 
  153     primaldual->precondition_iters(params.precondition_iters());
 
  156       const auto world = sopt::mpi::Communicator::World();
 
  158           uv_data.u, uv_data.v, params.cellsizex(), params.cellsizey(), params.width(),
 
  159           params.height(), params.oversampling(), 0.5, world));
 
  163           uv_data.u, uv_data.v, params.cellsizex(), params.cellsizey(), params.width(),
 
  164           params.height(), params.oversampling(), 0.5));
 
  167   if (params.algorithm() == 
"padmm") {
 
  168     const std::weak_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>> algo_weak(
padmm);
 
  170     factory::add_updater<t_complex, sopt::algorithm::ImagingProximalADMM<t_complex>>(
 
  171         algo_weak, 1e-3, params.update_tolerance(), params.update_iters(), update_header_sol,
 
  172         update_header_res, params.height(), params.width(), wavelets.
sara_size, using_mpi,
 
  175   if (params.algorithm() == 
"primaldual") {
 
  176     const std::weak_ptr<sopt::algorithm::ImagingPrimalDual<t_complex>> algo_weak(primaldual);
 
  178     factory::add_updater<t_complex, sopt::algorithm::ImagingPrimalDual<t_complex>>(
 
  179         algo_weak, 1e-3, params.update_tolerance(), params.update_iters(), update_header_sol,
 
  180         update_header_res, params.height(), params.width(), wavelets.
sara_size, using_mpi,
 
  183   if (params.algorithm() == 
"fb") {
 
  184     const std::weak_ptr<sopt::algorithm::ImagingForwardBackward<t_complex>> algo_weak(fb);
 
  186     factory::add_updater<t_complex, sopt::algorithm::ImagingForwardBackward<t_complex>>(
 
  187         algo_weak, 0, params.update_tolerance(), 0, update_header_sol, update_header_res,
 
  188         params.height(), params.width(), wavelets.
sara_size, using_mpi, beam_units);
 
  193   Image<t_real> residual_image;
 
  196   const Vector<t_complex> estimate_image =
 
  197       (params.warm_start() != 
"")
 
  199                                    params.height() * params.width())
 
  201           : Vector<t_complex>::Zero(params.height() * params.width()).eval();
 
  202   const Vector<t_complex> estimate_res =
 
  203       (*measurements_transform * estimate_image).eval() - uv_data.vis;
 
  204   if (params.algorithm() == 
"padmm") {
 
  206     auto const diagnostic = (*padmm)(std::make_tuple(estimate_image.eval(), estimate_res.eval()));
 
  209     image = Image<t_complex>::Map(diagnostic.x.data(), params.height(), params.width()).real();
 
  210     const Vector<t_complex> residuals =
 
  211         measurements_transform->adjoint() * (diagnostic.residual / beam_units);
 
  213         Image<t_complex>::Map(residuals.data(), params.height(), params.width()).real();
 
  215     purified_header.
niters = diagnostic.niters;
 
  217   if (params.algorithm() == 
"fb") {
 
  219     auto const diagnostic = (*fb)(std::make_tuple(estimate_image.eval(), estimate_res.eval()));
 
  223     image = Image<t_complex>::Map(diagnostic.x.data(), params.height(), params.width()).real();
 
  224     const Vector<t_complex> residuals =
 
  225         measurements_transform->adjoint() * (diagnostic.residual / beam_units);
 
  227         Image<t_complex>::Map(residuals.data(), params.height(), params.width()).real();
 
  229     purified_header.
niters = diagnostic.niters;
 
  231   if (params.algorithm() == 
"primaldual") {
 
  233     auto const diagnostic =
 
  234         (*primaldual)(std::make_tuple(estimate_image.eval(), estimate_res.eval()));
 
  237     image = Image<t_complex>::Map(diagnostic.x.data(), params.height(), params.width()).real();
 
  238     const Vector<t_complex> residuals =
 
  239         measurements_transform->adjoint() * (diagnostic.residual / beam_units);
 
  241         Image<t_complex>::Map(residuals.data(), params.height(), params.width()).real();
 
  243     purified_header.
niters = diagnostic.niters;
 
  245   if (params.mpiAlgorithm() != factory::algo_distribution::serial) {
 
  247     auto const world = sopt::mpi::Communicator::World();
 
  250     throw std::runtime_error(
"Compile with MPI if you want to use MPI algorithm");
 
  260   if (params.mpiAlgorithm() != factory::algo_distribution::serial) {
 
  262     auto const world = sopt::mpi::Communicator::World();
 
  265     throw std::runtime_error(
"Compile with MPI if you want to use MPI algorithm");
 
std::string output_path() const
 
#define PURIFY_LOW_LOG(...)
Low priority message.
 
#define PURIFY_HIGH_LOG(...)
High priority message.
 
std::enable_if< std::is_same< Algorithm, sopt::algorithm::ImagingPrimalDual< t_complex > >::value, std::shared_ptr< Algorithm > >::type primaldual_factory(const algo_distribution dist, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &measurements, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &wavelets, const utilities::vis_params &uv_data, const t_real sigma, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const t_uint max_iterations=500, const bool real_constraint=true, const bool positive_constraint=true, const t_real relative_variation=1e-3, const t_real residual_tolerance_scaling=1)
return shared pointer to primal dual object
 
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
 
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
 
Image< t_complex > read2d(const std::string &fits_name)
Read image from fits file.
 
Vector< t_complex > sample_density_weights(const Vector< t_real > &u, const Vector< t_real > &v, const t_real cellx, const t_real celly, const t_uint imsizex, const t_uint imsizey, const t_real oversample_ratio, const t_real scale)
create sample density weights for a given field of view, uniform weighting
 
std::string version()
Returns library version.
 
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)
 
void saveMeasurementEigenVector(const YamlParser ¶ms, const Vector< t_complex > &measurement_op_eigen_vector)
 
void saveDirtyImage(const YamlParser ¶ms, const pfitsio::header_params &def_header, const std::shared_ptr< sopt::LinearTransform< Vector< t_complex >>> &measurements_transform, const utilities::vis_params &uv_data, const t_real beam_units)
 
void savePSF(const YamlParser ¶ms, const pfitsio::header_params &def_header, const std::shared_ptr< sopt::LinearTransform< Vector< t_complex >>> &measurements_transform, const utilities::vis_params &uv_data, const t_real flux_scale, const t_real sigma, const t_real beam_units)
 
std::shared_ptr< sopt::LinearTransform< Vector< t_complex > > > createMeasurementOperator(const YamlParser ¶ms, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi, const std::vector< t_int > &image_index, const std::vector< t_real > &w_stacks, const utilities::vis_params &uv_data, Vector< t_complex > &measurement_op_eigen_vector)
 
waveletInfo createWaveletOperator(YamlParser ¶ms, const factory::distributed_wavelet_operator &wop_algo)
 
OperatorsInfo selectOperators(YamlParser ¶ms)
 
void initOutDirectoryWithConfig(YamlParser ¶ms)
 
Headers genHeaders(const YamlParser ¶ms, const utilities::vis_params &uv_data)
 
inputData getInputData(const YamlParser ¶ms, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi)
 
std::shared_ptr< const sopt::LinearTransform< Eigen::VectorXcd > > transform