jax.distributed module — JAX documentation (original) (raw)

jax.distributed module#

initialize([coordinator_address, ...]) Initializes the JAX distributed system.
shutdown() Shuts down the distributed system.