jax.numpy.tri — JAX documentation (original) (raw)

jax.numpy.tri#

jax.numpy.tri(N, M=None, k=0, dtype=None)[source]#

Return an array with ones on and below the diagonal and zeros elsewhere.

JAX implementation of numpy.tri()

Parameters:

Returns:

An array of shape (N, M) containing the lower triangle with elements below the sub-diagonal specified by k are set to one and zero elsewhere.

Return type:

Array

See also

Examples

jnp.tri(3) Array([[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]], dtype=float32)

When M is not equal to N:

jnp.tri(3, 4) Array([[1., 0., 0., 0.], [1., 1., 0., 0.], [1., 1., 1., 0.]], dtype=float32)

when k>0:

jnp.tri(3, k=1) Array([[1., 1., 0.], [1., 1., 1.], [1., 1., 1.]], dtype=float32)

When k<0:

jnp.tri(3, 4, k=-1) Array([[0., 0., 0., 0.], [1., 0., 0., 0.], [1., 1., 0., 0.]], dtype=float32)