StochasticWeightAveraging — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)

class lightning.pytorch.callbacks.StochasticWeightAveraging(swa_lrs, swa_epoch_start=0.8, annealing_epochs=10, annealing_strategy='cos', avg_fn=None, device=device(type='cpu'))[source]

Bases: Callback

Implements the Stochastic Weight Averaging (SWA) Callback to average a model.

Stochastic Weight Averaging was proposed in Averaging Weights Leads to Wider Optima and Better Generalization by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson (UAI 2018).

This documentation is highly inspired by PyTorch’s work on SWA. The callback arguments follow the scheme defined in PyTorch’s swa_utils package.

For a SWA explanation, please take a lookhere.

Warning

StochasticWeightAveraging is currently not supported for multiple optimizers/schedulers.

Warning

StochasticWeightAveraging is currently only supported on every epoch.

See also how to enable it directly on the Trainer

Parameters:

static avg_fn(averaged_model_parameter, model_parameter, num_averaged)[source]

Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.

Return type:

Tensor

load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict (dict[str, Any]) – the callback state returned by state_dict.

Return type:

None

on_fit_start(trainer, pl_module)[source]

Called when fit begins.

Return type:

None

on_train_end(trainer, pl_module)[source]

Called when the train ends.

Return type:

None

on_train_epoch_end(trainer, *args)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of thelightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule): def init(self): super().init() self.training_step_outputs = []

def training_step(self):
    loss = ...
    self.training_step_outputs.append(loss)
    return loss

class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()

Return type:

None

on_train_epoch_start(trainer, pl_module)[source]

Called when the train epoch begins.

Return type:

None

reset_batch_norm_and_save_state(pl_module)[source]

Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154.

Return type:

None

reset_momenta()[source]

Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.

Return type:

None

setup(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune begins.

Return type:

None

state_dict()[source]

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type:

dict[str, Any]

Returns:

A dictionary containing callback state.

static update_parameters(average_model, model, n_averaged, avg_fn)[source]

Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112.

Return type:

None