class harmonic.flows.RealNVP(n_features: int, n_scaled_layers: int = 2, n_unscaled_layers: int = 4, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Real-valued non-volume preserving flow using flax and tfp-jax.

Parameters:
  • n_features (int) – Number of features in the data.

  • n_scaled_layers (int, optional) – Non-zero number of layers in the flow. Defaults to 2.

  • n_unscaled_layers (int, optional) – Number of unscaled layers in the flow. Defaults to 4.

log_prob(x: array, temperature: float = 1.0) array

Evaluate the log probability of the flow for a batched input.

Parameters:
  • x (jnp.ndarray (batch_size, ndim)) – Sample for which to predict posterior values.

  • temperature (float, optional) – Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1.

Returns:

Predicted log_e posterior value.

Return type:

jnp.ndarray (batch_size,)

make_flow(temperature: float = 1.0)

Make tfp-jax distribution object containing the RealNVP flow.

Parameters:

temperature (float, optional) – Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1.

Returns:

Base Gaussian transformed by scaled contained in the scaled_layers

attribute, followed by unscaled affine coupling layers contained in the unscaled_layers attribute.

Return type:

tfb.Distribution

Raises:

ValueError – If n_scaled_layers is not positive.

sample(rng: PRNGKey, num_samples: int, temperature: float = 1.0) array

” Sample from the flow.

Parameters:
  • rng (Union[Array, PRNGKeyArray])) – Key used in random number generation process.

  • num_samples (int) – Number of samples generated.

  • temperature (float, optional) – Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1.

Returns:

Samples from fitted distribution.

Return type:

jnp.array (num_samples, ndim)

setup()

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class harmonic.flows.RQSpline(n_features: int, num_layers: int, hidden_size: ~typing.Sequence[int], num_bins: int, spline_range: ~typing.Sequence[float] = (-10.0, 10.0), parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Rational quadratic spline normalizing flow model using distrax.

Parameters:
  • n_features (int) – Number of features in the data.

  • num_layers (int) – Number of layers in the flow.

  • num_bins (int) – Number of bins in the spline.

  • hidden_size (Sequence[int]) – Size of the hidden layers in the conditioner.

  • spline_range (Sequence[float], optional) – Range of the spline. Defaults to (-10, 10)

Note

Adapted from github.com/kazewong/flowMC

log_prob(x: array, temperature: float = 1.0) array

Evaluate the log probability of the flow for a batched input.

Parameters:
  • x (jnp.ndarray (batch_size, ndim)) – Sample for which to predict posterior values.

  • temperature (float, optional) – Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1.

Returns:

Predicted log_e posterior value.

Return type:

jnp.ndarray (batch_size,)

make_flow(temperature: float = 1.0)

Make distrax distribution containing the rational quadratic spline flow.

Parameters:

temperature (float, optional) – Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1.

Returns:

Base Gaussian transformed by rational quadratic spline flow.

sample(rng: PRNGKey, num_samples: int, temperature: float = 1.0) array

” Sample from the flow.

Parameters:
  • rng (Union[Array, PRNGKeyArray])) – Key used in random number generation process.

  • num_samples (int) – Number of samples generated.

  • temperature (float, optional) – Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1.

Returns:

Samples from fitted distribution.

Return type:

jnp.array (num_samples, ndim)

setup()

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.