jax.Array.squeeze — JAX documentation (original) (raw)

jax.Array.squeeze#

abstract Array.squeeze(axis=None)[source]#

Remove one or more length-1 axes from array.

Refer to jax.numpy.squeeze() for full documentation.

Parameters:

Return type:

Array