PURIFY
Next-generation radio interferometric imaging
Enumerations | Functions | Variables
purify::factory Namespace Reference

Enumerations

enum class  algorithm { padmm , primal_dual , sdmm , forward_backward }
 
enum class  algo_distribution { serial , mpi_serial , mpi_distributed , mpi_random_updates }
 
enum class  ConvergenceType { mpi_local , mpi_global }
 
enum class  distributed_measurement_operator {
  serial , mpi_distribute_image , mpi_distribute_grid , mpi_distribute_all_to_all ,
  gpu_serial , gpu_mpi_distribute_image , gpu_mpi_distribute_grid , gpu_mpi_distribute_all_to_all
}
 determine type of distribute for mpi measurement operator More...
 
enum class  distributed_wavelet_operator { serial , mpi_sara }
 

Functions

template<class Algorithm , class... ARGS>
std::shared_ptr< Algorithm > algorithm_factory (const factory::algorithm algo, ARGS &&...args)
 return chosen algorithm given parameters More...
 
template<class Algorithm >
std::enable_if< std::is_same< Algorithm, sopt::algorithm::ImagingProximalADMM< t_complex > >::value, std::shared_ptr< Algorithm > >::type padmm_factory (const algo_distribution dist, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &measurements, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &wavelets, const utilities::vis_params &uv_data, const t_real sigma, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const t_uint max_iterations=500, const bool real_constraint=true, const bool positive_constraint=true, const bool tight_frame=false, const t_real relative_variation=1e-3, const t_real l1_proximal_tolerance=1e-2, const t_uint maximum_proximal_iterations=50, const t_real residual_tolerance_scaling=1)
 return shared pointer to padmm object More...
 
template<class Algorithm >
std::enable_if< std::is_same< Algorithm, sopt::algorithm::ImagingForwardBackward< t_complex > >::value, std::shared_ptr< Algorithm > >::type fb_factory (const algo_distribution dist, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &measurements, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &wavelets, const utilities::vis_params &uv_data, const t_real sigma, const t_real step_size, const t_real reg_parameter, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const t_uint max_iterations=500, const bool real_constraint=true, const bool positive_constraint=true, const bool tight_frame=false, const t_real relative_variation=1e-3, const t_real l1_proximal_tolerance=1e-2, const t_uint maximum_proximal_iterations=50, const std::string model_path="", const nondiff_func_type g_proximal=nondiff_func_type::L1Norm, std::shared_ptr< DifferentiableFunc< typename Algorithm::Scalar >> f_function=nullptr)
 return shared pointer to forward backward object More...
 
template<class Algorithm >
std::enable_if< std::is_same< Algorithm, sopt::algorithm::ImagingPrimalDual< t_complex > >::value, std::shared_ptr< Algorithm > >::type primaldual_factory (const algo_distribution dist, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &measurements, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &wavelets, const utilities::vis_params &uv_data, const t_real sigma, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const t_uint max_iterations=500, const bool real_constraint=true, const bool positive_constraint=true, const t_real relative_variation=1e-3, const t_real residual_tolerance_scaling=1)
 return shared pointer to primal dual object More...
 
template<class T , class... ARGS>
std::shared_ptr< sopt::LinearTransform< T > > all_to_all_measurement_operator_factory (const distributed_measurement_operator distribute, const std::vector< t_int > &image_stacks, const std::vector< t_real > &w_stacks, ARGS &&...args)
 distributed measurement operator factory More...
 
template<class T , class... ARGS>
std::shared_ptr< sopt::LinearTransform< T > > measurement_operator_factory (const distributed_measurement_operator distribute, ARGS &&...args)
 distributed measurement operator factory More...
 
template<class T , class Algo >
void add_updater (std::weak_ptr< Algo > const algo_weak, const t_real step_size_scale, const t_real update_tol, const t_uint update_iters, const pfitsio::header_params &update_solution_header, const pfitsio::header_params &update_residual_header, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const bool using_mpi, const t_real beam_units=1)
 
template<class T >
std::shared_ptr< sopt::LinearTransform< T > const > wavelet_operator_factory (const distributed_wavelet_operator distribute, const std::vector< std::tuple< std::string, t_uint >> &wavelets, const t_uint imsizey, const t_uint imsizex, t_uint &sara_size)
 construct sara wavelet operator More...
 
template<class T >
std::shared_ptr< sopt::LinearTransform< T > const > wavelet_operator_factory (const distributed_wavelet_operator distribute, const std::vector< std::tuple< std::string, t_uint >> &wavelets, const t_uint imsizey, const t_uint imsizex)
 

Variables

const std::map< std::string, algo_distributionalgo_distribution_string
 

Enumeration Type Documentation

◆ algo_distribution

Enumerator
serial 
mpi_serial 
mpi_distributed 
mpi_random_updates 

