# Torch frontend guide#

colab image

[1]:
import sys
IN_COLAB = 'google.colab' in sys.modules

# Install s2fft and data if running on google colab.
if IN_COLAB:
    !pip install s2fft &> /dev/null

This minimal tutorial demonstrates how to use the torch frontend for S2FFT to compute spherical harmonic transforms. Though S2FFT is primarily designed for JAX, this torch functionality is fully unit tested (including gradients) and can be used straightforwardly as a learnable layer within existing models. As the torch functions wrap the JAX implementations we need to configure JAX to use 64-bit precision floating point types by default to ensure sufficient precision for the transforms - S2FFT will emit a warning if this has not been done.

[2]:
import jax
jax.config.update("jax_enable_x64", True)
import torch
import numpy as np
from s2fft.transforms.spherical import inverse, forward
from s2fft.precompute_transforms.spherical import (
    inverse as precompute_inverse, forward as precompute_forward
)
from s2fft.precompute_transforms.construct import spin_spherical_kernel_torch
from s2fft.utils import signal_generator

Lets set up a mock problem by specifiying a bandlimit \(L\) and generating some arbitrary harmonic coefficients.

[3]:
L = 64
rng = np.random.default_rng(1234951510)
flm = torch.from_numpy(signal_generator.generate_flm(rng, L))

Now lets calculate the signal on the sphere by applying the inverse spherical harmonic transform

[4]:
f = inverse(flm, L, method="torch")
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

To calculate the corresponding spherical harmonic representation execute

[5]:
flm_check = forward(f, L, method="torch")

Finally, lets check the error on the round trip is as expected for 64 bit machine precision floating point arithmetic

[6]:
print(f"Mean absolute error = {np.nanmean(np.abs(flm_check - flm))}")
Mean absolute error = 2.8915048238993476e-14

For the fully precompute transform we must also generate the precompute kernels which we store as a torch tensors.

[7]:
inverse_kernel = spin_spherical_kernel_torch(L, forward=False)
forward_kernel = spin_spherical_kernel_torch(L, forward=True)

We then pass the kernels as additional arguments to the transform functions

[8]:
precompute_f = precompute_inverse(flm, L, kernel=inverse_kernel, method="torch")
precompute_flm_check = precompute_forward(f, L, kernel=forward_kernel, method="torch")

Again, we check the error on the round trip is as expected

[9]:
print(f"Mean absolute error = {np.nanmean(np.abs(precompute_flm_check - flm))}")
Mean absolute error = 2.904741595325594e-14