Wavelet transform#
Lets start by importing some packages
[1]:
# Lets set the precision.
from jax.config import config
config.update("jax_enable_x64", True)
# Import math libraries.
import numpy as np
import jax.numpy as jnp
# Check which devices we're running on.
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
# Import the s2ball library.
import s2ball
from s2ball.transform import ball_wavelet, laguerre
cpu
<ipython-input-1-486a22980479>:2: DeprecationWarning: Accessing jax.config via the jax.config submodule is deprecated.
from jax.config import config
Generate a random complex bandlimited field#
Here we generate random Spherical-Laguerre coefficients flmp which we then convert into a bandlimit signal f on \(\mathbb{B}^4=\mathbb{R}^+\times\mathbb{S}^2\). We also compute some matrices which are cached and pass to their associated functions at run time.
[2]:
L = 32 # Harmonic bandlimit of the problem.
P = 32 # Radial bandlimit of the problem.
N = 3 # Azimuthal (directional) bandlimit of problem.
# Define a random seed.
rng = np.random.default_rng(193412341234)
# Use s2ball functions to generate a random signal.
flmp = s2ball.utils.generate_flmp(rng, L, P)
f = laguerre.inverse(flmp, L, P) # Note currently this has to explicitly bandlimit flmnp,
w = ball_wavelet.forward(f, L, N, P) # as I have yet to enforce bandlimiting symmetries to
f = ball_wavelet.inverse(w, L, N, P) # generate_flmp directly.
Load/construct relevant associated Legendre matrices#
Load precomputed associated Legendre matrices which are used to evaluate the spherical harmonic transform. If these matrices have already been computed, the load function will attempt to locate them inside the .matrices hidden directory. Note that you can specify a directory of your choice, .matrices is simply a default.
[3]:
matrices = s2ball.construct.matrix.generate_matrices("wavelet", L, N, P)
Forward transform#
NumPy CPU implementation#
[4]:
w_numpy = ball_wavelet.forward_transform(f, L, N, P, matrices)
%timeit ball_wavelet.forward_transform(f, L, N, P, matrices)
465 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX GPU implementation#
[5]:
w_jax = ball_wavelet.forward_transform_jax(f, L, N, P, matrices)
%timeit ball_wavelet.forward_transform_jax(f, L, N, P, matrices)
64.5 ms ± 798 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Inverse transform#
NumPy CPU implementation#
[6]:
f_numpy = ball_wavelet.inverse_transform(w_numpy, L, N, P, matrices)
%timeit ball_wavelet.inverse_transform(w_numpy, L, N, P, matrices)
354 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAX GPU implementation#
[7]:
f_jax = ball_wavelet.inverse_transform_jax(w_jax, L, N, P, matrices)
%timeit ball_wavelet.inverse_transform_jax(w_jax, L, N, P, matrices)
60.9 ms ± 932 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Evaluate transform error#
[8]:
print("Round-trip: Mean absolute difference = {}".format(np.nanmean(np.abs(f_numpy - f_jax))))
print("Numpy: Mean absolute error = {}".format(np.nanmean(np.abs(f_numpy - f))))
print("JAX: Mean absolute error = {}".format(np.nanmean(np.abs(f_jax - f))))
Round-trip: Mean absolute difference = 3.997665540708201e-15
Numpy: Mean absolute error = 3.062883484046315e-13
JAX: Mean absolute error = 3.062058563688262e-13