pymc.sample — PyMC 5.22.0 documentation (original) (raw)
pymc.sample(draws=1000, *, tune=1000, chains=None, cores=None, random_seed=None, progressbar=True, progressbar_theme=None, step=None, var_names=None, nuts_sampler='pymc', initvals=None, init='auto', jitter_max_retries=10, n_init=200000, trace=None, discard_tuned_samples=True, compute_convergence_checks=True, keep_warning_stat=False, return_inferencedata=True, idata_kwargs=None, nuts_sampler_kwargs=None, callback=None, mp_ctx=None, blas_cores='auto', model=None, compile_kwargs=None, **kwargs)[source]#
Draw samples from the posterior using the given step methods.
Multiple step methods are supported via compound step methods.
Parameters:
drawsint
The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded by default. See discard_tuned_samples
.
tuneint
Number of iterations to tune, defaults to 1000. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the draws
argument, and will be discarded unless discard_tuned_samples
is set to False.
chainsint
The number of chains to sample. Running independent chains is important for some convergence statistics and can also reveal multiple modes in the posterior. If None
, then set to either cores
or 2, whichever is larger.
coresint
The number of chains to run in parallel. If None
, set to the number of CPUs in the system, but at most 4.
random_seedint, array_like of int, or Generator
, optional
Random seed(s) used by the sampling steps. Each step will create its ownGenerator object to make its random draws in a way that is indepedent from all other steppers and all other chains. A TypeError
will be raised if a legacy RandomState object is passed. We no longer support RandomState
objects because their seeding mechanism does not allow easy spawning of new independent random streams that are needed by the step methods.
progressbar: bool or ProgressType, optional
How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask for one of the following: - “combined”: A single progress bar that displays the total progress across all chains. Only timing
information is shown.
- “split”: A separate progress bar for each chain. Only timing information is shown.
- “combined+stats” or “stats+combined”: A single progress bar displaying the total progress across all
chains. Aggregate sample statistics are also displayed. - “split+stats” or “stats+split”: A separate progress bar for each chain. Sample statistics for each chain
are also displayed.
If True, the default is “split+stats” is used.
stepfunction
or iterable of functions
A step function or collection of functions. If there are variables without step methods, step methods for those variables will be assigned automatically. By default the NUTS step method will be used, if appropriate to the model.
var_nameslist of str, optional
Names of variables to be stored in the trace. Defaults to all free variables and deterministics.
nuts_samplerstr
Which NUTS implementation to run. One of [“pymc”, “nutpie”, “blackjax”, “numpyro”]. This requires the chosen sampler to be installed. All samplers, except “pymc”, require the full model to be continuous.
blas_cores: int or “auto” or None, default = “auto”
The total number of threads blas and openmp functions should use during sampling. Setting it to “auto” will ensure that the total number of active blas threads is the same as the cores argument. If set to an integer, the sampler will try to use that total number of blas threads. If blas_cores is not divisible by cores, it might get rounded down. If set to None, this will keep the default behavior of whatever blas implementation is used at runtime.
initvalsoptional, dict, array of dict
Dict or list of dicts with initial value strategies to use instead of the defaults fromModel.initial_values. The keys should be names of transformed random variables. Initialization methods for NUTS (see init
keyword) can overwrite the default.
initstr
Initialization method to use for auto-assigned NUTS samplers. See pm.init_nuts for a list of all options. This argument is ignored when manually passing the NUTS step method. Only applicable to the pymc nuts sampler.
jitter_max_retriesint
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter that yields a finite probability. This applies to jitter+adapt_diag
andjitter+adapt_full
init methods.
n_initint
Number of iterations of initializer. Only works for ‘ADVI’ init methods.
tracebackend
, optional
A backend instance or None. If None
, a MultiTrace
object with underlying NDArray
trace objects is used. If trace
is a ZarrTrace instance, the drawn samples will be written onto the desired storage while sampling is on-going. This means sampling runs that, for whatever reason, die in the middle of their execution will write the partial results onto the storage. If the storage persist on disk, these results should be available even after a server crash. See ZarrTrace for more information.
discard_tuned_samplesbool
Whether to discard posterior samples of the tune interval.
compute_convergence_checksbool, default=True
Whether to compute sampler statistics like Gelman-Rubin and effective_n
.
keep_warning_statbool
If True
the “warning” stat emitted by, for example, HMC samplers will be kept in the returned idata.sample_stats
group. This leads to the idata
not supporting .to_netcdf()
or .to_zarr()
and should only be set to True
if you intend to use the “warning” objects right away. Defaults to False
such that pm.drop_warning_stat
is applied automatically, making the InferenceData
compatible with saving.
return_inferencedatabool
Whether to return the trace as an arviz.InferenceData (True) object or aMultiTrace (False). Defaults to True.
idata_kwargsdict, optional
Keyword arguments for pymc.to_inference_data()
nuts_sampler_kwargsdict, optional
Keyword arguments for the sampling library that implements nuts. Only used when an external sampler is specified via the nuts_sampler kwarg.
callbackfunction
, default=None
A function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw and will contain all samples for a single trace. the draw.chain
argument can be used to determine which of the active chains the sample is drawn from. Sampling can be interrupted by throwing a KeyboardInterrupt
in the callback.
mp_ctxmultiprocessing.context.BaseContent
A multiprocessing context for parallel sampling. See multiprocessing documentation for details.
modelModel
(optional if
in
with
context
)
Model to sample from. The model needs to have free random variables.
compile_kwargs: dict, optional
Dictionary with keyword argument to pass to the functions compiled by the step methods.
Returns:
tracepymc.backends.base.MultiTrace | pymc.backends.zarr.ZarrTrace | arviz.InferenceData
A MultiTrace
, InferenceData orZarrTrace object that contains the samples. AZarrTrace
is only returned if the supplied trace
argument is aZarrTrace
instance. Refer to ZarrTrace for the benefits this backend provides.
Notes
Optional keyword arguments can be passed to sample
to be delivered to thestep_method
s used during sampling.
For example:
target_accept
to NUTS: nuts={‘target_accept’:0.9}transit_p
to BinaryGibbsMetropolis: binary_gibbs_metropolis={‘transit_p’:.7}
Note that available step names are:
nuts
, hmc
, metropolis
, binary_metropolis
,binary_gibbs_metropolis
, categorical_gibbs_metropolis
,DEMetropolis
, DEMetropolisZ
, slice
The NUTS step method has several options including:
- target_accept : float in [0, 1]. The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors. This argument can be passed directly to sample.
- max_treedepth : The maximum depth of the trajectory tree
- step_scale : float, default 0.25 The initial guess for the step size scaled down by \(1/n**(1/4)\), where n is the dimensionality of the parameter space
Alternatively, if you manually declare the step_method
s, within the step
kwarg, then you can address the step_method
kwargs directly. e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis, you could send
step = [ pm.NUTS([freeRV1, freeRV2], target_accept=0.9), pm.BinaryGibbsMetropolis([freeRV3], transit_p=0.7), ]
You can find a full list of arguments in the docstring of the step methods.
Examples
In [1]: import pymc as pm ...: n = 100 ...: h = 61 ...: alpha = 2 ...: beta = 2
In [2]: with pm.Model() as model: # context management ...: p = pm.Beta("p", alpha=alpha, beta=beta) ...: y = pm.Binomial("y", n=n, p=p, observed=h) ...: idata = pm.sample()
In [3]: az.summary(idata, kind="stats")
Out[3]: mean sd hdi_3% hdi_97% p 0.609 0.047 0.528 0.699