1 #include <catch2/catch_all.hpp>
14 auto constexpr
N = 30;
15 SCENARIO(
"SDMM with warm start",
"[sdmm][integration]") {
18 GIVEN(
"An SDMM instance with its input") {
19 t_Matrix const Id = t_Matrix::Identity(
N,
N).eval();
20 t_Vector const target0 = t_Vector::Random(
N);
21 t_Vector target1 = t_Vector::Random(
N) * 4;
23 auto convergence = [&target1, &target0](
t_Vector const &x) ->
bool {
24 t_Vector const segment = (target1 - target0).normalized();
25 t_real const alpha = (x - target0).transpose() * segment;
26 return alpha >= 0e0 and (target1 - target0).transpose() * segment >= alpha and
27 (x - target0 - alpha * segment).stableNorm() < 1e-8;
34 .conjugate_gradient(std::numeric_limits<t_uint>::max(), 1e-12)
37 t_Vector const input = t_Vector::Random(
N);
39 WHEN(
"the algorithms runs") {
40 auto const full = sdmm(input);
41 THEN(
"it converges") {
42 CHECK(full.niters > 20);
46 WHEN(
"It is set to stop before convergence") {
47 auto const first_half = sdmm.itermax(full.niters - 5)(input);
48 THEN(
"It is not converged") { CHECK(not first_half.good); }
50 WHEN(
"A warm restart is attempted") {
51 auto const second_half = sdmm.itermax(5000)(first_half);
52 THEN(
"The warm restart is validated by the fast convergence") {
53 CHECK(second_half.niters < 10);
sopt::Vector< Scalar > t_Vector
sopt::Matrix< Scalar > t_Matrix
bool is_converged(t_Vector const &x) const
Forwards to convergence function parameter.
Proximal of euclidian norm.
Translation< FUNCTION, VECTOR > translate(FUNCTION const &func, VECTOR const &translation)
Translates given proximal by given vector.
double t_real
Root of the type hierarchy for real numbers.
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > Matrix
A matrix of a given type.
SCENARIO("SDMM with warm start", "[sdmm][integration]")