Effect Handlers — NumPyro documentation (original) (raw)

This provides a small set of effect handlers in NumPyro that are modeled after Pyro’s poutine module. For a tutorial on effect handlers more generally, readers are encouraged to read Poutine: A Guide to Programming with Effect Handlers in Pyro. These simple effect handlers can be composed together or new ones added to enable implementation of custom inference utilities and algorithms.

When a handler, such as handlers.seed, is applied to a model in NumPyro, e.g.,seeded_model = handlers.seed(model, rng_seed=0), it creates a callable object with stateful attributes. These attributes can interfere with JAX primitives, such as jax.jit, jax.vmap, and jax.grad. To ensure proper composition with JAX primitives, handlers should be applied locally within the function or context where the model is used, rather than globally. For example:

Good: can be used in a jitted function

def seeded_model(data): return handlers.seed(model, rng_seed=0)(data)

Bad: might create tracer-leaks when used in a jitted function

seeded_model = handlers.seed(model, rng_seed=0)

Example

As an example, we are using seed, traceand substitute handlers to define the log_likelihood function below. We first create a logistic regression model and sample from the posterior distribution over the regression parameters using MCMC(). The log_likelihood function uses effect handlers to run the model by substituting sample sites with values from the posterior distribution and computes the log density for a single data point. The log_predictive_densityfunction computes the log likelihood for each draw from the joint posterior and aggregates the results for all the data points, but does so by using JAX’s auto-vectorize transform calledvmap so that we do not need to loop over all the data points.

import jax.numpy as jnp from jax import random, vmap from jax.scipy.special import logsumexp import numpyro import numpyro.distributions as dist from numpyro import handlers from numpyro.infer import MCMC, NUTS

N, D = 3000, 3 def logistic_regression(data, labels): ... coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(D), jnp.ones(D))) ... intercept = numpyro.sample('intercept', dist.Normal(0., 10.)) ... logits = jnp.sum(coefs * data + intercept, axis=-1) ... return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

data = random.normal(random.PRNGKey(0), (N, D)) true_coefs = jnp.arange(1., D + 1.) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

num_warmup, num_samples = 1000, 1000 mcmc = MCMC(NUTS(model=logistic_regression), num_warmup=num_warmup, num_samples=num_samples) mcmc.run(random.PRNGKey(2), data, labels)
sample: 100%|██████████| 1000/1000 [00:00<00:00, 1252.39it/s, 1 steps of size 5.83e-01. acc. prob=0.85] mcmc.print_summary()

               mean         sd       5.5%      94.5%      n_eff       Rhat
coefs[0]       0.96       0.07       0.85       1.07     455.35       1.01
coefs[1]       2.05       0.09       1.91       2.20     332.00       1.01
coefs[2]       3.18       0.13       2.96       3.37     320.27       1.00

intercept -0.03 0.02 -0.06 0.00 402.53 1.00

def log_likelihood(rng_key, params, model, *args, **kwargs): ... model = handlers.substitute(handlers.seed(model, rng_key), params) ... model_trace = handlers.trace(model).get_trace(*args, **kwargs) ... obs_node = model_trace['obs'] ... return obs_node['fn'].log_prob(obs_node['value'])

def log_predictive_density(rng_key, params, model, *args, **kwargs): ... n = list(params.values())[0].shape[0] ... log_lk_fn = vmap(lambda rng_key, params: log_likelihood(rng_key, params, model, *args, **kwargs)) ... log_lk_vals = log_lk_fn(random.split(rng_key, n), params) ... return jnp.sum(logsumexp(log_lk_vals, 0) - jnp.log(n))

print(log_predictive_density(random.PRNGKey(2), mcmc.get_samples(), ... logistic_regression, data, labels))
-874.89813

block

class block(fn: Callable | None = None, hide_fn: Callable | None = None, hide: list[str] | None = None, expose_types: list[str] | None = None, expose: list[str] | None = None)[source]

Bases: Messenger

