jax.random.split — JAX documentation (original) (raw)

jax.random.split#

jax.random.split(key, num=2)[source]#

Splits a PRNG key into num new keys by adding a leading axis.

Parameters:

Returns:

An array-like object of num new PRNG keys.

Return type:

Array