torch.nn.utils.prune.ln_structured — PyTorch 2.7 documentation (original) (raw)

torch.nn.utils.prune.ln_structured(module, name, amount, n, dim, importance_scores=None)[source][source]

Prune tensor by removing channels with the lowest Ln-norm along the specified dimension.

Prunes tensor corresponding to parameter called name in moduleby removing the specified amount of (currently unpruned) channels along the specified dim with the lowest Ln-norm. Modifies module in place (and also return the modified module) by:

  1. adding a named buffer called name+'_mask' corresponding to the binary mask applied to the parameter name by the pruning method.
  2. replacing the parameter name by its pruned version, while the original (unpruned) parameter is stored in a new parameter namedname+'_orig'.

Parameters

Returns

modified (i.e. pruned) version of the input module

Return type

module (nn.Module)

Examples

from torch.nn.utils import prune m = prune.ln_structured( ... nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf') ... )