ModelCheckpoint — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)

class lightning.pytorch.callbacks.ModelCheckpoint(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=1, save_weights_only=False, mode='min', auto_insert_metric_name=True, every_n_train_steps=None, train_time_interval=None, every_n_epochs=None, save_on_train_epoch_end=None, enable_version_counter=True)[source]

Bases: Checkpoint

Save the model periodically by monitoring a quantity. Every metric logged withlog() or log_dict() is a candidate for the monitor key. For more information, see Checkpointing.

After training finishes, use best_model_path to retrieve the path to the best checkpoint file and best_model_score to retrieve its score.

Parameters:

custom path

saves a file like: my/path/epoch=0-step=10.ckpt

checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
By default, dirpath is None and will be set at runtime to the location specified by Trainer’sdefault_root_dir argument, and if the Trainer uses a logger, the path will also contain logger name and version.

save any arbitrary metrics like val_loss, etc. in name

saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt

checkpoint_callback = ModelCheckpoint(
... dirpath='my/path',
... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
... )
By default, filename is None and will be set to '{epoch}-{step}', where “epoch” and “step” match the number of finished epoch and optimizer steps respectively.

Note

For extra customization, ModelCheckpoint includes the following attributes:

For example, you can change the default last checkpoint name by doingcheckpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"

If you want to checkpoint every N hours, every M train batches, and/or every K val epochs, then you should create multiple ModelCheckpoint callbacks.

If the checkpoint’s dirpath changed from what it was before while resuming the training, only best_model_path will be reloaded and a warning will be issued.

Raises:

Example:

from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint

saves checkpoints to 'my/path/' at every epoch

checkpoint_callback = ModelCheckpoint(dirpath='my/path/') trainer = Trainer(callbacks=[checkpoint_callback])

save epoch and val_loss in name

saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt

checkpoint_callback = ModelCheckpoint( ... monitor='val_loss', ... dirpath='my/path/', ... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}' ... )

save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard

or Neptune, due to the presence of characters like '=' or '/')

saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt

checkpoint_callback = ModelCheckpoint( ... monitor='val/loss', ... dirpath='my/path/', ... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}', ... auto_insert_metric_name=False ... )

retrieve the best checkpoint after training

checkpoint_callback = ModelCheckpoint(dirpath='my/path/') trainer = Trainer(callbacks=[checkpoint_callback]) model = ... trainer.fit(model) checkpoint_callback.best_model_path

Tip

Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the following arguments:

monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval

Read more: Persisting Callback State

file_exists(filepath, trainer)[source]

Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks.

Return type:

bool

format_checkpoint_name(metrics, filename=None, ver=None)[source]

Generate a filename according to the defined template.

Example:

tmpdir = os.path.dirname(file) ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}') os.path.basename(ckpt.format_checkpoint_name(dict(epoch=0))) 'epoch=0.ckpt' ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}') os.path.basename(ckpt.format_checkpoint_name(dict(epoch=5))) 'epoch=005.ckpt' ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}') os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456))) 'epoch=2-val_loss=0.12.ckpt' os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.12), filename='{epoch:d}')) 'epoch=2.ckpt' ckpt = ModelCheckpoint(dirpath=tmpdir, ... filename='epoch={epoch}-validation_loss={val_loss:.2f}', ... auto_insert_metric_name=False) os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456))) 'epoch=2-validation_loss=0.12.ckpt' ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}') os.path.basename(ckpt.format_checkpoint_name({})) 'missing=0.ckpt' ckpt = ModelCheckpoint(filename='{step}') os.path.basename(ckpt.format_checkpoint_name(dict(step=0))) 'step=0.ckpt'

Return type:

str

load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict (dict[str, Any]) – the callback state returned by state_dict.

Return type:

None

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]

Save checkpoint on train batch end if we meet the criteria for every_n_train_steps

Return type:

None

on_train_epoch_end(trainer, pl_module)[source]

Save a checkpoint at the end of the training epoch.

Return type:

None

on_train_start(trainer, pl_module)[source]

Called when the train begins.

Return type:

None

on_validation_end(trainer, pl_module)[source]

Save a checkpoint at the end of the validation stage.

Return type:

None

setup(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune begins.

Return type:

None

state_dict()[source]

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type:

dict[str, Any]

Returns:

A dictionary containing callback state.

to_yaml(filepath=None)[source]

Saves the best_k_models dict containing the checkpoint paths with the corresponding scores to a YAML file.

Return type:

None

property state_key_: str_

Identifier for the state of the callback.

Used to store and retrieve a callback’s state from the checkpoint dictionary bycheckpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.