1 #ifndef SOPT_FORWARD_BACKWARD_H
2 #define SOPT_FORWARD_BACKWARD_H
4 #include "sopt/config.h"
26 template <
typename SCALAR>
27 class ForwardBackward {
30 using value_type = SCALAR;
40 using t_IsConverged = std::function<bool(
const t_Vector &,
const t_Vector &)>;
46 using t_randomUpdater = std::function<std::shared_ptr<IterationState<t_Vector>>()>;
57 Diagnostic(
t_uint niters = 0u,
bool good =
false)
58 : niters(niters), good(good), residual(t_Vector::Zero(0)) {}
60 : niters(niters), good(good), residual(std::move(residual)) {}
63 struct DiagnosticAndResult :
public Diagnostic {
71 ForwardBackward(t_Gradient
const &f_gradient, t_Proximal
const &g_proximal,
73 : itermax_(std::numeric_limits<t_uint>::max()),
74 regulariser_strength_(1e-8),
78 f_gradient_(f_gradient),
79 g_proximal_(g_proximal)
81 std::shared_ptr<t_LinearTransform> Id = std::make_shared<t_LinearTransform>(linear_transform_identity<Scalar>());
82 problem_state = std::make_shared<IterationState<t_Vector>>(
target, Id);
84 virtual ~ForwardBackward() {}
88 #define SOPT_MACRO(NAME, TYPE) \
89 TYPE const &NAME() const { return NAME##_; } \
90 ForwardBackward<SCALAR> &NAME(TYPE const &(NAME)) { \
119 problem_state->Phi(new_phi);
123 ForwardBackward<SCALAR> &random_updater(t_randomUpdater &rU)
125 random_updater_ = rU;
139 void g_proximal(
t_Vector &out, Real regulariser_strength,
t_Vector const &x)
const {
140 g_proximal()(out, regulariser_strength, x);
144 ForwardBackward<Scalar> &is_converged(std::function<
bool(
t_Vector const &x)>
const &func) {
145 return is_converged([func](
t_Vector const &x,
t_Vector const &) {
return func(x); });
149 t_Vector const &
target()
const {
return problem_state->target(); }
152 problem_state->target(
target);
158 return static_cast<bool>(is_converged()) and is_converged()(x, residual);
163 Diagnostic operator()(
t_Vector &out) {
return operator()(out, initial_guess()); }
167 Diagnostic operator()(
t_Vector &out, std::tuple<t_Vector, t_Vector>
const &guess) {
168 return operator()(out, std::get<0>(guess), std::get<1>(guess));
173 Diagnostic operator()(
t_Vector &out,
174 std::tuple<t_Vector const &, t_Vector const &>
const &guess) {
175 return operator()(out, std::get<0>(guess), std::get<1>(guess));
179 DiagnosticAndResult operator()(std::tuple<t_Vector, t_Vector>
const &guess) {
180 return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
184 DiagnosticAndResult operator()(
185 std::tuple<t_Vector const &, t_Vector const &>
const &guess) {
186 DiagnosticAndResult result;
187 static_cast<Diagnostic &
>(result) =
operator()(result.x, guess);
192 DiagnosticAndResult operator()() {
193 DiagnosticAndResult result;
194 static_cast<Diagnostic &
>(result) =
operator()(result.x, initial_guess());
198 DiagnosticAndResult operator()(DiagnosticAndResult
const &warmstart) {
199 DiagnosticAndResult result = warmstart;
200 static_cast<Diagnostic &
>(result) =
operator()(result.x, warmstart.x, warmstart.residual);
204 template <
typename... ARGS>
205 typename std::enable_if<
sizeof...(ARGS) >= 1, ForwardBackward &>::type Phi(ARGS &&... args) {
214 std::tuple<t_Vector, t_Vector> initial_guess()
const {
215 return ForwardBackward<SCALAR>::initial_guess(
target(), Phi());
224 static std::tuple<t_Vector, t_Vector> initial_guess(
t_Vector const &
target,
226 std::tuple<t_Vector, t_Vector> guess;
228 std::get<1>(guess) = phi * std::get<0>(guess) -
target;
237 void sanity_check(
t_Vector const &x_guess,
t_Vector const &res_guess)
const {
238 if ((Phi().adjoint() *
target()).size() != x_guess.size())
239 SOPT_THROW(
"target, adjoint measurement operator and input vector have inconsistent sizes");
240 if (
target().size() != res_guess.size())
241 SOPT_THROW(
"target and residual vector have inconsistent sizes");
242 if ((Phi() * x_guess).size() !=
target().size())
243 SOPT_THROW(
"target, measurement operator and input vector have inconsistent sizes");
244 if (not
static_cast<bool>(is_converged()))
245 SOPT_WARN(
"No convergence function was provided: algorithm will run for {} steps", itermax());
255 std::shared_ptr<IterationState<t_Vector>> problem_state;
256 t_randomUpdater random_updater_;
273 template <
typename SCALAR>
277 f_gradient(gradient_current, auxilliary_image, residual, Phi());
278 t_Vector auxilliary_with_step = auxilliary_image - step_size() / Phi().sq_norm() * gradient_current;
279 const Real weight = regulariser_strength() * step_size();
280 g_proximal(image, weight, auxilliary_with_step);
281 auxilliary_image = image + FISTA_step * (image - prev_image);
286 problem_state = random_updater_();
288 residual = (Phi() * auxilliary_image) -
target();
291 template <
typename SCALAR>
292 typename ForwardBackward<SCALAR>::Diagnostic ForwardBackward<SCALAR>::operator()(
300 sanity_check(x_guess, res_guess);
302 const size_t image_size = x_guess.size();
304 t_Vector auxilliary_image = x_guess;
306 t_Vector gradient_current = t_Vector::Zero(image_size);
310 bool converged =
false;
312 Real theta_new = 1.0;
313 Real FISTA_step = 0.0;
314 for (; (not converged) && (niters < itermax()); ++niters) {
317 theta_new = (1 + std::sqrt(1 + 4 * theta * theta)) / 2.;
318 FISTA_step = (theta - 1) / (theta_new);
322 iteration_step(out, residual, auxilliary_image, gradient_current, FISTA_step);
323 SOPT_LOW_LOG(
" - [FB] Sum of residuals: {}", residual.array().abs().sum());
324 converged = is_converged(out, residual);
328 SOPT_MEDIUM_LOG(
" - [FB] converged in {} of {} iterations", niters, itermax());
329 }
else if (
static_cast<bool>(is_converged())) {
331 SOPT_ERROR(
" - [FB] did not converge within {} iterations", itermax());
333 return {niters, converged, std::move(residual)};
sopt::Vector< Scalar > t_Vector
Computes inner-most element type.
#define SOPT_MACRO(NAME, TYPE)
#define SOPT_LOW_LOG(...)
Low priority message.
#define SOPT_HIGH_LOG(...)
High priority message.
#define SOPT_ERROR(...)
\macro Something is definitely wrong, algorithm exits
#define SOPT_WARN(...)
\macro Something might be going wrong
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
double t_real
Root of the type hierarchy for real numbers.
size_t t_uint
Root of the type hierarchy for unsigned integers.
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
std::function< void(Vector< SCALAR > &output, typename real_type< SCALAR >::type const weight, Vector< SCALAR > const &input)> ProximalFunction
Typical function signature for calls to proximal.