tf.math.approx_max_k  |  TensorFlow v2.16.1 (original) (raw)

tf.math.approx_max_k

Stay organized with collections Save and categorize content based on your preferences.

Returns max k values and their indices of the input operand in an approximate manner.

View aliases

Main aliases

tf.nn.approx_max_k

Compat aliases for migration

SeeMigration guide for more details.

tf.compat.v1.math.approx_max_k, tf.compat.v1.nn.approx_max_k

tf.math.approx_max_k(
    operand,
    k,
    reduction_dimension=-1,
    recall_target=0.95,
    reduction_input_size_override=-1,
    aggregate_to_topk=True,
    name=None
)

See https://arxiv.org/abs/2206.14286 for the algorithm details. This op is only optimized on TPU currently.

Args
operand Array to search for max-k. Must be a floating number type.
k Specifies the number of max-k.
reduction_dimension Integer dimension along which to search. Default: -1.
recall_target Recall target for the approximation.
reduction_input_size_override When set to a positive value, it overrides the size determined by operand[reduction_dim] for evaluating the recall. This option is useful when the given operand is only a subset of the overall computation in SPMD or distributed pipelines, where the true input size cannot be deferred by the operand shape.
aggregate_to_topk When true, aggregates approximate results to top-k. When false, returns the approximate results. The number of the approximate results is implementation defined and is greater equals to the specifiedk.
name Optional name for the operation.
Returns
Tuple of two arrays. The arrays are the max k values and the corresponding indices along the reduction_dimension of the inputoperand. The arrays' dimensions are the same as the input operandexcept for the reduction_dimension: when aggregate_to_topk is true, the reduction dimension is k; otherwise, it is greater equals to kwhere the size is implementation-defined.

We encourage users to wrap approx_max_k with jit. See the following example for maximal inner production search (MIPS):

import tensorflow as tf @tf.function(jit_compile=True) def mips(qy, db, k=10, recall_target=0.95): dists = tf.einsum('ik,jk->ij', qy, db) # returns (f32[qy_size, k], i32[qy_size, k]) return tf.nn.approx_max_k(dists, k=k, recall_target=recall_target) `` qy = tf.random.uniform((256,128)) db = tf.random.uniform((2048,128)) dot_products, neighbors = mips(qy, db, k=20)