ModelPruning — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)

class lightning.pytorch.callbacks.ModelPruning(pruning_fn, parameters_to_prune=(), parameter_names=None, use_global_unstructured=True, amount=0.5, apply_pruning=True, make_pruning_permanent=True, use_lottery_ticket_hypothesis=True, resample_parameters=False, pruning_dim=None, pruning_norm=None, verbose=0, prune_on_train_epoch_end=True)[source]

Bases: Callback

Model pruning Callback, using PyTorch’s prune utilities. This callback is responsible of pruning networks parameters during training.

To learn more about pruning with PyTorch, please take a look atthis tutorial.

parameters_to_prune = [(model.mlp_1, "weight"), (model.mlp_2, "weight")]

trainer = Trainer( callbacks=[ ModelPruning( pruning_fn="l1_unstructured", parameters_to_prune=parameters_to_prune, amount=0.01, use_global_unstructured=True, ) ] )

When parameters_to_prune is None, parameters_to_prune will contain all parameters from the model. The user can override filter_parameters_to_prune to filter any nn.Module to be pruned.

Parameters:

Raises:

MisconfigurationException – If parameter_names is neither "weight" nor "bias", if the provided pruning_fn is not supported, if pruning_dim is not provided when "unstructured", if pruning_norm is not provided when "ln_structured", if pruning_fn is neither str nor torch.nn.utils.prune.BasePruningMethod, or if amount is none of int, float and Callable.

apply_lottery_ticket_hypothesis()[source]

Lottery ticket hypothesis algorithm (see page 2 of the paper): :rtype: None

  1. Randomly initialize a neural network \(f(x; \theta_0)\) (where \(\theta_0 \sim \mathcal{D}_\theta\)).
  2. Train the network for \(j\) iterations, arriving at parameters \(\theta_j\).
  3. Prune \(p\%\) of the parameters in \(\theta_j\), creating a mask \(m\).
  4. Reset the remaining parameters to their values in \(\theta_0\), creating the winning ticket \(f(x; m \odot \theta_0)\).

This function implements the step 4.

The resample_parameters argument can be used to reset the parameters with a new \(\theta_z \sim \mathcal{D}_\theta\)

apply_pruning(amount)[source]

Applies pruning to parameters_to_prune.

Return type:

None

filter_parameters_to_prune(parameters_to_prune=())[source]

This function can be overridden to control which module to prune.

Return type:

Sequence[tuple[Module, str]]

make_pruning_permanent(module)[source]

Removes pruning buffers from any pruned modules.

Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/utils/prune.py#L1118-L1122

Return type:

None

on_save_checkpoint(trainer, pl_module, checkpoint)[source]

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:

Return type:

None

on_train_end(trainer, pl_module)[source]

Called when the train ends.

Return type:

None

on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of thelightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule): def init(self): super().init() self.training_step_outputs = []

def training_step(self):
    loss = ...
    self.training_step_outputs.append(loss)
    return loss

class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()

Return type:

None

on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

Return type:

None

static sanitize_parameters_to_prune(pl_module, parameters_to_prune=(), parameter_names=())[source]

This function is responsible of sanitizing parameters_to_prune and parameter_names. Ifparameters_to_prune is None, it will be generated with all parameters of the model.

Raises:

MisconfigurationException – If parameters_to_prune doesn’t exist in the model, or if parameters_to_prune is neither a list nor a tuple.

Return type:

Sequence[tuple[Module, str]]

setup(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune begins.

Return type:

None