Optimisation#
- s2scat.optimisation.fit_optax(params: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], loss_func, niter: int = 10, learning_rate: float32 = 0.0001, loss_history: list = None, print_iters: int = 10, apply_jit: bool = False, verbose: bool = False, track_history: bool = False) Tuple[Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], List] #
Minimises the declared loss function starting at params using optax (adam).
- Parameters:
params (jnp.ndarray) – Initial estimate (signal).
loss_func (function) – Loss function to minimise.
method (str, optional) – jaxopt optimization algorithm. Defaults to “L-BFGS-B”.
niter (int, optional) – Maximum number of iterations. Defaults to 10.
learning_rate (jnp.float32, optional) – Adam learning rate for optax. Defaults to 1e-4.
loss_history (list, optional) – A list in which to store the loss history. Defaults to None.
print_iters (int, optional) – How often to return the loss during training. Defaults to 10.
apply_jit (bool, optional) – Whether to jit the training step. Defaults to False.
verbose (bool, optional) – Whether to print loss during generation. Defaults to False.
track_history (bool, optional) – Whether to track history during generation. Defaults to False.
- Returns:
Optimised solution and loss history.
- Return type:
Tuple[optax.Params, List]
- s2scat.optimisation.get_P00prime(flm: Array, filter_linear: Tuple[Array], normalisation: List[Array] = None) Tuple[Array] #
Computes P00prime which is the averaged power within each wavelet scale.
- Parameters:
flm (jnp.ndarray) – Spherical harmonic coefficients of signal.
filter_linear (Tuple[jnp.ndarray]) – Linearised wavelet filters.
normalisation (List[jnp.ndarray], optional) – _description_. Defaults to None.
- Returns:
Tuple of the power and averaged power over wavelet scales.
- Return type:
Tuple[jnp.ndarray]
- s2scat.optimisation.l2_covariance_loss(predicts, targets) float64 #
L2 loss wrapper for the scattering covariance.
- Parameters:
predicts (List[jnp.ndarray]) – Predicted scattering covariances.
targets (List[jnp.ndarray]) – Target scattering covariances.
- Returns:
L2 loss.
- Return type:
jnp.float64
- s2scat.optimisation.l2_loss(predict, target) float64 #
L2 loss for a single scattering covariance.
- Parameters:
predict (jnp.ndarray) – Predicted scattering covariance.
target (jnp.ndarray) – Target scattering covariance.
- Returns:
L2 loss.
- Return type:
jnp.float64