jax.local_devices — JAX documentation (original) (raw)
jax.local_devices#
jax.local_devices(process_index=None, backend=None, host_id=None)[source]#
Like jax.devices(), but only returns devices local to a given process.
If process_index
is None
, returns devices local to this process.
Parameters:
- process_index (int | None) – the integer index of the process. Process indices can be retrieved via
len(jax.process_count())
. - backend (str | xla_client.Client | None) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend:
'cpu'
,'gpu'
, or'tpu'
. - host_id (int | None)
Returns:
List of Device subclasses.
Return type:
list[xla_client.Device]