5 #include <boost/filesystem.hpp>
6 #include <boost/math/special_functions/erf.hpp>
7 #include "purify/directories.h"
15 #include <sopt/imaging_padmm.h>
16 #include <sopt/mpi/communicator.h>
17 #include <sopt/mpi/session.h>
18 #include <sopt/power_method.h>
19 #include <sopt/relative_variation.h>
20 #include <sopt/utilities.h>
21 #include <sopt/wavelets.h>
22 #include <sopt/wavelets/sara.h>
28 #ifndef PURIFY_PADMM_ALGORITHM
29 #define PURIFY_PADMM_ALGORITHM 2
35 Image<t_complex>
const &ground_truth_image, t_uint number_of_vis, t_real snr,
36 const std::tuple<bool, t_real> &w_term) {
41 uv_data.u.size() / ground_truth_image.size());
43 auto const sky_measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
45 uv_data, ground_truth_image.rows(), ground_truth_image.cols(), std::get<1>(w_term),
47 100, 1e-4, Vector<t_complex>::Random(ground_truth_image.size())));
49 uv_data.vis = (*sky_measurements) *
50 Image<t_complex>::Map(ground_truth_image.data(), ground_truth_image.size(), 1);
56 return std::make_tuple(uv_data, sigma);
60 Image<t_complex>
const &ground_truth_image, t_uint number_of_vis, t_real snr,
61 const std::tuple<bool, t_real> &w_term, sopt::mpi::Communicator
const &comm) {
62 if (comm.size() == 1)
return dirty_visibilities(ground_truth_image, number_of_vis, snr, w_term);
65 comm.broadcast(std::get<1>(result));
71 auto const sigma = comm.broadcast<t_real>();
75 std::shared_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>>
padmm_factory(
76 std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>
const>
const &measurements,
77 const sopt::wavelets::SARA &sara,
const Image<t_complex> &ground_truth_image,
79 auto const Psi = sopt::linear_transform<t_complex>(sara, ground_truth_image.rows(),
80 ground_truth_image.cols(), comm);
82 #if PURIFY_PADMM_ALGORITHM == 2
83 auto const epsilon = std::sqrt(
85 #elif PURIFY_PADMM_ALGORITHM == 3 || PURIFY_PADMM_ALGORITHM == 1
88 PURIFY_LOW_LOG(
"SARA Size = {}, Rank = {}", sara.size(), comm.rank());
89 const t_real regulariser_strength =
91 std::make_shared<sopt::LinearTransform<Vector<t_complex>>
const>(Psi),
95 PURIFY_LOW_LOG(
"Regulariser_Strength {}, SARA Size = {}, Rank = {}", regulariser_strength,
96 sara.size(), comm.rank());
100 auto padmm = std::make_shared<sopt::algorithm::ImagingProximalADMM<t_complex>>(uv_data.
vis);
102 .regulariser_strength(comm.all_reduce(regulariser_strength, MPI_MAX))
103 .relative_variation(1e-3)
104 .l2ball_proximal_epsilon(epsilon)
105 #if PURIFY_PADMM_ALGORITHM == 2
107 .l2ball_proximal_communicator(comm)
110 .l1_proximal_adjoint_space_comm(comm)
112 .l1_proximal_tolerance(1e-2)
114 .l1_proximal_itermax(50)
115 .l1_proximal_positivity_constraint(
true)
116 .l1_proximal_real_constraint(
true)
117 .residual_tolerance(epsilon)
118 .lagrange_update_scale(0.9)
121 sopt::ScalarRelativeVariation<t_complex> conv(
padmm->relative_variation(),
122 padmm->relative_variation(),
"Objective function");
123 std::weak_ptr<decltype(
padmm)::element_type>
const padmm_weak(
padmm);
124 padmm->residual_convergence([padmm_weak, conv, comm](
125 Vector<t_complex>
const &x,
126 Vector<t_complex>
const &residual)
mutable ->
bool {
127 auto const padmm = padmm_weak.lock();
128 #if PURIFY_PADMM_ALGORITHM == 2
129 auto const residual_norm = sopt::mpi::l2_norm(residual,
padmm->l2ball_proximal_weights(), comm);
130 auto const result = residual_norm <
padmm->residual_tolerance();
131 #elif PURIFY_PADMM_ALGORITHM == 3 || PURIFY_PADMM_ALGORITHM == 1
132 auto const residual_norm = sopt::l2_norm(residual,
padmm->l2ball_proximal_weights());
134 comm.all_reduce<int8_t>(residual_norm <
padmm->residual_tolerance(), MPI_LAND);
136 SOPT_LOW_LOG(
" - [PADMM] Residuals: {} <? {}", residual_norm,
padmm->residual_tolerance());
140 padmm->objective_convergence([padmm_weak, conv, comm](Vector<t_complex>
const &x,
141 Vector<t_complex>
const &)
mutable ->
bool {
142 auto const padmm = padmm_weak.lock();
143 #if PURIFY_PADMM_ALGORITHM == 2
144 return conv(sopt::mpi::l1_norm(
padmm->Psi().adjoint() * x,
padmm->l1_proximal_weights(), comm));
145 #elif PURIFY_PADMM_ALGORITHM == 3 || PURIFY_PADMM_ALGORITHM == 1
146 return comm.all_reduce<uint8_t>(
147 conv(sopt::l1_norm(
padmm->Psi().adjoint() * x,
padmm->l1_proximal_weights())), MPI_LAND);
154 int main(
int nargs,
char const **args) {
157 auto const session = sopt::mpi::init(nargs, args);
158 auto const world = sopt::mpi::Communicator::World();
160 const t_real FoV = 1;
161 const t_real max_w = 100;
162 const std::string name =
"M31";
163 const t_real snr = 30;
165 const bool w_term =
true;
168 ground_truth_image /= ground_truth_image.array().abs().maxCoeff();
170 const t_real cellsize = FoV / ground_truth_image.cols() * 60. * 60.;
172 t_int
const number_of_pixels = ground_truth_image.size();
173 t_int
const number_of_vis = std::floor(number_of_pixels * 0.1);
177 std::make_tuple(w_term, cellsize), world);
178 #if PURIFY_PADMM_ALGORITHM == 2 || PURIFY_PADMM_ALGORITHM == 3
180 auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
182 world, std::get<0>(data), ground_truth_image.rows(), ground_truth_image.cols(), cellsize,
186 Vector<t_complex>::Random(ground_truth_image.rows() * ground_truth_image.cols())
191 auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
193 world, std::get<0>(data), ground_truth_image.rows(), ground_truth_image.cols(), cellsize,
197 Vector<t_complex>::Random(ground_truth_image.rows() * ground_truth_image.cols())
201 #elif PURIFY_PADMM_ALGORITHM == 1
203 auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
204 measurementoperator::init_degrid_operator_2d_mpi<Vector<t_complex>>(
205 world, std::get<0>(data), ground_truth_image.rows(), ground_truth_image.cols(), cellsize,
209 Vector<t_complex>::Random(ground_truth_image.rows() * ground_truth_image.cols())
214 auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
215 gpu::measurementoperator::init_degrid_operator_2d_mpi(
216 world, std::get<0>(data), ground_truth_image.rows(), ground_truth_image.cols(), cellsize,
220 Vector<t_complex>::Random(ground_truth_image.rows() * ground_truth_image.cols())
225 auto const sara = sopt::wavelets::distribute_sara(
226 sopt::wavelets::SARA{
227 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
228 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
229 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)},
233 auto const padmm =
padmm_factory(measurements, sara, ground_truth_image, std::get<0>(data),
234 std::get<1>(data), world);
236 auto const diagnostic = (*padmm)();
239 assert(diagnostic.x.size() == ground_truth_image.size());
240 assert(world.broadcast(diagnostic.x).isApprox(diagnostic.x));
243 auto const residual_image = (measurements->adjoint() * diagnostic.residual).real();
244 auto const dirty_image = (measurements->adjoint() * std::get<0>(data).vis).real();
245 if (world.is_root()) {
247 #if PURIFY_PADMM_ALGORITHM == 3
248 auto const pb_path = path /
kernel /
"local_epsilon_replicated_grids";
249 #elif PURIFY_PADMM_ALGORITHM == 2
250 auto const pb_path = path /
kernel /
"global_epsilon_replicated_grids";
251 #elif PURIFY_PADMM_ALGORITHM == 1
252 auto const pb_path = path /
kernel /
"local_epsilon_distributed_grids";
254 #error Unknown or unimplemented algorithm
258 pfitsio::write2d(ground_truth_image.real(), (path /
"input.fits").native());
259 pfitsio::write2d(dirty_image, ground_truth_image.rows(), ground_truth_image.cols(),
260 (pb_path /
"dirty.fits").native());
261 pfitsio::write2d(diagnostic.x.real(), ground_truth_image.rows(), ground_truth_image.cols(),
262 (pb_path /
"solution.fits").native());
263 pfitsio::write2d(residual_image, ground_truth_image.rows(), ground_truth_image.cols(),
264 (pb_path /
"residual.fits").native());
#define PURIFY_LOW_LOG(...)
Low priority message.
#define PURIFY_HIGH_LOG(...)
High priority message.
const t_real pi
mathematical constant
std::vector< t_int > distribute_measurements(Vector< t_real > const &u, Vector< t_real > const &v, Vector< t_real > const &w, t_int const number_of_nodes, distribute::plan const distribution_plan, t_int const &grid_size)
Distribute visiblities into groups.
const std::map< std::string, kernel > kernel_from_string
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
std::shared_ptr< sopt::LinearTransform< T > > init_degrid_operator_2d(const Vector< t_real > &u, const Vector< t_real > &v, const Vector< t_real > &w, const Vector< t_complex > &weights, const t_uint &imsizey, const t_uint &imsizex, const t_real &oversample_ratio=2, const kernels::kernel kernel=kernels::kernel::kb, const t_uint Ju=4, const t_uint Jv=4, const bool w_stacking=false, const t_real &cellx=1, const t_real &celly=1)
Returns linear transform that is the standard degridding operator.
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
Image< t_complex > read2d(const std::string &fits_name)
Read image from fits file.
t_real step_size(T const &vis, const std::shared_ptr< sopt::LinearTransform< T > const > &measurements, const std::shared_ptr< sopt::LinearTransform< T > const > &wavelets, const t_uint sara_size)
Calculate step size using MPI (does not include factor of 1e-3)
t_real SNR_to_standard_deviation(const Vector< t_complex > &y0, const t_real &SNR)
Converts SNR to RMS noise.
Vector< t_complex > add_noise(const Vector< t_complex > &y0, const t_complex &mean, const t_real &standard_deviation)
Add guassian noise to vector.
vis_params scatter_visibilities(vis_params const ¶ms, std::vector< t_int > const &sizes, sopt::mpi::Communicator const &comm)
vis_params regroup_and_scatter(vis_params const ¶ms, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
utilities::vis_params random_sample_density(const t_int vis_num, const t_real mean, const t_real standard_deviation, const t_real rms_w)
Generates a random visibility coverage.
t_real calculate_l2_radius(const t_uint y_size, const t_real &sigma, const t_real &n_sigma, const std::string distirbution)
A function that calculates the l2 ball radius for sopt.
void mkdir_recursive(const std::string &path)
recursively create directories when they do not exist
std::string output_filename(std::string const &filename)
Test output file.
std::string image_filename(std::string const &filename)
Image filename.
std::shared_ptr< sopt::algorithm::ImagingProximalADMM< t_complex > > padmm_factory(std::shared_ptr< sopt::LinearTransform< Vector< t_complex >> const > const &measurements, const sopt::wavelets::SARA &sara, const Image< t_complex > &ground_truth_image, const utilities::vis_params &uv_data, const t_real sigma, const sopt::mpi::Communicator &comm)
std::tuple< utilities::vis_params, t_real > dirty_visibilities(Image< t_complex > const &ground_truth_image, t_uint number_of_vis, t_real snr, const std::tuple< bool, t_real > &w_term)
int main(int nargs, char const **args)
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)