jax.numpy.astype — JAX documentation (original) (raw)

jax.numpy.astype#

jax.numpy.astype(x, dtype, /, *, copy=False, device=None)[source]#

Convert an array to a specified dtype.

JAX imlementation of numpy.astype().

This is implemented via jax.lax.convert_element_type(), which may have slightly different behavior than numpy.astype() in some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent.

Parameters:

Returns:

An array with the same shape as x, containing values of the specified dtype.

Return type:

Array

See also

Examples

x = jnp.array([0, 1, 2, 3]) x Array([0, 1, 2, 3], dtype=int32) x.astype('float32') Array([0.0, 1.0, 2.0, 3.0], dtype=float32)

y = jnp.array([0.0, 0.5, 1.0]) y.astype(int) # truncates fractional values Array([0, 0, 1], dtype=int32)