torch.bmm — PyTorch 2.7 documentation (original) (raw)

torch.bmm(input, mat2, *, out=None) → Tensor

Performs a batch matrix-matrix product of matrices stored in inputand mat2.

input and mat2 must be 3-D tensors each containing the same number of matrices.

If input is a (b×n×m)(b \times n \times m) tensor, mat2 is a(b×m×p)(b \times m \times p) tensor, out will be a(b×n×p)(b \times n \times p) tensor.

outi=inputi@mat2i\text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i

This operator supports TensorFloat32.

On certain ROCm devices, when using float16 inputs this module will use different precision for backward.

Parameters

Keyword Arguments

out (Tensor, optional) – the output tensor.

Example:

input = torch.randn(10, 3, 4) mat2 = torch.randn(10, 4, 5) res = torch.bmm(input, mat2) res.size() torch.Size([10, 3, 5])