Poutine (Effect handlers) — Pyro documentation (original) (raw)

Beneath the built-in inference algorithms, Pyro has a library of composable effect handlers for creating new inference algorithms and working with probabilistic programs. Pyro’s inference algorithms are all built by applying these handlers to stochastic functions. In order to get a general understanding what effect handlers are and what problem they solve, readAn Introduction to Algebraic Effects and Handlersby Matija Pretnar.

Handlers

Poutine is a library of composable effect handlers for recording and modifying the behavior of Pyro programs. These lower-level ingredients simplify the implementation of new inference algorithms and behavior.

Handlers can be used as higher-order functions, decorators, or context managers to modify the behavior of functions or blocks of code:

For example, consider the following Pyro program:

def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2

We can mark sample sites as observed using condition, which returns a callable with the same input and output signatures as model:

conditioned_model = poutine.condition(model, data={"z": 1.0})

We can also use handlers as decorators:

@pyro.condition(data={"z": 1.0}) ... def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2

Or as context managers:

with pyro.condition(data={"z": 1.0}): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(0., s)) ... y = z ** 2

Handlers compose freely:

conditioned_model = poutine.condition(model, data={"z": 1.0}) traced_model = poutine.trace(conditioned_model)

Many inference algorithms or algorithmic components can be implemented in just a few lines of code:

guide_tr = poutine.trace(guide).get_trace(...) model_tr = poutine.trace(poutine.replay(conditioned_model, trace=guide_tr)).get_trace(...) monte_carlo_elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()

block(fn: None = None, hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None) → pyro.poutine.block_messenger.BlockMessenger[source]

block(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None) → Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

Convenient wrapper of BlockMessenger

This handler selectively hides Pyro primitive sites from the outside world. Default behavior: block everything.

A site is hidden if at least one of the following holds:

  1. hide_fn(msg) is True or (not expose_fn(msg)) is True
  2. msg["name"] in hide
  3. msg["type"] in hide_types
  4. msg["name"] not in expose and msg["type"] not in expose_types
  5. hide, hide_types, and expose_types are all None

For example, suppose the stochastic function fn has two sample sites “a” and “b”. Then any effect outside of BlockMessenger(fn, hide=["a"])will not be applied to site “a” and will only see site “b”:

def fn(): ... a = pyro.sample("a", dist.Normal(0., 1.)) ... return pyro.sample("b", dist.Normal(a, 1.)) fn_inner = pyro.poutine.trace(fn) fn_outer = pyro.poutine.trace(pyro.poutine.block(fn_inner, hide=["a"])) trace_inner = fn_inner.get_trace() trace_outer = fn_outer.get_trace() "a" in trace_inner True "a" in trace_outer False "b" in trace_inner True "b" in trace_outer True

Parameters

Returns

stochastic function decorated with a BlockMessenger

broadcast(fn: None = None) → pyro.poutine.broadcast_messenger.BroadcastMessenger[source]

broadcast(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]) → Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

Convenient wrapper of BroadcastMessenger

Automatically broadcasts the batch shape of the stochastic function at a sample site when inside a single or nested plate context. The existing batch_shape must be broadcastable with the size of the plate contexts installed in thecond_indep_stack.

Notice how model_automatic_broadcast below automates expanding of distribution batch shapes. This makes it easy to modularize a Pyro model as the sub-components are agnostic of the wrappingplate contexts.

def model_broadcast_by_hand(): ... with IndepMessenger("batch", 100, dim=-2): ... with IndepMessenger("components", 3, dim=-1): ... sample = pyro.sample("sample", dist.Bernoulli(torch.ones(3) * 0.5) ... .expand_by(100)) ... assert sample.shape == torch.Size((100, 3)) ... return sample

@poutine.broadcast ... def model_automatic_broadcast(): ... with IndepMessenger("batch", 100, dim=-2): ... with IndepMessenger("components", 3, dim=-1): ... sample = pyro.sample("sample", dist.Bernoulli(torch.tensor(0.5))) ... assert sample.shape == torch.Size((100, 3)) ... return sample

collapse(fn: None = None, *args: Any, **kwargs: Any) → pyro.poutine.collapse_messenger.CollapseMessenger[source]

collapse(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], *args: Any, **kwargs: Any) → Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

Convenient wrapper of CollapseMessenger

EXPERIMENTAL Collapses all sites in the context by lazily sampling and attempting to use conjugacy relations. If no conjugacy is known this will fail. Code using the results of sample sites must be written to accept Funsors rather than Tensors. This requires funsor to be installed.

Warning

This is not compatible with automatic guessing ofmax_plate_nesting. If any plates appear within the collapsed context, you should manually declare max_plate_nesting to your inference algorithm (e.g. Trace_ELBO(max_plate_nesting=1)).

condition(data: Union[Dict[str, torch.Tensor], Trace]) → pyro.poutine.condition_messenger.ConditionMessenger[source]

condition(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Union[Dict[str, torch.Tensor], Trace]) → Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

Convenient wrapper of ConditionMessenger

Given a stochastic function with some sample statements and a dictionary of observations at names, change the sample statements at those names into observes with those values.

Consider the following Pyro program:

def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2

To observe a value for site z, we can write

conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})

This is equivalent to adding obs=value as a keyword argument to pyro.sample(“z”, …) in model.

Parameters

Returns

stochastic function decorated with a ConditionMessenger

do(data: Dict[str, Union[torch.Tensor, numbers.Number]]) → pyro.poutine.do_messenger.DoMessenger[source]

do(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Dict[str, Union[torch.Tensor, numbers.Number]]) → Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

Convenient wrapper of DoMessenger

Given a stochastic function with some sample statements and a dictionary of values at names, set the return values of those sites equal to the values as if they were hard-coded to those values and introduce fresh sample sites with the same names whose values do not propagate.

Composes freely with condition()to represent counterfactual distributions over potential outcomes. See Single World Intervention Graphs [1] for additional details and theory.

Consider the following Pyro program:

def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2

