2 #include <sopt/differentiable_func.h>
3 #include <sopt/l1_non_diff_function.h>
4 #include <sopt/l2_differentiable_func.h>
5 #include <sopt/non_differentiable_func.h>
8 #include <sopt/onnx_differentiable_func.h>
11 #include <sopt/power_method.h>
12 #include <sopt/real_indicator.h>
15 #include <sopt/tf_non_diff_function.h>
22 std::vector<std::tuple<std::string, t_uint>> sara;
23 for (
size_t i = 0; i < params.wavelet_basis().size(); i++)
24 sara.push_back(std::make_tuple(params.wavelet_basis().at(i), params.wavelet_levels()));
28 auto const world = sopt::mpi::Communicator::World();
30 sara = sopt::wavelets::distribute_sara(sara, world);
33 auto const wavelets_transform = factory::wavelet_operator_factory<Vector<t_complex>>(
34 wop_algo, sara, params.height(), params.width(), sara_size);
35 return {wavelets_transform, sara_size};
43 bool using_mpi =
false;
46 throw std::runtime_error(
"Compile with MPI if you want to use MPI algorithm");
48 mop_algo = (not params.gpu())
51 if (params.mpi_all_to_all())
52 mop_algo = (not params.gpu())
63 return {mop_algo, wop_algo, using_mpi};
70 bool w_term = params.w_term();
72 std::vector<t_int> image_index = std::vector<t_int>();
73 std::vector<t_real> w_stacks = std::vector<t_real>();
75 Vector<t_complex> measurement_op_eigen_vector =
76 Vector<t_complex>::Ones(params.width() * params.height());
78 if (params.eigenvector_real() !=
"" and params.eigenvector_imag() !=
"") {
83 Vector<t_real> temp_real;
84 Vector<t_real> temp_imag;
85 pfitsio::read3d(params.eigenvector_real(), temp_real, rows, cols, chans, pols);
86 if (rows != params.height() or cols != params.width() or chans != 1 or pols != 1)
87 throw std::runtime_error(
"Image of measurement operator eigenvector is wrong size.");
88 pfitsio::read3d(params.eigenvector_imag(), temp_imag, rows, cols, chans, pols);
89 if (rows != params.height() or cols != params.width() or chans != 1 or pols != 1)
90 throw std::runtime_error(
"Image of measurement operator eigenvector is wrong size.");
91 measurement_op_eigen_vector.real() = temp_real;
92 measurement_op_eigen_vector.imag() = temp_imag;
96 for (
size_t i = 0; i < params.measurements().size(); i++)
98 sigma = params.measurements_sigma();
101 auto const world = sopt::mpi::Communicator::World();
104 params.measurements_units());
106 std::sqrt(world.all_sum_all(
107 (uv_data.
weights.real().array() * uv_data.
weights.real().array()).sum()) /
108 world.all_sum_all(uv_data.
size()));
117 params.measurements_units());
118 const t_real norm = std::sqrt(
119 (uv_data.
weights.real().array() * uv_data.
weights.real().array()).sum() / uv_data.
size());
127 if (params.mpi_wstacking() and
130 auto const world = sopt::mpi::Communicator::World();
131 const auto cost = [](t_real x) -> t_real {
return std::abs(x * x); };
135 uv_data, du, params.Jx(), params.Jw(), world, params.kmeans_iters(), 0, cost);
136 }
else if (params.mpi_wstacking()) {
137 auto const world = sopt::mpi::Communicator::World();
138 const auto cost = [](t_real x) -> t_real {
return std::abs(x * x); };
143 PURIFY_HIGH_LOG(
"Input visibilities will be generated for random coverage.");
146 if (params.height() != image.rows() || params.width() != image.cols())
147 throw std::runtime_error(
"Input image size (" + std::to_string(image.cols()) +
"x" +
148 std::to_string(image.rows()) +
") is not equal to the input one (" +
149 std::to_string(params.width()) +
"x" +
150 std::to_string(params.height()) +
").");
151 t_int
const number_of_pixels = image.size();
152 t_int
const number_of_vis = params.number_of_measurements();
154 const t_real rms_w = params.w_rms();
155 if (params.measurements().at(0) ==
"") {
158 uv_data.
weights = Vector<t_complex>::Ones(uv_data.
size());
162 auto const world = sopt::mpi::Communicator::World();
165 params.measurements_units());
169 params.measurements_units());
170 uv_data.
weights = Vector<t_complex>::Ones(uv_data.
weights.size());
174 if (params.mpi_wstacking() and
177 auto const world = sopt::mpi::Communicator::World();
178 const auto cost = [](t_real x) -> t_real {
return std::abs(x * x); };
182 uv_data, du, params.Jx(), params.Jw(), world, params.kmeans_iters(), 0, cost);
183 }
else if (params.mpi_wstacking()) {
184 auto const world = sopt::mpi::Communicator::World();
185 const auto cost = [](t_real x) -> t_real {
return std::abs(x * x); };
189 std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>> sky_measurements;
193 (not params.wprojection())
195 mop_algo, uv_data, params.height(), params.width(), params.cellsizey(),
196 params.cellsizex(), params.oversampling(),
198 params.mpi_wstacking())
200 mop_algo, uv_data, params.height(), params.width(), params.cellsizey(),
201 params.cellsizex(), params.oversampling(),
206 (not params.wprojection())
208 mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(),
209 params.cellsizey(), params.cellsizex(), params.oversampling(),
211 params.mpi_wstacking())
213 mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(),
214 params.cellsizey(), params.cellsizex(), params.oversampling(),
218 ((*sky_measurements) * Vector<t_complex>::Map(image.data(), image.size())).eval().array();
223 params.width(), params.oversampling());
225 params.height(), params.oversampling());
228 auto const comm = sopt::mpi::Communicator::World();
230 comm.all_reduce<t_real>(uv_data.
u.cwiseAbs().maxCoeff(), MPI_MAX), params.width(),
231 params.oversampling());
233 comm.all_reduce<t_real>(uv_data.
v.cwiseAbs().maxCoeff(), MPI_MAX), params.height(),
234 params.oversampling());
238 "Using cell size {}\" x {}\", recommended from the uv coverage and field of view is "
240 params.cellsizey(), params.cellsizex(), ideal_cell_y, ideal_cell_x);
243 params.oversampling()),
245 params.oversampling()));
247 return {uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks};
253 const std::vector<t_int> &image_index,
const std::vector<t_real> &w_stacks,
255 std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>> measurements_transform;
258 measurements_transform =
259 (not params.wprojection())
261 mop_algo, uv_data, params.height(), params.width(), params.cellsizey(),
262 params.cellsizex(), params.oversampling(),
264 params.mpi_wstacking())
266 mop_algo, uv_data, params.height(), params.width(), params.cellsizey(),
267 params.cellsizex(), params.oversampling(),
271 measurements_transform =
272 (not params.wprojection())
274 mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(),
275 params.cellsizey(), params.cellsizex(), params.oversampling(),
277 params.mpi_wstacking())
279 mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(),
280 params.cellsizey(), params.cellsizex(), params.oversampling(),
283 t_real operator_norm = 1.;
286 auto const comm = sopt::mpi::Communicator::World();
287 auto power_method_result =
289 ? sopt::algorithm::power_method<Vector<t_complex>>(
290 *measurements_transform, params.powMethod_iter(), params.powMethod_tolerance(),
291 comm.broadcast(measurement_op_eigen_vector).eval())
292 : sopt::algorithm::all_sum_all_power_method<Vector<t_complex>>(
293 comm, *measurements_transform, params.powMethod_iter(),
294 params.powMethod_tolerance(), comm.broadcast(measurement_op_eigen_vector).eval());
295 measurement_op_eigen_vector = std::get<1>(power_method_result);
296 operator_norm = std::get<0>(power_method_result);
297 measurements_transform->set_norm(operator_norm);
301 auto power_method_result = sopt::algorithm::power_method<Vector<t_complex>>(
302 *measurements_transform, params.powMethod_iter(), params.powMethod_tolerance(),
303 measurement_op_eigen_vector);
304 measurement_op_eigen_vector = std::get<1>(power_method_result);
305 operator_norm = std::get<0>(power_method_result);
306 measurements_transform->set_norm(operator_norm);
309 return measurements_transform;
313 std::unique_ptr<NonDifferentiableFunc<t_complex>> &g, t_real sigma,
314 sopt::LinearTransform<Vector<t_complex>> &Phi) {
315 switch (params.diffFuncType()) {
317 f = std::make_unique<sopt::L2DifferentiableFunc<t_complex>>(sigma, Phi);
321 f = std::make_unique<sopt::ONNXDifferentiableFunc<t_complex>>(
322 params.CRR_function_model_path(), params.CRR_gradient_model_path(), sigma, params.CRR_mu(),
323 params.CRR_lambda(), Phi);
325 throw std::runtime_error(
326 "To use the CRR you must compile with ONNX runtime turned on. (-Donnxrt=on)");
331 switch (params.nondiffFuncType()) {
333 g = std::make_unique<sopt::algorithm::L1GProximal<t_complex>>();
337 g = std::make_unique<sopt::algorithm::TFGProximal<t_complex>>(params.model_path());
340 throw std::runtime_error(
341 "To use the Denoiser you must compile with ONNX runtime turned on. (-Donnxrt=on)");
345 g = std::make_unique<sopt::algorithm::RealIndicator<t_complex>>();
353 auto const world = sopt::mpi::Communicator::World();
356 throw std::runtime_error(
"Compile with MPI if you want to use MPI algorithm");
367 uv_data.
dec, params.measurements_polarization(), params.cellsizex(),
371 uv_data.
dec, params.measurements_polarization(), params.cellsizex(),
374 "",
"Jy/Pixel", 1, uv_data.
ra, uv_data.
dec, params.measurements_polarization(),
375 params.cellsizex(), params.cellsizey(), uv_data.
average_frequency, 0, 0,
false, 0, 0, 0);
377 return {update_header_sol, update_header_res, def_header};
381 const Vector<t_complex> &measurement_op_eigen_vector) {
384 auto const world = sopt::mpi::Communicator::World();
387 throw std::runtime_error(
"Compile with MPI if you want to use MPI algorithm");
390 pfitsio::write2d(measurement_op_eigen_vector.real(), params.height(), params.width(),
391 params.
output_path() +
"/eigenvector_real.fits",
"pix",
true);
392 pfitsio::write2d(measurement_op_eigen_vector.imag(), params.height(), params.width(),
393 params.
output_path() +
"/eigenvector_imag.fits",
"pix",
true);
396 pfitsio::write2d(measurement_op_eigen_vector.real(), params.height(), params.width(),
397 params.
output_path() +
"/eigenvector_real.fits",
"pix",
true);
398 pfitsio::write2d(measurement_op_eigen_vector.imag(), params.height(), params.width(),
399 params.
output_path() +
"/eigenvector_imag.fits",
"pix",
true);
405 const std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>> &measurements_transform,
407 const t_real beam_units) {
411 const Vector<t_complex> psf = measurements_transform->adjoint() * (uv_data.
weights / flux_scale);
412 const Image<t_real> psf_image =
413 Image<t_complex>::Map(psf.data(), params.height(), params.width()).real();
415 "Peak of PSF: {} (used to convert between Jy/Pixel and Jy/BEAM)",
416 psf_image(
static_cast<t_int
>(params.width() * 0.5 + params.height() * 0.5 * params.width())));
419 auto const world = sopt::mpi::Communicator::World();
421 "Expected image domain residual RMS is {} jy/beam",
422 sigma * params.epsilonScaling() * measurements_transform->norm() /
423 (std::sqrt(params.width() * params.height()) * world.all_sum_all(uv_data.
size())));
426 throw std::runtime_error(
"Compile with MPI if you want to use MPI algorithm");
430 PURIFY_LOW_LOG(
"Expected image domain residual RMS is {} jy/beam",
431 sigma * params.epsilonScaling() * measurements_transform->norm() /
432 (std::sqrt(params.width() * params.height()) * uv_data.
size()));
436 "Theoretical calculation for peak PSF: {} (used to convert between Jy/Pixel and Jy/BEAM)",
438 PURIFY_HIGH_LOG(
"Effective sigma is {} Jy", sigma * params.epsilonScaling());
443 const std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>> &measurements_transform,
448 const Vector<t_complex> dimage = measurements_transform->adjoint() * uv_data.
vis;
449 const Image<t_real> dirty_image =
450 Image<t_complex>::Map(dimage.data(), params.height(), params.width()).real();
453 auto const world = sopt::mpi::Communicator::World();
456 throw std::runtime_error(
"Compile with MPI if you want to use MPI algorithm");
std::string output_path() const
#define PURIFY_LOW_LOG(...)
Low priority message.
#define PURIFY_HIGH_LOG(...)
High priority message.
const t_real pi
mathematical constant
distributed_wavelet_operator
distributed_measurement_operator
determine type of distribute for mpi measurement operator
@ gpu_mpi_distribute_image
@ gpu_mpi_distribute_all_to_all
@ mpi_distribute_all_to_all
std::shared_ptr< sopt::LinearTransform< T > > measurement_operator_factory(const distributed_measurement_operator distribute, ARGS &&...args)
distributed measurement operator factory
std::shared_ptr< sopt::LinearTransform< T > > all_to_all_measurement_operator_factory(const distributed_measurement_operator distribute, const std::vector< t_int > &image_stacks, const std::vector< t_real > &w_stacks, ARGS &&...args)
distributed measurement operator factory
const std::map< std::string, kernel > kernel_from_string
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
Image< t_complex > read2d(const std::string &fits_name)
Read image from fits file.
std::vector< Image< t_complex > > read3d(const std::string &fits_name)
Read cube from fits file.
utilities::vis_params read_measurements(const std::string &name, const bool w_term, const stokes pol, const utilities::vis_units units)
read in single measurement file
utilities::vis_params w_stacking(utilities::vis_params const ¶ms, sopt::mpi::Communicator const &comm, const t_int iters, const std::function< t_real(t_real)> &cost, const t_real k_means_rel_diff)
t_real SNR_to_standard_deviation(const Vector< t_complex > &y0, const t_real &SNR)
Converts SNR to RMS noise.
Vector< t_complex > add_noise(const Vector< t_complex > &y0, const t_complex &mean, const t_real &standard_deviation)
Add guassian noise to vector.
utilities::vis_params conjugate_w(const utilities::vis_params &uv_vis)
reflects visibilities into the w >= 0 domain
std::tuple< utilities::vis_params, std::vector< t_int >, std::vector< t_real > > w_stacking_with_all_to_all(utilities::vis_params const ¶ms, const t_real du, const t_int min_support, const t_int max_support, sopt::mpi::Communicator const &comm, const t_int iters, const t_real fill_relaxation, const std::function< t_real(t_real)> &cost, const t_real k_means_rel_diff)
utilities::vis_params random_sample_density(const t_int vis_num, const t_real mean, const t_real standard_deviation, const t_real rms_w)
Generates a random visibility coverage.
t_real equivalent_miriad_cell_size(const t_real cell, const t_uint imsize, const t_real oversample_ratio)
for a given purify cell size in arcsec provide the equivalent miriad cell size in arcsec
t_real pixel_to_lambda(const t_real cell, const t_uint imsize, const t_real oversample_ratio)
return factors to convert between arcsecond pixel size image space and lambda for uv space
t_real estimate_cell_size(const t_real max_u, const t_uint imsize, const t_real oversample_ratio)
return cell size from the bandwidth
void saveMeasurementEigenVector(const YamlParser ¶ms, const Vector< t_complex > &measurement_op_eigen_vector)
void saveDirtyImage(const YamlParser ¶ms, const pfitsio::header_params &def_header, const std::shared_ptr< sopt::LinearTransform< Vector< t_complex >>> &measurements_transform, const utilities::vis_params &uv_data, const t_real beam_units)
void savePSF(const YamlParser ¶ms, const pfitsio::header_params &def_header, const std::shared_ptr< sopt::LinearTransform< Vector< t_complex >>> &measurements_transform, const utilities::vis_params &uv_data, const t_real flux_scale, const t_real sigma, const t_real beam_units)
std::shared_ptr< sopt::LinearTransform< Vector< t_complex > > > createMeasurementOperator(const YamlParser ¶ms, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi, const std::vector< t_int > &image_index, const std::vector< t_real > &w_stacks, const utilities::vis_params &uv_data, Vector< t_complex > &measurement_op_eigen_vector)
waveletInfo createWaveletOperator(YamlParser ¶ms, const factory::distributed_wavelet_operator &wop_algo)
OperatorsInfo selectOperators(YamlParser ¶ms)
void initOutDirectoryWithConfig(YamlParser ¶ms)
Headers genHeaders(const YamlParser ¶ms, const utilities::vis_params &uv_data)
inputData getInputData(const YamlParser ¶ms, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi)
void setupCostFunctions(const YamlParser ¶ms, std::unique_ptr< DifferentiableFunc< t_complex >> &f, std::unique_ptr< NonDifferentiableFunc< t_complex >> &g, t_real sigma, sopt::LinearTransform< Vector< t_complex >> &Phi)
t_uint size() const
return number of measurements
Vector< t_complex > weights