Definition at line 37 of file algorithm_factory.h.

◆ algorithm

Enumerator
padmm 
primal_dual 
sdmm 
forward_backward 

Definition at line 36 of file algorithm_factory.h.

void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)

◆ ConvergenceType

Enumerator
mpi_local 
mpi_global 

Definition at line 12 of file convergence_factory.h.

◆ distributed_measurement_operator

determine type of distribute for mpi measurement operator

Enumerator
serial 
mpi_distribute_image 
mpi_distribute_grid 
mpi_distribute_all_to_all 
gpu_serial 
gpu_mpi_distribute_image 
gpu_mpi_distribute_grid 
gpu_mpi_distribute_all_to_all 

Definition at line 22 of file measurement_operator_factory.h.

◆ distributed_wavelet_operator

Enumerator
serial 
mpi_sara 

Definition at line 18 of file wavelet_operator_factory.h.

Function Documentation

◆ add_updater()

template<class T , class Algo >
void purify::factory::add_updater ( std::weak_ptr< Algo > const  algo_weak,
const t_real  step_size_scale,
const t_real  update_tol,
const t_uint  update_iters,
const pfitsio::header_params update_solution_header,
const pfitsio::header_params update_residual_header,
const t_uint  imsizey,
const t_uint  imsizex,
const t_uint  sara_size,
const bool  using_mpi,
const t_real  beam_units = 1 
)

Definition at line 20 of file update_factory.h.

25  {
26  if (update_tol < 0)
27  throw std::runtime_error("Step size update tolerance must be greater than zero.");
28  if (step_size_scale < 0)
29  throw std::runtime_error("Step size update scale must be greater than zero.");
30 #ifdef PURIFY_CImg
31  auto const canvas =
32  std::make_shared<CDisplay>(cimg::make_display(Vector<t_real>::Zero(1024 * 512), 1024, 512));
33 #endif
34  const std::shared_ptr<pfitsio::header_params> update_header_sol =
35  std::make_shared<pfitsio::header_params>(update_solution_header);
36  const std::shared_ptr<pfitsio::header_params> update_header_res =
37  std::make_shared<pfitsio::header_params>(update_residual_header);
38  if (using_mpi) {
39 #ifdef PURIFY_MPI
40  const auto comm = sopt::mpi::Communicator::World();
41  const std::shared_ptr<t_int> iter = std::make_shared<t_int>(0);
42  const auto updater = [update_tol, update_iters, imsizex, imsizey, algo_weak, iter,
43  step_size_scale, update_header_sol, update_header_res, sara_size, comm,
44  beam_units](const Vector<T> &x, const Vector<T> &res) -> bool {
45  auto algo = algo_weak.lock();
46  if (comm.is_root()) PURIFY_MEDIUM_LOG("Step size γ {}", algo->regulariser_strength());
47  if (algo->regulariser_strength() > 0) {
48  Vector<t_complex> const alpha = algo->Psi().adjoint() * x;
49  const t_real new_gamma =
50  comm.all_reduce((sara_size > 0) ? alpha.real().cwiseAbs().maxCoeff() : 0., MPI_MAX) *
51  step_size_scale;
52  if (comm.is_root()) PURIFY_MEDIUM_LOG("Step size γ update {}", new_gamma);
53  // updating parameter
54  algo->regulariser_strength(
55  ((std::abs(algo->regulariser_strength() - new_gamma) > update_tol) and
56  *iter < update_iters)
57  ? new_gamma
58  : algo->regulariser_strength());
59  }
60  Vector<t_complex> const residual = algo->Phi().adjoint() * (res / beam_units).eval();
61  PURIFY_MEDIUM_LOG("RMS of residual map in Jy/beam {}",
62  residual.norm() / std::sqrt(residual.size()));
63  if (comm.is_root()) {
64  update_header_sol->niters = *iter;
65  update_header_res->niters = *iter;
66  pfitsio::write2d(Image<T>::Map(x.data(), imsizey, imsizex).real(), *update_header_sol,
67  true);
68  pfitsio::write2d(Image<T>::Map(residual.data(), imsizey, imsizex).real(),
69  *update_header_res, true);
70  }
71  *iter = *iter + 1;
72  return true;
73  };
74  auto algo = algo_weak.lock();
75  algo->is_converged(updater);
76 
77 #else
78  throw std::runtime_error(
79  "Trying to use algorithm step size update with MPI, but you did not compile with MPI.");
80 #endif
81  } else {
82  const std::shared_ptr<t_int> iter = std::make_shared<t_int>(0);
83  const auto updater = [update_tol, update_iters, imsizex, imsizey, algo_weak, iter,
84  step_size_scale, update_header_sol, update_header_res, beam_units
85 #ifdef PURIFY_CImg
86  ,
87  canvas
88 #endif
89  ](const Vector<T> &x, const Vector<T> &res) -> bool {
90  auto algo = algo_weak.lock();
91  if (algo->regulariser_strength() > 0) {
92  PURIFY_MEDIUM_LOG("Step size γ {}", algo->regulariser_strength());
93  Vector<T> const alpha = algo->Psi().adjoint() * x;
94  const t_real new_gamma = alpha.real().cwiseAbs().maxCoeff() * step_size_scale;
95  PURIFY_MEDIUM_LOG("Step size γ update {}", new_gamma);
96  // updating parameter
97  algo->regulariser_strength(((std::abs((algo->regulariser_strength() - new_gamma) /
98  algo->regulariser_strength()) > update_tol) and
99  *iter < update_iters)
100  ? new_gamma
101  : algo->regulariser_strength());
102  }
103  Vector<t_complex> const residual = algo->Phi().adjoint() * (res / beam_units).eval();
104  PURIFY_MEDIUM_LOG("RMS of residual map in Jy/beam {}",
105  residual.norm() / std::sqrt(residual.size()));
106 #ifdef PURIFY_CImg
107  const auto img1 = cimg::make_image(x.real().eval(), imsizey, imsizex)
108  .get_normalize(0, 1)
109  .get_resize(512, 512);
110  const auto img2 = cimg::make_image(residual.real().eval(), imsizey, imsizex)
111  .get_normalize(0, 1)
112  .get_resize(512, 512);
113  const auto results =
114  CImageList<t_real>(img1.get_equalize(256, 0.05, 1.), img2.get_equalize(256, 0.1, 1.));
115  canvas->display(results);
116  canvas->resize(true);
117 #endif
118  update_header_sol->niters = *iter;
119  update_header_res->niters = *iter;
120  pfitsio::write2d(Image<T>::Map(x.data(), imsizey, imsizex).real(), *update_header_sol, true);
121  pfitsio::write2d(Image<T>::Map(residual.data(), imsizey, imsizex).real(), *update_header_res,
122  true);
123  *iter = *iter + 1;
124  return true;
125  };
126  auto algo = algo_weak.lock();
127  algo->is_converged(updater);
128  }
129 }
#define PURIFY_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:205
void write2d(const Image< t_real > &eigen_image, const pfitsio::header_params &header, const bool &overwrite)
Write image to fits file.
Definition: pfitsio.cc:30

