tf.tpu.experimental.embedding.serving_embedding_lookup  |  TensorFlow v2.16.1 (original) (raw)

tf.tpu.experimental.embedding.serving_embedding_lookup

Stay organized with collections Save and categorize content based on your preferences.

Apply standard lookup ops with tf.tpu.experimental.embedding configs.

View aliases

Compat aliases for migration

SeeMigration guide for more details.

tf.compat.v1.tpu.experimental.embedding.serving_embedding_lookup

tf.tpu.experimental.embedding.serving_embedding_lookup(
    inputs: Any,
    weights: Optional[Any],
    tables: Dict[tf.tpu.experimental.embedding.TableConfig, tf.Variable],
    feature_config: Union[tf.tpu.experimental.embedding.FeatureConfig, Iterable]
) -> Any

This function is a utility which allows using thetf.tpu.experimental.embedding config objects with standard lookup functions. This can be used when exporting a model which usestf.tpu.experimental.embedding.TPUEmbedding for serving on CPU. In particulartf.tpu.experimental.embedding.TPUEmbedding only supports lookups on TPUs and should not be part of your serving graph.

Note that TPU specific options (such as max_sequence_length) in the configuration objects will be ignored.

In the following example we take a trained model (see the documentation fortf.tpu.experimental.embedding.TPUEmbedding for the context) and create a saved model with a serving function that will perform the embedding lookup and pass the results to your model:

model = model_fn(...)
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
    feature_config=feature_config,
    batch_size=1024,
    optimizer=tf.tpu.experimental.embedding.SGD(0.1))
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
checkpoint.restore(...)

@tf.function(input_signature=[{'feature_one': tf.TensorSpec(...),
                               'feature_two': tf.TensorSpec(...),
                               'feature_three': tf.TensorSpec(...)}])
def serve_tensors(embedding_features):
  embedded_features = tf.tpu.experimental.embedding.serving_embedding_lookup(
      embedding_features, None, embedding.embedding_tables,
      feature_config)
  return model(embedded_features)

model.embedding_api = embedding
tf.saved_model.save(model,
                    export_dir=...,
                    signatures={'serving_default': serve_tensors})

Args
inputs a nested structure of Tensors, SparseTensors or RaggedTensors.
weights a nested structure of Tensors, SparseTensors or RaggedTensors or None for no weights. If not None, structure must match that of inputs, but entries are allowed to be None.
tables a dict of mapping TableConfig objects to Variables.
feature_config a nested structure of FeatureConfig objects with the same structure as inputs.
Returns
A nested structure of Tensors with the same structure as inputs.