Wavelet transform (PyTorch)#
Note that currently we only provide precompute support for PyTorch, so these transforms will only work up until around a bandlimit of \(L\sim1024\). Support for recursive, or so called on-the-fly, algorithms is already provided in JAX and should reach PyTorch soon.
Lets start by importing some packages which we’ll be using in this notebook
[ ]:
# Install s2wav
!pip install s2wav &> /dev/null
Lets start by importing some packages which we’ll be using in this notebook
[1]:
import torch # Differentiable programming ecosystem
import s2wav # Wavelet transforms on the sphere and rotation group
import s2fft # Spherical harmonic and Wigner transforms
import numpy as np
JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L.
Now we’ll define the constraints of the problem and generated some random data just for this example
[2]:
L = 16 # Spherical harmonic bandlimit
N = 3 # Azimuthal (directional) bandlimit
# Generate a random bandlimited signal to work with
rng = np.random.default_rng(12346161)
flm = s2fft.utils.signal_generator.generate_flm(rng, L)
f = s2fft.inverse(flm, L)
# We'll need to convert this numpy array into a torch.tensor
f_torch = torch.from_numpy(f)
We can calculate the wavelet and scaling coefficients by first building a bank of wavelet filters and precomputing and caching all matrices involved in the core transforms
[3]:
filter_bank = s2wav.filters.filters_directional_vectorised(L, N, using_torch=True)
analysis_matrices = s2wav.construct.generate_full_precomputes(L, N, using_torch=True, forward=False)
synthesis_matrices = s2wav.construct.generate_full_precomputes(L, N, using_torch=True, forward=True)
Now we can run the transforms, which are straightforwared linear algebra, by running
[4]:
wavelet_coeffs, scaling_coeffs = s2wav.analysis_precomp_torch(
f_torch, L, N, filters=filter_bank, precomps=analysis_matrices
)
When an exact sampling theorem is chosen we can recover the original signal to machine precision by running
[5]:
f_check = s2wav.synthesis_precomp_torch(
wavelet_coeffs, scaling_coeffs, L, N, filters=filter_bank, precomps=synthesis_matrices
)
Again this first call is quite slow, but subsequent calls should be much faster. Lets double check that we actually got machine precision!
[7]:
print(f"Mean absolute error = {np.nanmean(np.abs(f_check.resolve_conj().numpy() - f))}")
Mean absolute error = 2.0514116979479282e-14