torch.ao.quantization.quantize — PyTorch 2.7 documentation (original) (raw)

mypy: allow-untyped-defs

import copy import inspect import itertools import warnings

import torch import torch.ao.nn.quantized as nnq import torch.nn as nn from torch.ao.nn.intrinsic import _FusedModule from torch.ao.quantization.observer import _is_activation_post_process from torch.ao.quantization.qconfig import ( _activation_is_memoryless, _add_module_to_qconfig_obs_ctr, default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_weight_only_qconfig, float_qparams_weight_only_qconfig_4bit, ) from torch.ao.quantization.quantization_mappings import ( _get_special_act_post_process, _has_special_act_post_process, get_default_dynamic_quant_module_mappings, get_default_qat_module_mappings, get_default_qconfig_propagation_list, get_default_static_quant_module_mappings, get_default_static_quant_reference_module_mappings, no_observer_set, ) from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper from torch.nn.utils.parametrize import type_before_parametrizations

from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations

all = [ "get_default_custom_config_dict", "propagate_qconfig_", "add_quant_dequant", "prepare", "quantize", "quantize_dynamic", "prepare_qat", "quantize_qat", "convert", "swap_module", ]

TODO remove this once BC is no longer required to avoid a SEV

is_activation_post_process = _is_activation_post_process

_DEFAULT_CUSTOM_CONFIG_DICT = { "float_to_observed_custom_module_class": { nn.LSTM: nn.quantizable.LSTM, nn.MultiheadAttention: nn.quantizable.MultiheadAttention, }, "observed_to_quantized_custom_module_class": { nn.quantizable.LSTM: nn.quantized.LSTM, nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention, }, }

def get_default_custom_config_dict(): r"""Defines the default custom config dict.""" return _DEFAULT_CUSTOM_CONFIG_DICT

def _propagate_qconfig_helper( module, qconfig_dict, qconfig_parent=None, prefix="", prepare_custom_config_dict=None, ): r"""This is a helper function for propagate_qconfig_

Args:
    module: input module
    qconfig_dict: dictionary that maps from name of submodule to quantization
                 configuration
    qconfig_parent: quantization config of parent module, we will fallback to
                   this config when there is no specified config for current
                   module
    prefix: corresponding prefix of the current module, used as key in
            qconfig_dict
    prepare_custom_config_dict: dictionary for custom handling of modules
                                see docs for :func:`~torch.ao.quantization.prepare_fx`

Return:
    None, module is modified inplace with qconfig attached
"""

module_qconfig = qconfig_dict.get(
    type_before_parametrizations(module), qconfig_parent
)
module_qconfig = qconfig_dict.get(prefix, module_qconfig)
module_qconfig = getattr(module, "qconfig", module_qconfig)

torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module)

qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module)
module.qconfig = qconfig_with_device_check

for name, child in module.named_children():
    module_prefix = prefix + "." + name if prefix else name
    #  do no not propagate qconfig to child if child is non traceable
    if prepare_custom_config_dict is None or not (
        name in prepare_custom_config_dict.get("non_traceable_module_name", [])
        or type(child)
        in prepare_custom_config_dict.get("non_traceable_module_class", [])
    ):
        _propagate_qconfig_helper(
            child, qconfig_dict, qconfig_with_device_check, module_prefix
        )

[docs]def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None): r"""Propagate qconfig through the module hierarchy and assign qconfig attribute on each leaf module

Args:
    module: input module
    qconfig_dict: dictionary that maps from name or type of submodule to
        quantization configuration, qconfig applies to all submodules of a
        given module unless qconfig for the submodules are specified (when
        the submodule already has qconfig attribute)
    prepare_custom_config_dict: dictionary for custom handling of modules
        see docs for :func:`~torch.ao.quantization.prepare_fx`

Return:
    None, module is modified inplace with qconfig attached
"""
if qconfig_dict is None:
    qconfig_dict = {}
if prepare_custom_config_dict is None:
    prepare_custom_config_dict = {}
