jax.numpy.load — JAX documentation (original) (raw)

jax.numpy.load#

jax.numpy.load(file, *args, **kwargs)[source]#

Load JAX arrays from npy files.

JAX wrapper of numpy.load().

This function is a simple wrapper of numpy.load(), but in the case of.npy files created with numpy.save() or jax.numpy.save(), the output will be returned as a jax.Array, and bfloat16 data types will be restored. For .npz files, results will be returned as normal NumPy arrays.

This function requires concrete array inputs, and is not compatible with transformations like jax.jit() or jax.vmap().

Parameters:

Returns:

the array stored in the file.

Return type:

Array

Examples

import io f = io.BytesIO() # use an in-memory file-like object. x = jnp.array([2, 4, 6, 8], dtype='bfloat16') jnp.save(f, x) f.seek(0) 0 jnp.load(f) Array([2, 4, 6, 8], dtype=bfloat16)