Optimisation#

s2scat.optimisation.fit_optax(params: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], loss_func, niter: int = 10, learning_rate: float32 = 0.0001, loss_history: list = None, print_iters: int = 10, apply_jit: bool = False, verbose: bool = False, track_history: bool = False) Tuple[Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], List]#

Minimises the declared loss function starting at params using optax (adam).

Parameters:
  • params (jnp.ndarray) – Initial estimate (signal).

  • loss_func (function) – Loss function to minimise.

  • method (str, optional) – jaxopt optimization algorithm. Defaults to “L-BFGS-B”.

  • niter (int, optional) – Maximum number of iterations. Defaults to 10.

  • learning_rate (jnp.float32, optional) – Adam learning rate for optax. Defaults to 1e-4.

  • loss_history (list, optional) – A list in which to store the loss history. Defaults to None.

  • print_iters (int, optional) – How often to return the loss during training. Defaults to 10.

  • apply_jit (bool, optional) – Whether to jit the training step. Defaults to False.

  • verbose (bool, optional) – Whether to print loss during generation. Defaults to False.

  • track_history (bool, optional) – Whether to track history during generation. Defaults to False.

Returns:

Optimised solution and loss history.

Return type:

Tuple[optax.Params, List]

s2scat.optimisation.get_P00prime(flm: Array, filter_linear: Tuple[Array], normalisation: List[Array] = None) Tuple[Array]#

Computes P00prime which is the averaged power within each wavelet scale.

Parameters:
  • flm (jnp.ndarray) – Spherical harmonic coefficients of signal.

  • filter_linear (Tuple[jnp.ndarray]) – Linearised wavelet filters.

  • normalisation (List[jnp.ndarray], optional) – _description_. Defaults to None.

Returns:

Tuple of the power and averaged power over wavelet scales.

Return type:

Tuple[jnp.ndarray]

s2scat.optimisation.l2_covariance_loss(predicts, targets) float64#

L2 loss wrapper for the scattering covariance.

Parameters:
  • predicts (List[jnp.ndarray]) – Predicted scattering covariances.

  • targets (List[jnp.ndarray]) – Target scattering covariances.

Returns:

L2 loss.

Return type:

jnp.float64

s2scat.optimisation.l2_loss(predict, target) float64#

L2 loss for a single scattering covariance.

Parameters:
  • predict (jnp.ndarray) – Predicted scattering covariance.

  • target (jnp.ndarray) – Target scattering covariance.

Returns:

L2 loss.

Return type:

jnp.float64