torch.baddbmm — PyTorch 2.7 documentation (original) (raw)

torch.baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) → Tensor

Performs a batch matrix-matrix product of matrices in batch1and batch2.input is added to the final result.

batch1 and batch2 must be 3-D tensors each containing the same number of matrices.

If batch1 is a (b×n×m)(b \times n \times m) tensor, batch2 is a(b×m×p)(b \times m \times p) tensor, then input must bebroadcastable with a(b×n×p)(b \times n \times p) tensor and out will be a(b×n×p)(b \times n \times p) tensor. Both alpha and beta mean the same as the scaling factors used in torch.addbmm().

outi=β inputi+α (batch1i@batch2i)\text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i)

If beta is 0, then the content of input will be ignored, and nan and inf in it will not be propagated.

For inputs of type FloatTensor or DoubleTensor, arguments beta andalpha must be real numbers, otherwise they should be integers.

This operator supports TensorFloat32.

On certain ROCm devices, when using float16 inputs this module will use different precision for backward.

Parameters

Keyword Arguments

Example:

M = torch.randn(10, 3, 5) batch1 = torch.randn(10, 3, 4) batch2 = torch.randn(10, 4, 5) torch.baddbmm(M, batch1, batch2).size() torch.Size([10, 3, 5])