s2wav is currently in an open beta, please provide feedback on GitHub

Wavelet transform (PyTorch)#

colab image

Note that currently we only provide precompute support for PyTorch, so these transforms will only work up until around a bandlimit of \(L\sim1024\). Support for recursive, or so called on-the-fly, algorithms is already provided in JAX and should reach PyTorch soon.

Lets start by importing some packages which we’ll be using in this notebook

[ ]:
# Install s2wav
!pip install s2wav &> /dev/null

Lets start by importing some packages which we’ll be using in this notebook

[1]:
import torch       # Differentiable programming ecosystem
import s2wav       # Wavelet transforms on the sphere and rotation group
import s2fft       # Spherical harmonic and Wigner transforms
import numpy as np
JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L.

Now we’ll define the constraints of the problem and generated some random data just for this example

[2]:
L = 16            # Spherical harmonic bandlimit
N = 3             # Azimuthal (directional) bandlimit

# Generate a random bandlimited signal to work with
rng = np.random.default_rng(12346161)
flm = s2fft.utils.signal_generator.generate_flm(rng, L)
f = s2fft.inverse(flm, L)

# We'll need to convert this numpy array into a torch.tensor
f_torch = torch.from_numpy(f)

We can calculate the wavelet and scaling coefficients by first building a bank of wavelet filters and precomputing and caching all matrices involved in the core transforms

[3]:
filter_bank = s2wav.filters.filters_directional_vectorised(L, N, using_torch=True)
analysis_matrices = s2wav.construct.generate_full_precomputes(L, N, using_torch=True, forward=False)
synthesis_matrices = s2wav.construct.generate_full_precomputes(L, N, using_torch=True, forward=True)

Now we can run the transforms, which are straightforwared linear algebra, by running

[4]:
wavelet_coeffs, scaling_coeffs = s2wav.analysis_precomp_torch(
    f_torch, L, N, filters=filter_bank, precomps=analysis_matrices
)

When an exact sampling theorem is chosen we can recover the original signal to machine precision by running

[5]:
f_check = s2wav.synthesis_precomp_torch(
    wavelet_coeffs, scaling_coeffs, L, N, filters=filter_bank, precomps=synthesis_matrices
)

Again this first call is quite slow, but subsequent calls should be much faster. Lets double check that we actually got machine precision!

[7]:
print(f"Mean absolute error = {np.nanmean(np.abs(f_check.resolve_conj().numpy() - f))}")
Mean absolute error = 2.0514116979479282e-14