PURIFY
Next-generation radio interferometric imaging
wproj_operators_gpu.h
Go to the documentation of this file.
1 #ifndef PURIFY_WPROJ_OPERATORS_GPU_H
2 #define PURIFY_WPROJ_OPERATORS_GPU_H
3 
4 #include "purify/operators_gpu.h"
6 
7 #ifdef PURIFY_ARRAYFIRE
8 namespace purify {
9 namespace gpu {
10 namespace operators {
11 std::tuple<sopt::OperatorFunction<af::array>, sopt::OperatorFunction<af::array>>
12 af_base_padding_and_FFT_2d(const std::function<t_real(t_real)> &ftkerneluv, const t_uint imsizey,
13  const t_uint imsizex, const t_real oversample_ratio, const t_real w_mean,
14  const t_real cellx, const t_real celly) {
15  sopt::OperatorFunction<af::array> directZ, indirectZ;
16  sopt::OperatorFunction<af::array> directFFT, indirectFFT;
17  const Image<t_complex> S = purify::details::init_correction_radial_2d(
18  oversample_ratio, imsizey, imsizex, ftkerneluv, w_mean, cellx, celly);
19  PURIFY_LOW_LOG("Building GPU Measurement Operator: WGFZDB");
20  PURIFY_LOW_LOG("Constructing Zero Padding and Correction Operator: ZDB");
21  PURIFY_MEDIUM_LOG("Image size (width, height): {} x {}", imsizex, imsizey);
22  PURIFY_MEDIUM_LOG("Oversampling Factor: {}", oversample_ratio);
23  std::tie(directZ, indirectZ) =
24  purify::gpu::operators::init_af_zero_padding_2d(S.cast<t_complexf>(), oversample_ratio);
25  PURIFY_LOW_LOG("Constructing FFT operator: F");
26  std::tie(directFFT, indirectFFT) =
27  purify::gpu::operators::init_af_FFT_2d(imsizey, imsizex, oversample_ratio);
28  const auto direct = sopt::chained_operators<af::array>(directFFT, directZ);
29  const auto indirect = sopt::chained_operators<af::array>(indirectZ, indirectFFT);
30  return std::make_tuple(direct, indirect);
31 }
32 std::tuple<sopt::OperatorFunction<Vector<t_complex>>, sopt::OperatorFunction<Vector<t_complex>>>
33 base_degrid_operator_2d(const Vector<t_real> &u, const Vector<t_real> &v, const Vector<t_real> &w,
34  const Vector<t_complex> &weights, const t_uint &imsizey,
35  const t_uint &imsizex, const t_real oversample_ratio,
36  const kernels::kernel kernel, const t_uint Ju, const t_uint Jw,
37  const bool w_stacking, const t_real cellx, const t_real celly,
38  const t_real absolute_error, const t_real relative_error,
39  const dde_type dde, const t_uint idx = 0) {
40  sopt::OperatorFunction<af::array> directFZ, indirectFZ;
41  sopt::OperatorFunction<af::array> directG, indirectG;
42  t_real const w_mean = w_stacking ? w.array().mean() : 0;
43  switch (dde) {
44  case (dde_type::wkernel_radial): {
45  auto const kerneluvs = purify::create_radial_ftkernel(kernel, Ju, oversample_ratio);
46  std::tie(directFZ, indirectFZ) = af_base_padding_and_FFT_2d(
47  std::get<0>(kerneluvs), imsizey, imsizex, oversample_ratio, w_mean, cellx, celly);
48  PURIFY_MEDIUM_LOG("FoV (width, height): {} deg x {} deg", imsizex * cellx / (60. * 60.),
49  imsizey * celly / (60. * 60.));
50  PURIFY_LOW_LOG("Constructing Weighting and Gridding Operators: WG");
51  PURIFY_MEDIUM_LOG("Number of visibilities: {}", u.size());
52  PURIFY_MEDIUM_LOG("Mean, w: {}, +/- {}", w.array().mean(), (w.maxCoeff() - w.minCoeff()) * 0.5);
53  std::tie(directG, indirectG) =
54  init_af_gridding_matrix_2d(u, v, w.array() - w_mean, weights, imsizey, imsizex,
55  oversample_ratio, std::get<0>(kerneluvs), std::get<1>(kerneluvs),
56  Ju, Jw, cellx, celly, absolute_error, relative_error, dde);
57  break;
58  }
59  case (dde_type::wkernel_2d): {
60  std::function<t_real(t_real)> kernelu, kernelv, ftkernelu, ftkernelv;
61  std::tie(kernelu, kernelv, ftkernelu, ftkernelv) =
62  purify::create_kernels(kernel, Ju, Ju, imsizey, imsizex, oversample_ratio);
63  std::tie(directFZ, indirectFZ) = af_base_padding_and_FFT_2d(
64  ftkernelu, ftkernelv, imsizey, imsizex, oversample_ratio, w_mean, cellx, celly);
65  PURIFY_MEDIUM_LOG("FoV (width, height): {} deg x {} deg", imsizex * cellx / (60. * 60.),
66  imsizey * celly / (60. * 60.));
67  PURIFY_LOW_LOG("Constructing Weighting and Gridding Operators: WG");
68  PURIFY_MEDIUM_LOG("Number of visibilities: {}", u.size());
69  PURIFY_MEDIUM_LOG("Mean, w: {}, +/- {}", w.array().mean(), (w.maxCoeff() - w.minCoeff()) * 0.5);
70  auto const kerneluvs = purify::create_radial_ftkernel(kernel, Ju, oversample_ratio);
71  std::tie(directG, indirectG) =
72  init_af_gridding_matrix_2d(u, v, w.array() - w_mean, weights, imsizey, imsizex,
73  oversample_ratio, std::get<0>(kerneluvs), std::get<1>(kerneluvs),
74  Ju, Jw, cellx, celly, absolute_error, relative_error, dde);
75  break;
76  }
77  default:
78  throw std::runtime_error("DDE is not recognised.");
79  }
80  auto direct = gpu::host_wrapper(sopt::chained_operators<af::array>(directG, directFZ),
81  imsizey * imsizex, u.size(), idx);
82  auto indirect = gpu::host_wrapper(sopt::chained_operators<af::array>(indirectFZ, indirectG),
83  u.size(), imsizex * imsizey, idx);
84  PURIFY_LOW_LOG("Finished consturction of Φ.");
85  return std::make_tuple(direct, indirect);
86 }
87 } // namespace operators
88 
89 namespace measurementoperator {
90 
92 std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>> init_degrid_operator_2d(
93  const Vector<t_real> &u, const Vector<t_real> &v, const Vector<t_real> &w,
94  const Vector<t_complex> &weights, const t_uint imsizey, const t_uint imsizex,
95  const t_real oversample_ratio, const kernels::kernel kernel, const t_uint Ju, const t_uint Jw,
96  const bool w_stacking, const t_real cellx, const t_real celly, const t_real absolute_error,
97  const t_real relative_error, const dde_type dde, const t_uint idx = 0) {
98  std::array<t_int, 3> N = {0, 1, static_cast<t_int>(imsizey * imsizex)};
99  std::array<t_int, 3> M = {0, 1, static_cast<t_int>(u.size())};
100  sopt::OperatorFunction<Vector<t_complex>> directDegrid, indirectDegrid;
101  std::tie(directDegrid, indirectDegrid) = gpu::operators::base_degrid_operator_2d(
102  u, v, w, weights, imsizey, imsizex, oversample_ratio, kernel, Ju, Jw, w_stacking, cellx,
103  celly, absolute_error, relative_error, dde, idx);
104  auto direct = directDegrid;
105  auto indirect = indirectDegrid;
106  return std::make_shared<sopt::LinearTransform<Vector<t_complex>>>(direct, M, indirect, N);
107 }
108 
109 std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>> init_degrid_operator_2d(
110  const utilities::vis_params &uv_vis_input, const t_uint imsizey, const t_uint imsizex,
111  const t_real cell_x, const t_real cell_y, const t_real oversample_ratio,
112  const kernels::kernel kernel, const t_uint Ju, const t_uint Jw, const bool w_stacking,
113  const t_real absolute_error, const t_real relative_error, const dde_type dde,
114  const t_uint idx = 0) {
115  const auto uv_vis = utilities::convert_to_pixels(uv_vis_input, cell_x, cell_y, imsizex, imsizey,
116  oversample_ratio);
117  return init_degrid_operator_2d(uv_vis.u, uv_vis.v, uv_vis.w, uv_vis.weights, imsizey, imsizex,
118  oversample_ratio, kernel, Ju, Jw, w_stacking, cell_x, cell_y,
119  absolute_error, relative_error, dde, idx);
120 }
121 #ifdef PURIFY_MPI
123 std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>> init_degrid_operator_2d(
124  const sopt::mpi::Communicator &comm, const Vector<t_real> &u, const Vector<t_real> &v,
125  const Vector<t_real> &w, const Vector<t_complex> &weights, const t_uint imsizey,
126  const t_uint imsizex, const t_real oversample_ratio, const kernels::kernel kernel,
127  const t_uint Ju, const t_uint Jw, const bool w_stacking, const t_real cellx, const t_real celly,
128  const t_real absolute_error, const t_real relative_error, const dde_type dde,
129  const t_uint idx = 0) {
130  std::array<t_int, 3> N = {0, 1, static_cast<t_int>(imsizey * imsizex)};
131  std::array<t_int, 3> M = {0, 1, static_cast<t_int>(u.size())};
132  sopt::OperatorFunction<Vector<t_complex>> directDegrid, indirectDegrid;
133  std::tie(directDegrid, indirectDegrid) = gpu::operators::base_degrid_operator_2d(
134  u, v, w, weights, imsizey, imsizex, oversample_ratio, kernel, Ju, Jw, w_stacking, cellx,
135  celly, absolute_error, relative_error, dde, idx);
136  const auto allsumall = purify::operators::init_all_sum_all<Vector<t_complex>>(comm);
137  auto direct = directDegrid;
138  auto indirect = sopt::chained_operators<Vector<t_complex>>(allsumall, indirectDegrid);
139  return std::make_shared<sopt::LinearTransform<Vector<t_complex>>>(direct, M, indirect, N);
140 }
141 
142 std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>> init_degrid_operator_2d(
143  const sopt::mpi::Communicator &comm, const utilities::vis_params &uv_vis_input,
144  const t_uint imsizey, const t_uint imsizex, const t_real cell_x, const t_real cell_y,
145  const t_real oversample_ratio, const kernels::kernel kernel, const t_uint Ju, const t_uint Jw,
146  const bool w_stacking, const t_real absolute_error, const t_real relative_error,
147  const dde_type dde, const t_uint idx = 0) {
148  const auto uv_vis = utilities::convert_to_pixels(uv_vis_input, cell_x, cell_y, imsizex, imsizey,
149  oversample_ratio);
150  return init_degrid_operator_2d(comm, uv_vis.u, uv_vis.v, uv_vis.w, uv_vis.weights, imsizey,
151  imsizex, oversample_ratio, kernel, Ju, Jw, w_stacking, cell_x,
152  cell_y, absolute_error, relative_error, dde, idx);
153 }
154 
155 std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>> init_degrid_operator_2d_mpi(
156  const sopt::mpi::Communicator &comm, const Vector<t_real> &u, const Vector<t_real> &v,
157  const Vector<t_real> &w, const Vector<t_complex> &weights, const t_uint imsizey,
158  const t_uint imsizex, const t_real oversample_ratio, const kernels::kernel kernel,
159  const t_uint Ju, const t_uint Jw, const bool w_stacking, const t_real cellx, const t_real celly,
160  const t_real absolute_error, const t_real relative_error, const dde_type dde,
161  const t_uint idx = 0) {
162  throw std::runtime_error("Under construction!");
163  return init_degrid_operator_2d(comm, u, v, w, weights, imsizey, imsizex, oversample_ratio, kernel,
164  Ju, Jw, w_stacking, cellx, celly, absolute_error, relative_error,
165  dde, idx);
166 }
167 
168 std::shared_ptr<sopt::LinearTransform<Vector<t_complex>>> init_degrid_operator_2d_mpi(
169  const sopt::mpi::Communicator &comm, const utilities::vis_params &uv_vis_input,
170  const t_uint imsizey, const t_uint imsizex, const t_real cell_x, const t_real cell_y,
171  const t_real oversample_ratio, const kernels::kernel kernel, const t_uint Ju, const t_uint Jw,
172  const bool w_stacking, const t_real absolute_error, const t_real relative_error,
173  const dde_type dde, const t_uint idx = 0) {
174  const auto uv_vis = utilities::convert_to_pixels(uv_vis_input, cell_x, cell_y, imsizex, imsizey,
175  oversample_ratio);
176  return init_degrid_operator_2d(comm, uv_vis.u, uv_vis.v, uv_vis.w, uv_vis.weights, imsizey,
177  imsizex, oversample_ratio, kernel, Ju, Jw, w_stacking, cell_x,
178  cell_y, absolute_error, relative_error, dde, idx);
179 }
180 #endif
181 } // namespace measurementoperator
182 } // namespace gpu
183 } // namespace purify
184 #endif
185 #endif
#define PURIFY_LOW_LOG(...)
Low priority message.
Definition: logging.h:207
#define PURIFY_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:205
const std::vector< t_real > u
data for u coordinate
Definition: operators.cc:18
const std::vector< t_real > v
data for v coordinate
Definition: operators.cc:20
Image< t_complex > init_correction_radial_2d(const t_real oversample_ratio, const t_uint imsizey_, const t_uint imsizex_, const std::function< t_real(t_real)> &ftkerneluv, const t_real w_mean, const t_real cellx, const t_real celly)
std::shared_ptr< sopt::LinearTransform< T > > init_degrid_operator_2d(const Vector< t_real > &u, const Vector< t_real > &v, const Vector< t_real > &w, const Vector< t_complex > &weights, const t_uint &imsizey, const t_uint &imsizex, const t_real &oversample_ratio=2, const kernels::kernel kernel=kernels::kernel::kb, const t_uint Ju=4, const t_uint Jv=4, const bool w_stacking=false, const t_real &cellx=1, const t_real &celly=1)
Returns linear transform that is the standard degridding operator.
Definition: operators.h:608
std::tuple< sopt::OperatorFunction< T >, sopt::OperatorFunction< T > > base_degrid_operator_2d(const Vector< t_real > &u, const Vector< t_real > &v, const Vector< t_real > &w, const Vector< t_complex > &weights, const t_uint &imsizey, const t_uint &imsizex, const t_real &oversample_ratio=2, const kernels::kernel kernel=kernels::kernel::kb, const t_uint Ju=4, const t_uint Jv=4, const fftw_plan &ft_plan=fftw_plan::measure, const bool w_stacking=false, const t_real &cellx=1, const t_real &celly=1, const bool on_the_fly=true)
Definition: operators.h:490
utilities::vis_params w_stacking(utilities::vis_params const &params, sopt::mpi::Communicator const &comm, const t_int iters, const std::function< t_real(t_real)> &cost, const t_real k_means_rel_diff)
utilities::vis_params convert_to_pixels(const utilities::vis_params &uv_vis, const t_real cell_x, const t_real cell_y, const t_real imsizex, const t_real imsizey, const t_real oversample_ratio)
Converts u and v coordaintes to units of pixels.
std::tuple< std::function< t_real(t_real)>, std::function< t_real(t_real)> > create_radial_ftkernel(const kernels::kernel kernel_name_, const t_uint Ju_, const t_real oversample_ratio)
Definition: kernels.cc:347
std::complex< float > t_complexf
Definition: types.h:21
std::tuple< std::function< t_real(t_real)>, std::function< t_real(t_real)>, std::function< t_real(t_real)>, std::function< t_real(t_real)> > create_kernels(const kernels::kernel kernel_name_, const t_uint Ju_, const t_uint Jv_, const t_real imsizey_, const t_real imsizex_, const t_real oversample_ratio)
Definition: kernels.cc:249
dde_type
Types of DDEs in purify.
Definition: types.h:59