Easy Custom Guides — Pyro documentation (original) (raw)

EasyGuide

class EasyGuide(model)[source]

Bases: pyro.nn.module.PyroModule

Base class for “easy guides”, which are more flexible thanAutoGuide s, but are easier to write than raw Pyro guides.

Derived classes should define a guide() method. This guide()method can combine ordinary guide statements (e.g. pyro.sample andpyro.param) with the following special statements:

Derived classes may also override the init() method to provide custom initialization for models sites.

Parameters

model (callable) – A Pyro model.

property model

abstract guide(*args, **kargs)[source]

Guide implementation, to be overridden by user.

init(site)[source]

Model initialization method, may be overridden by user.

This should input a site and output a valid sample from that site. The default behavior is to draw a random sample:

For other possible initialization functions seehttp://docs.pyro.ai/en/stable/infer.autoguide.html#module-pyro.infer.autoguide.initialization

forward(*args, **kwargs)[source]

Runs the guide. This is typically used by inference algorithms.

Note

This method is used internally by Module. Users should instead use __call__().

plate(name, size=None, subsample_size=None, subsample=None, *args, **kwargs)[source]

A wrapper around pyro.plate to allow EasyGuide to automatically construct plates. You should use this rather thanpyro.plate inside your guide() implementation.

group(match='.*')[source]

Select a Group of model sites for joint guidance.

Parameters

match (str) – A regex string matching names of model sample sites.

Returns

A group of model sites.

Return type

Group

map_estimate(name)[source]

Construct a maximum a posteriori (MAP) guide using Delta distributions.

Parameters

name (str) – The name of a model sample site.

Returns

A sampled value.

Return type

torch.Tensor

training_: bool_

easy_guide

easy_guide(model)[source]

Convenience decorator to create an EasyGuide . The following are equivalent:

Version 1. Decorate a function.

@easy_guide(model) def guide(self, foo, bar): return my_guide(foo, bar)

Version 2. Create and instantiate a subclass of EasyGuide.

class Guide(EasyGuide): def guide(self, foo, bar): return my_guide(foo, bar) guide = Guide(model)

Note @easy_guide wrappers cannot be pickled; to build a guide that can be pickled, instead subclass from EasyGuide.

Parameters

model (callable) – a Pyro model.

Group

class Group(guide, sites)[source]

Bases: object

An autoguide helper to match a group of model sites.

Variables

Parameters

property guide

sample(guide_name, fn, infer=None)[source]

Wrapper around pyro.sample() to create a single auxiliary sample site and then unpack to multiple sample sites for model replay.

Parameters

Returns

A pair (guide_z, model_zs) where guide_z is the single concatenated blob and model_zs is a dict mapping site name to constrained model sample.

Return type

tuple

map_estimate()[source]

Construct a maximum a posteriori (MAP) guide using Delta distributions.

Returns

A dict mapping model site name to sampled value.

Return type

dict