4 #include <boost/math/special_functions/erf.hpp>
5 #include "purify/directories.h"
8 #include <sopt/credible_region.h>
9 #include <sopt/imaging_padmm.h>
10 #include <sopt/relative_variation.h>
11 #include <sopt/utilities.h>
12 #include <sopt/wavelets.h>
13 #include <sopt/wavelets/sara.h>
25 void padmm(
const std::string &name,
const t_uint &imsizex,
const t_uint &imsizey,
27 const t_real sigma,
const std::tuple<bool, t_real> &w_term) {
28 std::string
const outfile_fits =
output_filename(name +
"_solution.fits");
29 std::string
const residual_fits =
output_filename(name +
"_residual.fits");
30 std::string
const dirty_image_fits =
output_filename(name +
"_dirty.fits");
33 t_real
const over_sample = 2;
34 std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>
const> measurements_transform =
35 measurementoperator::init_degrid_operator_2d<Vector<t_complex>>(
36 uv_data, imsizey, imsizex, std::get<1>(w_term), std::get<1>(w_term), over_sample,
38 t_uint
const M = uv_data.
size();
39 t_uint
const N = imsizex * imsizey;
40 sopt::wavelets::SARA
const sara{
41 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
42 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
43 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
45 auto const Psi = sopt::linear_transform<t_complex>(sara, imsizey, imsizex);
46 const Vector<> dimage = (measurements_transform->adjoint() * uv_data.
vis).real();
47 Matrix<t_complex> point = Matrix<t_complex>::Zero(imsizey, imsizex);
48 point(
int(imsizey / 2),
int(imsizex / 2)) = 1.;
50 (measurements_transform->adjoint() *
51 (*measurements_transform * Vector<t_complex>::Map(point.data(), point.size())).eval())
53 Vector<t_complex> initial_estimate = Vector<t_complex>::Zero(dimage.size());
54 pfitsio::write2d(Image<t_real>::Map(dimage.data(), imsizey, imsizex), dirty_image_fits);
55 pfitsio::write2d(Image<t_real>::Map(psf.data(), imsizey, imsizex), psf_image_fits);
56 auto const epsilon = 3 * std::sqrt(2 * uv_data.
size()) * sigma;
57 auto const regulariser_strength =
58 (measurements_transform->adjoint() * uv_data.
vis).real().maxCoeff() * 1e-3;
61 auto const canvas = std::make_shared<CDisplay>(
62 cimg::make_display(Vector<t_real>::Zero(2 * imsizex * imsizey), 2 * imsizex, imsizey));
64 auto const show_image = [&, measurements_transform](
const Vector<t_complex> &x) ->
bool {
65 if (!canvas->is_closed()) {
66 const Vector<t_complex> res =
67 (measurements_transform->adjoint() * (uv_data.
vis - (*measurements_transform * x)));
68 const auto img1 = cimg::make_image(x.real().eval(), imsizey, imsizex)
70 .get_resize(512, 512);
71 const auto img2 = cimg::make_image(res.real().eval(), imsizey, imsizex)
73 .get_resize(512, 512);
74 const auto results = CImageList<t_real>(img1, img2);
75 canvas->display(results);
81 auto padmm = std::make_shared<sopt::algorithm::ImagingProximalADMM<t_complex>>(uv_data.
vis);
83 .regulariser_strength(regulariser_strength)
84 .relative_variation(1e-3)
85 .l2ball_proximal_epsilon(epsilon)
87 .l1_proximal_tolerance(1e-2)
89 .l1_proximal_itermax(50)
90 .l1_proximal_positivity_constraint(
true)
91 .l1_proximal_real_constraint(
true)
92 .residual_convergence(epsilon)
93 .lagrange_update_scale(0.9)
95 .Phi(*measurements_transform);
97 auto convergence_function = [](
const Vector<t_complex> &x) {
return true; };
98 const std::shared_ptr<t_uint> iter = std::make_shared<t_uint>(0);
100 std::weak_ptr<decltype(
padmm)::element_type>
const padmm_weak(
padmm);
101 const auto algo_update = [uv_data, imsizex, imsizey, padmm_weak,
102 iter](
const Vector<t_complex> &x) ->
bool {
103 auto padmm = padmm_weak.lock();
106 Vector<t_complex>
const alpha =
padmm->Psi().adjoint() * x;
108 const t_real new_regulariser_strength = alpha.real().cwiseAbs().maxCoeff() * 1e-3;
110 padmm->regulariser_strength(
111 ((std::abs(
padmm->regulariser_strength() - new_regulariser_strength) > 0.2) and *iter < 200)
112 ? new_regulariser_strength
113 :
padmm->regulariser_strength());
115 Vector<t_complex>
const residual =
padmm->Phi().adjoint() * (uv_data.
vis -
padmm->Phi() * x);
121 auto lambda = [=](Vector<t_complex>
const &x) {
122 return convergence_function(x)
128 padmm->is_converged(lambda);
129 auto const diagnostic = (*padmm)();
130 Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
132 Vector<t_complex> residuals = measurements_transform->adjoint() *
133 (uv_data.
vis - ((*measurements_transform) * diagnostic.x));
134 Image<t_complex> residual_image = Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
137 const auto results = CImageList<t_real>(
138 cimg::make_image(diagnostic.x.real().eval(), imsizey, imsizex).get_resize(512, 512),
139 cimg::make_image(residuals.real().eval(), imsizey, imsizex).get_resize(512, 512));
140 canvas->display(results);
141 cimg::make_image(residuals.real().eval(), imsizey, imsizex)
143 .display_graph(
"Residual Histogram", 2);
144 while (!canvas->is_closed()) canvas->wait();
151 const std::string &name =
"real_data";
152 const bool w_term =
false;
153 const t_real cellsize = 20;
154 const t_uint imsizex = 1024;
155 const t_uint imsizey = 1024;
156 const std::string
kernel =
"kb";
157 const std::vector<std::string> inputfiles = {
vla_filename(
"../mwa/uvdump_01.uvfits"),
161 t_real
const sigma = uv_data.weights.norm() / std::sqrt(uv_data.weights.size()) * 0.05;
163 uv_data.vis.array() * uv_data.weights.array() / uv_data.weights.array().cwiseAbs().maxCoeff();
164 padmm(name, imsizex, imsizey,
kernel, 4, uv_data, sigma, std::make_tuple(w_term, cellsize));
#define PURIFY_HIGH_LOG(...)
High priority message.
#define PURIFY_MEDIUM_LOG(...)
Medium priority message.
const std::map< std::string, kernel > kernel_from_string
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
utilities::vis_params read_uvfits(const std::vector< std::string > &names, const bool flag, const stokes pol)
Read uvfits files from name of vector.
std::string output_filename(std::string const &filename)
Test output file.
std::string vla_filename(std::string const &filename)
Specific vla data.
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)
t_uint size() const
return number of measurements