jax.Array.choose — JAX documentation (original) (raw)

jax.Array.choose#

abstract Array.choose(choices, out=None, mode='raise')[source]#

Construct an array choosing from elements of multiple arrays.

Refer to jax.numpy.choose() for the full documentation.

Parameters:

Return type:

Array