SOPT
Sparse OPTimisation
wavelets.h
Go to the documentation of this file.
1 #ifndef SOPT_WAVELETS_H
2 #define SOPT_WAVELETS_H
3 
4 #include <iostream>
5 // Convenience header to include wavelets headers and additional utilities
6 #include "sopt/config.h"
8 #ifdef SOPT_MPI
10 #endif
11 #include "sopt/wavelets/sara.h"
12 #include "sopt/wavelets/wavelets.h"
13 
14 namespace sopt {
15 namespace details {
16 namespace {
18 template <typename T, typename OP>
21  [op](Vector<T> &out, Vector<T> const &x) { op.indirect(x.array(), out.array()); },
22  [op](Vector<T> &out, Vector<T> const &x) { op.direct(out.array(), x.array()); });
23 }
31 template <typename T, typename OP>
33  t_uint factor = 1) {
35  [op, rows, cols, factor](Vector<T> &out, Vector<T> const &x) {
36  assert(static_cast<t_uint>(x.size()) == rows * cols * factor);
37  out.resize(rows * cols);
38  auto signal = Image<T>::Map(out.data(), rows, cols);
39  auto const coeffs = Image<T>::Map(x.data(), rows, cols * factor);
40  op.indirect(coeffs, signal);
41  },
42  {{0, 1, static_cast<t_int>(rows * cols)}},
43  [op, rows, cols, factor](Vector<T> &out, Vector<T> const &x) {
44  assert(static_cast<t_uint>(x.size()) == rows * cols);
45  out.resize(rows * cols * factor);
46  auto const signal = Image<T>::Map(x.data(), rows, cols);
47  auto coeffs = Image<T>::Map(out.data(), rows, cols * factor);
48  op.direct(coeffs, signal);
49  },
50  {{0, 1, static_cast<t_int>(factor * rows * cols)}});
51 }
52 } // namespace
53 } // namespace details
54 
55 namespace utilities {
57 template <typename T>
58 Vector<T> &get_wavelet_basis_coefficients(Vector<T> &coeffs, const t_uint basis_index,
59  const t_uint size);
61 template <typename T>
62 Vector<T> &get_wavelet_levels_1d(Vector<T> &coeffs, const t_uint level, const t_uint size);
64 template <typename T>
65 Vector<T> &get_wavelet_levels(Vector<T> &coeffs, const t_uint level, const t_uint rows,
66  const t_uint cols);
68 template <typename T>
69 Vector<T> &get_wavelet_low_high_pass(Vector<T> &coeffs, const t_uint level, const t_uint rows,
70  const t_uint cols);
72 template <typename T>
73 Vector<T> &get_wavelet_high_high_pass(Vector<T> &coeffs, const t_uint level, const t_uint rows,
74  const t_uint cols);
76 template <typename T>
77 Vector<T> &get_wavelet_high_low_pass(Vector<T> &coeffs, const t_uint level, const t_uint rows,
78  const t_uint cols);
80 template <typename T>
81 Vector<T> &get_wavelet_low_low_pass(Vector<T> &coeffs, const t_uint level, const t_uint rows,
82  const t_uint cols);
84 template <typename T>
85 Vector<T> &get_wavelet_high_pass_1d(Vector<T> &coeffs, const t_uint level, const t_uint size);
86 // macro to add version to work with a wavelet dictionary
87 #define SOPT_WAVELET_MACRO(NAME) \
88  template <typename T> \
89  Vector<T> &NAME(Vector<T> &coeffs, const t_uint basis_index, const t_uint level, \
90  const t_uint rows, const t_uint cols); \
91  template <typename T> \
92  Vector<T> &NAME(Vector<T> &coeffs, const t_uint basis_index, const t_uint level, \
93  const t_uint rows, const t_uint cols) { \
94  return NAME(get_wavelet_basis_coefficients(coeffs, basis_index, coeffs.size()), level, rows, \
95  cols); \
96  };
102 #undef SOPT_WAVELET_MACRO
103 // implimentations
104 template <typename T>
106  const t_uint size) {
107  assert(coeffs.size() > basis_index * size);
108  return coeffs.segment(basis_index * size, size);
109 }
110 template <typename T>
111 Vector<T> &get_wavelet_levels_1d(Vector<T> &coeffs, const t_uint level, const t_uint size) {
112  auto const N = static_cast<t_uint>(coeffs.size()) >> level; // bitshift to divide by 2^level
113  return coeffs.head(N);
114 }
115 template <typename T>
116 Vector<T> &get_wavelet_high_pass_1d(Vector<T> &coeffs, const t_uint level, const t_uint size) {
117  auto const N = static_cast<t_uint>(coeffs.size()) >> level; // bitshift to divide by 2^level
118  return get_wavelet_levels(coeffs, level, size).tail(N / 2);
119 }
120 template <typename T>
122  const t_uint cols) {
123  const Matrix<T> signal = Matrix<T>::Map(coeffs.data(), rows, cols);
124  auto const Nx = static_cast<t_uint>(signal.rows()) >> level; // bitshift to divide by 2^level
125  auto const Ny = static_cast<t_uint>(signal.cols()) >> level;
126  return Vector<T>::Map(signal.topLeftCorner(Nx, Ny).data(), Nx * Ny);
127 }
128 template <typename T>
130  const t_uint cols) {
131  auto const Nx = rows >> level; // bitshift to divide by 2^level
132  auto const Ny = cols >> level;
133  const Matrix<T> signal =
134  Matrix<T>::Map(get_wavelet_levels(coeffs, level, rows, cols).data(), Nx, Ny);
135  return Vector<T>::Map(signal.topRightCorner(signal.rows() / 2, signal.cols() / 2).data(),
136  signal.size() / 4);
137 }
138 template <typename T>
140  const t_uint cols) {
141  auto const Nx = rows >> level; // bitshift to divide by 2^level
142  auto const Ny = cols >> level;
143  const Matrix<T> signal =
144  Matrix<T>::Map(get_wavelet_levels(coeffs, level, rows, cols).data(), Nx, Ny);
145  return Vector<T>::Map(signal.bottomRightCorner(signal.rows() / 2, signal.cols() / 2).data(),
146  signal.size() / 4);
147 }
148 template <typename T>
150  const t_uint cols) {
151  auto const Nx = rows >> level; // bitshift to divide by 2^level
152  auto const Ny = cols >> level;
153  const Matrix<T> signal =
154  Matrix<T>::Map(get_wavelet_levels(coeffs, level, rows, cols).data(), Nx, Ny);
155  return Vector<T>::Map(signal.bottomLeftCorner(signal.rows() / 2, signal.cols() / 2).data(),
156  signal.size() / 4);
157 }
158 template <typename T>
160  const t_uint cols) {
161  auto const Nx = rows >> level; // bitshift to divide by 2^level
162  auto const Ny = cols >> level;
163  const Matrix<T> signal =
164  Matrix<T>::Map(get_wavelet_levels(coeffs, level, rows, cols).data(), Nx, Ny);
165  return Vector<T>::Map(signal.topLeftCorner(signal.rows() / 2, signal.cols() / 2).data(),
166  signal.size() / 4);
167 }
168 
169 } // namespace utilities
170 
174 template <typename T>
176  return details::linear_transform<T, wavelets::Wavelet>(wavelet);
177 }
178 
182 template <typename T>
184  return details::linear_transform<T, wavelets::SARA>(sara);
185 }
186 
190 template <typename T>
192  t_uint cols = 1) {
193  return details::linear_transform<T, wavelets::Wavelet>(wavelet, rows, cols);
194 }
201 template <typename T>
203  t_uint cols = 1) {
204  return details::linear_transform<T, wavelets::SARA>(sara, rows, cols, sara.size());
205 }
206 #ifdef SOPT_MPI
215 template <typename T>
216 LinearTransform<Vector<T>> linear_transform(wavelets::SARA const &sara, t_uint rows, t_uint cols,
217  sopt::mpi::Communicator const &comm) {
218  auto const factor = sara.size();
219  auto const normalization = std::sqrt(sara.size()) / std::sqrt(comm.all_sum_all(sara.size()));
221  [sara, rows, cols, factor, comm, normalization](Vector<T> &out, Vector<T> const &x) {
222  assert(static_cast<t_uint>(x.size()) == rows * cols * factor);
223  out.resize(rows * cols);
224  if (sara.empty())
225  out.fill(0);
226  else {
227  auto signal = Image<T>::Map(out.data(), rows, cols);
228  auto const coeffs = Image<T>::Map(x.data(), rows, cols * factor);
229  sara.indirect(coeffs, signal);
230  out *= normalization;
231  }
232  comm.all_sum_all(out);
233  },
234  {{0, 1, static_cast<t_int>(rows * cols)}},
235  [sara, rows, cols, factor, normalization](Vector<T> &out, Vector<T> const &x) {
236  assert(static_cast<t_uint>(x.size()) == rows * cols);
237  out.resize(rows * cols * factor);
238  auto const signal = Image<T>::Map(x.data(), rows, cols);
239  auto coeffs = Image<T>::Map(out.data(), rows, cols * factor);
240  sara.direct(coeffs, signal);
241  out *= normalization;
242  },
243  {{0, 1, static_cast<t_int>(factor * rows * cols)}});
244 }
245 #endif
246 
247 } // namespace sopt
248 
249 #endif
constexpr auto N
Definition: wavelets.cc:57
Joins together direct and indirect operators.
Sparsity Averaging Reweighted Analysis.
Definition: sara.h:20
Performs direct and indirect wavelet transforms.
Definition: wavelets.h:21
t_uint rows
t_uint cols
Vector< T > & get_wavelet_low_low_pass(Vector< T > &coeffs, const t_uint level, const t_uint rows, const t_uint cols)
return wavelet basis coefficients high pass (rows) and high pass (cols) for a given level
Definition: wavelets.h:159
Vector< T > & get_wavelet_high_pass_1d(Vector< T > &coeffs, const t_uint level, const t_uint size)
return 1d high pass filter for a given level of a wavelet
Definition: wavelets.h:116
Vector< T > & get_wavelet_low_high_pass(Vector< T > &coeffs, const t_uint level, const t_uint rows, const t_uint cols)
return wavelet basis coefficients low pass (rows) and high pass (cols) for a given level
Definition: wavelets.h:129
Vector< T > & get_wavelet_high_low_pass(Vector< T > &coeffs, const t_uint level, const t_uint rows, const t_uint cols)
return wavelet basis coefficients high pass (rows) and high pass (cols) for a given level
Definition: wavelets.h:149
Vector< T > & get_wavelet_basis_coefficients(Vector< T > &coeffs, const t_uint basis_index, const t_uint size)
return wavelet basis coefficients from a dictionary
Definition: wavelets.h:105
Vector< T > & get_wavelet_levels_1d(Vector< T > &coeffs, const t_uint level, const t_uint size)
return wavelet basis coefficients for a given level and below (1d case)
Definition: wavelets.h:111
Vector< T > & get_wavelet_levels(Vector< T > &coeffs, const t_uint level, const t_uint rows, const t_uint cols)
return wavelet basis coefficients for a given level and below (2d case)
Definition: wavelets.h:121
Vector< T > & get_wavelet_high_high_pass(Vector< T > &coeffs, const t_uint level, const t_uint rows, const t_uint cols)
return wavelet basis coefficients high pass (rows) and high pass (cols) for a given level
Definition: wavelets.h:139
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > Matrix
A matrix of a given type.
Definition: types.h:29
sopt::Vector< Scalar > Vector
Definition: inpainting.cc:28
sopt::Image< Scalar > Image
Definition: inpainting.cc:30
#define SOPT_WAVELET_MACRO(NAME)
Definition: wavelets.h:87