Wavelet transform (JAX)#
[ ]:
# Install s2wav
!pip install s2wav &> /dev/null
Lets start by importing some packages which we’ll be using in this notebook
[2]:
# Make sure we configure 64 bit precision.
# 32 bit can be faster but you will be (potentially much) less precise.
import jax
jax.config.update("jax_enable_x64", True)
import s2wav # Wavelet transforms on the sphere and rotation group
import s2fft # Spherical harmonic and Wigner transforms
import numpy as np
Now we’ll define the constraints of the problem and generated some random data just for this example
[3]:
L = 16 # Spherical harmonic bandlimit
N = 3 # Azimuthal (directional) bandlimit
sampling = "mw" # Sampling scheme
# Generate a random bandlimited signal to work with
rng = np.random.default_rng(12346161)
flm = s2fft.utils.signal_generator.generate_flm(rng, L)
f = s2fft.inverse(flm, L)
We can calculate the wavelet and scaling coefficients by first building a bank of wavelet filters and the running the analysis transform
[5]:
filter_bank = s2wav.filters.filters_directional_vectorised(L, N)
wavelet_coeffs, scaling_coeffs = s2wav.analysis(f, L, N, filters=filter_bank)
You’ll notice that this first pass is very slow. That’s because it is JIT compiling the function, so future calls to s2wav.analysis
will be much fater! When an exact sampling theorem is chosen we can recover the original signal to machine precision by running
[6]:
f_check = s2wav.synthesis(wavelet_coeffs, scaling_coeffs, L, N, filters=filter_bank)
Again this first call is quite slow, but subsequent calls should be much faster. Lets double check that we actually got machine precision!
[7]:
print(f"Mean absolute error = {np.nanmean(np.abs(f_check - f))}")
Mean absolute error = 2.068390707329961e-14