torch.bmm — PyTorch 2.7 documentation (original) (raw)
torch.bmm(input, mat2, *, out=None) → Tensor¶
Performs a batch matrix-matrix product of matrices stored in input
and mat2
.
input
and mat2
must be 3-D tensors each containing the same number of matrices.
If input
is a (b×n×m)(b \times n \times m) tensor, mat2
is a(b×m×p)(b \times m \times p) tensor, out
will be a(b×n×p)(b \times n \times p) tensor.
outi=inputi@mat2i\text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i
This operator supports TensorFloat32.
On certain ROCm devices, when using float16 inputs this module will use different precision for backward.
Parameters
- input (Tensor) – the first batch of matrices to be multiplied
- mat2 (Tensor) – the second batch of matrices to be multiplied
Keyword Arguments
out (Tensor, optional) – the output tensor.
Example:
input = torch.randn(10, 3, 4) mat2 = torch.randn(10, 4, 5) res = torch.bmm(input, mat2) res.size() torch.Size([10, 3, 5])