PURIFY
Next-generation radio interferometric imaging
measurement_operator_factory.h
Go to the documentation of this file.
1 #ifndef MEASUREMENT_OPERATOR_FACTORY_H
2 #define MEASUREMENT_OPERATOR_FACTORY_H
3 
4 #include "purify/config.h"
5 
6 #include "purify/types.h"
7 #include "purify/logging.h"
8 
9 #include "purify/operators.h"
10 #include "purify/operators_gpu.h"
11 #include "purify/wproj_operators.h"
13 
14 #ifdef PURIFY_MPI
15 #include <sopt/mpi/communicator.h>
16 #include <sopt/mpi/session.h>
17 #endif
18 
19 namespace purify {
20 namespace factory {
23  serial,
27  gpu_serial,
31 };
32 
33 namespace {
34 template <class T>
35 void check_complex_for_gpu() {
36  if (!std::is_same<Vector<t_complex>, T>::value)
37  throw std::runtime_error("Arrayfire will only use complex type with Eigen.");
38 }
39 } // namespace
40 
42 template <class T, class... ARGS>
43 std::shared_ptr<sopt::LinearTransform<T>> all_to_all_measurement_operator_factory(
44  const distributed_measurement_operator distribute, const std::vector<t_int> &image_stacks,
45  const std::vector<t_real> &w_stacks, ARGS &&...args) {
46  switch (distribute) {
47 #ifdef PURIFY_MPI
49  PURIFY_LOW_LOG("Using MPI all to all measurement operator.");
50  auto const world = sopt::mpi::Communicator::World();
51  return measurementoperator::init_degrid_operator_2d_all_to_all<T>(world, image_stacks, w_stacks,
52  std::forward<ARGS>(args)...);
53  }
54 #endif
55  default:
56  throw std::runtime_error(
57  "Distributed method not found for Measurement Operator. Are you sure you compiled with "
58  "MPI?");
59  }
60 }
62 template <class T, class... ARGS>
63 std::shared_ptr<sopt::LinearTransform<T>> measurement_operator_factory(
64  const distributed_measurement_operator distribute, ARGS &&...args) {
65  switch (distribute) {
67  PURIFY_LOW_LOG("Using serial measurement operator.");
68  return measurementoperator::init_degrid_operator_2d<T>(std::forward<ARGS>(args)...);
69  }
71 #ifndef PURIFY_ARRAYFIRE
72  throw std::runtime_error("Tried to use GPU operator but you did not build with ArrayFire.");
73 #else
74  check_complex_for_gpu<T>();
75  PURIFY_LOW_LOG("Using serial measurement operator with Arrayfire.");
76  af::setDevice(0);
77  return gpu::measurementoperator::init_degrid_operator_2d(std::forward<ARGS>(args)...);
78 #endif
79  }
80 #ifdef PURIFY_MPI
82  auto const world = sopt::mpi::Communicator::World();
83  PURIFY_LOW_LOG("Using distributed image MPI measurement operator.");
84  return measurementoperator::init_degrid_operator_2d<T>(world, std::forward<ARGS>(args)...);
85  }
87  auto const world = sopt::mpi::Communicator::World();
88  PURIFY_LOW_LOG("Using distributed grid MPI measurement operator.");
89  return measurementoperator::init_degrid_operator_2d_mpi<T>(world, std::forward<ARGS>(args)...);
90  }
92 #ifndef PURIFY_ARRAYFIRE
93  throw std::runtime_error("Tried to use GPU operator but you did not build with ArrayFire.");
94 #else
95  check_complex_for_gpu<T>();
96  auto const world = sopt::mpi::Communicator::World();
97  PURIFY_LOW_LOG("Using distributed image MPI + Arrayfire measurement operator.");
98  af::setDevice(0);
99  return gpu::measurementoperator::init_degrid_operator_2d(world, std::forward<ARGS>(args)...);
100 #endif
101  }
103 #ifndef PURIFY_ARRAYFIRE
104  throw std::runtime_error("Tried to use GPU operator but you did not build with ArrayFire.");
105 #else
106  check_complex_for_gpu<T>();
107  auto const world = sopt::mpi::Communicator::World();
108  PURIFY_LOW_LOG("Using distributed grid MPI + Arrayfire measurement operator.");
109  af::setDevice(0);
110  return gpu::measurementoperator::init_degrid_operator_2d_mpi(world,
111  std::forward<ARGS>(args)...);
112 #endif
113  }
114 #endif
115  default:
116  throw std::runtime_error(
117  "Distributed method not found for Measurement Operator. Are you sure you compiled with "
118  "MPI?");
119  }
120 }
121 
122 } // namespace factory
123 } // namespace purify
124 #endif
#define PURIFY_LOW_LOG(...)
Low priority message.
Definition: logging.h:207
distributed_measurement_operator
determine type of distribute for mpi measurement operator
std::shared_ptr< sopt::LinearTransform< T > > measurement_operator_factory(const distributed_measurement_operator distribute, ARGS &&...args)
distributed measurement operator factory
std::shared_ptr< sopt::LinearTransform< T > > all_to_all_measurement_operator_factory(const distributed_measurement_operator distribute, const std::vector< t_int > &image_stacks, const std::vector< t_real > &w_stacks, ARGS &&...args)
distributed measurement operator factory
std::shared_ptr< sopt::LinearTransform< T > > init_degrid_operator_2d(const Vector< t_real > &u, const Vector< t_real > &v, const Vector< t_real > &w, const Vector< t_complex > &weights, const t_uint &imsizey, const t_uint &imsizex, const t_real &oversample_ratio=2, const kernels::kernel kernel=kernels::kernel::kb, const t_uint Ju=4, const t_uint Jv=4, const bool w_stacking=false, const t_real &cellx=1, const t_real &celly=1)
Returns linear transform that is the standard degridding operator.
Definition: operators.h:608