torch.lu_solve — PyTorch 2.7 documentation (original) (raw)

torch.lu_solve(b, LU_data, LU_pivots, *, out=None) → Tensor

Returns the LU solve of the linear system Ax=bAx = b using the partially pivoted LU factorization of A from lu_factor().

This function supports float, double, cfloat and cdouble dtypes for input.

Warning

torch.lu_solve() is deprecated in favor of torch.linalg.lu_solve().torch.lu_solve() will be removed in a future PyTorch release.X = torch.lu_solve(B, LU, pivots) should be replaced with

X = linalg.lu_solve(LU, pivots, B)

Parameters

Keyword Arguments

out (Tensor, optional) – the output tensor.

Example:

A = torch.randn(2, 3, 3) b = torch.randn(2, 3, 1) LU, pivots = torch.linalg.lu_factor(A) x = torch.lu_solve(b, LU, pivots) torch.dist(A @ x, b) tensor(1.00000e-07 * 2.8312)