torch.nn.utils.fuse_conv_bn_weights — PyTorch 2.7 documentation (original) (raw)
torch.nn.utils.fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False)[source][source]¶
Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters.
Parameters
- conv_w (torch.Tensor) – Convolutional weight.
- conv_b (Optional _[_torch.Tensor]) – Convolutional bias.
- bn_rm (torch.Tensor) – BatchNorm running mean.
- bn_rv (torch.Tensor) – BatchNorm running variance.
- bn_eps (float) – BatchNorm epsilon.
- bn_w (Optional _[_torch.Tensor]) – BatchNorm weight.
- bn_b (Optional _[_torch.Tensor]) – BatchNorm bias.
- transpose (bool, optional) – If True, transpose the conv weight. Defaults to False.
Returns
Fused convolutional weight and bias.
Return type
Tuple[torch.nn.Parameter, torch.nn.Parameter]