- class harmonic.flows.RealNVP(n_features: int, n_scaled_layers: int = 2, n_unscaled_layers: int = 4, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)
Real-valued non-volume preserving flow using flax and tfp-jax.
- Parameters:
n_features (int) – Number of features in the data.
n_scaled_layers (int, optional) – Non-zero number of layers in the flow. Defaults to 2.
n_unscaled_layers (int, optional) – Number of unscaled layers in the flow. Defaults to 4.
- log_prob(x: array, temperature: float = 1.0) array
Evaluate the log probability of the flow for a batched input.
- Parameters:
x (jnp.ndarray (batch_size, ndim)) – Sample for which to predict posterior values.
temperature (float, optional) – Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1.
- Returns:
Predicted log_e posterior value.
- Return type:
jnp.ndarray (batch_size,)
- make_flow(temperature: float = 1.0)
Make tfp-jax distribution object containing the RealNVP flow.
- Parameters:
temperature (float, optional) – Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1.
- Returns:
- Base Gaussian transformed by scaled contained in the scaled_layers
attribute, followed by unscaled affine coupling layers contained in the unscaled_layers attribute.
- Return type:
tfb.Distribution
- Raises:
ValueError – If n_scaled_layers is not positive.
- sample(rng: PRNGKey, num_samples: int, temperature: float = 1.0) array
” Sample from the flow.
- Parameters:
rng (Union[Array, PRNGKeyArray])) – Key used in random number generation process.
num_samples (int) – Number of samples generated.
temperature (float, optional) – Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1.
- Returns:
Samples from fitted distribution.
- Return type:
jnp.array (num_samples, ndim)
- setup()
Initializes a Module lazily (similar to a lazy
__init__
).setup
is called once lazily on a module instance when a module is bound, immediately before any other methods like__call__
are invoked, or before asetup
-defined attribute onself
is accessed.This can happen in three cases:
Immediately when invoking
apply()
,init()
orinit_and_output()
.Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setup
method (see__setattr__()
):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact()
, immediately before another method is called orsetup
defined attribute is accessed.
- class harmonic.flows.RQSpline(n_features: int, num_layers: int, hidden_size: ~typing.Sequence[int], num_bins: int, spline_range: ~typing.Sequence[float] = (-10.0, 10.0), parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)
Rational quadratic spline normalizing flow model using distrax.
- Parameters:
n_features (int) – Number of features in the data.
num_layers (int) – Number of layers in the flow.
num_bins (int) – Number of bins in the spline.
hidden_size (Sequence[int]) – Size of the hidden layers in the conditioner.
spline_range (Sequence[float], optional) – Range of the spline. Defaults to (-10, 10)
Note
Adapted from github.com/kazewong/flowMC
- log_prob(x: array, temperature: float = 1.0) array
Evaluate the log probability of the flow for a batched input.
- Parameters:
x (jnp.ndarray (batch_size, ndim)) – Sample for which to predict posterior values.
temperature (float, optional) – Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1.
- Returns:
Predicted log_e posterior value.
- Return type:
jnp.ndarray (batch_size,)
- make_flow(temperature: float = 1.0)
Make distrax distribution containing the rational quadratic spline flow.
- Parameters:
temperature (float, optional) – Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1.
- Returns:
Base Gaussian transformed by rational quadratic spline flow.
- sample(rng: PRNGKey, num_samples: int, temperature: float = 1.0) array
” Sample from the flow.
- Parameters:
rng (Union[Array, PRNGKeyArray])) – Key used in random number generation process.
num_samples (int) – Number of samples generated.
temperature (float, optional) – Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1.
- Returns:
Samples from fitted distribution.
- Return type:
jnp.array (num_samples, ndim)
- setup()
Initializes a Module lazily (similar to a lazy
__init__
).setup
is called once lazily on a module instance when a module is bound, immediately before any other methods like__call__
are invoked, or before asetup
-defined attribute onself
is accessed.This can happen in three cases:
Immediately when invoking
apply()
,init()
orinit_and_output()
.Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setup
method (see__setattr__()
):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact()
, immediately before another method is called orsetup
defined attribute is accessed.