jax.Array.squeeze — JAX documentation (original) (raw)
jax.Array.squeeze#
abstract Array.squeeze(axis=None)[source]#
Remove one or more length-1 axes from array.
Refer to jax.numpy.squeeze() for full documentation.
Parameters:
- self (Array)
- axis (reductions.Axis)
Return type: