torch.Tensor.sparse_mask (original) (raw)
Returns a new sparse tensor with values from a strided tensor self
filtered by the indices of the sparse tensor mask
. The values of mask
sparse tensor are ignored. self
and mask
tensors must have the same shape.
nse = 5 dims = (5, 5, 2, 2) I = torch.cat([torch.randint(0, dims[0], size=(nse,)), ... torch.randint(0, dims[1], size=(nse,))], 0).reshape(2, nse) V = torch.randn(nse, dims[2], dims[3]) S = torch.sparse_coo_tensor(I, V, dims).coalesce() D = torch.randn(dims) D.sparse_mask(S) tensor(indices=tensor([[0, 0, 0, 2], [0, 1, 4, 3]]), values=tensor([[[ 1.6550, 0.2397], [-0.1611, -0.0779]],
[[ 0.2326, -1.0558],
[ 1.4711, 1.9678]],
[[-0.5138, -0.0411],
[ 1.9417, 0.5158]],
[[ 0.0793, 0.0036],
[-0.2569, -0.1055]]]),
size=(5, 5, 2, 2), nnz=4, layout=torch.sparse_coo)