jax.numpy.linalg.vector_norm — JAX documentation (original) (raw)
jax.numpy.linalg.vector_norm#
jax.numpy.linalg.vector_norm(x, /, *, axis=None, keepdims=False, ord=2)[source]#
Compute the vector norm of a vector or batch of vectors.
JAX implementation of numpy.linalg.vector_norm().
Parameters:
- x (ArrayLike) – N-dimensional array for which to take the norm.
- axis (int | tuple[_int,_ ... ] | None) – optional axis along which to compute the vector norm. If None (default) then
xis flattened and the norm is taken over all values. - keepdims (bool) – if True, keep the reduced dimensions in the output.
- ord (int | str) – A string or int specifying the type of norm; default is the 2-norm. See numpy.linalg.norm() for details on available options.
Returns:
array containing the norm of x.
Return type:
See also
- jax.numpy.linalg.matrix_norm(): Norm of a matrix or stack of matrices.
- jax.numpy.linalg.norm(): More general matrix or vector norm.
Examples
Norm of a single vector:
x = jnp.array([1., 2., 3.]) jnp.linalg.vector_norm(x) Array(3.7416575, dtype=float32)
Norm of a batch of vectors:
x = jnp.array([[1., 2., 3.], ... [4., 5., 7.]]) jnp.linalg.vector_norm(x, axis=1) Array([3.7416575, 9.486833 ], dtype=float32)