jax.numpy.identity — JAX documentation (original) (raw)

jax.numpy.identity#

jax.numpy.identity(n, dtype=None)[source]#

Create a square identity matrix

JAX implementation of numpy.identity().

Parameters:

Returns:

Identity array of shape (n, n).

Return type:

Array

Examples

A simple 3x3 identity matrix:

jnp.identity(3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)

A 2x2 integer identity matrix:

jnp.identity(2, dtype=int) Array([[1, 0], [0, 1]], dtype=int32)