jax.numpy.where — JAX documentation (original) (raw)
jax.numpy.where#
jax.numpy.where(condition, x=None, y=None, /, *, size=None, fill_value=None)[source]#
Select elements from two arrays based on a condition.
JAX implementation of numpy.where().
Note
when only condition
is provided, jnp.where(condition)
is equivalent to jnp.nonzero(condition)
. For that case, refer to the documentation ofjax.numpy.nonzero(). The docstring below focuses on the case wherex
and y
are specified.
The three-term version of jnp.where
lowers to jax.lax.select().
Parameters:
- condition – boolean array. Must be broadcast-compatible with
x
andy
when they are specified. - x – arraylike. Should be broadcast-compatible with
condition
andy
, and typecast-compatible withy
. - y – arraylike. Should be broadcast-compatible with
condition
andx
, and typecast-compatible withx
. - size – integer, only referenced when
x
andy
areNone
. For details, see jax.numpy.nonzero(). - fill_value – only referenced when
x
andy
areNone
. For details, see jax.numpy.nonzero().
Returns:
An array of dtype jnp.result_type(x, y)
with values drawn from x
where condition
is True, and from y
where condition is False
. If x
and y
are None
, the function behaves differently; see jax.numpy.nonzero() for a description of the return type.
Notes
Special care is needed when the x
or y
input to jax.numpy.where() could have a value of NaN. Specifically, when a gradient is taken with jax.grad()(reverse-mode differentiation), a NaN in either x
or y
will propagate into the gradient, regardless of the value of condition
. More information on this behavior and workarounds is available in the JAX FAQ.
Examples
When x
and y
are not provided, where
behaves equivalently tojax.numpy.nonzero():
x = jnp.arange(10) jnp.where(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),) jnp.nonzero(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),)
When x
and y
are provided, where
selects between them based on the specified condition:
jnp.where(x > 4, x, 0) Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)