Reparameterizers — NumPyro documentation (original) (raw)
The numpyro.infer.reparam module contains reparameterization strategies for the numpyro.handlers.reparam effect. These are useful for altering geometry of a poorly-conditioned parameter space to make the posterior better shaped. These can be used with a variety of inference algorithms, e.g.Auto*Normal
guides and MCMC.
Bases: ABC
Base class for reparameterizers.
Loc-Scale Decentering
class LocScaleReparam(centered=None, shape_params=())[source]
Bases: Reparam
Generic decentering reparameterizer [1] for latent variables parameterized by loc
and scale
(and possibly additional shape_params
).
This reparameterization works only for latent variables, not likelihoods.
References:
- Automatic Reparameterisation of Probabilistic Programs, Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)
Parameters:
- centered (float) – optional centered parameter. If None (default) learn a per-site per-element centering parameter in
[0,1]
initialized at value 0.5. To sample the parameter, consider using lift handler with a prior likeUniform(0, 1)
to cast the parameter to a latent variable. If 0, fully decenter the distribution; if 1, preserve the centered distribution unchanged. - shape_params (tuple or list) – list of additional parameter names to copy unchanged from the centered to decentered distribution.
__call__(name, fn, obs)[source]
Parameters:
- name (str) – A sample site name.
- fn (Distribution) – A distribution.
- obs (numpy.ndarray) – Observed value or None.
Returns:
A pair (new_fn
, value
).
Neural Transport
class NeuTraReparam(guide, params)[source]
Bases: Reparam
Neural Transport reparameterizer [1] of multiple latent variables.
This uses a trained AutoContinuousguide to alter the geometry of a model, typically for use e.g. in MCMC. Example usage:
Step 1. Train a guide
guide = AutoIAFNormal(model) svi = SVI(model, guide, ...)
...train the guide...
Step 2. Use trained guide in NeuTra MCMC
neutra = NeuTraReparam(guide) model = neutra.reparam(model) nuts = NUTS(model)
...now use the model in HMC or NUTS...
This reparameterization works only for latent variables, not likelihoods. Note that all sites must share a single common NeuTraReparaminstance, and that the model must have static structure.
[1] Hoffman, M. et al. (2019)
“NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport”https://arxiv.org/abs/1903.03704
Parameters:
- guide (AutoContinuous) – A guide.
- params – trained parameters of the guide.
__call__(name, fn, obs)[source]
Parameters:
- name (str) – A sample site name.
- fn (Distribution) – A distribution.
- obs (numpy.ndarray) – Observed value or None.
Returns:
A pair (new_fn
, value
).
transform_sample(latent)[source]
Given latent samples from the warped posterior (with possible batch dimensions), return a dict of samples from the latent sites in the model.
Parameters:
latent – sample from the warped posterior (possibly batched).
Returns:
a dict of samples keyed by latent sites in the model.
Return type:
Transformed Distributions
class TransformReparam[source]
Bases: Reparam
Reparameterizer forTransformedDistribution
latent variables.
This is useful for transformed distributions with complex, geometry-changing transforms, where the posterior has simple shape in the space of base_dist
.
This reparameterization works only for latent variables, not likelihoods.
__call__(name, fn, obs)[source]
Parameters:
- name (str) – A sample site name.
- fn (Distribution) – A distribution.
- obs (numpy.ndarray) – Observed value or None.
Returns:
A pair (new_fn
, value
).
Projected Normal Distributions
class ProjectedNormalReparam[source]
Bases: Reparam
Reparametrizer for ProjectedNormal
latent variables.
This reparameterization works only for latent variables, not likelihoods.
__call__(name, fn, obs)[source]
Parameters:
- name (str) – A sample site name.
- fn (Distribution) – A distribution.
- obs (numpy.ndarray) – Observed value or None.
Returns:
A pair (new_fn
, value
).
Circular Distributions
class CircularReparam[source]
Bases: Reparam
Reparametrizer for VonMises
latent variables.
__call__(name, fn, obs)[source]
Parameters:
- name (str) – A sample site name.
- fn (Distribution) – A distribution.
- obs (numpy.ndarray) – Observed value or None.
Returns:
A pair (new_fn
, value
).
Explicit Reparameterization
class ExplicitReparam(transform)[source]
Bases: Reparam
Explicit reparametrizer of a latent variable x
to a transformed spacey = transform(x)
with more amenable geometry. This reparametrizer is similar to TransformReparam but allows reparametrizations to be decoupled from the model declaration.
Parameters:
transform – Bijective transform to the reparameterized space.
Example:
from jax import random from jax import numpy as jnp import numpyro from numpyro import handlers, distributions as dist from numpyro.infer import MCMC, NUTS from numpyro.infer.reparam import ExplicitReparam
def model(): ... numpyro.sample("x", dist.Gamma(4, 4))
Sample in unconstrained space using a soft-plus instead of exp transform.
reparam = ExplicitReparam(dist.transforms.SoftplusTransform().inv) reparametrized = handlers.reparam(model, {"x": reparam}) kernel = NUTS(model=reparametrized) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=1) mcmc.run(random.PRNGKey(2))
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2306.47it/s, 3 steps of size 9.65e-01. acc. prob=0.93]
__call__(name, fn, obs)[source]
Parameters:
- name (str) – A sample site name.
- fn (Distribution) – A distribution.
- obs (numpy.ndarray) – Observed value or None.
Returns:
A pair (new_fn
, value
).