mpi4jax — mpi4jax documentation (original) (raw)
mpi4jax
enables zero-copy, multi-host communication of JAX arrays, even from jitted code and from GPU memory.
But why?
The JAX framework has great performance for scientific computing workloads, but its multi-host capabilities are still limited.
With mpi4jax
, you can scale your JAX-based simulations to entire CPU and GPU clusters (without ever leaving jax.jit
).
In the spirit of differentiable programming, mpi4jax
also supports differentiating through some MPI operations.
Installation
mpi4jax
is available through pip
and conda
:
$ pip install mpi4jax # Pip $ conda install -c conda-forge mpi4jax # conda
Depending on the different jax backends you want to use, you can install mpi4jax in the following way
pip install 'jax[cpu]'
$ pip install mpi4jax
pip install -U 'jax[cuda12]'
$ pip install cython $ pip install mpi4jax --no-build-isolation
pip install -U 'jax[cuda12_local]'
$ CUDA_ROOT=XXX pip install mpi4jax
(for more informations on jax GPU distributions, see the JAX installation instructions)
In case your MPI installation is not detected correctly, it can help to install mpi4py separately. When using a pre-installed mpi4py
, you must use --no-build-isolation
when installing mpi4jax
:
if mpi4py is already installed
$ pip install cython $ pip install mpi4jax --no-build-isolation
Our documentation includes some more advanced installation examples.
Example usage
from mpi4py import MPI import jax import jax.numpy as jnp import mpi4jax
comm = MPI.COMM_WORLD rank = comm.Get_rank()
@jax.jit def foo(arr): arr = arr + rank arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm) return arr_sum
a = jnp.zeros((3, 3)) result = foo(a)
if rank == 0: print(result)
Running this script on 4 processes gives:
$ mpirun -n 4 python example.py [[6. 6. 6.] [6. 6. 6.] [6. 6. 6.]]
allreduce
is just one example of the MPI primitives you can use. See all supported operations here.
How to cite
If you use mpi4jax
in your work, please consider citing the following article:
@article{mpi4jax, doi = {10.21105/joss.03419}, url = {https://doi.org/10.21105/joss.03419}, year = {2021}, publisher = {The Open Journal}, volume = {6}, number = {65}, pages = {3419}, author = {Dion Häfner and Filippo Vicentini}, title = {mpi4jax: Zero-copy MPI communication of JAX arrays}, journal = {Journal of Open Source Software} }
Contents
- Installation
- Usage examples
- Demo application: Shallow-water model
- 🔪 The Sharp Bits 🔪
- API Reference
- Developer guide