_propagate_qconfig_helper(
    module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict
)

def _observer_forward_hook(self, input, output): r"""Forward hook that calls observer on the output""" return self.activation_post_process(output)

def _observer_forward_pre_hook(self, input): r"""Forward pre hook that calls observer on the output""" return self.activation_post_process(input[0])

def _register_activation_post_process_hook(module, pre_hook=False): assert hasattr( module, "activation_post_process" ), "Expect activation_post_process attribute already attached to the module" if pre_hook: module.register_forward_pre_hook(_observer_forward_pre_hook, prepend=True) else: module.register_forward_hook(_observer_forward_hook, prepend=True)

def add_observer( module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None, ): r"""Add observer for the leaf child of the module.

This function insert observer module to all leaf child module that
has a valid qconfig attribute.

Args:
    module: input module with qconfig attributes for all the leaf modules that we want to quantize
    qconfig_propagation_list: a list of quantizable modules that will have observers added to them
        if they are leaf nodes
    device: parent device, if any
    non_leaf_module_list: list of non-leaf modules we want to add observer

Return:
    None, module is modified inplace with added observer modules and forward_hooks
"""
if qconfig_propagation_list is None:
    qconfig_propagation_list = get_default_qconfig_propagation_list()

if custom_module_class_mapping is None:
    custom_module_class_mapping = {}

# respect device affinity when adding observers
if device is None:
    devices = _get_unique_devices_(module)
    assert (
        len(devices) <= 1
    ), f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}"
    device = next(iter(devices)) if len(devices) > 0 else None

def get_activation_post_process(qconfig, device, special_act_post_process=None):
    activation = (
        qconfig.activation()
        if special_act_post_process is None
        else special_act_post_process()
    )
    if device is not None:
        activation.to(device)
    return activation

def needs_observation(m):
    return hasattr(m, "qconfig") and m.qconfig is not None

def insert_activation_post_process(m, special_act_post_process=None):
    """Adds an activation post process module and register
    a pre or post hook that calls the module
    """
    # We don't insert observer/fake_quantize for DeQuantStub
    if needs_observation(m) and not isinstance(m, DeQuantStub):
        # observer and hook will be gone after we swap the module
        m.add_module(
            "activation_post_process",
            get_activation_post_process(
                m.qconfig, device, special_act_post_process
            ),
        )
        # Register observer as the first entry in the hook list
        # All post forward hooks are preserved and will be executed after the observer before convert
        _register_activation_post_process_hook(
            m, pre_hook=_activation_is_memoryless(m.qconfig)
        )