Given a callable fn, return another callable that selectively hides primitive sites from other effect handlers on the stack. In the absence of parameters, all primitive sites are blocked. hide_fn takes precedence over hide, which has higher priority than expose_types followed by expose. Only the parameter with the precedence is considered.

Parameters:

Returns:

Python callable with NumPyro primitives.

Example:

from jax import random import numpyro from numpyro.handlers import block, seed, trace import numpyro.distributions as dist

def model(): ... a = numpyro.sample('a', dist.Normal(0., 1.)) ... return numpyro.sample('b', dist.Normal(a, 1.))

model = seed(model, random.PRNGKey(0)) block_all = block(model) block_a = block(model, lambda site: site['name'] == 'a') trace_block_all = trace(block_all).get_trace() assert not {'a', 'b'}.intersection(trace_block_all.keys()) trace_block_a = trace(block_a).get_trace() assert 'a' not in trace_block_a assert 'b' in trace_block_a

process_message(msg: dict[str, Any]) → None[source]

To be implemented by subclasses.

collapse

class collapse(*args, **kwargs)[source]

Bases: trace

EXPERIMENTAL Collapses all sites in the context by lazily sampling and attempting to use conjugacy relations. If no conjugacy is known this will fail. Code using the results of sample sites must be written to accept Funsors rather than Tensors. This requires funsor to be installed.

process_message(msg: dict[str, Any]) → None[source]

To be implemented by subclasses.

condition

class condition(fn: Callable | None = None, data: dict[str, Array | ndarray | bool_ | number | bool | int | float | complex] | None = None, condition_fn: Callable | None = None)[source]

Bases: Messenger

Conditions unobserved sample sites to values from data or condition_fn. Similar to substitute except that it only affectssample sites and changes the is_observed property to True.

Parameters:

Example:

from jax import random import numpyro from numpyro.handlers import condition, seed, substitute, trace import numpyro.distributions as dist

def model(): ... numpyro.sample('a', dist.Normal(0., 1.))

model = seed(model, random.PRNGKey(0)) exec_trace = trace(condition(model, {'a': -1})).get_trace() assert exec_trace['a']['value'] == -1 assert exec_trace['a']['is_observed']

process_message(msg: dict[str, Any]) → None[source]

To be implemented by subclasses.

do

class do(fn: Callable | None = None, data: dict[str, Array | ndarray | bool_ | number | bool | int | float | complex] | None = None)[source]

Bases: Messenger

Given a stochastic function with some sample statements and a dictionary of values at names, set the return values of those sites equal to the values as if they were hard-coded to those values and introduce fresh sample sites with the same names whose values do not propagate.

Composes freely with condition() to represent counterfactual distributions over potential outcomes. See Single World Intervention Graphs [1] for additional details and theory.

This is equivalent to replacing z = numpyro.sample(“z”, …) with z = 1.and introducing a fresh sample site numpyro.sample(“z”, …) whose value is not used elsewhere.

References:

  1. Single World Intervention Graphs: A Primer, Thomas Richardson, James Robins

Parameters:

Example:

import jax.numpy as jnp import numpyro from numpyro.handlers import do, trace, seed import numpyro.distributions as dist def model(x): ... s = numpyro.sample("s", dist.LogNormal()) ... z = numpyro.sample("z", dist.Normal(x, s)) ... return z ** 2 intervened_model = handlers.do(model, data={"z": 1.}) with trace() as exec_trace: ... z_square = seed(intervened_model, 0)(1) assert exec_trace['z']['value'] != 1. assert not exec_trace['z']['is_observed'] assert not exec_trace['z'].get('stop', None) assert z_square == 1

process_message(msg: dict[str, Any]) → None[source]

To be implemented by subclasses.

infer_config

class infer_config(fn: Callable | None = None, config_fn: Callable | None = None)[source]

Bases: Messenger

Given a callable fn that contains NumPyro primitive calls and a callable config_fn taking a trace site and returning a dictionary, updates the value of the infer kwarg at a sample site to config_fn(site).