To intervene with a value for site z, we can write

intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})

This is equivalent to replacing z = pyro.sample(“z”, …) withz = torch.tensor(1.)and introducing a fresh sample site pyro.sample(“z”, …) whose value is not used elsewhere.

References

[1] Single World Intervention Graphs: A Primer,

Thomas Richardson, James Robins

Parameters

Returns

stochastic function decorated with a DoMessenger

enum(fn: None = None, first_available_dim: Optional[int] = None) → pyro.poutine.enum_messenger.EnumMessenger[source]

enum(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], first_available_dim: Optional[int] = None) → Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

Convenient wrapper of EnumMessenger

Enumerates in parallel over discrete sample sites markedinfer={"enumerate": "parallel"}.

Parameters

first_available_dim (int) – The first tensor dimension (counting from the right) that is available for parallel enumeration. This dimension and all dimensions left may be used internally by Pyro. This should be a negative integer or None.

escape(escape_fn: Callable[[Message], bool]) → pyro.poutine.escape_messenger.EscapeMessenger[source]

escape(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], escape_fn: Callable[[Message], bool]) → Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

Convenient wrapper of EscapeMessenger

Messenger that does a nonlocal exit by raising a util.NonlocalExit exception

infer_config(config_fn: Callable[[Message], InferDict]) → pyro.poutine.infer_config_messenger.InferConfigMessenger[source]

infer_config(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], config_fn: Callable[[Message], InferDict]) → Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

Convenient wrapper of InferConfigMessenger

Given a callable fn that contains Pyro primitive calls and a callable config_fn taking a trace site and returning a dictionary, updates the value of the infer kwarg at a sample site to config_fn(site).

Parameters

Returns

stochastic function decorated with InferConfigMessenger

lift(prior: Union[Callable, Distribution, Dict[str, Union[Distribution, Callable]]]) → pyro.poutine.lift_messenger.LiftMessenger[source]

lift(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], prior: Union[Callable, Distribution, Dict[str, Union[Distribution, Callable]]]) → Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

Convenient wrapper of LiftMessenger

Given a stochastic function with param calls and a prior distribution, create a stochastic function where all param calls are replaced by sampling from prior. Prior should be a callable or a dict of names to callables.

Consider the following Pyro program:

def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2 lifted_model = pyro.poutine.lift(model, prior={"s": dist.Exponential(0.3)})

lift makes param statements behave like sample statements using the distributions in prior. In this example, site s will now behave as if it was replaced with s = pyro.sample("s", dist.Exponential(0.3)):

tr = pyro.poutine.trace(lifted_model).get_trace(0.0) tr.nodes["s"]["type"] == "sample" True tr2 = pyro.poutine.trace(lifted_model).get_trace(0.0) bool((tr2.nodes["s"]["value"] == tr.nodes["s"]["value"]).all()) False

Parameters

Returns

fn decorated with a LiftMessenger

markov(fn: None = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) → pyro.poutine.markov_messenger.MarkovMessenger[source]

markov(fn: Iterable[int] = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) → pyro.poutine.markov_messenger.MarkovMessenger

markov(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T] = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) → Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

Markov dependency declaration.

This can be used in a variety of ways:

Parameters

mask(mask: Union[bool, torch.BoolTensor]) → pyro.poutine.mask_messenger.MaskMessenger[source]

mask(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], mask: Union[bool, torch.BoolTensor]) → Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

Convenient wrapper of MaskMessenger

Given a stochastic function with some batched sample statements and masking tensor, mask out some of the sample statements elementwise.

Parameters

Returns

stochastic function decorated with a MaskMessenger

queue(fn=None, queue=None, max_tries=None, extend_fn=None, escape_fn=None, num_samples=None)[source]

Used in sequential enumeration over discrete variables.

Given a stochastic function and a queue, return a return value from a complete trace in the queue.

Parameters

Returns

stochastic function decorated with poutine logic

reparam(config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]]) → pyro.poutine.reparam_messenger.ReparamMessenger[source]

reparam(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]]) → pyro.poutine.reparam_messenger.ReparamHandler[pyro.poutine.handlers._P, pyro.poutine.handlers._T]

Convenient wrapper of ReparamMessenger

Reparametrizes each affected sample site into one or more auxiliary sample sites followed by a deterministic transformation [1].

To specify reparameterizers, pass a config dict or callable to the constructor. See the pyro.infer.reparam module for available reparameterizers.

Note some reparameterizers can examine the *args,**kwargs inputs of functions they affect; these reparameterizers require usingpoutine.reparam as a decorator rather than as a context manager.

[1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)

“Automatic Reparameterisation of Probabilistic Programs”https://arxiv.org/pdf/1906.03028.pdf

Parameters

config (dict or callable) – Configuration, either a dict mapping site name toReparameterizer , or a function mapping site to Reparam or None. See pyro.infer.reparam.strategies for built-in configuration strategies.

replay(fn: None = None, trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor]] = None) → pyro.poutine.replay_messenger.ReplayMessenger[source]

replay(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor]] = None) → Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

Convenient wrapper of ReplayMessenger

Given a callable that contains Pyro primitive calls, return a callable that runs the original, reusing the values at sites in trace at those sites in the new trace

Consider the following Pyro program:

def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2

replay makes sample statements behave as if they had sampled the values at the corresponding sites in the trace:

old_trace = pyro.poutine.trace(model).get_trace(1.0) replayed_model = pyro.poutine.replay(model, trace=old_trace) bool(replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"]) True

Parameters

Returns

a stochastic function decorated with a ReplayMessenger

scale(scale: Union[float, torch.Tensor]) → pyro.poutine.scale_messenger.ScaleMessenger[source]

scale(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], scale: Union[float, torch.Tensor]) → Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

Convenient wrapper of ScaleMessenger

Given a stochastic function with some sample statements and a positive scale factor, scale the score of all sample and observe sites in the function.

Consider the following Pyro program:

def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... pyro.sample("z", dist.Normal(x, s), obs=torch.tensor(1.0))

