2 #include "catch2/catch_all.hpp"
4 #include "purify/config.h"
8 #include "purify/directories.h"
17 #include <sopt/onnx_differentiable_func.h>
24 #include <sopt/gradient_utils.h>
25 #include <sopt/power_method.h>
27 #include "purify/test_data.h"
32 const std::string &
test_dir =
"expected/padmm/";
42 CAPTURE(uv_data.vis.head(5));
43 REQUIRE(uv_data.size() == 13107);
45 t_uint
const imsizey = 128;
46 t_uint
const imsizex = 128;
47 Vector<t_complex>
const init = Vector<t_complex>::Ones(imsizex * imsizey);
48 auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
51 auto const power_method_stuff =
52 sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
53 const t_real op_norm = std::get<0>(power_method_stuff);
54 measurements_transform->set_norm(op_norm);
56 std::vector<std::tuple<std::string, t_uint>>
const sara{
57 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
58 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
59 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
60 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
62 t_real
const sigma = 0.016820222945913496 * std::sqrt(2);
63 auto const padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
65 imsizex, sara.size(), 300,
true,
true,
false, 1e-2, 1e-3, 50, 1);
67 auto const diagnostic = (*padmm)();
68 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
70 CAPTURE(Vector<t_complex>::Map(solution.data(), solution.size()).real().head(10));
71 CAPTURE(Vector<t_complex>::Map(image.data(), image.size()).real().head(10));
72 CAPTURE(Vector<t_complex>::Map((image / solution).eval().data(), image.size()).real().head(10));
73 CHECK(image.isApprox(solution, 1e-4));
75 const Vector<t_complex> residuals = measurements_transform->adjoint() *
76 (uv_data.vis - ((*measurements_transform) * diagnostic.x));
77 const Image<t_complex> residual_image = Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
79 CAPTURE(Vector<t_complex>::Map(residual.data(), residual.size()).real().head(10));
80 CAPTURE(Vector<t_complex>::Map(residuals.data(), residuals.size()).real().head(10));
81 CHECK(residual_image.real().isApprox(residual.real(), 1e-4));
85 const std::string &
test_dir =
"expected/primal_dual/";
96 CAPTURE(uv_data.vis.head(5));
97 REQUIRE(uv_data.size() == 13107);
99 t_uint
const imsizey = 128;
100 t_uint
const imsizex = 128;
102 Vector<t_complex>
const init = Vector<t_complex>::Ones(imsizex * imsizey);
103 auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
106 auto const power_method_stuff =
107 sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
108 const t_real op_norm = std::get<0>(power_method_stuff);
109 measurements_transform->set_norm(op_norm);
111 std::vector<std::tuple<std::string, t_uint>>
const sara{
112 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
113 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
114 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
115 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
117 t_real
const sigma = 0.016820222945913496 * std::sqrt(2);
118 auto const primaldual =
119 factory::primaldual_factory<sopt::algorithm::ImagingPrimalDual<t_complex>>(
121 imsizey, imsizex, sara.size(), 1000,
true,
true, 1e-3, 1);
123 auto const diagnostic = (*primaldual)();
125 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
128 double brightness = solution.real().cwiseAbs().maxCoeff();
129 double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
133 double rms = sqrt(mse);
134 CHECK(rms <= brightness * 5e-2);
138 const std::string &
test_dir =
"expected/fb/";
149 CAPTURE(uv_data.vis.head(5));
150 REQUIRE(uv_data.size() == 13107);
152 t_uint
const imsizey = 128;
153 t_uint
const imsizex = 128;
155 Vector<t_complex>
const init = Vector<t_complex>::Ones(imsizex * imsizey);
156 auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
159 auto const power_method_stuff =
160 sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
161 const t_real op_norm = std::get<0>(power_method_stuff);
162 measurements_transform->set_norm(op_norm);
164 std::vector<std::tuple<std::string, t_uint>>
const sara{
165 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
166 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
167 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
168 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
171 t_real
const sigma = 0.016820222945913496 * std::sqrt(2);
172 t_real
const beta = sigma * sigma;
173 t_real
const gamma = 0.0001;
175 auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
177 gamma, imsizey, imsizex, sara.size(), 1000,
true,
true,
false, 1e-2, 1e-3, 50);
179 auto const diagnostic = (*fb)();
180 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
184 double brightness = solution.real().cwiseAbs().maxCoeff();
185 double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
189 double rms = sqrt(mse);
190 CHECK(rms <= brightness * 5e-2);
195 const std::string &
test_dir =
"expected/fb/";
203 CAPTURE(uv_data.vis.head(5));
204 REQUIRE(uv_data.size() == 13107);
206 t_uint
const imsizey = 128;
207 t_uint
const imsizex = 128;
211 const size_t N = 1000;
212 std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()>
random_updater =
213 [&input_data_path, imsizex, imsizey, &rng, &N]() {
218 std::vector<size_t> indices(uv_data.
size());
220 for (
auto &x : indices) {
224 std::shuffle(indices.begin(), indices.end(), rng);
225 Vector<t_real> u_fragment(N);
226 Vector<t_real> v_fragment(N);
227 Vector<t_real> w_fragment(N);
228 Vector<t_complex> vis_fragment(N);
229 Vector<t_complex> weights_fragment(N);
230 for (i = 0; i < N; i++) {
231 size_t j = indices[i];
232 u_fragment[i] = uv_data.
u[j];
233 v_fragment[i] = uv_data.
v[j];
234 w_fragment[i] = uv_data.
w[j];
235 vis_fragment[i] = uv_data.
vis[j];
236 weights_fragment[i] = uv_data.
weights[j];
239 weights_fragment, uv_data.
units, uv_data.
ra,
242 auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
246 Vector<t_complex>
const init = Vector<t_complex>::Ones(imsizex * imsizey);
247 auto const power_method_stuff =
248 sopt::algorithm::power_method<Vector<t_complex>>(*phi, 1000, 1e-5, init);
249 const t_real op_norm = std::get<0>(power_method_stuff);
250 phi->set_norm(op_norm);
252 return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data_fragment.vis, phi);
258 std::vector<std::tuple<std::string, t_uint>>
const sara{
259 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
260 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
261 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
262 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
266 t_real
const sigma = 0.016820222945913496 * std::sqrt(2);
267 t_real
const beta = sigma * sigma;
268 t_real
const gamma = 0.0001;
270 sopt::algorithm::ImagingForwardBackward<t_complex> fb(
random_updater);
272 .step_size(beta * sqrt(2))
273 .sigma(sigma * sqrt(2))
274 .regulariser_strength(gamma)
275 .relative_variation(1e-3)
276 .residual_tolerance(0)
279 auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(
false);
280 gp->l1_proximal_tolerance(1e-4)
282 .l1_proximal_itermax(50)
283 .l1_proximal_positivity_constraint(
true)
284 .l1_proximal_real_constraint(
true)
288 auto const diagnostic = fb();
289 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
293 auto soln_flat = Vector<t_complex>::Map(solution.data(), solution.size());
294 double brightness = soln_flat.real().cwiseAbs().maxCoeff();
295 double mse = (soln_flat - diagnostic.x).real().squaredNorm() / solution.size();
296 SOPT_HIGH_LOG(
"MSE = {}", mse);
297 CHECK(mse <= brightness * 5e-2);
303 const std::string &
test_dir =
"expected/fb/";
314 CAPTURE(uv_data.
vis.head(5));
315 REQUIRE(uv_data.
size() == 13107);
317 t_uint
const imsizey = 128;
318 t_uint
const imsizex = 128;
320 Vector<t_complex>
const init = Vector<t_complex>::Ones(imsizex * imsizey);
321 auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
324 auto const power_method_stuff =
325 sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
326 const t_real op_norm = std::get<0>(power_method_stuff);
327 measurements_transform->set_norm(op_norm);
329 std::vector<std::tuple<std::string, t_uint>>
const sara{
330 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
331 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
332 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
333 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
335 t_real
const sigma = 0.016820222945913496 * std::sqrt(2);
336 t_real
const beta = sigma * sigma;
337 t_real
const gamma = 0.0001;
341 auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
343 gamma, imsizey, imsizex, sara.size(), 1000,
true,
true,
false, 1e-2, 1e-3, 50, tf_model_path,
346 auto const diagnostic = (*fb)();
347 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
351 double brightness = solution.real().cwiseAbs().maxCoeff();
352 double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
356 double rms = sqrt(mse);
357 CHECK(rms <= brightness * 5e-2);
361 const std::string &
test_dir =
"expected/fb/";
371 CAPTURE(uv_data.vis.head(5));
372 REQUIRE(uv_data.size() == 13107);
374 t_uint
const imsizey = 128;
375 t_uint
const imsizex = 128;
377 Vector<t_complex>
const init = Vector<t_complex>::Ones(imsizex * imsizey);
378 auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
381 auto const power_method_stuff =
382 sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
383 const t_real op_norm = std::get<0>(power_method_stuff);
384 measurements_transform->set_norm(op_norm);
386 std::vector<std::tuple<std::string, t_uint>>
const sara{
387 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
388 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
389 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
390 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
392 t_real
const sigma = 0.016820222945913496 * std::sqrt(2);
393 t_real
const beta = sigma * sigma;
394 t_real
const gamma = 0.0001;
396 std::string
const prior_path =
398 std::string
const prior_gradient_path =
400 std::shared_ptr<sopt::ONNXDifferentiableFunc<t_complex>> diff_function =
401 std::make_shared<sopt::ONNXDifferentiableFunc<t_complex>>(
402 prior_path, prior_gradient_path, sigma, 20, 5e4, *measurements_transform);
404 auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
406 gamma, imsizey, imsizex, sara.size(), 1000,
true,
true,
false, 1e-3, 1e-3, 50,
"",
409 auto const diagnostic = (*fb)();
410 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
414 double brightness = solution.real().cwiseAbs().maxCoeff();
415 double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
419 double rms = sqrt(mse);
420 CHECK(rms <= brightness * 5e-2);
425 const std::string &
test_dir =
"expected/joint_map/";
435 CAPTURE(uv_data.vis.head(5));
436 REQUIRE(uv_data.size() == 13107);
438 t_uint
const imsizey = 128;
439 t_uint
const imsizex = 128;
441 Vector<t_complex>
const init = Vector<t_complex>::Ones(imsizex * imsizey);
442 auto measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
445 auto const power_method_stuff =
446 sopt::algorithm::power_method<Vector<t_complex>>(*measurements_transform, 1000, 1e-5, init);
447 const t_real op_norm = std::get<0>(power_method_stuff);
448 measurements_transform->set_norm(op_norm);
450 std::vector<std::tuple<std::string, t_uint>>
const sara{
451 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
452 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
453 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
454 auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
456 t_real
const sigma = 0.016820222945913496 * std::sqrt(2);
457 t_real
const beta = sigma * sigma;
458 t_real
const gamma = 1;
459 auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
461 gamma, imsizey, imsizex, sara.size(), 1000,
true,
true,
false, 1e-2, 1e-3, 50);
462 auto const l1_norm = [wavelets](
const Vector<t_complex> &x) {
463 auto val = sopt::l1_norm(wavelets->adjoint() * x);
466 auto const joint_map =
467 sopt::algorithm::JointMAP<sopt::algorithm::ImagingForwardBackward<t_complex>>(
468 fb, l1_norm, imsizex * imsizey * sara.size())
469 .relative_variation(1e-3)
470 .objective_variation(1e-3)
473 auto const diagnostic = joint_map();
474 const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
481 const Vector<t_complex> residuals = measurements_transform->adjoint() *
482 (uv_data.vis - ((*measurements_transform) * diagnostic.x));
483 const Image<t_complex> residual_image = Image<t_complex>::Map(residuals.data(), imsizey, imsizex);
TEST_CASE("padmm_factory")
#define CHECK(CONDITION, ERROR)
const std::string test_dir
const std::map< std::string, kernel > kernel_from_string
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
Image< t_complex > read2d(const std::string &fits_name)
Read image from fits file.
std::function< bool()> random_updater(const sopt::mpi::Communicator &comm, const t_int total, const t_int update_size, const std::shared_ptr< bool > update_pointer, const std::string &update_name)
utilities::vis_params read_visibility(const std::vector< std::string > &names, const bool w_term)
Read visibility files from name of vector.
std::string models_directory()
Holds TF models.
std::string data_filename(std::string const &filename)
Holds data and such.
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)
t_uint size() const
return number of measurements
Vector< t_complex > weights