numpyro.infer.reparam — NumPyro documentation (original) (raw)

Copyright Contributors to the Pyro project.

SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod import math from typing import Iterable

import numpy as np

import jax import jax.numpy as jnp

import numpyro import numpyro.distributions as dist from numpyro.distributions import biject_to, constraints from numpyro.distributions.util import is_identically_one, safe_normalize, sum_rightmost from numpyro.infer.autoguide import AutoContinuous from numpyro.util import not_jax_tracer

[docs] class Reparam(ABC): """ Base class for reparameterizers. """

@abstractmethod
def __call__(self, name, fn, obs):
    """
    :param str name: A sample site name.
    :param ~numpyro.distributions.Distribution fn: A distribution.
    :param numpy.ndarray obs: Observed value or None.
    :return: A pair (``new_fn``, ``value``).
    """
    return fn, obs

def _unwrap(self, fn):
    """
    Unwrap Independent(...) and ExpandedDistribution(...) distributions.
    We can recover the input `fn` from the result triple `(fn, expand_shape, event_dim)`
    with `fn.expand(expand_shape).to_event(event_dim - fn.event_dim)`.
    """
    shape = fn.shape()
    event_dim = fn.event_dim
    while isinstance(fn, (dist.Independent, dist.ExpandedDistribution)):
        fn = fn.base_dist
    expand_shape = shape[: len(shape) - fn.event_dim]
    return fn, expand_shape, event_dim

def _wrap(self, fn, expand_shape, event_dim):
    """
    Wrap in Independent and ExpandedDistribution distributions.
    """
    # Match batch_shape.
    assert fn.event_dim <= event_dim
    fn = fn.expand(expand_shape)  # no-op if expand_shape == fn.batch_shape

    # Match event_dim.
    if fn.event_dim < event_dim:
        fn = fn.to_event(event_dim - fn.event_dim)
    assert fn.event_dim == event_dim
    return fn

[docs] class LocScaleReparam(Reparam): """ Generic decentering reparameterizer [1] for latent variables parameterized by loc and scale (and possibly additional shape_params).

This reparameterization works only for latent variables, not likelihoods.

**References:**

1. *Automatic Reparameterisation of Probabilistic Programs*,
   Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)

:param float centered: optional centered parameter. If None (default) learn
    a per-site per-element centering parameter in ``[0,1]`` initialized at value 0.5.
    To sample the parameter, consider using :class:`~numpyro.handlers.lift` handler with a
    prior like ``Uniform(0, 1)`` to cast the parameter to a latent variable. If 0, fully
    decenter the distribution; if 1, preserve the centered distribution
    unchanged.
:param shape_params: list of additional parameter names to copy unchanged from
    the centered to decentered distribution.
:type shape_params: tuple or list
"""

def __init__(self, centered=None, shape_params=()):
    assert centered is None or isinstance(
        centered, (int, float, np.generic, np.ndarray, jnp.ndarray, jax.core.Tracer)
    )
    assert isinstance(shape_params, (tuple, list))
    assert all(isinstance(name, str) for name in shape_params)
    if centered is not None:
        is_valid = constraints.unit_interval.check(centered)
        if not_jax_tracer(is_valid):
            if not np.all(is_valid):
                raise ValueError(
                    "`centered` argument does not satisfy `0 <= centered <= 1`."
                )

    self.centered = centered
    self.shape_params = shape_params

[docs] def call(self, name, fn, obs): assert obs is None, "LocScaleReparam does not support observe statements" support = fn.support if isinstance(support, constraints.independent): support = fn.support.base_constraint if support is not constraints.real: raise ValueError( "LocScaleReparam only supports distributions with real " f"support, but got {support} support at site {name}." )

    centered = self.centered
    if is_identically_one(centered):
        return fn, obs
    event_shape = fn.event_shape
    fn, expand_shape, event_dim = self._unwrap(fn)

    # Apply a partial decentering transform.
    params = {key: getattr(fn, key) for key in self.shape_params}
    if self.centered is None:
        centered = numpyro.param(
            "{}_centered".format(name),
            jnp.full(event_shape, 0.5),
            constraint=constraints.unit_interval,
        )
    if isinstance(centered, (int, float, np.generic)) and centered == 0.0:
        params["loc"] = jnp.zeros_like(fn.loc)
        params["scale"] = jnp.ones_like(fn.scale)
    else:
        params["loc"] = fn.loc * centered
        params["scale"] = fn.scale**centered
    decentered_fn = self._wrap(type(fn)(**params), expand_shape, event_dim)

    # Draw decentered noise.
    decentered_value = numpyro.sample("{}_decentered".format(name), decentered_fn)

    # Differentiably transform.
    delta = decentered_value - centered * fn.loc
    value = fn.loc + jnp.power(fn.scale, 1 - centered) * delta

    # Simulate a pyro.deterministic() site.
    return None, value

