1 #ifndef SOPT_IMAGING_FORWARD_BACKWARD_H
2 #define SOPT_IMAGING_FORWARD_BACKWARD_H
4 #include "sopt/config.h"
30 template <
typename SCALAR>
33 using FB = ForwardBackward<SCALAR>;
38 using Real =
typename FB::Real;
51 : FB::
Diagnostic(niters, good, std::move(residual)) {}
70 : g_function_(nullptr),
72 random_updater_(nullptr),
74 residual_tolerance_(0.),
75 relative_variation_(1e-4),
76 residual_convergence_(nullptr),
77 objective_convergence_(nullptr),
78 itermax_(std::numeric_limits<
t_uint>::max()),
79 regulariser_strength_(1e-8),
85 std::shared_ptr<t_LinearTransform> Id = std::make_shared<t_LinearTransform>(linear_transform_identity<Scalar>());
86 problem_state = std::make_shared<IterationState<t_Vector>>(
target, Id);
90 : g_function_(nullptr),
92 random_updater_(updater),
94 residual_tolerance_(0.),
95 relative_variation_(1e-4),
96 residual_convergence_(nullptr),
97 objective_convergence_(nullptr),
98 itermax_(std::numeric_limits<
t_uint>::max()),
99 regulariser_strength_(1e-8),
108 problem_state = random_updater_();
112 throw std::runtime_error(
"Attempted to construct ImagingForwardBackward class with a null random updater. To run without random updates supply a target vector instead.");
119 #define SOPT_MACRO(NAME, TYPE) \
120 TYPE const &NAME() const { return NAME##_; } \
121 ImagingForwardBackward<SCALAR> &NAME(TYPE const &(NAME)) { \
161 problem_state->Phi(
Phi);
168 SOPT_MACRO(adjoint_space_comm, mpi::Communicator);
176 std::shared_ptr<NonDifferentiableFunc<SCALAR>>
g_function() {
return g_function_; }
185 std::shared_ptr<DifferentiableFunc<SCALAR>>
f_function() {
return f_function_; }
194 random_updater_ = new_updater;
200 return (g_function_) ? g_function_->Psi() : linear_transform_identity<Scalar>();
219 problem_state->target(
target);
233 return operator()(out, std::get<0>(guess), std::get<1>(guess));
239 std::tuple<t_Vector const &, t_Vector const &>
const &guess) {
240 return operator()(out, std::get<0>(guess), std::get<1>(guess));
245 return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
250 std::tuple<t_Vector const &, t_Vector const &>
const &guess) {
252 static_cast<Diagnostic &
>(result) =
operator()(result.
x, guess);
259 static_cast<Diagnostic &
>(result) =
operator()(
260 result.
x, ForwardBackward<SCALAR>::initial_guess(
target(),
Phi()));
266 static_cast<Diagnostic &
>(result) =
operator()(result.
x, warmstart.
x, warmstart.residual);
271 template <
typename... ARGS>
295 std::shared_ptr<NonDifferentiableFunc<SCALAR>> g_function_;
296 std::shared_ptr<DifferentiableFunc<SCALAR>> f_function_;
299 std::shared_ptr<IterationState<t_Vector>> problem_state;
302 mutable Real objmin_;
328 template <
typename SCALAR>
333 throw std::runtime_error(
"Non-differentiable function `g` has not been set. You must set it with `g_function()` before calling the algorithm.");
335 g_function_->log_message();
337 auto const g_proximal = g_function_->proximal_operator();
338 t_Gradient f_gradient;
339 Real gradient_step_size;
342 f_gradient = f_function_->gradient();
343 gradient_step_size = f_function_->get_step_size();
347 SOPT_MEDIUM_LOG(
"Gradient function has not been set; using default (gaussian likelihood) gradient. (To set a custom gradient set_gradient() must be called before the algorithm is run.)");
349 output = Phi.adjoint() * (residual / (this->
sigma() * this->
sigma()));
353 ScalarRelativeVariation<Scalar> scalvar(relative_variation(), relative_variation(),
354 "Objective function");
355 auto const convergence = [
this, &scalvar](
t_Vector const &x,
t_Vector const &residual)
mutable {
356 const bool result = this->is_converged(scalvar, x, residual);
357 this->objmin_ = std::real(scalvar.previous());
360 auto fb = ForwardBackward<SCALAR>(f_gradient, g_proximal,
target())
362 .step_size(gradient_step_size)
363 .regulariser_strength(regulariser_strength())
366 .is_converged(convergence)
367 .random_updater(random_updater_)
368 .set_problem_state(problem_state);
369 static_cast<typename ForwardBackward<SCALAR>::Diagnostic &
>(result) =
370 fb(out, std::tie(guess, res));
374 template <
typename SCALAR>
377 if (
static_cast<bool>(residual_convergence()))
return residual_convergence()(x, residual);
378 if (residual_tolerance() <= 0e0)
return true;
380 SOPT_LOW_LOG(
" - [FB] Residuals: {} <? {}", residual_norm, residual_tolerance());
381 return residual_norm < residual_tolerance();
384 template <
typename SCALAR>
388 if (
static_cast<bool>(objective_convergence()))
return objective_convergence()(x, residual);
389 if (scalvar.relative_tolerance() <= 0e0)
return true;
390 auto const current = ((regulariser_strength() > 0) ? g_function_->function(x)
391 * regulariser_strength() : 0) + \
393 return scalvar(current);
397 template <
typename SCALAR>
399 ScalarRelativeVariation<Scalar> &scalvar,
402 if (
static_cast<bool>(objective_convergence()))
return objective_convergence()(x, residual);
403 if (scalvar.relative_tolerance() <= 0e0)
return true;
404 auto const current = obj_comm.all_sum_all<
t_real>(
405 ((regulariser_strength() > 0) ? g_function_->function(x)
406 * regulariser_strength() : 0) + std::pow(
sopt::
l2_norm(residual), 2) / (2 * sigma_ * sigma_));
407 return scalvar(current);
411 template <
typename SCALAR>
415 auto const user =
static_cast<bool>(is_converged()) ==
false or is_converged()(x, residual);
416 auto const res = residual_convergence(x, residual);
418 auto const obj = objective_convergence(obj_comm(), scalvar, x, residual);
420 auto const obj = objective_convergence(scalvar, x, residual);
424 return user and res and obj;
sopt::Vector< Scalar > t_Vector
Real objmin() const
Minimum of objective_function.
typename FB::t_Gradient t_Gradient
typename FB::Scalar Scalar
typename FB::t_LinearTransform t_LinearTransform
Diagnostic operator()(t_Vector &out) const
Calls Forward Backward.
ImagingForwardBackward< SCALAR > & random_updater(t_randomUpdater &new_updater)
std::shared_ptr< NonDifferentiableFunc< SCALAR > > g_function()
t_LinearTransform const & Psi() const
ImagingForwardBackward(t_randomUpdater &updater)
ImagingForwardBackward(t_Vector const &target)
std::shared_ptr< DifferentiableFunc< SCALAR > > f_function()
typename FB::t_Proximal t_Proximal
SOPT_MACRO(sigma, Real)
γ parameter.
ImagingForwardBackward< SCALAR > & Phi(t_LinearTransform const &(Phi))
t_LinearTransform const & Phi() const
Measurement operator.
typename FB::value_type value_type
DiagnosticAndResult operator()(std::tuple< t_Vector, t_Vector > const &guess)
Calls Forward Backward.
ImagingForwardBackward< Scalar > & target(t_Vector const &target)
Sets the vector of target measurements.
ImagingForwardBackward< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
typename FB::t_Vector t_Vector
t_randomUpdater & random_updater()
DiagnosticAndResult operator()(std::tuple< t_Vector const &, t_Vector const & > const &guess)
Calls Forward Backward.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector const &, t_Vector const & > const &guess)
Calls Forward Backward.
ImagingForwardBackward< SCALAR > & g_function(std::shared_ptr< NonDifferentiableFunc< SCALAR >> g_function)
SOPT_MACRO(relative_variation, Real)
Convergence of the relative variation of the objective functions.
ImagingForwardBackward< Scalar > & residual_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
SOPT_MACRO(step_size, Real)
γ parameter.
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart)
Makes it simple to chain different calls to FB.
typename FB::t_IsConverged t_IsConverged
ImagingForwardBackward< Scalar > & objective_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
SOPT_MACRO(regulariser_strength, Real)
γ parameter.
ImagingForwardBackward &::type Phi(ARGS &&... args)
typename std::function< void(t_Vector &, const t_Vector &)> t_l2Gradient
ImagingForwardBackward< SCALAR > & f_function(std::shared_ptr< DifferentiableFunc< SCALAR >> f_function)
DiagnosticAndResult operator()()
Calls Forward Backward.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector, t_Vector > const &guess)
Calls Forward Backward.
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SOPT_MACRO(objective_convergence, t_IsConverged)
Convergence of the residuals.
SOPT_MACRO(fista, bool)
flag to for FISTA Forward-Backward algorithm. True by default but should be false when using a learne...
typename FB::t_randomUpdater t_randomUpdater
SOPT_MACRO(residual_convergence, t_IsConverged)
Convergence of the residuals.
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
t_Vector const & target() const
Vector of target measurements.
virtual ~ImagingForwardBackward()
SOPT_MACRO(tight_frame, bool)
Whether Ψ is a tight-frame or not.
SOPT_MACRO(residual_tolerance, Real)
Convergence of the relative variation of the objective functions.
sopt::LinearTransform< t_Vector > t_LinearTransform
#define SOPT_LOW_LOG(...)
Low priority message.
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
size_t t_uint
Root of the type hierarchy for unsigned integers.
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
real_type< typename T0::Scalar >::type l2_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L2 norm.
Holds result vector as well.
Values indicating how the algorithm ran.
Diagnostic(t_uint niters=0u, bool good=false)
Diagnostic(t_uint niters, bool good, t_Vector &&residual)