References PURIFY_MEDIUM_LOG, and purify::pfitsio::write2d().

◆ algorithm_factory()

template<class Algorithm , class... ARGS>
std::shared_ptr< Algorithm > purify::factory::algorithm_factory ( const factory::algorithm  algo,
ARGS &&...  args 
)

return chosen algorithm given parameters

Definition at line 365 of file algorithm_factory.h.

365  {
366  switch (algo) {
367  case algorithm::padmm:
368  return padmm_factory<Algorithm>(std::forward<ARGS>(args)...);
369  break;
370  /*
371  case algorithm::primal_dual:
372  return pd_factory(std::forward<ARGS>(args)...);
373  break;
374  case algorithm::sdmm:
375  return sdmm_factory(std::forward<ARGS>(args)...);
376  break;
377  case algorithm::forward_backward:
378  return fb_factory(std::forward<ARGS>(args)...);
379  break;
380  */
381  default:
382  throw std::runtime_error("Algorithm not implimented.");
383  }
384 }

References padmm.

◆ all_to_all_measurement_operator_factory()

template<class T , class... ARGS>
std::shared_ptr<sopt::LinearTransform<T> > purify::factory::all_to_all_measurement_operator_factory ( const distributed_measurement_operator  distribute,
const std::vector< t_int > &  image_stacks,
const std::vector< t_real > &  w_stacks,
ARGS &&...  args 
)

distributed measurement operator factory

Definition at line 43 of file measurement_operator_factory.h.

45  {
46  switch (distribute) {
47 #ifdef PURIFY_MPI
48  case (distributed_measurement_operator::mpi_distribute_all_to_all): {
49  PURIFY_LOW_LOG("Using MPI all to all measurement operator.");
50  auto const world = sopt::mpi::Communicator::World();
51  return measurementoperator::init_degrid_operator_2d_all_to_all<T>(world, image_stacks, w_stacks,
52  std::forward<ARGS>(args)...);
53  }
54 #endif
55  default:
56  throw std::runtime_error(
57  "Distributed method not found for Measurement Operator. Are you sure you compiled with "
58  "MPI?");
59  }
60 }
#define PURIFY_LOW_LOG(...)
Low priority message.
Definition: logging.h:207

References mpi_distribute_all_to_all, and PURIFY_LOW_LOG.

Referenced by createMeasurementOperator(), getInputData(), and TEST_CASE().

◆ fb_factory()

