s2ball is currently in an open alpha, please provide feedback on GitHub

Wigner-Laguerre transform#

Lets start by importing some packages

[1]:
# Lets set the precision.
from jax.config import config
config.update("jax_enable_x64", True)

# Import math libraries.
import numpy as np
import jax.numpy as jnp

# Check which devices we're running on.
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

# Import the s2ball library.
import s2ball
from s2ball.transform import wigner_laguerre
cpu
<ipython-input-1-7841848869cc>:2: DeprecationWarning: Accessing jax.config via the jax.config submodule is deprecated.
  from jax.config import config

Generate a random complex bandlimited field#

Here we generate random Wigner-Laguerre coefficients flmnp which we then convert into a bandlimit signal f on \(\mathbb{H}^4=\mathbb{R}^+\times \text{SO}(3)\). We also compute some matrices which are cached and pass to their associated functions at run time.

[2]:
L = 32        # Harmonic bandlimit of the problem.
P = 32        # Radial bandlimit of the problem.
N = 3         # Azimuthal (directional) bandlimit of problem.

# Define a random seed.
rng = np.random.default_rng(193412341234)

# Use s2ball functions to generate a random signal.
flmnp = s2ball.utils.generate_flmnp(rng, L, N, P)
f = wigner_laguerre.inverse(flmnp, L, N, P)    # Note currently this has to explicitly bandlimit flmnp,
flmnp = wigner_laguerre.forward(f, L, N, P)    # as I have yet to enforce bandlimiting symmetries to
f = wigner_laguerre.inverse(flmnp, L, N, P)    # generate_flmnp directly.

Load/construct relevant associated Legendre matrices#

Load precomputed associated Legendre matrices which are used to evaluate the spherical harmonic transform. If these matrices have already been computed, the load function will attempt to locate them inside the .matrices hidden directory. Note that you can specify a directory of your choice, .matrices is simply a default.

[3]:
matrices = s2ball.construct.matrix.generate_matrices("wigner_laguerre", L, N, P)

Forward transform#

Shape: \((P, 2N-1, L, 2L-1) \rightarrow (P, 2N-1, L, 2L-1)\) triangularly oversampled spherical Laguerre coefficients.

NumPy CPU implementation#

[4]:
flmnp_numpy = wigner_laguerre.forward_transform(f, matrices, L, N)
%timeit wigner_laguerre.forward_transform(f, matrices, L, N)
56.4 ms ± 981 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

JAX GPU implementation#

[5]:
flmnp_jax = wigner_laguerre.forward_transform_jax(f, matrices, L, N)
%timeit wigner_laguerre.forward_transform_jax(f, matrices, L, N)
7.69 ms ± 138 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Evaluate transform error#

[6]:
print("Numpy: Forward mean absolute error = {}".format(np.nanmean(np.abs(flmnp_numpy - flmnp))))
print("JAX: Forward mean absolute error = {}".format(np.nanmean(np.abs(flmnp_jax - flmnp))))
Numpy: Forward mean absolute error = 2.826225319563e-14
JAX: Forward mean absolute error = 2.826701487097508e-14

Inverse transform#

Shape: \((P, 2N-1, L, 2L-1) \rightarrow (P, 2N-1, L, 2L-1)\)

NumPy CPU implementation#

[7]:
f_numpy = wigner_laguerre.inverse_transform(flmnp_numpy, matrices, L)
%timeit wigner_laguerre.inverse_transform(flmnp_numpy, matrices, L)
71.5 ms ± 118 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

JAX GPU implementation#

[8]:
f_jax = wigner_laguerre.inverse_transform_jax(flmnp_jax, matrices, L)
%timeit wigner_laguerre.inverse_transform_jax(flmnp_jax, matrices, L)
7.8 ms ± 79.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Evaluate transform error#

[9]:
print("Numpy: Forward mean absolute error = {}".format(np.nanmean(np.abs(f_numpy - f))))
print("JAX: Forward mean absolute error = {}".format(np.nanmean(np.abs(f_jax - f))))
Numpy: Forward mean absolute error = 1.4534176996105977e-13
JAX: Forward mean absolute error = 1.4530986305103866e-13