tf.ensure_shape  |  TensorFlow v2.16.1 (original) (raw)

Updates the shape of a tensor and checks at runtime that the shape holds.

View aliases

Compat aliases for migration

SeeMigration guide for more details.

tf.compat.v1.ensure_shape

tf.ensure_shape(
    x, shape, name=None
)

When executed, this operation asserts that the input tensor x's shape is compatible with the shape argument. See tf.TensorShape.is_compatible_with for details.

x = tf.constant([[1, 2, 3], [4, 5, 6]]) x = tf.ensure_shape(x, [2, 3])

Use None for unknown dimensions:

x = tf.ensure_shape(x, [None, 3]) x = tf.ensure_shape(x, [2, None])

If the tensor's shape is not compatible with the shape argument, an error is raised:

x = tf.ensure_shape(x, [5]) Traceback (most recent call last): `` tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not compatible with expected shape [5]. [Op:EnsureShape]

During graph construction (typically tracing a tf.function),tf.ensure_shape updates the static-shape of the result tensor by merging the two shapes. See tf.TensorShape.merge_with for details.

This is most useful when you know a shape that can't be determined statically by TensorFlow.

The following trivial tf.function prints the input tensor's static-shape before and after ensure_shape is applied.

@tf.function def f(tensor): print("Static-shape before:", tensor.shape) tensor = tf.ensure_shape(tensor, [None, 3]) print("Static-shape after:", tensor.shape) return tensor

This lets you see the effect of tf.ensure_shape when the function is traced:

>>> cf = f.get_concrete_function(tf.TensorSpec([None, None]))
Static-shape before: (None, None)
Static-shape after: (None, 3)

cf(tf.zeros([3, 3])) # Passes cf(tf.constant([1, 2, 3])) # fails Traceback (most recent call last): `` InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3].

The above example raises tf.errors.InvalidArgumentError, because x's shape, (3,), is not compatible with the shape argument, (None, 3)

Inside a tf.function or v1.Graph context it checks both the buildtime and runtime shapes. This is stricter than tf.Tensor.set_shape which only checks the buildtime shape.

For example, of loading images of a known size:

@tf.function def decode_image(png): image = tf.image.decode_png(png, channels=3) # the `print` executes during tracing. print("Initial shape: ", image.shape) image = tf.ensure_shape(image,[28, 28, 3]) print("Final shape: ", image.shape) return image

When tracing a function, no ops are being executed, shapes may be unknown. See the Concrete Functions Guidefor details.

concrete_decode = decode_image.get_concrete_function( tf.TensorSpec([], dtype=tf.string)) Initial shape: (None, None, 3) Final shape: (28, 28, 3)

image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32) image = tf.cast(image,tf.uint8) png = tf.image.encode_png(image) image2 = concrete_decode(png) print(image2.shape) (28, 28, 3)

image = tf.concat([image,image], axis=0) print(image.shape) (56, 28, 3) png = tf.image.encode_png(image) image2 = concrete_decode(png) Traceback (most recent call last): `` tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is not compatible with expected shape [28,28,3].

@tf.function def bad_decode_image(png): image = tf.image.decode_png(png, channels=3) # the `print` executes during tracing. print("Initial shape: ", image.shape) # BAD: forgot to use the returned tensor. tf.ensure_shape(image,[28, 28, 3]) print("Final shape: ", image.shape) return image

image = bad_decode_image(png) Initial shape: (None, None, 3) Final shape: (None, None, 3) print(image.shape) (56, 28, 3)

Args
x A Tensor.
shape A TensorShape representing the shape of this tensor, aTensorShapeProto, a list, a tuple, or None.
name A name for this operation (optional). Defaults to "EnsureShape".
Returns
A Tensor. Has the same type and contents as x.
Raises
tf.errors.InvalidArgumentError If shape is incompatible with the shape of x.