LnStructured — PyTorch 2.7 documentation (original) (raw)

class torch.nn.utils.prune.LnStructured(amount, n, dim=-1)[source][source]

Prune entire (currently unpruned) channels in a tensor based on their Ln-norm.

Parameters

classmethod apply(module, name, amount, n, dim, importance_scores=None)[source][source]

Add pruning on the fly and reparametrization of a tensor.

Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask.

Parameters

apply_mask(module)[source]

Simply handles the multiplication between the parameter being pruned and the generated mask.

Fetches the mask and the original tensor from the module and returns the pruned version of the tensor.

Parameters

module (nn.Module) – module containing the tensor to prune

Returns

pruned version of the input tensor

Return type

pruned_tensor (torch.Tensor)

compute_mask(t, default_mask)[source][source]

Compute and returns a mask for the input tensor t.

Starting from a base default_mask (which should be a mask of ones if the tensor has not been pruned yet), generate a mask to apply on top of the default_mask by zeroing out the channels along the specified dim with the lowest Ln-norm.

Parameters

Returns

mask to apply to t, of same dims as t

Return type

mask (torch.Tensor)

Raises

IndexError – if self.dim >= len(t.shape)

prune(t, default_mask=None, importance_scores=None)[source]

Compute and returns a pruned version of input tensor t.

According to the pruning rule specified in compute_mask().

Parameters

Returns

pruned version of tensor t.

remove(module)[source]

Remove the pruning reparameterization from a module.

The pruned parameter named name remains permanently pruned, and the parameter named name+'_orig' is removed from the parameter list. Similarly, the buffer named name+'_mask' is removed from the buffers.

Note

Pruning itself is NOT undone or reversed!