scale multiplicatively scales the log-probabilities of sample sites:

scaled_model = pyro.poutine.scale(model, scale=0.5) scaled_tr = pyro.poutine.trace(scaled_model).get_trace(0.0) unscaled_tr = pyro.poutine.trace(model).get_trace(0.0) bool((scaled_tr.log_prob_sum() == 0.5 * unscaled_tr.log_prob_sum()).all()) True

Parameters

Returns

stochastic function decorated with a ScaleMessenger

seed(rng_seed: int) → pyro.poutine.seed_messenger.SeedMessenger[source]

seed(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], rng_seed: int) → Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

Convenient wrapper of SeedMessenger

Handler to set the random number generator to a pre-defined state by setting its seed. This is the same as calling pyro.set_rng_seed() before the call to fn. This handler has no additional effect on primitive statements on the standard Pyro backend, but it might intercept pyro.sample calls in other backends. e.g. the NumPy backend.

Parameters

substitute(data: Dict[str, torch.Tensor]) → pyro.poutine.substitute_messenger.SubstituteMessenger[source]

substitute(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Dict[str, torch.Tensor]) → Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

Convenient wrapper of SubstituteMessenger

Given a stochastic function with param calls and a set of parameter values, create a stochastic function where all param calls are substituted with the fixed values. data should be a dict of names to values. Consider the following Pyro program:

def model(x): ... a = pyro.param("a", torch.tensor(0.5)) ... x = pyro.sample("x", dist.Bernoulli(probs=a)) ... return x substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)})

In this example, site a will now have value torch.tensor(0.3). :param data: dictionary of values keyed by site names. :returns: fn decorated with a SubstituteMessenger

trace(fn: None = None, graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None) → pyro.poutine.trace_messenger.TraceMessenger[source]

trace(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None) → pyro.poutine.trace_messenger.TraceHandler[pyro.poutine.handlers._P, pyro.poutine.handlers._T]

Convenient wrapper of TraceMessenger

Return a handler that records the inputs and outputs of primitive calls and their dependencies.

Consider the following Pyro program:

def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2

We can record its execution using traceand use the resulting data structure to compute the log-joint probability of all of the sample sites in the execution or extract all parameters.

trace = pyro.poutine.trace(model).get_trace(0.0) logp = trace.log_prob_sum() params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]

Parameters

Returns

stochastic function decorated with a TraceMessenger

uncondition(fn: None = None) → pyro.poutine.uncondition_messenger.UnconditionMessenger[source]

uncondition(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T] = None) → Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

Convenient wrapper of UnconditionMessenger

Messenger to force the value of observed nodes to be sampled from their distribution, ignoring observations.

config_enumerate(guide=None, default='parallel', expand=False, num_samples=None, tmc='diagonal')[source]

Configures enumeration for all relevant sites in a guide. This is mainly used in conjunction with TraceEnum_ELBO.

When configuring for exhaustive enumeration of discrete variables, this configures all sample sites whose distribution satisfies.has_enumerate_support == True. When configuring for local parallel Monte Carlo sampling viadefault="parallel", num_samples=n, this configures all sample sites. This does not overwrite existing annotations infer={"enumerate": ...}.

This can be used as either a function:

guide = config_enumerate(guide)

or as a decorator:

@config_enumerate def guide1(*args, **kwargs): ...

@config_enumerate(default="sequential", expand=True) def guide2(*args, **kwargs): ...

Parameters

Returns

an annotated guide

Return type

callable

Trace

class Trace(graph_type: Literal['flat', 'dense'] = 'flat')[source]

Bases: object

Graph data structure denoting the relationships amongst different pyro primitives in the execution trace.

An execution trace of a Pyro program is a record of every call to pyro.sample() and pyro.param() in a single execution of that program. Traces are directed graphs whose nodes represent primitive calls or input/output, and whose edges represent conditional dependence relationships between those primitive calls. They are created and populated by poutine.trace.

Each node (or site) in a trace contains the name, input and output value of the site, as well as additional metadata added by inference algorithms or user annotation. In the case of pyro.sample, the trace also includes the stochastic function at the site, and any observed data added by users.

Consider the following Pyro program:

def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2

We can record its execution using pyro.poutine.traceand use the resulting data structure to compute the log-joint probability of all of the sample sites in the execution or extract all parameters.

trace = pyro.poutine.trace(model).get_trace(0.0) logp = trace.log_prob_sum() params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]

We can also inspect or manipulate individual nodes in the trace.trace.nodes contains a collections.OrderedDictof site names and metadata corresponding to x, s, z, and the return value:

list(name for name in trace.nodes.keys())
["_INPUT", "s", "z", "_RETURN"]

Values of trace.nodes are dictionaries of node metadata:

trace.nodes["z"]
{'type': 'sample', 'name': 'z', 'is_observed': False, 'fn': Normal(), 'value': tensor(0.6480), 'args': (), 'kwargs': {}, 'infer': {}, 'scale': 1.0, 'cond_indep_stack': (), 'done': True, 'stop': False, 'continuation': None}

'infer' is a dictionary of user- or algorithm-specified metadata.'args' and 'kwargs' are the arguments passed via pyro.sampleto fn.__call__ or fn.log_prob.'scale' is used to scale the log-probability of the site when computing the log-joint.'cond_indep_stack' contains data structures corresponding to pyro.plate contexts appearing in the execution.'done', 'stop', and 'continuation' are only used by Pyro’s internals.

Parameters

graph_type (string) – string specifying the kind of trace graph to construct

add_edge(site1: str, site2: str) → None[source]

add_node(site_name: str, **kwargs: Any) → None[source]

Parameters

site_name (string) – the name of the site to be added

Adds a site to the trace.

Raises an error when attempting to add a duplicate node instead of silently overwriting.

compute_log_prob(site_filter: Callable[[str, Message], bool] = <function allow_all_sites>) → None[source]

Compute the site-wise log probabilities of the trace. Each log_prob has shape equal to the corresponding batch_shape. Each log_prob_sum is a scalar. Both computations are memoized.

