PURIFY
Next-generation radio interferometric imaging
Functions
algo_factory.cc File Reference
#include "catch2/catch_all.hpp"
#include "purify/config.h"
#include "purify/logging.h"
#include "purify/types.h"
#include "purify/directories.h"
#include "purify/pfitsio.h"
#include "purify/utilities.h"
#include "purify/algorithm_factory.h"
#include "purify/measurement_operator_factory.h"
#include "purify/wavelet_operator_factory.h"
#include <sopt/gradient_utils.h>
#include <sopt/power_method.h>
#include "purify/test_data.h"
+ Include dependency graph for algo_factory.cc:

Go to the source code of this file.

Functions

 TEST_CASE ("padmm_factory")
 
 TEST_CASE ("primal_dual_factory")
 
 TEST_CASE ("fb_factory")
 
 TEST_CASE ("joint_map_factory")
 

Function Documentation

◆ TEST_CASE() [1/4]

TEST_CASE ( "fb_factory"  )

Definition at line 137 of file algo_factory.cc.

137  {
138  const std::string &test_dir = "expected/fb/";
139  const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
140  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
141  const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
142  const std::string &result_path = data_filename(test_dir + "fb_result.fits");
143 
144  const auto solution = pfitsio::read2d(expected_solution_path);
145  const auto residual = pfitsio::read2d(expected_residual_path);
146 
147  auto uv_data = utilities::read_visibility(input_data_path, false);
148  uv_data.units = utilities::vis_units::radians;
149  CAPTURE(uv_data.vis.head(5));
150  REQUIRE(uv_data.size() == 13107);
151 
152  t_uint const imsizey = 128;
153  t_uint const imsizex = 128;
154 
155  Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
156  auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
157  factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2,
158  kernels::kernel_from_string.at("kb"), 4, 4);
159  auto const power_method_stuff =
160  sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
161  const t_real op_norm = std::get<0>(power_method_stuff);
162  measurements_transform->set_norm(op_norm);
163 
164  std::vector<std::tuple<std::string, t_uint>> const sara{
165  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
166  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
167  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
168  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
169  factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
170 
171  t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
172  t_real const beta = sigma * sigma;
173  t_real const gamma = 0.0001;
174 
175  auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
176  factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta,
177  gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50);
178 
179  auto const diagnostic = (*fb)();
180  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
181  pfitsio::write2d(image.real(), result_path);
182  // pfitsio::write2d(residual_image.real(), expected_residual_path);
183 
184  double brightness = solution.real().cwiseAbs().maxCoeff();
185  double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
186  .real()
187  .squaredNorm() /
188  solution.size();
189  double rms = sqrt(mse);
190  CHECK(rms <= brightness * 5e-2);
191 }
#define CHECK(CONDITION, ERROR)
Definition: casa.cc:6
const std::string test_dir
Definition: operators.cc:16
utilities::vis_params read_visibility(const std::string &vis_name, const bool w_term)
Reads an HDF5 file with u,v visibilities, constructs a vis_params object and returns it.
Definition: h5reader.h:166
const std::map< std::string, kernel > kernel_from_string
Definition: kernels.h:16
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
Definition: pfitsio.cc:30
Image< t_complex > read2d(const std::string &fits_name)
Read image from fits file.
Definition: pfitsio.cc:109
std::string data_filename(std::string const &filename)
Holds data and such.

References CHECK, purify::data_filename(), purify::kernels::kernel_from_string, purify::utilities::radians, purify::pfitsio::read2d(), purify::utilities::read_visibility(), purify::factory::serial, operators_test::test_dir, and purify::pfitsio::write2d().

◆ TEST_CASE() [2/4]

TEST_CASE ( "joint_map_factory"  )

Definition at line 424 of file algo_factory.cc.

424  {
425  const std::string &test_dir = "expected/joint_map/";
426  const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
427  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
428  const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
429 
430  const auto solution = pfitsio::read2d(expected_solution_path);
431  const auto residual = pfitsio::read2d(expected_residual_path);
432 
433  auto uv_data = utilities::read_visibility(input_data_path, false);
434  uv_data.units = utilities::vis_units::radians;
435  CAPTURE(uv_data.vis.head(5));
436  REQUIRE(uv_data.size() == 13107);
437 
438  t_uint const imsizey = 128;
439  t_uint const imsizex = 128;
440 
441  Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
442  auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
443  factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2,
444  kernels::kernel_from_string.at("kb"), 4, 4);
445  auto const power_method_stuff =
446  sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
447  const t_real op_norm = std::get<0>(power_method_stuff);
448  measurements_transform->set_norm(op_norm);
449 
450  std::vector<std::tuple<std::string, t_uint>> const sara{
451  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
452  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
453  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
454  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
455  factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
456  t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
457  t_real const beta = sigma * sigma;
458  t_real const gamma = 1;
459  auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
460  factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta,
461  gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50);
462  auto const l1_norm = [wavelets](const Vector<t_complex> &x) {
463  auto val = sopt::l1_norm(wavelets->adjoint() * x);
464  return val;
465  };
466  auto const joint_map =
467  sopt::algorithm::JointMAP<sopt::algorithm::ImagingForwardBackward<t_complex>>(
468  fb, l1_norm, imsizex * imsizey * sara.size())
469  .relative_variation(1e-3)
470  .objective_variation(1e-3)
471  .beta(1.)
472  .alpha(1.);
473  auto const diagnostic = joint_map();
474  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
475  // CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
476  // CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
477  // CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(),
478  // image.size()).real().head(10));
479  // CHECK(image.isApprox(solution, 1e-6));
480 
481  const Vector<t_complex> residuals = measurements_transform->adjoint() *
482  (uv_data.vis - ((*measurements_transform) * diagnostic.x));
483  const Image<t_complex> residual_image = Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
484  // CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
485  // CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
486  // CHECK(residual_image.real().isApprox(residual.real(), 1e-6));
487 }

