SOPT
Sparse OPTimisation
direct.h
Go to the documentation of this file.
1 #ifndef SOPT_WAVELETS_DIRECT_H
2 #define SOPT_WAVELETS_DIRECT_H
3 
4 #include "sopt/config.h"
5 #include <algorithm> // for std::copy<>
6 #include <type_traits>
7 #include "sopt/types.h"
9 
10 // Function inside anonymouns namespace won't appear in library
11 namespace sopt::wavelets {
12 
13 namespace {
18 template <typename T0, typename T1>
19 typename std::enable_if<T1::IsVectorAtCompileTime, void>::type direct_transform_impl(
20  Eigen::ArrayBase<T0> const &coeffs_, Eigen::ArrayBase<T1> const &signal,
21  WaveletData const &wavelet) {
22  Eigen::ArrayBase<T0> &coeffs = const_cast<Eigen::ArrayBase<T0> &>(coeffs_);
23  assert(coeffs.size() == signal.size());
24  assert(wavelet.direct_filter.low.size() == wavelet.direct_filter.high.size());
25 
26  auto const N = signal.size() / 2;
27  down_convolve(coeffs.head(N), signal, wavelet.direct_filter.low);
28  down_convolve(coeffs.tail(coeffs.size() - N), signal, wavelet.direct_filter.high);
29 }
30 
35 template <typename T0, typename T1>
36 typename std::enable_if<not T1::IsVectorAtCompileTime, void>::type direct_transform_impl(
37  Eigen::ArrayBase<T0> const &coeffs_, Eigen::ArrayBase<T1> const &signal_,
38  WaveletData const &wavelet) {
39  Eigen::ArrayBase<T0> &coeffs = const_cast<Eigen::ArrayBase<T0> &>(coeffs_);
40  Eigen::ArrayBase<T1> &signal = const_cast<Eigen::ArrayBase<T1> &>(signal_);
41  assert(coeffs.rows() == signal.rows());
42  assert(coeffs.cols() == signal.cols());
43  assert(wavelet.direct_filter.low.size() == wavelet.direct_filter.high.size());
44  for (t_uint i = 0; i < static_cast<t_uint>(coeffs.rows()); ++i)
45  direct_transform_impl(coeffs.row(i).transpose(), signal.row(i).transpose(), wavelet);
46 
47  for (t_uint i = 0; i < static_cast<t_uint>(coeffs.cols()); ++i) {
48  signal.col(i) = coeffs.col(i);
49  direct_transform_impl(coeffs.col(i), signal.col(i), wavelet);
50  }
51 }
52 } // namespace
53 
60 template <typename T0, typename T1>
61 typename std::enable_if<T1::IsVectorAtCompileTime, void>::type direct_transform(
62  Eigen::ArrayBase<T0> &coeffs, Eigen::ArrayBase<T1> const &signal, t_uint levels,
63  WaveletData const &wavelet) {
64  assert(coeffs.rows() == signal.rows());
65  assert(coeffs.cols() == signal.cols());
66 
67  auto input = copy(signal);
68  if (levels > 0) direct_transform_impl(coeffs, input, wavelet);
69  for (t_uint level(1); level < levels; ++level) {
70  auto const N = static_cast<t_uint>(signal.size()) >> level;
71  input.head(N) = coeffs.head(N);
72  direct_transform_impl(coeffs.head(N), input.head(N), wavelet);
73  }
74 }
80 template <typename T0, typename T1>
81 typename std::enable_if<not T1::IsVectorAtCompileTime, void>::type direct_transform(
82  Eigen::ArrayBase<T0> const &coeffs_, Eigen::ArrayBase<T1> const &signal, t_uint levels,
83  WaveletData const &wavelet) {
84  assert(coeffs_.rows() == signal.rows());
85  assert(coeffs_.cols() == signal.cols());
86  Eigen::ArrayBase<T0> &coeffs = const_cast<Eigen::ArrayBase<T0> &>(coeffs_);
87 
88  if (levels == 0) {
89  coeffs = signal;
90  return;
91  }
92 
93  auto input = copy(signal);
94  direct_transform_impl(coeffs, input, wavelet);
95  for (t_uint level(1); level < levels; ++level) {
96  auto const Nx = static_cast<t_uint>(signal.rows()) >> level;
97  auto const Ny = static_cast<t_uint>(signal.cols()) >> level;
98  input.topLeftCorner(Nx, Ny) = coeffs.topLeftCorner(Nx, Ny);
99  direct_transform_impl(coeffs.topLeftCorner(Nx, Ny), input.topLeftCorner(Nx, Ny), wavelet);
100  }
101 }
102 
106 template <typename T0>
107 auto direct_transform(Eigen::ArrayBase<T0> const &signal, t_uint levels, WaveletData const &wavelet)
108  -> decltype(copy(signal)) {
109  auto result = copy(signal);
110  direct_transform(result, signal, levels, wavelet);
111  return result;
112 }
113 } // namespace sopt::wavelets
114 #endif
constexpr auto N
Definition: wavelets.cc:57
std::enable_if< T1::IsVectorAtCompileTime, void >::type direct_transform(Eigen::ArrayBase< T0 > &coeffs, Eigen::ArrayBase< T1 > const &signal, t_uint levels, WaveletData const &wavelet)
N-levels 1d direct transform.
Definition: direct.h:61
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
t_vector low
Low-pass filter for direct transform.
Definition: wavelet_data.h:22
t_vector high
High-pass filter for direct transform.
Definition: wavelet_data.h:24
Holds wavelets coefficients.
Definition: wavelet_data.h:11
struct sopt::wavelets::WaveletData::DirectFilter direct_filter