jax.numpy.trace — JAX documentation (original) (raw)

jax.numpy.trace#

jax.numpy.trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None)[source]#

Calculate sum of the diagonal of input along the given axes.

JAX implementation of numpy.trace().

Parameters:

Returns:

An array of dimension x.ndim-2 containing the sum of the diagonal elements along axes (axis1, axis2)

Return type:

Array

See also

Examples

x = jnp.arange(1, 9).reshape(2, 2, 2) x Array([[[1, 2], [3, 4]],

   [[5, 6],
    [7, 8]]], dtype=int32)

jnp.trace(x) Array([ 8, 10], dtype=int32) jnp.trace(x, offset=1) Array([3, 4], dtype=int32) jnp.trace(x, axis1=1, axis2=2) Array([ 5, 13], dtype=int32) jnp.trace(x, offset=1, axis1=1, axis2=2) Array([2, 6], dtype=int32)