SOPT
Sparse OPTimisation
credible_region.cc
Go to the documentation of this file.
1 #include "sopt/credible_region.h"
2 #include <iostream>
3 #include "catch2/catch_all.hpp"
5 #include "sopt/types.h"
6 
7 using namespace sopt;
8 using Scalar = t_complex;
11 using Catch::Approx;
12 t_uint rows = 128;
13 t_uint cols = 128;
15 
16 TEST_CASE("calculating gamma") {
17  logging::set_level("debug");
18  const std::function<t_real(t_Vector)> energy_function = [](const t_Vector &input) -> t_real {
19  return 0.;
20  };
21  const t_Vector x = t_Vector::Random(N);
22  CHECK(0 == energy_function(x));
23  for (t_uint i = 1; i < 10; i++) {
24  const t_real alpha = 0.9 + i * 0.01;
25  const t_real gamma = credible_region::compute_energy_upper_bound(alpha, x, energy_function);
26  CHECK(gamma == Approx(N * (std::sqrt(16 * std::log(3 / (1 - alpha)) / N) + 1)));
27  }
28 }
29 TEST_CASE("caculating upper and lower interval") {
30  const t_Vector x = t_Vector::Constant(N, 0.5);
31  const std::function<t_real(t_Vector)> energy_function = [](const t_Vector &input) -> t_real {
32  return (input.array()).cwiseAbs().maxCoeff();
33  };
34  constexpr t_real gamma = 1.;
35  std::tuple<t_uint, t_uint, t_uint, t_uint> const region = std::make_tuple(0, 0, rows, cols);
36  CAPTURE(gamma);
37  t_real lower = 0;
38  t_real upper = 0;
39  t_real mean = 0;
40  std::tie(lower, mean, upper) =
41  credible_region::find_credible_interval(x, rows, cols, region, energy_function, gamma);
42  CHECK(std::abs(lower + 1.5) <= 1e-2);
43  CHECK(std::abs(mean - 0.5) <= 1e-2);
44  CHECK(std::abs(upper - 0.5) <= 1e-2);
45  CAPTURE(lower);
46  CAPTURE(mean);
47  CAPTURE(upper);
48  std::tie(lower, mean, upper) = credible_region::find_credible_interval(
49  x, rows, cols,
50  std::make_tuple(std::floor(rows * 0.25), std::floor(cols * 0.25), std::floor(rows * 0.5),
51  std::floor(cols * 0.5)),
52  energy_function, gamma);
53  CHECK(std::abs(lower + 1.5) <= 1e-2);
54  CHECK(std::abs(upper - 0.5) <= 1e-2);
55  CHECK(std::abs(mean - 0.5) <= 1e-2);
56  CAPTURE(lower);
57  CAPTURE(mean);
58  CAPTURE(upper);
59 }
60 
61 TEST_CASE("calculating upper and lower interval grid") {
62  constexpr t_uint pix_size = 16;
63  const t_uint grid_cols = std::floor(cols / pix_size);
64  const t_uint grid_rows = std::floor(rows / pix_size);
65  constexpr t_real gamma = 1.;
66  t_Image image = t_Image::Constant(rows, cols, 0);
67  const Image<t_real> expected_lower = Image<t_real>::Constant(grid_rows, grid_cols, -gamma);
68  const Image<t_real> expected_mean = Image<t_real>::Constant(grid_rows, grid_cols, 0);
69  const Image<t_real> expected_upper = Image<t_real>::Constant(grid_rows, grid_cols, gamma);
70  const t_Vector x = t_Vector::Map(image.data(), image.size());
71  const std::function<t_real(t_Vector)> energy_function = [&](const t_Vector &input) -> t_real {
72  return input.cwiseAbs().maxCoeff();
73  };
77  std::tie(lower, mean, upper) = credible_region::credible_interval_grid<t_Vector, t_real>(
78  x, rows, cols, pix_size, energy_function, gamma);
79  CHECK(expected_lower.isApprox(lower, 1e-2));
80  CHECK(expected_mean.isApprox(mean, 1e-2));
81  CHECK(expected_upper.isApprox(upper, 1e-2));
82 }
83 
84 TEST_CASE("calculating upper and lower interval grid non const") {
85  constexpr t_uint pix_size = 16;
86  rows = 145;
87  cols = 153;
88  N = rows * cols;
89  const t_uint grid_cols = std::ceil(cols / pix_size);
90  const t_uint grid_rows = std::ceil(rows / pix_size);
91  t_Image image = t_Image::Constant(rows, cols, 0);
92  const t_Vector x = t_Vector::Map(image.data(), image.size());
93  const std::function<t_real(t_Vector)> energy_function = [&](const t_Vector &input) -> t_real {
94  return input.cwiseAbs().maxCoeff();
95  };
96  constexpr t_real gamma = 1.;
100  std::tie(lower, mean, upper) = credible_region::credible_interval_grid<t_Vector, t_real>(
101  x, rows, cols, pix_size, energy_function, gamma);
102  Image<t_real> const expected_lower = Image<t_real>::Constant(grid_rows, grid_cols, -gamma);
103  Image<t_real> const expected_mean = Image<t_real>::Constant(grid_rows, grid_cols, 0);
104  Image<t_real> const expected_upper = Image<t_real>::Constant(grid_rows, grid_cols, gamma);
105  CHECK(expected_lower.isApprox(lower, 1e-2));
106  CHECK(expected_mean.isApprox(mean, 1e-2));
107  CHECK(expected_upper.isApprox(upper, 1e-2));
108 }
sopt::Vector< Scalar > t_Vector
sopt::t_real Scalar
t_uint rows
Image< Scalar > t_Image
t_uint cols
t_uint N
TEST_CASE("calculating gamma")
std::tuple< t_real, t_real, t_real > find_credible_interval(const Eigen::MatrixBase< T > &solution, const t_uint &rows, const t_uint &cols, const std::tuple< t_uint, t_uint, t_uint, t_uint > &region, const std::function< t_real(typename T::PlainObject)> &objective_function, const t_real &energy_upperbound)
t_real compute_energy_upper_bound(const t_real &alpha, const Eigen::MatrixBase< T > &solution, const std::function< t_real(typename T::PlainObject)> &objective_function)
void set_level(const std::string &level)
Method to set the logging level of the default Log object.
Definition: logging.h:154
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
Eigen::Array< T, Eigen::Dynamic, Eigen::Dynamic > Image
A 2-dimensional list of elements of given type.
Definition: types.h:39
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
std::complex< t_real > t_complex
Root of the type hierarchy for (real) complex numbers.
Definition: types.h:19