Parameters:

process_message(msg: dict[str, Any]) → None[source]

To be implemented by subclasses.

lift

class lift(fn: Callable | None = None, prior: DistributionT | dict[str, DistributionT] | None = None)[source]

Bases: Messenger

Given a stochastic function with param calls and a prior distribution, create a stochastic function where all param calls are replaced by sampling from prior. Prior should be a distribution or a dict of names to distributions.

Consider the following NumPyro program:

import numpyro import numpyro.distributions as dist from numpyro.handlers import lift

def model(x): ... s = numpyro.param("s", 0.5) ... z = numpyro.sample("z", dist.Normal(x, s)) ... return z ** 2 lifted_model = lift(model, prior={"s": dist.Exponential(0.3)})

lift makes param statements behave like sample statements using the distributions in prior. In this example, site s will now behave as if it was replaced with s = numpyro.sample("s", dist.Exponential(0.3)).

Parameters:

process_message(msg: dict[str, Any]) → None[source]

To be implemented by subclasses.

mask

class mask(fn: Callable | None = None, mask: Array | ndarray | bool_ | number | bool | int | float | complex | None = True)[source]

Bases: Messenger

This messenger masks out some of the sample statements elementwise.

Parameters:

mask – a boolean or a boolean-valued array for masking elementwise log probability of sample sites (True includes a site, False excludes a site).

process_message(msg: dict[str, Any]) → None[source]

To be implemented by subclasses.

reparam

class reparam(fn: Callable | None = None, config: dict | Callable | None = None)[source]

Bases: Messenger

Reparametrizes each affected sample site into one or more auxiliary sample sites followed by a deterministic transformation [1].

To specify reparameterizers, pass a config dict or callable to the constructor. See the numpyro.infer.reparam module for available reparameterizers.

Note some reparameterizers can examine the *args,**kwargs inputs of functions they affect; these reparameterizers require usinghandlers.reparam as a decorator rather than as a context manager.

[1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)

“Automatic Reparameterisation of Probabilistic Programs”https://arxiv.org/pdf/1906.03028.pdf

Parameters:

config (dict or callable) – Configuration, either a dict mapping site name toReparam , or a function mapping site toReparam or None.

process_message(msg: dict[str, Any]) → None[source]

To be implemented by subclasses.

replay

class replay(fn: Callable | None = None, trace: OrderedDict[str, dict[str, Any]] | None = None)[source]

Bases: Messenger

Given a callable fn and an execution trace trace, return a callable which substitutes sample calls in fn with values from the corresponding site names in trace.

Parameters:

Example:

from jax import random import numpyro import numpyro.distributions as dist from numpyro.handlers import replay, seed, trace

def model(): ... numpyro.sample('a', dist.Normal(0., 1.))

exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace() print(exec_trace['a']['value'])
-0.20584235 replayed_trace = trace(replay(model, exec_trace)).get_trace() print(exec_trace['a']['value'])
-0.20584235 assert replayed_trace['a']['value'] == exec_trace['a']['value']

process_message(msg: dict[str, Any]) → None[source]

To be implemented by subclasses.

scale

class scale(fn: Callable | None = None, scale: Array | ndarray | bool_ | number | bool | int | float | complex = 1.0)[source]

Bases: Messenger

This messenger rescales the log probability score.

This is typically used for data subsampling or for stratified sampling of data (e.g. in fraud detection where negatives vastly outnumber positives).

Parameters:

scale (float or numpy.ndarray) – a positive scaling factor that is broadcastable to the shape of log probability.

process_message(msg: dict[str, Any]) → None[source]

To be implemented by subclasses.

scope

class scope(fn: Callable | None = None, prefix: str = '', divider: str = '/', *, hide_types: list[str] | None = None)[source]

Bases: Messenger

This handler prepend a prefix followed by a divider to the name of sample sites.

Example:

import numpyro import numpyro.distributions as dist from numpyro.handlers import scope, seed, trace

