4 #include "sopt/config.h"
22 template <
typename SCALAR>
54 : itermax_(std::numeric_limits<
t_uint>::max()),
56 conjugate_gradient_(std::numeric_limits<
t_uint>::max(), 1e-6),
57 is_converged_([](
t_Vector const &) {
return false; }) {}
62 #define SOPT_MACRO(NAME, TYPE) \
63 TYPE const &NAME() const { return NAME##_; } \
64 SDMM<SCALAR> &NAME(TYPE const &(NAME)) { \
84 conjugate_gradient_.itermax(itermax);
85 conjugate_gradient_.tolerance(tolerance);
89 template <
typename PROXIMAL,
typename T>
96 template <
typename PROXIMAL>
98 return append(proximal, linear_transform_identity<Scalar>());
101 template <
typename PROXIMAL,
typename L,
typename LADJOINT>
103 return append(proximal, linear_transform<t_Vector>(l, ladjoint));
106 template <
typename PROXIMAL,
typename L,
typename LADJOINT>
108 return append(proximal, linear_transform<t_Vector>(l, ladjoint, sizes));
111 template <
typename PROXIMAL,
typename L,
typename LADJOINT>
113 std::array<t_int, 3> isizes) {
114 return append(proximal, linear_transform<t_Vector>(l, dsizes, ladjoint, isizes));
124 static_cast<Diagnostic &
>(result) =
operator()(result.
x, input);
130 static_cast<Diagnostic &
>(result) =
operator()(result.
x, warmstart.
x);
135 std::vector<t_LinearTransform>
const &
transforms()
const {
return transforms_; }
137 std::vector<t_LinearTransform> &
transforms() {
return transforms_; }
144 std::vector<t_Proximal>
const &
proximals()
const {
return proximals_; }
146 std::vector<t_Proximal> &
proximals() {
return proximals_; }
152 template <
typename T0>
154 t_uint i, Eigen::MatrixBase<T0>
const &x)
const {
165 template <
typename T0,
typename... T>
167 -> decltype(this->
conjugate_gradient()(std::forward<T0>(t0), std::forward<T>(args)...)) {
176 std::vector<t_LinearTransform> transforms_;
178 std::vector<t_Proximal> proximals_;
181 using t_Vectors = std::vector<t_Vector>;
184 t_Vectors
const &z)
const;
186 virtual void update_directions(t_Vectors &y, t_Vectors &z,
t_Vector const &x)
const;
189 virtual void initialization(t_Vectors &y, t_Vectors &z,
t_Vector const &x)
const;
192 virtual void sanity_check(
t_Vector const &input)
const;
195 template <
typename SCALAR>
199 bool convergence =
false;
202 auto const has_finished = [&convergence, &niters,
this](
t_Vector const &out) {
203 convergence = is_converged(out);
204 return niters >= itermax() or convergence;
209 t_Vectors y(transforms().size());
210 t_Vectors z(transforms().size());
213 initialization(y, z, input);
214 auto cg_diagnostic = solve_for_xn(out, y, z);
216 while (not has_finished(out)) {
219 update_directions(y, z, out);
221 std::accumulate(z.begin(), z.end(),
Scalar(0e0),
224 cg_diagnostic = solve_for_xn(out, y, z);
225 SOPT_LOW_LOG(
" - CG Residual = {} in {}/{} iterations", cg_diagnostic.residual,
226 cg_diagnostic.niters, conjugate_gradient().itermax());
230 return {niters, convergence, cg_diagnostic};
233 template <
typename SCALAR>
235 t_Vectors
const &z)
const {
236 assert(z.size() == transforms().size());
237 assert(y.size() == transforms().size());
242 for (
t_uint i(0); i < transforms().size(); ++i)
b += transforms(i).adjoint() * (y[i] - z[i]);
243 if (
b.stableNorm() < 1e-12) {
250 out = out.Zero(input.size());
251 for (
auto const &transform : this->transforms())
252 out += transform.adjoint() *
static_cast<t_Vector>(transform * input);
256 auto const diagnostic = this->conjugate_gradient(out, A,
b);
257 if (not diagnostic.good) {
258 SOPT_ERROR(
"CG error - iterations: {}/{} - residuals {}\n", diagnostic.niters,
259 conjugate_gradient().itermax(), diagnostic.residual);
260 SOPT_THROW(
"Conjugate gradient failed to converge");
266 template <
typename SCALAR>
267 void SDMM<SCALAR>::update_directions(t_Vectors &y, t_Vectors &z,
t_Vector const &x)
const {
269 for (
t_uint i(0); i < transforms().size(); ++i) {
270 z[i] += transforms(i) * x;
271 y[i] = proximals(i, z[i]);
276 template <
typename SCALAR>
277 void SDMM<SCALAR>::initialization(t_Vectors &y, t_Vectors &z,
t_Vector const &x)
const {
279 for (
t_uint i(0); i < transforms().size(); i++) {
280 y[i] = transforms(i) * x;
281 z[i].resize(y[i].size());
283 assert(z[i].size() == y[i].size());
284 SOPT_TRACE(
" - transform {}: {}", i, y[i].transpose());
288 template <
typename SCALAR>
289 void SDMM<SCALAR>::sanity_check(
t_Vector const &x)
const {
291 if (proximals().size() != transforms().size()) {
292 SOPT_ERROR(
"Internal error: number of proximals and transforms do not match");
295 if (x.size() == 0)
SOPT_WARN(
"Input vector has zero size");
296 if (size() == 0)
SOPT_WARN(
"No operators - SDMM is empty");
297 for (
t_uint i(0); i < size(); ++i) {
298 auto const xdual = t_Vector::Zero((transforms(i) * x).size());
299 auto const r = (transforms(i).adjoint() * xdual).size();
301 SOPT_ERROR(
"Output size of transform {} and input do not match: {} vs {}", i, r, x.size());
305 if (doexit)
SOPT_THROW(
"Input to SDMM is inconsistent");
sopt::Vector< Scalar > t_Vector
Solves $Ax = b$ for $x$, given $A$ and $b$.
Simultaneous-direction method of the multipliers.
SDMM< SCALAR > & append(PROXIMAL proximal, L l, std::array< t_int, 3 > dsizes, LADJOINT ladjoint, std::array< t_int, 3 > isizes)
Appends a proximal with the linear transform as pair of functions.
auto conjugate_gradient(T0 &&t0, T &&... args) const -> decltype(this->conjugate_gradient()(std::forward< T0 >(t0), std::forward< T >(args)...))
Forwards to internal conjugage gradient object.
std::vector< t_LinearTransform > const & transforms() const
Linear transforms associated with each objective function.
bool is_converged(t_Vector const &x) const
Forwards to convergence function parameter.
SDMM< SCALAR > & append(PROXIMAL proximal, L l, LADJOINT ladjoint)
Appends a proximal with the linear transform as pair of functions.
proximal::ProximalExpression< t_Proximal const &, T0 > proximals(t_uint i, Eigen::MatrixBase< T0 > const &x) const
Lazy call to specific proximal function.
SDMM< SCALAR > & append(PROXIMAL proximal, L l, LADJOINT ladjoint, std::array< t_int, 3 > sizes)
Appends a proximal with the linear transform as pair of functions.
SDMM< SCALAR > & append(PROXIMAL proximal)
Appends a proximal with identity as the linear transform.
std::vector< t_Proximal > & proximals()
Linear transforms associated with each objective function.
Diagnostic operator()(t_Vector &out, t_Vector const &input) const
Implements SDMM.
t_Proximal & proximals(t_uint i)
Proximal associated with a given objective function.
Vector< SCALAR > t_Vector
Type of then underlying vectors.
value_type Scalar
Scalar type.
SDMM< SCALAR > & conjugate_gradient(t_uint itermax, t_real tolerance)
Helps setup conjugate gradient.
t_LinearTransform & transforms(t_uint i)
Linear transform associated with a given objective function.
typename real_type< Scalar >::type Real
Real type.
t_Proximal const & proximals(t_uint i) const
Proximal associated with a given objective function.
std::vector< t_Proximal > const & proximals() const
Proximal of each objective function.
t_LinearTransform const & transforms(t_uint i) const
Linear transform associated with a given objective function.
SOPT_MACRO(gamma, Real)
Gamma.
ProximalFunction< SCALAR > t_Proximal
Type of the proximal functions.
SDMM< SCALAR > & append(PROXIMAL proximal, T args)
Appends a proximal and linear transform.
SCALAR value_type
Scalar type.
std::vector< t_LinearTransform > & transforms()
Linear transforms associated with each objective function.
ConvergenceFunction< SCALAR > t_IsConverged
Type of the convergence function.
DiagnosticAndResult operator()(t_Vector const &input) const
t_uint size() const
Number of terms.
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const
Makes it simple to chain different calls to SDMM.
SOPT_MACRO(itermax, t_uint)
Maximum number of iterations.
SOPT_MACRO(conjugate_gradient, ConjugateGradient)
Conjugate gradient.
SOPT_MACRO(is_converged, t_IsConverged)
A function verifying convergence.
Computes inner-most element type.
Expression referencing a lazy proximal function call.
#define SOPT_LOW_LOG(...)
Low priority message.
#define SOPT_HIGH_LOG(...)
High priority message.
#define SOPT_ERROR(...)
\macro Something is definitely wrong, algorithm exits
#define SOPT_WARN(...)
\macro Something might be going wrong
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
double t_real
Root of the type hierarchy for real numbers.
size_t t_uint
Root of the type hierarchy for unsigned integers.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
std::function< void(Vector< SCALAR > &output, typename real_type< SCALAR >::type const weight, Vector< SCALAR > const &input)> ProximalFunction
Typical function signature for calls to proximal.
std::function< bool(Vector< SCALAR > const &)> ConvergenceFunction
Typical function signature for convergence.
Values indicating how the algorithm ran.
Vector< SCALAR > x
Vector which minimizes the sum of functions.
Values indicating how the algorithm ran.
t_uint niters
Number of iterations.
ConjugateGradient::Diagnostic cg_diagnostic
Conjugate gradient result.
bool good
Wether convergence was achieved.