torch.lu — PyTorch 2.7 documentation (original) (raw)

torch.lu(*args, **kwargs)[source]

Computes the LU factorization of a matrix or batches of matricesA. Returns a tuple containing the LU factorization and pivots of A. Pivoting is done if pivot is set toTrue.

Warning

torch.lu() is deprecated in favor of torch.linalg.lu_factor()and torch.linalg.lu_factor_ex(). torch.lu() will be removed in a future PyTorch release.LU, pivots, info = torch.lu(A, compute_pivots) should be replaced with

LU, pivots = torch.linalg.lu_factor(A, compute_pivots)

LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True) should be replaced with

LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)

Note

Warning

The gradients of this function will only be finite when A is full rank. This is because the LU decomposition is just differentiable at full rank matrices. Furthermore, if A is close to not being full rank, the gradient will be numerically unstable as it depends on the computation of L−1L^{-1} and U−1U^{-1}.

Parameters

Returns

A tuple of tensors containing

Return type

(Tensor, IntTensor, IntTensor (optional))

Example:

A = torch.randn(2, 3, 3) A_LU, pivots = torch.lu(A) A_LU tensor([[[ 1.3506, 2.5558, -0.0816], [ 0.1684, 1.1551, 0.1940], [ 0.1193, 0.6189, -0.5497]],

    [[ 0.4526,  1.2526, -0.3285],
     [-0.7988,  0.7175, -0.9701],
     [ 0.2634, -0.9255, -0.3459]]])

pivots tensor([[ 3, 3, 3], [ 3, 3, 3]], dtype=torch.int32) A_LU, pivots, info = torch.lu(A, get_infos=True) if info.nonzero().size(0) == 0: ... print('LU factorization succeeded for all samples!') LU factorization succeeded for all samples!