1 #ifndef SOPT_WAVELETS_H
2 #define SOPT_WAVELETS_H
6 #include "sopt/config.h"
18 template <
typename T,
typename OP>
31 template <
typename T,
typename OP>
40 op.indirect(coeffs, signal);
48 op.direct(coeffs, signal);
87 #define SOPT_WAVELET_MACRO(NAME) \
88 template <typename T> \
89 Vector<T> &NAME(Vector<T> &coeffs, const t_uint basis_index, const t_uint level, \
90 const t_uint rows, const t_uint cols); \
91 template <typename T> \
92 Vector<T> &NAME(Vector<T> &coeffs, const t_uint basis_index, const t_uint level, \
93 const t_uint rows, const t_uint cols) { \
94 return NAME(get_wavelet_basis_coefficients(coeffs, basis_index, coeffs.size()), level, rows, \
102 #undef SOPT_WAVELET_MACRO
104 template <
typename T>
107 assert(coeffs.size() > basis_index * size);
108 return coeffs.segment(basis_index * size, size);
110 template <
typename T>
112 auto const N =
static_cast<t_uint>(coeffs.size()) >> level;
113 return coeffs.head(
N);
115 template <
typename T>
117 auto const N =
static_cast<t_uint>(coeffs.size()) >> level;
120 template <
typename T>
124 auto const Nx =
static_cast<t_uint>(signal.rows()) >> level;
125 auto const Ny =
static_cast<t_uint>(signal.cols()) >> level;
126 return Vector<T>::Map(signal.topLeftCorner(Nx, Ny).data(), Nx * Ny);
128 template <
typename T>
131 auto const Nx =
rows >> level;
132 auto const Ny =
cols >> level;
135 return Vector<T>::Map(signal.topRightCorner(signal.rows() / 2, signal.cols() / 2).data(),
138 template <
typename T>
141 auto const Nx =
rows >> level;
142 auto const Ny =
cols >> level;
145 return Vector<T>::Map(signal.bottomRightCorner(signal.rows() / 2, signal.cols() / 2).data(),
148 template <
typename T>
151 auto const Nx =
rows >> level;
152 auto const Ny =
cols >> level;
155 return Vector<T>::Map(signal.bottomLeftCorner(signal.rows() / 2, signal.cols() / 2).data(),
158 template <
typename T>
161 auto const Nx =
rows >> level;
162 auto const Ny =
cols >> level;
165 return Vector<T>::Map(signal.topLeftCorner(signal.rows() / 2, signal.cols() / 2).data(),
174 template <
typename T>
176 return details::linear_transform<T, wavelets::Wavelet>(wavelet);
182 template <
typename T>
184 return details::linear_transform<T, wavelets::SARA>(sara);
190 template <
typename T>
193 return details::linear_transform<T, wavelets::Wavelet>(wavelet,
rows,
cols);
201 template <
typename T>
204 return details::linear_transform<T, wavelets::SARA>(sara,
rows,
cols, sara.size());
215 template <
typename T>
217 sopt::mpi::Communicator
const &comm) {
218 auto const factor = sara.size();
219 auto const normalization = std::sqrt(sara.size()) / std::sqrt(comm.all_sum_all(sara.size()));
229 sara.indirect(coeffs, signal);
230 out *= normalization;
232 comm.all_sum_all(out);
240 sara.direct(coeffs, signal);
241 out *= normalization;
Sparsity Averaging Reweighted Analysis.
Performs direct and indirect wavelet transforms.
Vector< T > & get_wavelet_low_low_pass(Vector< T > &coeffs, const t_uint level, const t_uint rows, const t_uint cols)
return wavelet basis coefficients high pass (rows) and high pass (cols) for a given level
Vector< T > & get_wavelet_high_pass_1d(Vector< T > &coeffs, const t_uint level, const t_uint size)
return 1d high pass filter for a given level of a wavelet
Vector< T > & get_wavelet_low_high_pass(Vector< T > &coeffs, const t_uint level, const t_uint rows, const t_uint cols)
return wavelet basis coefficients low pass (rows) and high pass (cols) for a given level
Vector< T > & get_wavelet_high_low_pass(Vector< T > &coeffs, const t_uint level, const t_uint rows, const t_uint cols)
return wavelet basis coefficients high pass (rows) and high pass (cols) for a given level
Vector< T > & get_wavelet_basis_coefficients(Vector< T > &coeffs, const t_uint basis_index, const t_uint size)
return wavelet basis coefficients from a dictionary
Vector< T > & get_wavelet_levels_1d(Vector< T > &coeffs, const t_uint level, const t_uint size)
return wavelet basis coefficients for a given level and below (1d case)
Vector< T > & get_wavelet_levels(Vector< T > &coeffs, const t_uint level, const t_uint rows, const t_uint cols)
return wavelet basis coefficients for a given level and below (2d case)
Vector< T > & get_wavelet_high_high_pass(Vector< T > &coeffs, const t_uint level, const t_uint rows, const t_uint cols)
return wavelet basis coefficients high pass (rows) and high pass (cols) for a given level
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
int t_int
Root of the type hierarchy for signed integers.
size_t t_uint
Root of the type hierarchy for unsigned integers.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > Matrix
A matrix of a given type.
sopt::Vector< Scalar > Vector
sopt::Image< Scalar > Image
#define SOPT_WAVELET_MACRO(NAME)