.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/JAX_SSHT_frontend.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_JAX_SSHT_frontend.py: JAX SSHT frontend ================= This short tutorial demonstrates how to use the custom ``JAX`` frontend support ``S2FFT`` provides for the `SSHT `_ C library. .. image:: https://colab.research.google.com/assets/colab-badge.svg :align: center :alt: Open in Google Colab :target: https://colab.research.google.com/github/astro-informatics/s2fft/tree/gh-pages/_colab_notebooks/JAX_SSHT_frontend.ipynb If you are working on this notebook in Google Colab, you will need to have Google Coab install ``s2fft`` and ``pyssht``. You can do this by adding a cell to the top of the notebook with the following content: .. code-block:: bash !pip install s2fft pyssht &> /dev/null and then running that cell. .. GENERATED FROM PYTHON SOURCE LINES 23-24 As with the other introductions, let's import some packages and define an arbitrary bandlimited signal to work with. .. GENERATED FROM PYTHON SOURCE LINES 24-38 .. code-block:: Python import jax import numpy as np import s2fft jax.config.update("jax_enable_x64", True) L = 128 method = "jax_ssht" rng = np.random.default_rng(23457801234570) flm = s2fft.utils.signal_generator.generate_flm(rng, L) f = s2fft.inverse(flm, L, method=method) .. GENERATED FROM PYTHON SOURCE LINES 39-41 Calling forward SSHT C function from JAX ---------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 41-44 .. code-block:: Python flm = s2fft.forward(f, L, method=method) .. GENERATED FROM PYTHON SOURCE LINES 45-47 Calling inverse SSHT C function from JAX ---------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 47-50 .. code-block:: Python f_recov = s2fft.inverse(flm, L, method=method) .. GENERATED FROM PYTHON SOURCE LINES 51-55 Computing the roundtrip error ----------------------------- Let's check the associated error, which should be close to machine precision for the sampling scheme used. .. GENERATED FROM PYTHON SOURCE LINES 55-58 .. code-block:: Python print(f"Mean absolute error = {np.nanmean(np.abs(f_recov - f))}") .. rst-class:: sphx-glr-script-out .. code-block:: none Mean absolute error = 8.736021141250402e-13 .. GENERATED FROM PYTHON SOURCE LINES 59-63 Differentiating through SSHT C functions ---------------------------------------- So far all this is doing is providing an interface between ``JAX`` and ``SSHT``, the real novelty comes when we differentiate through the C library. .. GENERATED FROM PYTHON SOURCE LINES 63-77 .. code-block:: Python # Define an arbitrary JAX function def differentiable_test(flm) -> int: f = s2fft.inverse(flm, L, method=method) return jax.numpy.nanmean(jax.numpy.abs(f) ** 2) # Create the JAX reverse mode gradient function gradient_func = jax.grad(differentiable_test) # Compute the gradient automatically gradient = gradient_func(flm) .. GENERATED FROM PYTHON SOURCE LINES 78-83 Validating these gradients -------------------------- This is all well and good, but how do we know these gradients are correct? Thankfully ``JAX`` provides a simple function to check this... .. GENERATED FROM PYTHON SOURCE LINES 83-87 .. code-block:: Python from jax.test_util import check_grads check_grads(differentiable_test, (flm,), order=1, modes=("rev")) .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.031 seconds) .. _sphx_glr_download_tutorials_JAX_SSHT_frontend.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: JAX_SSHT_frontend.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: JAX_SSHT_frontend.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: JAX_SSHT_frontend.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_