jax.random module — JAX documentation (original) (raw)

jax.random module#

Utilities for pseudo-random number generation.

The jax.random package provides a number of routines for deterministic generation of sequences of pseudorandom numbers.

Basic usage#

seed = 1701 num_steps = 100 key = jax.random.key(seed) for i in range(num_steps): ... key, subkey = jax.random.split(key) ... params = compiled_update(subkey, params, next(batches))

PRNG keys#

Unlike the stateful pseudorandom number generators (PRNGs) that users of NumPy and SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to be passed as a first argument. The random state is described by a special array element type that we call a key, usually generated by the jax.random.key() function:

from jax import random key = random.key(0) key Array((), dtype=key) overlaying: [0 0]

This key can then be used in any of JAX’s random number generation routines:

random.uniform(key) Array(0.947667, dtype=float32)

Note that using a key does not modify it, so reusing the same key will lead to the same result:

random.uniform(key) Array(0.947667, dtype=float32)

If you need a new random number, you can use jax.random.split() to generate new subkeys:

key, subkey = random.split(key) random.uniform(subkey) Array(0.00729382, dtype=float32)

Note

Typed key arrays, with element types such as key<fry> above, were introduced in JAX v0.4.16. Before then, keys were conventionally represented in uint32 arrays, whose final dimension represented the key’s bit-level representation.

Both forms of key array can still be created and used with thejax.random module. New-style typed key arrays are made withjax.random.key(). Legacy uint32 key arrays are made with jax.random.PRNGKey().

To convert between the two, use jax.random.key_data() andjax.random.wrap_key_data(). The legacy key format may be needed when interfacing with systems outside of JAX (e.g. exporting arrays to a serializable format), or when passing keys to JAX-based libraries that assume the legacy format.

Otherwise, typed keys are recommended. Caveats of legacy keys relative to typed ones include:

To learn more about this upgrade, and the design of key types, seeJEP 9263.

Advanced#

Design and background#

TLDR: JAX PRNG = Threefry counter PRNG+ a functional array-oriented splitting model

See docs/jep/263-prng.mdfor more details.

To summarize, among other requirements, the JAX PRNG aims to:

  1. ensure reproducibility,
  2. parallelize well, both in terms of vectorization (generating array values) and multi-replica, multi-core computation. In particular it should not use sequencing constraints between random function calls.

Advanced RNG configuration#

JAX provides several PRNG implementations. A specific one can be selected with the optional impl keyword argument tojax.random.key. When no impl option is passed to the keyconstructor, the implementation is determined by the globaljax_default_prng_impl configuration flag. The string names of available implementations are:

Reasons to use an alternative to the default RNG include that:

  1. It may be slow to compile for TPUs.
  2. It is relatively slower to execute on TPUs.

Automatic partitioning:

In order for jax.jit to efficiently auto-partition functions that generate sharded random number arrays (or key arrays), all PRNG implementations require extra flags:

The XLA flag can be set using an the XLA_FLAGS environment variable, e.g. asXLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1.

For more about jax_threefry_partitionable, seehttps://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers

Summary:

Property Threefry Threefry* rbg unsafe_rbg rbg** unsafe_rbg**
Fastest on TPU
efficiently shardable (w/ pjit)
identical across shardings
identical across CPU/GPU/TPU
exact jax.vmap over keys

(*): with jax_threefry_partitionable=1 set

(**): with XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1 set

API Reference#

Key Creation & Manipulation#

key(seed, *[, impl]) Create a pseudo-random number generator (PRNG) key given an integer seed.
key_data(keys) Recover the bits of key data underlying a PRNG key array.
wrap_key_data(key_bits_array, *[, impl]) Wrap an array of key data bits into a PRNG key array.
fold_in(key, data) Folds in data to a PRNG key to form a new PRNG key.
split(key[, num]) Splits a PRNG key into num new keys by adding a leading axis.
clone(key) Clone a key for reuse
PRNGKey(seed, *[, impl]) Create a legacy PRNG key given an integer seed.

Random Samplers#

ball(key, d[, p, shape, dtype]) Sample uniformly from the unit Lp ball.
bernoulli(key[, p, shape, mode]) Sample Bernoulli random values with given shape and mean.
beta(key, a, b[, shape, dtype]) Sample Beta random values with given shape and float dtype.
binomial(key, n, p[, shape, dtype]) Sample Binomial random values with given shape and float dtype.
bits(key[, shape, dtype, out_sharding]) Sample uniform bits in the form of unsigned integers.
categorical(key, logits[, axis, shape, replace]) Sample random values from categorical distributions.
cauchy(key[, shape, dtype]) Sample Cauchy random values with given shape and float dtype.
chisquare(key, df[, shape, dtype]) Sample Chisquare random values with given shape and float dtype.
choice(key, a[, shape, replace, p, axis]) Generates a random sample from a given array.
dirichlet(key, alpha[, shape, dtype]) Sample Dirichlet random values with given shape and float dtype.
double_sided_maxwell(key, loc, scale[, ...]) Sample from a double sided Maxwell distribution.
exponential(key[, shape, dtype]) Sample Exponential random values with given shape and float dtype.
f(key, dfnum, dfden[, shape, dtype]) Sample F-distribution random values with given shape and float dtype.
gamma(key, a[, shape, dtype]) Sample Gamma random values with given shape and float dtype.
generalized_normal(key, p[, shape, dtype]) Sample from the generalized normal distribution.
geometric(key, p[, shape, dtype]) Sample Geometric random values with given shape and float dtype.
gumbel(key[, shape, dtype, mode]) Sample Gumbel random values with given shape and float dtype.
laplace(key[, shape, dtype]) Sample Laplace random values with given shape and float dtype.
loggamma(key, a[, shape, dtype]) Sample log-gamma random values with given shape and float dtype.
logistic(key[, shape, dtype]) Sample logistic random values with given shape and float dtype.
lognormal(key[, sigma, shape, dtype]) Sample lognormal random values with given shape and float dtype.
maxwell(key[, shape, dtype]) Sample from a one sided Maxwell distribution.
multinomial(key, n, p, *[, shape, dtype, unroll]) Sample from a multinomial distribution.
multivariate_normal(key, mean, cov[, shape, ...]) Sample multivariate normal random values with given mean and covariance.
normal(key[, shape, dtype, out_sharding]) Sample standard normal random values with given shape and float dtype.
orthogonal(key, n[, shape, dtype, m]) Sample uniformly from the orthogonal group O(n).
pareto(key, b[, shape, dtype]) Sample Pareto random values with given shape and float dtype.
permutation(key, x[, axis, independent, ...]) Returns a randomly permuted array or range.
poisson(key, lam[, shape, dtype]) Sample Poisson random values with given shape and integer dtype.
rademacher(key[, shape, dtype]) Sample from a Rademacher distribution.
randint(key, shape, minval, maxval[, dtype, ...]) Sample uniform random values in [minval, maxval) with given shape/dtype.
rayleigh(key, scale[, shape, dtype]) Sample Rayleigh random values with given shape and float dtype.
t(key, df[, shape, dtype]) Sample Student's t random values with given shape and float dtype.
triangular(key, left, mode, right[, shape, ...]) Sample Triangular random values with given shape and float dtype.
truncated_normal(key, lower, upper[, shape, ...]) Sample truncated standard normal random values with given shape and dtype.
uniform(key[, shape, dtype, minval, maxval, ...]) Sample uniform random values in [minval, maxval) with given shape/dtype.
wald(key, mean[, shape, dtype]) Sample Wald random values with given shape and float dtype.
weibull_min(key, scale, concentration[, ...]) Sample from a Weibull distribution.