jax.lax.convert_element_type — JAX documentation (original) (raw)

jax.lax.convert_element_type#

jax.lax.convert_element_type(operand, new_dtype)[source]#

Elementwise cast.

This function lowers directly to the stablehlo.convert operation, which performs an elementwise conversion from one type to another, similar to a C++ static_cast.

Parameters:

Returns:

An array with the same shape as operand, cast elementwise to new_dtype.

Return type:

Array

Note

If new_dtype is a 64-bit type and x64 mode is not enabled, the appropriate 32-bit type will be used in its place.

If the input is a JAX array and the input dtype and output dtype match, then the input array will be returned unmodified.

See also