torch.inner — PyTorch 2.7 documentation (original) (raw)
torch.inner(input, other, *, out=None) → Tensor¶
Computes the dot product for 1D tensors. For higher dimensions, sums the product of elements from input
and other
along their last dimension.
Note
If either input
or other
is a scalar, the result is equivalent to torch.mul(input, other).
If both input
and other
are non-scalars, the size of their last dimension must match and the result is equivalent to torch.tensordot(input, other, dims=([-1], [-1]))
Parameters
Keyword Arguments
out (Tensor, optional) – Optional output tensor to write result into. The output shape is input.shape[:-1] + other.shape[:-1].
Example:
Dot product
torch.inner(torch.tensor([1, 2, 3]), torch.tensor([0, 2, 1])) tensor(7)
Multidimensional input tensors
a = torch.randn(2, 3) a tensor([[0.8173, 1.0874, 1.1784], [0.3279, 0.1234, 2.7894]]) b = torch.randn(2, 4, 3) b tensor([[[-0.4682, -0.7159, 0.1506], [ 0.4034, -0.3657, 1.0387], [ 0.9892, -0.6684, 0.1774], [ 0.9482, 1.3261, 0.3917]],
[[ 0.4537, 0.7493, 1.1724],
[ 0.2291, 0.5749, -0.2267],
[-0.7920, 0.3607, -0.3701],
[ 1.3666, -0.5850, -1.7242]]])
torch.inner(a, b) tensor([[[-0.9837, 1.1560, 0.2907, 2.6785], [ 2.5671, 0.5452, -0.6912, -1.5509]],
[[ 0.1782, 2.9843, 0.7366, 1.5672],
[ 3.5115, -0.4864, -1.2476, -4.4337]]])
Scalar input
torch.inner(a, torch.tensor(2)) tensor([[1.6347, 2.1748, 2.3567], [0.6558, 0.2469, 5.5787]])