jax.Array.mean — JAX documentation (original) (raw)
jax.Array.mean#
abstract Array.mean(axis=None, dtype=None, out=None, keepdims=False, *, where=None)[source]#
Return the mean of array elements along a given axis.
Refer to jax.numpy.mean() for the full documentation.
Parameters:
- self (Array)
- axis (reductions.Axis)
- dtype (DTypeLike | None)
- out (None)
- keepdims (bool)
- where (ArrayLike | None)
Return type: