jax.random.split — JAX documentation (original) (raw)
jax.random.split#
jax.random.split(key, num=2)[source]#
Splits a PRNG key into num new keys by adding a leading axis.
Parameters:
- key (ArrayLike) – a PRNG key (from
key
,split
,fold_in
). - num (int | tuple[_int,_ ... ]) – optional, a positive integer (or tuple of integers) indicating the number (or shape) of keys to produce. Defaults to 2.
Returns:
An array-like object of num new PRNG keys.
Return type: