jax.numpy.linalg.svd — JAX documentation (original) (raw)

jax.numpy.linalg.svd#

jax.numpy.linalg.svd(a, full_matrices=True, compute_uv=True, hermitian=False, subset_by_index=None)[source]#

Compute the singular value decomposition.

JAX implementation of numpy.linalg.svd(), implemented in terms ofjax.lax.linalg.svd().

The SVD of a matrix A is given by

\[A = U\Sigma V^H\]

Parameters:

Returns:

A tuple of arrays (u, s, vh) if compute_uv is True, otherwise the array s.

where K = min(N, M).

Return type:

Array | SVDResult

See also

Examples

Consider the SVD of a small real-valued array:

x = jnp.array([[1., 2., 3.], ... [6., 5., 4.]]) u, s, vt = jnp.linalg.svd(x, full_matrices=False) s
Array([9.361919 , 1.8315067], dtype=float32)

The singular vectors are in the columns of u and v = vt.T. These vectors are orthonormal, which can be demonstrated by comparing the matrix product with the identity matrix:

jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5) Array(True, dtype=bool) v = vt.T jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5) Array(True, dtype=bool)

Given the SVD, x can be reconstructed via matrix multiplication:

x_reconstructed = u @ jnp.diag(s) @ vt jnp.allclose(x_reconstructed, x) Array(True, dtype=bool)