template<class Algorithm >
std::enable_if< std::is_same<Algorithm, sopt::algorithm::ImagingForwardBackward<t_complex> >::value, std::shared_ptr<Algorithm> >::type purify::factory::fb_factory ( const algo_distribution  dist,
std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &  measurements,
std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &  wavelets,
const utilities::vis_params uv_data,
const t_real  sigma,
const t_real  step_size,
const t_real  reg_parameter,
const t_uint  imsizey,
const t_uint  imsizex,
const t_uint  sara_size,
const t_uint  max_iterations = 500,
const bool  real_constraint = true,
const bool  positive_constraint = true,
const bool  tight_frame = false,
const t_real  relative_variation = 1e-3,
const t_real  l1_proximal_tolerance = 1e-2,
const t_uint  maximum_proximal_iterations = 50,
const std::string  model_path = "",
const nondiff_func_type  g_proximal = nondiff_func_type::L1Norm,
std::shared_ptr< DifferentiableFunc< typename Algorithm::Scalar >>  f_function = nullptr 
)

return shared pointer to forward backward object

Definition at line 153 of file algorithm_factory.h.

166  {
167  typedef typename Algorithm::Scalar t_scalar;
168  if (sara_size > 1 and tight_frame)
169  throw std::runtime_error(
170  "l1 proximal not consistent: You say you are using a tight frame, but you have more than "
171  "one wavelet basis.");
172  PURIFY_INFO("Constructing Forward Backward algorithm");
173  auto fb = std::make_shared<Algorithm>(uv_data.vis);
174  fb->itermax(max_iterations)
175  .regulariser_strength(reg_parameter)
176  .sigma(sigma * std::sqrt(2))
177  .step_size(step_size * std::sqrt(2))
178  .relative_variation(relative_variation)
179  .tight_frame(tight_frame)
180  .Phi(*measurements);
181 
182  if (f_function) fb->f_function(f_function); // only override f_function default if non-null
183  std::shared_ptr<NonDifferentiableFunc<t_scalar>> g;
184 
185  switch (g_proximal) {
186  case (nondiff_func_type::L1Norm): {
187  // Create a shared pointer to an instance of the L1GProximal class
188  // and set its properties
189  auto l1_gp = std::make_shared<sopt::algorithm::L1GProximal<t_scalar>>(false);
190  l1_gp->l1_proximal_tolerance(l1_proximal_tolerance)
191  .l1_proximal_nu(1.)
192  .l1_proximal_itermax(maximum_proximal_iterations)
193  .l1_proximal_positivity_constraint(positive_constraint)
194  .l1_proximal_real_constraint(real_constraint)
195  .Psi(*wavelets);
196 #ifdef PURIFY_MPI
197  if (dist == algo_distribution::mpi_serial) {
198  auto const comm = sopt::mpi::Communicator::World();
199  l1_gp->l1_proximal_adjoint_space_comm(comm);
200  l1_gp->l1_proximal_direct_space_comm(comm);
201  }
202 #endif
203  g = l1_gp;
204  break;
205  }
206  case (nondiff_func_type::Denoiser): {
207 #ifdef PURIFY_ONNXRT
208  // Create a shared pointer to an instance of the TFGProximal class
209  g = std::make_shared<sopt::algorithm::TFGProximal<t_scalar>>(model_path);
210  break;
211 #else
212  throw std::runtime_error(
213  "Type TFGProximal not recognized because purify was built with onnxrt=off");
214 #endif
215  }
216 
217  case (nondiff_func_type::RealIndicator): {
218  g = std::make_shared<sopt::algorithm::RealIndicator<t_scalar>>();
219  break;
220  }
221  default: {
222  throw std::runtime_error("Type of g_proximal operator not recognised.");
223  }
224  }
225 
226  fb->g_function(g);
227 
228  switch (dist) {
229  case (algo_distribution::serial): {
230  break;
231  }
232 #ifdef PURIFY_MPI
233  case (algo_distribution::mpi_serial): {
234  auto const comm = sopt::mpi::Communicator::World();
235  fb->adjoint_space_comm(comm);
236  fb->obj_comm(comm);
237  break;
238  }
239 #endif
240  default:
241  throw std::runtime_error(
242  "Type of distributed Forward Backward algorithm not recognised. You might not have "
243  "compiled "
244  "with MPI.");
245  }
246  return fb;
247 }
#define PURIFY_INFO(...)
Definition: logging.h:195
t_real step_size(T const &vis, const std::shared_ptr< sopt::LinearTransform< T > const > &measurements, const std::shared_ptr< sopt::LinearTransform< T > const > &wavelets, const t_uint sara_size)
Calculate step size using MPI (does not include factor of 1e-3)
Definition: mpi_utilities.h:79

References purify::Denoiser, purify::L1Norm, mpi_serial, PURIFY_INFO, purify::RealIndicator, serial, purify::utilities::step_size(), and purify::utilities::vis_params::vis.

◆ measurement_operator_factory()

