tf.expand_dims  |  TensorFlow v2.16.1 (original) (raw)

tf.expand_dims

Stay organized with collections Save and categorize content based on your preferences.

Returns a tensor with a length 1 axis inserted at index axis.

tf.expand_dims(
    input, axis, name=None
)

Used in the notebooks

Used in the guide Used in the tutorials
Extension types Import a JAX model using JAX2TF Migrate `tf.feature_column`s to Keras preprocessing layers Understanding masking & padding Working with RNNs Integrated gradients Playing CartPole with the Actor-Critic method Generate music with an RNN DeepDream pix2pix: Image-to-image translation with a conditional GAN

Given a tensor input, this operation inserts a dimension of length 1 at the dimension index axis of input's shape. The dimension index follows Python indexing rules: It's zero-based, a negative index it is counted backward from the end.

This operation is useful to:

For example:

If you have a single image of shape [height, width, channels]:

image = tf.zeros([10,10,3])

You can add an outer batch axis by passing axis=0:

tf.expand_dims(image, axis=0).shape.as_list() [1, 10, 10, 3]

The new axis location matches Python list.insert(axis, 1):

tf.expand_dims(image, axis=1).shape.as_list() [10, 1, 10, 3]

Following standard Python indexing rules, a negative axis counts from the end so axis=-1 adds an inner most dimension:

tf.expand_dims(image, -1).shape.as_list() [10, 10, 3, 1]

This operation requires that axis is a valid index for input.shape, following Python indexing rules:

-1-tf.rank(input) <= axis <= tf.rank(input)

This operation is related to:

Args
input A Tensor.
axis Integer specifying the dimension index at which to expand the shape of input. Given an input of D dimensions, axis must be in range[-(D+1), D] (inclusive).
name Optional string. The name of the output Tensor.
Returns
A tensor with the same data as input, with an additional dimension inserted at the index specified by axis.
Raises
TypeError If axis is not specified.
InvalidArgumentError If axis is out of range [-(D+1), D].