compute_score_parts() → None[source]

Compute the batched local score parts at each site of the trace. Each log_prob has shape equal to the corresponding batch_shape. Each log_prob_sum is a scalar. All computations are memoized.

copy() → pyro.poutine.trace_struct.Trace[source]

Makes a shallow copy of self with nodes and edges preserved.

detach_() → None[source]

Detach values (in-place) at each sample site of the trace.

property edges_: Iterable[Tuple[str, str]]_

format_shapes(title: str = 'Trace Shapes:', last_site: Optional[str] = None) → str[source]

Returns a string showing a table of the shapes of all sites in the trace.

iter_stochastic_nodes() → Iterator[Tuple[str, Message]][source]

Returns

an iterator over stochastic nodes in the trace.

log_prob_sum(site_filter: Callable[[str, Message], bool] = <function allow_all_sites>) → Union[torch.Tensor, float][source]

Compute the site-wise log probabilities of the trace. Each log_prob has shape equal to the corresponding batch_shape. Each log_prob_sum is a scalar. The computation of log_prob_sum is memoized.

Returns

total log probability.

Return type

torch.Tensor

property nonreparam_stochastic_nodes_: List[str]_

a list of names of sample sites whose stochastic functions are not reparameterizable primitive distributions

Type

return

property observation_nodes_: List[str]_

a list of names of observe sites

Type

return

pack_tensors(plate_to_symbol: Optional[Dict[str, str]] = None) → None[source]

Computes packed representations of tensors in the trace. This should be called after compute_log_prob() or compute_score_parts().

property param_nodes_: List[str]_

a list of names of param sites

Type

return

predecessors(site_name: str) → Set[str][source]

remove_node(site_name: str) → None[source]

property reparameterized_nodes_: List[str]_

a list of names of sample sites whose stochastic functions are reparameterizable primitive distributions

Type

return

property stochastic_nodes_: List[str]_

a list of names of sample sites

Type

return

successors(site_name: str) → Set[str][source]

symbolize_dims(plate_to_symbol: Optional[Dict[str, str]] = None) → None[source]

Assign unique symbols to all tensor dimensions.

topological_sort(reverse: bool = False) → List[str][source]

Return a list of nodes (site names) in topologically sorted order.

Parameters

reverse (bool) – Return the list in reverse order.

Returns

list of topologically sorted nodes (site names).

Runtime

class InferDict[source]

Bases: typing_extensions.TypedDict

A dictionary that contains information about inference.

This can be used to configure per-site inference strategies, e.g.:

pyro.sample( "x", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}, )

Keys:

enumerate (str):

If one of the strings “sequential” or “parallel”, enables enumeration. Parallel enumeration is generally faster but requires broadcasting-safe operations and static structure.

expand (bool):

Whether to expand the distribution during enumeration. Defaults to False if missing.

is_auxiliary (bool):

Whether the sample site is auxiliary, e.g. for use in guides that deterministically transform auxiliary variables. Defaults to False if missing.

is_observed (bool):

Whether the sample site is observed (i.e. not latent). Defaults to False if missing.

num_samples (int):

The number of samples to draw. Defaults to 1 if missing.

obs (optional torch.Tensor):

The observed value, or None for latent variables. Defaults to None if missing.

prior (optional torch.distributions.Distribution):

(internal) For use in GuideMessenger to store the model’s prior distribution (conditioned on upstream sites).

tmc (str):

Whether to use the diagonal or mixture approximation for Tensor Monte Carlo in TraceTMC_ELBO.

was_observed (bool):

(internal) Whether the sample site was originally observed, in the context of inference via Reweighted Wake Sleep or Compiled Sequential Importance Sampling.

enumerate_: typing_extensions.Literal[sequential, parallel]_

expand_: bool_

is_auxiliary_: bool_

is_observed_: bool_

num_samples_: int_

obs_: Optional[torch.Tensor]_

prior_: TorchDistributionMixin_

tmc_: typing_extensions.Literal[diagonal, mixture]_

was_observed_: bool_

class Message[source]

Bases: typing_extensions.TypedDict, Generic[pyro.poutine.runtime._P, pyro.poutine.runtime._T]

Pyro’s internal message type for effect handling.

Messages are stored in trace objects, e.g.:

trace.nodes["my_site_name"] # This is a Message.

Keys:

type (str):

The message type, typically one of the strings “sample”, “param”, “plate”, or “markov”, but possibly custom.

name (str):

The site name, typically naming a sample or parameter.

fn (callable):

The distribution or function used to generate the sample.

is_observed (bool):

A flag to indicate whether the value is observed.

args (tuple):

Positional arguments to the distribution or function.

kwargs (dict):

Keyword arguments to the distribution or function.

value (torch.Tensor):

The value of the sample (either observed or sampled).

scale (torch.Tensor):

A scaling factor for the log probability.

mask (bool torch.Tensor):

A bool or tensor to mask the log probability.

cond_indep_stack (tuple):

The site’s local stack of conditional independence metadata. Immutable.

done (bool):

A flag to indicate whether the message has been handled.

stop (bool):

A flag to stop further processing of the message.

continuation (callable):

A function to call after processing the message.

infer (optional InferDict):

A dictionary of inference parameters.

obs (torch.Tensor):

The observed value.

log_prob (torch.Tensor):

The log probability of the sample.

log_prob_sum (torch.Tensor):

The sum of the log probability.

unscaled_log_prob (torch.Tensor):

The unscaled log probability.

score_parts (pyro.distributions.ScoreParts):

A collection of score parts.

packed (Message):

A packed message, used during enumeration.

args_: Tuple_

cond_indep_stack_: Tuple[CondIndepStackFrame, ...]_

continuation_: Optional[Callable[[Message], None]]_

done_: bool_

fn_: Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T]_

infer_: Optional[pyro.poutine.runtime.InferDict]_

is_observed_: bool_

kwargs_: Dict_