template<class T , class... ARGS>
std::shared_ptr<sopt::LinearTransform<T> > purify::factory::measurement_operator_factory ( const distributed_measurement_operator  distribute,
ARGS &&...  args 
)

distributed measurement operator factory

Definition at line 63 of file measurement_operator_factory.h.

64  {
65  switch (distribute) {
66  case (distributed_measurement_operator::serial): {
67  PURIFY_LOW_LOG("Using serial measurement operator.");
68  return measurementoperator::init_degrid_operator_2d<T>(std::forward<ARGS>(args)...);
69  }
70  case (distributed_measurement_operator::gpu_serial): {
71 #ifndef PURIFY_ARRAYFIRE
72  throw std::runtime_error("Tried to use GPU operator but you did not build with ArrayFire.");
73 #else
74  check_complex_for_gpu<T>();
75  PURIFY_LOW_LOG("Using serial measurement operator with Arrayfire.");
76  af::setDevice(0);
77  return gpu::measurementoperator::init_degrid_operator_2d(std::forward<ARGS>(args)...);
78 #endif
79  }
80 #ifdef PURIFY_MPI
81  case (distributed_measurement_operator::mpi_distribute_image): {
82  auto const world = sopt::mpi::Communicator::World();
83  PURIFY_LOW_LOG("Using distributed image MPI measurement operator.");
84  return measurementoperator::init_degrid_operator_2d<T>(world, std::forward<ARGS>(args)...);
85  }
86  case (distributed_measurement_operator::mpi_distribute_grid): {
87  auto const world = sopt::mpi::Communicator::World();
88  PURIFY_LOW_LOG("Using distributed grid MPI measurement operator.");
89  return measurementoperator::init_degrid_operator_2d_mpi<T>(world, std::forward<ARGS>(args)...);
90  }
91  case (distributed_measurement_operator::gpu_mpi_distribute_image): {
92 #ifndef PURIFY_ARRAYFIRE
93  throw std::runtime_error("Tried to use GPU operator but you did not build with ArrayFire.");
94 #else
95  check_complex_for_gpu<T>();
96  auto const world = sopt::mpi::Communicator::World();
97  PURIFY_LOW_LOG("Using distributed image MPI + Arrayfire measurement operator.");
98  af::setDevice(0);
99  return gpu::measurementoperator::init_degrid_operator_2d(world, std::forward<ARGS>(args)...);
100 #endif
101  }
102  case (distributed_measurement_operator::gpu_mpi_distribute_grid): {
103 #ifndef PURIFY_ARRAYFIRE
104  throw std::runtime_error("Tried to use GPU operator but you did not build with ArrayFire.");
105 #else
106  check_complex_for_gpu<T>();
107  auto const world = sopt::mpi::Communicator::World();
108  PURIFY_LOW_LOG("Using distributed grid MPI + Arrayfire measurement operator.");
109  af::setDevice(0);
110  return gpu::measurementoperator::init_degrid_operator_2d_mpi(world,
111  std::forward<ARGS>(args)...);
112 #endif
113  }
114 #endif
115  default:
116  throw std::runtime_error(
117  "Distributed method not found for Measurement Operator. Are you sure you compiled with "
118  "MPI?");
119  }
120 }
std::shared_ptr< sopt::LinearTransform< T > > init_degrid_operator_2d(const Vector< t_real > &u, const Vector< t_real > &v, const Vector< t_real > &w, const Vector< t_complex > &weights, const t_uint &imsizey, const t_uint &imsizex, const t_real &oversample_ratio=2, const kernels::kernel kernel=kernels::kernel::kb, const t_uint Ju=4, const t_uint Jv=4, const bool w_stacking=false, const t_real &cellx=1, const t_real &celly=1)
Returns linear transform that is the standard degridding operator.
Definition: operators.h:608

References gpu_mpi_distribute_grid, gpu_mpi_distribute_image, gpu_serial, purify::measurementoperator::init_degrid_operator_2d(), mpi_distribute_grid, mpi_distribute_image, PURIFY_LOW_LOG, and serial.

Referenced by createMeasurementOperator(), getInputData(), and TEST_CASE().

◆ padmm_factory()

template<class Algorithm >
std::enable_if< std::is_same<Algorithm, sopt::algorithm::ImagingProximalADMM<t_complex> >::value, std::shared_ptr<Algorithm> >::type purify::factory::padmm_factory ( const algo_distribution  dist,
std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &  measurements,
std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &  wavelets,
const utilities::vis_params uv_data,
const t_real  sigma,
const t_uint  imsizey,
const t_uint  imsizex,
const t_uint  sara_size,
const t_uint  max_iterations = 500,
const bool  real_constraint = true,
const bool  positive_constraint = true,
const bool  tight_frame = false,
const t_real  relative_variation = 1e-3,
const t_real  l1_proximal_tolerance = 1e-2,
const t_uint  maximum_proximal_iterations = 50,
const t_real  residual_tolerance_scaling = 1 
)

