# Wigner transform#
[1]:
import sys
IN_COLAB = 'google.colab' in sys.modules
# Install s2fft and data if running on google colab.
if IN_COLAB:
!pip install s2fft &> /dev/null
This tutorial demonstrates how to use S2FFT
to compute Wigner transforms, i.e. Fourier transforms on the rotation group SO(3).
Specifically, we will adopt the sampling scheme of McEwen et al. (2015).
To demonstrate how to compute S2FFT
Wigner transforms we will first construct an input signal that is sampled on the rotation group using this sampling scheme. We’ll simply construct a random test signal in harmonic space for demonstration purposes.
[2]:
import jax
jax.config.update("jax_enable_x64", True)
import numpy as np
import s2fft
L = 128
N = 3
reality = True
rng = np.random.default_rng(83459)
flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality)
Computing the inverse Wigner transform#
Let’s run the JAX function to compute the inverse Wigner transform of this random signal.
[3]:
f = s2fft.wigner.inverse_jax(flmn, L, N, reality=reality)
If you are planning on applying this transform many times (e.g. during training of a model) we recommend precomputing and storing some small arrays that are used every time. To do this simply compute these and pass as a static argument.
[4]:
precomps = s2fft.generate_precomputes_wigner_jax(L, N, forward=False, reality=reality)
f_pre = s2fft.wigner.inverse_jax(flmn, L, N, reality=reality, precomps=precomps)
Computing the forward Wigner transform#
Let’s run the JAX function to compute the forward Wigner transforms to get us back to the random Wigner coefficients.
[5]:
flmn_recov = s2fft.wigner.forward_jax(f, L, N, reality=reality)
Again, if you are planning on applying this transform many times (e.g. during training of a model) we recommend precomputing and storing some small arrays that are used every time. To do this simply compute these and pass as a static argument.
[6]:
precomps = s2fft.generate_precomputes_wigner_jax(L, N, forward=True, reality=reality)
flmn_recov_pre = s2fft.wigner.forward_jax(f_pre, L, N, reality=reality, precomps=precomps)
Computing the error#
Let’s check the roundtrip error, which should be close to machine precision for the sampling theorem used.
[7]:
print(f"Mean absolute error = {np.nanmean(np.abs(flmn_recov - flmn))}")
Mean absolute error = 5.1348799839254916e-14
[8]:
print(f"Mean absolute error using precomputes = {np.nanmean(np.abs(flmn_recov_pre - flmn))}")
Mean absolute error using precomputes = 5.1348799839254916e-14