5 #include <boost/filesystem.hpp>
6 #include <boost/math/special_functions/erf.hpp>
7 #include "purify/directories.h"
16 #include <sopt/imaging_padmm.h>
17 #include <sopt/mpi/communicator.h>
18 #include <sopt/mpi/session.h>
19 #include <sopt/power_method.h>
20 #include <sopt/relative_variation.h>
21 #include <sopt/utilities.h>
22 #include <sopt/wavelets.h>
23 #include <sopt/wavelets/sara.h>
29 #ifndef PURIFY_PADMM_ALGORITHM
30 #define PURIFY_PADMM_ALGORITHM 2
40 sopt::mpi::Communicator
const &comm) {
51 std::shared_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>>
padmm_factory(
52 std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>
const>
const &measurements,
54 const sopt::mpi::Communicator &comm,
const t_uint &imsizex,
const t_uint &imsizey) {
55 auto const Psi = sopt::linear_transform<t_complex>(sara, imsizey, imsizex, comm);
57 #if PURIFY_PADMM_ALGORITHM == 2
58 auto const epsilon = 3 * std::sqrt(comm.all_sum_all(std::pow(sigma, 2))) *
59 std::sqrt(2 * comm.all_sum_all(uv_data.
size()));
60 #elif PURIFY_PADMM_ALGORITHM == 3 || PURIFY_PADMM_ALGORITHM == 1
61 auto const epsilon = 3 * std::sqrt(2 * uv_data.
size()) * sigma;
63 const t_real regulariser_strength =
65 std::make_shared<sopt::LinearTransform<Vector<t_complex>>
const>(Psi),
73 auto padmm = std::make_shared<sopt::algorithm::ImagingProximalADMM<t_complex>>(uv_data.
vis);
75 .regulariser_strength(comm.all_reduce<t_real>(regulariser_strength, MPI_MAX))
76 .relative_variation(1e-3)
77 .l2ball_proximal_epsilon(epsilon)
78 #if PURIFY_PADMM_ALGORITHM == 2
80 .l2ball_proximal_communicator(comm)
83 .l1_proximal_adjoint_space_comm(comm)
85 .l1_proximal_tolerance(1e-2)
87 .l1_proximal_itermax(50)
88 .l1_proximal_positivity_constraint(
true)
89 .l1_proximal_real_constraint(
true)
90 .residual_tolerance(epsilon)
91 .lagrange_update_scale(0.9)
94 sopt::ScalarRelativeVariation<t_complex> conv(
padmm->relative_variation(),
95 padmm->relative_variation(),
"Objective function");
96 std::weak_ptr<decltype(
padmm)::element_type>
const padmm_weak(
padmm);
97 padmm->residual_convergence([padmm_weak, conv, comm](
98 Vector<t_complex>
const &x,
99 Vector<t_complex>
const &residual)
mutable ->
bool {
100 auto const padmm = padmm_weak.lock();
101 #if PURIFY_PADMM_ALGORITHM == 2
102 auto const residual_norm = sopt::mpi::l2_norm(residual,
padmm->l2ball_proximal_weights(), comm);
103 auto const result = residual_norm <
padmm->residual_tolerance();
104 #elif PURIFY_PADMM_ALGORITHM == 3 || PURIFY_PADMM_ALGORITHM == 1
105 auto const residual_norm = sopt::l2_norm(residual,
padmm->l2ball_proximal_weights());
107 comm.all_reduce<int8_t>(residual_norm <
padmm->residual_tolerance(), MPI_LAND);
109 SOPT_LOW_LOG(
" - [PADMM] Residuals: {} <? {}", residual_norm,
padmm->residual_tolerance());
113 padmm->objective_convergence([padmm_weak, conv, comm](Vector<t_complex>
const &x,
114 Vector<t_complex>
const &)
mutable ->
bool {
115 auto const padmm = padmm_weak.lock();
116 #if PURIFY_PADMM_ALGORITHM == 2
117 return conv(sopt::mpi::l1_norm(
padmm->Psi().adjoint() * x,
padmm->l1_proximal_weights(), comm));
118 #elif PURIFY_PADMM_ALGORITHM == 3 || PURIFY_PADMM_ALGORITHM == 1
119 return comm.all_reduce<uint8_t>(
120 conv(sopt::l1_norm(
padmm->Psi().adjoint() * x,
padmm->l1_proximal_weights())), MPI_LAND);
124 auto convergence_function = [](
const Vector<t_complex> &x) {
return true; };
125 const std::shared_ptr<t_uint> iter = std::make_shared<t_uint>(0);
126 const auto algo_update = [uv_data, imsizex, imsizey, padmm_weak, iter,
127 comm](
const Vector<t_complex> &x) ->
bool {
128 auto padmm = padmm_weak.lock();
131 Vector<t_complex>
const alpha =
padmm->Psi().adjoint() * x;
132 const t_real new_regulariser_strength =
133 comm.all_reduce(alpha.real().cwiseAbs().maxCoeff(), MPI_MAX) * 1e-3;
134 if (comm.is_root())
PURIFY_MEDIUM_LOG(
"Step size γ update {}", new_regulariser_strength);
135 padmm->regulariser_strength(
136 ((std::abs(
padmm->regulariser_strength() - new_regulariser_strength) > 0.2) and *iter < 200)
137 ? new_regulariser_strength
138 :
padmm->regulariser_strength());
141 Vector<t_complex>
const residual =
padmm->Phi().adjoint() * (uv_data.
vis -
padmm->Phi() * x);
143 if (comm.is_root()) {
149 auto lambda = [convergence_function, algo_update](Vector<t_complex>
const &x) {
150 return convergence_function(x) and algo_update(x);
152 padmm->is_converged(lambda);
156 int main(
int nargs,
char const **args) {
159 auto const session = sopt::mpi::init(nargs, args);
160 auto const world = sopt::mpi::Communicator::World();
162 const std::string name =
"realdata";
163 const std::string filename_base =
vla_filename(
"../mwa/uvdump_");
164 const std::vector<std::string> filenames = {filename_base +
167 std::string kernel_name =
"kb";
168 const bool w_term =
false;
170 const t_real cellsize = 20;
171 const t_uint imsizex = 1024;
172 const t_uint imsizey = 1024;
178 data.
weights.norm() / std::sqrt(world.all_sum_all(data.
weights.size())) * 0.5;
180 world.all_reduce(data.
weights.array().cwiseAbs().maxCoeff(), MPI_MAX);
181 #if PURIFY_PADMM_ALGORITHM == 2 || PURIFY_PADMM_ALGORITHM == 3
183 auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
185 world, data, imsizey, imsizex, cellsize, cellsize, 2,
kernel, 4, 4, w_term),
186 100, 1e-4, world.broadcast(Vector<t_complex>::Random(imsizex * imsizey).eval())));
189 auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
191 cellsize, 2,
kernel, 4, 4, w_term),
192 100, 1e-4, world.broadcast(Vector<t_complex>::Random(imsizex * imsizey).eval())));
195 #elif PURIFY_PADMM_ALGORITHM == 1
197 auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
198 measurementoperator::init_degrid_operator_2d_mpi<Vector<t_complex>>(
199 world, data, imsizey, imsizex, cellsize, cellsize, 2,
kernel, 4, 4, w_term),
200 100, 1e-4, world.broadcast(Vector<t_complex>::Random(imsizex * imsizey).eval())));
204 auto const measurements = std::get<2>(sopt::algorithm::normalise_operator<Vector<t_complex>>(
205 gpu::measurementoperator::init_degrid_operator_2d_mpi(world, data, imsizey, imsizex, cellsize,
206 cellsize, 2,
kernel, 4, 4, w_term),
207 100, 1e-4, world.broadcast(Vector<t_complex>::Random(imsizex * imsizey).eval())));
211 auto const sara = sopt::wavelets::distribute_sara(
212 sopt::wavelets::SARA{
213 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
214 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
215 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)},
218 Vector<t_real>
const dirty_image = (measurements->adjoint() * (data.
vis)).real();
220 if (world.is_root()) {
223 #if PURIFY_PADMM_ALGORITHM == 3
224 auto const pb_path = path / kernel_name /
"local_epsilon_replicated_grids";
225 #elif PURIFY_PADMM_ALGORITHM == 2
226 auto const pb_path = path / kernel_name /
"global_epsilon_replicated_grids";
227 #elif PURIFY_PADMM_ALGORITHM == 1
228 auto const pb_path = path / kernel_name /
"local_epsilon_distributed_grids";
230 #error Unknown or unimplemented algorithm
234 pfitsio::write2d(dirty_image, imsizey, imsizex, (pb_path /
"dirty.fits").native());
238 auto const padmm =
padmm_factory(measurements, sigma, sara, data, world, imsizey, imsizex);
240 auto const diagnostic = (*padmm)();
243 assert(world.broadcast(diagnostic.x).isApprox(diagnostic.x));
245 Vector<t_real>
const residual_image = (measurements->adjoint() * diagnostic.residual).real();
246 if (world.is_root()) {
249 #if PURIFY_PADMM_ALGORITHM == 3
250 auto const pb_path = path / kernel_name /
"local_epsilon_replicated_grids";
251 #elif PURIFY_PADMM_ALGORITHM == 2
252 auto const pb_path = path / kernel_name /
"global_epsilon_replicated_grids";
253 #elif PURIFY_PADMM_ALGORITHM == 1
254 auto const pb_path = path / kernel_name /
"local_epsilon_distributed_grids";
256 #error Unknown or unimplemented algorithm
260 pfitsio::write2d(dirty_image, imsizey, imsizex, (pb_path /
"dirty.fits").native());
261 pfitsio::write2d(diagnostic.x.real(), imsizey, imsizex, (pb_path /
"solution.fits").native());
262 pfitsio::write2d(residual_image, imsizey, imsizex, (pb_path /
"residual.fits").native());
#define PURIFY_MEDIUM_LOG(...)
Medium priority message.
std::vector< t_int > distribute_measurements(Vector< t_real > const &u, Vector< t_real > const &v, Vector< t_real > const &w, t_int const number_of_nodes, distribute::plan const distribution_plan, t_int const &grid_size)
Distribute visiblities into groups.
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
std::shared_ptr< sopt::LinearTransform< T > > init_degrid_operator_2d(const Vector< t_real > &u, const Vector< t_real > &v, const Vector< t_real > &w, const Vector< t_complex > &weights, const t_uint &imsizey, const t_uint &imsizex, const t_real &oversample_ratio=2, const kernels::kernel kernel=kernels::kernel::kb, const t_uint Ju=4, const t_uint Jv=4, const bool w_stacking=false, const t_real &cellx=1, const t_real &celly=1)
Returns linear transform that is the standard degridding operator.
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
t_real step_size(T const &vis, const std::shared_ptr< sopt::LinearTransform< T > const > &measurements, const std::shared_ptr< sopt::LinearTransform< T > const > &wavelets, const t_uint sara_size)
Calculate step size using MPI (does not include factor of 1e-3)
vis_params scatter_visibilities(vis_params const ¶ms, std::vector< t_int > const &sizes, sopt::mpi::Communicator const &comm)
vis_params regroup_and_scatter(vis_params const ¶ms, std::vector< t_int > const &groups, sopt::mpi::Communicator const &comm)
utilities::vis_params read_visibility(const std::vector< std::string > &names, const bool w_term)
Read visibility files from name of vector.
void mkdir_recursive(const std::string &path)
recursively create directories when they do not exist
std::string output_filename(std::string const &filename)
Test output file.
std::string vla_filename(std::string const &filename)
Specific vla data.
std::shared_ptr< sopt::algorithm::ImagingProximalADMM< t_complex > > padmm_factory(std::shared_ptr< sopt::LinearTransform< Vector< t_complex >> const > const &measurements, t_real const sigma, const sopt::wavelets::SARA &sara, const utilities::vis_params &uv_data, const sopt::mpi::Communicator &comm, const t_uint &imsizex, const t_uint &imsizey)
int main(int nargs, char const **args)
utilities::vis_params dirty_visibilities(const std::vector< std::string > &names)
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)
t_uint size() const
return number of measurements
Vector< t_complex > weights