jax.numpy.trace — JAX documentation (original) (raw)
jax.numpy.trace#
jax.numpy.trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None)[source]#
Calculate sum of the diagonal of input along the given axes.
JAX implementation of numpy.trace().
Parameters:
- a (ArrayLike) – input array. Must have
a.ndim >= 2. - offset (int | ArrayLike) – optional, int, default=0. Diagonal offset from the main diagonal. Can be positive or negative.
- axis1 (int) – optional, default=0. The first axis along which to take the sum of diagonal. Must be a static integer value.
- axis2 (int) – optional, default=1. The second axis along which to take the sum of diagonal. Must be a static integer value.
- dtype (DTypeLike | None) – optional. The dtype of the output array. Should be provided as static argument in JIT compilation.
- out (None) – Not used by JAX.
Returns:
An array of dimension x.ndim-2 containing the sum of the diagonal elements along axes (axis1, axis2)
Return type:
See also
- jax.numpy.diag(): Returns the specified diagonal or constructs a diagonal array
- jax.numpy.diagonal(): Returns the specified diagonal of an array.
- jax.numpy.diagflat(): Returns a 2-D array with the flattened input array laid out on the diagonal.
Examples
x = jnp.arange(1, 9).reshape(2, 2, 2) x Array([[[1, 2], [3, 4]],
[[5, 6],
[7, 8]]], dtype=int32)jnp.trace(x) Array([ 8, 10], dtype=int32) jnp.trace(x, offset=1) Array([3, 4], dtype=int32) jnp.trace(x, axis1=1, axis2=2) Array([ 5, 13], dtype=int32) jnp.trace(x, offset=1, axis1=1, axis2=2) Array([2, 6], dtype=int32)