jax.numpy.array — JAX documentation (original) (raw)

jax.numpy.array#

jax.numpy.array(object, dtype=None, copy=True, order='K', ndmin=0, *, device=None)[source]#

Convert an object to a JAX array.

JAX implementation of numpy.array().

Parameters:

Returns:

A JAX array constructed from the input.

Return type:

Array

See also

Examples

Constructing JAX arrays from Python scalars:

jnp.array(True) Array(True, dtype=bool) jnp.array(42) Array(42, dtype=int32, weak_type=True) jnp.array(3.5) Array(3.5, dtype=float32, weak_type=True) jnp.array(1 + 1j) Array(1.+1.j, dtype=complex64, weak_type=True)

Constructing JAX arrays from Python collections:

jnp.array([1, 2, 3]) # list of ints -> 1D array Array([1, 2, 3], dtype=int32) jnp.array([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array Array([[1, 2, 3], [4, 5, 6]], dtype=int32) jnp.array(range(5)) Array([0, 1, 2, 3, 4], dtype=int32)

Constructing JAX arrays from NumPy arrays:

jnp.array(np.linspace(0, 2, 5)) Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)

Constructing a JAX array via the Python buffer interface, using Python’s built-in array module.

from array import array pybuffer = array('i', [2, 3, 5, 7]) jnp.array(pybuffer) Array([2, 3, 5, 7], dtype=int32)