jax.Array.trace — JAX documentation (original) (raw)
jax.Array.trace#
abstract Array.trace(offset=0, axis1=0, axis2=1, dtype=None, out=None)[source]#
Return the sum along the diagonal.
Refer to jax.numpy.trace() for full documentation.
Parameters:
Return type: