tfp.mcmc.TransformedTransitionKernel  |  TensorFlow Probability (original) (raw)

TransformedTransitionKernel applies a bijector to the MCMC's state space.

Inherits From: TransitionKernel

tfp.mcmc.TransformedTransitionKernel(
    inner_kernel, bijector, name=None
)

Used in the notebooks

Used in the tutorials
Multilevel Modeling Primer in TensorFlow Probability TensorFlow Probability Case Study: Covariance Estimation A Tour of TensorFlow Probability Bayesian Gaussian Mixture Model and Hamiltonian MCMC Bayesian Switchpoint Analysis

The TransformedTransitionKernel TransitionKernel enables fitting a tfp.bijectors.Bijector which serves to decorrelate the Markov chain Monte Carlo (MCMC) event dimensions thus making the chain mix faster. This is particularly useful when the geometry of the target distribution is unfavorable. In such cases it may take many evaluations of thetarget_log_prob_fn for the chain to mix between faraway states.

The idea of training an affine function to decorrelate chain event dims was presented in [Parno and Marzouk (2014)][1]. Used in conjunction with theHamiltonianMonteCarlo TransitionKernel, the [Parno and Marzouk (2014)][1] idea is an instance of Riemannian manifold HMC [(Girolami and Calderhead, 2011)][2].

The TransformedTransitionKernel enables arbitrary bijective transformations of arbitrary TransitionKernels, e.g., one could use bijectorstfp.bijectors.ScaleMatvecTriL, tfp.bijectors.RealNVP, etc. with transition kernels tfp.mcmc.HamiltonianMonteCarlo, tfp.mcmc.RandomWalkMetropolis, etc.

Transforming nested kernels

TransformedTransitionKernel can operate on multiply nested kernels, as in the following example:

tfp.mcmc.TransformedTransitionKernel(
  inner_kernel=tfp.mcmc.SimpleStepSizeAdaptation(
    inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
      ... # doesn't matter
    ),
    num_adaptation_steps=9)
  bijector=tfb.Identity()))

Upon construction, TransformedTransitionKernel searches the giveninner_kernel and the "stack" of nested kernels in any inner_kernelfields thereof until it finds one with a field called target_log_prob_fn, and replaces this with the transformed function. If noinner_kernel has such a target log prob a ValueError is raised.

Mathematical Details

TransformedTransitionKernel enables Markov chains which operate in "unconstrained space." Since we interpret the bijector as mapping "unconstrained space" to "user space", this means that the MCMC transformedtarget_log_prob is:

target_log_prob(bij.forward(x)) + bij.forward_log_det_jacobian(x)

Recall that tfp.distributions.TransformedDistribution uses the inverse to compute its log_prob. Despite this difference, the use of forward inTransformedTransitionKernel is perfectly consistent withTransformedDistribution following the TFP convention of "sampling" being what defines semantics. The apparent difference is becauseTransformedDistribution.log_prob is derived from a user provided distribution while in TransformedTransitionKernel samples are derived fromtarget_log_prob_fn. That is, in TransformedDistribution we do:

x ~ NoiseDistribution()
y = bij.forward(x)
log_prob_y = NoiseDistribution().log_prob(bij.inverse(y))
             + bij.inverse_log_det_jacobian(y)

yet in TransformedTransitionKernel we do:

x ~ MCMC()
y = bij.forward(x)
log_prob_y = log_prob(y) + bij.forward_log_det_jacobian(x)

In other words (and in general), tfp.mcmc is derived from a log_probwhich what induces a seeming direction convention change. Aside from TFP convention, that Bijectors should adhere to "sample first" semantics is important because it mitigates pervasive necessity of tfp.bijectors.Invertin user code.

Examples

RealNVP + HamiltonianMonteCarlo
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

def make_likelihood(true_variances):
  return tfd.MultivariateNormalDiag(
      scale_diag=tf.sqrt(true_variances))

dims = 10
dtype = np.float32
true_variances = tf.linspace(dtype(1), dtype(3), dims)
likelihood = make_likelihood(true_variances)

realnvp_hmc = tfp.mcmc.TransformedTransitionKernel(
    inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
      target_log_prob_fn=likelihood.log_prob,
      step_size=0.5,
      num_leapfrog_steps=2),
    bijector=tfb.RealNVP(
      num_masked=2,
      shift_and_log_scale_fn=tfb.real_nvp_default_template(
          hidden_layers=[512, 512])))

states, kernel_results = tfp.mcmc.sample_chain(
    num_results=1000,
    current_state=tf.zeros(dims),
    kernel=realnvp_hmc,
    num_burnin_steps=500)

