EarlyStopping — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)
class lightning.pytorch.callbacks.EarlyStopping(monitor, min_delta=0.0, patience=3, verbose=False, mode='min', strict=True, check_finite=True, stopping_threshold=None, divergence_threshold=None, check_on_train_epoch_end=None, log_rank_zero_only=False)[source]¶
Bases: Callback
Monitor a metric and stop training when it stops improving.
Parameters:
- monitor¶ (str) – quantity to be monitored.
- min_delta¶ (float) – minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than or equal to min_delta, will count as no improvement.
- patience¶ (int) –
number of checks with no improvement after which training will be stopped. Under the default configuration, one check happens after every training epoch. However, the frequency of validation can be modified by setting various parameters on theTrainer
, for examplecheck_val_every_n_epoch
andval_check_interval
.
Note
It must be noted that the patience parameter counts the number of validation checks with no improvement, and not the number of training epochs. Therefore, with parameterscheck_val_every_n_epoch=10
andpatience=3
, the trainer will perform at least 40 training epochs before being stopped. - verbose¶ (bool) – verbosity mode.
- mode¶ (str) – one of
'min'
,'max'
. In'min'
mode, training will stop when the quantity monitored has stopped decreasing and in'max'
mode it will stop when the quantity monitored has stopped increasing. - strict¶ (bool) – whether to crash the training if monitor is not found in the validation metrics.
- check_finite¶ (bool) – When set
True
, stops training when the monitor becomes NaN or infinite. - stopping_threshold¶ (Optional[float]) – Stop training immediately once the monitored quantity reaches this threshold.
- divergence_threshold¶ (Optional[float]) – Stop training as soon as the monitored quantity becomes worse than this threshold.
- check_on_train_epoch_end¶ (Optional[bool]) – whether to run early stopping at the end of the training epoch. If this is
False
, then the check runs at the end of the validation. - log_rank_zero_only¶ (bool) – When set
True
, logs the status of the early stopping callback only for rank 0 process.
Raises:
- MisconfigurationException – If
mode
is none of"min"
or"max"
. - RuntimeError – If the metric
monitor
is not available.
Example:
from lightning.pytorch import Trainer from lightning.pytorch.callbacks import EarlyStopping early_stopping = EarlyStopping('val_loss') trainer = Trainer(callbacks=[early_stopping])
Tip
Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the following arguments:
monitor, mode
Read more: Persisting Callback State
load_state_dict(state_dict)[source]¶
Called when loading a checkpoint, implement to reload callback state given callback’s state_dict
.
Parameters:
state_dict¶ (dict[str, Any]) – the callback state returned by state_dict
.
Return type:
on_train_epoch_end(trainer, pl_module)[source]¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of thelightning.pytorch.core.LightningModule and access them in this hook:
class MyLightningModule(L.LightningModule): def init(self): super().init() self.training_step_outputs = []
def training_step(self):
loss = ...
self.training_step_outputs.append(loss)
return loss
class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
Return type:
on_validation_end(trainer, pl_module)[source]¶
Called when the validation loop ends.
Return type:
setup(trainer, pl_module, stage)[source]¶
Called when fit, validate, test, predict, or tune begins.
Return type:
Called when saving a checkpoint, implement to generate callback’s state_dict
.
Return type:
Returns:
A dictionary containing callback state.
Identifier for the state of the callback.
Used to store and retrieve a callback’s state from the checkpoint dictionary bycheckpoint["callbacks"][state_key]
. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.