tf.data.experimental.map_and_batch  |  TensorFlow v2.16.1 (original) (raw)

tf.data.experimental.map_and_batch

Stay organized with collections Save and categorize content based on your preferences.

Fused implementation of map and batch. (deprecated)

View aliases

Compat aliases for migration

SeeMigration guide for more details.

tf.compat.v1.data.experimental.map_and_batch

tf.data.experimental.map_and_batch(
    map_func,
    batch_size,
    num_parallel_batches=None,
    drop_remainder=False,
    num_parallel_calls=None
)

Maps map_func across batch_size consecutive elements of this dataset and then combines them into a batch. Functionally, it is equivalent to mapfollowed by batch. This API is temporary and deprecated since input pipeline optimization now fuses consecutive map and batch operations automatically.

Args
map_func A function mapping a nested structure of tensors to another nested structure of tensors.
batch_size A tf.int64 scalar tf.Tensor, representing the number of consecutive elements of this dataset to combine in a single batch.
num_parallel_batches (Optional.) A tf.int64 scalar tf.Tensor, representing the number of batches to create in parallel. On one hand, higher values can help mitigate the effect of stragglers. On the other hand, higher values can increase contention if CPU is scarce.
drop_remainder (Optional.) A tf.bool scalar tf.Tensor, representing whether the last batch should be dropped in case its size is smaller than desired; the default behavior is not to drop the smaller batch.
num_parallel_calls (Optional.) A tf.int32 scalar tf.Tensor, representing the number of elements to process in parallel. If not specified, batch_size * num_parallel_batches elements will be processed in parallel. If the value tf.data.AUTOTUNE is used, then the number of parallel calls is set dynamically based on available CPU.
Returns
A Dataset transformation function, which can be passed totf.data.Dataset.apply.
Raises
ValueError If both num_parallel_batches and num_parallel_calls are specified.

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.

Last updated 2024-04-26 UTC.