SOPT
Sparse OPTimisation
sara.cc
Go to the documentation of this file.
1 #include <catch2/catch_all.hpp>
2 #include <random>
3 #include <string>
4 #include <tuple>
5 
6 #include "sopt/wavelets.h"
7 #include "sopt/wavelets/sara.h"
8 
10  extern std::unique_ptr<std::mt19937_64> mersenne;
11  std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
12  return uniform_dist(*mersenne);
13 };
14 
15 TEST_CASE("Check SARA implementation mechanically", "[wavelet]") {
16  using namespace sopt::wavelets;
17  using namespace sopt;
18 
19  using t_i = std::tuple<std::string, sopt::t_uint>;
20  SARA const sara{t_i{std::string{"DB3"}, 1u}, t_i{std::string{"DB1"}, 2u},
21  t_i{std::string{"DB1"}, 3u}};
22  SECTION("Construction and vector functionality") {
23  CHECK(sara.size() == 3);
24  CHECK(sara[0].levels() == 1);
25  CHECK(sara[1].levels() == 2);
26  CHECK(sara[2].levels() == 3);
27  CHECK(sara.max_levels() == 3);
28  CHECK(sara[0].coefficients.isApprox(factory("DB3", 1).coefficients));
29  CHECK(sara[1].coefficients.isApprox(factory("DB1", 1).coefficients));
30  CHECK(sara[2].coefficients.isApprox(factory("DB1", 1).coefficients));
31  }
32 
33  Image<> const input = Image<>::Random((1u << sara.max_levels()) * 3, (1u << sara.max_levels()));
34  Image<> coeffs;
35  sara.direct(coeffs, input);
36 
37  SECTION("Direct transform") {
38  Image<> const first = sara[0].direct(input) / std::sqrt(sara.size());
39  Image<> const second = sara[1].direct(input) / std::sqrt(sara.size());
40  Image<> const third = sara[2].direct(input) / std::sqrt(sara.size());
41 
42  auto const N = input.cols();
43  CAPTURE(coeffs.leftCols(N));
44  CAPTURE(first);
45  CHECK(coeffs.leftCols(N).isApprox(first));
46  CHECK(coeffs.leftCols(2 * N).rightCols(N).isApprox(second));
47  CHECK(coeffs.rightCols(N).isApprox(third));
48  }
49 
50  SECTION("Indirect transform") {
51  auto const output = sara.indirect(coeffs);
52  CHECK(output.isApprox(input));
53  }
54 }
55 
56 TEST_CASE("Linear-transform wrapper", "[wavelet]") {
57  using namespace sopt::wavelets;
58  using namespace sopt;
59  SARA const sara{std::make_tuple(std::string{"DB3"}, 1u), std::make_tuple(std::string{"DB1"}, 2u),
60  std::make_tuple(std::string{"DB1"}, 3u)};
61  SECTION("1d") {
62  auto constexpr rows = 256, cols = 1;
63  auto const Psi = linear_transform<t_real>(sara, rows, cols);
64  SECTION("Indirect transform") {
65  Image<> const image = Image<>::Random(rows, cols);
66  Image<> const expected = sara.direct(image);
67  // The linear transform expects a column vector as input
68  auto const as_vector = Vector<>::Map(image.data(), image.size());
69  // And it returns a column vector as well
70  Vector<> const actual = Psi.adjoint() * as_vector;
71  CHECK(actual.size() == expected.size());
72  auto const coeffs = Image<>::Map(actual.data(), image.rows(), image.cols() * sara.size());
73  CHECK(expected.rows() == coeffs.rows());
74  CHECK(expected.cols() == coeffs.cols());
75  CHECK(coeffs.isApprox(expected, 1e-8));
76  }
77  SECTION("direct transform") {
78  Image<> const coeffs = Image<>::Random(rows, cols * sara.size());
79  Image<> const expected = sara.indirect(coeffs);
80  // The linear transform expects a column vector as input
81  auto const as_vector = Vector<>::Map(coeffs.data(), coeffs.size());
82  // And it returns a column vector as well
83  Vector<> const actual = Psi * as_vector;
84  CHECK(actual.size() == expected.size());
85  CHECK(coeffs.cols() % sara.size() == 0);
86  auto const image = Image<>::Map(actual.data(), coeffs.rows(), coeffs.cols() / sara.size());
87  CHECK(expected.rows() == image.rows());
88  CHECK(expected.cols() == image.cols());
89  CHECK(image.isApprox(expected, 1e-8));
90  }
91  }
92  SECTION("2d") {
93  auto constexpr rows = 256, cols = 256;
94  auto const Psi = linear_transform<t_real>(sara, rows, cols);
95  SECTION("Indirect transform") {
96  Image<> const image = Image<>::Random(rows, cols);
97  Image<> const expected = sara.direct(image);
98  // The linear transform expects a column vector as input
99  auto const as_vector = Vector<>::Map(image.data(), image.size());
100  // And it returns a column vector as well
101  Vector<> const actual = Psi.adjoint() * as_vector;
102  CHECK(actual.size() == expected.size());
103  auto const coeffs = Image<>::Map(actual.data(), image.rows(), image.cols() * sara.size());
104  CHECK(expected.rows() == coeffs.rows());
105  CHECK(expected.cols() == coeffs.cols());
106  CHECK(coeffs.isApprox(expected, 1e-8));
107  }
108  SECTION("direct transform") {
109  Image<> const coeffs = Image<>::Random(rows, cols * sara.size());
110  Image<> const expected = sara.indirect(coeffs);
111  // The linear transform expects a column vector as input
112  auto const as_vector = Vector<>::Map(coeffs.data(), coeffs.size());
113  // And it returns a column vector as well
114  Vector<> const actual = Psi * as_vector;
115  CHECK(actual.size() == expected.size());
116  CHECK(coeffs.cols() % sara.size() == 0);
117  auto const image = Image<>::Map(actual.data(), coeffs.rows(), coeffs.cols() / sara.size());
118  CHECK(expected.rows() == image.rows());
119  CHECK(expected.cols() == image.cols());
120  CHECK(image.isApprox(expected, 1e-8));
121  }
122  }
123 }
constexpr auto N
Definition: wavelets.cc:57
Sparsity Averaging Reweighted Analysis.
Definition: sara.h:20
std::unique_ptr< std::mt19937_64 > mersenne(new std::mt19937_64(0))
t_uint rows
t_uint cols
Wavelet factory(const std::string &name, t_uint nlevels)
Creates a wavelet transform object.
Definition: wavelets.cc:8
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
Eigen::Array< T, Eigen::Dynamic, Eigen::Dynamic > Image
A 2-dimensional list of elements of given type.
Definition: types.h:39
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
sopt::t_int random_integer(sopt::t_int min, sopt::t_int max)
Definition: sara.cc:9
TEST_CASE("Check SARA implementation mechanically", "[wavelet]")
Definition: sara.cc:15