SOPT
Sparse OPTimisation
power_method.cc
Go to the documentation of this file.
1 #include <numeric>
2 #include <random>
3 #include <Eigen/Eigenvalues>
4 #include "catch2/catch_all.hpp"
5 
6 #include "sopt/power_method.h"
7 
8 using Catch::Approx;
9 
10 TEST_CASE("Power Method") {
11  using namespace sopt;
12  using Scalar = t_real;
13  auto constexpr N = 10;
14 
15  Eigen::EigenSolver<Matrix<Scalar>> es;
16  Matrix<Scalar> A(N, N);
17  std::iota(A.data(), A.data() + A.size(), 0);
18  es.compute(A.adjoint() * A, true);
19 
20  auto const eigenvalues = es.eigenvalues();
21  auto const eigenvectors = es.eigenvectors();
22  Eigen::DenseIndex index;
23  (eigenvalues.transpose() * eigenvalues).real().maxCoeff(&index);
24  auto const eigenvalue = eigenvalues(index);
25  Vector<t_complex> const eigenvector = eigenvectors.col(index);
26  // Create input vector close to solution
27  Vector<t_complex> const input = eigenvector * 1e-4 + Vector<t_complex>::Random(N);
28  auto const pm = algorithm::PowerMethod<t_complex>().tolerance(1e-12);
29 
30  SECTION("AtA") {
31  auto const lt = linear_transform(A.cast<t_complex>());
32  auto const result = pm.AtA(lt, input);
33  CHECK(result.good);
34  CAPTURE(eigenvalue);
35  CAPTURE(result.magnitude);
36  CAPTURE(result.eigenvector.transpose() * eigenvector);
37  CHECK(std::abs(result.magnitude - std::abs(eigenvalue)) < 1e-8);
38  }
39 
40  SECTION("A") {
41  auto const result = pm((A.adjoint() * A).cast<t_complex>(), input);
42  CHECK(result.good);
43  CAPTURE(eigenvalue);
44  CAPTURE(result.magnitude);
45  CAPTURE(result.eigenvector.transpose() * eigenvector);
46  CHECK(std::abs(result.magnitude - std::abs(eigenvalue)) < 1e-8);
47  }
48 }
49 
50 TEST_CASE("Power Method (from Purify)") {
51  using namespace sopt;
52  using Scalar = t_real;
53  auto constexpr N = 10;
54  constexpr t_uint power_iters = 100000;
55  constexpr t_real power_tol = 1e-6;
56  Eigen::EigenSolver<Matrix<Scalar>> es;
57  Matrix<Scalar> A(N, N);
58  std::iota(A.data(), A.data() + A.size(), 0);
59  es.compute(A.adjoint() * A, true);
60 
61  auto const eigenvalues = es.eigenvalues();
62  auto const eigenvectors = es.eigenvectors();
63  Eigen::DenseIndex index;
64  (eigenvalues.transpose() * eigenvalues).real().maxCoeff(&index);
65  auto const eigenvalue = eigenvalues(index);
66  Vector<t_complex> const eigenvector = eigenvectors.col(index);
67  // Create input vector close to solution
68  Vector<t_complex> const input = eigenvector * 1e-4 + Vector<t_complex>::Random(N);
69 
70  const auto forward = [=](Vector<t_complex> &out, const Vector<t_complex> &in) { out = A * in; };
71  const auto backward = [=](Vector<t_complex> &out, const Vector<t_complex> &in) {
72  out = A.adjoint() * in;
73  };
74 
75  SECTION("Power Method") {
76  const sopt::LinearTransform<Vector<t_complex>> op = {forward, backward};
77  auto const result =
78  algorithm::power_method<Vector<t_complex>>(op, power_iters, power_tol, input);
79  const t_real op_norm = std::get<0>(result);
80  const Vector<t_complex> op_eigen_vector_c = std::get<1>(result);
81  CHECK(op_eigen_vector_c.unaryExpr([](t_complex x) { return std::arg(x); })
82  .isApprox(Vector<t_complex>::Constant(op_eigen_vector_c.size(),
83  std::arg(op_eigen_vector_c(0))),
84  power_tol));
85  const Vector<t_complex> op_eigen_vector =
86  op_eigen_vector_c * std::polar<t_real>(1, -std::arg(op_eigen_vector_c(0)));
87  CAPTURE(eigenvalue);
88  CAPTURE(op_norm * op_norm);
89  CAPTURE(op_eigen_vector);
90  CAPTURE(eigenvector);
91  CHECK(op_norm == Approx(std::sqrt(std::abs(eigenvalue))).epsilon(power_tol));
92  CHECK(op_eigen_vector.isApprox(eigenvector, power_tol));
93  auto const norm_operator_result =
94  algorithm::normalise_operator<Vector<t_complex>>(op, power_iters, power_tol, input);
95  CHECK(std::get<0>(norm_operator_result) == Approx(op_norm).epsilon(1e-12));
96  CHECK(std::get<1>(norm_operator_result).isApprox(op_eigen_vector_c, 1e-12));
97  CHECK(((op * input) / op_norm)
98  .eval()
99  .isApprox((std::get<2>(norm_operator_result) * input).eval(), 1e-12));
100  CHECK(((op.adjoint() * input) / op_norm)
101  .eval()
102  .isApprox((std::get<2>(norm_operator_result).adjoint() * input).eval(), 1e-12));
103  }
104 }
constexpr auto N
Definition: wavelets.cc:57
sopt::t_real Scalar
Joins together direct and indirect operators.
LinearTransform< VECTOR > adjoint() const
Indirect transform.
Eigenvalue and eigenvector for eigenvalue with largest magnitude.
Definition: power_method.h:137
LinearTransform< VECTOR > linear_transform(OperatorFunction< VECTOR > const &direct, OperatorFunction< VECTOR > const &indirect, std::array< t_int, 3 > const &sizes={{1, 1, 0}})
double t_real
Root of the type hierarchy for real numbers.
Definition: types.h:17
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
real_type< T >::type epsilon(sopt::LinearTransform< Vector< T >> const &sampling, sopt::Image< T > const &image)
Definition: inpainting.h:38
Eigen::Matrix< T, Eigen::Dynamic, 1 > Vector
A vector of a given type.
Definition: types.h:24
std::complex< t_real > t_complex
Root of the type hierarchy for (real) complex numbers.
Definition: types.h:19
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > Matrix
A matrix of a given type.
Definition: types.h:29
TEST_CASE("Power Method")
Definition: power_method.cc:10