jax.numpy.diagonal — JAX documentation (original) (raw)

jax.numpy.diagonal#

jax.numpy.diagonal(a, offset=0, axis1=0, axis2=1)[source]#

Returns the specified diagonal of an array.

JAX implementation of numpy.diagonal().

The JAX version always returns a copy of the input, although if this is used within a JIT compilation, the compiler may avoid the copy.

Parameters:

Return type:

Array

Examples

x = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) jnp.diagonal(x) Array([1, 5, 9], dtype=int32) jnp.diagonal(x, offset=1) Array([2, 6], dtype=int32) jnp.diagonal(x, offset=-1) Array([4, 8], dtype=int32)