torch.lu_solve — PyTorch 2.7 documentation (original) (raw)
torch.lu_solve(b, LU_data, LU_pivots, *, out=None) → Tensor¶
Returns the LU solve of the linear system Ax=bAx = b using the partially pivoted LU factorization of A from lu_factor().
This function supports float
, double
, cfloat
and cdouble
dtypes for input
.
Warning
torch.lu_solve() is deprecated in favor of torch.linalg.lu_solve().torch.lu_solve() will be removed in a future PyTorch release.X = torch.lu_solve(B, LU, pivots)
should be replaced with
X = linalg.lu_solve(LU, pivots, B)
Parameters
- b (Tensor) – the RHS tensor of size (∗,m,k)(*, m, k), where ∗*is zero or more batch dimensions.
- LU_data (Tensor) – the pivoted LU factorization of A from lu_factor() of size (∗,m,m)(*, m, m), where ∗* is zero or more batch dimensions.
- LU_pivots (IntTensor) – the pivots of the LU factorization from lu_factor() of size (∗,m)(*, m), where ∗* is zero or more batch dimensions. The batch dimensions of
LU_pivots
must be equal to the batch dimensions ofLU_data
.
Keyword Arguments
out (Tensor, optional) – the output tensor.
Example:
A = torch.randn(2, 3, 3) b = torch.randn(2, 3, 1) LU, pivots = torch.linalg.lu_factor(A) x = torch.lu_solve(b, LU, pivots) torch.dist(A @ x, b) tensor(1.00000e-07 * 2.8312)