s2ball is currently in an open alpha, please provide feedback on GitHub

Spherical harmonic transform#

Lets start by importing some packages which we’ll be using in this notebook

[2]:
# Lets set the precision.
from jax.config import config
config.update("jax_enable_x64", True)

# Import math libraries.
import numpy as np
import jax.numpy as jnp

# Check which devices we're running on.
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

# Import the s2ball library.
import s2ball
from s2ball.transform import harmonic
cpu
<ipython-input-2-9e4b2c1c9df1>:2: DeprecationWarning: Accessing jax.config via the jax.config submodule is deprecated.
  from jax.config import config

Generate a random complex bandlimited field#

Here we generate random harmonic coefficients flm_2d which we then convert into a bandlimit signal f on \(\mathbb{S}^2\).

[3]:
L = 64    # Harmonic bandlimit of the problem.
spin = 2  # Spin of the field under consideration.

# Define a random seed.
rng = np.random.default_rng(193412341234)

# Use s2ball functions to generate a random signal.
flm = s2ball.utils.generate_flm(rng, L, spin)
f = harmonic.inverse(flm, L, spin=spin)

Load/construct relevant associated Legendre matrices#

Load precomputed associated Legendre matrices which are used to evaluate the spherical harmonic transform. If these matrices have already been computed, the load function will attempt to locate them inside the .matrices hidden directory. Note that you can specify a directory of your choice, .matrices is simply a default.

[4]:
matrices = s2ball.construct.matrix.generate_matrices("spherical_harmonic", L, spin=spin)

Forward transform#

Shape: \((L, 2L-1) \rightarrow (L, 2L-1)\) triangularly oversampled spherical harmonic coefficients.

NumPy CPU implementation#

[5]:
flm_numpy = harmonic.forward_transform(f, matrices)
%timeit harmonic.forward_transform(f, matrices)
747 µs ± 8.12 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

JAX GPU implementation#

[6]:
flm_jax = harmonic.forward_transform_jax(f, matrices)
%timeit harmonic.forward_transform_jax(f, matrices)
923 µs ± 20.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Evaluate transform error#

[7]:
print("Numpy: Forward mean absolute error = {}".format(np.nanmean(np.abs(flm_numpy - flm))))
print("JAX: Forward mean absolute error = {}".format(np.nanmean(np.abs(flm_jax - flm))))
Numpy: Forward mean absolute error = 1.6464202973778054e-15
JAX: Forward mean absolute error = 1.6285429530404877e-15

Inverse transform#

Shape: \((L, 2L-1) \rightarrow (L, 2L-1)\)

NumPy CPU implementation#

[8]:
f_numpy = harmonic.inverse_transform(flm_numpy, matrices)
%timeit harmonic.inverse_transform(flm_numpy, matrices)
2.19 ms ± 60.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

JAX GPU implementation#

[9]:
f_jax = harmonic.inverse_transform_jax(flm_jax, matrices)
%timeit harmonic.inverse_transform_jax(flm_jax, matrices)
662 µs ± 8.19 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Evaluate transform error#

[10]:
print("Numpy: Inverse mean absolute error = {}".format(np.nanmean(np.abs(f_numpy - f))))
print("JAX: Inverse mean absolute error = {}".format(np.nanmean(np.abs(f_jax - f))))
Numpy: Inverse mean absolute error = 4.9853753961577024e-14
JAX: Inverse mean absolute error = 4.86174558520892e-14