fuse_modules — PyTorch 2.7 documentation (original) (raw)

class torch.ao.quantization.fuse_modules.fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=<function fuse_known_modules>, fuse_custom_config_dict=None)[source][source]

Fuse a list of modules into a single module.

Fuses only the following sequence of modules: conv, bn conv, bn, relu conv, relu linear, relu bn, relu All other sequences are left unchanged. For these sequences, replaces the first item in the list with the fused module, replacing the rest of the modules with identity.

Parameters

Example of fuse_custom_config_dict

fuse_custom_config_dict = { # Additional fuser_method mapping "additional_fuser_method_mapping": { (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn }, }

Returns

model with fused modules. A new copy is created if inplace=True.

Examples:

m = M().eval()

m is a module containing the sub-modules below

modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) output = fused_m(input)

m = M().eval()

Alternately provide a single list of modules to fuse

modules_to_fuse = ['conv1', 'bn1', 'relu1'] fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) output = fused_m(input)