# Compute sample stats.
sample_mean = tf.reduce_mean(states, axis=0)
sample_var = tf.reduce_mean(
    tf.squared_difference(states, sample_mean),
    axis=0)

References

[1]: Matthew Parno and Youssef Marzouk. Transport map accelerated Markov chain Monte Carlo. arXiv preprint arXiv:1412.5492, 2014.https://arxiv.org/abs/1412.5492

[2]: Mark Girolami and Ben Calderhead. Riemann manifold langevin and hamiltonian monte carlo methods. In Journal of the Royal Statistical Society, 2011. https://doi.org/10.1111/j.1467-9868.2010.00765.x

Args
inner_kernel TransitionKernel-like object that either has atarget_log_prob_fn argument, or wraps around another inner_kernelwith said argument.
bijector tfp.distributions.Bijector or list oftfp.distributions.Bijectors. These bijectors use forward to map theinner_kernel state space to the state expected byinner_kernel.target_log_prob_fn.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., "transformed_kernel").
Attributes
bijector
experimental_shard_axis_names The shard axis names for members of the state.
inner_kernel
is_calibrated Returns True if Markov chain converges to specified distribution.TransitionKernels which are "uncalibrated" are often calibrated by composing them with the tfp.mcmc.MetropolisHastings TransitionKernel.
name
parameters Return dict of __init__ arguments and their values.

Methods

bootstrap_results

View source

bootstrap_results(
    init_state=None, transformed_init_state=None
)

Returns an object with the same type as returned by one_step.

Unlike other TransitionKernels,TransformedTransitionKernel.bootstrap_results has the option of initializing the TransformedTransitionKernelResults from either an initial state, eg, requiring computing bijector.inverse(init_state), or directly from transformed_init_state, i.e., a Tensor or list of Tensors which is interpretted as the bijector.inversetransformed state.

Args
init_state Tensor or Python list of Tensors representing the a state(s) of the Markov chain(s). Must specify init_state ortransformed_init_state but not both.
transformed_init_state Tensor or Python list of Tensors representing the a state(s) of the Markov chain(s). Must specifyinit_state or transformed_init_state but not both.
Returns
kernel_results A (possibly nested) tuple, namedtuple or list ofTensors representing internal calculations made within this function.
Raises
ValueError if none of the nested inner_kernel results contain the member "target_log_prob".

Examples

To use transformed_init_state in context oftfp.mcmc.sample_chain, you need to explicitly pass theprevious_kernel_results, e.g.,

transformed_kernel = tfp.mcmc.TransformedTransitionKernel(...)
init_state = ...        # Doesnt matter.
transformed_init_state = ... # Does matter.
results = tfp.mcmc.sample_chain(
    num_results=...,
    current_state=init_state,
    previous_kernel_results=transformed_kernel.bootstrap_results(
        transformed_init_state=transformed_init_state),
    trace_fn=None,
    kernel=transformed_kernel)

copy

View source

copy(
    **override_parameter_kwargs
)

Non-destructively creates a deep copy of the kernel.

Args
**override_parameter_kwargs Python String/value dictionary of initialization arguments to override with new values.
Returns
new_kernel TransitionKernel object of same type as self, initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e.,dict(self.parameters, **override_parameters_kwargs).

experimental_with_shard_axes

View source

experimental_with_shard_axes(
    shard_axis_names
)

Returns a copy of the kernel with the provided shard axis names.

Args
shard_axis_names a structure of strings indicating the shard axis names for each component of this kernel's state.
Returns
A copy of the current kernel with the shard axis information.

one_step

View source

one_step(
    current_state, previous_kernel_results, seed=None
)

Runs one iteration of the Transformed Kernel.

Args
current_state Tensor or Python list of Tensors representing the current state(s) of the Markov chain(s),after application of bijector.forward. The first rdimensions index independent chains,r = tf.rank(target_log_prob_fn(*current_state)). Theinner_kernel.one_step does not actually use current_state, rather it takes as inputprevious_kernel_results.transformed_state (becauseTransformedTransitionKernel creates a copy of the input inner_kernel with a modified target_log_prob_fn which internally applies the bijector.forward).
previous_kernel_results collections.namedtuple containing Tensors representing values from previous calls to this function (or from thebootstrap_results function.)
seed PRNG seed; see tfp.random.sanitize_seed for details.
Returns
next_state Tensor or Python list of Tensors representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as current_state.
kernel_results collections.namedtuple of internal calculations used to advance the chain.