.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/torch_frontend.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_torch_frontend.py: Torch frontend guide ==================== This minimal tutorial demonstrates how to use the torch frontend for ``S2FFT`` to compute spherical harmonic transforms. .. image:: https://colab.research.google.com/assets/colab-badge.svg :align: center :alt: Open in Google Colab :target: https://colab.research.google.com/github/astro-informatics/s2fft/tree/gh-pages/_colab_notebooks/spherical_rotation.ipynb If you are working on this notebook in Google Colab, you will need to have Google Colab install ``s2fft``. You can do this by adding a cell to the top of the notebook with the following content: .. code-block:: bash !pip install s2fft &> /dev/null and then running that cell. .. GENERATED FROM PYTHON SOURCE LINES 23-25 Though ``S2FFT`` is primarily designed for ``JAX``, this torch functionality is fully unit tested (including gradients) and can be used straightforwardly as a learnable layer within existing models. As the torch functions wrap the ``JAX`` implementations we need to configure ``JAX`` to use 64-bit precision floating point types by default to ensure sufficient precision for the transforms - ``S2FFT`` will emit a warning if this has not been done. .. GENERATED FROM PYTHON SOURCE LINES 25-35 .. code-block:: Python import jax import numpy as np import torch jax.config.update("jax_enable_x64", True) from s2fft.transforms.spherical import forward, inverse from s2fft.utils import signal_generator .. GENERATED FROM PYTHON SOURCE LINES 36-37 Lets set up a mock problem by specifying a bandlimit $L$ and generating some arbitrary harmonic coefficients. .. GENERATED FROM PYTHON SOURCE LINES 37-42 .. code-block:: Python L = 64 rng = np.random.default_rng(1234951510) flm = torch.from_numpy(signal_generator.generate_flm(rng, L)) .. GENERATED FROM PYTHON SOURCE LINES 43-44 Now lets calculate the signal on the sphere by applying the inverse spherical harmonic transform. .. GENERATED FROM PYTHON SOURCE LINES 44-47 .. code-block:: Python f = inverse(flm, L, method="torch") .. GENERATED FROM PYTHON SOURCE LINES 48-49 To calculate the corresponding spherical harmonic representation execute: .. GENERATED FROM PYTHON SOURCE LINES 49-52 .. code-block:: Python flm_check = forward(f, L, method="torch") .. GENERATED FROM PYTHON SOURCE LINES 53-54 Finally, lets check the error on the round trip is as expected for 64 bit machine precision floating point arithmetic. .. GENERATED FROM PYTHON SOURCE LINES 54-56 .. code-block:: Python print(f"Mean absolute error = {np.nanmean(np.abs(flm_check - flm))}") .. rst-class:: sphx-glr-script-out .. code-block:: none Mean absolute error = 2.8915000666098304e-14 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 2.843 seconds) .. _sphx_glr_download_tutorials_torch_frontend.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: torch_frontend.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: torch_frontend.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: torch_frontend.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_