jax.numpy.insert — JAX documentation (original) (raw)
jax.numpy.insert#
jax.numpy.insert(arr, obj, values, axis=None)[source]#
Insert entries into an array at specified indices.
JAX implementation of numpy.insert().
Parameters:
- arr (ArrayLike) – array object into which values will be inserted.
- obj (ArrayLike | slice) – slice or array of indices specifying insertion locations.
- values (ArrayLike) – array of values to be inserted.
- axis (int | None | None) – specify the insertion axis in the case of multi-dimensional arrays. If unspecified,
arr
will be flattened.
Returns:
A copy of arr
with values inserted at the specified locations.
Return type:
Examples
Inserting a single value:
x = jnp.arange(5) jnp.insert(x, 2, 99) Array([ 0, 1, 99, 2, 3, 4], dtype=int32)
Inserting multiple identical values using a slice:
jnp.insert(x, slice(None, None, 2), -1) Array([-1, 0, 1, -1, 2, 3, -1, 4], dtype=int32)
Inserting multiple values using an index:
indices = jnp.array([4, 2, 5]) values = jnp.array([10, 11, 12]) jnp.insert(x, indices, values) Array([ 0, 1, 11, 2, 3, 10, 4, 12], dtype=int32)
Inserting columns into a 2D array:
x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) indices = jnp.array([1, 3]) values = jnp.array([[10, 11], ... [12, 13]]) jnp.insert(x, indices, values, axis=1) Array([[ 1, 10, 2, 3, 11], [ 4, 12, 5, 6, 13]], dtype=int32)