tf.train.CheckpointManager  |  TensorFlow v2.16.1 (original) (raw)

tf.train.CheckpointManager

Stay organized with collections Save and categorize content based on your preferences.

Manages multiple checkpoints by keeping some and deleting unneeded ones.

View aliases

Compat aliases for migration

SeeMigration guide for more details.

tf.compat.v1.train.CheckpointManager

tf.train.CheckpointManager(
    checkpoint,
    directory,
    max_to_keep,
    keep_checkpoint_every_n_hours=None,
    checkpoint_name='ckpt',
    step_counter=None,
    checkpoint_interval=None,
    init_fn=None
)

Used in the notebooks

Used in the guide Used in the tutorials
Training checkpoints tf.data: Build TensorFlow input pipelines Migrate the fault tolerance mechanism Distributed training with DTensors Custom training loop with Keras and MultiWorkerMirroredStrategy Multi-worker training with Keras CycleGAN Training with Orbit

Example usage:

import tensorflow as tf
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager(
    checkpoint, directory="/tmp/model", max_to_keep=5)
status = checkpoint.restore(manager.latest_checkpoint)
while True:
  # train
  manager.save()

CheckpointManager preserves its own state across instantiations (see the__init__ documentation for details). Only one should be active in a particular directory at a time.

Args
checkpoint The tf.train.Checkpoint instance to save and manage checkpoints for.
directory The path to a directory in which to write checkpoints. A special file named "checkpoint" is also written to this directory (in a human-readable text format) which contains the state of theCheckpointManager.
max_to_keep An integer, the number of checkpoints to keep. Unless preserved by keep_checkpoint_every_n_hours, checkpoints will be deleted from the active set, oldest first, until only max_to_keepcheckpoints remain. If None, no checkpoints are deleted and everything stays in the active set. Note that max_to_keep=None will keep all checkpoint paths in memory and in the checkpoint state protocol buffer on disk.
keep_checkpoint_every_n_hours Upon removal from the active set, a checkpoint will be preserved if it has been at leastkeep_checkpoint_every_n_hours since the last preserved checkpoint. The default setting of None does not preserve any checkpoints in this way.
checkpoint_name Custom name for the checkpoint file.
step_counter A tf.Variable instance for checking the current step counter value, in case users want to save checkpoints every N steps.
checkpoint_interval An integer, indicates the minimum step interval between two checkpoints.
init_fn Callable. A function to do customized intialization if no checkpoints are in the directory.
Raises
ValueError If max_to_keep is not a positive integer.
Attributes
checkpoint Returns the tf.train.Checkpoint object.
checkpoint_interval
checkpoints A list of managed checkpoints.Note that checkpoints saved due to keep_checkpoint_every_n_hours will not show up in this list (to avoid ever-growing filename lists).
directory
latest_checkpoint The prefix of the most recent checkpoint in directory.Equivalent to tf.train.latest_checkpoint(directory) where directory is the constructor argument to CheckpointManager. Suitable for passing to tf.train.Checkpoint.restore to resume training.

Methods

restore_or_initialize

View source

restore_or_initialize()

Restore items in checkpoint from the latest checkpoint file.

This method will first try to restore from the most recent checkpoint indirectory. If no checkpoints exist in directory, and init_fn is specified, this method will call init_fn to do customized initialization. This can be used to support initialization from pretrained models.

Note that unlike tf.train.Checkpoint.restore(), this method doesn't return a load status object that users can run assertions on (e.g. assert_consumed()). Thus to run assertions, users should directly usetf.train.Checkpoint.restore() method.

Returns
The restored checkpoint path if the lastest checkpoint is found and restored. Otherwise None.

save

View source

save(
    checkpoint_number=None, check_interval=True, options=None
)

Creates a new checkpoint and manages it.

Args
checkpoint_number An optional integer, or an integer-dtype Variable orTensor, used to number the checkpoint. If None (default), checkpoints are numbered using checkpoint.save_counter. Even ifcheckpoint_number is provided, save_counter is still incremented. A user-provided checkpoint_number is not incremented even if it is aVariable.
check_interval An optional boolean. The argument is only effective whencheckpoint_interval is passed into the manager. If True, the manager will only save the checkpoint if the interval between checkpoints is larger than checkpoint_interval. Otherwise it will always save the checkpoint unless a checkpoint has already been saved for the current step.
options Optional tf.train.CheckpointOptions object. This argument only works with TF2 checkpoint objects. For example, options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
Returns
The path to the new checkpoint. It is also recorded in the checkpointsand latest_checkpoint properties. None if no checkpoint is saved.

sync

View source

sync()

Wait for any outstanding save or restore operations.