log_prob_: torch.Tensor_

log_prob_sum_: torch.Tensor_

mask_: Optional[Union[bool, torch.Tensor]]_

name_: Optional[str]_

obs_: Optional[torch.Tensor]_

packed_: Message_

scale_: Union[torch.Tensor, float]_

score_parts_: ScoreParts_

stop_: bool_

type_: str_

unscaled_log_prob_: torch.Tensor_

value_: Optional[pyro.poutine.runtime._T]_

exception NonlocalExit(site: pyro.poutine.runtime.Message, *args, **kwargs)[source]

Bases: Exception

Exception for exiting nonlocally from poutine execution.

Used by poutine.EscapeMessenger to return site information.

reset_stack() → None[source]

Reset the state of the frames remaining in the stack. Necessary for multiple re-executions in poutine.queue.

am_i_wrapped() → bool[source]

Checks whether the current computation is wrapped in a poutine. :returns: bool

apply_stack(initial_msg: pyro.poutine.runtime.Message) → None[source]

Execute the effect stack at a single site according to the following scheme:

  1. For each Messenger in the stack from bottom to top, execute Messenger._process_message with the message; if the message field “stop” is True, stop; otherwise, continue
  2. Apply default behavior (default_process_message) to finish remaining site execution
  3. For each Messenger in the stack from top to bottom, execute _postprocess_message to update the message and internal messenger state with the site results
  4. If the message field “continuation” is not None, call it with the message

Parameters

initial_msg (dict) – the starting version of the trace site

Returns

None

default_process_message(msg: pyro.poutine.runtime.Message) → None[source]

Default method for processing messages in inference.

Parameters

msg – a message to be processed

Returns

None

effectful(fn: None = None, type: Optional[str] = None) → Callable[[Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T]], Callable[[...], pyro.poutine.runtime._T]][source]

effectful(fn: Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T] = None, type: Optional[str] = None) → Callable[[...], pyro.poutine.runtime._T]

Parameters

Wrapper for calling apply_stack() to apply any active effects.

get_mask() → Optional[Union[bool, torch.Tensor]][source]

Records the effects of enclosing poutine.mask handlers.

This is useful for avoiding expensive pyro.factor() computations during prediction, when the log density need not be computed, e.g.:

def model(): # ... if poutine.get_mask() is not False: log_density = my_expensive_computation() pyro.factor("foo", log_density) # ...

Returns

The mask.

Return type

None, bool, or torch.Tensor

get_plates() → Tuple[CondIndepStackFrame, ...][source]

Records the effects of enclosing pyro.plate contexts.

Returns

A tuple ofpyro.poutine.indep_messenger.CondIndepStackFrame objects.

Return type

tuple

Utilities

all_escape(trace: Trace, msg: Message) → bool[source]

Parameters

Returns

boolean decision value

Utility function that checks if a site is not already in a trace.

Used by EscapeMessenger to decide whether to do a nonlocal exit at a site. Subroutine for approximately integrating out variables for variance reduction.

discrete_escape(trace: Trace, msg: Message) → bool[source]

Parameters

Returns

boolean decision value

Utility function that checks if a sample site is discrete and not already in a trace.

Used by EscapeMessenger to decide whether to do a nonlocal exit at a site. Subroutine for integrating out discrete variables for variance reduction.

enable_validation(is_validate: bool) → None[source]

enum_extend(trace: Trace, msg: Message, num_samples: Optional[int] = None) → List[Trace][source]

Parameters

Returns

a list of traces, copies of input trace with one extra site

Utility function to copy and extend a trace with sites based on the input site whose values are enumerated from the support of the input site’s distribution.

Used for exact inference and integrating out discrete variables.

is_validation_enabled() → bool[source]

mc_extend(trace: Trace, msg: Message, num_samples: Optional[int] = None) → List[Trace][source]

Parameters

Returns

a list of traces, copies of input trace with one extra site

Utility function to copy and extend a trace with sites based on the input site whose values are sampled from the input site’s function.

Used for Monte Carlo marginalization of individual sample sites.

prune_subsample_sites(trace: Trace) → Trace[source]

Copies and removes all subsample sites from a trace.

site_is_factor(site: Message) → bool[source]

Determines whether a trace site originated from a factor statement.

site_is_subsample(site: Message) → bool[source]

Determines whether a trace site originated from a subsample statement inside an plate.

Messengers

Messenger objects contain the implementations of the effects exposed by handlers. Advanced users may modify the implementations of messengers behind existing handlers or write new messengers that implement new effects and compose correctly with the rest of the library.

Messenger

class Messenger[source]

Bases: object

Context manager class that modifies behavior and adds side effects to stochastic functions i.e. callables containing Pyro primitive statements.

This is the base Messenger class. It implements the default behavior for all Pyro primitives, so that the joint distribution induced by a stochastic function fn is identical to the joint distribution induced by Messenger()(fn).

Class of transformers for messages passed during inference. Most inference operations are implemented in subclasses of this.

classmethod register(fn: Optional[Callable] = None, type: Optional[str] = None, post: Optional[bool] = None) → Callable[source]

Parameters

Dynamically add operations to an effect. Useful for generating wrappers for libraries.

Example:

@SomeMessengerClass.register def some_function(msg) ...do_something... return msg

classmethod unregister(fn: Optional[Callable] = None, type: Optional[str] = None) → Optional[Callable][source]

Parameters

Dynamically remove operations from an effect. Useful for removing wrappers from libraries.

Example:

SomeMessengerClass.unregister(some_function, "name")

block_messengers(predicate: Callable[[pyro.poutine.messenger.Messenger], bool]) → Iterator[List[pyro.poutine.messenger.Messenger]][source]

EXPERIMENTAL Context manager to temporarily remove matching messengers from the _PYRO_STACK. Note this does not call the .__exit__() and.__enter__() methods.

This is useful to selectively block enclosing handlers.

Parameters

predicate (callable) – A predicate mapping messenger instance to boolean. This mutes all messengers m for which bool(predicate(m)) is True.

