1 #ifndef SOPT_TV_PRIMAL_DUAL_H
2 #define SOPT_TV_PRIMAL_DUAL_H
4 #include "sopt/config.h"
18 template <
typename SCALAR>
50 template <
typename DERIVED>
53 proximal::tv_norm<t_Vector, t_Vector>(out, regulariser_strength, x);
56 proximal::tv_norm<t_Vector, t_Vector, Vector<Real>>(out, regulariser_strength, x);
59 l2ball_proximal_(1e0),
60 residual_tolerance_(1e-4),
61 relative_variation_(1e-4),
62 residual_convergence_(
nullptr),
63 objective_convergence_(
nullptr),
64 itermax_(std::numeric_limits<t_uint>::max()),
67 regulariser_strength_(0.5),
69 precondition_stepsize_(0.5),
70 precondition_weights_(t_Vector::Ones(
target.size())),
71 precondition_iters_(0),
75 Phi_(linear_transform_identity<Scalar>()),
76 Psi_(linear_transform_identity<Scalar>()),
77 random_measurement_updater_([]() {
return true; }),
78 random_wavelet_updater_([]() {
return true; }),
84 #define SOPT_MACRO(NAME, TYPE) \
85 TYPE const &NAME() const { return NAME##_; } \
86 TVPrimalDual<SCALAR> &NAME(TYPE const &(NAME)) { \
151 SOPT_MACRO(v_all_sum_all_comm, mpi::Communicator);
153 SOPT_MACRO(u_all_sum_all_comm, mpi::Communicator);
160 template <
typename DERIVED>
175 return operator()(out, std::get<0>(guess), std::get<1>(guess));
181 std::tuple<t_Vector const &, t_Vector const &>
const &guess)
const {
182 return operator()(out, std::get<0>(guess), std::get<1>(guess));
187 return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
192 std::tuple<t_Vector const &, t_Vector const &>
const &guess)
const {
194 static_cast<Diagnostic &
>(result) =
operator()(result.
x, guess);
201 static_cast<Diagnostic &
>(result) =
operator()(result.
x,
213 template <
typename... ARGS>
214 typename std::enable_if<
sizeof...(ARGS) >= 1,
TVPrimalDual &>::type
Phi(ARGS &&... args) {
224 template <
typename... ARGS>
225 typename std::enable_if<
sizeof...(ARGS) >= 1,
TVPrimalDual &>::type
Psi(ARGS &&... args) {
235 #define SOPT_MACRO(VAR, NAME, PROXIMAL) \
237 decltype(std::declval<proximal::PROXIMAL<Scalar> const>().VAR()) NAME##_proximal_##VAR() const { \
238 return NAME##_proximal().VAR(); \
241 TVPrimalDual<Scalar> &NAME##_proximal_##VAR( \
242 decltype(std::declval<proximal::PROXIMAL<Scalar> const>().VAR()) (VAR)) { \
243 NAME##_proximal().VAR(VAR); \
249 SOPT_MACRO(communicator, l2ball, WeightedL2Ball);
287 bool check_tv_weight_proximal(
const t_Proximal<Real> &no_weights,
292 no_weights(output, 1, x);
294 return output.isApprox(outputw);
298 template <
typename SCALAR>
301 SOPT_HIGH_LOG(
"Performing Primal Dual with TV and L2 operators");
303 if (not check_tv_weight_proximal(tv_proximal(), tv_proximal_weighted()))
305 "tv proximal and weighted tv proximal appear to be different functions. Please make sure "
306 "both are the same function.");
307 auto const f_proximal = [
this](
t_Vector &out, Real regulariser_strength,
t_Vector const &x) {
308 if (this->tv_proximal_weights().size() > 1)
309 this->tv_proximal_weighted()(out, this->tv_proximal_weights() * regulariser_strength, x);
311 this->tv_proximal()(out, this->tv_proximal_weights()(0) * regulariser_strength, x);
313 auto const g_proximal = [
this](
t_Vector &out, Real regulariser_strength,
t_Vector const &x) {
314 this->l2ball_proximal()(out, regulariser_strength, x);
316 for (
t_int i = 0; i < this->precondition_iters(); i++)
317 this->l2ball_proximal()(
318 out, regulariser_strength,
319 out - this->precondition_stepsize() *
320 (out.array() * this->precondition_weights().array() - x.array()).matrix());
322 if (this->precondition_iters() > 0) out = out.array() * this->precondition_weights().array();
324 ScalarRelativeVariation<Scalar> scalvar(relative_variation(), relative_variation(),
325 "Objective function");
326 auto const convergence = [
this, scalvar](
t_Vector const &x,
t_Vector const &residual)
mutable {
327 return this->is_converged(scalvar, x, residual);
329 const bool positive = positivity_constraint();
330 const bool real = real_constraint();
331 t_Constraint constraint = [real, positive](
t_Vector &out,
const t_Vector &x) {
332 if (real) out.real() = x.real();
334 if (not real and not positive) out = x;
336 auto const pd = PD(f_proximal, g_proximal,
target())
338 .constraint(constraint)
341 .regulariser_strength(regulariser_strength())
342 .update_scale(update_scale())
345 .regulariser_strength(regulariser_strength())
348 .random_measurement_updater(random_measurement_updater())
349 .random_wavelet_updater(random_wavelet_updater())
351 .v_all_sum_all_comm(v_all_sum_all_comm())
352 .u_all_sum_all_comm(u_all_sum_all_comm())
354 .is_converged(convergence);
356 static_cast<typename PD::Diagnostic &
>(result) = pd(out, std::tie(guess, res));
360 template <
typename SCALAR>
362 if (
static_cast<bool>(residual_convergence()))
return residual_convergence()(x, residual);
363 if (residual_tolerance() <= 0e0)
return true;
364 auto const residual_norm =
sopt::l2_norm(residual, l2ball_proximal_weights());
365 SOPT_LOW_LOG(
" - [Primal Dual] Residuals: {} <? {}", residual_norm, residual_tolerance());
366 return residual_norm < residual_tolerance();
369 template <
typename SCALAR>
373 if (
static_cast<bool>(objective_convergence()))
return objective_convergence()(x, residual);
374 if (scalvar.relative_tolerance() <= 0e0)
return true;
377 return scalvar(current);
380 template <
typename SCALAR>
383 auto const user =
static_cast<bool>(is_converged()) ==
false or is_converged()(x, residual);
384 auto const res = residual_convergence(x, residual);
385 auto const obj = objective_convergence(scalvar, x, residual);
388 return user and res and obj;
sopt::Vector< Scalar > t_Vector
SCALAR value_type
Scalar type.
std::function< bool()> t_Random_Updater
Type of random update function.
Vector< Scalar > t_Vector
Type of then underlying vectors.
LinearTransform< t_Vector > t_LinearTransform
Type of the Ψ and Ψ^H operations, as well as Φ and Φ^H.
value_type Scalar
Scalar type.
std::function< void(t_Vector &, const t_Vector &)> t_Constraint
Type of the constraint function.
std::function< bool(const t_Vector &, const t_Vector &)> t_IsConverged
Type of the convergence function.
std::tuple< t_Vector, t_Vector > initial_guess() const
Computes initial guess for x and the residual using the targets.
typename real_type< Scalar >::type Real
Real type.
TVPrimalDual< Scalar > & target(Eigen::MatrixBase< DERIVED > const &target)
Sets the vector of target measurements.
SOPT_MACRO(l2ball_proximal, proximal::WeightedL2Ball< Scalar >)
The weighted L2 proximal functioning as g.
TVPrimalDual(Eigen::MatrixBase< DERIVED > const &target)
SOPT_MACRO(precondition_stepsize, Real)
precondtion step size parameter
SOPT_MACRO(epsilon, l2ball, WeightedL2Ball)
typename PD::t_Random_Updater t_Random_Updater
SOPT_MACRO(relative_variation, Real)
Convergence of the relative variation of the objective functions.
SOPT_MACRO(objective_convergence, t_IsConverged)
Convergence of the residuals.
Diagnostic operator()(t_Vector &out) const
Calls Primal Dual.
SOPT_MACRO(random_measurement_updater, t_Random_Updater)
lambda that determines if to update measurements
SOPT_MACRO(update_scale, Real)
update parameter
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SOPT_MACRO(xi, Real)
xi parameter
typename PD::Scalar Scalar
TVPrimalDual &::type Phi(ARGS &&... args)
SOPT_MACRO(Phi, t_LinearTransform)
Measurement operator.
SOPT_MACRO(sigma, Real)
sigma parameter
SOPT_MACRO(weights, l2ball, WeightedL2Ball)
SOPT_MACRO(precondition_weights, t_Vector)
precondition weights parameter
typename PD::value_type value_type
SOPT_MACRO(precondition_iters, t_uint)
precondition iterations parameter
t_Vector const & target() const
Vector of target measurements.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector, t_Vector > const &guess) const
Calls Primal Dual.
typename PD::t_IsConverged t_IsConverged
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
SOPT_MACRO(residual_tolerance, Real)
Convergence of the relative variation of the objective functions.
TVPrimalDual< Scalar > & residual_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
SOPT_MACRO(tv_proximal_weights, Vector< Real >)
The tv prox weights functioning.
SOPT_MACRO(real_constraint, bool)
Apply real constraint.
SOPT_MACRO(tau, Real)
tau parameter
DiagnosticAndResult operator()(std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Primal Dual.
SOPT_MACRO(rho, Real)
rho parameter
typename PD::t_Vector t_Vector
SOPT_MACRO(residual_convergence, t_IsConverged)
Convergence of the residuals.
DiagnosticAndResult operator()() const
Calls Primal Dual.
SOPT_MACRO(tv_proximal_weighted, t_Proximal< Vector< Real >>)
The tv prox with weights functioning as f.
typename PD::t_LinearTransform t_LinearTransform
SOPT_MACRO(positivity_constraint, bool)
Apply positivity constraint.
std::function< void(t_Vector &, const T &, const t_Vector &)> t_Proximal
TVPrimalDual< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Primal Dual.
SOPT_MACRO(Psi, t_LinearTransform)
Wavelet operator.
SOPT_MACRO(tv_proximal, t_Proximal< Real >)
The tv prox functioning as f.
typename PD::t_Constraint t_Constraint
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const
Makes it simple to chain different calls to PD.
SOPT_MACRO(regulariser_strength, Real)
regulariser_strength parameter
proximal::WeightedL2Ball< Scalar > & l2ball_proximal()
Proximal of the L2 ball.
DiagnosticAndResult operator()(std::tuple< t_Vector, t_Vector > const &guess) const
Calls Primal Dual.
TVPrimalDual &::type Psi(ARGS &&... args)
SOPT_MACRO(random_wavelet_updater, t_Random_Updater)
lambda that determines if to update wavelets
TVPrimalDual< Scalar > & objective_convergence(Real const &tolerance)
Helper function to set-up default residual convergence function.
#define SOPT_LOW_LOG(...)
Low priority message.
#define SOPT_HIGH_LOG(...)
High priority message.
real_type< typename T0::Scalar >::type tv_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted TV norm.
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
Eigen::CwiseUnaryOp< const details::ProjectPositiveQuadrant< typename T::Scalar >, const T > positive_quadrant(Eigen::DenseBase< T > const &input)
Expression to create projection onto positive quadrant.
int t_int
Root of the type hierarchy for signed integers.
size_t t_uint
Root of the type hierarchy for unsigned integers.
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
real_type< typename T0::Scalar >::type l2_norm(Eigen::ArrayBase< T0 > const &input, Eigen::ArrayBase< T1 > const &weights)
Computes weighted L2 norm.
Values indicating how the algorithm ran.
t_Vector residual
the residual from the last iteration
t_uint niters
Number of iterations.
bool good
Wether convergence was achieved.
Holds result vector as well.
Values indicating how the algorithm ran.
Diagnostic(t_uint niters, bool good, t_Vector &&residual)
Diagnostic(t_uint niters=0u, bool good=false)