PURIFY
Next-generation radio interferometric imaging
algorithm_factory.h
Go to the documentation of this file.
1 #ifndef ALGORITHM_FACTORY_H
2 #define ALGORITHM_FACTORY_H
3 
4 #include "purify/config.h"
5 
6 #include "purify/types.h"
8 #include "purify/logging.h"
10 #include "purify/utilities.h"
11 #include "purify/uvw_utilities.h"
12 
13 #ifdef PURIFY_MPI
14 #include "purify/mpi_utilities.h"
15 #include <sopt/mpi/communicator.h>
16 #endif
17 
18 #include <sopt/differentiable_func.h>
19 #include <sopt/imaging_forward_backward.h>
20 #include <sopt/imaging_padmm.h>
21 #include <sopt/imaging_primal_dual.h>
22 #include <sopt/joint_map.h>
23 #include <sopt/l1_non_diff_function.h>
24 #include <sopt/non_differentiable_func.h>
25 #include <sopt/real_indicator.h>
26 #include <sopt/relative_variation.h>
27 #include <sopt/utilities.h>
28 #include <sopt/wavelets.h>
29 #include <sopt/wavelets/sara.h>
30 #ifdef PURIFY_ONNXRT
31 #include <sopt/tf_non_diff_function.h>
32 #endif
33 
34 namespace purify {
35 namespace factory {
38 const std::map<std::string, algo_distribution> algo_distribution_string = {
39  {"none", algo_distribution::serial},
40  {"serial-equivalent", algo_distribution::mpi_serial},
41  {"random-updates", algo_distribution::mpi_random_updates},
42  {"fully-distributed", algo_distribution::mpi_distributed}};
43 
45 template <class Algorithm, class... ARGS>
46 std::shared_ptr<Algorithm> algorithm_factory(const factory::algorithm algo, ARGS &&...args);
48 template <class Algorithm>
49 typename std::enable_if<
50  std::is_same<Algorithm, sopt::algorithm::ImagingProximalADMM<t_complex>>::value,
51  std::shared_ptr<Algorithm>>::type
53  std::shared_ptr<sopt::LinearTransform<Vector<typename Algorithm::Scalar>> const> const
54  &measurements,
55  std::shared_ptr<sopt::LinearTransform<Vector<typename Algorithm::Scalar>> const> const
56  &wavelets,
57  const utilities::vis_params &uv_data, const t_real sigma, const t_uint imsizey,
58  const t_uint imsizex, const t_uint sara_size, const t_uint max_iterations = 500,
59  const bool real_constraint = true, const bool positive_constraint = true,
60  const bool tight_frame = false, const t_real relative_variation = 1e-3,
61  const t_real l1_proximal_tolerance = 1e-2,
62  const t_uint maximum_proximal_iterations = 50,
63  const t_real residual_tolerance_scaling = 1) {
64  typedef typename Algorithm::Scalar t_scalar;
65  if (sara_size > 1 and tight_frame)
66  throw std::runtime_error(
67  "l1 proximal not consistent: You say you are using a tight frame, but you have more than "
68  "one wavelet basis.");
69  PURIFY_INFO("Constructing PADMM algorithm");
70  auto epsilon = std::sqrt(2 * uv_data.size() + 2 * std::sqrt(4 * uv_data.size())) * sigma;
71  auto padmm = std::make_shared<Algorithm>(uv_data.vis);
72  padmm->itermax(max_iterations)
73  .relative_variation(relative_variation)
74  .tight_frame(tight_frame)
75  .l1_proximal_tolerance(l1_proximal_tolerance)
76  .l1_proximal_nu(1)
77  .l1_proximal_itermax(maximum_proximal_iterations)
78  .l1_proximal_positivity_constraint(positive_constraint)
79  .l1_proximal_real_constraint(real_constraint)
80  .lagrange_update_scale(0.9)
81  .Psi(*wavelets)
82  .Phi(*measurements);
83 #ifdef PURIFY_MPI
86 #endif
87  switch (dist) {
89  padmm
90  ->regulariser_strength(
91  (wavelets->adjoint() * (measurements->adjoint() * uv_data.vis).eval())
92  .cwiseAbs()
93  .maxCoeff() *
94  1e-3)
95  .l2ball_proximal_epsilon(epsilon)
96  .residual_tolerance(epsilon * residual_tolerance_scaling);
97  return padmm;
98  break;
99 
100 #ifdef PURIFY_MPI
102  obj_conv = ConvergenceType::mpi_global;
103  rel_conv = ConvergenceType::mpi_global;
104  auto const comm = sopt::mpi::Communicator::World();
105  epsilon = std::sqrt(2 * comm.all_sum_all(uv_data.size()) +
106  2 * std::sqrt(4 * comm.all_sum_all(uv_data.size()))) *
107  sigma;
108  // communicator ensuring l2 norm in l2ball proximal is global
109  padmm->l2ball_proximal_communicator(comm);
110  } break;
111 
113  obj_conv = ConvergenceType::mpi_local;
114  rel_conv = ConvergenceType::mpi_local;
115  auto const comm = sopt::mpi::Communicator::World();
116  epsilon = std::sqrt(2 * uv_data.size() + 2 * std::sqrt(4 * uv_data.size()) *
117  std::sqrt(comm.all_sum_all(4 * uv_data.size())) /
118  comm.all_sum_all(std::sqrt(4 * uv_data.size()))) *
119  sigma;
120  } break;
121 
122 #endif
123  default:
124  throw std::runtime_error(
125  "Type of distributed proximal ADMM algorithm not recognised. You might not have compiled "
126  "with MPI.");
127  }
128 #ifdef PURIFY_MPI
129  auto const comm = sopt::mpi::Communicator::World();
130  std::weak_ptr<Algorithm> const padmm_weak(padmm);
131  // set epsilon
132  padmm->residual_tolerance(epsilon * residual_tolerance_scaling).l2ball_proximal_epsilon(epsilon);
133  // set regulariser_strength
134  padmm->regulariser_strength(comm.all_reduce(
135  utilities::step_size<Vector<t_complex>>(uv_data.vis, measurements, wavelets, sara_size) *
136  1e-3,
137  MPI_MAX));
138  // communicator ensuring l1 norm in l1 proximal is global
139  padmm->l1_proximal_adjoint_space_comm(comm);
140  padmm->residual_convergence(
141  purify::factory::l2_convergence_factory<typename Algorithm::Scalar>(rel_conv, padmm_weak));
142  padmm->objective_convergence(
143  purify::factory::l1_convergence_factory<typename Algorithm::Scalar>(obj_conv, padmm_weak));
144 #endif
145  return padmm;
146 }
147 
149 template <class Algorithm>
150 typename std::enable_if<
151  std::is_same<Algorithm, sopt::algorithm::ImagingForwardBackward<t_complex>>::value,
152  std::shared_ptr<Algorithm>>::type
154  std::shared_ptr<sopt::LinearTransform<Vector<typename Algorithm::Scalar>> const> const
155  &measurements,
156  std::shared_ptr<sopt::LinearTransform<Vector<typename Algorithm::Scalar>> const> const
157  &wavelets,
158  const utilities::vis_params &uv_data, const t_real sigma, const t_real step_size,
159  const t_real reg_parameter, const t_uint imsizey, const t_uint imsizex,
160  const t_uint sara_size, const t_uint max_iterations = 500,
161  const bool real_constraint = true, const bool positive_constraint = true,
162  const bool tight_frame = false, const t_real relative_variation = 1e-3,
163  const t_real l1_proximal_tolerance = 1e-2, const t_uint maximum_proximal_iterations = 50,
164  const std::string model_path = "",
165  const nondiff_func_type g_proximal = nondiff_func_type::L1Norm,
166  std::shared_ptr<DifferentiableFunc<typename Algorithm::Scalar>> f_function = nullptr) {
167  typedef typename Algorithm::Scalar t_scalar;
168  if (sara_size > 1 and tight_frame)
169  throw std::runtime_error(
170  "l1 proximal not consistent: You say you are using a tight frame, but you have more than "
171  "one wavelet basis.");
172  PURIFY_INFO("Constructing Forward Backward algorithm");
173  auto fb = std::make_shared<Algorithm>(uv_data.vis);
174  fb->itermax(max_iterations)
175  .regulariser_strength(reg_parameter)
176  .sigma(sigma * std::sqrt(2))
177  .step_size(step_size * std::sqrt(2))
178  .relative_variation(relative_variation)
179  .tight_frame(tight_frame)
180  .Phi(*measurements);
181 
182  if (f_function) fb->f_function(f_function); // only override f_function default if non-null
183  std::shared_ptr<NonDifferentiableFunc<t_scalar>> g;
184 
185  switch (g_proximal) {
186  case (nondiff_func_type::L1Norm): {
187  // Create a shared pointer to an instance of the L1GProximal class
188  // and set its properties
189  auto l1_gp = std::make_shared<sopt::algorithm::L1GProximal<t_scalar>>(false);
190  l1_gp->l1_proximal_tolerance(l1_proximal_tolerance)
191  .l1_proximal_nu(1.)
192  .l1_proximal_itermax(maximum_proximal_iterations)
193  .l1_proximal_positivity_constraint(positive_constraint)
194  .l1_proximal_real_constraint(real_constraint)
195  .Psi(*wavelets);
196 #ifdef PURIFY_MPI
197  if (dist == algo_distribution::mpi_serial) {
198  auto const comm = sopt::mpi::Communicator::World();
199  l1_gp->l1_proximal_adjoint_space_comm(comm);
200  l1_gp->l1_proximal_direct_space_comm(comm);
201  }
202 #endif
203  g = l1_gp;
204  break;
205  }
207 #ifdef PURIFY_ONNXRT
208  // Create a shared pointer to an instance of the TFGProximal class
209  g = std::make_shared<sopt::algorithm::TFGProximal<t_scalar>>(model_path);
210  break;
211 #else
212  throw std::runtime_error(
213  "Type TFGProximal not recognized because purify was built with onnxrt=off");
214 #endif
215  }
216 
218  g = std::make_shared<sopt::algorithm::RealIndicator<t_scalar>>();
219  break;
220  }
221  default: {
222  throw std::runtime_error("Type of g_proximal operator not recognised.");
223  }
224  }
225 
226  fb->g_function(g);
227 
228  switch (dist) {
229  case (algo_distribution::serial): {
230  break;
231  }
232 #ifdef PURIFY_MPI
234  auto const comm = sopt::mpi::Communicator::World();
235  fb->adjoint_space_comm(comm);
236  fb->obj_comm(comm);
237  break;
238  }
239 #endif
240  default:
241  throw std::runtime_error(
242  "Type of distributed Forward Backward algorithm not recognised. You might not have "
243  "compiled "
244  "with MPI.");
245  }
246  return fb;
247 }
248 
250 template <class Algorithm>
251 typename std::enable_if<
252  std::is_same<Algorithm, sopt::algorithm::ImagingPrimalDual<t_complex>>::value,
253  std::shared_ptr<Algorithm>>::type
255  const algo_distribution dist,
256  std::shared_ptr<sopt::LinearTransform<Vector<typename Algorithm::Scalar>> const> const
257  &measurements,
258  std::shared_ptr<sopt::LinearTransform<Vector<typename Algorithm::Scalar>> const> const
259  &wavelets,
260  const utilities::vis_params &uv_data, const t_real sigma, const t_uint imsizey,
261  const t_uint imsizex, const t_uint sara_size, const t_uint max_iterations = 500,
262  const bool real_constraint = true, const bool positive_constraint = true,
263  const t_real relative_variation = 1e-3, const t_real residual_tolerance_scaling = 1) {
264  typedef typename Algorithm::Scalar t_scalar;
265  PURIFY_INFO("Constructing Primal Dual algorithm");
266  auto epsilon = std::sqrt(2 * uv_data.size() + 2 * std::sqrt(4 * uv_data.size())) * sigma;
267  auto primaldual = std::make_shared<Algorithm>(uv_data.vis);
268  primaldual->itermax(max_iterations)
269  .relative_variation(relative_variation)
270  .real_constraint(real_constraint)
271  .positivity_constraint(positive_constraint)
272  .Psi(*wavelets)
273  .Phi(*measurements)
274  .tau(0.5 / (measurements->sq_norm() + 1))
275  .xi(1.)
276  .sigma(1.);
277 #ifdef PURIFY_MPI
280 #endif
281  switch (dist) {
282  case (algo_distribution::serial): {
283  primaldual
284  ->regulariser_strength(
285  (wavelets->adjoint() * (measurements->adjoint() * uv_data.vis).eval())
286  .cwiseAbs()
287  .maxCoeff() *
288  1e-3)
289  .l2ball_proximal_epsilon(epsilon)
290  .residual_tolerance(epsilon * residual_tolerance_scaling);
291  return primaldual;
292  }
293 #ifdef PURIFY_MPI
295  obj_conv = ConvergenceType::mpi_global;
296  rel_conv = ConvergenceType::mpi_global;
297  auto const comm = sopt::mpi::Communicator::World();
298  epsilon = std::sqrt(2 * comm.all_sum_all(uv_data.size()) +
299  2 * std::sqrt(4 * comm.all_sum_all(uv_data.size()))) *
300  sigma;
301  // communicator ensuring l2 norm in l2ball proximal is global
302  primaldual->l2ball_proximal_communicator(comm);
303  } break;
305  obj_conv = ConvergenceType::mpi_local;
306  rel_conv = ConvergenceType::mpi_local;
307  auto const comm = sopt::mpi::Communicator::World();
308  epsilon = std::sqrt(2 * uv_data.size() + 2 * std::sqrt(4 * uv_data.size()) *
309  std::sqrt(comm.all_sum_all(4 * uv_data.size())) /
310  comm.all_sum_all(std::sqrt(4 * uv_data.size()))) *
311  sigma;
312  } break;
314  auto const comm = sopt::mpi::Communicator::World();
315  epsilon = std::sqrt(2 * uv_data.size() + 2 * std::sqrt(4 * uv_data.size()) *
316  std::sqrt(comm.all_sum_all(4 * uv_data.size())) /
317  comm.all_sum_all(std::sqrt(4 * uv_data.size()))) *
318  sigma;
319  obj_conv = ConvergenceType::mpi_local;
320  rel_conv = ConvergenceType::mpi_local;
321  std::shared_ptr<bool> random_measurement_update_ptr = std::make_shared<bool>(true);
322  std::shared_ptr<bool> random_wavelet_update_ptr = std::make_shared<bool>(true);
323  const t_int update_size = std::max<t_int>(std::floor(0.5 * comm.size()), 1);
324  auto random_measurement_updater = random_updater::random_updater(
325  comm, comm.size(), update_size, random_measurement_update_ptr, "measurements");
326  auto random_wavelet_updater = random_updater::random_updater(
327  comm, std::min<t_int>(comm.size(), std::floor(comm.all_sum_all(sara_size))),
328  std::min<t_int>(update_size, std::floor(0.5 * comm.all_sum_all(sara_size))),
329  random_measurement_update_ptr, "wavelets");
330 
331  primaldual->random_measurement_updater(random_measurement_updater)
332  .random_wavelet_updater(random_wavelet_updater)
333  .v_all_sum_all_comm(comm)
334  .u_all_sum_all_comm(comm);
335  } break;
336 #endif
337  default:
338  throw std::runtime_error(
339  "Type of distributed primal dual algorithm not recognised. You might not have compiled "
340  "with MPI.");
341  }
342 #ifdef PURIFY_MPI
343  auto const comm = sopt::mpi::Communicator::World();
344  std::weak_ptr<Algorithm> const primaldual_weak(primaldual);
345  // set epsilon
346  primaldual->residual_tolerance(epsilon * residual_tolerance_scaling)
347  .l2ball_proximal_epsilon(epsilon);
348  // set regulariser_strength
349  primaldual->regulariser_strength(comm.all_reduce(
350  utilities::step_size<Vector<t_complex>>(uv_data.vis, measurements, wavelets, sara_size) *
351  1e-3,
352  MPI_MAX));
353  // communicator ensuring l1 norm in l1 proximal is global
354  primaldual->residual_convergence(
355  purify::factory::l2_convergence_factory<typename Algorithm::Scalar>(rel_conv,
356  primaldual_weak));
357  primaldual->objective_convergence(
358  purify::factory::l1_convergence_factory<typename Algorithm::Scalar>(obj_conv,
359  primaldual_weak));
360 #endif
361  return primaldual;
362 }
363 
364 template <class Algorithm, class... ARGS>
365 std::shared_ptr<Algorithm> algorithm_factory(const factory::algorithm algo, ARGS &&...args) {
366  switch (algo) {
367  case algorithm::padmm:
368  return padmm_factory<Algorithm>(std::forward<ARGS>(args)...);
369  break;
370  /*
371  case algorithm::primal_dual:
372  return pd_factory(std::forward<ARGS>(args)...);
373  break;
374  case algorithm::sdmm:
375  return sdmm_factory(std::forward<ARGS>(args)...);
376  break;
377  case algorithm::forward_backward:
378  return fb_factory(std::forward<ARGS>(args)...);
379  break;
380  */
381  default:
382  throw std::runtime_error("Algorithm not implimented.");
383  }
384 }
385 
386 } // namespace factory
387 } // namespace purify
388 #endif
#define PURIFY_INFO(...)
Definition: logging.h:195
std::enable_if< std::is_same< Algorithm, sopt::algorithm::ImagingProximalADMM< t_complex > >::value, std::shared_ptr< Algorithm > >::type padmm_factory(const algo_distribution dist, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &measurements, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &wavelets, const utilities::vis_params &uv_data, const t_real sigma, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const t_uint max_iterations=500, const bool real_constraint=true, const bool positive_constraint=true, const bool tight_frame=false, const t_real relative_variation=1e-3, const t_real l1_proximal_tolerance=1e-2, const t_uint maximum_proximal_iterations=50, const t_real residual_tolerance_scaling=1)
return shared pointer to padmm object
const std::map< std::string, algo_distribution > algo_distribution_string
std::enable_if< std::is_same< Algorithm, sopt::algorithm::ImagingPrimalDual< t_complex > >::value, std::shared_ptr< Algorithm > >::type primaldual_factory(const algo_distribution dist, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &measurements, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &wavelets, const utilities::vis_params &uv_data, const t_real sigma, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const t_uint max_iterations=500, const bool real_constraint=true, const bool positive_constraint=true, const t_real relative_variation=1e-3, const t_real residual_tolerance_scaling=1)
return shared pointer to primal dual object
std::shared_ptr< Algorithm > algorithm_factory(const factory::algorithm algo, ARGS &&...args)
return chosen algorithm given parameters
std::enable_if< std::is_same< Algorithm, sopt::algorithm::ImagingForwardBackward< t_complex > >::value, std::shared_ptr< Algorithm > >::type fb_factory(const algo_distribution dist, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &measurements, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &wavelets, const utilities::vis_params &uv_data, const t_real sigma, const t_real step_size, const t_real reg_parameter, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const t_uint max_iterations=500, const bool real_constraint=true, const bool positive_constraint=true, const bool tight_frame=false, const t_real relative_variation=1e-3, const t_real l1_proximal_tolerance=1e-2, const t_uint maximum_proximal_iterations=50, const std::string model_path="", const nondiff_func_type g_proximal=nondiff_func_type::L1Norm, std::shared_ptr< DifferentiableFunc< typename Algorithm::Scalar >> f_function=nullptr)
return shared pointer to forward backward object
std::function< bool()> random_updater(const sopt::mpi::Communicator &comm, const t_int total, const t_int update_size, const std::shared_ptr< bool > update_pointer, const std::string &update_name)
t_real step_size(T const &vis, const std::shared_ptr< sopt::LinearTransform< T > const > &measurements, const std::shared_ptr< sopt::LinearTransform< T > const > &wavelets, const t_uint sara_size)
Calculate step size using MPI (does not include factor of 1e-3)
Definition: mpi_utilities.h:79
nondiff_func_type
Definition: types.h:31
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)
Vector< t_complex > vis
Definition: uvw_utilities.h:22
t_uint size() const
return number of measurements
Definition: uvw_utilities.h:54