.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/JAX_HEALPix_frontend.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_JAX_HEALPix_frontend.py: JAX HEALPix Frontend ==================== This short tutorial demonstrates how to use the custom ``JAX`` frontend support ``S2FFT`` provides for the `HEALPix `_ C++ library. .. image:: https://colab.research.google.com/assets/colab-badge.svg :align: center :alt: Open in Google Colab :target: https://colab.research.google.com/github/astro-informatics/s2fft/tree/gh-pages/_colab_notebooks/JAX_HEALPix_frontend.ipynb If you are working on this notebook in Google Colab, you will need to have Google Colab install ``s2fft`` and ``healpy``. You can do this by adding a cell to the top of the notebook with the following content: .. code-block:: bash !pip install s2fft healpy &> /dev/null and then running that cell. .. GENERATED FROM PYTHON SOURCE LINES 23-25 ``S2FFT``'s support for the `HEALPix `_ C++ library resolves issues involving long JIT compile times for HEALPix when running on CPU. As with the other introductions, let's import some packages and define an arbitrary bandlimited signal to work with. .. GENERATED FROM PYTHON SOURCE LINES 25-40 .. code-block:: Python import jax import numpy as np import s2fft jax.config.update("jax_enable_x64", True) L = 128 nside = 64 method = "jax_healpy" sampling = "healpix" rng = np.random.default_rng(23457801234570) flm = s2fft.utils.signal_generator.generate_flm(rng, L) f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method) .. GENERATED FROM PYTHON SOURCE LINES 41-43 Calling forward HEALPix C++ function from JAX --------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 43-46 .. code-block:: Python flm = s2fft.forward(f, L, nside=nside, sampling=sampling, method=method) .. GENERATED FROM PYTHON SOURCE LINES 47-49 Calling inverse HEALPix C++ function from JAX --------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 49-52 .. code-block:: Python f_recov = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method) .. GENERATED FROM PYTHON SOURCE LINES 53-58 Computing the roundtrip error ----------------------------- Let's check the associated error, which should be around ``1e-5`` for healpix, which is not an exact sampling of the sphere. Note that increasing ``iters`` will reduce the numerical error here slightly, at the cost of linearly increased compute. .. GENERATED FROM PYTHON SOURCE LINES 58-61 .. code-block:: Python print(f"Mean absolute error = {np.nanmean(np.abs(f_recov - f))}") .. rst-class:: sphx-glr-script-out .. code-block:: none Mean absolute error = 1.3855672275957913e-05 .. GENERATED FROM PYTHON SOURCE LINES 62-66 Differentiating through HEALPix C++ functions --------------------------------------------- So far all this is doing is providing an interface between ``JAX`` and ``HEALPix``, the real novelty comes when we differentiate through the C++ library. .. GENERATED FROM PYTHON SOURCE LINES 66-80 .. code-block:: Python # Define an arbitrary JAX function def differentiable_test(flm) -> int: f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method) return jax.numpy.nanmean(jax.numpy.abs(f) ** 2) # Create the JAX reverse mode gradient function gradient_func = jax.grad(differentiable_test) # Compute the gradient automatically gradient = gradient_func(flm) .. GENERATED FROM PYTHON SOURCE LINES 81-86 Validating these gradients -------------------------- This is all well and good, but how do we know these gradients are correct? Thankfully ``JAX`` provides a simple function to check this... .. GENERATED FROM PYTHON SOURCE LINES 86-90 .. code-block:: Python from jax.test_util import check_grads check_grads(differentiable_test, (flm,), order=1, modes=("rev")) .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.623 seconds) .. _sphx_glr_download_tutorials_JAX_HEALPix_frontend.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: JAX_HEALPix_frontend.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: JAX_HEALPix_frontend.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: JAX_HEALPix_frontend.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_