API Reference — mpi4jax documentation (original) (raw)
Utilities
has_cuda_support
mpi4jax.has_cuda_support() → bool
Returns True if mpi4jax is built with CUDA support and can be used with GPU-based jax-arrays, False otherwise.
Communication primitives
allgather
mpi4jax.allgather(x, *, comm=None, token=None)
Perform an allgather operation.
Warning
x
must have the same shape and dtype on all processes.
Parameters:
- x – Array or scalar input to send.
- comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
). - token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
Returns:
- Received data.
- A new, modified token, that depends on this operation.
Return type:
Tuple[DeviceArray, Token]
allreduce
mpi4jax.allreduce(x, op, *, comm=None, token=None)
Perform an allreduce operation.
Note
This primitive can be differentiated via jax.grad() and related functions if op
is mpi4py.MPI.SUM
.
Parameters:
- x – Array or scalar input.
- op (mpi4py.MPI.Op) – The reduction operator (e.g
mpi4py.MPI.SUM
). - comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
). - token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
Returns:
- Result of the allreduce operation.
- A new, modified token, that depends on this operation.
Return type:
Tuple[DeviceArray, Token]
alltoall
mpi4jax.alltoall(x, *, comm=None, token=None)
Perform an alltoall operation.
Parameters:
- x – Array input to send. First axis must have size
nproc
. - comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
). - token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
Returns:
- Received data.
- A new, modified token, that depends on this operation.
Return type:
Tuple[DeviceArray, Token]
barrier
mpi4jax.barrier(*, comm=None, token=None)
Perform a barrier operation.
Parameters:
- comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
). - token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
Returns:
- A new, modified token, that depends on this operation.
Return type:
Token
bcast
mpi4jax.bcast(x, root, *, comm=None, token=None)
Perform a bcast (broadcast) operation.
Warning
Unlike mpi4py’s bcast, this returns a new array with the received data.
Parameters:
- x – Array or scalar input. Data is only read on root process. On non-root processes, this is used to determine the shape and dtype of the result.
- root (int) – The process to use as source.
- comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
). - token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
Returns:
- Received data.
- A new, modified token, that depends on this operation.
Return type:
Tuple[DeviceArray, Token]
gather
mpi4jax.gather(x, root, *, comm=None, token=None)
Perform a gather operation.
Warning
x
must have the same shape and dtype on all processes.
Warning
The shape of the returned data varies between ranks. On the root process, it is (nproc, *input_shape)
. On all other processes the output is identical to the input.
Parameters:
- x – Array or scalar input to send.
- root (int) – Rank of the root MPI process.
- comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
). - token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
Returns:
- Received data on root process, otherwise unmodified input.
- A new, modified token, that depends on this operation.
Return type:
Tuple[DeviceArray, Token]
recv
mpi4jax.recv(x, source=-1, *, tag=-1, comm=None, status=None, token=None)
Perform a recv (receive) operation.
Warning
Unlike mpi4py’s recv, this returns a new array with the received data.
Parameters:
- x – Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten.
- source (int) – Rank of the source MPI process.
- tag (int) – Tag of this message.
- comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
). - status (mpi4py.MPI.Status) – Status object, can be used for introspection.
- token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
Returns:
- Received data.
- A new, modified token, that depends on this operation.
Return type:
Tuple[DeviceArray, Token]
reduce
mpi4jax.reduce(x, op, root, *, comm=None, token=None)
Perform a reduce operation.
Parameters:
- x – Array or scalar input to send.
- op (mpi4py.MPI.Op) – The reduction operator (e.g
mpi4py.MPI.SUM
). - root (int) – Rank of the root MPI process.
- comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
). - token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
Returns:
- Result of the reduce operation on root process, otherwise unmodified input.
- A new, modified token, that depends on this operation.
Return type:
Tuple[DeviceArray, Token]
scan
mpi4jax.scan(x, op, *, comm=None, token=None)
Perform a scan operation.
Parameters:
- x – Array or scalar input to send.
- op (mpi4py.MPI.Op) – The reduction operator (e.g
mpi4py.MPI.SUM
). - comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
). - token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
Returns:
- Result of the scan operation.
- A new, modified token, that depends on this operation.
Return type:
Tuple[DeviceArray, Token]
scatter
mpi4jax.scatter(x, root, *, comm=None, token=None)
Perform a scatter operation.
Warning
Unlike mpi4py’s scatter, this returns a new array with the received data.
Warning
The expected shape of the first input varies between ranks. On the root process, it is (nproc, *input_shape)
. On all other processes, it is input_shape
.
Parameters:
- x – Array or scalar input with the correct shape and dtype. On the root process, this contains the data to send, and its first axis must have size
nproc
. On non-root processes, this may contain arbitrary data and will not be overwritten. - root (int) – Rank of the root MPI process.
- comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
). - token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
Returns:
- Received data.
- A new, modified token, that depends on this operation.
Return type:
Tuple[DeviceArray, Token]
send
mpi4jax.send(x, dest, *, tag=0, comm=None, token=None)
Perform a send operation.
Parameters:
- x – Array or scalar input to send.
- dest (int) – Rank of the destination MPI process.
- tag (int) – Tag of this message.
- comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
). - token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
Returns:
A new, modified token, that depends on this operation.
Return type:
Token
sendrecv
mpi4jax.sendrecv(sendbuf, recvbuf, source, dest, *, sendtag=0, recvtag=-1, comm=None, status=None, token=None)
Perform a sendrecv operation.
Warning
Unlike mpi4py’s sendrecv, this returns a new array with the received data.
Parameters:
- sendbuf – Array or scalar input to send.
- recvbuf – Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten.
- source (int) – Rank of the source MPI process.
- dest (int) – Rank of the destination MPI process.
- sendtag (int) – Tag of this message for sending.
- recvtag (int) – Tag of this message for receiving.
- comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of
COMM_WORLD
). - status (mpi4py.MPI.Status) – Status object, can be used for introspection.
- token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.
Returns:
- Received data.
- A new, modified token, that depends on this operation.
Return type:
Tuple[DeviceArray, Token]