tf.vectorized_map  |  TensorFlow v2.0.0 (original) (raw)

tf.vectorized_map

Stay organized with collections Save and categorize content based on your preferences.

Parallel map on the list of tensors unpacked from elems on dimension 0.

View aliases

Compat aliases for migration

SeeMigration guide for more details.

tf.compat.v1.vectorized_map

tf.vectorized_map(
    fn, elems
)

This method works similar to tf.map_fn but is optimized to run much faster, possibly with a much larger memory footprint. The speedups are obtained by vectorization (see https://arxiv.org/pdf/1903.04243.pdf). The idea behind vectorization is to semantically launch all the invocations of fn in parallel and fuse corresponding operations across all these invocations. This fusion is done statically at graph generation time and the generated code is often similar in performance to a manually fused version.

Because tf.vectorized_map fully parallelizes the batch, this method will generally be significantly faster than using tf.map_fn, especially in eager mode. However this is an experimental feature and currently has a lot of limitations:

Args
fn The callable to be performed. It accepts one argument, which will have the same (possibly nested) structure as elems, and returns a possibly nested structure of Tensors and Operations, which may be different than the structure of elems.
elems A tensor or (possibly nested) sequence of tensors, each of which will be unpacked along their first dimension. The nested sequence of the resulting slices will be mapped over by fn.
Returns
A tensor or (possibly nested) sequence of tensors. Each tensor packs the results of applying fn to tensors unpacked from elems along the first dimension, from first to last.

Examples:

def outer_product(a):
  return tf.tensordot(a, a, 0)

batch_size = 100
a = tf.ones((batch_size, 32, 32))
c = tf.vectorized_map(outer_product, a)
assert c.shape == (batch_size, 32, 32, 32, 32)
# Computing per-example gradients

batch_size = 10
num_features = 32
layer = tf.keras.layers.Dense(1)

def model_fn(arg):
  with tf.GradientTape() as g:
    inp, label = arg
    inp = tf.expand_dims(inp, 0)
    label = tf.expand_dims(label, 0)
    prediction = layer(inp)
    loss = tf.nn.l2_loss(label - prediction)
  return g.gradient(loss, (layer.kernel, layer.bias))

inputs = tf.random_uniform([batch_size, num_features])
labels = tf.random_uniform([batch_size, 1])
per_example_gradients = tf.vectorized_map(model_fn, (inputs, labels))
assert per_example_gradients[0].shape == (batch_size, num_features, 1)
assert per_example_gradients[1].shape == (batch_size, 1)