SOPT
Sparse OPTimisation
sara.h
Go to the documentation of this file.
1 #ifndef SOPT_WAVELETS_SARA_H
2 #define SOPT_WAVELETS_SARA_H
3 
4 #include "sopt/config.h"
5 #include <algorithm> // for std::min<>
6 #include <cmath>
7 #include <initializer_list>
8 #include <string> // for std::string
9 #include <tuple>
10 #include <vector>
11 #include "sopt/logging.h"
12 #include "sopt/wavelets/wavelets.h"
13 #ifdef SOPT_MPI
14 #include "sopt/mpi/communicator.h"
15 #endif
16 
17 namespace sopt::wavelets {
18 
20 class SARA : public std::vector<Wavelet> {
21  public:
22 #ifndef SOPT_HAS_NOT_USING
23  // Constructors
24  using std::vector<Wavelet>::vector;
25 #else
27  SARA() : std::vector<Wavelet>(){};
28 #endif
30  SARA(std::initializer_list<std::tuple<std::string, t_uint>> const &init)
31  : SARA(init.begin(), init.end()) {}
33  template <typename ITERATOR,
34  class T = typename std::enable_if<
35  std::is_convertible<decltype(std::get<0>(*std::declval<ITERATOR>())),
36  std::string>::value and
37  std::is_convertible<decltype(std::get<1>(*std::declval<ITERATOR>())),
38  t_uint>::value>::type>
39  SARA(ITERATOR first, ITERATOR last) {
40  for (; first != last; ++first) emplace_back(std::get<0>(*first), std::get<1>(*first));
41  }
42 
43  SARA(const_iterator first, const_iterator last) : std::vector<Wavelet>(first, last) {}
45  virtual ~SARA() {}
46 
54  template <typename T0>
55  typename T0::PlainObject direct(Eigen::ArrayBase<T0> const &signal) const;
64  template <typename T0, typename T1>
65  void direct(Eigen::ArrayBase<T1> &coefficients, Eigen::ArrayBase<T0> const &signal) const;
76  template <typename T0, typename T1>
77  void direct(Eigen::ArrayBase<T1> &&coefficients, Eigen::ArrayBase<T0> const &signal) const {
78  direct(coefficients, signal);
79  }
85  template <typename T0>
86  typename T0::PlainObject indirect(Eigen::ArrayBase<T0> const &coeffs) const;
92  template <typename T0, typename T1>
93  void indirect(Eigen::ArrayBase<T1> const &coefficients, Eigen::ArrayBase<T0> &signal) const;
101  template <typename T0, typename T1>
102  void indirect(Eigen::ArrayBase<T1> const &coeffs, Eigen::ArrayBase<T0> &&signal) const {
103  indirect(coeffs, signal);
104  }
105 
107  t_uint max_levels() const {
108  if (size() == 0) return 0;
109  auto cmp = [](Wavelet const &a, Wavelet const &b) { return a.levels() < b.levels(); };
110  return std::max_element(begin(), end(), cmp)->levels();
111  }
112 
114  void emplace_back(std::string const &name, t_uint nlevels) {
115  std::vector<Wavelet>::emplace_back(factory(name, nlevels));
116  }
117 };
118 
119 #define SOPT_WAVELET_ERROR_MACRO(INPUT) \
120  if ((INPUT).rows() % (1u << max_levels()) != 0) \
121  throw std::length_error("Inconsistent number of columns and wavelet levels"); \
122  else if ((INPUT).cols() != 1 and (INPUT).cols() % (1u << max_levels())) \
123  throw std::length_error("Inconsistent number of rows and wavelet levels");
124 
125 template <typename T0, typename T1>
126 void SARA::direct(Eigen::ArrayBase<T1> &coeffs, Eigen::ArrayBase<T0> const &signal) const {
127  SOPT_WAVELET_ERROR_MACRO(signal);
128  if (coeffs.rows() != signal.rows() or coeffs.cols() != signal.cols() * static_cast<t_int>(size()))
129  coeffs.derived().resize(signal.rows(), signal.cols() * size());
130  if (coeffs.rows() != signal.rows() or coeffs.cols() != signal.cols() * static_cast<t_int>(size()))
131  throw std::length_error("Incorrect size for output matrix(or could not resize)");
132  if (size() == 0) return;
133  auto const Ncols = signal.cols();
134  {
135 #ifndef SOPT_OPENMP
136  SOPT_TRACE("Calling direct sara without threads");
137 #endif
138  if (Ncols == 1)
139  for (size_type i = 0; i < size(); ++i) at(i).direct(coeffs.col(i), signal.col(0));
140  else
141  for (size_type i = 0; i < size(); ++i)
142  at(i).direct(coeffs.leftCols((i + 1) * Ncols).rightCols(Ncols), signal);
143  }
144 
145  coeffs /= std::sqrt(size());
146 }
147 
148 template <typename T0, typename T1>
149 void SARA::indirect(Eigen::ArrayBase<T1> const &coeffs, Eigen::ArrayBase<T0> &signal) const {
150  if (size() == 0) throw std::runtime_error("Empty wavelets: adjoint operation undefined");
151  if (signal.cols() == 1) {
152  if (coeffs.rows() % (1u << max_levels()) != 0)
153  throw std::length_error("Inconsistent number of columns and wavelet levels");
154  } else
155  SOPT_WAVELET_ERROR_MACRO(coeffs);
156  if (coeffs.cols() % size() != 0)
157  throw std::length_error(
158  "Columns of coefficient matrix and number of wavelets are inconsistent");
159  if (coeffs.rows() != signal.rows() or coeffs.cols() != signal.cols() * static_cast<t_int>(size()))
160  signal.derived().resize(coeffs.rows(), coeffs.cols() / size());
161  if (coeffs.rows() != signal.rows() or coeffs.cols() != signal.cols() * static_cast<t_int>(size()))
162  throw std::length_error("Incorrect size for output matrix(or could not resize)");
163  auto const Ncols = signal.cols();
164  {
165 #ifndef SOPT_OPENMP
166  SOPT_TRACE("Calling indirect sara without threads");
167 #endif
168  signal = T0::Zero(signal.rows(), signal.cols());
169  if (Ncols == 1)
170  for (size_type i = 0; i < size(); ++i) signal.col(0) += at(i).indirect(coeffs.col(i));
171  else
172  for (size_type i = 0; i < size(); ++i)
173  signal += at(i).indirect(coeffs.leftCols((i + 1) * Ncols).rightCols(Ncols));
174  }
175  signal /= std::sqrt(size());
176 }
177 
178 #undef SOPT_WAVELET_ERROR_MACRO
179 
180 template <typename T0>
181 typename T0::PlainObject SARA::indirect(Eigen::ArrayBase<T0> const &coeffs) const {
182  using t_Output = decltype(this->front().indirect(coeffs));
183  t_Output signal = t_Output::Zero(coeffs.rows(), coeffs.cols() / size());
184  (*this).indirect(coeffs, signal);
185  return signal;
186 }
187 
188 template <typename T0>
189 typename T0::PlainObject SARA::direct(Eigen::ArrayBase<T0> const &signal) const {
190  using t_Output = decltype(this->front().direct(signal));
191  t_Output result = t_Output::Zero(signal.rows(), signal.cols() * size());
192  (*this).direct(result, signal);
193  return result;
194 }
195 
196 #ifdef SOPT_MPI
200 template <typename T>
201 T distribute_sara(T const &sara, t_uint size, t_uint rank) {
202  auto const start = [](t_uint size, t_uint ncomms, t_uint rank) {
203  return std::min(size, rank * (size / ncomms) + std::min(rank, size % ncomms));
204  };
205  auto const startw = start(sara.size(), size, rank);
206  auto const endw = start(sara.size(), size, rank + 1);
207  return T(sara.begin() + startw, sara.begin() + endw);
208 }
209 template <typename T>
210 T distribute_sara(T const &all_wavelets, mpi::Communicator const &comm) {
211  return distribute_sara<T>(all_wavelets, comm.size(), comm.rank());
212 }
213 #endif
214 } // namespace sopt::wavelets
215 #endif
constexpr Scalar b
constexpr Scalar a
Sparsity Averaging Reweighted Analysis.
Definition: sara.h:20
t_uint max_levels() const
Number of levels over which to do transform.
Definition: sara.h:107
void emplace_back(std::string const &name, t_uint nlevels)
Adds a wavelet of specific type.
Definition: sara.h:114
void direct(Eigen::ArrayBase< T1 > &&coefficients, Eigen::ArrayBase< T0 > const &signal) const
Direct transform.
Definition: sara.h:77
SARA(std::initializer_list< std::tuple< std::string, t_uint >> const &init)
Easy constructor.
Definition: sara.h:30
SARA(const_iterator first, const_iterator last)
Definition: sara.h:43
virtual ~SARA()
Destructor.
Definition: sara.h:45
T0::PlainObject indirect(Eigen::ArrayBase< T0 > const &coeffs) const
Indirect transform.
Definition: sara.h:181
T0::PlainObject direct(Eigen::ArrayBase< T0 > const &signal) const
Direct transform.
Definition: sara.h:189
SARA(ITERATOR first, ITERATOR last)
Construct from any iterator over a (std:string, t_uint) tuple.
Definition: sara.h:39
void indirect(Eigen::ArrayBase< T1 > const &coeffs, Eigen::ArrayBase< T0 > &&signal) const
Indirect transform.
Definition: sara.h:102
Performs direct and indirect wavelet transforms.
Definition: wavelets.h:21
#define SOPT_TRACE(...)
Definition: logging.h:220
std::shared_ptr< details::initializer > init(int argc, const char **argv)
Definition: session.cc:27
Wavelet factory(const std::string &name, t_uint nlevels)
Creates a wavelet transform object.
Definition: wavelets.cc:8
int t_int
Root of the type hierarchy for signed integers.
Definition: types.h:13
size_t t_uint
Root of the type hierarchy for unsigned integers.
Definition: types.h:15
#define SOPT_WAVELET_ERROR_MACRO(INPUT)
Definition: sara.h:119