SOPT
Sparse OPTimisation
Functions
sopt::credible_region Namespace Reference

Functions

template<typename T >
t_real compute_energy_upper_bound (const t_real &alpha, const Eigen::MatrixBase< T > &solution, const std::function< t_real(typename T::PlainObject)> &objective_function)
 
template<typename T >
std::tuple< t_real, t_real, t_realfind_credible_interval (const Eigen::MatrixBase< T > &solution, const t_uint &rows, const t_uint &cols, const std::tuple< t_uint, t_uint, t_uint, t_uint > &region, const std::function< t_real(typename T::PlainObject)> &objective_function, const t_real &energy_upperbound)
 
template<typename T , typename K >
std::enable_if< is_complex< K >::value or std::is_arithmetic< K >::value, std::tuple< Image< K >, Image< K >, Image< K > > >::type credible_interval_grid (const Eigen::MatrixBase< T > &solution, const t_uint &rows, const t_uint &cols, const t_uint &grid_pixel_size, const std::function< t_real(typename T::PlainObject)> &objective_function, const t_real &energy_upperbound)
 
template<typename T , typename K >
std::enable_if< is_complex< K >::value or std::is_arithmetic< K >::value, std::tuple< Image< K >, Image< K >, Image< K > > >::type credible_interval_grid (const Eigen::MatrixBase< T > &solution, const t_uint &rows, const t_uint &cols, const std::tuple< t_uint, t_uint > &grid_pixel_size, const std::function< t_real(typename T::PlainObject)> &objective_function, const t_real &energy_upperbound)
 
template<typename T , typename K >
std::enable_if< is_complex< K >::value or std::is_arithmetic< K >::value, std::tuple< Image< K >, Image< K >, Image< K > > >::type credible_interval (const Eigen::MatrixBase< T > &solution, const t_uint &rows, const t_uint &cols, const t_uint &grid_pixel_size, const std::function< t_real(typename T::PlainObject)> &objective_function, const t_real &alpha)
 
template<typename T , typename K >
std::enable_if< is_complex< K >::value or std::is_arithmetic< K >::value, std::tuple< Image< K >, Image< K >, Image< K > > >::type credible_interval (const Eigen::MatrixBase< T > &solution, const t_uint &rows, const t_uint &cols, const std::tuple< t_uint, t_uint > &grid_pixel_size, const std::function< t_real(typename T::PlainObject)> &objective_function, const t_real &alpha)
 

Function Documentation

◆ compute_energy_upper_bound()

template<typename T >
t_real sopt::credible_region::compute_energy_upper_bound ( const t_real alpha,
const Eigen::MatrixBase< T > &  solution,
const std::function< t_real(typename T::PlainObject)> &  objective_function 
)

Definition at line 61 of file credible_region.h.

63  {
64  if (alpha <= 0) SOPT_THROW("α must positive.");
65  if (alpha >= 1) SOPT_THROW("α must less than 1.");
66  const t_real N = solution.size();
67  const t_real energy = objective_function(solution);
68  auto const gamma = energy + N * (std::sqrt(16 * std::log(3 / (1 - alpha)) / N) + 1);
69  SOPT_MEDIUM_LOG("Confidence interval: %{}", 100 * alpha);
70  SOPT_MEDIUM_LOG("γ = {}, g(x_s) = {}", gamma, energy);
71  return gamma;
72 }
constexpr auto N
Definition: wavelets.cc:57
#define SOPT_THROW(MSG)
Definition: exception.h:46
sopt::t_real t_real
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
Definition: logging.h:225

References N, SOPT_MEDIUM_LOG, and SOPT_THROW.

Referenced by credible_interval(), and TEST_CASE().

◆ credible_interval() [1/2]

template<typename T , typename K >
std::enable_if< is_complex< K >::value or std::is_arithmetic< K >::value, std::tuple< Image< K >, Image< K >, Image< K > > >::type sopt::credible_region::credible_interval ( const Eigen::MatrixBase< T > &  solution,
const t_uint rows,
const t_uint cols,
const std::tuple< t_uint, t_uint > &  grid_pixel_size,
const std::function< t_real(typename T::PlainObject)> &  objective_function,
const t_real alpha 
)

Definition at line 178 of file credible_region.h.

181  {
182  const t_real energy_upperbound = compute_energy_upper_bound(alpha, solution, objective_function);
183  return credible_interval_grid<typename T::PlainObject, K>(solution, rows, cols, grid_pixel_size,
184  objective_function, energy_upperbound);
185 }
t_uint rows
t_uint cols
t_real compute_energy_upper_bound(const t_real &alpha, const Eigen::MatrixBase< T > &solution, const std::function< t_real(typename T::PlainObject)> &objective_function)

