ModelPruning — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)
class lightning.pytorch.callbacks.ModelPruning(pruning_fn, parameters_to_prune=(), parameter_names=None, use_global_unstructured=True, amount=0.5, apply_pruning=True, make_pruning_permanent=True, use_lottery_ticket_hypothesis=True, resample_parameters=False, pruning_dim=None, pruning_norm=None, verbose=0, prune_on_train_epoch_end=True)[source]¶
Bases: Callback
Model pruning Callback, using PyTorch’s prune utilities. This callback is responsible of pruning networks parameters during training.
To learn more about pruning with PyTorch, please take a look atthis tutorial.
parameters_to_prune = [(model.mlp_1, "weight"), (model.mlp_2, "weight")]
trainer = Trainer( callbacks=[ ModelPruning( pruning_fn="l1_unstructured", parameters_to_prune=parameters_to_prune, amount=0.01, use_global_unstructured=True, ) ] )
When parameters_to_prune
is None
, parameters_to_prune
will contain all parameters from the model. The user can override filter_parameters_to_prune
to filter any nn.Module
to be pruned.
Parameters:
- pruning_fn¶ (Union[Callable, str]) – Function from torch.nn.utils.prune module or your own PyTorch
BasePruningMethod
subclass. Can also be string e.g. “l1_unstructured”. See pytorch docs for more details. - parameters_to_prune¶ (Sequence[tuple[Module, str]]) – List of tuples
(nn.Module, "parameter_name_string")
. - parameter_names¶ (Optional[list[str]]) – List of parameter names to be pruned from the nn.Module. Can either be
"weight"
or"bias"
. - use_global_unstructured¶ (bool) – Whether to apply pruning globally on the model. If
parameters_to_prune
is provided, global unstructured will be restricted on them. - amount¶ (Union[int, float, Callable[[int], Union[int, float]]]) –
Quantity of parameters to prune:float
. Between 0.0 and 1.0. Represents the fraction of parameters to prune.int
. Represents the absolute number of parameters to prune.Callable
. For dynamic values. Will be called every epoch. Should return a value.
- apply_pruning¶ (Union[bool, Callable[[int], bool]]) –
Whether to apply pruning.bool
. Always apply it or not.Callable[[epoch], bool]
. For dynamic values. Will be called every epoch.
- make_pruning_permanent¶ (bool) – Whether to remove all reparametrization pre-hooks and apply masks when training ends or the model is saved.
- use_lottery_ticket_hypothesis¶ (Union[bool, Callable[[int], bool]]) –
See The lottery ticket hypothesis:bool
. Whether to apply it or not.Callable[[epoch], bool]
. For dynamic values. Will be called every epoch.
- resample_parameters¶ (bool) – Used with
use_lottery_ticket_hypothesis
. If True, the model parameters will be resampled, otherwise, the exact original parameters will be used. - pruning_dim¶ (Optional[int]) – If you are using a structured pruning method you need to specify the dimension.
- pruning_norm¶ (Optional[int]) – If you are using
ln_structured
you need to specify the norm. - verbose¶ (int) – Verbosity level. 0 to disable, 1 to log overall sparsity, 2 to log per-layer sparsity
- prune_on_train_epoch_end¶ (bool) – whether to apply pruning at the end of the training epoch. If this is
False
, then the check runs at the end of the validation epoch.
Raises:
MisconfigurationException – If parameter_names
is neither "weight"
nor "bias"
, if the provided pruning_fn
is not supported, if pruning_dim
is not provided when "unstructured"
, if pruning_norm
is not provided when "ln_structured"
, if pruning_fn
is neither str
nor torch.nn.utils.prune.BasePruningMethod, or if amount
is none of int
, float
and Callable
.
apply_lottery_ticket_hypothesis()[source]¶
Lottery ticket hypothesis algorithm (see page 2 of the paper): :rtype: None
- Randomly initialize a neural network \(f(x; \theta_0)\) (where \(\theta_0 \sim \mathcal{D}_\theta\)).
- Train the network for \(j\) iterations, arriving at parameters \(\theta_j\).
- Prune \(p\%\) of the parameters in \(\theta_j\), creating a mask \(m\).
- Reset the remaining parameters to their values in \(\theta_0\), creating the winning ticket \(f(x; m \odot \theta_0)\).
This function implements the step 4.
The resample_parameters
argument can be used to reset the parameters with a new \(\theta_z \sim \mathcal{D}_\theta\)
apply_pruning(amount)[source]¶
Applies pruning to parameters_to_prune
.
Return type:
filter_parameters_to_prune(parameters_to_prune=())[source]¶
This function can be overridden to control which module to prune.
Return type:
make_pruning_permanent(module)[source]¶
Removes pruning buffers from any pruned modules.
Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/utils/prune.py#L1118-L1122
Return type:
on_save_checkpoint(trainer, pl_module, checkpoint)[source]¶
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
Parameters:
- trainer¶ (Trainer) – the current Trainer instance.
- pl_module¶ (LightningModule) – the current LightningModule instance.
- checkpoint¶ (dict[str, Any]) – the checkpoint dictionary that will be saved.
Return type:
on_train_end(trainer, pl_module)[source]¶
Called when the train ends.
Return type:
on_train_epoch_end(trainer, pl_module)[source]¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of thelightning.pytorch.core.LightningModule and access them in this hook:
class MyLightningModule(L.LightningModule): def init(self): super().init() self.training_step_outputs = []
def training_step(self):
loss = ...
self.training_step_outputs.append(loss)
return loss
class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
Return type:
on_validation_epoch_end(trainer, pl_module)[source]¶
Called when the val epoch ends.
Return type:
static sanitize_parameters_to_prune(pl_module, parameters_to_prune=(), parameter_names=())[source]¶
This function is responsible of sanitizing parameters_to_prune
and parameter_names
. Ifparameters_to_prune is None
, it will be generated with all parameters of the model.
Raises:
MisconfigurationException – If parameters_to_prune
doesn’t exist in the model, or if parameters_to_prune
is neither a list nor a tuple.
Return type:
setup(trainer, pl_module, stage)[source]¶
Called when fit, validate, test, predict, or tune begins.
Return type: