torch.inner (original) (raw)
Computes the dot product for 1D tensors. For higher dimensions, sums the product of elements from input and other along their last dimension.
Note
If either input or other is a scalar, the result is equivalent to torch.mul(input, other).
If both input and other are non-scalars, the size of their last dimension must match and the result is equivalent to torch.tensordot(input, other, dims=([-1], [-1]))
Dot product
torch.inner(torch.tensor([1, 2, 3]), torch.tensor([0, 2, 1])) tensor(7)
Multidimensional input tensors
a = torch.randn(2, 3) a tensor([[0.8173, 1.0874, 1.1784], [0.3279, 0.1234, 2.7894]]) b = torch.randn(2, 4, 3) b tensor([[[-0.4682, -0.7159, 0.1506], [ 0.4034, -0.3657, 1.0387], [ 0.9892, -0.6684, 0.1774], [ 0.9482, 1.3261, 0.3917]],
[[ 0.4537, 0.7493, 1.1724],
[ 0.2291, 0.5749, -0.2267],
[-0.7920, 0.3607, -0.3701],
[ 1.3666, -0.5850, -1.7242]]])torch.inner(a, b) tensor([[[-0.9837, 1.1560, 0.2907, 2.6785], [ 2.5671, 0.5452, -0.6912, -1.5509]],
[[ 0.1782, 2.9843, 0.7366, 1.5672],
[ 3.5115, -0.4864, -1.2476, -4.4337]]])Scalar input
torch.inner(a, torch.tensor(2)) tensor([[1.6347, 2.1748, 2.3567], [0.6558, 0.2469, 5.5787]])