jax.numpy.arange — JAX documentation (original) (raw)

jax.numpy.arange#

jax.numpy.arange(start, stop=None, step=None, dtype=None, *, device=None, out_sharding=None)[source]#

Create an array of evenly-spaced values.

JAX implementation of numpy.arange(), implemented in terms ofjax.lax.iota().

Similar to Python’s range() function, this can be called with a few different positional signatures:

Like with Python’s range() function, the starting value is inclusive, and the stop value is exclusive.

Parameters:

Returns:

Array of evenly-spaced values from start to stop, separated by step.

Return type:

Array

Note

Using arange with a floating-point step argument can lead to unexpected results due to accumulation of floating-point errors, especially with lower-precision data types like float8_* and bfloat16. To avoid precision errors, consider generating a range of integers, and scaling it to the desired range. For example, instead of this:

jnp.arange(-1, 1, 0.01, dtype='bfloat16')

it can be more accurate to generate a sequence of integers, and scale them:

(jnp.arange(-100, 100) * 0.01).astype('bfloat16')

Examples

Single-argument version specifies only the stop value:

jnp.arange(4) Array([0, 1, 2, 3], dtype=int32)

Passing a floating-point stop value leads to a floating-point result:

jnp.arange(4.0) Array([0., 1., 2., 3.], dtype=float32)

Two-argument version specifies start and stop, with step=1:

jnp.arange(1, 6) Array([1, 2, 3, 4, 5], dtype=int32)

Three-argument version specifies start, stop, and step:

jnp.arange(0, 2, 0.5) Array([0. , 0.5, 1. , 1.5], dtype=float32)

See also