jax.numpy.nanmax — JAX documentation (original) (raw)

jax.numpy.nanmax#

jax.numpy.nanmax(a, axis=None, out=None, keepdims=False, initial=None, where=None)[source]#

Return the maximum of the array elements along a given axis, ignoring NaNs.

JAX implementation of numpy.nanmax().

Parameters:

Returns:

An array of maximum values along the given axis, ignoring NaNs. If all values are NaNs along the given axis, returns nan.

Return type:

Array

See also

Examples

By default, jnp.nanmax computes the maximum of elements along the flattened array.

nan = jnp.nan x = jnp.array([[8, nan, 4, 6], ... [nan, -2, nan, -4], ... [-2, 1, 7, nan]]) jnp.nanmax(x) Array(8., dtype=float32)

If axis=1, the maximum will be computed along axis 1.

jnp.nanmax(x, axis=1) Array([ 8., -2., 7.], dtype=float32)

If keepdims=True, ndim of the output will be same of that of the input.

jnp.nanmax(x, axis=1, keepdims=True) Array([[ 8.], [-2.], [ 7.]], dtype=float32)

To include only specific elements in computing the maximum, you can usewhere. It can either have same dimension as input

where=jnp.array([[0, 0, 1, 0], ... [0, 0, 1, 1], ... [1, 1, 1, 0]], dtype=bool) jnp.nanmax(x, axis=1, keepdims=True, initial=0, where=where) Array([[4.], [0.], [7.]], dtype=float32)

or must be broadcast compatible with input.

where = jnp.array([[True], ... [False], ... [False]]) jnp.nanmax(x, axis=0, keepdims=True, initial=0, where=where) Array([[8., 0., 4., 6.]], dtype=float32)