References cols, compute_energy_upper_bound(), and rows.

◆ credible_interval() [2/2]

template<typename T , typename K >
std::enable_if< is_complex< K >::value or std::is_arithmetic< K >::value, std::tuple< Image< K >, Image< K >, Image< K > > >::type sopt::credible_region::credible_interval ( const Eigen::MatrixBase< T > &  solution,
const t_uint rows,
const t_uint cols,
const t_uint grid_pixel_size,
const std::function< t_real(typename T::PlainObject)> &  objective_function,
const t_real alpha 
)

Definition at line 189 of file credible_region.h.

192  {
193  const t_real energy_upperbound = compute_energy_upper_bound(alpha, solution, objective_function);
194  return credible_interval_grid<typename T::PlainObject, K>(solution, rows, cols, grid_pixel_size,
195  objective_function, energy_upperbound);
196 }

References cols, compute_energy_upper_bound(), and rows.

◆ credible_interval_grid() [1/2]

template<typename T , typename K >
std::enable_if< is_complex< K >::value or std::is_arithmetic< K >::value, std::tuple< Image< K >, Image< K >, Image< K > > >::type sopt::credible_region::credible_interval_grid ( const Eigen::MatrixBase< T > &  solution,
const t_uint rows,
const t_uint cols,
const std::tuple< t_uint, t_uint > &  grid_pixel_size,
const std::function< t_real(typename T::PlainObject)> &  objective_function,
const t_real energy_upperbound 
)

Definition at line 130 of file credible_region.h.

133  {
134  if ((std::get<0>(grid_pixel_size) > rows) or (std::get<1>(grid_pixel_size) > cols))
135  SOPT_THROW("Grid pixel size too big.");
136  using Derived = typename T::PlainObject;
137  const t_uint drow = std::get<0>(grid_pixel_size);
138  const t_uint dcol = std::get<1>(grid_pixel_size);
139  const t_uint grid_rows = std::floor(static_cast<t_real>(rows) / drow);
140  const t_uint grid_cols = std::floor(static_cast<t_real>(cols) / dcol);
141  Image<K> credible_grid_lower_bound = Image<K>::Zero(grid_rows, grid_cols);
142  Image<K> credible_grid_upper_bound = Image<K>::Zero(grid_rows, grid_cols);
143  Image<K> credible_grid_mean = Image<K>::Zero(grid_rows, grid_cols);
144  SOPT_LOW_LOG("Starting calculation of credible interval: {} x {} grid.", grid_rows, grid_cols);
145  for (t_uint i = 0; i < grid_rows; i++) {
146  for (t_uint j = 0; j < grid_cols; j++) {
147  const t_uint start_row = i * drow;
148  const t_uint start_col = j * dcol;
149  if (static_cast<t_int>(rows - start_row - drow) < 0)
150  SOPT_THROW("Interval grid calculation going out of bounds.");
151  if (static_cast<t_int>(cols - start_col - dcol) < 0)
152  SOPT_THROW("Interval grid calculation going out of bounds.");
153  const t_uint delta_row =
154  ((drow > (rows - start_row - drow)) and ((rows - start_row - drow) > 0))
155  ? rows - start_row - drow
156  : drow;
157  const t_uint delta_col =
158  ((dcol > (cols - start_col - dcol)) and ((cols - start_col - dcol) > 0))
159  ? cols - start_col - dcol
160  : dcol;
161  SOPT_LOW_LOG("Grid pixel ({}, {}): [{}, {}) x [{}, {})", i, j, start_row,
162  start_row + delta_row, start_col, start_col + delta_col);
163  const auto region = std::make_tuple(start_row, start_col, delta_row, delta_col);
164  const std::tuple<t_real, t_real, t_real> bounds = find_credible_interval(
165  solution, rows, cols, region, objective_function, energy_upperbound);
166  SOPT_LOW_LOG("η- = {}, mean = {}, η+ = {}", std::get<0>(bounds), std::get<1>(bounds),
167  std::get<2>(bounds));
168  credible_grid_lower_bound(i, j) = std::get<0>(bounds);
169  credible_grid_mean(i, j) = std::get<1>(bounds);
170  credible_grid_upper_bound(i, j) = std::get<2>(bounds);
171  }
172  }
173  return std::make_tuple(credible_grid_lower_bound, credible_grid_mean, credible_grid_upper_bound);
174 }
#define SOPT_LOW_LOG(...)
Low priority message.
Definition: logging.h:227
std::tuple< t_real, t_real, t_real > find_credible_interval(const Eigen::MatrixBase< T > &solution, const t_uint &rows, const t_uint &cols, const std::tuple< t_uint, t_uint, t_uint, t_uint > &region, const std::function< t_real(typename T::PlainObject)> &objective_function, const t_real &energy_upperbound)
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
sopt::Image< Scalar > Image
Definition: inpainting.cc:30

