API Reference — mpi4jax documentation (original) (raw)

Utilities

has_cuda_support

mpi4jax.has_cuda_support() → bool

Returns True if mpi4jax is built with CUDA support and can be used with GPU-based jax-arrays, False otherwise.

Communication primitives

allgather

mpi4jax.allgather(x, *, comm=None, token=None)

Perform an allgather operation.

Warning

x must have the same shape and dtype on all processes.

Parameters:

Returns:

Return type:

Tuple[DeviceArray, Token]

allreduce

mpi4jax.allreduce(x, op, *, comm=None, token=None)

Perform an allreduce operation.

Note

This primitive can be differentiated via jax.grad() and related functions if op is mpi4py.MPI.SUM.

Parameters:

Returns:

Return type:

Tuple[DeviceArray, Token]

alltoall

mpi4jax.alltoall(x, *, comm=None, token=None)

Perform an alltoall operation.

Parameters:

Returns:

Return type:

Tuple[DeviceArray, Token]

barrier

mpi4jax.barrier(*, comm=None, token=None)

Perform a barrier operation.

Parameters:

Returns:

Return type:

Token

bcast

mpi4jax.bcast(x, root, *, comm=None, token=None)

Perform a bcast (broadcast) operation.

Warning

Unlike mpi4py’s bcast, this returns a new array with the received data.

Parameters:

Returns:

Return type:

Tuple[DeviceArray, Token]

gather

mpi4jax.gather(x, root, *, comm=None, token=None)

Perform a gather operation.

Warning

x must have the same shape and dtype on all processes.

Warning

The shape of the returned data varies between ranks. On the root process, it is (nproc, *input_shape). On all other processes the output is identical to the input.

Parameters:

Returns:

Return type:

Tuple[DeviceArray, Token]

recv

mpi4jax.recv(x, source=-1, *, tag=-1, comm=None, status=None, token=None)

Perform a recv (receive) operation.

Warning

Unlike mpi4py’s recv, this returns a new array with the received data.

Parameters:

Returns:

Return type:

Tuple[DeviceArray, Token]

reduce

mpi4jax.reduce(x, op, root, *, comm=None, token=None)

Perform a reduce operation.

Parameters:

Returns:

Return type:

Tuple[DeviceArray, Token]

scan

mpi4jax.scan(x, op, *, comm=None, token=None)

Perform a scan operation.

Parameters:

Returns:

Return type:

Tuple[DeviceArray, Token]

scatter

mpi4jax.scatter(x, root, *, comm=None, token=None)

Perform a scatter operation.

Warning

Unlike mpi4py’s scatter, this returns a new array with the received data.

Warning

The expected shape of the first input varies between ranks. On the root process, it is (nproc, *input_shape). On all other processes, it is input_shape.

Parameters:

Returns:

Return type:

Tuple[DeviceArray, Token]

send

mpi4jax.send(x, dest, *, tag=0, comm=None, token=None)

Perform a send operation.

Parameters:

Returns:

A new, modified token, that depends on this operation.

Return type:

Token

sendrecv

mpi4jax.sendrecv(sendbuf, recvbuf, source, dest, *, sendtag=0, recvtag=-1, comm=None, status=None, token=None)

Perform a sendrecv operation.

Warning

Unlike mpi4py’s sendrecv, this returns a new array with the received data.

Parameters:

Returns:

Return type:

Tuple[DeviceArray, Token]