Generation (Step-by-Step)#

colab image

This tutorial is a basic overview of how one may use the scattering covariances as a statistical generative model.

Generative AI models typically require the abundance of realistic training data. In many (often high dimensional) application domains, such as the sciences, such training data does not exist, limiting generative AI approaches.

One may instead construct an expressive statistical representation from which, provided at least a single fiducial realisation, many realisations may be drawn. This concept is actually very familiar, particularly in cosmology where it is typical to draw Gaussian realisations from a known power spectrum. However, this generative model does not capture complex non-linear structural information.

Here we will instead use the scattering covariances \(\Phi(x)\) as our statistical representation. Given \(\Phi\) is a non-linear function of the data \(x\), generating new realisations isn’t quite so straightforward. In fact, to do so we’ll need to minimise the loss function:

\[\mathcal{L}(x) = ||\Phi(x) - \Phi(x_t)||^2_2\]

where \(\Phi(x_t)\) are the target covariances computed from the signal we are aiming to emulate \(x_t\). To solve this optimisation with gradient based methods we clearly need to be able to differentiate through \(\Phi\) which is a complex function involving wavelet transforms, non-linearities, spherical harmonic and Wigner transforms.

As S2SCAT is a JAX package, we can readily access these gradients, so lets see exactly how this works!

Import the package#

Lets first import S2SCAT and some basic plotting functions. We’ll also pick up pickle to load the targets which have been stored just to save you some time.

[ ]:
import sys
IN_COLAB = 'google.colab' in sys.modules

# Install a spherical plotting package.
!pip install cartopy &> /dev/null

# Install s2fft and data if running on google colab.
if IN_COLAB:
    !pip install s2scat &> /dev/null
    !pip install numpy==1.23.5 &> /dev/null
    !mkdir data/
    !wget https://github.com/astro-informatics/s2scat/raw/main/notebooks/data/target_map_lss.npy -P data/ &> /dev/null
[ ]:
import jax
jax.config.update("jax_enable_x64", True)

from matplotlib import pyplot as plt
import numpy as np
import cartopy.crs as ccrs
import s2scat, s2fft

Configure the problem#

Lets set up the target field we are aiming to emulate, and the hyperparameters of the scattering covariance representation we will work with.

[ ]:
L = 256                # Spherical harmonic bandlimit.
N = 3                  # Azimuthal bandlimit (directionality).
J_min = 2              # Minimum wavelet scale.
reality = True         # Input signal is real.
recursive = False      # Use the fully precompute transform.

# Lets load in the spherical field we wish to emulate and its harmonic coefficients.
x_t = np.load('data/target_map_lss.npy')
xlm_t = s2fft.forward_jax(x_t, L, reality=reality)[:,L-1:]

Before calling the scattering transform you need to run configuration, which will generate any precomputed arrays and cache them. When running the recurisve transform this shouldn’t take much memory at all. However, the fully precompute transform, which is much faster, can be extremely memory hungry at L ~ 512 and above!

[ ]:
# Configure the representation e.g. load wavelet filters and Wigner matrices.
config = s2scat.configure(L, N, J_min, reality, recursive)

# Calculate normalisation and target latent vector.
norm = s2scat.compute_norm(xlm_t, L, N, J_min, reality, config, recursive)
targets = s2scat.scatter(xlm_t, L, N, J_min, reality, config, norm, recursive)

Define a loss function#

Lets define a simple \(\ell_2\)-loss function which just computes the mean squared error between the scattering covariances computed at our current iterant and those of the target. In practice, any loss could be considered here, however we’ll use the most straightforward scenario for this demonstration.

[ ]:
def loss_func(xlm):
    predicts = s2scat.scatter(xlm, L, N, J_min, reality, config, norm, recursive)
    return s2scat.optimisation.l2_covariance_loss(predicts, targets)

Generate an initial estimate#

We need to choose a set of harmonic coefficients \(x_{\ell m}\) from which to start our optimisation. Strictly speaking, we should start from Gaussianly distributed random signal to ensure we form a macro-canonical model of our target field, and we will do precisely this. However, in practice it may be better to start from e.g. a Gaussian random field, generated from a fiducial power spectrum, as this may reduce the total number of iterations required for convergence.

In any case, lets generate a starting signal.

[ ]:
# Compute the standard deviation of the target field.
sigma_bar = np.std(np.abs(xlm_t)[xlm_t!=0])

# Generate Gaussian random harmonic coefficients with the correct variance.
xlm = np.random.randn(L, L) * sigma_bar + 1j*np.random.randn(L, L) * sigma_bar

# Save the starting noise signal for posterity and plotting!
xlm_start = s2scat.operators.spherical.make_flm_full(xlm, L)
x_start = s2fft.inverse(xlm_start, L, reality=reality, method="jax")

Minimise the objective#

Now we can pass all these components to optax, which we have internally configured to use the adam optimizer to minimise the loss and return us a synthetic realisation which should approximate the target field statistics.

[ ]:
# Run the optimisation to generate a new realisation xlm.
xlm_end, _ = s2scat.optimisation.fit_optax(xlm, loss_func, niter=400, learning_rate=1e-3, verbose=True, track_history=True)

# Convert the synthetic harmonic coefficients into a pixel-space image.
xlm_end = s2scat.operators.spherical.make_flm_full(xlm_end, L)
x_end = s2fft.inverse_jax(xlm_end, L, reality=reality)

Check the synthesis#

Finally, lets check how our starting and ending realisations shape up against the target field!

[ ]:
fields = [x_t, x_start, x_end]
titles = ["Target", "Start", "Emulation"]
fig, axs = plt.subplots(1, 3, subplot_kw={'projection': ccrs.Mollweide()}, figsize=(30,10))
mx, mn = 3, -1
for i in range(3):
    axs[i].imshow(fields[i], transform=ccrs.PlateCarree(), cmap='viridis', vmax=mx, vmin=mn)
    axs[i].set_title(titles[i])
    axs[i].axis('off')
plt.show()