EarlyStopping — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)

class lightning.pytorch.callbacks.EarlyStopping(monitor, min_delta=0.0, patience=3, verbose=False, mode='min', strict=True, check_finite=True, stopping_threshold=None, divergence_threshold=None, check_on_train_epoch_end=None, log_rank_zero_only=False)[source]

Bases: Callback

Monitor a metric and stop training when it stops improving.

Parameters:

Raises:

Example:

from lightning.pytorch import Trainer from lightning.pytorch.callbacks import EarlyStopping early_stopping = EarlyStopping('val_loss') trainer = Trainer(callbacks=[early_stopping])

Tip

Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the following arguments:

monitor, mode

Read more: Persisting Callback State

load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict (dict[str, Any]) – the callback state returned by state_dict.

Return type:

None

on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of thelightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule): def init(self): super().init() self.training_step_outputs = []

def training_step(self):
    loss = ...
    self.training_step_outputs.append(loss)
    return loss

class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()

Return type:

None

on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

Return type:

None

setup(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune begins.

Return type:

None

state_dict()[source]

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type:

dict[str, Any]

Returns:

A dictionary containing callback state.

property state_key_: str_

Identifier for the state of the callback.

Used to store and retrieve a callback’s state from the checkpoint dictionary bycheckpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.