1 #ifndef SOPT_CREDIBLE_REGION_H
2 #define SOPT_CREDIBLE_REGION_H
4 #include "sopt/config.h"
10 #include <type_traits>
20 const t_real &alpha,
const Eigen::MatrixBase<T> &solution,
21 const std::function<
t_real(
typename T::PlainObject)> &objective_function);
26 const std::tuple<t_uint, t_uint, t_uint, t_uint> ®ion,
27 const std::function<
t_real(
typename T::PlainObject)> &objective_function,
28 const t_real &energy_upperbound);
30 template <
typename T,
typename K>
31 typename std::enable_if<is_complex<K>::value or std::is_arithmetic<K>::value,
34 const t_uint &grid_pixel_size,
35 const std::function<
t_real(
typename T::PlainObject)> &objective_function,
36 const t_real &energy_upperbound);
38 template <
typename T,
typename K>
39 typename std::enable_if<is_complex<K>::value or std::is_arithmetic<K>::value,
42 const std::tuple<t_uint, t_uint> &grid_pixel_size,
43 const std::function<
t_real(
typename T::PlainObject)> &objective_function,
44 const t_real &energy_upperbound);
45 template <
typename T,
typename K>
46 typename std::enable_if<is_complex<K>::value or std::is_arithmetic<K>::value,
49 const t_uint &grid_pixel_size,
50 const std::function<
t_real(
typename T::PlainObject)> &objective_function,
52 template <
typename T,
typename K>
53 typename std::enable_if<is_complex<K>::value or std::is_arithmetic<K>::value,
56 const std::tuple<t_uint, t_uint> &grid_pixel_size,
57 const std::function<
t_real(
typename T::PlainObject)> &objective_function,
62 const t_real &alpha,
const Eigen::MatrixBase<T> &solution,
63 const std::function<
t_real(
typename T::PlainObject)> &objective_function) {
64 if (alpha <= 0)
SOPT_THROW(
"α must positive.");
65 if (alpha >= 1)
SOPT_THROW(
"α must less than 1.");
66 const t_real N = solution.size();
67 const t_real energy = objective_function(solution);
68 auto const gamma = energy +
N * (std::sqrt(16 * std::log(3 / (1 - alpha)) /
N) + 1);
77 const std::tuple<t_uint, t_uint, t_uint, t_uint> ®ion,
78 const std::function<
t_real(
typename T::PlainObject)> &objective_function,
79 const t_real &energy_upperbound) {
80 using Derived =
typename T::PlainObject;
81 assert(energy_upperbound > 0);
82 if (solution.size() !=
cols *
rows)
SOPT_THROW(
"Solution is wrong size for credible interval.");
83 if ((std::get<2>(region) >
rows) or (std::get<3>(region) >
cols))
85 if (energy_upperbound <= 0)
86 SOPT_THROW(
"Energy upper bound is not positive when calculating credible interval.");
88 const std::shared_ptr<Matrix<typename T::Scalar>> varried_solution =
89 std::make_shared<Matrix<typename T::Scalar>>(solution);
91 const t_real mean = varried_solution
92 ->block(std::get<0>(region), std::get<1>(region), std::get<2>(region),
98 ? solution.cwiseAbs().maxCoeff() * 3
99 : std::max(solution.stableNorm(),
static_cast<t_real>(solution.size()));
102 ->block(std::get<0>(region), std::get<1>(region), std::get<2>(region), std::get<3>(region))
104 return objective_function(
108 const t_real bound_lower =
110 const t_real bound_upper =
112 return std::make_tuple(bound_lower, mean, bound_upper);
114 template <
typename T,
typename K>
115 typename std::enable_if<is_complex<K>::value or std::is_arithmetic<K>::value,
118 const t_uint &grid_pixel_size,
119 const std::function<
t_real(
typename T::PlainObject)> &objective_function,
120 const t_real &energy_upperbound) {
121 const std::tuple<t_uint, t_uint> grid_pixel_size_2d =
122 std::make_tuple(std::min(grid_pixel_size,
rows), std::min(grid_pixel_size,
cols));
123 return credible_interval_grid<typename T::PlainObject, K>(
124 solution,
rows,
cols, grid_pixel_size_2d, objective_function, energy_upperbound);
127 template <
typename T,
typename K>
128 typename std::enable_if<is_complex<K>::value or std::is_arithmetic<K>::value,
131 const std::tuple<t_uint, t_uint> &grid_pixel_size,
132 const std::function<
t_real(
typename T::PlainObject)> &objective_function,
133 const t_real &energy_upperbound) {
134 if ((std::get<0>(grid_pixel_size) >
rows) or (std::get<1>(grid_pixel_size) >
cols))
136 using Derived =
typename T::PlainObject;
137 const t_uint drow = std::get<0>(grid_pixel_size);
138 const t_uint dcol = std::get<1>(grid_pixel_size);
144 SOPT_LOW_LOG(
"Starting calculation of credible interval: {} x {} grid.", grid_rows, grid_cols);
145 for (
t_uint i = 0; i < grid_rows; i++) {
146 for (
t_uint j = 0; j < grid_cols; j++) {
147 const t_uint start_row = i * drow;
148 const t_uint start_col = j * dcol;
149 if (
static_cast<t_int>(
rows - start_row - drow) < 0)
150 SOPT_THROW(
"Interval grid calculation going out of bounds.");
151 if (
static_cast<t_int>(
cols - start_col - dcol) < 0)
152 SOPT_THROW(
"Interval grid calculation going out of bounds.");
154 ((drow > (
rows - start_row - drow)) and ((
rows - start_row - drow) > 0))
155 ?
rows - start_row - drow
158 ((dcol > (
cols - start_col - dcol)) and ((
cols - start_col - dcol) > 0))
159 ?
cols - start_col - dcol
161 SOPT_LOW_LOG(
"Grid pixel ({}, {}): [{}, {}) x [{}, {})", i, j, start_row,
162 start_row + delta_row, start_col, start_col + delta_col);
163 const auto region = std::make_tuple(start_row, start_col, delta_row, delta_col);
165 solution,
rows,
cols, region, objective_function, energy_upperbound);
166 SOPT_LOW_LOG(
"η- = {}, mean = {}, η+ = {}", std::get<0>(bounds), std::get<1>(bounds),
167 std::get<2>(bounds));
168 credible_grid_lower_bound(i, j) = std::get<0>(bounds);
169 credible_grid_mean(i, j) = std::get<1>(bounds);
170 credible_grid_upper_bound(i, j) = std::get<2>(bounds);
173 return std::make_tuple(credible_grid_lower_bound, credible_grid_mean, credible_grid_upper_bound);
175 template <
typename T,
typename K>
176 typename std::enable_if<is_complex<K>::value or std::is_arithmetic<K>::value,
179 const std::tuple<t_uint, t_uint> &grid_pixel_size,
180 const std::function<
t_real(
typename T::PlainObject)> &objective_function,
183 return credible_interval_grid<typename T::PlainObject, K>(solution,
rows,
cols, grid_pixel_size,
184 objective_function, energy_upperbound);
186 template <
typename T,
typename K>
187 typename std::enable_if<is_complex<K>::value or std::is_arithmetic<K>::value,
190 const t_uint &grid_pixel_size,
191 const std::function<
t_real(
typename T::PlainObject)> &objective_function,
194 return credible_interval_grid<typename T::PlainObject, K>(solution,
rows,
cols, grid_pixel_size,
195 objective_function, energy_upperbound);
#define SOPT_LOW_LOG(...)
Low priority message.
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
std::enable_if< is_complex< K >::value or std::is_arithmetic< K >::value, std::tuple< Image< K >, Image< K >, Image< K > > >::type credible_interval_grid(const Eigen::MatrixBase< T > &solution, const t_uint &rows, const t_uint &cols, const t_uint &grid_pixel_size, const std::function< t_real(typename T::PlainObject)> &objective_function, const t_real &energy_upperbound)
std::tuple< t_real, t_real, t_real > find_credible_interval(const Eigen::MatrixBase< T > &solution, const t_uint &rows, const t_uint &cols, const std::tuple< t_uint, t_uint, t_uint, t_uint > ®ion, const std::function< t_real(typename T::PlainObject)> &objective_function, const t_real &energy_upperbound)
t_real compute_energy_upper_bound(const t_real &alpha, const Eigen::MatrixBase< T > &solution, const std::function< t_real(typename T::PlainObject)> &objective_function)
std::enable_if< is_complex< K >::value or std::is_arithmetic< K >::value, std::tuple< Image< K >, Image< K >, Image< K > > >::type credible_interval(const Eigen::MatrixBase< T > &solution, const t_uint &rows, const t_uint &cols, const t_uint &grid_pixel_size, const std::function< t_real(typename T::PlainObject)> &objective_function, const t_real &alpha)
int t_int
Root of the type hierarchy for signed integers.
double t_real
Root of the type hierarchy for real numbers.
size_t t_uint
Root of the type hierarchy for unsigned integers.
std::enable_if< std::is_same< t_real, K >::value, K >::type bisection_method(const K &function_value, const std::function< K(K)> &func, const K &a, const K &b, const t_real &rel_convergence=1e-4)
Find root to a function within an interval.
Eigen::Array< T, Eigen::Dynamic, Eigen::Dynamic > Image
A 2-dimensional list of elements of given type.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > Matrix
A matrix of a given type.