[docs] class TransformReparam(Reparam): """ Reparameterizer for :class:~numpyro.distributions.TransformedDistribution latent variables.

This is useful for transformed distributions with complex,
geometry-changing transforms, where the posterior has simple shape in
the space of ``base_dist``.

This reparameterization works only for latent variables, not likelihoods.
"""

[docs] def call(self, name, fn, obs): assert obs is None, "TransformReparam does not support observe statements" fn, expand_shape, event_dim = self._unwrap(fn) if not isinstance(fn, dist.TransformedDistribution): raise ValueError( "TransformReparam does not automatically work with {}" " distribution anymore. Please explicitly using" " TransformedDistribution(base_dist, AffineTransform(...)) pattern" " with TransformReparam.".format(type(fn).name) )

    # Draw noise from the base distribution.
    base_event_dim = event_dim
    for t in reversed(fn.transforms):
        base_event_dim += t.domain.event_dim - t.codomain.event_dim
    x = numpyro.sample(
        "{}_base".format(name),
        self._wrap(fn.base_dist, expand_shape, base_event_dim),
    )

    # Differentiably transform.
    for t in fn.transforms:
        x = t(x)

    # Simulate a pyro.deterministic() site.
    return None, x

[docs] class ProjectedNormalReparam(Reparam): """ Reparametrizer for :class:~numpyro.distributions.ProjectedNormal latent variables.

This reparameterization works only for latent variables, not likelihoods.
"""

[docs] def call(self, name, fn, obs): assert obs is None, "ProjectedNormalReparam does not support observe statements" fn, expand_shape, event_dim = self._unwrap(fn) assert isinstance(fn, dist.ProjectedNormal)

    # Draw parameter-free noise.
    new_fn = dist.Normal(jnp.zeros(fn.concentration.shape), 1).to_event(1)
    x = numpyro.sample(
        "{}_normal".format(name), self._wrap(new_fn, expand_shape, event_dim)
    )

    # Differentiably transform.
    value = safe_normalize(x + fn.concentration)

    # Simulate a pyro.deterministic() site.
    return None, value

[docs] class NeuTraReparam(Reparam): """ Neural Transport reparameterizer [1] of multiple latent variables.

This uses a trained :class:`~numpyro.infer.autoguide.AutoContinuous`
guide to alter the geometry of a model, typically for use e.g. in MCMC.
Example usage::

    # Step 1. Train a guide
    guide = AutoIAFNormal(model)
    svi = SVI(model, guide, ...)
    # ...train the guide...

    # Step 2. Use trained guide in NeuTra MCMC
    neutra = NeuTraReparam(guide)
    model = neutra.reparam(model)
    nuts = NUTS(model)
    # ...now use the model in HMC or NUTS...

This reparameterization works only for latent variables, not likelihoods.
Note that all sites must share a single common :class:`NeuTraReparam`
instance, and that the model must have static structure.

[1] Hoffman, M. et al. (2019)
    "NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport"
    https://arxiv.org/abs/1903.03704

:param ~numpyro.infer.autoguide.AutoContinuous guide: A guide.
:param params: trained parameters of the guide.
"""

def __init__(self, guide, params):
    if not isinstance(guide, AutoContinuous):
        raise TypeError(
            "NeuTraReparam expected an AutoContinuous guide, but got {}".format(
                type(guide)
            )
        )
    self.guide = guide
    self.params = params
    try:
        self.transform = self.guide.get_transform(params)
    except (NotImplementedError, TypeError) as e:
        raise ValueError(
            "NeuTraReparam only supports guides that implement "
            "`get_transform` method that does not depend on the "
            "model's `*args, **kwargs`"
        ) from e
    self._x_unconstrained = {}

def _reparam_config(self, site):
    if site["name"] in self.guide.prototype_trace:
        # We only reparam if this is an unobserved site in the guide
        # prototype trace.
        guide_site = self.guide.prototype_trace[site["name"]]
        if not guide_site.get("is_observed", False):
            return self

[docs] def reparam(self, fn=None): return numpyro.handlers.reparam(fn, config=self._reparam_config)

