1 #ifndef SOPT_L2_PRIMAL_DUAL_H
2 #define SOPT_L2_PRIMAL_DUAL_H
4 #include "sopt/config.h"
18 template <
typename SCALAR>
19 class ImagingPrimalDual {
21 using PD = PrimalDual<SCALAR>;
42 struct DiagnosticAndResult :
public Diagnostic {
50 template <
typename DERIVED>
59 l2ball_proximal_(1e0),
60 residual_tolerance_(1e-4),
61 relative_variation_(1e-4),
62 residual_convergence_(
nullptr),
63 objective_convergence_(
nullptr),
64 itermax_(std::numeric_limits<t_uint>::max()),
69 precondition_stepsize_(0.5),
70 precondition_weights_(t_Vector::Ones(
target.size())),
71 precondition_iters_(0),
75 Phi_(linear_transform_identity<Scalar>()),
76 Psi_(linear_transform_identity<Scalar>()),
77 random_measurement_updater_([]() {
return true; }),
78 random_wavelet_updater_([]() {
return true; }),
84 #define SOPT_MACRO(NAME, TYPE) \
85 TYPE const &NAME() const { return NAME##_; } \
86 ImagingPrimalDual<SCALAR> &NAME(TYPE const &(NAME)) { \
151 SOPT_MACRO(v_all_sum_all_comm, mpi::Communicator);
153 SOPT_MACRO(u_all_sum_all_comm, mpi::Communicator);
160 template <
typename DERIVED>
175 return operator()(out, std::get<0>(guess), std::get<1>(guess));
181 std::tuple<t_Vector const &, t_Vector const &>
const &guess)
const {
182 return operator()(out, std::get<0>(guess), std::get<1>(guess));
187 return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
192 std::tuple<t_Vector const &, t_Vector const &>
const &guess)
const {
194 static_cast<Diagnostic &
>(result) =
operator()(result.x, guess);
201 static_cast<Diagnostic &
>(result) =
operator()(result.x,
208 static_cast<Diagnostic &
>(result) =
operator()(result.x, warmstart.x, warmstart.residual);
213 template <
typename... ARGS>
224 template <
typename... ARGS>
235 #define SOPT_MACRO(VAR, NAME, PROXIMAL) \
237 decltype(std::declval<proximal::PROXIMAL<Scalar> const>().VAR()) NAME##_proximal_##VAR() const { \
238 return NAME##_proximal().VAR(); \
241 ImagingPrimalDual<Scalar> &NAME##_proximal_##VAR( \
242 decltype(std::declval<proximal::PROXIMAL<Scalar> const>().VAR()) (VAR)) { \
243 NAME##_proximal().VAR(VAR); \
249 SOPT_MACRO(communicator, l2ball, WeightedL2Ball);
287 bool check_l2_weight_proximal(
const t_Proximal<Real> &no_weights,
293 no_weights(output, 1, x);
295 return output.isApprox(outputw);
299 template <
typename SCALAR>
302 SOPT_HIGH_LOG(
"Performing Primal Dual with L1 and L2 operators");
304 if (not check_l2_weight_proximal(l2_proximal(), l2_proximal_weighted()))
306 "l1 proximal and weighted l1 proximal appear to be different functions. Please make sure "
307 "both are the same function.");
308 auto const f_proximal = [
this](
t_Vector &out, Real gamma,
t_Vector const &x) {
309 if (this->l2_proximal_weights().size() > 1)
310 this->l2_proximal_weighted()(out, this->l2_proximal_weights() * gamma, x);
312 this->l2_proximal()(out, this->l2_proximal_weights()(0) * gamma, x);
314 auto const g_proximal = [
this](
t_Vector &out, Real gamma,
t_Vector const &x) {
315 this->l2ball_proximal()(out, gamma, x);
317 for (
t_int i = 0; i < this->precondition_iters(); i++)
318 this->l2ball_proximal()(
320 out - this->precondition_stepsize() *
321 (out.array() * this->precondition_weights().array() - x.array()).matrix());
323 if (this->precondition_iters() > 0) out = out.array() * this->precondition_weights().array();
325 ScalarRelativeVariation<Scalar> scalvar(relative_variation(), relative_variation(),
326 "Objective function");
327 auto const convergence = [
this, scalvar](
t_Vector const &x,
t_Vector const &residual)
mutable {
328 return this->is_converged(scalvar, x, residual);
330 const bool positive = positivity_constraint();
331 const bool real = real_constraint();
332 t_Constraint constraint = [real, positive](
t_Vector &out,
const t_Vector &x) {
333 if (real) out.real() = x.real();
335 if (not real and not positive) out = x;
337 auto const pd = PD(f_proximal, g_proximal,
target())
339 .constraint(constraint)
343 .update_scale(update_scale())
349 .random_measurement_updater(random_measurement_updater())
350 .random_wavelet_updater(random_wavelet_updater())
352 .v_all_sum_all_comm(v_all_sum_all_comm())
353 .u_all_sum_all_comm(u_all_sum_all_comm())
355 .is_converged(convergence);
357 static_cast<typename PD::Diagnostic &
>(result) = pd(out, std::tie(guess, res));
361 template <
typename SCALAR>
364 if (
static_cast<bool>(residual_convergence()))
return residual_convergence()(x, residual);
365 if (residual_tolerance() <= 0e0)
return true;
366 auto const residual_norm =
sopt::l2_norm(residual, l2ball_proximal_weights());
367 SOPT_LOW_LOG(
" - [Primal Dual] Residuals: {} <? {}", residual_norm, residual_tolerance());
368 return residual_norm < residual_tolerance();
371 template <
typename SCALAR>
375 if (
static_cast<bool>(objective_convergence()))
return objective_convergence()(x, residual);
376 if (scalvar.relative_tolerance() <= 0e0)
return true;
378 (l2_proximal_weights().size() > 1)
379 ?
sopt::l2_norm(l2_proximal_weights().array() * (Psi().adjoint() * x).array())
380 :
sopt::l2_norm(l2_proximal_weights()(0) * (Psi().adjoint() * x));
381 return scalvar(current);
384 template <
typename SCALAR>
387 auto const user =
static_cast<bool>(is_converged()) ==
false or is_converged()(x, residual);
388 auto const res = residual_convergence(x, residual);
389 auto const obj = objective_convergence(scalvar, x, residual);
392 return user and res and obj;
sopt::Vector< Scalar > t_Vector
SOPT_MACRO(Psi, t_LinearTransform)
Wavelet operator.
std::function< void(t_Vector &, const T &, const t_Vector &)> t_Proximal
SOPT_MACRO(residual_tolerance, Real)
Convergence of the relative variation of the objective functions.
ImagingPrimalDual &::type Psi(ARGS &&... args)
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector, t_Vector > const &guess) const
Calls Primal Dual.
DiagnosticAndResult operator()(std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Primal Dual.
SOPT_MACRO(l1_proximal, t_Proximal< Real >)
The l1 prox functioning as f.
SOPT_MACRO(real_constraint, bool)
Apply real constraint.
SOPT_MACRO(gamma, Real)
gamma parameter
SOPT_MACRO(Phi, t_LinearTransform)
Measurement operator.
SOPT_MACRO(l2_proximal_weighted, t_Proximal< Vector< Real >>)
The l2 prox with weights functioning as f.
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SOPT_MACRO(residual_convergence, t_IsConverged)
Convergence of the residuals.
typename PD::t_Vector t_Vector
SOPT_MACRO(l2ball_proximal, proximal::WeightedL2Ball< Scalar >)
The weighted L2 proximal functioning as g.
typename PD::Scalar Scalar
virtual ~ImagingPrimalDual()
SOPT_MACRO(random_measurement_updater, t_Random_Updater)
lambda that determines if to update measurements
SOPT_MACRO(epsilon, l2ball, WeightedL2Ball)
DiagnosticAndResult operator()(std::tuple< t_Vector, t_Vector > const &guess) const
Calls Primal Dual.
SOPT_MACRO(weights, l2ball, WeightedL2Ball)
SOPT_MACRO(update_scale, Real)
update parameter
SOPT_MACRO(l2_proximal_weights, Vector< Real >)
The l2 prox weights functioning.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Primal Dual.
ImagingPrimalDual< Scalar > & objective_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
typename PD::t_LinearTransform t_LinearTransform
SOPT_MACRO(relative_variation, Real)
Convergence of the relative variation of the objective functions.
typename PD::t_Constraint t_Constraint
DiagnosticAndResult operator()() const
Calls Primal Dual.
SOPT_MACRO(rho, Real)
rho parameter
t_Vector const & target() const
Vector of target measurements.
SOPT_MACRO(precondition_stepsize, Real)
precondtion step size parameter
SOPT_MACRO(positivity_constraint, bool)
Apply positivity constraint.
SOPT_MACRO(sigma, Real)
sigma parameter
typename PD::value_type value_type
ImagingPrimalDual< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
SOPT_MACRO(l2_proximal, t_Proximal< Real >)
The l2 prox functioning as f.
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
ImagingPrimalDual(Eigen::MatrixBase< DERIVED > const &target)
SOPT_MACRO(precondition_iters, t_uint)
precondition iterations parameter
ImagingPrimalDual< Scalar > & residual_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
proximal::WeightedL2Ball< Scalar > & l2ball_proximal()
Proximal of the L2 ball.
ImagingPrimalDual< Scalar > & target(Eigen::MatrixBase< DERIVED > const &target)
Sets the vector of target measurements.
Diagnostic operator()(t_Vector &out) const
Calls Primal Dual.
SOPT_MACRO(precondition_weights, t_Vector)
precondition weights parameter
SOPT_MACRO(xi, Real)
xi parameter
SOPT_MACRO(objective_convergence, t_IsConverged)
Convergence of the residuals.
SOPT_MACRO(tau, Real)
tau parameter
ImagingPrimalDual &::type Phi(ARGS &&... args)
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const
Makes it simple to chain different calls to PD.
typename PD::t_Random_Updater t_Random_Updater
SOPT_MACRO(random_wavelet_updater, t_Random_Updater)
lambda that determines if to update wavelets
typename PD::t_IsConverged t_IsConverged
SCALAR value_type
Scalar type.
std::function< bool()> t_Random_Updater
Type of random update function.
Vector< Scalar > t_Vector
Type of then underlying vectors.
LinearTransform< t_Vector > t_LinearTransform
Type of the Ψ and Ψ^H operations, as well as Φ and Φ^H.
value_type Scalar
Scalar type.
std::function< void(t_Vector &, const t_Vector &)> t_Constraint
Type of the constraint function.
std::function< bool(const t_Vector &, const t_Vector &)> t_IsConverged
Type of the convergence function.
std::tuple< t_Vector, t_Vector > initial_guess() const
Computes initial guess for x and the residual using the targets.
typename real_type< Scalar >::type Real
Real type.
#define SOPT_LOW_LOG(...)
Low priority message.
#define SOPT_HIGH_LOG(...)
High priority message.
void l2_norm(Eigen::DenseBase< T0 > &out, typename real_type< typename T0::Scalar >::type gamma, Eigen::DenseBase< T1 > const &x)
Proximal of the l2 norm (note this is different from the l2 ball indicator function)
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
Eigen::CwiseUnaryOp< const details::ProjectPositiveQuadrant< typename T::Scalar >, const T > positive_quadrant(Eigen::DenseBase< T > const &input)
Expression to create projection onto positive quadrant.
int t_int
Root of the type hierarchy for signed integers.
size_t t_uint
Root of the type hierarchy for unsigned integers.
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
real_type< typename T0::Scalar >::type l2_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L2 norm.
Holds result vector as well.
Values indicating how the algorithm ran.
Diagnostic(t_uint niters, bool good, t_Vector &&residual)
Diagnostic(t_uint niters=0u, bool good=false)
Values indicating how the algorithm ran.
t_Vector residual
the residual from the last iteration
t_uint niters
Number of iterations.
bool good
Wether convergence was achieved.