jax.Array.astype — JAX documentation (original) (raw)
jax.Array.astype#
abstract Array.astype(dtype, copy=False, device=None)[source]#
Copy the array and cast to a specified dtype.
This is implemented via jax.lax.convert_element_type(), which may have slightly different behavior than numpy.ndarray.astype() in some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent.
Parameters:
Return type: