tf.compat.v1.train.init_from_checkpoint  |  TensorFlow v2.16.1 (original) (raw)

tf.compat.v1.train.init_from_checkpoint

Stay organized with collections Save and categorize content based on your preferences.

Replaces tf.Variable initializers so they load from a checkpoint file.

tf.compat.v1.train.init_from_checkpoint(
    ckpt_dir_or_file, assignment_map
)

Migrate to TF2

tf.compat.v1.train.init_from_checkpoint is not recommended for restoring variable values in TF2.

To restore checkpoints in TF2, please usetf.keras.Model.load_weights or tf.train.Checkpoint.restore. These APIs use use an object-based method of checkpointing, whiletf.compat.v1.init_from_checkpoint relies on a more-fragile variable-name based method of checkpointing. There is no object-based equivalent ofinit_from_checkpoint in TF2.

Please re-write your checkpoints immediately using the object-based APIs, see migration guide for more details.

You can load a name-based checkpoint written by tf.compat.v1.train.Saverusing tf.train.Checkpoint.restore or tf.keras.Model.load_weights. However, you may have to change the names of the variables in your model to match the variable names in the name-based checkpoint, which can be viewed withtf.train.list_variables(path).

Another option is to create an assignment_map that maps the name of the variables in the name-based checkpoint to the variables in your model, eg:

{
    'sequential/dense/bias': model.variables[0],
    'sequential/dense/kernel': model.variables[1]
}

and use tf.compat.v1.train.init_from_checkpoint(path, assignment_map) to restore the name-based checkpoint.

After restoring, re-encode your checkpoint using tf.train.Checkpoint.saveor tf.keras.Model.save_weights.

Description

Used in the notebooks

Used in the guide
Migrating model checkpoints

Values are not loaded immediately, but when the initializer is run (typically by running a tf.compat.v1.global_variables_initializer op).

Assignment map supports following syntax:

Supports loading into partitioned variables, which are represented as'<variable>/part_<part #>'.

Assignment map can be a dict, or a list of pairs. The latter is necessary to initialize multiple variables in the current graph from the same variable in the checkpoint.

Example:


# Say, '/tmp/model.ckpt' has the following tensors:
#  -- name='old_scope_1/var1', shape=[20, 2]
#  -- name='old_scope_1/var2', shape=[50, 4]
#  -- name='old_scope_2/var3', shape=[100, 100]

# Create new model's variables
with tf.compat.v1.variable_scope('new_scope_1'):
  var1 = tf.compat.v1.get_variable('var1', shape=[20, 2],
                         initializer=tf.compat.v1.zeros_initializer())
with tf.compat.v1.variable_scope('new_scope_2'):
  var2 = tf.compat.v1.get_variable('var2', shape=[50, 4],
                         initializer=tf.compat.v1.zeros_initializer())
  # Partition into 5 variables along the first axis.
  var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100],
                         initializer=tf.compat.v1.zeros_initializer(),
                         partitioner=lambda shape, dtype: [5, 1])

# Initialize all variables in `new_scope_1` from `old_scope_1`.
init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1/'})

# Use names to specify which variables to initialize from checkpoint.
init_from_checkpoint('/tmp/model.ckpt',
                     {'old_scope_1/var1': 'new_scope_1/var1',
                      'old_scope_1/var2': 'new_scope_2/var2'})

# Or use tf.Variable objects to identify what to initialize.
init_from_checkpoint('/tmp/model.ckpt',
                     {'old_scope_1/var1': var1,
                      'old_scope_1/var2': var2})

# Initialize partitioned variables using variable's name
init_from_checkpoint('/tmp/model.ckpt',
                     {'old_scope_2/var3': 'new_scope_2/var3'})

# Or specify the list of tf.Variable objects.
init_from_checkpoint('/tmp/model.ckpt',
                     {'old_scope_2/var3': var3._get_variable_list()})

Args
ckpt_dir_or_file Directory with checkpoints file or path to checkpoint.
assignment_map Dict, or a list of key-value pairs, where keys are names of the variables in the checkpoint and values are current variables or names of current variables (in default graph).
Raises
ValueError If missing variables in current graph, or if missing checkpoints or tensors in checkpoints.