jax.numpy.where — JAX documentation (original) (raw)

jax.numpy.where#

jax.numpy.where(condition, x=None, y=None, /, *, size=None, fill_value=None)[source]#

Select elements from two arrays based on a condition.

JAX implementation of numpy.where().

Note

when only condition is provided, jnp.where(condition) is equivalent to jnp.nonzero(condition). For that case, refer to the documentation ofjax.numpy.nonzero(). The docstring below focuses on the case wherex and y are specified.

The three-term version of jnp.where lowers to jax.lax.select().

Parameters:

Returns:

An array of dtype jnp.result_type(x, y) with values drawn from x where conditionis True, and from y where condition is False. If x and y are None, the function behaves differently; see jax.numpy.nonzero() for a description of the return type.

Notes

Special care is needed when the x or y input to jax.numpy.where() could have a value of NaN. Specifically, when a gradient is taken with jax.grad()(reverse-mode differentiation), a NaN in either x or y will propagate into the gradient, regardless of the value of condition. More information on this behavior and workarounds is available in the JAX FAQ.

Examples

When x and y are not provided, where behaves equivalently tojax.numpy.nonzero():

x = jnp.arange(10) jnp.where(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),) jnp.nonzero(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),)

When x and y are provided, where selects between them based on the specified condition:

jnp.where(x > 4, x, 0) Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)