[mlir] Initial patch to add an MPI dialect by AntonLydike · Pull Request #68892 · llvm/llvm-project (original) (raw)

This patch introduces the new MPI dialect into MLIR. The Message Passing Interface (MPI) is a widely-used standard for distributed programs to exchange data. This PR goes together with a talk later at today's LLVM Dev Meeting.

This is just a first, small patch to get going and add the necessary base files, so that we can add more operations in further patches.

Here's the documentation as generated by ninja mlir-doc:

'mpi' Dialect

This dialect models the Message Passing Interface (MPI), version
4.0. It is meant to serve as an interfacing dialect that is targeted
by higher-level dialects. The MPI dialect itself can be lowered to
multiple MPI implementations and hide differences in ABI. The dialect
models the functions of the MPI specification as close to 1:1 as possible
while preserving SSA value semantics where it makes sense, and uses
memref types instead of bare pointers.

This dialect is under active development, and while stability is an
eventual goal, it is not guaranteed at this juncture. Given the early
state, it is recommended to inquire further prior to using this dialect.

For an in-depth documentation of the MPI library interface, please refer
to official documentation such as the
OpenMPI online documentation.

[TOC]

Operation definition

mpi.comm_rank (mpi::CommRankOp)

Get the current rank, equivalent to MPICommrank(MPICOMMWORLD, &rank)

Syntax:

operation ::= `mpi.comm_rank` attr-dict `:` type(results)

Communicators other than MPI_COMM_WORLD are not supported for now.

This operation can optionally return an !mpi.retval value that can be used
to check for errors.

Results:

Result Description
retval MPI function call return value
rank 32-bit signless integer

mpi.error_class (mpi::ErrorClassOp)

Get the error class from an error code, equivalent to the MPIErrorclass function

Syntax:

operation ::= `mpi.error_class` <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>v</mi><mi>a</mi><mi>l</mi><mi>a</mi><mi>t</mi><mi>t</mi><mi>r</mi><mo>−</mo><mi>d</mi><mi>i</mi><mi>c</mi><mi>t</mi><mi mathvariant="normal">‘</mi><mo>:</mo><mi mathvariant="normal">‘</mi><mi>t</mi><mi>y</mi><mi>p</mi><mi>e</mi><mo stretchy="false">(</mo></mrow><annotation encoding="application/x-tex">val attr-dict `:` type(</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7778em;vertical-align:-0.0833em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">v</span><span class="mord mathnormal">a</span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mord mathnormal">a</span><span class="mord mathnormal">tt</span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">d</span><span class="mord mathnormal">i</span><span class="mord mathnormal">c</span><span class="mord mathnormal">t</span><span class="mord">‘</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">:</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord">‘</span><span class="mord mathnormal">t</span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mord mathnormal">p</span><span class="mord mathnormal">e</span><span class="mopen">(</span></span></span></span>val)

MPI_Error_class maps return values from MPI calls to a set of well-known
MPI error classes.

Operands:

Operand Description
val MPI function call return value

Results:

Result Description
errclass MPI function call return value

mpi.finalize (mpi::FinalizeOp)

Finalize the MPI library, equivalent to MPIFinalize()

Syntax:

operation ::= `mpi.finalize` attr-dict (`:` type($retval)^)?

This function cleans up the MPI state. Afterwards, no MPI methods may
be invoked (excpet for MPI_Get_version, MPI_Initialized, and MPI_Finalized).
Notably, MPI_Init cannot be called again in the same program.

This operation can optionally return an !mpi.retval value that can be used
to check for errors.

Results:

Result Description
retval MPI function call return value

mpi.init (mpi::InitOp)

Initialize the MPI library, equivalent to MPIInit(NULL, NULL)

Syntax:

operation ::= `mpi.init` attr-dict (`:` type($retval)^)?

This operation must preceed most MPI calls (except for very few exceptions,
please consult with the MPI specification on these).

Passing &argc, &argv is not supported currently.

This operation can optionally return an !mpi.retval value that can be used
to check for errors.

Results:

Result Description
retval MPI function call return value

mpi.recv (mpi::RecvOp)

Equivalent to MPIRecv(ptr, size, dtype, dest, tag, MPICOMMWORLD, MPISTATUSIGNORE)

Syntax:

operation ::= `mpi.recv` `(` <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>r</mi><mi>e</mi><mi>f</mi><mi mathvariant="normal">‘</mi><mo separator="true">,</mo><mi mathvariant="normal">‘</mi></mrow><annotation encoding="application/x-tex">ref `,` </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">re</span><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="mord">‘</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">‘</span></span></span></span>tag `,` <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>r</mi><mi>a</mi><mi>n</mi><mi>k</mi><mi mathvariant="normal">‘</mi><mo stretchy="false">)</mo><mi mathvariant="normal">‘</mi><mi>a</mi><mi>t</mi><mi>t</mi><mi>r</mi><mo>−</mo><mi>d</mi><mi>i</mi><mi>c</mi><mi>t</mi><mi mathvariant="normal">‘</mi><mo>:</mo><mi mathvariant="normal">‘</mi><mi>t</mi><mi>y</mi><mi>p</mi><mi>e</mi><mo stretchy="false">(</mo></mrow><annotation encoding="application/x-tex">rank `)` attr-dict `:` type(</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="mord mathnormal" style="margin-right:0.03148em;">ank</span><span class="mord">‘</span><span class="mclose">)</span><span class="mord">‘</span><span class="mord mathnormal">a</span><span class="mord mathnormal">tt</span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">d</span><span class="mord mathnormal">i</span><span class="mord mathnormal">c</span><span class="mord mathnormal">t</span><span class="mord">‘</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">:</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord">‘</span><span class="mord mathnormal">t</span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mord mathnormal">p</span><span class="mord mathnormal">e</span><span class="mopen">(</span></span></span></span>ref) `,` type($tag) `,` type($rank)(`->` type($retval)^)?

MPI_Recv performs a blocking receive of size elements of type dtype
from rank dest. The tag value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.

Communicators other than MPI_COMM_WORLD are not supprted for now.
The MPI_Status is set to MPI_STATUS_IGNORE, as the status object
is not yet ported to MLIR.

This operation can optionally return an !mpi.retval value that can be used
to check for errors.

Operands:

Operand Description
ref memref of any type values
tag 32-bit signless integer
rank 32-bit signless integer

Results:

Result Description
retval MPI function call return value

mpi.retval_check (mpi::RetvalCheckOp)

Check an MPI return value against an error class

Syntax:

operation ::= `mpi.retval_check` <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>v</mi><mi>a</mi><mi>l</mi><mi mathvariant="normal">‘</mi><mo>=</mo><mi mathvariant="normal">‘</mi></mrow><annotation encoding="application/x-tex">val `=` </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">v</span><span class="mord mathnormal">a</span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mord">‘</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord">‘</span></span></span></span>errclass attr-dict `:` type($res)

This operation compares MPI status codes to known error class
constants such as MPI_SUCCESS, or MPI_ERR_COMM.

Attributes:

Attribute MLIR Type Description
errclass ::mlir::mpi::MPI_ErrorClassEnumAttr MPI error class name{{% markdown %}}Enum cases: * MPI_SUCCESS (`MPI_SUCCESS`) * MPI_ERR_ACCESS (`MPI_ERR_ACCESS`) * MPI_ERR_AMODE (`MPI_ERR_AMODE`) * MPI_ERR_ARG (`MPI_ERR_ARG`) * MPI_ERR_ASSERT (`MPI_ERR_ASSERT`) * MPI_ERR_BAD_FILE (`MPI_ERR_BAD_FILE`) * MPI_ERR_BASE (`MPI_ERR_BASE`) * MPI_ERR_BUFFER (`MPI_ERR_BUFFER`) * MPI_ERR_COMM (`MPI_ERR_COMM`) * MPI_ERR_CONVERSION (`MPI_ERR_CONVERSION`) * MPI_ERR_COUNT (`MPI_ERR_COUNT`) * MPI_ERR_DIMS (`MPI_ERR_DIMS`) * MPI_ERR_DISP (`MPI_ERR_DISP`) * MPI_ERR_DUP_DATAREP (`MPI_ERR_DUP_DATAREP`) * MPI_ERR_ERRHANDLER (`MPI_ERR_ERRHANDLER`) * MPI_ERR_FILE (`MPI_ERR_FILE`) * MPI_ERR_FILE_EXISTS (`MPI_ERR_FILE_EXISTS`) * MPI_ERR_FILE_IN_USE (`MPI_ERR_FILE_IN_USE`) * MPI_ERR_GROUP (`MPI_ERR_GROUP`) * MPI_ERR_INFO (`MPI_ERR_INFO`) * MPI_ERR_INFO_KEY (`MPI_ERR_INFO_KEY`) * MPI_ERR_INFO_NOKEY (`MPI_ERR_INFO_NOKEY`) * MPI_ERR_INFO_VALUE (`MPI_ERR_INFO_VALUE`) * MPI_ERR_IN_STATUS (`MPI_ERR_IN_STATUS`) * MPI_ERR_INTERN (`MPI_ERR_INTERN`) * MPI_ERR_IO (`MPI_ERR_IO`) * MPI_ERR_KEYVAL (`MPI_ERR_KEYVAL`) * MPI_ERR_LOCKTYPE (`MPI_ERR_LOCKTYPE`) * MPI_ERR_NAME (`MPI_ERR_NAME`) * MPI_ERR_NO_MEM (`MPI_ERR_NO_MEM`) * MPI_ERR_NO_SPACE (`MPI_ERR_NO_SPACE`) * MPI_ERR_NO_SUCH_FILE (`MPI_ERR_NO_SUCH_FILE`) * MPI_ERR_NOT_SAME (`MPI_ERR_NOT_SAME`) * MPI_ERR_OP (`MPI_ERR_OP`) * MPI_ERR_OTHER (`MPI_ERR_OTHER`) * MPI_ERR_PENDING (`MPI_ERR_PENDING`) * MPI_ERR_PORT (`MPI_ERR_PORT`) * MPI_ERR_PROC_ABORTED (`MPI_ERR_PROC_ABORTED`) * MPI_ERR_QUOTA (`MPI_ERR_QUOTA`) * MPI_ERR_RANK (`MPI_ERR_RANK`) * MPI_ERR_READ_ONLY (`MPI_ERR_READ_ONLY`) * MPI_ERR_REQUEST (`MPI_ERR_REQUEST`) * MPI_ERR_RMA_ATTACH (`MPI_ERR_RMA_ATTACH`) * MPI_ERR_RMA_CONFLICT (`MPI_ERR_RMA_CONFLICT`) * MPI_ERR_RMA_FLAVOR (`MPI_ERR_RMA_FLAVOR`) * MPI_ERR_RMA_RANGE (`MPI_ERR_RMA_RANGE`) * MPI_ERR_RMA_SHARED (`MPI_ERR_RMA_SHARED`) * MPI_ERR_RMA_SYNC (`MPI_ERR_RMA_SYNC`) * MPI_ERR_ROOT (`MPI_ERR_ROOT`) * MPI_ERR_SERVICE (`MPI_ERR_SERVICE`) * MPI_ERR_SESSION (`MPI_ERR_SESSION`) * MPI_ERR_SIZE (`MPI_ERR_SIZE`) * MPI_ERR_SPAWN (`MPI_ERR_SPAWN`) * MPI_ERR_TAG (`MPI_ERR_TAG`) * MPI_ERR_TOPOLOGY (`MPI_ERR_TOPOLOGY`) * MPI_ERR_TRUNCATE (`MPI_ERR_TRUNCATE`) * MPI_ERR_TYPE (`MPI_ERR_TYPE`) * MPI_ERR_UNKNOWN (`MPI_ERR_UNKNOWN`) * MPI_ERR_UNSUPPORTED_DATAREP (`MPI_ERR_UNSUPPORTED_DATAREP`) * MPI_ERR_UNSUPPORTED_OPERATION (`MPI_ERR_UNSUPPORTED_OPERATION`) * MPI_ERR_VALUE_TOO_LARGE (`MPI_ERR_VALUE_TOO_LARGE`) * MPI_ERR_WIN (`MPI_ERR_WIN`) * MPI_ERR_LASTCODE (`MPI_ERR_LASTCODE`){{% /markdown %}}

Operands:

Operand Description
val MPI function call return value

Results:

Result Description
res 1-bit signless integer

mpi.send (mpi::SendOp)

Equivalent to MPISend(ptr, size, dtype, dest, tag, MPICOMMWORLD)

Syntax:

operation ::= `mpi.send` `(` <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>r</mi><mi>e</mi><mi>f</mi><mi mathvariant="normal">‘</mi><mo separator="true">,</mo><mi mathvariant="normal">‘</mi></mrow><annotation encoding="application/x-tex">ref `,` </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">re</span><span class="mord mathnormal" style="margin-right:0.10764em;">f</span><span class="mord">‘</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">‘</span></span></span></span>tag `,` <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>r</mi><mi>a</mi><mi>n</mi><mi>k</mi><mi mathvariant="normal">‘</mi><mo stretchy="false">)</mo><mi mathvariant="normal">‘</mi><mi>a</mi><mi>t</mi><mi>t</mi><mi>r</mi><mo>−</mo><mi>d</mi><mi>i</mi><mi>c</mi><mi>t</mi><mi mathvariant="normal">‘</mi><mo>:</mo><mi mathvariant="normal">‘</mi><mi>t</mi><mi>y</mi><mi>p</mi><mi>e</mi><mo stretchy="false">(</mo></mrow><annotation encoding="application/x-tex">rank `)` attr-dict `:` type(</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="mord mathnormal" style="margin-right:0.03148em;">ank</span><span class="mord">‘</span><span class="mclose">)</span><span class="mord">‘</span><span class="mord mathnormal">a</span><span class="mord mathnormal">tt</span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">d</span><span class="mord mathnormal">i</span><span class="mord mathnormal">c</span><span class="mord mathnormal">t</span><span class="mord">‘</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">:</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord">‘</span><span class="mord mathnormal">t</span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mord mathnormal">p</span><span class="mord mathnormal">e</span><span class="mopen">(</span></span></span></span>ref) `,` type($tag) `,` type($rank)(`->` type($retval)^)?

MPI_Send performs a blocking send of size elements of type dtype to rank
dest. The tag value and communicator enables the library to determine
the matching of multiple sends and receives between the same ranks.

Communicators other than MPI_COMM_WORLD are not supprted for now.

This operation can optionally return an !mpi.retval value that can be used
to check for errors.

Operands:

Operand Description
ref memref of any type values
tag 32-bit signless integer
rank 32-bit signless integer

Results:

Result Description
retval MPI function call return value

Attribute definition

MPI_ErrorClassEnumAttr

MPI error class name

Syntax:

#mpi.errclass<
  ::mlir::mpi::MPI_ErrorClassEnum   # value
>

Enum cases:

Parameters:

Parameter C++ type Description
value ::mlir::mpi::MPI_ErrorClassEnum an enum of type MPI_ErrorClassEnum

Type definition

RetvalType

MPI function call return value

Syntax: !mpi.retval

This type represents a return value from an MPI function vall.
This value can be MPI_SUCCESS, MPI_ERR_IN_STATUS, or any error code.

This return value can be compared agains the known MPI error classes
represented by #mpi.errclass using the mpi.retval_check operation.