# JAX SSHT frontend#
[1]:
import sys
IN_COLAB = 'google.colab' in sys.modules
# Install s2fft and data if running on google colab.
if IN_COLAB:
!pip install s2fft &> /dev/null
This short tutorial demonstrates how to use the custom JAX frontend support S2FFT
provides for the `SSHT
<astro-informatics/ssht>`__ C library.
As with the other introductions, let’s import some packages and define an arbitrary bandlimited signal to work with.
[2]:
import jax
jax.config.update("jax_enable_x64", True)
import numpy as np
import s2fft
L = 128
method = "jax_ssht"
rng = np.random.default_rng(23457801234570)
flm = s2fft.utils.signal_generator.generate_flm(rng, L)
f = s2fft.inverse(flm, L, method=method)
Calling forward SSHT C function from JAX.#
[3]:
flm = s2fft.forward(f, L, method=method)
Calling inverse SSHT C function from JAX.#
[4]:
f_recov = s2fft.inverse(flm, L, method=method)
Computing the roundtrip error#
Let’s check the associated error, which should be close to machine precision for the sampling scheme used.
[5]:
print(f"Mean absolute error = {np.nanmean(np.abs(f_recov - f))}")
Mean absolute error = 7.784372519411174e-13
Differentiating through SSHT C functions.#
So far all this is doing is providing an interface between JAX
and SSHT
, the real novelty comes when we differentiate through the C library.
[6]:
# Define an arbitrary JAX function
def differentiable_test(flm) -> int:
f = s2fft.inverse(flm, L, method=method)
return jax.numpy.nanmean(jax.numpy.abs(f)**2)
# Create the JAX reverse mode gradient function
gradient_func = jax.grad(differentiable_test)
# Compute the gradient automatically
gradient = gradient_func(flm)
Validating these gradients#
This is all well and good, but how do we know these gradients are correct? Thankfully JAX
prvoides a simple function to check this…
[7]:
from jax.test_util import check_grads
check_grads(differentiable_test, (flm,), order=1, modes=("rev"))