Yields

A list of matched messengers that are blocked.

unwrap(fn: Callable) → Callable[source]

Recursively unwraps poutines.

BlockMessenger

class BlockMessenger(hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None)[source]

Bases: pyro.poutine.messenger.Messenger

This handler selectively hides Pyro primitive sites from the outside world. Default behavior: block everything.

A site is hidden if at least one of the following holds:

  1. hide_fn(msg) is True or (not expose_fn(msg)) is True
  2. msg["name"] in hide
  3. msg["type"] in hide_types
  4. msg["name"] not in expose and msg["type"] not in expose_types
  5. hide, hide_types, and expose_types are all None

For example, suppose the stochastic function fn has two sample sites “a” and “b”. Then any effect outside of BlockMessenger(fn, hide=["a"])will not be applied to site “a” and will only see site “b”:

def fn(): ... a = pyro.sample("a", dist.Normal(0., 1.)) ... return pyro.sample("b", dist.Normal(a, 1.)) fn_inner = pyro.poutine.trace(fn) fn_outer = pyro.poutine.trace(pyro.poutine.block(fn_inner, hide=["a"])) trace_inner = fn_inner.get_trace() trace_outer = fn_outer.get_trace() "a" in trace_inner True "a" in trace_outer False "b" in trace_inner True "b" in trace_outer True

Parameters

Returns

stochastic function decorated with a BlockMessenger

BroadcastMessenger

class BroadcastMessenger[source]

Bases: pyro.poutine.messenger.Messenger

Automatically broadcasts the batch shape of the stochastic function at a sample site when inside a single or nested plate context. The existing batch_shape must be broadcastable with the size of the plate contexts installed in thecond_indep_stack.

Notice how model_automatic_broadcast below automates expanding of distribution batch shapes. This makes it easy to modularize a Pyro model as the sub-components are agnostic of the wrappingplate contexts.

def model_broadcast_by_hand(): ... with IndepMessenger("batch", 100, dim=-2): ... with IndepMessenger("components", 3, dim=-1): ... sample = pyro.sample("sample", dist.Bernoulli(torch.ones(3) * 0.5) ... .expand_by(100)) ... assert sample.shape == torch.Size((100, 3)) ... return sample

@poutine.broadcast ... def model_automatic_broadcast(): ... with IndepMessenger("batch", 100, dim=-2): ... with IndepMessenger("components", 3, dim=-1): ... sample = pyro.sample("sample", dist.Bernoulli(torch.tensor(0.5))) ... assert sample.shape == torch.Size((100, 3)) ... return sample

CollapseMessenger

class CollapseMessenger(*args: Any, **kwargs: Any)[source]

Bases: pyro.poutine.trace_messenger.TraceMessenger

EXPERIMENTAL Collapses all sites in the context by lazily sampling and attempting to use conjugacy relations. If no conjugacy is known this will fail. Code using the results of sample sites must be written to accept Funsors rather than Tensors. This requires funsor to be installed.

Warning

This is not compatible with automatic guessing ofmax_plate_nesting. If any plates appear within the collapsed context, you should manually declare max_plate_nesting to your inference algorithm (e.g. Trace_ELBO(max_plate_nesting=1)).

ConditionMessenger

class ConditionMessenger(data: Union[Dict[str, torch.Tensor], pyro.poutine.trace_struct.Trace])[source]

Bases: pyro.poutine.messenger.Messenger

Given a stochastic function with some sample statements and a dictionary of observations at names, change the sample statements at those names into observes with those values.

Consider the following Pyro program:

def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2

To observe a value for site z, we can write

conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})

This is equivalent to adding obs=value as a keyword argument to pyro.sample(“z”, …) in model.

Parameters

Returns

stochastic function decorated with a ConditionMessenger

DoMessenger

class DoMessenger(data: Dict[str, Union[torch.Tensor, numbers.Number]])[source]

Bases: pyro.poutine.messenger.Messenger

Given a stochastic function with some sample statements and a dictionary of values at names, set the return values of those sites equal to the values as if they were hard-coded to those values and introduce fresh sample sites with the same names whose values do not propagate.

Composes freely with condition()to represent counterfactual distributions over potential outcomes. See Single World Intervention Graphs [1] for additional details and theory.

Consider the following Pyro program:

def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2

To intervene with a value for site z, we can write

intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})

This is equivalent to replacing z = pyro.sample(“z”, …) withz = torch.tensor(1.)and introducing a fresh sample site pyro.sample(“z”, …) whose value is not used elsewhere.

References

[1] Single World Intervention Graphs: A Primer,

Thomas Richardson, James Robins

Parameters

Returns

stochastic function decorated with a DoMessenger

EnumMessenger

class EnumMessenger(first_available_dim: Optional[int] = None)[source]

Bases: pyro.poutine.messenger.Messenger

Enumerates in parallel over discrete sample sites markedinfer={"enumerate": "parallel"}.

Parameters

first_available_dim (int) – The first tensor dimension (counting from the right) that is available for parallel enumeration. This dimension and all dimensions left may be used internally by Pyro. This should be a negative integer or None.

enumerate_site(msg: pyro.poutine.runtime.Message) → torch.Tensor[source]

EscapeMessenger

class EscapeMessenger(escape_fn: Callable[[pyro.poutine.runtime.Message], bool])[source]

Bases: pyro.poutine.messenger.Messenger

Messenger that does a nonlocal exit by raising a util.NonlocalExit exception

IndepMessenger

class CondIndepStackFrame(name, dim, size, counter, full_size)[source]

Bases: tuple

counter_: int_

Alias for field number 3

dim_: Optional[int]_

Alias for field number 1

full_size_: Optional[int]_

Alias for field number 4

name_: str_

Alias for field number 0

size_: int_

Alias for field number 2

property vectorized_: bool_

class IndepMessenger(name: str, size: int, dim: Optional[int] = None, device: Optional[str] = None)[source]

Bases: pyro.poutine.messenger.Messenger

