1 #ifndef SOPT_L1_PROXIMAL_ADMM_H
2 #define SOPT_L1_PROXIMAL_ADMM_H
4 #include "sopt/config.h"
19 template <
typename SCALAR>
54 template <
typename DERIVED>
57 l2ball_proximal_(1e0),
59 residual_tolerance_(1e-4),
60 relative_variation_(1e-4),
61 residual_convergence_(nullptr),
62 objective_convergence_(nullptr),
63 itermax_(std::numeric_limits<
t_uint>::max()),
64 regulariser_strength_(1e-8),
65 lagrange_update_scale_(0.9),
73 #define SOPT_MACRO(NAME, TYPE) \
74 TYPE const &NAME() const { return NAME##_; } \
75 ImagingProximalADMM<SCALAR> &NAME(TYPE const &(NAME)) { \
123 template <
typename DERIVED>
138 return operator()(out, std::get<0>(guess), std::get<1>(guess));
144 std::tuple<t_Vector const &, t_Vector const &>
const &guess)
const {
145 return operator()(out, std::get<0>(guess), std::get<1>(guess));
150 return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
155 std::tuple<t_Vector const &, t_Vector const &>
const &guess)
const {
157 static_cast<Diagnostic &
>(result) =
operator()(result.
x, guess);
164 static_cast<Diagnostic &
>(result) =
operator()(result.
x,
176 template <
typename... ARGS>
193 template <
typename... ARGS>
205 #define SOPT_MACRO(VAR, NAME, PROXIMAL) \
207 decltype(std::declval<proximal::PROXIMAL<Scalar> const>().VAR()) NAME##_proximal_##VAR() const { \
208 return NAME##_proximal().VAR(); \
211 ImagingProximalADMM<Scalar> &NAME##_proximal_##VAR( \
212 decltype(std::declval<proximal::PROXIMAL<Scalar> const>().VAR()) (VAR)) { \
213 NAME##_proximal().VAR(VAR); \
226 SOPT_MACRO(communicator, l2ball, WeightedL2Ball);
256 template <
typename T0,
typename T1>
258 Eigen::MatrixBase<T1>
const &x)
const {
259 return l1_proximal_real_constraint()
260 ? call_l1_proximal(out, regulariser_strength, x.real().template cast<typename T1::Scalar>())
261 : call_l1_proximal(out, regulariser_strength, x);
265 template <
typename T0,
typename T1>
267 Eigen::MatrixBase<T1>
const &x)
const {
270 return {0, 0,
l1_proximal().objective(x, out, regulariser_strength),
true};
272 return l1_proximal()(out, regulariser_strength, x);
287 template <
typename SCALAR>
290 SOPT_HIGH_LOG(
"Performing Proximal ADMM with L1 and L2 operators");
293 auto const f_proximal = [
this, &result](
t_Vector &out, Real regulariser_strength,
t_Vector const &x) {
294 result.l1_diagnostic = this->l1_proximal(out, regulariser_strength, x);
296 auto const g_proximal = [
this](
t_Vector &out, Real regulariser_strength,
t_Vector const &x) {
297 this->l2ball_proximal()(out, regulariser_strength, x);
299 ScalarRelativeVariation<Scalar> scalvar(relative_variation(), relative_variation(),
300 "Objective function");
301 auto const convergence = [
this, scalvar](
t_Vector const &x,
t_Vector const &residual)
mutable {
302 return this->is_converged(scalvar, x, residual);
304 auto const padmm = PADMM(f_proximal, g_proximal,
target())
306 .regulariser_strength(regulariser_strength())
307 .lagrange_update_scale(lagrange_update_scale())
309 .is_converged(convergence);
310 static_cast<typename PADMM::Diagnostic &
>(result) = padmm(out, std::tie(guess, res));
314 template <
typename SCALAR>
317 if (
static_cast<bool>(residual_convergence()))
return residual_convergence()(x, residual);
318 if (residual_tolerance() <= 0e0)
return true;
319 auto const residual_norm =
sopt::l2_norm(residual, l2ball_proximal_weights());
320 SOPT_LOW_LOG(
" - [PADMM] Residuals: {} <? {}", residual_norm, residual_tolerance());
321 return residual_norm < residual_tolerance();
324 template <
typename SCALAR>
328 if (
static_cast<bool>(objective_convergence()))
return objective_convergence()(x, residual);
329 if (scalvar.relative_tolerance() <= 0e0)
return true;
332 return scalvar(current);
335 template <
typename SCALAR>
338 auto const user =
static_cast<bool>(is_converged()) ==
false or is_converged()(x, residual);
339 auto const res = residual_convergence(x, residual);
340 auto const obj = objective_convergence(scalvar, x, residual);
343 return user and res and obj;
sopt::Vector< Scalar > t_Vector
ImagingProximalADMM(Eigen::MatrixBase< DERIVED > const &target)
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
SOPT_MACRO(weights, l2ball, WeightedL2Ball)
SOPT_MACRO(epsilon, l2ball, WeightedL2Ball)
SOPT_MACRO(residual_tolerance, Real)
Convergence of the relative variation of the objective functions.
SOPT_MACRO(tolerance, l1, L1)
proximal::L1< Scalar > & l1_proximal()
L1 proximal used during calculation.
typename PADMM::t_LinearTransform t_LinearTransform
SOPT_MACRO(fista_mixing, l1, L1)
ImagingProximalADMM< Scalar > &::type Psi(ARGS &&... args)
typename PADMM::Scalar Scalar
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const
Makes it simple to chain different calls to PADMM.
DiagnosticAndResult operator()(std::tuple< t_Vector, t_Vector > const &guess) const
Calls Proximal ADMM.
DiagnosticAndResult operator()(std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Proximal ADMM.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector, t_Vector > const &guess) const
Calls Proximal ADMM.
SOPT_MACRO(objective_convergence, t_IsConverged)
Convergence of the residuals.
ImagingProximalADMM< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
ImagingProximalADMM< Scalar > & objective_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
typename PADMM::Real Real
t_Vector const & target() const
Vector of target measurements.
typename PADMM::t_Vector t_Vector
virtual ~ImagingProximalADMM()
typename PADMM::t_IsConverged t_IsConverged
SOPT_MACRO(regulariser_strength, Real)
γ parameter.
SOPT_MACRO(itermax, l1, L1)
SOPT_MACRO(tight_frame, bool)
Whether Ψ is a tight-frame or not.
SOPT_MACRO(weights, l1, L1)
proximal::L1< Scalar > * g_proximal()
DiagnosticAndResult operator()() const
Calls Proximal ADMM.
ImagingProximalADMM< Scalar > & residual_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
ImagingProximalADMM< Scalar > & target(Eigen::MatrixBase< DERIVED > const &target)
Sets the vector of target measurements.
SOPT_MACRO(Phi, t_LinearTransform)
Measurement operator.
SOPT_MACRO(lagrange_update_scale, Real)
Lagrange update scale β
SOPT_MACRO(l1_proximal, proximal::L1< Scalar >)
Maximum number of iterations.
SOPT_MACRO(positivity_constraint, l1, L1)
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Proximal ADMM.
typename PADMM::value_type value_type
SOPT_MACRO(residual_convergence, t_IsConverged)
Convergence of the residuals.
SOPT_MACRO(relative_variation, Real)
Convergence of the relative variation of the objective functions.
Diagnostic operator()(t_Vector &out) const
Calls Proximal ADMM.
SOPT_MACRO(real_constraint, l1, L1)
ImagingProximalADMM &::type Phi(ARGS &&... args)
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
t_LinearTransform const & Psi() const
Analysis operator Ψ
typename PADMM::t_Proximal t_Proximal
proximal::WeightedL2Ball< Scalar > & l2ball_proximal()
Proximal of the L2 ball.
SOPT_MACRO(l2ball_proximal, proximal::WeightedL2Ball< Scalar >)
The weighted L2 proximal functioning as g.
Proximal Alternate Direction method of mutltipliers.
SCALAR value_type
Scalar type.
LinearTransform< t_Vector > t_LinearTransform
Type of the Ψ and Ψ^H operations, as well as Φ and Φ^H.
Vector< Scalar > t_Vector
Type of then underlying vectors.
value_type Scalar
Scalar type.
ProximalFunction< Scalar > t_Proximal
Type of the convergence function.
std::tuple< t_Vector, t_Vector > initial_guess() const
Computes initial guess for x and the residual using the targets.
std::function< bool(const t_Vector &, const t_Vector &)> t_IsConverged
Type of the convergence function.
typename real_type< Scalar >::type Real
Real type.
auto tight_frame(T &&... args) const -> decltype(this->L1TightFrame< Scalar >::operator()(std::forward< T >(args)...))
Special case if Ψ ia a tight frame.
LinearTransform< Vector< Scalar > > const & Psi() const
Linear transform applied to input prior to L1 norm.
#define SOPT_LOW_LOG(...)
Low priority message.
#define SOPT_HIGH_LOG(...)
High priority message.
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
size_t t_uint
Root of the type hierarchy for unsigned integers.
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
LinearTransform< Vector< SCALAR > > linear_transform_identity()
Helper function to create a linear transform that's just the identity.
real_type< typename T0::Scalar >::type l1_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L1 norm.
real_type< typename T0::Scalar >::type l2_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L2 norm.
Holds result vector as well.
Values indicating how the algorithm ran.
proximal::L1< Scalar >::Diagnostic l1_diagnostic
Diagnostic from calling L1 proximal.
Diagnostic(t_uint niters, bool good, typename proximal::L1< Scalar >::Diagnostic const &l1diag, t_Vector &&residual)
Diagnostic(t_uint niters=0u, bool good=false, typename proximal::L1< Scalar >::Diagnostic const &l1diag=typename proximal::L1< Scalar >::Diagnostic())
Values indicating how the algorithm ran.
bool good
Wether convergence was achieved.
t_uint niters
Number of iterations.
t_Vector residual
the residual from the last iteration