return shared pointer to padmm object

Definition at line 52 of file algorithm_factory.h.

63  {
64  typedef typename Algorithm::Scalar t_scalar;
65  if (sara_size > 1 and tight_frame)
66  throw std::runtime_error(
67  "l1 proximal not consistent: You say you are using a tight frame, but you have more than "
68  "one wavelet basis.");
69  PURIFY_INFO("Constructing PADMM algorithm");
70  auto epsilon = std::sqrt(2 * uv_data.size() + 2 * std::sqrt(4 * uv_data.size())) * sigma;
71  auto padmm = std::make_shared<Algorithm>(uv_data.vis);
72  padmm->itermax(max_iterations)
73  .relative_variation(relative_variation)
74  .tight_frame(tight_frame)
75  .l1_proximal_tolerance(l1_proximal_tolerance)
76  .l1_proximal_nu(1)
77  .l1_proximal_itermax(maximum_proximal_iterations)
78  .l1_proximal_positivity_constraint(positive_constraint)
79  .l1_proximal_real_constraint(real_constraint)
80  .lagrange_update_scale(0.9)
81  .Psi(*wavelets)
82  .Phi(*measurements);
83 #ifdef PURIFY_MPI
84  ConvergenceType obj_conv = ConvergenceType::mpi_global;
85  ConvergenceType rel_conv = ConvergenceType::mpi_global;
86 #endif
87  switch (dist) {
88  case (algo_distribution::serial):
89  padmm
90  ->regulariser_strength(
91  (wavelets->adjoint() * (measurements->adjoint() * uv_data.vis).eval())
92  .cwiseAbs()
93  .maxCoeff() *
94  1e-3)
95  .l2ball_proximal_epsilon(epsilon)
96  .residual_tolerance(epsilon * residual_tolerance_scaling);
97  return padmm;
98  break;
99 
100 #ifdef PURIFY_MPI
101  case (algo_distribution::mpi_serial): {
102  obj_conv = ConvergenceType::mpi_global;
103  rel_conv = ConvergenceType::mpi_global;
104  auto const comm = sopt::mpi::Communicator::World();
105  epsilon = std::sqrt(2 * comm.all_sum_all(uv_data.size()) +
106  2 * std::sqrt(4 * comm.all_sum_all(uv_data.size()))) *
107  sigma;
108  // communicator ensuring l2 norm in l2ball proximal is global
109  padmm->l2ball_proximal_communicator(comm);
110  } break;
111 
112  case (algo_distribution::mpi_distributed): {
113  obj_conv = ConvergenceType::mpi_local;
114  rel_conv = ConvergenceType::mpi_local;
115  auto const comm = sopt::mpi::Communicator::World();
116  epsilon = std::sqrt(2 * uv_data.size() + 2 * std::sqrt(4 * uv_data.size()) *
117  std::sqrt(comm.all_sum_all(4 * uv_data.size())) /
118  comm.all_sum_all(std::sqrt(4 * uv_data.size()))) *
119  sigma;
120  } break;
121 
122 #endif
123  default:
124  throw std::runtime_error(
125  "Type of distributed proximal ADMM algorithm not recognised. You might not have compiled "
126  "with MPI.");
127  }
128 #ifdef PURIFY_MPI
129  auto const comm = sopt::mpi::Communicator::World();
130  std::weak_ptr<Algorithm> const padmm_weak(padmm);
131  // set epsilon
132  padmm->residual_tolerance(epsilon * residual_tolerance_scaling).l2ball_proximal_epsilon(epsilon);
133  // set regulariser_strength
134  padmm->regulariser_strength(comm.all_reduce(
135  utilities::step_size<Vector<t_complex>>(uv_data.vis, measurements, wavelets, sara_size) *
136  1e-3,
137  MPI_MAX));
138  // communicator ensuring l1 norm in l1 proximal is global
139  padmm->l1_proximal_adjoint_space_comm(comm);
140  padmm->residual_convergence(
141  purify::factory::l2_convergence_factory<typename Algorithm::Scalar>(rel_conv, padmm_weak));
142  padmm->objective_convergence(
143  purify::factory::l1_convergence_factory<typename Algorithm::Scalar>(obj_conv, padmm_weak));
144 #endif
145  return padmm;
146 }

References mpi_distributed, mpi_global, mpi_local, mpi_serial, padmm(), PURIFY_INFO, serial, purify::utilities::vis_params::size(), purify::utilities::step_size(), and purify::utilities::vis_params::vis.

◆ primaldual_factory()

