s2fft is currently in an open beta, please provide feedback on GitHub

# Torch frontend guide#

colab image

[1]:
import sys
IN_COLAB = 'google.colab' in sys.modules

# Install s2fft and data if running on google colab.
if IN_COLAB:
    !pip install s2fft &> /dev/null

This minimal tutorial demonstrates how to use the torch frontend for S2FFT to compute spherical harmonic transforms. Though S2FFT is primarily designed for JAX, this torch functionality is fully unit tested (including gradients) and can be used straightforwardly as a learnable layer within existing models.

[2]:
import torch
import numpy as np
from s2fft.precompute_transforms.spherical import inverse, forward
from s2fft.precompute_transforms.construct import spin_spherical_kernel
from s2fft.utils import signal_generator
JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L.

Lets set up a mock problem by specifiying a bandlimit \(L\) and generating some arbitrary harmonic coefficients.

[3]:
L = 64                                                          # Spherical harmonic bandlimit
rng = np.random.default_rng(1234951510)                         # Random seed for signal generator
flm = signal_generator.generate_flm(rng, L, using_torch=True)   # Random set of spherical harmonic coefficients

For the fully precompute transform we must also generate the precompute kernels which we store as a torch tensors.

[4]:
inverse_kernel = spin_spherical_kernel(L, using_torch=True, forward=False)
forward_kernel = spin_spherical_kernel(L, using_torch=True, forward=True)

Now lets calculate the signal on the sphere by applying the inverse spherical harmonic transform

[5]:
f = inverse(flm, L, 0, inverse_kernel, method="torch")

To calculate the corresponding spherical harmonic representation execute

[6]:
flm_check = forward(f, L, 0, forward_kernel, method="torch")

Finally, lets check the error on the roundtrip is at 64bit machine precision

[7]:
print(f"Mean absolute error = {np.nanmean(np.abs(flm_check - flm))}")
Mean absolute error = 1.1866908936078849e-14