tf.compat.v1.map_fn | TensorFlow v2.16.1 (original) (raw)
tf.compat.v1.map_fn
Transforms elems by applying fn to each element unstacked on axis 0. (deprecated arguments)
tf.compat.v1.map_fn(
fn,
elems,
dtype=None,
parallel_iterations=None,
back_prop=True,
swap_memory=False,
infer_shape=True,
name=None,
fn_output_signature=None
)
Used in the notebooks
| Used in the tutorials |
|---|
| Classify Flowers with Transfer Learning |
See also tf.scan.
map_fn unstacks elems on axis 0 to obtain a sequence of elements; calls fn to transform each element; and then stacks the transformed values back together.
Mapping functions with single-Tensor inputs and outputs
If elems is a single tensor and fn's signature is tf.Tensor->tf.Tensor, then map_fn(fn, elems) is equivalent totf.stack([fn(elem) for elem in tf.unstack(elems)]). E.g.:
tf.map_fn(fn=lambda t: tf.range(t, t + 3), elems=tf.constant([3, 5, 2]))
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[3, 4, 5],
[5, 6, 7],
[2, 3, 4]], dtype=int32)>
map_fn(fn, elems).shape = [elems.shape[0]] + fn(elems[0]).shape.
Mapping functions with multi-arity inputs and outputs
map_fn also supports functions with multi-arity inputs and outputs:
- If
elemsis a tuple (or nested structure) of tensors, then those tensors must all have the same outer-dimension size (num_elems); andfnis used to transform each tuple (or structure) of corresponding slices fromelems. E.g., ifelemsis a tuple(t1, t2, t3), thenfnis used to transform each tuple of slices(t1[i], t2[i], t3[i])(where0 <= i < num_elems). - If
fnreturns a tuple (or nested structure) of tensors, then the result is formed by stacking corresponding elements from those structures.
Specifying fn's output signature
If fn's input and output signatures are different, then the output signature must be specified using fn_output_signature. (The input and output signatures are differ if their structures, dtypes, or tensor types do not match). E.g.:
tf.map_fn(fn=tf.strings.length, # input & output have different dtypes
elems=tf.constant(["hello", "moon"]),
fn_output_signature=tf.int32)
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([5, 4], dtype=int32)>
tf.map_fn(fn=tf.strings.join, # input & output have different structures
elems=[tf.constant(['The', 'A']), tf.constant(['Dog', 'Cat'])],
fn_output_signature=tf.string)
<tf.Tensor: shape=(2,), dtype=string,
numpy=array([b'TheDog', b'ACat'], dtype=object)>
fn_output_signature can be specified using any of the following:
- A tf.DType or tf.TensorSpec (to describe a tf.Tensor)
- A tf.RaggedTensorSpec (to describe a tf.RaggedTensor)
- A tf.SparseTensorSpec (to describe a tf.sparse.SparseTensor)
- A (possibly nested) tuple, list, or dict containing the above types.
RaggedTensors
map_fn supports tf.RaggedTensor inputs and outputs. In particular:
- If
elemsis aRaggedTensor, thenfnwill be called with each row of that ragged tensor.- If
elemshas only one ragged dimension, then the values passed tofnwill be tf.Tensors. - If
elemshas multiple ragged dimensions, then the values passed tofnwill be tf.RaggedTensors with one fewer ragged dimension.
- If
- If the result of
map_fnshould be aRaggedTensor, then use atf.RaggedTensorSpec to specifyfn_output_signature.- If
fnreturns tf.Tensors with varying sizes, then use atf.RaggedTensorSpec withragged_rank=0to combine them into a single ragged tensor (which will have ragged_rank=1). - If
fnreturns tf.RaggedTensors, then use a tf.RaggedTensorSpecwith the sameragged_rank.
- If
# Example: RaggedTensor input
rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
tf.map_fn(tf.reduce_sum, rt, fn_output_signature=tf.int32)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([6, 0, 9, 6], dtype=int32)>
# Example: RaggedTensor output
elems = tf.constant([3, 5, 0, 2])
tf.map_fn(tf.range, elems,
fn_output_signature=tf.RaggedTensorSpec(shape=[None],
dtype=tf.int32))
<tf.RaggedTensor [[0, 1, 2], [0, 1, 2, 3, 4], [], [0, 1]]>
- tf.ragged.map_flat_values(fn, rt)(if fn is expressible as TensorFlow ops)
rt.with_flat_values(map_fn(fn, rt.flat_values))(otherwise)
E.g.:
rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
tf.ragged.map_flat_values(lambda x: x + 2, rt)
<tf.RaggedTensor [[3, 4, 5], [], [6, 7], [8]]>
SparseTensors
map_fn supports tf.sparse.SparseTensor inputs and outputs. In particular:
- If
elemsis aSparseTensor, thenfnwill be called with each row of that sparse tensor. In particular, the value passed tofnwill be atf.sparse.SparseTensor with one fewer dimension thanelems. - If the result of
map_fnshould be aSparseTensor, then use atf.SparseTensorSpec to specifyfn_output_signature. The individualSparseTensors returned byfnwill be stacked into a singleSparseTensorwith one more dimension.
# Example: SparseTensor input
st = tf.sparse.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4])
tf.map_fn(tf.sparse.reduce_sum, st, fn_output_signature=tf.int32)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([2, 0, 7, 0], dtype=int32)>
# Example: SparseTensor output
tf.sparse.to_dense(
tf.map_fn(tf.sparse.eye, tf.constant([2, 3]),
fn_output_signature=tf.SparseTensorSpec(None, tf.float32)))
<tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
array([[[1., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]],
[[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]]], dtype=float32)>
- If the function is expressible as TensorFlow ops, use:
tf.sparse.SparseTensor(st.indices, fn(st.values), st.dense_shape) - Otherwise, use:
tf.sparse.SparseTensor(st.indices, tf.map_fn(fn, st.values),
st.dense_shape) map_fn vs. vectorized operations
map_fn will apply the operations used by fn to each element of elems, resulting in O(elems.shape[0]) total operations. This is somewhat mitigated by the fact that map_fn can process elements in parallel. However, a transform expressed using map_fn is still typically less efficient than an equivalent transform expressed using vectorized operations.
map_fn should typically only be used if one of the following is true:
- It is difficult or expensive to express the desired transform with vectorized operations.
fncreates large intermediate values, so an equivalent vectorized transform would take too much memory.- Processing elements in parallel is more efficient than an equivalent vectorized transform.
- Efficiency of the transform is not critical, and using
map_fnis more readable.
E.g., the example given above that maps fn=lambda t: tf.range(t, t + 3)across elems could be rewritten more efficiently using vectorized ops:
elems = tf.constant([3, 5, 2])
tf.range(3) + tf.expand_dims(elems, 1)
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[3, 4, 5],
[5, 6, 7],
[2, 3, 4]], dtype=int32)>
In some cases, tf.vectorized_map can be used to automatically convert a function to a vectorized equivalent.
Eager execution
When executing eagerly, map_fn does not execute in parallel even ifparallel_iterations is set to a value > 1. You can still get the performance benefits of running a function in parallel by using thetf.function decorator:
fn=lambda t: tf.range(t, t + 3)
@tf.function
def func(elems):
return tf.map_fn(fn, elems, parallel_iterations=3)
func(tf.constant([3, 5, 2]))
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[3, 4, 5],
[5, 6, 7],
[2, 3, 4]], dtype=int32)>
| Args | |
|---|---|
| fn | The callable to be performed. It accepts one argument, which will have the same (possibly nested) structure as elems. Its output must have the same structure as fn_output_signature if one is provided; otherwise it must have the same structure as elems. |
| elems | A tensor or (possibly nested) sequence of tensors, each of which will be unstacked along their first dimension. fn will be applied to the nested sequence of the resulting slices. elems may include ragged and sparse tensors. elems must consist of at least one tensor. |
| dtype | Deprecated: Equivalent to fn_output_signature. |
| parallel_iterations | (optional) The number of iterations allowed to run in parallel. When graph building, the default value is 10. While executing eagerly, the default value is set to 1. |
| back_prop | (optional) False disables support for back propagation. |
| swap_memory | (optional) True enables GPU-CPU memory swapping. |
| infer_shape | (optional) False disables tests for consistent output shapes. |
| name | (optional) Name prefix for the returned tensors. |
| fn_output_signature | The output signature of fn. Must be specified iffn's input and output signatures are different (i.e., if their structures, dtypes, or tensor types do not match).fn_output_signature can be specified using any of the following: A tf.DType or tf.TensorSpec (to describe a tf.Tensor) A tf.RaggedTensorSpec (to describe a tf.RaggedTensor) A tf.SparseTensorSpec (to describe a tf.sparse.SparseTensor) A (possibly nested) tuple, list, or dict containing the above types. |
| Returns |
|---|
| A tensor or (possibly nested) sequence of tensors. Each tensor stacks the results of applying fn to tensors unstacked from elems along the first dimension, from first to last. The result may include ragged and sparse tensors. |
| Raises | |
|---|---|
| TypeError | if fn is not callable or the structure of the output offn and fn_output_signature do not match. |
| ValueError | if the lengths of the output of fn and fn_output_signaturedo not match, or if the elems does not contain any tensor. |
| Examples |
|---|
| >>> elems = np.array([1, 2, 3, 4, 5, 6]) >>> tf.map_fn(lambda x: x * x, elems) <tf.Tensor: shape=(6,), dtype=int64, numpy=array([ 1, 4, 9, 16, 25, 36])> elems = (np.array([1, 2, 3]), np.array([-1, 1, -1])) tf.map_fn(lambda x: x[0] * x[1], elems, fn_output_signature=tf.int64) <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, 2, -3])> elems = np.array([1, 2, 3]) tf.map_fn(lambda x: (x, -x), elems, fn_output_signature=(tf.int64, tf.int64)) (<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, -2, -3])>) |
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-04-26 UTC.