Wigner transform#
Lets start by importing some packages which we’ll be using in this notebook
[10]:
# Lets set the precision.
from jax.config import config
config.update("jax_enable_x64", True)
# Import math libraries.
import numpy as np
import jax.numpy as jnp
# Check which devices we're running on.
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
# Import the s2ball library.
import s2ball
from s2ball.transform import wigner
cpu
<ipython-input-10-729094cb9b3b>:2: DeprecationWarning: Accessing jax.config via the jax.config submodule is deprecated.
from jax.config import config
Generate a random complex bandlimited field#
Here we generate random harmonic coefficients flmn_3d which we then convert into a bandlimit signal f on SO(3).
[11]:
L = 32 # Harmonic bandlimit of the problem.
N = 5 # Azimuthal (directional) bandlimit of problem.
# Define a random seed.
rng = np.random.default_rng(193412341234)
# Use s2ball functions to generate a random signal.
flmn = s2ball.utils.generate_flmn(rng, L, N)
f = wigner.inverse(flmn, L, N)
Load/construct relevant Wigner kernels#
Load precomputed Wigner matrices which are used to evaluate the Wigner transform. If these matrices have already been computed, the load function will attempt to locate them inside the .matrices hidden directory. Note that you can specify a directory of your choice, .matrices is simply a default.
[12]:
matrices = s2ball.construct.matrix.generate_matrices("wigner", L, N)
Forward transform#
Shape: \((2N-1, L, 2L-1) \rightarrow (2N-1,L, 2L-1)\) triangularly oversampled wigner coefficients.
NumPy CPU implementation#
[13]:
flmn_numpy = wigner.forward_transform(f, matrices, L, N)
%timeit wigner.forward_transform(f, matrices, L, N)
957 µs ± 8.57 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
JAX GPU implementation#
[5]:
flmn_jax = wigner.forward_transform_jax(f, matrices, L, N)
%timeit wigner.forward_transform_jax(f, matrices, L, N)
10.7 ms ± 9.33 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Evaluate transform error#
[6]:
flmn_jax = np.array(flmn_jax)
print("Numpy: Forward mean absolute error = {}".format(np.nanmean(np.abs(flmn_numpy - flmn))))
print("JAX: Forward mean absolute error = {}".format(np.nanmean(np.abs(flmn_jax - flmn))))
Numpy: Forward mean absolute error = 4.051471247538891e-16
JAX: Forward mean absolute error = 4.009922861864243e-16
Inverse transform#
Shape: \((2N-1, L, 2L-1) \rightarrow (2N-1, L, 2L-1)\).
NumPy CPU implementation#
[7]:
f_numpy = wigner.inverse_transform(flmn_numpy, matrices, L)
%timeit wigner.inverse_transform(flmn_numpy, matrices, L)
16.6 ms ± 45.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
JAX GPU implementation#
[8]:
f_jax = wigner.inverse_transform_jax(flmn_jax, matrices, L)
%timeit wigner.inverse_transform_jax(flmn_jax, matrices, L)
10.6 ms ± 26.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Evaluate transform error#
[9]:
print("Numpy: Forward mean absolute error = {}".format(np.nanmean(np.abs(f_numpy - f))))
print("JAX: Forward mean absolute error = {}".format(np.nanmean(np.abs(f_jax - f))))
Numpy: Forward mean absolute error = 1.9107959830973208e-14
JAX: Forward mean absolute error = 1.9281520313991466e-14