1 #ifndef SOPT_L1_PRIMAL_DUAL_H
2 #define SOPT_L1_PRIMAL_DUAL_H
4 #include "sopt/config.h"
19 template <
typename SCALAR>
51 template <
typename DERIVED>
54 proximal::l1_norm<t_Vector, t_Vector>(out, regulariser_strength, x);
57 proximal::l1_norm<t_Vector, t_Vector, Vector<Real>>(out, regulariser_strength, x);
60 l2ball_proximal_(1e0),
61 residual_tolerance_(1e-4),
62 relative_variation_(1e-4),
63 residual_convergence_(
nullptr),
64 objective_convergence_(
nullptr),
65 itermax_(std::numeric_limits<t_uint>::max()),
68 regulariser_strength_(0.5),
70 precondition_stepsize_(0.5),
71 precondition_weights_(t_Vector::Ones(
target.size())),
72 precondition_iters_(0),
76 Phi_(linear_transform_identity<Scalar>()),
77 Psi_(linear_transform_identity<Scalar>()),
78 random_measurement_updater_([]() {
return true; }),
79 random_wavelet_updater_([]() {
return true; }),
85 #define SOPT_MACRO(NAME, TYPE) \
86 TYPE const &NAME() const { return NAME##_; } \
87 ImagingPrimalDual<SCALAR> &NAME(TYPE const &(NAME)) { \
152 SOPT_MACRO(v_all_sum_all_comm, mpi::Communicator);
154 SOPT_MACRO(u_all_sum_all_comm, mpi::Communicator);
161 template <
typename DERIVED>
176 return operator()(out, std::get<0>(guess), std::get<1>(guess));
182 std::tuple<t_Vector const &, t_Vector const &>
const &guess)
const {
183 return operator()(out, std::get<0>(guess), std::get<1>(guess));
188 return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
193 std::tuple<t_Vector const &, t_Vector const &>
const &guess)
const {
195 static_cast<Diagnostic &
>(result) =
operator()(result.
x, guess);
202 static_cast<Diagnostic &
>(result) =
operator()(result.
x,
214 template <
typename... ARGS>
225 template <
typename... ARGS>
236 #define SOPT_MACRO(VAR, NAME, PROXIMAL) \
238 decltype(std::declval<proximal::PROXIMAL<Scalar> const>().VAR()) NAME##_proximal_##VAR() const { \
239 return NAME##_proximal().VAR(); \
242 ImagingPrimalDual<Scalar> &NAME##_proximal_##VAR( \
243 decltype(std::declval<proximal::PROXIMAL<Scalar> const>().VAR()) (VAR)) { \
244 NAME##_proximal().VAR(VAR); \
250 SOPT_MACRO(communicator, l2ball, WeightedL2Ball);
288 bool check_l1_weight_proximal(
const t_Proximal<Real> &no_weights,
293 no_weights(output, 1, x);
295 return output.isApprox(outputw);
299 template <
typename SCALAR>
302 SOPT_HIGH_LOG(
"Performing Primal Dual with L1 and L2 operators");
304 if (not check_l1_weight_proximal(l1_proximal(), l1_proximal_weighted()))
306 "l1 proximal and weighted l1 proximal appear to be different functions. Please make sure "
307 "both are the same function.");
308 auto const f_proximal = [
this](
t_Vector &out, Real regulariser_strength,
t_Vector const &x) {
309 if (this->l1_proximal_weights().size() > 1)
310 this->l1_proximal_weighted()(out, this->l1_proximal_weights() * regulariser_strength, x);
312 this->l1_proximal()(out, this->l1_proximal_weights()(0) * regulariser_strength, x);
314 auto const g_proximal = [
this](
t_Vector &out, Real regulariser_strength,
t_Vector const &x) {
315 this->l2ball_proximal()(out, regulariser_strength, x);
317 for (
t_int i = 0; i < this->precondition_iters(); i++)
318 this->l2ball_proximal()(
319 out, regulariser_strength,
320 out - this->precondition_stepsize() *
321 (out.array() * this->precondition_weights().array() - x.array()).matrix());
323 if (this->precondition_iters() > 0) out = out.array() * this->precondition_weights().array();
325 ScalarRelativeVariation<Scalar> scalvar(relative_variation(), relative_variation(),
326 "Objective function");
327 auto const convergence = [
this, scalvar](
t_Vector const &x,
t_Vector const &residual)
mutable {
328 return this->is_converged(scalvar, x, residual);
330 const bool positive = positivity_constraint();
331 const bool real = real_constraint();
332 t_Constraint constraint = [real, positive](
t_Vector &out,
const t_Vector &x) {
333 if (real) out = x.real();
335 if (not real and not positive) out = x;
337 auto const pd = PD(f_proximal, g_proximal,
target())
339 .constraint(constraint)
342 .regulariser_strength(regulariser_strength())
343 .update_scale(update_scale())
346 .regulariser_strength(regulariser_strength())
349 .random_measurement_updater(random_measurement_updater())
350 .random_wavelet_updater(random_wavelet_updater())
352 .v_all_sum_all_comm(v_all_sum_all_comm())
353 .u_all_sum_all_comm(u_all_sum_all_comm())
355 .is_converged(convergence);
357 static_cast<typename PD::Diagnostic &
>(result) = pd(out, std::tie(guess, res));
361 template <
typename SCALAR>
364 if (
static_cast<bool>(residual_convergence()))
return residual_convergence()(x, residual);
365 if (residual_tolerance() <= 0e0)
return true;
366 auto const residual_norm =
sopt::l2_norm(residual, l2ball_proximal_weights());
367 SOPT_LOW_LOG(
" - [Primal Dual] Residuals: {} <? {}", residual_norm, residual_tolerance());
368 return residual_norm < residual_tolerance();
371 template <
typename SCALAR>
375 if (
static_cast<bool>(objective_convergence()))
return objective_convergence()(x, residual);
376 if (scalvar.relative_tolerance() <= 0e0)
return true;
379 return scalvar(current);
382 template <
typename SCALAR>
385 auto const user =
static_cast<bool>(is_converged()) ==
false or is_converged()(x, residual);
386 auto const res = residual_convergence(x, residual);
387 auto const obj = objective_convergence(scalvar, x, residual);
390 return user and res and obj;
sopt::Vector< Scalar > t_Vector
SOPT_MACRO(Psi, t_LinearTransform)
Wavelet operator.
std::function< void(t_Vector &, const T &, const t_Vector &)> t_Proximal
SOPT_MACRO(residual_tolerance, Real)
Convergence of the relative variation of the objective functions.
ImagingPrimalDual &::type Psi(ARGS &&... args)
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector, t_Vector > const &guess) const
Calls Primal Dual.
DiagnosticAndResult operator()(std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Primal Dual.
SOPT_MACRO(l1_proximal, t_Proximal< Real >)
The l1 prox functioning as f.
SOPT_MACRO(real_constraint, bool)
Apply real constraint.
SOPT_MACRO(Phi, t_LinearTransform)
Measurement operator.
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SOPT_MACRO(residual_convergence, t_IsConverged)
Convergence of the residuals.
typename PD::t_Vector t_Vector
SOPT_MACRO(l2ball_proximal, proximal::WeightedL2Ball< Scalar >)
The weighted L2 proximal functioning as g.
typename PD::Scalar Scalar
virtual ~ImagingPrimalDual()
SOPT_MACRO(random_measurement_updater, t_Random_Updater)
lambda that determines if to update measurements
SOPT_MACRO(epsilon, l2ball, WeightedL2Ball)
DiagnosticAndResult operator()(std::tuple< t_Vector, t_Vector > const &guess) const
Calls Primal Dual.
SOPT_MACRO(weights, l2ball, WeightedL2Ball)
SOPT_MACRO(update_scale, Real)
update parameter
SOPT_MACRO(l1_proximal_weights, Vector< Real >)
The l1 prox weights functioning.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Primal Dual.
ImagingPrimalDual< Scalar > & objective_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
typename PD::t_LinearTransform t_LinearTransform
SOPT_MACRO(relative_variation, Real)
Convergence of the relative variation of the objective functions.
typename PD::t_Constraint t_Constraint
DiagnosticAndResult operator()() const
Calls Primal Dual.
SOPT_MACRO(rho, Real)
rho parameter
t_Vector const & target() const
Vector of target measurements.
SOPT_MACRO(precondition_stepsize, Real)
precondtion step size parameter
SOPT_MACRO(positivity_constraint, bool)
Apply positivity constraint.
SOPT_MACRO(sigma, Real)
sigma parameter
typename PD::value_type value_type
ImagingPrimalDual< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
SOPT_MACRO(l1_proximal_weighted, t_Proximal< Vector< Real >>)
The l1 prox with weights functioning as f.
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
ImagingPrimalDual(Eigen::MatrixBase< DERIVED > const &target)
SOPT_MACRO(precondition_iters, t_uint)
precondition iterations parameter
ImagingPrimalDual< Scalar > & residual_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
proximal::WeightedL2Ball< Scalar > & l2ball_proximal()
Proximal of the L2 ball.
ImagingPrimalDual< Scalar > & target(Eigen::MatrixBase< DERIVED > const &target)
Sets the vector of target measurements.
Diagnostic operator()(t_Vector &out) const
Calls Primal Dual.
SOPT_MACRO(precondition_weights, t_Vector)
precondition weights parameter
SOPT_MACRO(xi, Real)
xi parameter
SOPT_MACRO(objective_convergence, t_IsConverged)
Convergence of the residuals.
SOPT_MACRO(tau, Real)
tau parameter
ImagingPrimalDual &::type Phi(ARGS &&... args)
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const
Makes it simple to chain different calls to PD.
typename PD::t_Random_Updater t_Random_Updater
SOPT_MACRO(random_wavelet_updater, t_Random_Updater)
lambda that determines if to update wavelets
SOPT_MACRO(regulariser_strength, Real)
regulariser_strength parameter
typename PD::t_IsConverged t_IsConverged
SCALAR value_type
Scalar type.
std::function< bool()> t_Random_Updater
Type of random update function.
Vector< Scalar > t_Vector
Type of then underlying vectors.
LinearTransform< t_Vector > t_LinearTransform
Type of the Ψ and Ψ^H operations, as well as Φ and Φ^H.
value_type Scalar
Scalar type.
std::function< void(t_Vector &, const t_Vector &)> t_Constraint
Type of the constraint function.
std::function< bool(const t_Vector &, const t_Vector &)> t_IsConverged
Type of the convergence function.
std::tuple< t_Vector, t_Vector > initial_guess() const
Computes initial guess for x and the residual using the targets.
typename real_type< Scalar >::type Real
Real type.
#define SOPT_LOW_LOG(...)
Low priority message.
#define SOPT_HIGH_LOG(...)
High priority message.
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
Eigen::CwiseUnaryOp< const details::ProjectPositiveQuadrant< typename T::Scalar >, const T > positive_quadrant(Eigen::DenseBase< T > const &input)
Expression to create projection onto positive quadrant.
int t_int
Root of the type hierarchy for signed integers.
size_t t_uint
Root of the type hierarchy for unsigned integers.
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
real_type< typename T0::Scalar >::type l1_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L1 norm.
real_type< typename T0::Scalar >::type l2_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L2 norm.
Holds result vector as well.
Values indicating how the algorithm ran.
Diagnostic(t_uint niters, bool good, t_Vector &&residual)
Diagnostic(t_uint niters=0u, bool good=false)
Values indicating how the algorithm ran.
t_Vector residual
the residual from the last iteration
t_uint niters
Number of iterations.
bool good
Wether convergence was achieved.