1 #ifndef SOPT_PROXIMAL_ADMM_H
2 #define SOPT_PROXIMAL_ADMM_H
4 #include "sopt/config.h"
18 template <
typename SCALAR>
59 template <
typename DERIVED>
61 Eigen::MatrixBase<DERIVED>
const &
target)
62 : itermax_(std::numeric_limits<
t_uint>::max()),
63 regulariser_strength_(1e-8),
64 lagrange_update_scale_(0.9),
74 #define SOPT_MACRO(NAME, TYPE) \
75 TYPE const &NAME() const { return NAME##_; } \
76 ProximalADMM<SCALAR> &NAME(TYPE const &(NAME)) { \
119 template <
typename DERIVED>
137 return operator()(out, std::get<0>(guess), std::get<1>(guess));
143 std::tuple<t_Vector const &, t_Vector const &>
const &guess)
const {
144 return operator()(out, std::get<0>(guess), std::get<1>(guess));
149 return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
154 std::tuple<t_Vector const &, t_Vector const &>
const &guess)
const {
156 static_cast<Diagnostic &
>(result) =
operator()(result.
x, guess);
173 template <
typename... ARGS>
174 typename std::enable_if<
sizeof...(ARGS) >= 1,
ProximalADMM &>::type
Phi(ARGS &&... args) {
195 std::tuple<t_Vector, t_Vector> guess;
197 std::get<1>(guess) = phi * std::get<0>(guess) -
target;
205 void sanity_check(
t_Vector const &x_guess,
t_Vector const &res_guess)
const {
206 if ((
Phi().adjoint() *
target()).size() != x_guess.size())
207 SOPT_THROW(
"target, adjoint measurement operator and input vector have inconsistent sizes");
208 if (
target().size() != res_guess.size())
209 SOPT_THROW(
"target and residual vector have inconsistent sizes");
210 if ((
Phi() * x_guess).size() !=
target().size())
211 SOPT_THROW(
"target, measurement operator and input vector have inconsistent sizes");
213 SOPT_WARN(
"No convergence function was provided: algorithm will run for {} steps", itermax());
226 template <
typename SCALAR>
229 g_proximal(z, regulariser_strength(), -lambda - residual);
230 f_proximal(out, regulariser_strength() / Phi().sq_norm(),
231 out -
static_cast<t_Vector>(Phi().adjoint() * (residual + lambda + z)) / Phi().sq_norm());
233 lambda += lagrange_update_scale() * (residual + z);
236 template <
typename SCALAR>
240 sanity_check(x_guess, res_guess);
248 bool converged =
false;
249 for (; (not converged) && (niters < itermax()); ++niters) {
250 SOPT_LOW_LOG(
" - [PADMM] Iteration {}/{}", niters, itermax());
251 iteration_step(out, residual, lambda, z);
252 SOPT_LOW_LOG(
" - [PADMM] Sum of residuals: {}", residual.array().abs().sum());
253 converged = is_converged(out, residual);
257 SOPT_MEDIUM_LOG(
" - [PADMM] converged in {} of {} iterations", niters, itermax());
258 }
else if (
static_cast<bool>(is_converged())) {
260 SOPT_ERROR(
" - [PADMM] did not converge within {} iterations", itermax());
262 return {niters, converged, std::move(residual)};
sopt::Vector< Scalar > t_Vector
Proximal Alternate Direction method of mutltipliers.
ProximalADMM(t_Proximal const &f_proximal, t_Proximal const &g_proximal, Eigen::MatrixBase< DERIVED > const &target)
void g_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const
Simplifies calling the proximal of f.
SOPT_MACRO(regulariser_strength, Real)
γ parameter.
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SCALAR value_type
Scalar type.
t_Vector const & target() const
Vector of target measurements.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Proximal ADMM.
ProximalADMM< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
Vector< Scalar > t_Vector
Type of then underlying vectors.
Diagnostic operator()(t_Vector &out) const
Calls Proximal ADMM.
ProximalADMM< Scalar > & target(Eigen::MatrixBase< DERIVED > const &target)
Sets the vector of target measurements.
value_type Scalar
Scalar type.
ProximalFunction< Scalar > t_Proximal
Type of the convergence function.
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
SOPT_MACRO(g_proximal, t_Proximal)
Second proximal.
SOPT_MACRO(lagrange_update_scale, Real)
Lagrange update scale β
std::tuple< t_Vector, t_Vector > initial_guess() const
Computes initial guess for x and the residual using the targets.
ProximalADMM &::type Phi(ARGS &&... args)
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const
Makes it simple to chain different calls to PADMM.
SOPT_MACRO(f_proximal, t_Proximal)
First proximal.
DiagnosticAndResult operator()(std::tuple< t_Vector, t_Vector > const &guess) const
Calls Proximal ADMM.
bool is_converged(t_Vector const &x, t_Vector const &residual) const
Facilitates call to user-provided convergence function.
DiagnosticAndResult operator()(std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Proximal ADMM.
std::function< bool(const t_Vector &, const t_Vector &)> t_IsConverged
Type of the convergence function.
static std::tuple< t_Vector, t_Vector > initial_guess(t_Vector const &target, t_LinearTransform const &phi)
Computes initial guess for x and the residual using the targets.
void f_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const
Simplifies calling the proximal of f.
SOPT_MACRO(Phi, t_LinearTransform)
Measurement operator.
typename real_type< Scalar >::type Real
Real type.
DiagnosticAndResult operator()() const
Calls Proximal ADMM.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector, t_Vector > const &guess) const
Calls Proximal ADMM.
Computes inner-most element type.
#define SOPT_LOW_LOG(...)
Low priority message.
#define SOPT_HIGH_LOG(...)
High priority message.
#define SOPT_ERROR(...)
\macro Something is definitely wrong, algorithm exits
#define SOPT_WARN(...)
\macro Something might be going wrong
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
size_t t_uint
Root of the type hierarchy for unsigned integers.
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
LinearTransform< Vector< SCALAR > > linear_transform_identity()
Helper function to create a linear transform that's just the identity.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
std::function< void(Vector< SCALAR > &output, typename real_type< SCALAR >::type const weight, Vector< SCALAR > const &input)> ProximalFunction
Typical function signature for calls to proximal.
Holds result vector as well.
Values indicating how the algorithm ran.
bool good
Wether convergence was achieved.
Diagnostic(t_uint niters=0u, bool good=false)
t_uint niters
Number of iterations.
Diagnostic(t_uint niters, bool good, t_Vector &&residual)
t_Vector residual
the residual from the last iteration