1 #ifndef SOPT_PRIMAL_DUAL_H
2 #define SOPT_PRIMAL_DUAL_H
4 #include "sopt/config.h"
23 template <
typename SCALAR>
68 template <
typename DERIVED>
70 Eigen::MatrixBase<DERIVED>
const &
target)
71 : itermax_(std::numeric_limits<
t_uint>::max()),
74 regulariser_strength_(0.5),
80 Phi_(linear_transform_identity<Scalar>()),
81 Psi_(linear_transform_identity<Scalar>()),
84 random_measurement_updater_([]() {
return true; }),
85 random_wavelet_updater_([]() {
return true; }),
87 v_all_sum_all_comm_(mpi::Communicator()),
88 u_all_sum_all_comm_(mpi::Communicator()),
96 #define SOPT_MACRO(NAME, TYPE) \
97 TYPE const &NAME() const { return NAME##_; } \
98 PrimalDual<SCALAR> &NAME(TYPE const &(NAME)) { \
141 SOPT_MACRO(v_all_sum_all_comm, mpi::Communicator);
143 SOPT_MACRO(u_all_sum_all_comm, mpi::Communicator);
163 template <
typename DERIVED>
181 return operator()(out, std::get<0>(guess), std::get<1>(guess));
187 std::tuple<t_Vector const &, t_Vector const &>
const &guess)
const {
188 return operator()(out, std::get<0>(guess), std::get<1>(guess));
193 return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
198 std::tuple<t_Vector const &, t_Vector const &>
const &guess)
const {
200 static_cast<Diagnostic &
>(result) =
operator()(result.
x, guess);
217 template <
typename... ARGS>
218 typename std::enable_if<
sizeof...(ARGS) >= 1,
PrimalDual &>::type
Phi(ARGS &&... args) {
223 template <
typename... ARGS>
224 typename std::enable_if<
sizeof...(ARGS) >= 1,
PrimalDual &>::type
Psi(ARGS &&... args) {
245 std::tuple<t_Vector, t_Vector> guess;
247 std::get<1>(guess) =
target;
254 bool &random_measurement_update,
bool &random_wavelet_update,
258 void sanity_check(
t_Vector const &x_guess,
t_Vector const &res_guess)
const {
259 if ((
Phi().adjoint() *
target()).size() != x_guess.size())
260 SOPT_THROW(
"target, adjoint measurement operator and input vector have inconsistent sizes");
261 if (
target().size() != res_guess.size())
262 SOPT_THROW(
"target and residual vector have inconsistent sizes");
263 if ((
Phi() * x_guess).size() !=
target().size())
264 SOPT_THROW(
"target, measurement operator and input vector have inconsistent sizes");
266 SOPT_WARN(
"No convergence function was provided: algorithm will run for {} steps", itermax());
279 template <
typename SCALAR>
283 bool &random_measurement_update,
284 bool &random_wavelet_update,
t_Vector &u_update,
287 if (random_measurement_update) {
288 g_proximal(v_hold, rho(), v + residual);
289 v_hold = v + residual - v_hold;
290 v = v + update_scale() * (v_hold - v);
291 v_update =
static_cast<t_Vector>(Phi().adjoint() * v);
294 if (random_wavelet_update) {
295 q =
static_cast<t_Vector>(Psi().adjoint() * out_hold) *
sigma();
296 f_proximal(u_hold, regulariser_strength(), (u + q));
297 u_hold = u + q - u_hold;
298 u = u + update_scale() * (u_hold - u);
299 u_update =
static_cast<t_Vector>(Psi() * u);
304 if (v_all_sum_all_comm().size() > 0 and u_all_sum_all_comm().size() > 0)
307 r - tau() * (u_all_sum_all_comm().all_sum_all(
static_cast<const t_Vector>(u_update)) +
308 v_all_sum_all_comm().all_sum_all(
static_cast<const t_Vector>(v_update))));
311 constraint()(out_hold, r - tau() * (u_update + v_update));
312 out = r + update_scale() * (out_hold - r);
313 out_hold = 2 * out_hold - r;
314 random_measurement_update = random_measurement_updater_();
315 random_wavelet_update = random_wavelet_updater_();
317 if (random_measurement_update)
318 residual =
static_cast<t_Vector>(Phi() * out_hold) * xi() -
target();
321 template <
typename SCALAR>
325 sanity_check(x_guess, res_guess);
326 bool random_measurement_update = random_measurement_updater_();
327 bool random_wavelet_update = random_wavelet_updater_();
341 bool converged =
false;
342 for (; (not converged) && (niters < itermax()); ++niters) {
343 SOPT_LOW_LOG(
" - [Primal Dual] Iteration {}/{}", niters, itermax());
344 iteration_step(out, out_hold, u, u_hold, v, v_hold, residual, q, r, random_measurement_update,
345 random_wavelet_update, u_update, v_update);
347 static_cast<t_Vector>(residual).array().abs().sum());
348 converged = is_converged(out, residual);
352 SOPT_MEDIUM_LOG(
" - [Primal Dual] converged in {} of {} iterations", niters, itermax());
353 }
else if (
static_cast<bool>(is_converged())) {
355 SOPT_ERROR(
" - [Primal Dual] did not converge within {} iterations", itermax());
357 return {niters, converged, std::move(residual)};
sopt::Vector< Scalar > t_Vector
ProximalFunction< Scalar > t_Proximal
Type of the convergence function.
PrimalDual< Scalar > & target(Eigen::MatrixBase< DERIVED > const &target)
Sets the vector of target measurements.
SCALAR value_type
Scalar type.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector, t_Vector > const &guess) const
Calls Primal Dual.
SOPT_MACRO(f_proximal, t_Proximal)
First proximal.
t_Vector const & target() const
Vector of target measurements.
DiagnosticAndResult operator()() const
Calls Primal Dual.
SOPT_MACRO(random_wavelet_updater, t_Random_Updater)
lambda that determines if to update wavelets
std::function< bool()> t_Random_Updater
Type of random update function.
Vector< Scalar > t_Vector
Type of then underlying vectors.
Diagnostic operator()(t_Vector &out) const
Calls Primal Dual.
Diagnostic operator()(t_Vector &out, std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Primal Dual.
SOPT_MACRO(xi, Real)
xi parameter
bool is_converged(t_Vector const &x, t_Vector const &residual) const
Facilitates call to user-provided convergence function.
PrimalDual &::type Psi(ARGS &&... args)
SOPT_MACRO(Phi, t_LinearTransform)
Measurement operator.
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
value_type Scalar
Scalar type.
SOPT_MACRO(rho, Real)
rho parameter
std::function< void(t_Vector &, const t_Vector &)> t_Constraint
Type of the constraint function.
SOPT_MACRO(constraint, t_Constraint)
A function applying a simple constraint.
PrimalDual &::type Phi(ARGS &&... args)
std::function< bool(const t_Vector &, const t_Vector &)> t_IsConverged
Type of the convergence function.
SOPT_MACRO(regulariser_strength, Real)
γ parameter.
SOPT_MACRO(g_proximal, t_Proximal)
Second proximal.
void f_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const
Simplifies calling the proximal of f.
SOPT_MACRO(tau, Real)
tau parameter
PrimalDual< Scalar > & is_converged(std::function< bool(t_Vector const &x)> const &func)
Convergence function that takes only the output as argument.
void g_proximal(t_Vector &out, Real regulariser_strength, t_Vector const &x) const
Simplifies calling the proximal of f.
DiagnosticAndResult operator()(std::tuple< t_Vector const &, t_Vector const & > const &guess) const
Calls Primal Dual.
PrimalDual(t_Proximal const &f_proximal, t_Proximal const &g_proximal, Eigen::MatrixBase< DERIVED > const &target)
SOPT_MACRO(Psi, t_LinearTransform)
Wavelet operator.
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SOPT_MACRO(random_measurement_updater, t_Random_Updater)
lambda that determines if to update measurements
DiagnosticAndResult operator()(std::tuple< t_Vector, t_Vector > const &guess) const
Calls Primal Dual.
std::tuple< t_Vector, t_Vector > initial_guess() const
Computes initial guess for x and the residual using the targets.
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const
Makes it simple to chain different calls to PD.
typename real_type< Scalar >::type Real
Real type.
static std::tuple< t_Vector, t_Vector > initial_guess(t_Vector const &target, t_LinearTransform const &phi)
Computes initial guess for x and the residual using the targets.
SOPT_MACRO(sigma, Real)
sigma parameter
SOPT_MACRO(update_scale, Real)
Update parameter.
Computes inner-most element type.
#define SOPT_MPI
Whether or not to include mpi.
#define SOPT_LOW_LOG(...)
Low priority message.
#define SOPT_HIGH_LOG(...)
High priority message.
#define SOPT_ERROR(...)
\macro Something is definitely wrong, algorithm exits
#define SOPT_WARN(...)
\macro Something might be going wrong
#define SOPT_MEDIUM_LOG(...)
Medium priority message.
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
size_t t_uint
Root of the type hierarchy for unsigned integers.
Vector< T > target(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
std::function< void(Vector< SCALAR > &output, typename real_type< SCALAR >::type const weight, Vector< SCALAR > const &input)> ProximalFunction
Typical function signature for calls to proximal.
real_type< T >::type sigma(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Holds result vector as well.
Values indicating how the algorithm ran.
Diagnostic(t_uint niters, bool good, t_Vector &&residual)
t_Vector residual
the residual from the last iteration
t_uint niters
Number of iterations.
Diagnostic(t_uint niters=0u, bool good=false)
bool good
Wether convergence was achieved.