GradientAccumulationScheduler — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)
class lightning.pytorch.callbacks.GradientAccumulationScheduler(scheduling)[source]¶
Bases: Callback
Change gradient accumulation factor according to scheduling.
Parameters:
scheduling¶ (dict[int, int]) – scheduling in format {epoch: accumulation_factor}
Note
The argument scheduling is a dictionary. Each key represent an epoch and its associated accumulation factor value. Warning: Epoch are zero-indexed c.f it means if you want to change the accumulation factor after 4 epochs, set Trainer(accumulate_grad_batches={4: factor})
or GradientAccumulationScheduler(scheduling={4: factor})
. For more info check the example below.
Raises:
- TypeError – If
scheduling
is an emptydict
, or not all keys and values ofscheduling
are integers. - IndexError – If
minimal_epoch
is less than 0.
Example:
from lightning.pytorch import Trainer from lightning.pytorch.callbacks import GradientAccumulationScheduler
from epoch 5, it starts accumulating every 2 batches. Here we have 4 instead of 5
because epoch (key) should be zero-indexed.
accumulator = GradientAccumulationScheduler(scheduling={4: 2}) trainer = Trainer(callbacks=[accumulator])
on_train_epoch_start(trainer, *_)[source]¶
Called when the train epoch begins.
Return type:
on_train_start(trainer, pl_module)[source]¶
Performns a configuration validation before training starts and raises errors for incompatible settings.
Return type: