jax.numpy.rollaxis — JAX documentation (original) (raw)
jax.numpy.rollaxis#
jax.numpy.rollaxis(a, axis, start=0)[source]#
Roll the specified axis to a given position.
JAX implementation of numpy.rollaxis().
This function exists for compatibility with NumPy, but in most cases the newerjax.numpy.moveaxis() instead, because the meaning of its arguments is more intuitive.
Parameters:
- a (ArrayLike) – input array.
- axis (int) – index of the axis to roll forward.
- start (int) – index toward which the axis will be rolled (default = 0). After normalizing negative axes, if
start <= axis, the axis is rolled to thestartindex; ifstart > axis, the axis is rolled until the position beforestart.
Returns:
Copy of a with rolled axis.
Return type:
Notes
Unlike numpy.rollaxis(), jax.numpy.rollaxis() will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.
See also
- jax.numpy.moveaxis(): newer API with clearer semantics than
rollaxis; this should be preferred torollaxisin most cases. - jax.numpy.swapaxes(): swap two axes.
- jax.numpy.transpose(): general permutation of axes.
Examples
a = jnp.ones((2, 3, 4, 5))
Roll axis 2 to the start of the array:
jnp.rollaxis(a, 2).shape (4, 2, 3, 5)
Roll axis 1 to the end of the array:
jnp.rollaxis(a, 1, a.ndim).shape (2, 4, 5, 3)
Equivalent of these two with moveaxis()
jnp.moveaxis(a, 2, 0).shape (4, 2, 3, 5) jnp.moveaxis(a, 1, -1).shape (2, 4, 5, 3)