for name, child in module.named_children():
    # TODO remove Dropout special after codebase stable
    if type_before_parametrizations(child) in [nn.Dropout]:
        continue
    elif issubclass(
        type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional)
    ):
        if needs_observation(child):
            assert hasattr(
                child, "activation_post_process"
            ), f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`"
            child.activation_post_process = get_activation_post_process(
                child.qconfig, device
            )
    elif isinstance(child, _FusedModule):
        # activation_post_process are now added directly to nn.Sequential/_FusedModule
        if needs_observation(child):
            insert_activation_post_process(child)
    elif (
        non_leaf_module_list is not None
        and type_before_parametrizations(child) in non_leaf_module_list
    ):
        if needs_observation(child):
            insert_activation_post_process(child)
    elif _has_special_act_post_process(child):
        special_act_post_process = _get_special_act_post_process(child)
        insert_activation_post_process(child, special_act_post_process)
    elif (
        needs_observation(child)
        and type_before_parametrizations(child) in custom_module_class_mapping
    ):
        observed_class = custom_module_class_mapping[
            type_before_parametrizations(child)
        ]
        observed_child = observed_class.from_float(child)
        setattr(module, name, observed_child)
        # TODO: These are the modules that cannot be observed
        #       Once there are more, we should move them to a separate list
        if not issubclass(observed_class, tuple(no_observer_set())):
            insert_activation_post_process(observed_child)
    else:
        _add_observer_(
            child,
            qconfig_propagation_list,
            non_leaf_module_list,
            device,
            custom_module_class_mapping,
        )

# Insert observers only for leaf nodes, note that this observer is for
# the output of the module, for input QuantStub will observe them
if (
    has_no_children_ignoring_parametrizations(module)
    and not isinstance(module, torch.nn.Sequential)
    and type_before_parametrizations(module) in qconfig_propagation_list
):
    insert_activation_post_process(module)
# This is a special case for AdaRound eager mode
# AdaRound contains weight_fake_quant to be propagated from API to convert
# leaf node check with a number of children looks naive assumption that blocks
# Adding an exception case for AdaRound
if (
    hasattr(module, "weight_fake_quant")
    and not isinstance(module, torch.nn.Sequential)
    and type_before_parametrizations(module) in qconfig_propagation_list
):
    insert_activation_post_process(module)

def get_unique_devices(module): return {p.device for p in module.parameters() if p.device.type != "meta"} | { p.device for p in module.buffers() if p.device.type != "meta" }

[docs]def add_quant_dequant(module): r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig Note that this function will modify the children of module inplace and it can return a new module which wraps the input module as well.

Args:
    module: input module with qconfig attributes for all the leaf modules
    that we want to quantize

Return:
    Either the inplace modified module with submodules wrapped in
    `QuantWrapper` based on qconfig or a new `QuantWrapper` module which
    wraps the input module, the latter case only happens when the input
    module is a leaf module and we want to quantize it.
"""
if (
    has_no_children_ignoring_parametrizations(module)
    and hasattr(module, "qconfig")
    and module.qconfig
):
    return QuantWrapper(module)

for name, child in module.named_children():
    module._modules[name] = add_quant_dequant(child)
return module

[docs]def prepare( model, inplace=False, allow_list=None, observer_non_leaf_module_list=None, prepare_custom_config_dict=None, ): r"""Prepares a copy of the model for quantization calibration or quantization-aware training.

Quantization configuration should be assigned preemptively
to individual submodules in `.qconfig` attribute.

The model will be attached with observer or fake quant modules, and qconfig
will be propagated.

Args:
    `model`: input model to be modified in-place
    `inplace`: carry out model transformations in-place, the original module is mutated
    `allow_list`: list of quantizable modules
    `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer
    `prepare_custom_config_dict`: customization configuration dictionary for prepare function

.. code-block:: python

   # Example of prepare_custom_config_dict:
   prepare_custom_config_dict = {
       # user will manually define the corresponding observed
       # module class which has a from_float class method that converts
       # float custom module to observed custom module
       "float_to_observed_custom_module_class": {
           CustomModule: ObservedCustomModule
       }
    }

"""
torch._C._log_api_usage_once("quantization_api.quantize.prepare")
if prepare_custom_config_dict is None:
    prepare_custom_config_dict = get_default_custom_config_dict()
custom_module_class_mapping = prepare_custom_config_dict.get(
    "float_to_observed_custom_module_class", {}
)

if not inplace:
    model = copy.deepcopy(model)

# TODO: remove allow_list
qconfig_propagation_list = allow_list
if allow_list is None:
    qconfig_propagation_list = get_default_qconfig_propagation_list()
propagate_qconfig_(model, qconfig_dict=None)

# sanity check common API misusage
if not any(hasattr(m, "qconfig") and m.qconfig for m in model.modules()):
    warnings.warn(
        "None of the submodule got qconfig applied. Make sure you "
        "passed correct configuration through `qconfig_dict` or "
        "by assigning the `.qconfig` attribute directly on submodules"
    )

_add_observer_(
    model,
    qconfig_propagation_list,
    observer_non_leaf_module_list,
    custom_module_class_mapping=custom_module_class_mapping,
)
return model

def _remove_activation_post_process(module): # TODO: maybe we should change activation_post_process to _activation_post_process # to prevent it from being used by user if hasattr(module, "activation_post_process") and _is_activation_post_process( module.activation_post_process ): delattr(module, "activation_post_process")

# remove activation_post_process pre and post hooks
def remove_hooks(pre_hook=False):
    hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks
    observer_hook = (
        _observer_forward_pre_hook if pre_hook else _observer_forward_hook
    )
    handle_ids_to_remove = set()
    for handle_id, hook_fn in hook_map.items():
        if hook_fn is observer_hook:
            handle_ids_to_remove.add(handle_id)
    for handle_id in handle_ids_to_remove:
        hook_map.pop(handle_id)

remove_hooks(pre_hook=True)
remove_hooks(pre_hook=False)

TODO: rename to something more general

def _remove_qconfig(module): r"""Clean up the qconfig left in the module so that new qconfig can be propagated.

