tf.tpu.experimental.embedding.serving_embedding_lookup  |  TensorFlow v2.16.1 (original) (raw)

Apply standard lookup ops with tf.tpu.experimental.embedding configs.

View aliases

Compat aliases for migration

SeeMigration guide for more details.

tf.compat.v1.tpu.experimental.embedding.serving_embedding_lookup

tf.tpu.experimental.embedding.serving_embedding_lookup(
    inputs: Any,
    weights: Optional[Any],
    tables: Dict[tf.tpu.experimental.embedding.TableConfig, tf.Variable],
    feature_config: Union[tf.tpu.experimental.embedding.FeatureConfig, Iterable]
) -> Any

This function is a utility which allows using thetf.tpu.experimental.embedding config objects with standard lookup functions. This can be used when exporting a model which usestf.tpu.experimental.embedding.TPUEmbedding for serving on CPU. In particulartf.tpu.experimental.embedding.TPUEmbedding only supports lookups on TPUs and should not be part of your serving graph.

Note that TPU specific options (such as max_sequence_length) in the configuration objects will be ignored.

In the following example we take a trained model (see the documentation fortf.tpu.experimental.embedding.TPUEmbedding for the context) and create a saved model with a serving function that will perform the embedding lookup and pass the results to your model:

model = model_fn(...)
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
    feature_config=feature_config,
    batch_size=1024,
    optimizer=tf.tpu.experimental.embedding.SGD(0.1))
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
checkpoint.restore(...)

@tf.function(input_signature=[{'feature_one': tf.TensorSpec(...),
                               'feature_two': tf.TensorSpec(...),
                               'feature_three': tf.TensorSpec(...)}])
def serve_tensors(embedding_features):
  embedded_features = tf.tpu.experimental.embedding.serving_embedding_lookup(
      embedding_features, None, embedding.embedding_tables,
      feature_config)
  return model(embedded_features)

model.embedding_api = embedding
tf.saved_model.save(model,
                    export_dir=...,
                    signatures={'serving_default': serve_tensors})

Args
inputs a nested structure of Tensors, SparseTensors or RaggedTensors.
weights a nested structure of Tensors, SparseTensors or RaggedTensors or None for no weights. If not None, structure must match that of inputs, but entries are allowed to be None.
tables a dict of mapping TableConfig objects to Variables.
feature_config a nested structure of FeatureConfig objects with the same structure as inputs.
Returns
A nested structure of Tensors with the same structure as inputs.