References purify::data_filename(), purify::kernels::kernel_from_string, purify::utilities::radians, purify::pfitsio::read2d(), purify::utilities::read_visibility(), purify::factory::serial, and operators_test::test_dir.

◆ TEST_CASE() [3/4]

TEST_CASE ( "padmm_factory"  )

Definition at line 31 of file algo_factory.cc.

31  {
32  const std::string &test_dir = "expected/padmm/";
33  const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
34  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
35  const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
36 
37  const auto solution = pfitsio::read2d(expected_solution_path);
38  const auto residual = pfitsio::read2d(expected_residual_path);
39 
40  auto uv_data = utilities::read_visibility(input_data_path, false);
41  uv_data.units = utilities::vis_units::radians;
42  CAPTURE(uv_data.vis.head(5));
43  REQUIRE(uv_data.size() == 13107);
44 
45  t_uint const imsizey = 128;
46  t_uint const imsizex = 128;
47  Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
48  auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
49  factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2,
50  kernels::kernel_from_string.at("kb"), 4, 4);
51  auto const power_method_stuff =
52  sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
53  const t_real op_norm = std::get<0>(power_method_stuff);
54  measurements_transform->set_norm(op_norm);
55 
56  std::vector<std::tuple<std::string, t_uint>> const sara{
57  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
58  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
59  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
60  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
61  factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
62  t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
63  auto const padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
64  factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, imsizey,
65  imsizex, sara.size(), 300, true, true, false, 1e-2, 1e-3, 50, 1);
66 
67  auto const diagnostic = (*padmm)();
68  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
69  // pfitsio::write2d(image.real(), expected_solution_path);
70  CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
71  CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
72  CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
73  CHECK(image.isApprox(solution, 1e-4));
74 
75  const Vector<t_complex> residuals = measurements_transform->adjoint() *
76  (uv_data.vis - ((*measurements_transform) * diagnostic.x));
77  const Image<t_complex> residual_image = Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
78  // pfitsio::write2d(residual_image.real(), expected_residual_path);
79  CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
80  CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
81  CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
82 }
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)

References CHECK, purify::data_filename(), purify::kernels::kernel_from_string, padmm(), purify::utilities::radians, purify::pfitsio::read2d(), purify::utilities::read_visibility(), purify::factory::serial, and operators_test::test_dir.

◆ TEST_CASE() [4/4]

TEST_CASE ( "primal_dual_factory"  )

Definition at line 84 of file algo_factory.cc.

84  {
85  const std::string &test_dir = "expected/primal_dual/";
86  const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
87  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
88  const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
89  const std::string &result_path = data_filename(test_dir + "pd_result.fits");
90 
91  const auto solution = pfitsio::read2d(expected_solution_path);
92  const auto residual = pfitsio::read2d(expected_residual_path);
93 
94  auto uv_data = utilities::read_visibility(input_data_path, false);
95  uv_data.units = utilities::vis_units::radians;
96  CAPTURE(uv_data.vis.head(5));
97  REQUIRE(uv_data.size() == 13107);
98 
99  t_uint const imsizey = 128;
100  t_uint const imsizex = 128;
101 
102  Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
103  auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
104  factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2,
105  kernels::kernel_from_string.at("kb"), 4, 4);
106  auto const power_method_stuff =
107  sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
108  const t_real op_norm = std::get<0>(power_method_stuff);
109  measurements_transform->set_norm(op_norm);
110 
111  std::vector<std::tuple<std::string, t_uint>> const sara{
112  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
113  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
114  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
115  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
116  factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
117  t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
118  auto const primaldual =
119  factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
120  factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma,
121  imsizey, imsizex, sara.size(), 1000, true, true, 1e-3, 1);
122 
123  auto const diagnostic = (*primaldual)();
124 
125  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
126  // pfitsio::write2d(image.real(), result_path);
127 
128  double brightness = solution.real().cwiseAbs().maxCoeff();
129  double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
130  .real()
131  .squaredNorm() /
132  solution.size();
133  double rms = sqrt(mse);
134  CHECK(rms <= brightness * 5e-2);
135 }

References CHECK, purify::data_filename(), purify::kernels::kernel_from_string, purify::utilities::radians, purify::pfitsio::read2d(), purify::utilities::read_visibility(), purify::factory::serial, and operators_test::test_dir.