GradientAccumulationScheduler — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)

class lightning.pytorch.callbacks.GradientAccumulationScheduler(scheduling)[source]

Bases: Callback

Change gradient accumulation factor according to scheduling.

Parameters:

scheduling (dict[int, int]) – scheduling in format {epoch: accumulation_factor}

Note

The argument scheduling is a dictionary. Each key represent an epoch and its associated accumulation factor value. Warning: Epoch are zero-indexed c.f it means if you want to change the accumulation factor after 4 epochs, set Trainer(accumulate_grad_batches={4: factor})or GradientAccumulationScheduler(scheduling={4: factor}). For more info check the example below.

Raises:

Example:

from lightning.pytorch import Trainer from lightning.pytorch.callbacks import GradientAccumulationScheduler

from epoch 5, it starts accumulating every 2 batches. Here we have 4 instead of 5

because epoch (key) should be zero-indexed.

accumulator = GradientAccumulationScheduler(scheduling={4: 2}) trainer = Trainer(callbacks=[accumulator])

on_train_epoch_start(trainer, *_)[source]

Called when the train epoch begins.

Return type:

None

on_train_start(trainer, pl_module)[source]

Performns a configuration validation before training starts and raises errors for incompatible settings.

Return type:

None