jax.local_devices — JAX documentation (original) (raw)

jax.local_devices#

jax.local_devices(process_index=None, backend=None, host_id=None)[source]#

Like jax.devices(), but only returns devices local to a given process.

If process_index is None, returns devices local to this process.

Parameters:

Returns:

List of Device subclasses.

Return type:

list[xla_client.Device]