tfp.mcmc.MetropolisHastings  |  TensorFlow Probability (original) (raw)

Runs one step of the Metropolis-Hastings algorithm.

Inherits From: TransitionKernel

tfp.mcmc.MetropolisHastings(
    inner_kernel, name=None
)

The Metropolis-Hastings algorithm is a Markov chain Monte Carlo (MCMC) technique which uses a proposal distribution to eventually sample from a target distribution.

The Metropolis-Hastings log acceptance-probability is computed as:

log_accept_ratio = (current_kernel_results.target_log_prob
                    - previous_kernel_results.target_log_prob
                    + current_kernel_results.log_acceptance_correction)

If current_kernel_results.log_acceptance_correction does not exist, it is presumed 0. (i.e., that the proposal distribution is symmetric).

The most common use-case for log_acceptance_correction is in the Metropolis-Hastings algorithm, i.e.,

accept_prob(x' | x) = p(x') / p(x) (g(x|x') / g(x'|x))

where,
  p  represents the target distribution,
  g  represents the proposal (conditional) distribution,
  x' is the proposed state, and,
  x  is current state

The log of the parenthetical term is the log_acceptance_correction.

The log_acceptance_correction may not necessarily correspond to the ratio of proposal distributions, e.g, log_acceptance_correction has a different interpretation in Hamiltonian Monte Carlo.

Examples

import tensorflow_probability as tfp
hmc = tfp.mcmc.MetropolisHastings(
    tfp.mcmc.UncalibratedHamiltonianMonteCarlo(
        target_log_prob_fn=lambda x: -x - x**2,
        step_size=0.1,
        num_leapfrog_steps=3))
# ==> functionally equivalent to:
# hmc = tfp.mcmc.HamiltonianMonteCarlo(
#     target_log_prob_fn=lambda x: -x - x**2,
#     step_size=0.1,
#     num_leapfrog_steps=3)
Args
inner_kernel TransitionKernel-like object which hascollections.namedtuple kernel_results and which contains atarget_log_prob member and optionally a log_acceptance_correctionmember.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., "mh_kernel").
Attributes
experimental_shard_axis_names The shard axis names for members of the state.
inner_kernel
is_calibrated Returns True if Markov chain converges to specified distribution.TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.
name
parameters Return dict of __init__ arguments and their values.

Methods

bootstrap_results

View source

bootstrap_results(
    init_state
)

Returns an object with the same type as returned by one_step.

Args
init_state Tensor or Python list of Tensors representing the initial state(s) of the Markov chain(s).
Returns
kernel_results A (possibly nested) tuple, namedtuple or list ofTensors representing internal calculations made within this function.
Raises
ValueError if inner_kernel results doesn't contain the member "target_log_prob".

copy

View source

copy(
    **override_parameter_kwargs
)

Non-destructively creates a deep copy of the kernel.

Args
**override_parameter_kwargs Python String/value dictionary of initialization arguments to override with new values.
Returns
new_kernel TransitionKernel object of same type as self, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e.,dict(self.parameters, **override_parameters_kwargs).

experimental_with_shard_axes

View source

experimental_with_shard_axes(
    shard_axis_names
)

Returns a copy of the kernel with the provided shard axis names.

Args
shard_axis_names a structure of strings indicating the shard axis names for each component of this kernel's state.
Returns
A copy of the current kernel with the shard axis information.

one_step

View source

one_step(
    current_state, previous_kernel_results, seed=None
)

Takes one step of the TransitionKernel.

Args
current_state Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s).
previous_kernel_results A (possibly nested) tuple, namedtuple orlist of Tensors representing internal calculations made within the previous call to this function (or as returned by bootstrap_results).
seed PRNG seed; see tfp.random.sanitize_seed for details.
Returns
next_state Tensor or Python list of Tensors representing the next state(s) of the Markov chain(s).
kernel_results A (possibly nested) tuple, namedtuple or list ofTensors representing internal calculations made within this function.
Raises
ValueError if inner_kernel results doesn't contain the member "target_log_prob".