Reparameterizers — NumPyro documentation (original) (raw)

The numpyro.infer.reparam module contains reparameterization strategies for the numpyro.handlers.reparam effect. These are useful for altering geometry of a poorly-conditioned parameter space to make the posterior better shaped. These can be used with a variety of inference algorithms, e.g.Auto*Normal guides and MCMC.

class Reparam[source]

Bases: ABC

Base class for reparameterizers.

Loc-Scale Decentering

class LocScaleReparam(centered=None, shape_params=())[source]

Bases: Reparam

Generic decentering reparameterizer [1] for latent variables parameterized by loc and scale (and possibly additional shape_params).

This reparameterization works only for latent variables, not likelihoods.

References:

  1. Automatic Reparameterisation of Probabilistic Programs, Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)

Parameters:

__call__(name, fn, obs)[source]

Parameters:

Returns:

A pair (new_fn, value).

Neural Transport

class NeuTraReparam(guide, params)[source]

Bases: Reparam

Neural Transport reparameterizer [1] of multiple latent variables.

This uses a trained AutoContinuousguide to alter the geometry of a model, typically for use e.g. in MCMC. Example usage:

Step 1. Train a guide

guide = AutoIAFNormal(model) svi = SVI(model, guide, ...)

...train the guide...

Step 2. Use trained guide in NeuTra MCMC

neutra = NeuTraReparam(guide) model = neutra.reparam(model) nuts = NUTS(model)

...now use the model in HMC or NUTS...

This reparameterization works only for latent variables, not likelihoods. Note that all sites must share a single common NeuTraReparaminstance, and that the model must have static structure.

[1] Hoffman, M. et al. (2019)

“NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport”https://arxiv.org/abs/1903.03704

Parameters:

reparam(fn=None)[source]

__call__(name, fn, obs)[source]

Parameters:

Returns:

A pair (new_fn, value).

transform_sample(latent)[source]

Given latent samples from the warped posterior (with possible batch dimensions), return a dict of samples from the latent sites in the model.

Parameters:

latent – sample from the warped posterior (possibly batched).

Returns:

a dict of samples keyed by latent sites in the model.

Return type:

dict

Transformed Distributions

class TransformReparam[source]

Bases: Reparam

Reparameterizer forTransformedDistribution latent variables.

This is useful for transformed distributions with complex, geometry-changing transforms, where the posterior has simple shape in the space of base_dist.

This reparameterization works only for latent variables, not likelihoods.

__call__(name, fn, obs)[source]

Parameters:

Returns:

A pair (new_fn, value).

Projected Normal Distributions

class ProjectedNormalReparam[source]

Bases: Reparam

Reparametrizer for ProjectedNormal latent variables.

This reparameterization works only for latent variables, not likelihoods.

__call__(name, fn, obs)[source]

Parameters:

Returns:

A pair (new_fn, value).

Circular Distributions

class CircularReparam[source]

Bases: Reparam

Reparametrizer for VonMises latent variables.

__call__(name, fn, obs)[source]

Parameters:

Returns:

A pair (new_fn, value).

Explicit Reparameterization

class ExplicitReparam(transform)[source]

Bases: Reparam

Explicit reparametrizer of a latent variable x to a transformed spacey = transform(x) with more amenable geometry. This reparametrizer is similar to TransformReparam but allows reparametrizations to be decoupled from the model declaration.

Parameters:

transform – Bijective transform to the reparameterized space.

Example:

from jax import random from jax import numpy as jnp import numpyro from numpyro import handlers, distributions as dist from numpyro.infer import MCMC, NUTS from numpyro.infer.reparam import ExplicitReparam

def model(): ... numpyro.sample("x", dist.Gamma(4, 4))

Sample in unconstrained space using a soft-plus instead of exp transform.

reparam = ExplicitReparam(dist.transforms.SoftplusTransform().inv) reparametrized = handlers.reparam(model, {"x": reparam}) kernel = NUTS(model=reparametrized) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=1) mcmc.run(random.PRNGKey(2))
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2306.47it/s, 3 steps of size 9.65e-01. acc. prob=0.93]

__call__(name, fn, obs)[source]

Parameters:

Returns:

A pair (new_fn, value).