jax.lax.convert_element_type — JAX documentation (original) (raw)
jax.lax.convert_element_type#
jax.lax.convert_element_type(operand, new_dtype)[source]#
Elementwise cast.
This function lowers directly to the stablehlo.convert operation, which performs an elementwise conversion from one type to another, similar to a C++ static_cast
.
Parameters:
- operand (ArrayLike) – an array or scalar value to be cast.
- new_dtype (DTypeLike | dtypes.ExtendedDType) – a dtype-like object (e.g. a numpy.dtype, a scalar type, or a valid dtype name) representing the target dtype.
Returns:
An array with the same shape as operand
, cast elementwise to new_dtype
.
Return type:
Note
If new_dtype
is a 64-bit type and x64 mode is not enabled, the appropriate 32-bit type will be used in its place.
If the input is a JAX array and the input dtype and output dtype match, then the input array will be returned unmodified.
See also
- jax.numpy.astype(): NumPy-style dtype casting API.
- jax.Array.astype(): dtype casting as an array method.
- jax.lax.bitcast_convert_type(): cast bits directly to a new dtype.