s2fft is currently in an open beta, please provide feedback on GitHub

# JAX HEALPix frontend#

colab image

[1]:
import sys
IN_COLAB = 'google.colab' in sys.modules

# Install s2fft and data if running on google colab.
if IN_COLAB:
    !pip install s2fft &> /dev/null

This short tutorial demonstrates how to use the custom JAX frontend support S2FFT provides for the `HEALPix <https://healpix.jpl.nasa.gov>`__ C++ library. This solves the long JIT compile time for HEALPix when running on CPU.

As with the other introductions, let’s import some packages and define an arbitrary bandlimited signal to work with.

[2]:
import jax
jax.config.update("jax_enable_x64", True)

import numpy as np
import s2fft

L = 128
nside = 64
method = "jax_healpy"
sampling = "healpix"
rng = np.random.default_rng(23457801234570)
flm = s2fft.utils.signal_generator.generate_flm(rng, L)
f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)

Calling forward HEALPix C++ function from JAX.#


[3]:
flm = s2fft.forward(f, L, nside=nside, sampling=sampling, method=method)

Calling inverse HEALPix C++ function from JAX.#


[4]:
f_recov = s2fft.inverse(flm, L, nside=nside, sampling=sampling,  method=method)

Computing the roundtrip error#


Let’s check the associated error, which should be around 1e-5 for healpix, which is not an exact sampling of the sphere. Note that increasing iters will reduce the numerical error here slightly, at the cost of linearly increased compute.

[5]:
print(f"Mean absolute error = {np.nanmean(np.abs(f_recov - f))}")
Mean absolute error = 2.5921182352491347e-06

Differentiating through HEALPix C++ functions.#


So far all this is doing is providing an interface between JAX and HEALPix, the real novelty comes when we differentiate through the C++ library.

[6]:
# Define an arbitrary JAX function
def differentiable_test(flm) -> int:
    f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)
    return jax.numpy.nanmean(jax.numpy.abs(f)**2)

# Create the JAX reverse mode gradient function
gradient_func = jax.grad(differentiable_test)

# Compute the gradient automatically
gradient = gradient_func(flm)

Validating these gradients#


This is all well and good, but how do we know these gradients are correct? Thankfully JAX prvoides a simple function to check this…

[7]:
from jax.test_util import check_grads
check_grads(differentiable_test, (flm,), order=1, modes=("rev"))