Easy Custom Guides — Pyro documentation (original) (raw)
EasyGuide
class EasyGuide(model)[source]
Bases: pyro.nn.module.PyroModule
Base class for “easy guides”, which are more flexible thanAutoGuide
s, but are easier to write than raw Pyro guides.
Derived classes should define a guide() method. This guide()method can combine ordinary guide statements (e.g. pyro.sample
andpyro.param
) with the following special statements:
group = self.group(...)
selects multiplepyro.sample
sites in the model. SeeGroup
for subsequent methods.with self.plate(...): ...
should be used instead ofpyro.plate
.self.map_estimate(...)
uses aDelta
guide for a single site.
Derived classes may also override the init() method to provide custom initialization for models sites.
Parameters
model (callable) – A Pyro model.
property model
abstract guide(*args, **kargs)[source]
Guide implementation, to be overridden by user.
Model initialization method, may be overridden by user.
This should input a site and output a valid sample from that site. The default behavior is to draw a random sample:
For other possible initialization functions seehttp://docs.pyro.ai/en/stable/infer.autoguide.html#module-pyro.infer.autoguide.initialization
forward(*args, **kwargs)[source]
Runs the guide. This is typically used by inference algorithms.
Note
This method is used internally by Module. Users should instead use __call__()
.
plate(name, size=None, subsample_size=None, subsample=None, *args, **kwargs)[source]
A wrapper around pyro.plate
to allow EasyGuide to automatically construct plates. You should use this rather thanpyro.plate
inside your guide() implementation.
Select a Group
of model sites for joint guidance.
Parameters
match (str) – A regex string matching names of model sample sites.
Returns
A group of model sites.
Return type
Construct a maximum a posteriori (MAP) guide using Delta distributions.
Parameters
name (str) – The name of a model sample site.
Returns
A sampled value.
Return type
easy_guide
Convenience decorator to create an EasyGuide . The following are equivalent:
Version 1. Decorate a function.
@easy_guide(model) def guide(self, foo, bar): return my_guide(foo, bar)
Version 2. Create and instantiate a subclass of EasyGuide.
class Guide(EasyGuide): def guide(self, foo, bar): return my_guide(foo, bar) guide = Guide(model)
Note @easy_guide
wrappers cannot be pickled; to build a guide that can be pickled, instead subclass from EasyGuide.
Parameters
model (callable) – a Pyro model.
Group
class Group(guide, sites)[source]
Bases: object
An autoguide helper to match a group of model sites.
Variables
- event_shape (torch.Size) – The total flattened concatenated shape of all matching sample sites in the model.
- prototype_sites (list) – A list of all matching sample sites in a prototype trace of the model.
Parameters
property guide
sample(guide_name, fn, infer=None)[source]
Wrapper around pyro.sample()
to create a single auxiliary sample site and then unpack to multiple sample sites for model replay.
Parameters
- guide_name (str) – The name of the auxiliary guide site.
- fn (callable) – A distribution with shape
self.event_shape
. - infer (dict) – Optional inference configuration dict.
Returns
A pair (guide_z, model_zs)
where guide_z
is the single concatenated blob and model_zs
is a dict mapping site name to constrained model sample.
Return type
Construct a maximum a posteriori (MAP) guide using Delta distributions.
Returns
A dict mapping model site name to sampled value.
Return type