PruningContainer — PyTorch 2.7 documentation (original) (raw)

class torch.nn.utils.prune.PruningContainer(*args)[source][source]

Container holding a sequence of pruning methods for iterative pruning.

Keeps track of the order in which pruning methods are applied and handles combining successive pruning calls.

Accepts as argument an instance of a BasePruningMethod or an iterable of them.

add_pruning_method(method)[source][source]

Add a child pruning method to the container.

Parameters

method (subclass of BasePruningMethod) – child pruning method to be added to the container.

classmethod apply(module, name, *args, importance_scores=None, **kwargs)[source]

Add pruning on the fly and reparametrization of a tensor.

Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask.

Parameters

apply_mask(module)[source]

Simply handles the multiplication between the parameter being pruned and the generated mask.

Fetches the mask and the original tensor from the module and returns the pruned version of the tensor.

Parameters

module (nn.Module) – module containing the tensor to prune

Returns

pruned version of the input tensor

Return type

pruned_tensor (torch.Tensor)

compute_mask(t, default_mask)[source][source]

Apply the latest method by computing the new partial masks and returning its combination with the default_mask.

The new partial mask should be computed on the entries or channels that were not zeroed out by the default_mask. Which portions of the tensor t the new mask will be calculated from depends on the PRUNING_TYPE (handled by the type handler):

Parameters

Returns

new mask that combines the effects of the default_mask and the new mask from the current pruning method (of same dimensions as default_mask andt).

Return type

mask (torch.Tensor)

prune(t, default_mask=None, importance_scores=None)[source]

Compute and returns a pruned version of input tensor t.

According to the pruning rule specified in compute_mask().

Parameters

Returns

pruned version of tensor t.

remove(module)[source]

Remove the pruning reparameterization from a module.

The pruned parameter named name remains permanently pruned, and the parameter named name+'_orig' is removed from the parameter list. Similarly, the buffer named name+'_mask' is removed from the buffers.

Note

Pruning itself is NOT undone or reversed!