[docs] def call(self, name, fn, obs): if name not in self.guide.prototype_trace: return fn, obs assert obs is None, "NeuTraReparam does not support observe statements"

    log_density = 0.0
    compute_density = numpyro.get_mask() is not False
    if not self._x_unconstrained:  # On first sample site.
        # Sample a shared latent.
        model_plates = {
            msg["name"]
            for msg in self.guide.prototype_trace.values()
            if msg["type"] == "plate"
        }
        z_unconstrained = numpyro.sample(
            "{}_shared_latent".format(self.guide.prefix),
            self.guide.get_base_dist().mask(False),
            infer={"block_plates": model_plates},
        )

        # Differentiably transform.
        x_unconstrained = self.transform(z_unconstrained)
        if compute_density:
            log_density = self.transform.log_abs_det_jacobian(
                z_unconstrained, x_unconstrained
            )
        self._x_unconstrained = self.guide._unpack_latent(x_unconstrained)

    # Extract a single site's value from the shared latent.
    unconstrained_value = self._x_unconstrained.pop(name)
    transform = biject_to(fn.support)
    value = transform(unconstrained_value)
    if compute_density:
        logdet = transform.log_abs_det_jacobian(unconstrained_value, value)
        logdet = sum_rightmost(
            logdet, jnp.ndim(logdet) - jnp.ndim(value) + len(fn.event_shape)
        )
        log_density = log_density + fn.log_prob(value) + logdet
    numpyro.factor("_{}_log_prob".format(name), log_density)
    return None, value

[docs] def transform_sample(self, latent): """ Given latent samples from the warped posterior (with possible batch dimensions), return a dict of samples from the latent sites in the model.

    :param latent: sample from the warped posterior (possibly batched).
    :return: a `dict` of samples keyed by latent sites in the model.
    :rtype: dict
    """
    x_unconstrained = self.transform(latent)
    return self.guide._unpack_and_constrain(x_unconstrained, self.params)

[docs] class CircularReparam(Reparam): """ Reparametrizer for :class:~numpyro.distributions.VonMises latent variables. """

[docs] def call(self, name, fn, obs): # Support must be circular support = fn.support if isinstance(support, constraints.independent): support = fn.support.base_constraint assert support is constraints.circular assert obs is None, "CircularReparam does not support observe statements"

    # Draw parameter-free noise.
    new_fn = dist.ImproperUniform(constraints.real, fn.batch_shape, fn.event_shape)
    value = numpyro.sample(
        f"{name}_unwrapped",
        new_fn,
        obs=obs,
    )

    # Differentiably transform.
    value = jnp.remainder(value + math.pi, 2 * math.pi) - math.pi

    # Simulate a pyro.deterministic() site.
    numpyro.factor(f"{name}_factor", fn.log_prob(value))
    return None, value

[docs] class ExplicitReparam(Reparam): """ Explicit reparametrizer of a latent variable :code:x to a transformed space :code:y = transform(x) with more amenable geometry. This reparametrizer is similar to :class:.TransformReparam but allows reparametrizations to be decoupled from the model declaration.

:param transform: Bijective transform to the reparameterized space.

**Example:**

.. doctest::

    >>> from jax import random
    >>> from jax import numpy as jnp
    >>> import numpyro
    >>> from numpyro import handlers, distributions as dist
    >>> from numpyro.infer import MCMC, NUTS
    >>> from numpyro.infer.reparam import ExplicitReparam
    >>>
    >>> def model():
    ...    numpyro.sample("x", dist.Gamma(4, 4))
    >>>
    >>> # Sample in unconstrained space using a soft-plus instead of exp transform.
    >>> reparam = ExplicitReparam(dist.transforms.SoftplusTransform().inv)
    >>> reparametrized = handlers.reparam(model, {"x": reparam})
    >>> kernel = NUTS(model=reparametrized)
    >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=1)
    >>> mcmc.run(random.PRNGKey(2))  # doctest: +SKIP
    sample: 100%|██████████| 2000/2000 [00:00<00:00, 2306.47it/s, 3 steps of size 9.65e-01. acc. prob=0.93]
"""

def __init__(self, transform):
    if isinstance(transform, Iterable) and all(
        isinstance(t, dist.transforms.Transform) for t in transform
    ):
        transform = dist.transforms.ComposeTransform(transform)
    self.transform = transform

[docs] def call(self, name, fn, obs): assert obs is None, "ExplicitReparam does not support observe statements" transformed = dist.TransformedDistribution(fn, self.transform) x = numpyro.sample(f"{name}_base", transformed) return None, self.transform.inv(x)