jax.numpy.astype — JAX documentation (original) (raw)
jax.numpy.astype#
jax.numpy.astype(x, dtype, /, *, copy=False, device=None)[source]#
Convert an array to a specified dtype.
JAX imlementation of numpy.astype().
This is implemented via jax.lax.convert_element_type(), which may have slightly different behavior than numpy.astype() in some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent.
Parameters:
- x (ArrayLike) – input array to convert
- dtype (DTypeLike | None) – output dtype
- copy (bool) – if True, then always return a copy. If False (default) then only return a copy if necessary.
- device (xc.Device | Sharding | None | None) – optionally specify the device to which the output will be committed.
Returns:
An array with the same shape as x
, containing values of the specified dtype.
Return type:
See also
- jax.lax.convert_element_type(): lower-level function for XLA-style dtype conversions.
Examples
x = jnp.array([0, 1, 2, 3]) x Array([0, 1, 2, 3], dtype=int32) x.astype('float32') Array([0.0, 1.0, 2.0, 3.0], dtype=float32)
y = jnp.array([0.0, 0.5, 1.0]) y.astype(int) # truncates fractional values Array([0, 0, 1], dtype=int32)