SOPT
Sparse OPTimisation
wavelets.cc
Go to the documentation of this file.
1 #include <catch2/catch_all.hpp>
2 #include <memory>
3 #include <random>
4 
5 #include "sopt/types.h"
6 #include "sopt/wavelets/direct.h"
10 
13  t_iVector result((x.size() + 1) / 2);
14  for (t_iVector::Index i(0); i < x.size(); i += 2) result(i / 2) = x(i);
15  return result;
16 };
17 t_iVector odd(t_iVector const &x) {
18  t_iVector result(x.size() / 2);
19  for (t_iVector::Index i(1); i < x.size(); i += 2) result(i / 2) = x(i);
20  return result;
21 };
22 template <typename T>
23 Eigen::Array<typename T::Scalar, T::RowsAtCompileTime, T::ColsAtCompileTime> upsample(
24  Eigen::ArrayBase<T> const &input) {
25  using Matrix = Eigen::Array<typename T::Scalar, T::RowsAtCompileTime, T::ColsAtCompileTime>;
26  Matrix result(input.size() * 2);
27  for (t_iVector::Index i(0); i < input.size(); ++i) {
28  result(2 * i) = input(i);
29  result(2 * i + 1) = 0;
30  }
31  return result;
32 };
33 
35  extern std::unique_ptr<std::mt19937_64> mersenne;
36  std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
37  return uniform_dist(*mersenne);
38 };
40  extern std::unique_ptr<std::mt19937_64> mersenne;
41  t_iVector result(size);
42  std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
43  for (t_iVector::Index i(0); i < result.size(); ++i) result(i) = uniform_dist(*mersenne);
44  return result;
45 };
46 
47 // Checks round trip operation
48 template <typename T0>
49 void check_round_trip(Eigen::ArrayBase<T0> const &input_, sopt::t_uint db,
50  sopt::t_uint nlevels = 1) {
51  auto const input = input_.eval();
52  auto const &dbwave = sopt::wavelets::daubechies_data(db);
53  auto const transform = sopt::wavelets::direct_transform(input, nlevels, dbwave);
54  auto const actual = sopt::wavelets::indirect_transform(transform, nlevels, dbwave);
55  CAPTURE(actual);
56  CAPTURE(input);
57  CAPTURE(transform);
58  CHECK(input.isApprox(actual, 1e-14));
59  CHECK(not transform.isApprox(sopt::wavelets::direct_transform(input, nlevels - 1, dbwave), 1e-4));
60 }
61 
62 TEST_CASE("wavelet data") {
63  for (sopt::t_int num = 1; num < 100; num++) {
64  if (num < 39)
65  REQUIRE(sopt::wavelets::daubechies_data(num).coefficients.size() == 2 * num);
66  else
67  REQUIRE_THROWS(sopt::wavelets::daubechies_data(num));
68  }
69 }
70 
71 TEST_CASE("Wavelet transform innards with integer data", "[wavelet]") {
72  using namespace sopt::wavelets;
73 
74  t_iVector small(3);
75  small << 1, 2, 3;
76  t_iVector large(6);
77  large << 4, 5, 6, 7, 8, 9;
78 
79  SECTION("Periodic scalar product") {
80  // no wrapping
81  CHECK(periodic_scalar_product(large, small, 0) == 1 * 4 + 2 * 5 + 3 * 6);
82  CHECK(periodic_scalar_product(large, small, 1) == 1 * 5 + 2 * 6 + 3 * 7);
83  CHECK(periodic_scalar_product(large, small, 3) == 1 * 7 + 2 * 8 + 3 * 9);
84 
85  // with wrapping
86  CHECK(periodic_scalar_product(large, small, 4) == 1 * 8 + 2 * 9 + 3 * 4);
87  // with wrapping and expression
88  CHECK(periodic_scalar_product(large, small.reverse(), 4) == 3 * 8 + 2 * 9 + 1 * 4);
89  // wrapping works with offset as well
90  CHECK(periodic_scalar_product(large, small, 4 + large.size()) == 1 * 8 + 2 * 9 + 3 * 4);
91  CHECK(periodic_scalar_product(large, small, 4 - 3 * large.size()) == 1 * 8 + 2 * 9 + 3 * 4);
92 
93  // signal smaller than filter
94  CHECK(periodic_scalar_product(small, large.head(4), 1) == 4 * 2 + 5 * 3 + 6 * 1 + 7 * 2);
95  }
96 
97  SECTION("Convolve") {
98  t_iVector result(large.size());
99 
100  convolve(result, large, small);
101 
102  CHECK(result(0) == 1 * 4 + 2 * 5 + 3 * 6);
103  CHECK(result(1) == 1 * 5 + 2 * 6 + 3 * 7);
104  CHECK(result(3) == 1 * 7 + 2 * 8 + 3 * 9);
105  CHECK(result(4) == 1 * 8 + 2 * 9 + 3 * 4);
106  }
107 
108  SECTION("Convolve and sum") {
109  t_iVector result(large.size());
110  t_iVector noOffset(large.size());
111 
112  // Check that if high pass is zero, then this is an offseted convolution
113  convolve_sum(result, large, small, large, 0 * small);
114  convolve(noOffset, large, small);
115  CHECK(result(small.size() - 1) == noOffset(0));
116  CHECK(result(0) == noOffset(result.size() - small.size() + 1));
117 
118  // Check same for low pass
119  convolve_sum(result, large, 0 * small, large, small);
120  CHECK(result(small.size() - 1) == noOffset(0));
121  CHECK(result(0) == noOffset(result.size() - small.size() + 1));
122 
123  // Check symmetry relationships
124  auto const trial = [&small, &large](int a, int b, int c, int d) {
125  t_iVector result(large.size());
126  convolve_sum(result, a * large, b * small, c * large, d * small);
127  return result;
128  };
129 
130  // should all be ok as long as arguments sum: (a * b) + (c * d) == (a' * b') + (c' * d')
131  CHECK((trial(0, 1, 3, 1) == trial(0, 1, 1, 3)).all());
132  CHECK((trial(5, 1, 3, 1) == trial(3, 1, 5, 1)).all());
133  CHECK((trial(1, 5, 3, 1) == trial(3, 1, 5, 1)).all());
134  CHECK((trial(1, 3, 5, 1) == trial(3, 1, 5, 1)).all());
135  CHECK((trial(1, 3, 1, 5) == trial(3, 1, 5, 1)).all());
136  CHECK((trial(1, 0, 4, 2) == trial(3, 1, 5, 1)).all());
137  CHECK((trial(1, -1, 1, 1) == trial(0, 1, 0, 1)).all());
138  CHECK((trial(4, -3, 2, 6) == trial(0, 1, 0, 1)).all());
139  }
140 
141  SECTION("Convolve and Down-sample simultaneously") {
142  t_iVector expected(large.size());
143  convolve(expected, large, small);
144  t_iVector actual(large.size() / 2);
145  down_convolve(actual, large, small);
146  for (size_t i(0); i < static_cast<size_t>(actual.size()); ++i)
147  CHECK(expected(i * 2) == actual(i));
148  }
149 
150  SECTION("Convolve output to expression") {
151  t_iVector actual(large.size() * 2);
152  t_iVector expected(large.size());
153  convolve(actual.head(large.size()), large, small);
154  convolve(expected, large, small);
155  CHECK((actual.head(large.size()) == expected).all());
156  }
157 
158  SECTION("Copy does copy") {
159  auto result = copy(large);
160  CHECK(large.data() != result.data());
161 
162  auto actual = copy(large.head(3));
163  CHECK(large.data() != actual.data());
164  CHECK(large.data() == large.head(3).data());
165  }
166 
167  SECTION("Convolve, Sum and Up-sample simultaneously") {
168  for (sopt::t_int i(0); i < 100; ++i) {
169  auto const Ncoeffs = random_integer(2, 10) * 2;
170  auto const Nfilters = random_integer(2, 5);
171  auto const Nhead = Ncoeffs / 2;
172  auto const Ntail = Ncoeffs - Nhead;
173 
174  auto const coeffs = random_ivector(Ncoeffs, -10, 10);
175  auto const low = random_ivector(Nfilters, -10, 10);
176  auto const high = random_ivector(Nfilters, -10, 10);
177 
178  t_iVector actual(Ncoeffs);
179  t_iVector expected(Ncoeffs);
180  // does all in go, more complicated but compuationally less intensive
181  up_convolve_sum(actual, coeffs, even(low), odd(low), even(high), odd(high));
182  // first up-samples, then does convolve: conceptually simpler but does unnecessary operations
183  convolve_sum(expected, upsample(coeffs.head(Nhead)), low, upsample(coeffs.tail(Ntail)), high);
184  CHECK((actual == expected).all());
185  }
186  }
187 }
188 
189 TEST_CASE("1D wavelet transform with floating point data", "[wavelet]") {
190  using namespace sopt;
191  using namespace sopt::wavelets;
192 
193  Image<> const data = Image<>::Random(16, 16);
194  auto const &wavelet = daubechies_data(4);
195 
196  // Condition on input fixture data
197  REQUIRE((data.rows() % 2 == 0 and (data.cols() == 1 or data.cols() % 2 == 0)));
198 
199  SECTION("Direct transform == two downsample + convolution") {
200  auto const actual = direct_transform(data.row(0), 1, wavelet);
201  Array<> high(data.cols() / 2);
202  Array<> low(data.cols() / 2);
203  down_convolve(high, data.row(0), wavelet.direct_filter.high);
204  down_convolve(low, data.row(0), wavelet.direct_filter.low);
205  CHECK(low.transpose().isApprox(actual.head(data.row(0).size() / 2)));
206  CHECK(high.transpose().isApprox(actual.tail(data.row(0).size() / 2)));
207  }
208 
209  SECTION("Indirect transform == two upsample + convolution") {
210  auto const actual = indirect_transform(data.row(0).transpose(), 1, wavelet);
211  auto const low = upsample(data.row(0).transpose().head(data.rows() / 2));
212  auto const high = upsample(data.row(0).transpose().tail(data.rows() / 2));
213  auto expected = copy(data.row(0).transpose());
214  convolve_sum(expected, low, wavelet.direct_filter.low.reverse(), high,
215  wavelet.direct_filter.high.reverse());
216  CAPTURE(expected.transpose());
217  CAPTURE(actual.transpose());
218  CHECK(expected.isApprox(actual));
219  }
220 
221  SECTION("Round-trip test for single level") {
222  for (t_int i(0); i < 20; ++i) {
224  }
225  }
226 
227  SECTION("Round-trip test for two levels") {
231  check_round_trip(Array<>::Random(52), 10, 2);
232  }
233 
234  t_uint constexpr nlevels = 5;
235  SECTION("Round-trip test for multiple levels") {
236  for (t_int i(0); i < 10; ++i) {
237  auto const n = random_integer(2, nlevels);
239  n);
240  }
241  }
242 }
243 
244 TEST_CASE("1D wavelet transform with complex data", "[wavelet]") {
245  using namespace sopt;
246  using namespace sopt::wavelets;
247  SECTION("Round-trip test for complex data") {
248  auto input = Array<t_complex>::Random(random_integer(2, 100) * 2).eval();
249  auto const &dbwave = daubechies_data(random_integer(1, 38));
250  auto const actual = indirect_transform(direct_transform(input, 1, dbwave), 1, dbwave);
251  CHECK(input.isApprox(actual, 1e-14));
252  CHECK(not input.isApprox(direct_transform(input, 1, dbwave), 1e-4));
253  }
254 }
255 
256 TEST_CASE("2D wavelet transform with real data", "[wavelet]") {
257  using namespace sopt;
258  using namespace sopt::wavelets;
259  SECTION("Single level round-trip test for square matrix") {
260  auto N = random_integer(2, 100) * 2;
262  }
263  SECTION("Single level round-trip test for non-square matrix") {
264  auto Nx = random_integer(2, 5) * 2;
265  auto Ny = Nx + 5 * 2;
266  check_round_trip(Image<>::Random(Nx, Ny), random_integer(1, 38), 1);
267  }
268  SECTION("Round-trip test for multiple levels") {
269  for (t_int i(0); i < 10; ++i) {
270  auto const n = random_integer(2, 5);
271  auto const Nx = random_integer(2, 5) * (1u << n);
272  auto const Ny = random_integer(2, 5) * (1u << n);
274  }
275  }
276 }
277 
278 TEST_CASE("Functor implementation", "[wavelet]") {
279  using namespace sopt;
280  auto const wavelet = wavelets::factory("DB3", 4);
281  auto const input = Image<t_complex>::Random(256, 128).eval();
282  SECTION("Normal instances") {
283  auto const transform = wavelet.direct(input);
284  CHECK(transform.isApprox(wavelets::direct_transform(input, wavelet.levels(), wavelet)));
285  CHECK(input.isApprox(wavelet.indirect(transform)));
286  }
287  SECTION("Expression instances") {
288  Image<t_complex> output(2, input.cols());
289  wavelet.direct(output.row(0).transpose(), input.row(0).transpose());
290  wavelet.indirect(output.row(0).transpose(), output.row(1).transpose());
291  CHECK(input.row(0).isApprox(output.row(1)));
292  }
293 }
294 
295 TEST_CASE("Automatic input resizing", "[wavelet]") {
296  using namespace sopt;
297  auto const wavelet = wavelets::factory("DB3", 4);
298  auto const input = Image<t_complex>::Random(256, 128).eval();
299  Image<t_complex> output(1, 1);
300  wavelet.direct(output, input);
301  CHECK(output.rows() == input.rows());
302  CHECK(output.cols() == input.cols());
303 
304  output.resize(1, 1);
305  wavelet.indirect(input, output);
306  CHECK(output.rows() == input.rows());
307  CHECK(output.cols() == input.cols());
308 }
309 
310 TEST_CASE("Dirac wavelets") {
311  using namespace sopt;
312  auto const wavelet = wavelets::factory("Dirac");
313  Image<t_complex> const input = Image<t_complex>::Random(256, 128);
314  Image<t_complex> output(1, 1);
315 
316  wavelet.direct(output, input);
317  CHECK(output.isApprox(input));
318 
319  output = Image<t_complex>::Zero(1, 1);
320  wavelet.indirect(input, output);
321  CHECK(output.isApprox(input));
322 }
constexpr auto n
Definition: wavelets.cc:56
constexpr auto N
Definition: wavelets.cc:57
constexpr Scalar b
constexpr Scalar a
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))
std::enable_if< T1::IsVectorAtCompileTime, void >::type direct_transform(Eigen::ArrayBase< T0 > &coeffs, Eigen::ArrayBase< T1 > const &signal, t_uint levels, WaveletData const &wavelet)
N-levels 1d direct transform.
Definition: direct.h:61
Wavelet factory(const std::string &name, t_uint nlevels)
Creates a wavelet transform object.
Definition: wavelets.cc:8
std::enable_if< T1::IsVectorAtCompileTime, void >::type indirect_transform(Eigen::ArrayBase< T0 > const &coeffs, Eigen::ArrayBase< T1 > &signal, t_uint levels, WaveletData const &wavelet)
N-levels 1d indirect transform.
Definition: indirect.h:58
WaveletData const & daubechies_data(t_uint n)
Factory function returning specific daubechie wavelet data.
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Eigen::Array< T, Eigen::Dynamic, Eigen::Dynamic > Image
A 2-dimensional list of elements of given type.
Definition: types.h:39
Eigen::Array< T, Eigen::Dynamic, 1 > Array
A 1-dimensional list of elements of given type.
Definition: types.h:34
sopt::Matrix< Scalar > Matrix
Definition: inpainting.cc:29
sopt::t_int random_integer(sopt::t_int min, sopt::t_int max)
Definition: wavelets.cc:34
sopt::Array< sopt::t_uint > t_iVector
Definition: wavelets.cc:11
t_iVector random_ivector(sopt::t_int size, sopt::t_int min, sopt::t_int max)
Definition: wavelets.cc:39
void check_round_trip(Eigen::ArrayBase< T0 > const &input_, sopt::t_uint db, sopt::t_uint nlevels=1)
Definition: wavelets.cc:49
Eigen::Array< typename T::Scalar, T::RowsAtCompileTime, T::ColsAtCompileTime > upsample(Eigen::ArrayBase< T > const &input)
Definition: wavelets.cc:23
t_iVector even(t_iVector const &x)
Definition: wavelets.cc:12
t_iVector odd(t_iVector const &x)
Definition: wavelets.cc:17
TEST_CASE("wavelet data")
Definition: wavelets.cc:62