1 #include <catch2/catch_all.hpp>
14 for (t_iVector::Index i(0); i < x.size(); i += 2) result(i / 2) = x(i);
19 for (t_iVector::Index i(1); i < x.size(); i += 2) result(i / 2) = x(i);
23 Eigen::Array<typename T::Scalar, T::RowsAtCompileTime, T::ColsAtCompileTime>
upsample(
24 Eigen::ArrayBase<T>
const &input) {
25 using Matrix = Eigen::Array<typename T::Scalar, T::RowsAtCompileTime, T::ColsAtCompileTime>;
26 Matrix result(input.size() * 2);
27 for (t_iVector::Index i(0); i < input.size(); ++i) {
28 result(2 * i) = input(i);
29 result(2 * i + 1) = 0;
35 extern std::unique_ptr<std::mt19937_64>
mersenne;
36 std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
40 extern std::unique_ptr<std::mt19937_64>
mersenne;
42 std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
43 for (t_iVector::Index i(0); i < result.size(); ++i) result(i) = uniform_dist(*
mersenne);
48 template <
typename T0>
51 auto const input = input_.eval();
58 CHECK(input.isApprox(actual, 1e-14));
71 TEST_CASE(
"Wavelet transform innards with integer data",
"[wavelet]") {
77 large << 4, 5, 6, 7, 8, 9;
79 SECTION(
"Periodic scalar product") {
81 CHECK(periodic_scalar_product(large, small, 0) == 1 * 4 + 2 * 5 + 3 * 6);
82 CHECK(periodic_scalar_product(large, small, 1) == 1 * 5 + 2 * 6 + 3 * 7);
83 CHECK(periodic_scalar_product(large, small, 3) == 1 * 7 + 2 * 8 + 3 * 9);
86 CHECK(periodic_scalar_product(large, small, 4) == 1 * 8 + 2 * 9 + 3 * 4);
88 CHECK(periodic_scalar_product(large, small.reverse(), 4) == 3 * 8 + 2 * 9 + 1 * 4);
90 CHECK(periodic_scalar_product(large, small, 4 + large.size()) == 1 * 8 + 2 * 9 + 3 * 4);
91 CHECK(periodic_scalar_product(large, small, 4 - 3 * large.size()) == 1 * 8 + 2 * 9 + 3 * 4);
94 CHECK(periodic_scalar_product(small, large.head(4), 1) == 4 * 2 + 5 * 3 + 6 * 1 + 7 * 2);
100 convolve(result, large, small);
102 CHECK(result(0) == 1 * 4 + 2 * 5 + 3 * 6);
103 CHECK(result(1) == 1 * 5 + 2 * 6 + 3 * 7);
104 CHECK(result(3) == 1 * 7 + 2 * 8 + 3 * 9);
105 CHECK(result(4) == 1 * 8 + 2 * 9 + 3 * 4);
108 SECTION(
"Convolve and sum") {
113 convolve_sum(result, large, small, large, 0 * small);
114 convolve(noOffset, large, small);
115 CHECK(result(small.size() - 1) == noOffset(0));
116 CHECK(result(0) == noOffset(result.size() - small.size() + 1));
119 convolve_sum(result, large, 0 * small, large, small);
120 CHECK(result(small.size() - 1) == noOffset(0));
121 CHECK(result(0) == noOffset(result.size() - small.size() + 1));
124 auto const trial = [&small, &large](
int a,
int b,
int c,
int d) {
126 convolve_sum(result,
a * large,
b * small, c * large, d * small);
131 CHECK((trial(0, 1, 3, 1) == trial(0, 1, 1, 3)).all());
132 CHECK((trial(5, 1, 3, 1) == trial(3, 1, 5, 1)).all());
133 CHECK((trial(1, 5, 3, 1) == trial(3, 1, 5, 1)).all());
134 CHECK((trial(1, 3, 5, 1) == trial(3, 1, 5, 1)).all());
135 CHECK((trial(1, 3, 1, 5) == trial(3, 1, 5, 1)).all());
136 CHECK((trial(1, 0, 4, 2) == trial(3, 1, 5, 1)).all());
137 CHECK((trial(1, -1, 1, 1) == trial(0, 1, 0, 1)).all());
138 CHECK((trial(4, -3, 2, 6) == trial(0, 1, 0, 1)).all());
141 SECTION(
"Convolve and Down-sample simultaneously") {
143 convolve(expected, large, small);
145 down_convolve(actual, large, small);
146 for (
size_t i(0); i < static_cast<size_t>(actual.size()); ++i)
147 CHECK(expected(i * 2) == actual(i));
150 SECTION(
"Convolve output to expression") {
153 convolve(actual.head(large.size()), large, small);
154 convolve(expected, large, small);
155 CHECK((actual.head(large.size()) == expected).all());
158 SECTION(
"Copy does copy") {
159 auto result = copy(large);
160 CHECK(large.data() != result.data());
162 auto actual = copy(large.head(3));
163 CHECK(large.data() != actual.data());
164 CHECK(large.data() == large.head(3).data());
167 SECTION(
"Convolve, Sum and Up-sample simultaneously") {
171 auto const Nhead = Ncoeffs / 2;
172 auto const Ntail = Ncoeffs - Nhead;
181 up_convolve_sum(actual, coeffs,
even(low),
odd(low),
even(high),
odd(high));
183 convolve_sum(expected,
upsample(coeffs.head(Nhead)), low,
upsample(coeffs.tail(Ntail)), high);
184 CHECK((actual == expected).all());
189 TEST_CASE(
"1D wavelet transform with floating point data",
"[wavelet]") {
190 using namespace sopt;
197 REQUIRE((data.rows() % 2 == 0 and (data.cols() == 1 or data.cols() % 2 == 0)));
199 SECTION(
"Direct transform == two downsample + convolution") {
203 down_convolve(high, data.row(0), wavelet.direct_filter.high);
204 down_convolve(low, data.row(0), wavelet.direct_filter.low);
205 CHECK(low.transpose().isApprox(actual.head(data.row(0).size() / 2)));
206 CHECK(high.transpose().isApprox(actual.tail(data.row(0).size() / 2)));
209 SECTION(
"Indirect transform == two upsample + convolution") {
211 auto const low =
upsample(data.row(0).transpose().head(data.rows() / 2));
212 auto const high =
upsample(data.row(0).transpose().tail(data.rows() / 2));
213 auto expected = copy(data.row(0).transpose());
214 convolve_sum(expected, low, wavelet.direct_filter.low.reverse(), high,
215 wavelet.direct_filter.high.reverse());
216 CAPTURE(expected.transpose());
217 CAPTURE(actual.transpose());
218 CHECK(expected.isApprox(actual));
221 SECTION(
"Round-trip test for single level") {
222 for (
t_int i(0); i < 20; ++i) {
227 SECTION(
"Round-trip test for two levels") {
234 t_uint constexpr nlevels = 5;
235 SECTION(
"Round-trip test for multiple levels") {
236 for (
t_int i(0); i < 10; ++i) {
244 TEST_CASE(
"1D wavelet transform with complex data",
"[wavelet]") {
245 using namespace sopt;
247 SECTION(
"Round-trip test for complex data") {
251 CHECK(input.isApprox(actual, 1e-14));
256 TEST_CASE(
"2D wavelet transform with real data",
"[wavelet]") {
257 using namespace sopt;
259 SECTION(
"Single level round-trip test for square matrix") {
263 SECTION(
"Single level round-trip test for non-square matrix") {
265 auto Ny = Nx + 5 * 2;
268 SECTION(
"Round-trip test for multiple levels") {
269 for (
t_int i(0); i < 10; ++i) {
279 using namespace sopt;
282 SECTION(
"Normal instances") {
283 auto const transform = wavelet.direct(input);
285 CHECK(input.isApprox(wavelet.indirect(transform)));
287 SECTION(
"Expression instances") {
289 wavelet.direct(output.row(0).transpose(), input.row(0).transpose());
290 wavelet.indirect(output.row(0).transpose(), output.row(1).transpose());
291 CHECK(input.row(0).isApprox(output.row(1)));
296 using namespace sopt;
300 wavelet.direct(output, input);
301 CHECK(output.rows() == input.rows());
302 CHECK(output.cols() == input.cols());
305 wavelet.indirect(input, output);
306 CHECK(output.rows() == input.rows());
307 CHECK(output.cols() == input.cols());
311 using namespace sopt;
316 wavelet.direct(output, input);
317 CHECK(output.isApprox(input));
320 wavelet.indirect(input, output);
321 CHECK(output.isApprox(input));
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))
std::enable_if< T1::IsVectorAtCompileTime, void >::type direct_transform(Eigen::ArrayBase< T0 > &coeffs, Eigen::ArrayBase< T1 > const &signal, t_uint levels, WaveletData const &wavelet)
N-levels 1d direct transform.
Wavelet factory(const std::string &name, t_uint nlevels)
Creates a wavelet transform object.
std::enable_if< T1::IsVectorAtCompileTime, void >::type indirect_transform(Eigen::ArrayBase< T0 > const &coeffs, Eigen::ArrayBase< T1 > &signal, t_uint levels, WaveletData const &wavelet)
N-levels 1d indirect transform.
WaveletData const & daubechies_data(t_uint n)
Factory function returning specific daubechie wavelet data.
int t_int
Root of the type hierarchy for signed integers.
size_t t_uint
Root of the type hierarchy for unsigned integers.
Eigen::Array< T, Eigen::Dynamic, Eigen::Dynamic > Image
A 2-dimensional list of elements of given type.
Eigen::Array< T, Eigen::Dynamic, 1 > Array
A 1-dimensional list of elements of given type.
sopt::Matrix< Scalar > Matrix
sopt::t_int random_integer(sopt::t_int min, sopt::t_int max)
sopt::Array< sopt::t_uint > t_iVector
t_iVector random_ivector(sopt::t_int size, sopt::t_int min, sopt::t_int max)
void check_round_trip(Eigen::ArrayBase< T0 > const &input_, sopt::t_uint db, sopt::t_uint nlevels=1)
Eigen::Array< typename T::Scalar, T::RowsAtCompileTime, T::ColsAtCompileTime > upsample(Eigen::ArrayBase< T > const &input)
t_iVector even(t_iVector const &x)
t_iVector odd(t_iVector const &x)
TEST_CASE("wavelet data")