jax.random.PRNGKey — JAX documentation (original) (raw)

jax.random.PRNGKey#

jax.random.PRNGKey(seed, *, impl=None)[source]#

Create a legacy PRNG key given an integer seed.

This function produces old-style legacy PRNG keys, which are arrays of dtype uint32. For more, see the note in the PRNG keyssection. When possible, jax.random.key() is recommended for use instead.

The resulting key does not carry a PRNG implementation. The returned key matches the implementation given by the optional implargument or, otherwise, determined by the jax_default_prng_implconfig flag. Callers must ensure that same implementation is set as the default when passing this key as an argument to other functions (such as jax.random.split and jax.random.normal).

Parameters:

Returns:

A PRNG key, consumable by random functions as well as splitand fold_in.

Return type:

Array