jax.numpy.linalg.vector_norm — JAX documentation (original) (raw)

jax.numpy.linalg.vector_norm#

jax.numpy.linalg.vector_norm(x, /, *, axis=None, keepdims=False, ord=2)[source]#

Compute the vector norm of a vector or batch of vectors.

JAX implementation of numpy.linalg.vector_norm().

Parameters:

Returns:

array containing the norm of x.

Return type:

Array

See also

Examples

Norm of a single vector:

x = jnp.array([1., 2., 3.]) jnp.linalg.vector_norm(x) Array(3.7416575, dtype=float32)

Norm of a batch of vectors:

x = jnp.array([[1., 2., 3.], ... [4., 5., 7.]]) jnp.linalg.vector_norm(x, axis=1) Array([3.7416575, 9.486833 ], dtype=float32)