jax.Array.choose — JAX documentation (original) (raw)
jax.Array.choose#
abstract Array.choose(choices, out=None, mode='raise')[source]#
Construct an array choosing from elements of multiple arrays.
Refer to jax.numpy.choose() for the full documentation.
Parameters:
Return type: