1 #ifndef MEASUREMENT_OPERATOR_FACTORY_H
2 #define MEASUREMENT_OPERATOR_FACTORY_H
4 #include "purify/config.h"
15 #include <sopt/mpi/communicator.h>
16 #include <sopt/mpi/session.h>
35 void check_complex_for_gpu() {
36 if (!std::is_same<Vector<t_complex>, T>::value)
37 throw std::runtime_error(
"Arrayfire will only use complex type with Eigen.");
42 template <
class T,
class... ARGS>
45 const std::vector<t_real> &w_stacks, ARGS &&...args) {
50 auto const world = sopt::mpi::Communicator::World();
51 return measurementoperator::init_degrid_operator_2d_all_to_all<T>(world, image_stacks, w_stacks,
52 std::forward<ARGS>(args)...);
56 throw std::runtime_error(
57 "Distributed method not found for Measurement Operator. Are you sure you compiled with "
62 template <
class T,
class... ARGS>
68 return measurementoperator::init_degrid_operator_2d<T>(std::forward<ARGS>(args)...);
71 #ifndef PURIFY_ARRAYFIRE
72 throw std::runtime_error(
"Tried to use GPU operator but you did not build with ArrayFire.");
74 check_complex_for_gpu<T>();
75 PURIFY_LOW_LOG(
"Using serial measurement operator with Arrayfire.");
82 auto const world = sopt::mpi::Communicator::World();
83 PURIFY_LOW_LOG(
"Using distributed image MPI measurement operator.");
84 return measurementoperator::init_degrid_operator_2d<T>(world, std::forward<ARGS>(args)...);
87 auto const world = sopt::mpi::Communicator::World();
88 PURIFY_LOW_LOG(
"Using distributed grid MPI measurement operator.");
89 return measurementoperator::init_degrid_operator_2d_mpi<T>(world, std::forward<ARGS>(args)...);
92 #ifndef PURIFY_ARRAYFIRE
93 throw std::runtime_error(
"Tried to use GPU operator but you did not build with ArrayFire.");
95 check_complex_for_gpu<T>();
96 auto const world = sopt::mpi::Communicator::World();
97 PURIFY_LOW_LOG(
"Using distributed image MPI + Arrayfire measurement operator.");
103 #ifndef PURIFY_ARRAYFIRE
104 throw std::runtime_error(
"Tried to use GPU operator but you did not build with ArrayFire.");
106 check_complex_for_gpu<T>();
107 auto const world = sopt::mpi::Communicator::World();
108 PURIFY_LOW_LOG(
"Using distributed grid MPI + Arrayfire measurement operator.");
110 return gpu::measurementoperator::init_degrid_operator_2d_mpi(world,
111 std::forward<ARGS>(args)...);
116 throw std::runtime_error(
117 "Distributed method not found for Measurement Operator. Are you sure you compiled with "
#define PURIFY_LOW_LOG(...)
Low priority message.
distributed_measurement_operator
determine type of distribute for mpi measurement operator
@ gpu_mpi_distribute_image
@ gpu_mpi_distribute_grid
@ gpu_mpi_distribute_all_to_all
@ mpi_distribute_all_to_all
std::shared_ptr< sopt::LinearTransform< T > > measurement_operator_factory(const distributed_measurement_operator distribute, ARGS &&...args)
distributed measurement operator factory
std::shared_ptr< sopt::LinearTransform< T > > all_to_all_measurement_operator_factory(const distributed_measurement_operator distribute, const std::vector< t_int > &image_stacks, const std::vector< t_real > &w_stacks, ARGS &&...args)
distributed measurement operator factory
std::shared_ptr< sopt::LinearTransform< T > > init_degrid_operator_2d(const Vector< t_real > &u, const Vector< t_real > &v, const Vector< t_real > &w, const Vector< t_complex > &weights, const t_uint &imsizey, const t_uint &imsizex, const t_real &oversample_ratio=2, const kernels::kernel kernel=kernels::kernel::kb, const t_uint Ju=4, const t_uint Jv=4, const bool w_stacking=false, const t_real &cellx=1, const t_real &celly=1)
Returns linear transform that is the standard degridding operator.