Args:
    module: module to be cleaned up
"""
for child in module.children():
    _remove_qconfig(child)

if hasattr(module, "qconfig"):
    del module.qconfig

_remove_activation_post_process(module)

[docs]def quantize(model, run_fn, run_args, mapping=None, inplace=False): r"""Quantize the input float model with post training static quantization.

First it will prepare the model for calibration, then it calls
`run_fn` which will run the calibration step, after that we will
convert the model to a quantized model.

Args:
    model: input float model
    run_fn: a calibration function for calibrating the prepared model
    run_args: positional arguments for `run_fn`
    inplace: carry out model transformations in-place, the original module is mutated
    mapping: correspondence between original module types and quantized counterparts

Return:
    Quantized model.
"""
torch._C._log_api_usage_once("quantization_api.quantize.quantize")
if mapping is None:
    mapping = get_default_static_quant_module_mappings()
if not inplace:
    model = copy.deepcopy(model)
model.eval()
prepare(model, inplace=True)
run_fn(model, *run_args)
convert(model, mapping, inplace=True)
return model

[docs]def quantize_dynamic( model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False ): r"""Converts a float model to dynamic (i.e. weights-only) quantized model.

Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.

For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
by default is performed for layers with large weights size - i.e. Linear and RNN variants.

Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`.
If `qconfig` is provided, the `dtype` argument is ignored.

Args:
    model: input model
    qconfig_spec: Either:

        - A dictionary that maps from name or type of submodule to quantization
          configuration, qconfig applies to all submodules of a given
          module unless qconfig for the submodules are specified (when the
          submodule already has qconfig attribute). Entries in the dictionary
          need to be QConfig instances.

        - A set of types and/or submodule names to apply dynamic quantization to,
          in which case the `dtype` argument is used to specify the bit-width

    inplace: carry out model transformations in-place, the original module is mutated
    mapping: maps type of a submodule to a type of corresponding dynamically quantized version
        with which the submodule needs to be replaced

"""
torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")
if qconfig_spec is None:
    if dtype == torch.qint8:
        qconfig_spec = {
            nn.Linear: default_dynamic_qconfig,
            nn.LSTM: default_dynamic_qconfig,
            nn.GRU: default_dynamic_qconfig,
            nn.LSTMCell: default_dynamic_qconfig,
            nn.RNNCell: default_dynamic_qconfig,
            nn.GRUCell: default_dynamic_qconfig,
        }
    elif dtype == torch.float16:
        qconfig_spec = {
            nn.Linear: float16_dynamic_qconfig,
            nn.LSTM: float16_dynamic_qconfig,
            nn.GRU: float16_dynamic_qconfig,
            nn.LSTMCell: float16_dynamic_qconfig,
            nn.RNNCell: float16_dynamic_qconfig,
            nn.GRUCell: float16_dynamic_qconfig,
        }
    elif dtype == torch.quint8:
        qconfig_spec = {
            nn.EmbeddingBag: float_qparams_weight_only_qconfig,
            nn.Embedding: float_qparams_weight_only_qconfig,
        }
    elif dtype == torch.quint4x2:
        qconfig_spec = {
            nn.EmbeddingBag: float_qparams_weight_only_qconfig_4bit,
        }
    else:
        raise ValueError(
            f"Don't know how to quantize with default settings for {dtype}. Provide full qconfig please"
        )
elif isinstance(qconfig_spec, set):
    if dtype is torch.qint8:
        default_qconfig = default_dynamic_qconfig
    elif dtype is torch.float16:
        default_qconfig = float16_dynamic_qconfig
    elif dtype is torch.quint8:
        default_qconfig = float_qparams_weight_only_qconfig
    elif dtype is torch.quint4x2:
        default_qconfig = float_qparams_weight_only_qconfig_4bit
    else:
        raise RuntimeError(
            "Unknown dtype specified for quantize_dynamic: ", str(dtype)
        )
    qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))

if mapping is None:
    mapping = get_default_dynamic_quant_module_mappings()

if not inplace:
    model = copy.deepcopy(model)
model.eval()
propagate_qconfig_(model, qconfig_spec)
convert(model, mapping, inplace=True)
return model

[docs]def prepare_qat(model, mapping=None, inplace=False): r""" Prepares a copy of the model for quantization calibration or quantization-aware training and converts it to quantized version.

Quantization configuration should be assigned preemptively
to individual submodules in `.qconfig` attribute.

Args:
    model: input model to be modified in-place
    mapping: dictionary that maps float modules to quantized modules to be
             replaced.
    inplace: carry out model transformations in-place, the original module
             is mutated
"""
torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
assert model.training, "prepare_qat only works on models in training mode"
if mapping is None:
    mapping = get_default_qat_module_mappings()

if not inplace:
    model = copy.deepcopy(model)

propagate_qconfig_(model, qconfig_dict=None)
convert(model, mapping=mapping, inplace=True, remove_qconfig=False)
prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True)
return model

[docs]def quantize_qat(model, run_fn, run_args, inplace=False): r"""Do quantization aware training and output a quantized model

Args:
    model: input model
    run_fn: a function for evaluating the prepared model, can be a
            function that simply runs the prepared model or a training
            loop
    run_args: positional arguments for `run_fn`

Return:
    Quantized model.
"""
torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat")
if not inplace:
    model = copy.deepcopy(model)
model.train()
prepare_qat(model, inplace=True)
run_fn(model, *run_args)
convert(model, inplace=True)
return model

[docs]def convert( module, mapping=None, inplace=False, remove_qconfig=True, is_reference=False, convert_custom_config_dict=None, use_precomputed_fake_quant=False, ): r"""Converts submodules in input module to a different module according to mapping by calling from_float method on the target module class. And remove qconfig at the end if remove_qconfig is set to True.

Args:
    `module`: prepared and calibrated module
    `mapping`: a dictionary that maps from source module type to target
               module type, can be overwritten to allow swapping user defined
               Modules
    `inplace`: carry out model transformations in-place, the original module
               is mutated
    `convert_custom_config_dict`: custom configuration dictionary for convert function
    `use_precomputed_fake_quant`: a flag to enable use of precomputed fake quant

.. code-block:: python

   # Example of convert_custom_config_dict:
   convert_custom_config_dict = {
       # user will manually define the corresponding quantized
       # module class which has a from_observed class method that converts
       # observed custom module to quantized custom module
       "observed_to_quantized_custom_module_class": {
           ObservedCustomModule: QuantizedCustomModule
       }
   }

"""
torch._C._log_api_usage_once("quantization_api.quantize.convert")
if not inplace:
    module = copy.deepcopy(module)
_convert(
    module,
    mapping,
    inplace=True,
    is_reference=is_reference,
    convert_custom_config_dict=convert_custom_config_dict,
    use_precomputed_fake_quant=use_precomputed_fake_quant,
)
if remove_qconfig:
    _remove_qconfig(module)
return module

def _convert( module, mapping=None, inplace=False, is_reference=False, convert_custom_config_dict=None, use_precomputed_fake_quant=False, ): r"""Converts submodules in input module to a different module according to mapping by calling from_float method on the target module class

Args:
    module: input module
    mapping: a dictionary that maps from source module type to target
             module type, can be overwritten to allow swapping user defined
             Modules
    inplace: carry out model transformations in-place, the original module
             is mutated
    is_reference: a flag to enable quantized reference module
    use_precomputed_fake_quant: a flag to enable use of precomputed fake quant

"""
if mapping is None:
    mapping = (
        get_default_static_quant_reference_module_mappings()
        if is_reference
        else get_default_static_quant_module_mappings()
    )
if convert_custom_config_dict is None:
    convert_custom_config_dict = get_default_custom_config_dict()
custom_module_class_mapping = convert_custom_config_dict.get(
    "observed_to_quantized_custom_module_class", {}
)

if not inplace:
    module = copy.deepcopy(module)
reassign = {}
for name, mod in module.named_children():
    # both fused modules and observed custom modules are
    # swapped as one unit
    if (
        not isinstance(mod, _FusedModule)
        and type_before_parametrizations(mod) not in custom_module_class_mapping
    ):
        _convert(
            mod,
            mapping,
            True,  # inplace
            is_reference,
            convert_custom_config_dict,
            use_precomputed_fake_quant=use_precomputed_fake_quant,
        )
    reassign[name] = swap_module(
        mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant
    )

for key, value in reassign.items():
    module._modules[key] = value

return module

[docs]def swap_module( mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant=False ): r"""Swaps the module if it has a quantized counterpart and it has an observer attached.

Args:
    mod: input module
    mapping: a dictionary that maps from nn module to nnq module

Return:
    The corresponding quantized module of `mod`
"""
new_mod = mod
if hasattr(mod, "qconfig") and mod.qconfig is not None:
    swapped = False
    if type_before_parametrizations(mod) in custom_module_class_mapping:
        new_mod = custom_module_class_mapping[
            type_before_parametrizations(mod)
        ].from_observed(mod)
        swapped = True
    elif type_before_parametrizations(mod) in mapping:
        qmod = mapping[type_before_parametrizations(mod)]
        if hasattr(qmod, "_IS_REFERENCE") and qmod._IS_REFERENCE:
            assert mod.qconfig is not None
            weight_post_process = mod.qconfig.weight()
            weight_post_process(mod.weight)
            weight_qparams = get_qparam_dict(weight_post_process)
            new_mod = qmod.from_float(mod, weight_qparams)
        else:
            sig = inspect.signature(qmod.from_float)
            if "use_precomputed_fake_quant" in sig.parameters:
                new_mod = qmod.from_float(
                    mod, use_precomputed_fake_quant=use_precomputed_fake_quant
                )
            else:
                new_mod = qmod.from_float(mod)
        swapped = True

    if swapped:
        # Preserve module's pre forward hooks. They'll be called on quantized input
        for pre_hook_fn in mod._forward_pre_hooks.values():
            new_mod.register_forward_pre_hook(pre_hook_fn)
        # Preserve module's post forward hooks except _observer_forward_hook
        # After convert they'll work with quantized output
        for hook_fn in mod._forward_hooks.values():
            if hook_fn is not _observer_forward_hook:
                new_mod.register_forward_hook(hook_fn)

        # respect device affinity when swapping modules
        devices = _get_unique_devices_(mod)
        assert len(devices) <= 1 or (
            len(devices) == 2 and torch.device("meta") in devices
        ), f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
        device = next(iter(devices)) if len(devices) > 0 else None
        if device:
            new_mod.to(device)
return new_mod

def _get_observer_dict(mod, target_dict, prefix=""): r"""Traverse the modules and save all observers into dict. This is mainly used for quantization accuracy debug Args: mod: the top module we want to save all observers prefix: the prefix for the current module target_dict: the dictionary used to save all the observers """

def get_prefix(prefix):
    return prefix if prefix == "" else prefix + "."

if hasattr(mod, "activation_post_process"):
    target_dict[
        get_prefix(prefix) + "activation_post_process"
    ] = mod.activation_post_process
for name, child in mod.named_children():
    module_prefix = get_prefix(prefix) + name if prefix else name
    _get_observer_dict(child, target_dict, module_prefix)