Dispatch — Pyro API 0.0 documentation (original) (raw)

Pyro API

Dispatching allows you to dynamically set a backend using pyro_backend()and to register new backends using register_backend() . It’s easiest to see how to use these by example:

from pyroapi import distributions as dist from pyroapi import infer, ops, optim, pyro, pyro_backend

These model and guide are backend-agnostic.

def model(): locs = pyro.param("locs", ops.tensor([0.2, 0.3, 0.5])) p = ops.tensor([0.2, 0.3, 0.5]) with pyro.plate("plate", len(data), dim=-1): x = pyro.sample("x", dist.Categorical(p)) pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)

def guide(): p = pyro.param("p", ops.tensor([0.5, 0.3, 0.2])) with pyro.plate("plate", len(data), dim=-1): pyro.sample("x", dist.Categorical(p))

We can now set a backend at inference time.

with pyro_backend("numpyro"): elbo = infer.Trace_ELBO(ignore_jit_warnings=True) adam = optim.Adam({"lr": 1e-6}) inference = infer.SVI(model, guide, adam, elbo) for step in range(10): loss = inference.step(*args, **kwargs) print("step {} loss = {}".format(step, loss))

pyro_backend(*aliases, **new_backends)[source]

Context manager to set a custom backend for Pyro models.

Backends can be specified either by name (for standard backends or backends registered through register_backend() ) or by providing kwargs mapping module name to backend module name. Standard backends include: pyro, minipyro, funsor, and numpy.

register_backend(alias, new_backends)[source]

Register a new backend alias. For example:

register_backend("minipyro", { "infer": "pyro.contrib.minipyro", "optim": "pyro.contrib.minipyro", "pyro": "pyro.contrib.minipyro", })

Parameters: alias (str) – The name of the new backend. new_backends (dict) – A dict mapping standard module name (str) to new module name (str). This needs to include only nonstandard backends (e.g. if your backend uses torch ops, you need not override ops)

Generic Modules