template<class Algorithm >
std::enable_if< std::is_same<Algorithm, sopt::algorithm::ImagingPrimalDual<t_complex> >::value, std::shared_ptr<Algorithm> >::type purify::factory::primaldual_factory ( const algo_distribution  dist,
std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &  measurements,
std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &  wavelets,
const utilities::vis_params uv_data,
const t_real  sigma,
const t_uint  imsizey,
const t_uint  imsizex,
const t_uint  sara_size,
const t_uint  max_iterations = 500,
const bool  real_constraint = true,
const bool  positive_constraint = true,
const t_real  relative_variation = 1e-3,
const t_real  residual_tolerance_scaling = 1 
)

return shared pointer to primal dual object

Definition at line 254 of file algorithm_factory.h.

263  {
264  typedef typename Algorithm::Scalar t_scalar;
265  PURIFY_INFO("Constructing Primal Dual algorithm");
266  auto epsilon = std::sqrt(2 * uv_data.size() + 2 * std::sqrt(4 * uv_data.size())) * sigma;
267  auto primaldual = std::make_shared<Algorithm>(uv_data.vis);
268  primaldual->itermax(max_iterations)
269  .relative_variation(relative_variation)
270  .real_constraint(real_constraint)
271  .positivity_constraint(positive_constraint)
272  .Psi(*wavelets)
273  .Phi(*measurements)
274  .tau(0.5 / (measurements->sq_norm() + 1))
275  .xi(1.)
276  .sigma(1.);
277 #ifdef PURIFY_MPI
278  ConvergenceType obj_conv = ConvergenceType::mpi_global;
279  ConvergenceType rel_conv = ConvergenceType::mpi_global;
280 #endif
281  switch (dist) {
282  case (algo_distribution::serial): {
283  primaldual
284  ->regulariser_strength(
285  (wavelets->adjoint() * (measurements->adjoint() * uv_data.vis).eval())
286  .cwiseAbs()
287  .maxCoeff() *
288  1e-3)
289  .l2ball_proximal_epsilon(epsilon)
290  .residual_tolerance(epsilon * residual_tolerance_scaling);
291  return primaldual;
292  }
293 #ifdef PURIFY_MPI
294  case (algo_distribution::mpi_serial): {
295  obj_conv = ConvergenceType::mpi_global;
296  rel_conv = ConvergenceType::mpi_global;
297  auto const comm = sopt::mpi::Communicator::World();
298  epsilon = std::sqrt(2 * comm.all_sum_all(uv_data.size()) +
299  2 * std::sqrt(4 * comm.all_sum_all(uv_data.size()))) *
300  sigma;
301  // communicator ensuring l2 norm in l2ball proximal is global
302  primaldual->l2ball_proximal_communicator(comm);
303  } break;
304  case (algo_distribution::mpi_distributed): {
305  obj_conv = ConvergenceType::mpi_local;
306  rel_conv = ConvergenceType::mpi_local;
307  auto const comm = sopt::mpi::Communicator::World();
308  epsilon = std::sqrt(2 * uv_data.size() + 2 * std::sqrt(4 * uv_data.size()) *
309  std::sqrt(comm.all_sum_all(4 * uv_data.size())) /
310  comm.all_sum_all(std::sqrt(4 * uv_data.size()))) *
311  sigma;
312  } break;
313  case (algo_distribution::mpi_random_updates): {
314  auto const comm = sopt::mpi::Communicator::World();
315  epsilon = std::sqrt(2 * uv_data.size() + 2 * std::sqrt(4 * uv_data.size()) *
316  std::sqrt(comm.all_sum_all(4 * uv_data.size())) /
317  comm.all_sum_all(std::sqrt(4 * uv_data.size()))) *
318  sigma;
319  obj_conv = ConvergenceType::mpi_local;
320  rel_conv = ConvergenceType::mpi_local;
321  std::shared_ptr<bool> random_measurement_update_ptr = std::make_shared<bool>(true);
322  std::shared_ptr<bool> random_wavelet_update_ptr = std::make_shared<bool>(true);
323  const t_int update_size = std::max<t_int>(std::floor(0.5 * comm.size()), 1);
324  auto random_measurement_updater = random_updater::random_updater(
325  comm, comm.size(), update_size, random_measurement_update_ptr, "measurements");
326  auto random_wavelet_updater = random_updater::random_updater(
327  comm, std::min<t_int>(comm.size(), std::floor(comm.all_sum_all(sara_size))),
328  std::min<t_int>(update_size, std::floor(0.5 * comm.all_sum_all(sara_size))),
329  random_measurement_update_ptr, "wavelets");
330 
331  primaldual->random_measurement_updater(random_measurement_updater)
332  .random_wavelet_updater(random_wavelet_updater)
333  .v_all_sum_all_comm(comm)
334  .u_all_sum_all_comm(comm);
335  } break;
336 #endif
337  default:
338  throw std::runtime_error(
339  "Type of distributed primal dual algorithm not recognised. You might not have compiled "
340  "with MPI.");
341  }
342 #ifdef PURIFY_MPI
343  auto const comm = sopt::mpi::Communicator::World();
344  std::weak_ptr<Algorithm> const primaldual_weak(primaldual);
345  // set epsilon
346  primaldual->residual_tolerance(epsilon * residual_tolerance_scaling)
347  .l2ball_proximal_epsilon(epsilon);
348  // set regulariser_strength
349  primaldual->regulariser_strength(comm.all_reduce(
350  utilities::step_size<Vector<t_complex>>(uv_data.vis, measurements, wavelets, sara_size) *
351  1e-3,
352  MPI_MAX));
353  // communicator ensuring l1 norm in l1 proximal is global
354  primaldual->residual_convergence(
355  purify::factory::l2_convergence_factory<typename Algorithm::Scalar>(rel_conv,
356  primaldual_weak));
357  primaldual->objective_convergence(
358  purify::factory::l1_convergence_factory<typename Algorithm::Scalar>(obj_conv,
359  primaldual_weak));
360 #endif
361  return primaldual;
362 }
std::function< bool()> random_updater(const sopt::mpi::Communicator &comm, const t_int total, const t_int update_size, const std::shared_ptr< bool > update_pointer, const std::string &update_name)

