tf.saved_model.load  |  TensorFlow v2.16.1 (original) (raw)

Load a SavedModel from export_dir.

View aliases

Compat aliases for migration

SeeMigration guide for more details.

tf.compat.v1.saved_model.load_v2

tf.saved_model.load(
    export_dir, tags=None, options=None
)

Used in the notebooks

Used in the guide Used in the tutorials
Using the SavedModel format Extension types Import a JAX model using JAX2TF Migrate the SavedModel workflow Ragged tensors Save and load a model using a distribution strategy Load text Simple audio recognition: Recognizing keywords Transfer learning with YAMNet for environmental sound classification Distributed training with DTensors

Signatures associated with the SavedModel are available as functions:

imported = tf.saved_model.load(path)
f = imported.signatures["serving_default"]
print(f(x=tf.constant([[1.]])))

Objects exported with tf.saved_model.save additionally have trackable objects and functions assigned to attributes:

exported = tf.train.Checkpoint(v=tf.Variable(3.))
exported.f = tf.function(
    lambda x: exported.v * x,
    input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
tf.saved_model.save(exported, path)
imported = tf.saved_model.load(path)
assert 3. == imported.v.numpy()
assert 6. == imported.f(x=tf.constant(2.)).numpy()

Loading Keras models

Keras models are trackable, so they can be saved to SavedModel. The object returned by tf.saved_model.load is not a Keras object (i.e. doesn't have.fit, .predict, etc. methods). A few attributes and functions are still available: .variables, .trainable_variables and .__call__.

model = tf.keras.Model(...)
tf.saved_model.save(model, path)
imported = tf.saved_model.load(path)
outputs = imported(inputs)

Use tf.keras.models.load_model to restore the Keras model.

Importing SavedModels from TensorFlow 1.x

1.x SavedModels APIs have a flat graph instead of tf.function objects. These SavedModels will be loaded with the following attributes:

imported = tf.saved_model.load(path_to_v1_saved_model)  
pruned = imported.prune("x:0", "out:0")  
pruned(tf.ones([]))  

See tf.compat.v1.wrap_function for details.

Consuming SavedModels asynchronously

When consuming SavedModels asynchronously (the producer is a separate process), the SavedModel directory will appear before all files have been written, and tf.saved_model.load will fail if pointed at an incomplete SavedModel. Rather than checking for the directory, check for "saved_model_dir/saved_model.pb". This file is written atomically as the lasttf.saved_model.save file operation.

Args
export_dir The SavedModel directory to load from.
tags A tag or sequence of tags identifying the MetaGraph to load. Optional if the SavedModel contains a single MetaGraph, as for those exported fromtf.saved_model.save.
options tf.saved_model.LoadOptions object that specifies options for loading.
Returns
A trackable object with a signatures attribute mapping from signature keys to functions. If the SavedModel was exported by tf.saved_model.save, it also points to trackable objects, functions, debug info which it has been saved.
Raises
ValueError If tags don't match a MetaGraph in the SavedModel.