1 #include "purify/config.h"
20 #include <sopt/imaging_padmm.h>
21 #include <sopt/positive_quadrant.h>
22 #include <sopt/power_method.h>
23 #include <sopt/relative_variation.h>
24 #include <sopt/reweighted.h>
27 #include <sopt/onnx_differentiable_func.h>
32 int main(
int argc,
const char **argv) {
33 std::srand(
static_cast<t_uint
>(std::time(0)));
34 std::mt19937 mersnne(std::time(0));
42 std::string file_path = argv[1];
45 throw std::runtime_error(
47 " but the configuration file expects version " + params.version() +
48 ". Please updated the config version manually to be compatable with the new version.");
51 auto const session = sopt::mpi::init(argc, argv);
60 auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] =
64 auto measurements_transform =
66 uv_data, measurement_op_eigen_vector);
71 PURIFY_LOW_LOG(
"Value of operator norm is {}", measurements_transform->norm());
72 t_real
const flux_scale = 1.;
73 uv_data.vis = uv_data.vis.array() * uv_data.weights.array() / flux_scale;
80 const auto [update_header_sol, update_header_res, def_header] =
genHeaders(params, uv_data);
86 t_real beam_units = 1.0;
89 auto const world = sopt::mpi::Communicator::World();
90 beam_units = world.all_sum_all(uv_data.size()) / flux_scale / flux_scale;
92 throw std::runtime_error(
"Compile with MPI if you want to use MPI algorithm");
95 beam_units = uv_data.size() / flux_scale / flux_scale;
98 savePSF(params, def_header, measurements_transform, uv_data, flux_scale, sigma, beam_units);
101 saveDirtyImage(params, def_header, measurements_transform, uv_data, beam_units);
104 std::shared_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>>
padmm;
105 std::shared_ptr<sopt::algorithm::ImagingForwardBackward<t_complex>> fb;
106 std::shared_ptr<sopt::algorithm::ImagingPrimalDual<t_complex>> primaldual;
107 if (params.algorithm() ==
"padmm")
108 padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
109 params.mpiAlgorithm(), measurements_transform, wavelets.
transform, uv_data,
110 sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(),
111 wavelets.
sara_size, params.iterations(), params.realValueConstraint(),
112 params.positiveValueConstraint(),
113 (params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and
114 (not params.positiveValueConstraint()),
115 params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50,
116 params.epsilonConvergenceScaling());
117 if (params.algorithm() ==
"fb") {
118 std::shared_ptr<DifferentiableFunc<t_complex>> f;
121 f = std::make_shared<sopt::ONNXDifferentiableFunc<t_complex>>(
122 params.CRR_function_model_path(), params.CRR_gradient_model_path(), sigma,
123 params.CRR_mu(), params.CRR_lambda(), *measurements_transform);
125 throw std::runtime_error(
"CRR approach cannot be used with ONNXRT off");
129 fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
130 params.mpiAlgorithm(), measurements_transform, wavelets.
transform, uv_data,
131 sigma * params.epsilonScaling() / flux_scale,
132 params.stepsize() * std::pow(sigma * params.epsilonScaling() / flux_scale, 2),
133 params.regularisation_parameter(), params.height(), params.width(), wavelets.
sara_size,
134 params.iterations(), params.realValueConstraint(), params.positiveValueConstraint(),
135 (params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and
136 (not params.positiveValueConstraint()),
137 params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50,
138 params.model_path(), params.nondiffFuncType(), f);
140 if (params.algorithm() ==
"primaldual")
142 params.mpiAlgorithm(), measurements_transform, wavelets.
transform, uv_data,
143 sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(),
144 wavelets.
sara_size, params.iterations(), params.realValueConstraint(),
145 params.positiveValueConstraint(), params.relVarianceConvergence(),
146 params.epsilonConvergenceScaling());
148 if (params.algorithm() ==
"primaldual" and params.precondition_iters() > 0) {
150 "Using visibility sampling density to precondtion primal dual with {} "
152 params.precondition_iters());
153 primaldual->precondition_iters(params.precondition_iters());
156 const auto world = sopt::mpi::Communicator::World();
158 uv_data.u, uv_data.v, params.cellsizex(), params.cellsizey(), params.width(),
159 params.height(), params.oversampling(), 0.5, world));
163 uv_data.u, uv_data.v, params.cellsizex(), params.cellsizey(), params.width(),
164 params.height(), params.oversampling(), 0.5));
167 if (params.algorithm() ==
"padmm") {
168 const std::weak_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>> algo_weak(
padmm);
170 factory::add_updater<t_complex, sopt::algorithm::ImagingProximalADMM<t_complex>>(
171 algo_weak, 1e-3, params.update_tolerance(), params.update_iters(), update_header_sol,
172 update_header_res, params.height(), params.width(), wavelets.
sara_size, using_mpi,
175 if (params.algorithm() ==
"primaldual") {
176 const std::weak_ptr<sopt::algorithm::ImagingPrimalDual<t_complex>> algo_weak(primaldual);
178 factory::add_updater<t_complex, sopt::algorithm::ImagingPrimalDual<t_complex>>(
179 algo_weak, 1e-3, params.update_tolerance(), params.update_iters(), update_header_sol,
180 update_header_res, params.height(), params.width(), wavelets.
sara_size, using_mpi,
183 if (params.algorithm() ==
"fb") {
184 const std::weak_ptr<sopt::algorithm::ImagingForwardBackward<t_complex>> algo_weak(fb);
186 factory::add_updater<t_complex, sopt::algorithm::ImagingForwardBackward<t_complex>>(
187 algo_weak, 0, params.update_tolerance(), 0, update_header_sol, update_header_res,
188 params.height(), params.width(), wavelets.
sara_size, using_mpi, beam_units);
193 Image<t_real> residual_image;
196 const Vector<t_complex> estimate_image =
197 (params.warm_start() !=
"")
199 params.height() * params.width())
201 : Vector<t_complex>::Zero(params.height() * params.width()).eval();
202 const Vector<t_complex> estimate_res =
203 (*measurements_transform * estimate_image).eval() - uv_data.vis;
204 if (params.algorithm() ==
"padmm") {
206 auto const diagnostic = (*padmm)(std::make_tuple(estimate_image.eval(), estimate_res.eval()));
209 image = Image<t_complex>::Map(diagnostic.x.data(), params.height(), params.width()).real();
210 const Vector<t_complex> residuals =
211 measurements_transform->adjoint() * (diagnostic.residual / beam_units);
213 Image<t_complex>::Map(residuals.data(), params.height(), params.width()).real();
215 purified_header.
niters = diagnostic.niters;
217 if (params.algorithm() ==
"fb") {
219 auto const diagnostic = (*fb)(std::make_tuple(estimate_image.eval(), estimate_res.eval()));
223 image = Image<t_complex>::Map(diagnostic.x.data(), params.height(), params.width()).real();
224 const Vector<t_complex> residuals =
225 measurements_transform->adjoint() * (diagnostic.residual / beam_units);
227 Image<t_complex>::Map(residuals.data(), params.height(), params.width()).real();
229 purified_header.
niters = diagnostic.niters;
231 if (params.algorithm() ==
"primaldual") {
233 auto const diagnostic =
234 (*primaldual)(std::make_tuple(estimate_image.eval(), estimate_res.eval()));
237 image = Image<t_complex>::Map(diagnostic.x.data(), params.height(), params.width()).real();
238 const Vector<t_complex> residuals =
239 measurements_transform->adjoint() * (diagnostic.residual / beam_units);
241 Image<t_complex>::Map(residuals.data(), params.height(), params.width()).real();
243 purified_header.
niters = diagnostic.niters;
247 auto const world = sopt::mpi::Communicator::World();
250 throw std::runtime_error(
"Compile with MPI if you want to use MPI algorithm");
262 auto const world = sopt::mpi::Communicator::World();
265 throw std::runtime_error(
"Compile with MPI if you want to use MPI algorithm");
int main(int argc, char const **argv)
std::string output_path() const
#define PURIFY_LOW_LOG(...)
Low priority message.
#define PURIFY_HIGH_LOG(...)
High priority message.
std::enable_if< std::is_same< Algorithm, sopt::algorithm::ImagingPrimalDual< t_complex > >::value, std::shared_ptr< Algorithm > >::type primaldual_factory(const algo_distribution dist, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &measurements, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &wavelets, const utilities::vis_params &uv_data, const t_real sigma, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const t_uint max_iterations=500, const bool real_constraint=true, const bool positive_constraint=true, const t_real relative_variation=1e-3, const t_real residual_tolerance_scaling=1)
return shared pointer to primal dual object
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
Image< t_complex > read2d(const std::string &fits_name)
Read image from fits file.
Vector< t_complex > sample_density_weights(const Vector< t_real > &u, const Vector< t_real > &v, const t_real cellx, const t_real celly, const t_uint imsizex, const t_uint imsizey, const t_real oversample_ratio, const t_real scale)
create sample density weights for a given field of view, uniform weighting
std::string version()
Returns library version.
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)
void saveMeasurementEigenVector(const YamlParser ¶ms, const Vector< t_complex > &measurement_op_eigen_vector)
void saveDirtyImage(const YamlParser ¶ms, const pfitsio::header_params &def_header, const std::shared_ptr< sopt::LinearTransform< Vector< t_complex >>> &measurements_transform, const utilities::vis_params &uv_data, const t_real beam_units)
void savePSF(const YamlParser ¶ms, const pfitsio::header_params &def_header, const std::shared_ptr< sopt::LinearTransform< Vector< t_complex >>> &measurements_transform, const utilities::vis_params &uv_data, const t_real flux_scale, const t_real sigma, const t_real beam_units)
std::shared_ptr< sopt::LinearTransform< Vector< t_complex > > > createMeasurementOperator(const YamlParser ¶ms, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi, const std::vector< t_int > &image_index, const std::vector< t_real > &w_stacks, const utilities::vis_params &uv_data, Vector< t_complex > &measurement_op_eigen_vector)
waveletInfo createWaveletOperator(YamlParser ¶ms, const factory::distributed_wavelet_operator &wop_algo)
OperatorsInfo selectOperators(YamlParser ¶ms)
void initOutDirectoryWithConfig(YamlParser ¶ms)
Headers genHeaders(const YamlParser ¶ms, const utilities::vis_params &uv_data)
inputData getInputData(const YamlParser ¶ms, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi)
std::shared_ptr< const sopt::LinearTransform< Eigen::VectorXcd > > transform