jax.Array.take — JAX documentation (original) (raw)

jax.Array.take#

abstract Array.take(indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[source]#

Take elements from an array.

Refer to jax.numpy.take() for full documentation.

Parameters:

Return type:

Array