References mpi_distributed, mpi_global, mpi_local, mpi_random_updates, mpi_serial, PURIFY_INFO, purify::random_updater::random_updater(), serial, purify::utilities::vis_params::size(), purify::utilities::step_size(), and purify::utilities::vis_params::vis.

Referenced by main().

◆ wavelet_operator_factory() [1/2]

template<class T >
std::shared_ptr<sopt::LinearTransform<T> const> purify::factory::wavelet_operator_factory ( const distributed_wavelet_operator  distribute,
const std::vector< std::tuple< std::string, t_uint >> &  wavelets,
const t_uint  imsizey,
const t_uint  imsizex 
)

Definition at line 55 of file wavelet_operator_factory.h.

58  {
59  t_uint size = 0;
60  return wavelet_operator_factory<T>(distribute, wavelets, imsizey, imsizex, size);
61 }

◆ wavelet_operator_factory() [2/2]

template<class T >
std::shared_ptr<sopt::LinearTransform<T> const> purify::factory::wavelet_operator_factory ( const distributed_wavelet_operator  distribute,
const std::vector< std::tuple< std::string, t_uint >> &  wavelets,
const t_uint  imsizey,
const t_uint  imsizex,
t_uint &  sara_size 
)

construct sara wavelet operator

Definition at line 21 of file wavelet_operator_factory.h.

24  {
25  const auto sara = sopt::wavelets::SARA(wavelets.begin(), wavelets.end());
26  switch (distribute) {
27  case (distributed_wavelet_operator::serial): {
28  PURIFY_LOW_LOG("Using serial wavelet operator.");
29  sara_size = sara.size();
30  if (sara.size() == 0)
31  return std::make_shared<sopt::LinearTransform<T> const>(
32  [imsizex, imsizey](T& out, const T& x) { out = T::Zero(imsizey * imsizex); },
33  std::array<t_int, 3>{0, 1, static_cast<t_int>(imsizex * imsizey)},
34  [imsizex, imsizey](T& out, const T& x) { out = T::Zero(imsizey * imsizex); },
35  std::array<t_int, 3>{0, 1, static_cast<t_int>(imsizey * imsizex)});
36  return std::make_shared<sopt::LinearTransform<T>>(
37  sopt::linear_transform<typename T::Scalar>(sara, imsizey, imsizex));
38  }
39 #ifdef PURIFY_MPI
40  case (distributed_wavelet_operator::mpi_sara): {
41  auto const comm = sopt::mpi::Communicator::World();
42  PURIFY_LOW_LOG("Using distributed image MPI wavelet operator.");
43  const auto dsara = sopt::wavelets::distribute_sara(sara, comm);
44  sara_size = dsara.size();
45  return std::make_shared<sopt::LinearTransform<T>>(
46  sopt::linear_transform<typename T::Scalar>(dsara, imsizey, imsizex, comm));
47  }
48 #endif
49  default:
50  throw std::runtime_error(
51  "Distributed method not found for Wavelet Operator. Are you sure you compiled with MPI?");
52  }
53 }

References mpi_sara, PURIFY_LOW_LOG, and serial.

Variable Documentation

◆ algo_distribution_string

const std::map<std::string, algo_distribution> purify::factory::algo_distribution_string
Initial value:
= {
{"none", algo_distribution::serial},
{"serial-equivalent", algo_distribution::mpi_serial},
{"random-updates", algo_distribution::mpi_random_updates},
{"fully-distributed", algo_distribution::mpi_distributed}}

Definition at line 38 of file algorithm_factory.h.

Referenced by purify::YamlParser::parseAndSetAlgorithmOptions().