1 #ifndef SOPT_L2_FORWARD_BACKWARD_H
2 #define SOPT_L2_FORWARD_BACKWARD_H
4 #include "sopt/config.h"
23 template <
typename SCALAR>
26 using FB = ForwardBackward<SCALAR>;
31 using Real =
typename FB::Real;
43 : FB::
Diagnostic(niters, good, std::move(residual)) {}
54 template <
typename DERIVED>
57 proximal::
l2_norm(output, regulariser_strength, x);
59 l2_proximal_weighted_(
65 output = Phi.adjoint()*residual;
68 residual_tolerance_(0.),
69 relative_variation_(1e-4),
70 residual_convergence_(
nullptr),
71 objective_convergence_(
nullptr),
72 itermax_(std::numeric_limits<t_uint>::max()),
73 regulariser_strength_(1e-8),
77 Phi_(linear_transform_identity<Scalar>()),
83 #define SOPT_MACRO(NAME, TYPE) \
84 TYPE const &NAME() const { return NAME##_; } \
85 L2ForwardBackward<SCALAR> &NAME(TYPE const &(NAME)) { \
140 template <
typename DERIVED>
149 return operator()(out, ForwardBackward<SCALAR>::initial_guess(
target(), Phi()));
155 return operator()(out, std::get<0>(guess), std::get<1>(guess));
161 std::tuple<t_Vector const &, t_Vector const &>
const &guess)
const {
162 return operator()(out, std::get<0>(guess), std::get<1>(guess));
167 return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
172 std::tuple<t_Vector const &, t_Vector const &>
const &guess)
const {
174 static_cast<Diagnostic &
>(result) =
operator()(result.
x, guess);
181 static_cast<Diagnostic &
>(result) =
operator()(
182 result.
x, ForwardBackward<SCALAR>::initial_guess(
target(), Phi()));
188 static_cast<Diagnostic &
>(result) =
operator()(result.
x, warmstart.
x, warmstart.residual);
193 template <
typename... ARGS>
210 return residual_convergence(
nullptr).residual_tolerance(tolerance);
214 return objective_convergence(
nullptr).relative_variation(tolerance);
218 return is_converged([func](
t_Vector const &x,
t_Vector const &) {
return func(x); });
225 mutable Real objmin_;
234 bool residual_convergence(
t_Vector const &x,
t_Vector const &residual)
const;
241 bool objective_convergence(mpi::Communicator
const &obj_comm,
251 template <
typename SCALAR>
252 typename L2ForwardBackward<SCALAR>::Diagnostic L2ForwardBackward<SCALAR>::operator()(
254 SOPT_HIGH_LOG(
"Performing Forward Backward with L2 and L2 norms");
257 auto const g_proximal = [
this](
t_Vector &out, Real regulariser_strength,
t_Vector const &x) {
258 if (this->l2_proximal_weights().size() > 1)
259 this->l2_proximal_weighted()(out, this->l2_proximal_weights() * regulariser_strength, x);
261 this->l2_proximal()(out, this->l2_proximal_weights()(0) * regulariser_strength, x);
266 temp = res / sigma_factor;
267 out = Phi.adjoint() * temp;
269 ScalarRelativeVariation<Scalar> scalvar(relative_variation(), relative_variation(),
270 "Objective function");
271 auto const convergence = [
this, &scalvar](
t_Vector const &x,
t_Vector const &residual)
mutable {
272 const bool result = this->is_converged(scalvar, x, residual);
273 this->objmin_ = std::real(scalvar.previous());
276 auto fb = ForwardBackward<SCALAR>(f_gradient, g_proximal,
target())
278 .step_size(step_size())
279 .regulariser_strength(regulariser_strength())
281 .is_converged(convergence);
282 static_cast<typename ForwardBackward<SCALAR>::Diagnostic &
>(result) =
283 fb(out, std::tie(guess, res));
287 template <
typename SCALAR>
288 bool L2ForwardBackward<SCALAR>::residual_convergence(
t_Vector const &x,
290 if (
static_cast<bool>(residual_convergence()))
return residual_convergence()(x, residual);
291 if (residual_tolerance() <= 0e0)
return true;
293 SOPT_LOW_LOG(
" - [FB] Residuals: {} <? {}", residual_norm, residual_tolerance());
294 return residual_norm < residual_tolerance();
297 template <
typename SCALAR>
298 bool L2ForwardBackward<SCALAR>::objective_convergence(ScalarRelativeVariation<Scalar> &scalvar,
301 if (
static_cast<bool>(objective_convergence()))
return objective_convergence()(x, residual);
302 if (scalvar.relative_tolerance() <= 0e0)
return true;
303 auto const current = ((regulariser_strength() > 0) ?
sopt::l2_norm(x, l2_proximal_weights()) * regulariser_strength() : 0) +
305 return scalvar(current);
309 template <
typename SCALAR>
310 bool L2ForwardBackward<SCALAR>::objective_convergence(mpi::Communicator
const &obj_comm,
311 ScalarRelativeVariation<Scalar> &scalvar,
314 if (
static_cast<bool>(objective_convergence()))
return objective_convergence()(x, residual);
315 if (scalvar.relative_tolerance() <= 0e0)
return true;
316 auto const current = obj_comm.all_sum_all<
t_real>(
317 ((regulariser_strength() > 0) ?
sopt::l2_norm(x, l2_proximal_weights()) * regulariser_strength() : 0) +
319 return scalvar(current);
323 template <
typename SCALAR>
324 bool L2ForwardBackward<SCALAR>::is_converged(ScalarRelativeVariation<Scalar> &scalvar,
327 auto const user =
static_cast<bool>(is_converged()) ==
false or is_converged()(x, residual);
328 auto const res = residual_convergence(x, residual);
330 auto const obj = objective_convergence(obj_comm(), scalvar, x, residual);
332 auto const obj = objective_convergence(scalvar, x, residual);
336 return user and res and obj;
sopt::Vector< Scalar > t_Vector
SOPT_MACRO(residual_tolerance, Real)
Convergence of the relative variation of the objective functions.
SOPT_MACRO(Phi, t_LinearTransform)
Measurement operator.
SOPT_MACRO(sigma, Real)
γ parameter.
DiagnosticAndResult operator()(std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Forward Backward.
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
L2ForwardBackward(Eigen::MatrixBase< DERIVED > const &target)
t_Proximal< Real > & l2_proximal()
L1 proximal used during calculation.
t_Gradient & l2_gradient()
Proximal of the L2 ball.
Real objmin() const
Minimun of objective_function.
typename FB::Scalar Scalar
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector, t_Vector > const &guess) const
Calls Forward Backward.
DiagnosticAndResult operator()(std::tuple< t_Vector, t_Vector > const &guess) const
Calls Forward Backward.
SOPT_MACRO(tight_frame, bool)
Whether Ψ is a tight-frame or not.
typename FB::t_LinearTransform t_LinearTransform
L2ForwardBackward< Scalar > & objective_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
typename FB::t_Gradient t_Gradient
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const
Makes it simple to chain different calls to FB.
std::function< void(t_Vector &, const T &, const t_Vector &)> t_Proximal
typename FB::t_Vector t_Vector
DiagnosticAndResult operator()() const
Calls Forward Backward.
L2ForwardBackward &::type Phi(ARGS &&... args)
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SOPT_MACRO(step_size, Real)
γ parameter.
L2ForwardBackward< Scalar > & target(Eigen::MatrixBase< DERIVED > const &target)
Sets the vector of target measurements.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Forward Backward.
SOPT_MACRO(objective_convergence, t_IsConverged)
Convergence of the residuals.
L2ForwardBackward< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
SOPT_MACRO(l2_proximal, t_Proximal< Real >)
l2 proximal for regularizaiton
SOPT_MACRO(l2_gradient, t_Gradient)
Gradient of the l2 norm.
SOPT_MACRO(residual_convergence, t_IsConverged)
Convergence of the residuals.
L2ForwardBackward< Scalar > & residual_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
SOPT_MACRO(l2_proximal_weighted, t_Proximal< Vector< Real >>)
l2 proximal for regularizaiton with weights
Diagnostic operator()(t_Vector &out) const
Calls Forward Backward.
typename FB::value_type value_type
SOPT_MACRO(relative_variation, Real)
Convergence of the relative variation of the objective functions.
t_Proximal< Vector< Real > > & l2_proximal_weighted()
t_Vector const & target() const
Vector of target measurements.
typename FB::t_IsConverged t_IsConverged
SOPT_MACRO(regulariser_strength, Real)
γ parameter.
SOPT_MACRO(l2_proximal_weights, Vector< Real >)
l2 proximal weights
virtual ~L2ForwardBackward()
sopt::LinearTransform< t_Vector > t_LinearTransform
#define SOPT_MACRO(NAME, TYPE)
#define SOPT_LOW_LOG(...)
Low priority message.
#define SOPT_HIGH_LOG(...)
High priority message.
void l2_norm(Eigen::DenseBase< T0 > &out, typename real_type< typename T0::Scalar >::type gamma, Eigen::DenseBase< T1 > const &x)
Proximal of the l2 norm (note this is different from the l2 ball indicator function)
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
double t_real
Root of the type hierarchy for real numbers.
size_t t_uint
Root of the type hierarchy for unsigned integers.
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
real_type< typename T0::Scalar >::type l2_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L2 norm.
Holds result vector as well.
Values indicating how the algorithm ran.
Diagnostic(t_uint niters=0u, bool good=false)
Diagnostic(t_uint niters, bool good, t_Vector &&residual)