# JAX SSHT frontend#

colab image

[1]:
import sys
IN_COLAB = 'google.colab' in sys.modules

# Install s2fft and data if running on google colab.
if IN_COLAB:
    !pip install s2fft &> /dev/null

This short tutorial demonstrates how to use the custom JAX frontend support S2FFT provides for the `SSHT <astro-informatics/ssht>`__ C library.

As with the other introductions, let’s import some packages and define an arbitrary bandlimited signal to work with.

[2]:
import jax
jax.config.update("jax_enable_x64", True)

import numpy as np
import s2fft

L = 128
method = "jax_ssht"
rng = np.random.default_rng(23457801234570)
flm = s2fft.utils.signal_generator.generate_flm(rng, L)
f = s2fft.inverse(flm, L, method=method)

Calling forward SSHT C function from JAX.#


[3]:
flm = s2fft.forward(f, L, method=method)

Calling inverse SSHT C function from JAX.#


[4]:
f_recov = s2fft.inverse(flm, L, method=method)

Computing the roundtrip error#


Let’s check the associated error, which should be close to machine precision for the sampling scheme used.

[5]:
print(f"Mean absolute error = {np.nanmean(np.abs(f_recov - f))}")
Mean absolute error = 7.784372519411174e-13

Differentiating through SSHT C functions.#


So far all this is doing is providing an interface between JAX and SSHT, the real novelty comes when we differentiate through the C library.

[6]:
# Define an arbitrary JAX function
def differentiable_test(flm) -> int:
    f = s2fft.inverse(flm, L, method=method)
    return jax.numpy.nanmean(jax.numpy.abs(f)**2)

# Create the JAX reverse mode gradient function
gradient_func = jax.grad(differentiable_test)

# Compute the gradient automatically
gradient = gradient_func(flm)

Validating these gradients#


This is all well and good, but how do we know these gradients are correct? Thankfully JAX prvoides a simple function to check this…

[7]:
from jax.test_util import check_grads
check_grads(differentiable_test, (flm,), order=1, modes=("rev"))