PURIFY
Next-generation radio interferometric imaging
algo_factory.cc
Go to the documentation of this file.
1 
2 #include "catch2/catch_all.hpp"
3 
4 #include "purify/config.h"
5 #include "purify/logging.h"
6 
7 #include "purify/types.h"
8 #include "purify/directories.h"
9 #include "purify/pfitsio.h"
10 #include "purify/utilities.h"
11 
15 
16 #ifdef PURIFY_ONNXRT
17 #include <sopt/onnx_differentiable_func.h>
18 #endif
19 
20 #ifdef PURIFY_H5
21 #include "purify/h5reader.h"
22 #endif
23 
24 #include <sopt/gradient_utils.h>
25 #include <sopt/power_method.h>
26 
27 #include "purify/test_data.h"
28 
29 using namespace purify;
30 
31 TEST_CASE("padmm_factory") {
32  const std::string &test_dir = "expected/padmm/";
33  const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
34  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
35  const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
36 
37  const auto solution = pfitsio::read2d(expected_solution_path);
38  const auto residual = pfitsio::read2d(expected_residual_path);
39 
40  auto uv_data = utilities::read_visibility(input_data_path, false);
41  uv_data.units = utilities::vis_units::radians;
42  CAPTURE(uv_data.vis.head(5));
43  REQUIRE(uv_data.size() == 13107);
44 
45  t_uint const imsizey = 128;
46  t_uint const imsizex = 128;
47  Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
48  auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
49  factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2,
50  kernels::kernel_from_string.at("kb"), 4, 4);
51  auto const power_method_stuff =
52  sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
53  const t_real op_norm = std::get<0>(power_method_stuff);
54  measurements_transform->set_norm(op_norm);
55 
56  std::vector<std::tuple<std::string, t_uint>> const sara{
57  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
58  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
59  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
60  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
61  factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
62  t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
63  auto const padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
64  factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, imsizey,
65  imsizex, sara.size(), 300, true, true, false, 1e-2, 1e-3, 50, 1);
66 
67  auto const diagnostic = (*padmm)();
68  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
69  // pfitsio::write2d(image.real(), expected_solution_path);
70  CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
71  CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
72  CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
73  CHECK(image.isApprox(solution, 1e-4));
74 
75  const Vector<t_complex> residuals = measurements_transform->adjoint() *
76  (uv_data.vis - ((*measurements_transform) * diagnostic.x));
77  const Image<t_complex> residual_image = Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
78  // pfitsio::write2d(residual_image.real(), expected_residual_path);
79  CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
80  CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
81  CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
82 }
83 
84 TEST_CASE("primal_dual_factory") {
85  const std::string &test_dir = "expected/primal_dual/";
86  const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
87  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
88  const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
89  const std::string &result_path = data_filename(test_dir + "pd_result.fits");
90 
91  const auto solution = pfitsio::read2d(expected_solution_path);
92  const auto residual = pfitsio::read2d(expected_residual_path);
93 
94  auto uv_data = utilities::read_visibility(input_data_path, false);
95  uv_data.units = utilities::vis_units::radians;
96  CAPTURE(uv_data.vis.head(5));
97  REQUIRE(uv_data.size() == 13107);
98 
99  t_uint const imsizey = 128;
100  t_uint const imsizex = 128;
101 
102  Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
103  auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
104  factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2,
105  kernels::kernel_from_string.at("kb"), 4, 4);
106  auto const power_method_stuff =
107  sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
108  const t_real op_norm = std::get<0>(power_method_stuff);
109  measurements_transform->set_norm(op_norm);
110 
111  std::vector<std::tuple<std::string, t_uint>> const sara{
112  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
113  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
114  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
115  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
116  factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
117  t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
118  auto const primaldual =
119  factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
120  factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma,
121  imsizey, imsizex, sara.size(), 1000, true, true, 1e-3, 1);
122 
123  auto const diagnostic = (*primaldual)();
124 
125  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
126  // pfitsio::write2d(image.real(), result_path);
127 
128  double brightness = solution.real().cwiseAbs().maxCoeff();
129  double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
130  .real()
131  .squaredNorm() /
132  solution.size();
133  double rms = sqrt(mse);
134  CHECK(rms <= brightness * 5e-2);
135 }
136 
137 TEST_CASE("fb_factory") {
138  const std::string &test_dir = "expected/fb/";
139  const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
140  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
141  const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
142  const std::string &result_path = data_filename(test_dir + "fb_result.fits");
143 
144  const auto solution = pfitsio::read2d(expected_solution_path);
145  const auto residual = pfitsio::read2d(expected_residual_path);
146 
147  auto uv_data = utilities::read_visibility(input_data_path, false);
148  uv_data.units = utilities::vis_units::radians;
149  CAPTURE(uv_data.vis.head(5));
150  REQUIRE(uv_data.size() == 13107);
151 
152  t_uint const imsizey = 128;
153  t_uint const imsizex = 128;
154 
155  Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
156  auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
157  factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2,
158  kernels::kernel_from_string.at("kb"), 4, 4);
159  auto const power_method_stuff =
160  sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
161  const t_real op_norm = std::get<0>(power_method_stuff);
162  measurements_transform->set_norm(op_norm);
163 
164  std::vector<std::tuple<std::string, t_uint>> const sara{
165  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
166  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
167  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
168  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
169  factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
170 
171  t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
172  t_real const beta = sigma * sigma;
173  t_real const gamma = 0.0001;
174 
175  auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
176  factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta,
177  gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50);
178 
179  auto const diagnostic = (*fb)();
180  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
181  pfitsio::write2d(image.real(), result_path);
182  // pfitsio::write2d(residual_image.real(), expected_residual_path);
183 
184  double brightness = solution.real().cwiseAbs().maxCoeff();
185  double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
186  .real()
187  .squaredNorm() /
188  solution.size();
189  double rms = sqrt(mse);
190  CHECK(rms <= brightness * 5e-2);
191 }
192 
193 #ifdef PURIFY_H5
194 TEST_CASE("fb_factory_stochastic") {
195  const std::string &test_dir = "expected/fb/";
196  const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
197  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
198  const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
199  const std::string &result_path = data_filename(test_dir + "fb_result_stochastic.fits");
200 
201  auto uv_data = utilities::read_visibility(input_data_path, false);
202  uv_data.units = utilities::vis_units::radians;
203  CAPTURE(uv_data.vis.head(5));
204  REQUIRE(uv_data.size() == 13107);
205 
206  t_uint const imsizey = 128;
207  t_uint const imsizex = 128;
208 
209  // This functor would be defined in Purify
210  std::mt19937 rng(0);
211  const size_t N = 1000;
212  std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
213  [&input_data_path, imsizex, imsizey, &rng, &N]() {
214  utilities::vis_params uv_data = utilities::read_visibility(input_data_path, false);
216 
217  // Get random subset
218  std::vector<size_t> indices(uv_data.size());
219  size_t i = 0;
220  for (auto &x : indices) {
221  x = i++;
222  }
223 
224  std::shuffle(indices.begin(), indices.end(), rng);
225  Vector<t_real> u_fragment(N);
226  Vector<t_real> v_fragment(N);
227  Vector<t_real> w_fragment(N);
228  Vector<t_complex> vis_fragment(N);
229  Vector<t_complex> weights_fragment(N);
230  for (i = 0; i < N; i++) {
231  size_t j = indices[i];
232  u_fragment[i] = uv_data.u[j];
233  v_fragment[i] = uv_data.v[j];
234  w_fragment[i] = uv_data.w[j];
235  vis_fragment[i] = uv_data.vis[j];
236  weights_fragment[i] = uv_data.weights[j];
237  }
238  utilities::vis_params uv_data_fragment(u_fragment, v_fragment, w_fragment, vis_fragment,
239  weights_fragment, uv_data.units, uv_data.ra,
240  uv_data.dec, uv_data.average_frequency);
241 
242  auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
243  factory::distributed_measurement_operator::serial, uv_data_fragment, imsizey, imsizex,
244  1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
245 
246  Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
247  auto const power_method_stuff =
248  sopt::algorithm::power_method<Vector<t_complex>>(*phi, 1000, 1e-5, init);
249  const t_real op_norm = std::get<0>(power_method_stuff);
250  phi->set_norm(op_norm);
251 
252  return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data_fragment.vis, phi);
253  };
254 
255  const auto solution = pfitsio::read2d(expected_solution_path);
256 
257  // wavelets
258  std::vector<std::tuple<std::string, t_uint>> const sara{
259  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
260  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
261  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
262  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
263  factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
264 
265  // algorithm
266  t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
267  t_real const beta = sigma * sigma;
268  t_real const gamma = 0.0001;
269 
270  sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
271  fb.itermax(1000)
272  .step_size(beta * sqrt(2))
273  .sigma(sigma * sqrt(2))
274  .regulariser_strength(gamma)
275  .relative_variation(1e-3)
276  .residual_tolerance(0)
277  .tight_frame(true);
278 
279  auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
280  gp->l1_proximal_tolerance(1e-4)
281  .l1_proximal_nu(1)
282  .l1_proximal_itermax(50)
283  .l1_proximal_positivity_constraint(true)
284  .l1_proximal_real_constraint(true)
285  .Psi(*wavelets);
286  fb.g_function(gp);
287 
288  auto const diagnostic = fb();
289  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
290  // pfitsio::write2d(image.real(), result_path);
291  // pfitsio::write2d(residual_image.real(), expected_residual_path);
292 
293  auto soln_flat = Vector<t_complex>::Map(solution.data(), solution.size());
294  double brightness = soln_flat.real().cwiseAbs().maxCoeff();
295  double mse = (soln_flat - diagnostic.x).real().squaredNorm() / solution.size();
296  SOPT_HIGH_LOG("MSE = {}", mse);
297  CHECK(mse <= brightness * 5e-2);
298 }
299 #endif
300 
301 #ifdef PURIFY_ONNXRT
302 TEST_CASE("tf_fb_factory") {
303  const std::string &test_dir = "expected/fb/";
304  const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
305  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
306  const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
307  const std::string &result_path = data_filename(test_dir + "tf_result.fits");
308 
309  const auto solution = pfitsio::read2d(expected_solution_path);
310  const auto residual = pfitsio::read2d(expected_residual_path);
311 
312  auto uv_data = utilities::read_visibility(input_data_path, false);
314  CAPTURE(uv_data.vis.head(5));
315  REQUIRE(uv_data.size() == 13107);
316 
317  t_uint const imsizey = 128;
318  t_uint const imsizex = 128;
319 
320  Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
321  auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
322  factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2,
323  kernels::kernel_from_string.at("kb"), 4, 4);
324  auto const power_method_stuff =
325  sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
326  const t_real op_norm = std::get<0>(power_method_stuff);
327  measurements_transform->set_norm(op_norm);
328 
329  std::vector<std::tuple<std::string, t_uint>> const sara{
330  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
331  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
332  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
333  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
334  factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
335  t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
336  t_real const beta = sigma * sigma;
337  t_real const gamma = 0.0001;
338 
339  std::string tf_model_path = purify::models_directory() + "/snr_15_model_dynamic.onnx";
340 
341  auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
342  factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta,
343  gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, tf_model_path,
345 
346  auto const diagnostic = (*fb)();
347  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
348  // pfitsio::write2d(image.real(), result_path);
349  // pfitsio::write2d(residual_image.real(), expected_residual_path);
350 
351  double brightness = solution.real().cwiseAbs().maxCoeff();
352  double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
353  .real()
354  .squaredNorm() /
355  solution.size();
356  double rms = sqrt(mse);
357  CHECK(rms <= brightness * 5e-2);
358 }
359 
360 TEST_CASE("onnx_fb_factory") {
361  const std::string &test_dir = "expected/fb/";
362  const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
363  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
364  const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
365  const std::string &result_path = data_filename(test_dir + "onnx_result.fits");
366  const auto solution = pfitsio::read2d(expected_solution_path);
367  const auto residual = pfitsio::read2d(expected_residual_path);
368 
369  auto uv_data = utilities::read_visibility(input_data_path, false);
370  uv_data.units = utilities::vis_units::radians;
371  CAPTURE(uv_data.vis.head(5));
372  REQUIRE(uv_data.size() == 13107);
373 
374  t_uint const imsizey = 128;
375  t_uint const imsizex = 128;
376 
377  Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
378  auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
379  factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2,
380  kernels::kernel_from_string.at("kb"), 4, 4);
381  auto const power_method_stuff =
382  sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
383  const t_real op_norm = std::get<0>(power_method_stuff);
384  measurements_transform->set_norm(op_norm);
385 
386  std::vector<std::tuple<std::string, t_uint>> const sara{
387  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
388  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
389  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
390  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
391  factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
392  t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
393  t_real const beta = sigma * sigma;
394  t_real const gamma = 0.0001;
395 
396  std::string const prior_path =
397  purify::models_directory() + "/example_cost_dynamic_CRR_sigma_5_t_5.onnx";
398  std::string const prior_gradient_path =
399  purify::models_directory() + "/example_grad_dynamic_CRR_sigma_5_t_5.onnx";
400  std::shared_ptr<sopt::ONNXDifferentiableFunc<t_complex>> diff_function =
401  std::make_shared<sopt::ONNXDifferentiableFunc<t_complex>>(
402  prior_path, prior_gradient_path, sigma, 20, 5e4, *measurements_transform);
403 
404  auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
405  factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta,
406  gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-3, 1e-3, 50, "",
407  nondiff_func_type::RealIndicator, diff_function);
408 
409  auto const diagnostic = (*fb)();
410  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
411  // pfitsio::write2d(image.real(), result_path);
412  // pfitsio::write2d(residual_image.real(), expected_residual_path);
413 
414  double brightness = solution.real().cwiseAbs().maxCoeff();
415  double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
416  .real()
417  .squaredNorm() /
418  solution.size();
419  double rms = sqrt(mse);
420  CHECK(rms <= brightness * 5e-2);
421 }
422 #endif
423 
424 TEST_CASE("joint_map_factory") {
425  const std::string &test_dir = "expected/joint_map/";
426  const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
427  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
428  const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
429 
430  const auto solution = pfitsio::read2d(expected_solution_path);
431  const auto residual = pfitsio::read2d(expected_residual_path);
432 
433  auto uv_data = utilities::read_visibility(input_data_path, false);
434  uv_data.units = utilities::vis_units::radians;
435  CAPTURE(uv_data.vis.head(5));
436  REQUIRE(uv_data.size() == 13107);
437 
438  t_uint const imsizey = 128;
439  t_uint const imsizex = 128;
440 
441  Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
442  auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
443  factory::distributed_measurement_operator::serial, uv_data, imsizey, imsizex, 1, 1, 2,
444  kernels::kernel_from_string.at("kb"), 4, 4);
445  auto const power_method_stuff =
446  sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
447  const t_real op_norm = std::get<0>(power_method_stuff);
448  measurements_transform->set_norm(op_norm);
449 
450  std::vector<std::tuple<std::string, t_uint>> const sara{
451  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
452  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
453  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
454  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
455  factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);
456  t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
457  t_real const beta = sigma * sigma;
458  t_real const gamma = 1;
459  auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
460  factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta,
461  gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50);
462  auto const l1_norm = [wavelets](const Vector<t_complex> &x) {
463  auto val = sopt::l1_norm(wavelets->adjoint() * x);
464  return val;
465  };
466  auto const joint_map =
467  sopt::algorithm::JointMAP<sopt::algorithm::ImagingForwardBackward<t_complex>>(
468  fb, l1_norm, imsizex * imsizey * sara.size())
469  .relative_variation(1e-3)
470  .objective_variation(1e-3)
471  .beta(1.)
472  .alpha(1.);
473  auto const diagnostic = joint_map();
474  const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
475  // CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
476  // CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
477  // CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(),
478  // image.size()).real().head(10));
479  // CHECK(image.isApprox(solution, 1e-6));
480 
481  const Vector<t_complex> residuals = measurements_transform->adjoint() *
482  (uv_data.vis - ((*measurements_transform) * diagnostic.x));
483  const Image<t_complex> residual_image = Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
484  // CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
485  // CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
486  // CHECK(residual_image.real().isApprox(residual.real(), 1e-6));
487 }
TEST_CASE("padmm_factory")
Definition: algo_factory.cc:31
#define CHECK(CONDITION, ERROR)
Definition: casa.cc:6
const std::string test_dir
Definition: operators.cc:16
const std::map< std::string, kernel > kernel_from_string
Definition: kernels.h:16
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
Definition: pfitsio.cc:30
Image< t_complex > read2d(const std::string &fits_name)
Read image from fits file.
Definition: pfitsio.cc:109
std::function< bool()> random_updater(const sopt::mpi::Communicator &comm, const t_int total, const t_int update_size, const std::shared_ptr< bool > update_pointer, const std::string &update_name)
utilities::vis_params read_visibility(const std::vector< std::string > &names, const bool w_term)
Read visibility files from name of vector.
std::string models_directory()
Holds TF models.
std::string data_filename(std::string const &filename)
Holds data and such.
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)
Vector< t_complex > vis
Definition: uvw_utilities.h:22
t_uint size() const
return number of measurements
Definition: uvw_utilities.h:54
Vector< t_complex > weights
Definition: uvw_utilities.h:23