jax.Array.take — JAX documentation (original) (raw)
jax.Array.take#
abstract Array.take(indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[source]#
Take elements from an array.
Refer to jax.numpy.take() for full documentation.
Parameters:
- self (Array)
- indices (ArrayLike)
- axis (int | None)
- out (None)
- mode (str | None)
- unique_indices (bool)
- indices_are_sorted (bool)
- fill_value (StaticScalar | None)
Return type: