torch.nn.utils.prune.custom_from_mask — PyTorch 2.7 documentation (original) (raw)
torch.nn.utils.prune.custom_from_mask(module, name, mask)[source][source]¶
Prune tensor corresponding to parameter called name
in module
by applying the pre-computed mask in mask
.
Modifies module in place (and also return the modified module) by:
- adding a named buffer called
name+'_mask'
corresponding to the binary mask applied to the parametername
by the pruning method. - replacing the parameter
name
by its pruned version, while the original (unpruned) parameter is stored in a new parameter namedname+'_orig'
.
Parameters
- module (nn.Module) – module containing the tensor to prune
- name (str) – parameter name within
module
on which pruning will act. - mask (Tensor) – binary mask to be applied to the parameter.
Returns
modified (i.e. pruned) version of the input module
Return type
module (nn.Module)
Examples
from torch.nn.utils import prune m = prune.custom_from_mask( ... nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0]) ... ) print(m.bias_mask) tensor([0., 1., 0.])