torch.tensordot — PyTorch 2.7 documentation (original) (raw)
torch.tensordot(a, b, dims=2, out=None)[source][source]¶
Returns a contraction of a and b over multiple dimensions.
tensordot implements a generalized matrix product.
Parameters
- a (Tensor) – Left tensor to contract
- b (Tensor) – Right tensor to contract
- dims (int or Tuple _[_ _List_ _[_int] , List _[_int] ] or List _[_ _List_ _[_int] ] containing two lists or Tensor) – number of dimensions to contract or explicit lists of dimensions for
a
andb
respectively
When called with a non-negative integer argument dims
= dd, and the number of dimensions of a
and b
is mm and nn, respectively, tensordot() computes
ri0,...,im−d,id,...,in=∑k0,...,kd−1ai0,...,im−d,k0,...,kd−1×bk0,...,kd−1,id,...,in.r_{i_0,...,i_{m-d}, i_d,...,i_n} = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}.
When called with dims
of the list form, the given dimensions will be contracted in place of the last dd of a
and the first dd of bb. The sizes in these dimensions must match, but tensordot() will deal with broadcasted dimensions.
Examples:
a = torch.arange(60.).reshape(3, 4, 5) b = torch.arange(24.).reshape(4, 3, 2) torch.tensordot(a, b, dims=([1, 0], [0, 1])) tensor([[4400., 4730.], [4532., 4874.], [4664., 5018.], [4796., 5162.], [4928., 5306.]])
a = torch.randn(3, 4, 5, device='cuda') b = torch.randn(4, 5, 6, device='cuda') c = torch.tensordot(a, b, dims=2).cpu() tensor([[ 8.3504, -2.5436, 6.2922, 2.7556, -1.0732, 3.2741], [ 3.3161, 0.0704, 5.0187, -0.4079, -4.3126, 4.8744], [ 0.8223, 3.9445, 3.2168, -0.2400, 3.4117, 1.7780]])
a = torch.randn(3, 5, 4, 6) b = torch.randn(6, 4, 5, 3) torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0])) tensor([[ 7.7193, -2.4867, -10.3204], [ 1.5513, -14.4737, -6.5113], [ -0.2850, 4.2573, -3.5997]])