SOPT
Sparse OPTimisation
credible_region.h
Go to the documentation of this file.
1 #ifndef SOPT_CREDIBLE_REGION_H
2 #define SOPT_CREDIBLE_REGION_H
3 
4 #include "sopt/config.h"
5 #include <algorithm> // for std::min()
6 #include <functional>
7 #include <iostream>
8 #include <memory> // for make_shared<>
9 #include <tuple> // for tuple<>
10 #include <type_traits>
11 #include "sopt/bisection_method.h"
12 #include "sopt/exception.h"
13 #include "sopt/logging.h"
14 #include "sopt/types.h"
15 
17 
18 template <typename T>
20  const t_real &alpha, const Eigen::MatrixBase<T> &solution,
21  const std::function<t_real(typename T::PlainObject)> &objective_function);
22 
23 template <typename T>
24 std::tuple<t_real, t_real, t_real> find_credible_interval(
25  const Eigen::MatrixBase<T> &solution, const t_uint &rows, const t_uint &cols,
26  const std::tuple<t_uint, t_uint, t_uint, t_uint> &region,
27  const std::function<t_real(typename T::PlainObject)> &objective_function,
28  const t_real &energy_upperbound);
29 
30 template <typename T, typename K>
31 typename std::enable_if<is_complex<K>::value or std::is_arithmetic<K>::value,
32  std::tuple<Image<K>, Image<K>, Image<K>>>::type
33 credible_interval_grid(const Eigen::MatrixBase<T> &solution, const t_uint &rows, const t_uint &cols,
34  const t_uint &grid_pixel_size,
35  const std::function<t_real(typename T::PlainObject)> &objective_function,
36  const t_real &energy_upperbound);
37 
38 template <typename T, typename K>
39 typename std::enable_if<is_complex<K>::value or std::is_arithmetic<K>::value,
40  std::tuple<Image<K>, Image<K>, Image<K>>>::type
41 credible_interval_grid(const Eigen::MatrixBase<T> &solution, const t_uint &rows, const t_uint &cols,
42  const std::tuple<t_uint, t_uint> &grid_pixel_size,
43  const std::function<t_real(typename T::PlainObject)> &objective_function,
44  const t_real &energy_upperbound);
45 template <typename T, typename K>
46 typename std::enable_if<is_complex<K>::value or std::is_arithmetic<K>::value,
47  std::tuple<Image<K>, Image<K>, Image<K>>>::type
48 credible_interval(const Eigen::MatrixBase<T> &solution, const t_uint &rows, const t_uint &cols,
49  const t_uint &grid_pixel_size,
50  const std::function<t_real(typename T::PlainObject)> &objective_function,
51  const t_real &alpha);
52 template <typename T, typename K>
53 typename std::enable_if<is_complex<K>::value or std::is_arithmetic<K>::value,
54  std::tuple<Image<K>, Image<K>, Image<K>>>::type
55 credible_interval(const Eigen::MatrixBase<T> &solution, const t_uint &rows, const t_uint &cols,
56  const std::tuple<t_uint, t_uint> &grid_pixel_size,
57  const std::function<t_real(typename T::PlainObject)> &objective_function,
58  const t_real &alpha);
59 
60 template <typename T>
62  const t_real &alpha, const Eigen::MatrixBase<T> &solution,
63  const std::function<t_real(typename T::PlainObject)> &objective_function) {
64  if (alpha <= 0) SOPT_THROW("α must positive.");
65  if (alpha >= 1) SOPT_THROW("α must less than 1.");
66  const t_real N = solution.size();
67  const t_real energy = objective_function(solution);
68  auto const gamma = energy + N * (std::sqrt(16 * std::log(3 / (1 - alpha)) / N) + 1);
69  SOPT_MEDIUM_LOG("Confidence interval: %{}", 100 * alpha);
70  SOPT_MEDIUM_LOG("γ = {}, g(x_s) = {}", gamma, energy);
71  return gamma;
72 }
73 
74 template <typename T>
75 std::tuple<t_real, t_real, t_real> find_credible_interval(
76  const Eigen::MatrixBase<T> &solution, const t_uint &rows, const t_uint &cols,
77  const std::tuple<t_uint, t_uint, t_uint, t_uint> &region,
78  const std::function<t_real(typename T::PlainObject)> &objective_function,
79  const t_real &energy_upperbound) {
80  using Derived = typename T::PlainObject;
81  assert(energy_upperbound > 0);
82  if (solution.size() != cols * rows) SOPT_THROW("Solution is wrong size for credible interval.");
83  if ((std::get<2>(region) > rows) or (std::get<3>(region) > cols))
84  SOPT_THROW("Region is out of bounds.");
85  if (energy_upperbound <= 0)
86  SOPT_THROW("Energy upper bound is not positive when calculating credible interval.");
87 
88  const std::shared_ptr<Matrix<typename T::Scalar>> varried_solution =
89  std::make_shared<Matrix<typename T::Scalar>>(solution);
90  *varried_solution = Matrix<typename T::Scalar>::Map(varried_solution->data(), rows, cols);
91  const t_real mean = varried_solution
92  ->block(std::get<0>(region), std::get<1>(region), std::get<2>(region),
93  std::get<3>(region))
94  .array()
95  .real()
96  .mean();
97  const t_real b = (mean > 0)
98  ? solution.cwiseAbs().maxCoeff() * 3
99  : std::max(solution.stableNorm(), static_cast<t_real>(solution.size()));
100  std::function<t_real(t_real)> const bound_estimater = [=](const t_real &x) -> t_real {
101  varried_solution
102  ->block(std::get<0>(region), std::get<1>(region), std::get<2>(region), std::get<3>(region))
103  .fill(mean + x);
104  return objective_function(
105  Vector<typename T::Scalar>::Map(varried_solution->data(), varried_solution->size()));
106  };
107 
108  const t_real bound_lower =
109  sopt::bisection_method(energy_upperbound, bound_estimater, -b, 0., 1e-3);
110  const t_real bound_upper =
111  sopt::bisection_method(energy_upperbound, bound_estimater, 0., b, 1e-3);
112  return std::make_tuple(bound_lower, mean, bound_upper);
113 }
114 template <typename T, typename K>
115 typename std::enable_if<is_complex<K>::value or std::is_arithmetic<K>::value,
116  std::tuple<Image<K>, Image<K>, Image<K>>>::type
117 credible_interval_grid(const Eigen::MatrixBase<T> &solution, const t_uint &rows, const t_uint &cols,
118  const t_uint &grid_pixel_size,
119  const std::function<t_real(typename T::PlainObject)> &objective_function,
120  const t_real &energy_upperbound) {
121  const std::tuple<t_uint, t_uint> grid_pixel_size_2d =
122  std::make_tuple(std::min(grid_pixel_size, rows), std::min(grid_pixel_size, cols));
123  return credible_interval_grid<typename T::PlainObject, K>(
124  solution, rows, cols, grid_pixel_size_2d, objective_function, energy_upperbound);
125 }
126 
127 template <typename T, typename K>
128 typename std::enable_if<is_complex<K>::value or std::is_arithmetic<K>::value,
129  std::tuple<Image<K>, Image<K>, Image<K>>>::type
130 credible_interval_grid(const Eigen::MatrixBase<T> &solution, const t_uint &rows, const t_uint &cols,
131  const std::tuple<t_uint, t_uint> &grid_pixel_size,
132  const std::function<t_real(typename T::PlainObject)> &objective_function,
133  const t_real &energy_upperbound) {
134  if ((std::get<0>(grid_pixel_size) > rows) or (std::get<1>(grid_pixel_size) > cols))
135  SOPT_THROW("Grid pixel size too big.");
136  using Derived = typename T::PlainObject;
137  const t_uint drow = std::get<0>(grid_pixel_size);
138  const t_uint dcol = std::get<1>(grid_pixel_size);
139  const t_uint grid_rows = std::floor(static_cast<t_real>(rows) / drow);
140  const t_uint grid_cols = std::floor(static_cast<t_real>(cols) / dcol);
141  Image<K> credible_grid_lower_bound = Image<K>::Zero(grid_rows, grid_cols);
142  Image<K> credible_grid_upper_bound = Image<K>::Zero(grid_rows, grid_cols);
143  Image<K> credible_grid_mean = Image<K>::Zero(grid_rows, grid_cols);
144  SOPT_LOW_LOG("Starting calculation of credible interval: {} x {} grid.", grid_rows, grid_cols);
145  for (t_uint i = 0; i < grid_rows; i++) {
146  for (t_uint j = 0; j < grid_cols; j++) {
147  const t_uint start_row = i * drow;
148  const t_uint start_col = j * dcol;
149  if (static_cast<t_int>(rows - start_row - drow) < 0)
150  SOPT_THROW("Interval grid calculation going out of bounds.");
151  if (static_cast<t_int>(cols - start_col - dcol) < 0)
152  SOPT_THROW("Interval grid calculation going out of bounds.");
153  const t_uint delta_row =
154  ((drow > (rows - start_row - drow)) and ((rows - start_row - drow) > 0))
155  ? rows - start_row - drow
156  : drow;
157  const t_uint delta_col =
158  ((dcol > (cols - start_col - dcol)) and ((cols - start_col - dcol) > 0))
159  ? cols - start_col - dcol
160  : dcol;
161  SOPT_LOW_LOG("Grid pixel ({}, {}): [{}, {}) x [{}, {})", i, j, start_row,
162  start_row + delta_row, start_col, start_col + delta_col);
163  const auto region = std::make_tuple(start_row, start_col, delta_row, delta_col);
164  const std::tuple<t_real, t_real, t_real> bounds = find_credible_interval(
165  solution, rows, cols, region, objective_function, energy_upperbound);
166  SOPT_LOW_LOG("η- = {}, mean = {}, η+ = {}", std::get<0>(bounds), std::get<1>(bounds),
167  std::get<2>(bounds));
168  credible_grid_lower_bound(i, j) = std::get<0>(bounds);
169  credible_grid_mean(i, j) = std::get<1>(bounds);
170  credible_grid_upper_bound(i, j) = std::get<2>(bounds);
171  }
172  }
173  return std::make_tuple(credible_grid_lower_bound, credible_grid_mean, credible_grid_upper_bound);
174 }
175 template <typename T, typename K>
176 typename std::enable_if<is_complex<K>::value or std::is_arithmetic<K>::value,
177  std::tuple<Image<K>, Image<K>, Image<K>>>::type
178 credible_interval(const Eigen::MatrixBase<T> &solution, const t_uint &rows, const t_uint &cols,
179  const std::tuple<t_uint, t_uint> &grid_pixel_size,
180  const std::function<t_real(typename T::PlainObject)> &objective_function,
181  const t_real &alpha) {
182  const t_real energy_upperbound = compute_energy_upper_bound(alpha, solution, objective_function);
183  return credible_interval_grid<typename T::PlainObject, K>(solution, rows, cols, grid_pixel_size,
184  objective_function, energy_upperbound);
185 }
186 template <typename T, typename K>
187 typename std::enable_if<is_complex<K>::value or std::is_arithmetic<K>::value,
188  std::tuple<Image<K>, Image<K>, Image<K>>>::type
189 credible_interval(const Eigen::MatrixBase<T> &solution, const t_uint &rows, const t_uint &cols,
190  const t_uint &grid_pixel_size,
191  const std::function<t_real(typename T::PlainObject)> &objective_function,
192  const t_real &alpha) {
193  const t_real energy_upperbound = compute_energy_upper_bound(alpha, solution, objective_function);
194  return credible_interval_grid<typename T::PlainObject, K>(solution, rows, cols, grid_pixel_size,
195  objective_function, energy_upperbound);
196 }
197 } // namespace sopt::credible_region
198 
199 #endif
constexpr auto N
Definition: wavelets.cc:57
constexpr Scalar b
t_uint rows
t_uint cols
#define SOPT_THROW(MSG)
Definition: exception.h:46
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:225
std::enable_if< is_complex< K >::value or std::is_arithmetic< K >::value, std::tuple< Image< K >, Image< K >, Image< K > > >::type credible_interval_grid(const Eigen::MatrixBase< T > &solution, const t_uint &rows, const t_uint &cols, const t_uint &grid_pixel_size, const std::function< t_real(typename T::PlainObject)> &objective_function, const t_real &energy_upperbound)
std::tuple< t_real, t_real, t_real > find_credible_interval(const Eigen::MatrixBase< T > &solution, const t_uint &rows, const t_uint &cols, const std::tuple< t_uint, t_uint, t_uint, t_uint > &region, const std::function< t_real(typename T::PlainObject)> &objective_function, const t_real &energy_upperbound)
t_real compute_energy_upper_bound(const t_real &alpha, const Eigen::MatrixBase< T > &solution, const std::function< t_real(typename T::PlainObject)> &objective_function)
std::enable_if< is_complex< K >::value or std::is_arithmetic< K >::value, std::tuple< Image< K >, Image< K >, Image< K > > >::type credible_interval(const Eigen::MatrixBase< T > &solution, const t_uint &rows, const t_uint &cols, const t_uint &grid_pixel_size, const std::function< t_real(typename T::PlainObject)> &objective_function, const t_real &alpha)
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
std::enable_if< std::is_same< t_real, K >::value, K >::type bisection_method(const K &function_value, const std::function< K(K)> &func, const K &a, const K &b, const t_real &rel_convergence=1e-4)
Find root to a function within an interval.
Eigen::Array< T, Eigen::Dynamic, Eigen::Dynamic > Image
A 2-dimensional list of elements of given type.
Definition: types.h:39
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > Matrix
A matrix of a given type.
Definition: types.h:29