References cols, find_credible_interval(), rows, SOPT_LOW_LOG, and SOPT_THROW.

◆ credible_interval_grid() [2/2]

template<typename T , typename K >
std::enable_if< is_complex< K >::value or std::is_arithmetic< K >::value, std::tuple< Image< K >, Image< K >, Image< K > > >::type sopt::credible_region::credible_interval_grid ( const Eigen::MatrixBase< T > &  solution,
const t_uint rows,
const t_uint cols,
const t_uint grid_pixel_size,
const std::function< t_real(typename T::PlainObject)> &  objective_function,
const t_real energy_upperbound 
)

Definition at line 117 of file credible_region.h.

120  {
121  const std::tuple<t_uint, t_uint> grid_pixel_size_2d =
122  std::make_tuple(std::min(grid_pixel_size, rows), std::min(grid_pixel_size, cols));
123  return credible_interval_grid<typename T::PlainObject, K>(
124  solution, rows, cols, grid_pixel_size_2d, objective_function, energy_upperbound);
125 }

References cols, and rows.

◆ find_credible_interval()

template<typename T >
std::tuple< t_real, t_real, t_real > sopt::credible_region::find_credible_interval ( const Eigen::MatrixBase< T > &  solution,
const t_uint rows,
const t_uint cols,
const std::tuple< t_uint, t_uint, t_uint, t_uint > &  region,
const std::function< t_real(typename T::PlainObject)> &  objective_function,
const t_real energy_upperbound 
)

Definition at line 75 of file credible_region.h.

79  {
80  using Derived = typename T::PlainObject;
81  assert(energy_upperbound > 0);
82  if (solution.size() != cols * rows) SOPT_THROW("Solution is wrong size for credible interval.");
83  if ((std::get<2>(region) > rows) or (std::get<3>(region) > cols))
84  SOPT_THROW("Region is out of bounds.");
85  if (energy_upperbound <= 0)
86  SOPT_THROW("Energy upper bound is not positive when calculating credible interval.");
87 
88  const std::shared_ptr<Matrix<typename T::Scalar>> varried_solution =
89  std::make_shared<Matrix<typename T::Scalar>>(solution);
90  *varried_solution = Matrix<typename T::Scalar>::Map(varried_solution->data(), rows, cols);
91  const t_real mean = varried_solution
92  ->block(std::get<0>(region), std::get<1>(region), std::get<2>(region),
93  std::get<3>(region))
94  .array()
95  .real()
96  .mean();
97  const t_real b = (mean > 0)
98  ? solution.cwiseAbs().maxCoeff() * 3
99  : std::max(solution.stableNorm(), static_cast<t_real>(solution.size()));
100  std::function<t_real(t_real)> const bound_estimater = [=](const t_real &x) -> t_real {
101  varried_solution
102  ->block(std::get<0>(region), std::get<1>(region), std::get<2>(region), std::get<3>(region))
103  .fill(mean + x);
104  return objective_function(
105  Vector<typename T::Scalar>::Map(varried_solution->data(), varried_solution->size()));
106  };
107 
108  const t_real bound_lower =
109  sopt::bisection_method(energy_upperbound, bound_estimater, -b, 0., 1e-3);
110  const t_real bound_upper =
111  sopt::bisection_method(energy_upperbound, bound_estimater, 0., b, 1e-3);
112  return std::make_tuple(bound_lower, mean, bound_upper);
113 }
constexpr Scalar b
std::enable_if< std::is_same< t_real, K >::value, K >::type bisection_method(const K &function_value, const std::function< K(K)> &func, const K &a, const K &b, const t_real &rel_convergence=1e-4)
Find root to a function within an interval.
sopt::Vector< Scalar > Vector
Definition: inpainting.cc:28
sopt::Matrix< Scalar > Matrix
Definition: inpainting.cc:29

References b, sopt::bisection_method(), cols, rows, and SOPT_THROW.

Referenced by credible_interval_grid(), and TEST_CASE().