1 #ifndef SOPT_WAVELETS_SARA_H
2 #define SOPT_WAVELETS_SARA_H
4 #include "sopt/config.h"
7 #include <initializer_list>
20 class SARA :
public std::vector<Wavelet> {
22 #ifndef SOPT_HAS_NOT_USING
24 using std::vector<Wavelet>::vector;
27 SARA() : std::vector<Wavelet>(){};
30 SARA(std::initializer_list<std::tuple<std::string, t_uint>>
const &
init)
33 template <
typename ITERATOR,
34 class T =
typename std::enable_if<
35 std::is_convertible<decltype(std::get<0>(*std::declval<ITERATOR>())),
36 std::string>::value and
37 std::is_convertible<decltype(std::get<1>(*std::declval<ITERATOR>())),
39 SARA(ITERATOR first, ITERATOR last) {
40 for (; first != last; ++first)
emplace_back(std::get<0>(*first), std::get<1>(*first));
43 SARA(const_iterator first, const_iterator last) : std::vector<
Wavelet>(first, last) {}
54 template <
typename T0>
55 typename T0::PlainObject
direct(Eigen::ArrayBase<T0>
const &signal)
const;
64 template <
typename T0,
typename T1>
65 void direct(Eigen::ArrayBase<T1> &coefficients, Eigen::ArrayBase<T0>
const &signal)
const;
76 template <
typename T0,
typename T1>
77 void direct(Eigen::ArrayBase<T1> &&coefficients, Eigen::ArrayBase<T0>
const &signal)
const {
78 direct(coefficients, signal);
85 template <
typename T0>
86 typename T0::PlainObject
indirect(Eigen::ArrayBase<T0>
const &coeffs)
const;
92 template <
typename T0,
typename T1>
93 void indirect(Eigen::ArrayBase<T1>
const &coefficients, Eigen::ArrayBase<T0> &signal)
const;
101 template <
typename T0,
typename T1>
102 void indirect(Eigen::ArrayBase<T1>
const &coeffs, Eigen::ArrayBase<T0> &&signal)
const {
108 if (size() == 0)
return 0;
109 auto cmp = [](
Wavelet const &
a,
Wavelet const &
b) {
return a.levels() <
b.levels(); };
110 return std::max_element(begin(), end(), cmp)->levels();
115 std::vector<Wavelet>::emplace_back(
factory(name, nlevels));
119 #define SOPT_WAVELET_ERROR_MACRO(INPUT) \
120 if ((INPUT).rows() % (1u << max_levels()) != 0) \
121 throw std::length_error("Inconsistent number of columns and wavelet levels"); \
122 else if ((INPUT).cols() != 1 and (INPUT).cols() % (1u << max_levels())) \
123 throw std::length_error("Inconsistent number of rows and wavelet levels");
125 template <
typename T0,
typename T1>
126 void SARA::direct(Eigen::ArrayBase<T1> &coeffs, Eigen::ArrayBase<T0>
const &signal)
const {
128 if (coeffs.rows() != signal.rows() or coeffs.cols() != signal.cols() *
static_cast<t_int>(size()))
129 coeffs.derived().resize(signal.rows(), signal.cols() * size());
130 if (coeffs.rows() != signal.rows() or coeffs.cols() != signal.cols() *
static_cast<t_int>(size()))
131 throw std::length_error(
"Incorrect size for output matrix(or could not resize)");
132 if (size() == 0)
return;
133 auto const Ncols = signal.cols();
136 SOPT_TRACE(
"Calling direct sara without threads");
139 for (size_type i = 0; i < size(); ++i) at(i).direct(coeffs.col(i), signal.col(0));
141 for (size_type i = 0; i < size(); ++i)
142 at(i).direct(coeffs.leftCols((i + 1) * Ncols).rightCols(Ncols), signal);
145 coeffs /= std::sqrt(size());
148 template <
typename T0,
typename T1>
149 void SARA::indirect(Eigen::ArrayBase<T1>
const &coeffs, Eigen::ArrayBase<T0> &signal)
const {
150 if (size() == 0)
throw std::runtime_error(
"Empty wavelets: adjoint operation undefined");
151 if (signal.cols() == 1) {
152 if (coeffs.rows() % (1u <<
max_levels()) != 0)
153 throw std::length_error(
"Inconsistent number of columns and wavelet levels");
156 if (coeffs.cols() % size() != 0)
157 throw std::length_error(
158 "Columns of coefficient matrix and number of wavelets are inconsistent");
159 if (coeffs.rows() != signal.rows() or coeffs.cols() != signal.cols() *
static_cast<t_int>(size()))
160 signal.derived().resize(coeffs.rows(), coeffs.cols() / size());
161 if (coeffs.rows() != signal.rows() or coeffs.cols() != signal.cols() *
static_cast<t_int>(size()))
162 throw std::length_error(
"Incorrect size for output matrix(or could not resize)");
163 auto const Ncols = signal.cols();
166 SOPT_TRACE(
"Calling indirect sara without threads");
168 signal = T0::Zero(signal.rows(), signal.cols());
170 for (size_type i = 0; i < size(); ++i) signal.col(0) += at(i).indirect(coeffs.col(i));
172 for (size_type i = 0; i < size(); ++i)
173 signal += at(i).indirect(coeffs.leftCols((i + 1) * Ncols).rightCols(Ncols));
175 signal /= std::sqrt(size());
178 #undef SOPT_WAVELET_ERROR_MACRO
180 template <
typename T0>
181 typename T0::PlainObject
SARA::indirect(Eigen::ArrayBase<T0>
const &coeffs)
const {
182 using t_Output = decltype(this->front().
indirect(coeffs));
183 t_Output signal = t_Output::Zero(coeffs.rows(), coeffs.cols() / size());
184 (*this).indirect(coeffs, signal);
188 template <
typename T0>
189 typename T0::PlainObject
SARA::direct(Eigen::ArrayBase<T0>
const &signal)
const {
190 using t_Output = decltype(this->front().
direct(signal));
191 t_Output result = t_Output::Zero(signal.rows(), signal.cols() * size());
192 (*this).direct(result, signal);
200 template <
typename T>
201 T distribute_sara(T
const &sara,
t_uint size,
t_uint rank) {
203 return std::min(size, rank * (size / ncomms) + std::min(rank, size % ncomms));
205 auto const startw = start(sara.size(), size, rank);
206 auto const endw = start(sara.size(), size, rank + 1);
207 return T(sara.begin() + startw, sara.begin() + endw);
209 template <
typename T>
210 T distribute_sara(T
const &all_wavelets, mpi::Communicator
const &comm) {
211 return distribute_sara<T>(all_wavelets, comm.size(), comm.rank());
Sparsity Averaging Reweighted Analysis.
t_uint max_levels() const
Number of levels over which to do transform.
void emplace_back(std::string const &name, t_uint nlevels)
Adds a wavelet of specific type.
void direct(Eigen::ArrayBase< T1 > &&coefficients, Eigen::ArrayBase< T0 > const &signal) const
Direct transform.
SARA(std::initializer_list< std::tuple< std::string, t_uint >> const &init)
Easy constructor.
SARA(const_iterator first, const_iterator last)
virtual ~SARA()
Destructor.
T0::PlainObject indirect(Eigen::ArrayBase< T0 > const &coeffs) const
Indirect transform.
T0::PlainObject direct(Eigen::ArrayBase< T0 > const &signal) const
Direct transform.
SARA(ITERATOR first, ITERATOR last)
Construct from any iterator over a (std:string, t_uint) tuple.
void indirect(Eigen::ArrayBase< T1 > const &coeffs, Eigen::ArrayBase< T0 > &&signal) const
Indirect transform.
Performs direct and indirect wavelet transforms.
std::shared_ptr< details::initializer > init(int argc, const char **argv)
Wavelet factory(const std::string &name, t_uint nlevels)
Creates a wavelet transform object.
int t_int
Root of the type hierarchy for signed integers.
size_t t_uint
Root of the type hierarchy for unsigned integers.
#define SOPT_WAVELET_ERROR_MACRO(INPUT)