jax.numpy.identity — JAX documentation (original) (raw)
jax.numpy.identity#
jax.numpy.identity(n, dtype=None)[source]#
Create a square identity matrix
JAX implementation of numpy.identity().
Parameters:
- n (DimSize) – integer specifying the size of each array dimension.
- dtype (DTypeLike | None) – optional dtype; defaults to floating point.
Returns:
Identity array of shape (n, n).
Return type:
Examples
A simple 3x3 identity matrix:
jnp.identity(3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
A 2x2 integer identity matrix:
jnp.identity(2, dtype=int) Array([[1, 0], [0, 1]], dtype=int32)