jax.numpy.linalg.cross — JAX documentation (original) (raw)

jax.numpy.linalg.cross#

jax.numpy.linalg.cross(x1, x2, /, *, axis=-1)[source]#

Compute the cross-product of two 3D vectors

JAX implementation of numpy.linalg.cross()

Parameters:

Returns:

array containing the result of the cross-product

Examples

Showing that \(\hat{x} \times \hat{y} = \hat{z}\):

x = jnp.array([1., 0., 0.]) y = jnp.array([0., 1., 0.]) jnp.linalg.cross(x, y) Array([0., 0., 1.], dtype=float32)

Cross product of \(\hat{x}\) with all three standard unit vectors, via broadcasting:

xyz = jnp.eye(3) jnp.linalg.cross(x, xyz, axis=-1) Array([[ 0., 0., 0.], [ 0., 0., 1.], [ 0., -1., 0.]], dtype=float32)