def model(): ... with scope(prefix="a"): ... with scope(prefix="b", divider="."): ... return numpyro.sample("x", dist.Bernoulli(0.5)) ... assert "a/b.x" in trace(seed(model, 0)).get_trace()

Parameters:

process_message(msg: dict[str, Any]) → None[source]

To be implemented by subclasses.

seed

class seed(fn: Callable | None = None, rng_seed: Array | int | None = None, hide_types: list[str] | None = None)[source]

Bases: Messenger

JAX uses a functional pseudo random number generator that requires passing in a seed PRNGKey() to every stochastic function. Theseed handler allows us to initially seed a stochastic function with aPRNGKey(). Every call to the sample()primitive inside the function results in a splitting of this initial seed so that we use a fresh seed for each subsequent call without having to explicitly pass in a PRNGKey to each sample call.

Parameters:

Note

Unlike in Pyro, numpyro.sample primitive cannot be used without wrapping it in seed handler since there is no global random state. As such, users need to use seed as a contextmanager to generate samples from distributions or as a decorator for their model callable (See below).

Note

The seed handler has a mutable attribute rng_key which keeps changing after each sample call. Hence an instance of this class (e.g. seed(model, rng_seed=0)) might create tracer-leaks when jitted. A solution is to close the instance in a function, e.g., seeded_model = lambda *args: seed(model, rng_seed=0)(*args). This seeded_model can be jitted.

Example:

from jax import random import numpyro import numpyro.handlers import numpyro.distributions as dist

as context manager

with handlers.seed(rng_seed=1): ... x = numpyro.sample('x', dist.Normal(0., 1.))

def model(): ... return numpyro.sample('y', dist.Normal(0., 1.))

as function decorator (/modifier)

y = handlers.seed(model, rng_seed=1)() assert x == y

stateful = False

process_message(msg: dict[str, Any]) → None[source]

To be implemented by subclasses.

substitute

class substitute(fn: Callable | None = None, data: dict[str, Array] | None = None, substitute_fn: Callable | None = None)[source]

Bases: Messenger

Given a callable fn and a dict data keyed by site names (alternatively, a callable substitute_fn), return a callable which substitutes all primitive calls in fn with values fromdata whose key matches the site name. If the site name is not present in data, there is no side effect.

If a substitute_fn is provided, then the value at the site is replaced by the value returned from the call to substitute_fnfor the given site.

Note

This handler is mainly used for internal algorithms. For conditioning a generative model on observed data, please use the condition handler.

Parameters:

Example:

from jax import random import numpyro from numpyro.handlers import seed, substitute, trace import numpyro.distributions as dist

def model(): ... numpyro.sample('a', dist.Normal(0., 1.))

model = seed(model, random.PRNGKey(0)) exec_trace = trace(substitute(model, {'a': -1})).get_trace() assert exec_trace['a']['value'] == -1

process_message(msg: dict[str, Any]) → None[source]

To be implemented by subclasses.

trace

class trace(fn: Callable | None = None)[source]

Bases: Messenger

Returns a handler that records the inputs and outputs at primitive calls inside fn.

Example:

from jax import random import numpyro import numpyro.distributions as dist from numpyro.handlers import seed, trace import pprint as pp

def model(): ... numpyro.sample('a', dist.Normal(0., 1.))

exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace() pp.pprint(exec_trace)
OrderedDict([('a', {'args': (), 'fn': <numpyro.distributions.continuous.Normal object at 0x7f9e689b1eb8>, 'is_observed': False, 'kwargs': {'rng_key': Array([0, 0], dtype=uint32)}, 'name': 'a', 'type': 'sample', 'value': Array(-0.20584235, dtype=float32)})])

postprocess_message(msg: dict[str, Any]) → None[source]

To be implemented by subclasses.

get_trace(*args, **kwargs) → OrderedDict[str, dict[str, Any]][source]

Run the wrapped callable and return the recorded trace.

Parameters:

Returns:

OrderedDict containing the execution trace.