torch.tensordot — PyTorch 2.7 documentation (original) (raw)

torch.tensordot(a, b, dims=2, out=None)[source][source]

Returns a contraction of a and b over multiple dimensions.

tensordot implements a generalized matrix product.

Parameters

When called with a non-negative integer argument dims = dd, and the number of dimensions of a and b is mm and nn, respectively, tensordot() computes

ri0,...,im−d,id,...,in=∑k0,...,kd−1ai0,...,im−d,k0,...,kd−1×bk0,...,kd−1,id,...,in.r_{i_0,...,i_{m-d}, i_d,...,i_n} = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}.

When called with dims of the list form, the given dimensions will be contracted in place of the last dd of a and the first dd of bb. The sizes in these dimensions must match, but tensordot() will deal with broadcasted dimensions.

Examples:

a = torch.arange(60.).reshape(3, 4, 5) b = torch.arange(24.).reshape(4, 3, 2) torch.tensordot(a, b, dims=([1, 0], [0, 1])) tensor([[4400., 4730.], [4532., 4874.], [4664., 5018.], [4796., 5162.], [4928., 5306.]])

a = torch.randn(3, 4, 5, device='cuda') b = torch.randn(4, 5, 6, device='cuda') c = torch.tensordot(a, b, dims=2).cpu() tensor([[ 8.3504, -2.5436, 6.2922, 2.7556, -1.0732, 3.2741], [ 3.3161, 0.0704, 5.0187, -0.4079, -4.3126, 4.8744], [ 0.8223, 3.9445, 3.2168, -0.2400, 3.4117, 1.7780]])

a = torch.randn(3, 5, 4, 6) b = torch.randn(6, 4, 5, 3) torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0])) tensor([[ 7.7193, -2.4867, -10.3204], [ 1.5513, -14.4737, -6.5113], [ -0.2850, 4.2573, -3.5997]])