This messenger keeps track of stack of independence information declared by nested plate contexts. This information is stored in acond_indep_stack at each sample/observe site for consumption byTraceMessenger.

Example:

x_axis = IndepMessenger('outer', 320, dim=-1) y_axis = IndepMessenger('inner', 200, dim=-2) with x_axis: x_noise = sample("x_noise", dist.Normal(loc, scale).expand_by([320])) with y_axis: y_noise = sample("y_noise", dist.Normal(loc, scale).expand_by([200, 1])) with x_axis, y_axis: xy_noise = sample("xy_noise", dist.Normal(loc, scale).expand_by([200, 320]))

property indices_: torch.Tensor_

next_context() → None[source]

Increments the counter.

InferConfigMessenger

class InferConfigMessenger(config_fn: Callable[[Message], InferDict])[source]

Bases: pyro.poutine.messenger.Messenger

Given a callable fn that contains Pyro primitive calls and a callable config_fn taking a trace site and returning a dictionary, updates the value of the infer kwarg at a sample site to config_fn(site).

Parameters

Returns

stochastic function decorated with InferConfigMessenger

LiftMessenger

class LiftMessenger(prior: Union[Callable, pyro.distributions.distribution.Distribution, Dict[str, Union[pyro.distributions.distribution.Distribution, Callable]]])[source]

Bases: pyro.poutine.messenger.Messenger

Given a stochastic function with param calls and a prior distribution, create a stochastic function where all param calls are replaced by sampling from prior. Prior should be a callable or a dict of names to callables.

Consider the following Pyro program:

def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2 lifted_model = pyro.poutine.lift(model, prior={"s": dist.Exponential(0.3)})

lift makes param statements behave like sample statements using the distributions in prior. In this example, site s will now behave as if it was replaced with s = pyro.sample("s", dist.Exponential(0.3)):

tr = pyro.poutine.trace(lifted_model).get_trace(0.0) tr.nodes["s"]["type"] == "sample" True tr2 = pyro.poutine.trace(lifted_model).get_trace(0.0) bool((tr2.nodes["s"]["value"] == tr.nodes["s"]["value"]).all()) False

Parameters

Returns

fn decorated with a LiftMessenger

MarkovMessenger

class MarkovMessenger(history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None)[source]

Bases: pyro.poutine.reentrant_messenger.ReentrantMessenger

Markov dependency declaration.

This is a statistical equivalent of a memory management arena.

Parameters

generator(iterable: Iterable[int]) → typing_extensions.Self[source]

MaskMessenger

class MaskMessenger(mask: Union[bool, torch.BoolTensor])[source]

Bases: pyro.poutine.messenger.Messenger

Given a stochastic function with some batched sample statements and masking tensor, mask out some of the sample statements elementwise.

Parameters

Returns

stochastic function decorated with a MaskMessenger

PlateMessenger

class PlateMessenger(name: str, size: Optional[int] = None, subsample_size: Optional[int] = None, subsample: Optional[torch.Tensor] = None, dim: Optional[int] = None, use_cuda: Optional[bool] = None, device: Optional[str] = None)[source]

Bases: pyro.poutine.subsample_messenger.SubsampleMessenger

Swiss army knife of broadcasting amazingness: combines shape inference, independence annotation, and subsampling

block_plate(name: Optional[str] = None, dim: Optional[int] = None, *, strict: bool = True) → Iterator[None][source]

EXPERIMENTAL Context manager to temporarily block a single enclosing plate.

This is useful for sampling auxiliary variables or lazily sampling global variables that are needed in a plated context. For example the following models are equivalent:

Example:

def model_1(data): loc = pyro.sample("loc", dist.Normal(0, 1)) with pyro.plate("data", len(data)): with block_plate("data"): scale = pyro.sample("scale", dist.LogNormal(0, 1)) pyro.sample("x", dist.Normal(loc, scale))

def model_2(data): loc = pyro.sample("loc", dist.Normal(0, 1)) scale = pyro.sample("scale", dist.LogNormal(0, 1)) with pyro.plate("data", len(data)): pyro.sample("x", dist.Normal(loc, scale))

Parameters

Raises

ValueError if no enclosing plate was found and strict=True.

ReentrantMessenger

class ReentrantMessenger[source]

Bases: pyro.poutine.messenger.Messenger

ReparamMessenger

class ReparamHandler(msngr, fn: Callable[[pyro.poutine.reparam_messenger._P], pyro.poutine.reparam_messenger._T])[source]

Bases: Generic[pyro.poutine.reparam_messenger._P, pyro.poutine.reparam_messenger._T]

Reparameterization poutine.

class ReparamMessenger(config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]])[source]

Bases: pyro.poutine.messenger.Messenger

Reparametrizes each affected sample site into one or more auxiliary sample sites followed by a deterministic transformation [1].

To specify reparameterizers, pass a config dict or callable to the constructor. See the pyro.infer.reparam module for available reparameterizers.

Note some reparameterizers can examine the *args,**kwargs inputs of functions they affect; these reparameterizers require usingpoutine.reparam as a decorator rather than as a context manager.

[1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)

“Automatic Reparameterisation of Probabilistic Programs”https://arxiv.org/pdf/1906.03028.pdf

Parameters

config (dict or callable) – Configuration, either a dict mapping site name toReparameterizer , or a function mapping site to Reparam or None. See pyro.infer.reparam.strategies for built-in configuration strategies.

ReplayMessenger

class ReplayMessenger(trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor]] = None)[source]

Bases: pyro.poutine.messenger.Messenger

Given a callable that contains Pyro primitive calls, return a callable that runs the original, reusing the values at sites in trace at those sites in the new trace

Consider the following Pyro program:

def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2

replay makes sample statements behave as if they had sampled the values at the corresponding sites in the trace:

old_trace = pyro.poutine.trace(model).get_trace(1.0) replayed_model = pyro.poutine.replay(model, trace=old_trace) bool(replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"]) True

Parameters

Returns

a stochastic function decorated with a ReplayMessenger

ScaleMessenger

class ScaleMessenger(scale: Union[float, torch.Tensor])[source]

Bases: pyro.poutine.messenger.Messenger

Given a stochastic function with some sample statements and a positive scale factor, scale the score of all sample and observe sites in the function.

Consider the following Pyro program:

def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... pyro.sample("z", dist.Normal(x, s), obs=torch.tensor(1.0))

scale multiplicatively scales the log-probabilities of sample sites:

scaled_model = pyro.poutine.scale(model, scale=0.5) scaled_tr = pyro.poutine.trace(scaled_model).get_trace(0.0) unscaled_tr = pyro.poutine.trace(model).get_trace(0.0) bool((scaled_tr.log_prob_sum() == 0.5 * unscaled_tr.log_prob_sum()).all()) True

Parameters

Returns

stochastic function decorated with a ScaleMessenger

SeedMessenger

class SeedMessenger(rng_seed: int)[source]

Bases: pyro.poutine.messenger.Messenger

Handler to set the random number generator to a pre-defined state by setting its seed. This is the same as calling pyro.set_rng_seed() before the call to fn. This handler has no additional effect on primitive statements on the standard Pyro backend, but it might intercept pyro.sample calls in other backends. e.g. the NumPy backend.

Parameters

SubsampleMessenger

class SubsampleMessenger(name: str, size: Optional[int] = None, subsample_size: Optional[int] = None, subsample: Optional[torch.Tensor] = None, dim: Optional[int] = None, use_cuda: Optional[bool] = None, device: Optional[str] = None)[source]

Bases: pyro.poutine.indep_messenger.IndepMessenger

Extension of IndepMessenger that includes subsampling.

SubstituteMessenger

class SubstituteMessenger(data: Dict[str, torch.Tensor])[source]

Bases: pyro.poutine.messenger.Messenger

Given a stochastic function with param calls and a set of parameter values, create a stochastic function where all param calls are substituted with the fixed values. data should be a dict of names to values. Consider the following Pyro program:

def model(x): ... a = pyro.param("a", torch.tensor(0.5)) ... x = pyro.sample("x", dist.Bernoulli(probs=a)) ... return x substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)})

In this example, site a will now have value torch.tensor(0.3). :param data: dictionary of values keyed by site names. :returns: fn decorated with a SubstituteMessenger

TraceMessenger

class TraceHandler(msngr: pyro.poutine.trace_messenger.TraceMessenger, fn: Callable[[pyro.poutine.trace_messenger._P], pyro.poutine.trace_messenger._T])[source]

Bases: Generic[pyro.poutine.trace_messenger._P, pyro.poutine.trace_messenger._T]

Execution trace poutine.

A TraceHandler records the input and output to every Pyro primitive and stores them as a site in a Trace(). This should, in theory, be sufficient information for every inference algorithm (along with the implicit computational graph in the Variables?)

We can also use this for visualization.

get_trace(*args, **kwargs) → pyro.poutine.trace_struct.Trace[source]

Returns

data structure

Return type

pyro.poutine.Trace

Helper method for a very common use case. Calls this poutine and returns its trace instead of the function’s return value.

property trace_: pyro.poutine.trace_struct.Trace_

class TraceMessenger(graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None)[source]

Bases: pyro.poutine.messenger.Messenger

Return a handler that records the inputs and outputs of primitive calls and their dependencies.

Consider the following Pyro program:

def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2

We can record its execution using traceand use the resulting data structure to compute the log-joint probability of all of the sample sites in the execution or extract all parameters.

trace = pyro.poutine.trace(model).get_trace(0.0) logp = trace.log_prob_sum() params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]

Parameters

Returns

stochastic function decorated with a TraceMessenger

get_trace() → pyro.poutine.trace_struct.Trace[source]

Returns

data structure

Return type

pyro.poutine.Trace

Helper method for a very common use case. Returns a shallow copy of self.trace.

identify_dense_edges(trace: pyro.poutine.trace_struct.Trace) → None[source]

Modifies a trace in-place by adding all edges based on thecond_indep_stack information stored at each site.

UnconditionMessenger

class UnconditionMessenger[source]

Bases: pyro.poutine.messenger.Messenger

Messenger to force the value of observed nodes to be sampled from their distribution, ignoring observations.

GuideMessenger

class GuideMessenger(model: Callable)[source]

Bases: pyro.poutine.trace_messenger.TraceMessenger, abc.ABC

Abstract base class for effect-based guides.

Derived classes must implement the get_posterior() method.

property model_: Callable_

__call__(*args, **kwargs) → Dict[str, torch.Tensor][source]

Draws posterior samples from the guide and replays the model against those samples.

Returns

A dict mapping sample site name to sample value. This includes latent, deterministic, and observed values.

Return type

dict

abstract get_posterior(name: str, prior: TorchDistributionMixin) → Union[TorchDistributionMixin, torch.Tensor][source]

Abstract method to compute a posterior distribution or sample a posterior value given a prior distribution conditioned on upstream posterior samples.

Implementations may use pyro.param and pyro.sample inside this function, but pyro.sample statements should setinfer={"is_auxiliary": True"} .

Implementations may access further information for computations:

Parameters

Returns

A posterior distribution or sample from the posterior distribution.

Return type

Distribution or torch.Tensor

upstream_value(name: str) → Optional[torch.Tensor][source]

For use in get_posterior() .

Returns

The value of an upstream sample or deterministic site

Return type

torch.Tensor

get_traces() → Tuple[pyro.poutine.trace_struct.Trace, pyro.poutine.trace_struct.Trace][source]

This can be called after running __call__() to extract a pair of traces.

In contrast to the trace-replay pattern of generating a pair of traces,GuideMessenger interleaves model and guide computations, so only a single guide(*args, **kwargs) call is needed to create both traces. This function merely extract the relevant information from this guide’s .trace attribute.

Returns

a pair (model_trace, guide_trace)

Return type

tuple