torch.Tensor.sparse_mask (original) (raw)

Returns a new sparse tensor with values from a strided tensor self filtered by the indices of the sparse tensor mask. The values of mask sparse tensor are ignored. self and mask tensors must have the same shape.

nse = 5 dims = (5, 5, 2, 2) I = torch.cat([torch.randint(0, dims[0], size=(nse,)), ... torch.randint(0, dims[1], size=(nse,))], 0).reshape(2, nse) V = torch.randn(nse, dims[2], dims[3]) S = torch.sparse_coo_tensor(I, V, dims).coalesce() D = torch.randn(dims) D.sparse_mask(S) tensor(indices=tensor([[0, 0, 0, 2], [0, 1, 4, 3]]), values=tensor([[[ 1.6550, 0.2397], [-0.1611, -0.0779]],

                  [[ 0.2326, -1.0558],
                   [ 1.4711,  1.9678]],

                  [[-0.5138, -0.0411],
                   [ 1.9417,  0.5158]],

                  [[ 0.0793,  0.0036],
                   [-0.2569, -0.1055]]]),
   size=(5, 5, 2, 2), nnz=4, layout=torch.sparse_coo)