torch.sparse.sampled_addmm — PyTorch 2.7 documentation (original) (raw)

torch.sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) → Tensor

Performs a matrix multiplication of the dense matrices mat1 and mat2 at the locations specified by the sparsity pattern of input. The matrix input is added to the final result.

Mathematically this performs the following operation:

out=α (mat1@mat2)∗spy(input)+β input\text{out} = \alpha\ (\text{mat1} \mathbin{@} \text{mat2})*\text{spy}(\text{input}) + \beta\ \text{input}

where spy(input)\text{spy}(\text{input}) is the sparsity pattern matrix of input, alphaand beta are the scaling factors.spy(input)\text{spy}(\text{input}) has value 1 at the positions where input has non-zero values, and 0 elsewhere.

Note

input must be a sparse CSR tensor. mat1 and mat2 must be dense tensors.

Parameters

Keyword Arguments

Examples:

input = torch.eye(3, device='cuda').to_sparse_csr() mat1 = torch.randn(3, 5, device='cuda') mat2 = torch.randn(5, 3, device='cuda') torch.sparse.sampled_addmm(input, mat1, mat2) tensor(crow_indices=tensor([0, 1, 2, 3]), col_indices=tensor([0, 1, 2]), values=tensor([ 0.2847, -0.7805, -0.1900]), device='cuda:0', size=(3, 3), nnz=3, layout=torch.sparse_csr) torch.sparse.sampled_addmm(input, mat1, mat2).to_dense() tensor([[ 0.2847, 0.0000, 0.0000], [ 0.0000, -0.7805, 0.0000], [ 0.0000, 0.0000, -0.1900]], device='cuda:0') torch.sparse.sampled_addmm(input, mat1, mat2, beta=0.5, alpha=0.5) tensor(crow_indices=tensor([0, 1, 2, 3]), col_indices=tensor([0, 1, 2]), values=tensor([ 0.1423, -0.3903, -0.0950]), device='cuda:0', size=(3, 3), nnz=3, layout=torch.sparse_csr)