jax.numpy.squeeze — JAX documentation (original) (raw)

jax.numpy.squeeze#

jax.numpy.squeeze(a, axis=None)[source]#

Remove one or more length-1 axes from array

JAX implementation of numpy.sqeeze(), implemented via jax.lax.squeeze().

Parameters:

Returns:

copy of a with length-1 axes removed.

Return type:

Array

Notes

Unlike numpy.squeeze(), jax.numpy.squeeze() will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.

See also

Examples

x = jnp.array([[[0]], [[1]], [[2]]]) x.shape (3, 1, 1)

Squeeze all length-1 dimensions:

jnp.squeeze(x) Array([0, 1, 2], dtype=int32) _.shape (3,)

Equivalent while specifying the axes explicitly:

jnp.squeeze(x, axis=(1, 2)) Array([0, 1, 2], dtype=int32)

Attempting to squeeze a non-unit axis results in an error:

jnp.squeeze(x, axis=0)
Traceback (most recent call last): ... ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)

For convenience, this functionality is also available via thejax.Array.squeeze() method:

x.squeeze() Array([0, 1, 2], dtype=int32)