[RFC] Add GPU operations to permute data in 2 loaded mma_matrix (original) (raw)

What

Hi, we are the compiler team for Arm GPU via MLIR. We want to propose several GPU operations to construct a new subgroup matrix from the loaded subgroup matrices. We have prototyped these operations and it helps us to reduce the memory access in the Conv2D computation.

Motivation

Cooperative matrix is a well known representation to speed up matrix multiplication in GPU. It is called “cooperative” because the matrix is stored in a subgroup and the matrix multiplication is done by all threads in the subgroup. SPIR-V has an extension to define the data type for cooperative matrix[1]. NVIDIA and Intel GPU already support such representation and related operations[2][3]. In MLIR, we have a set of “subgroup_mma” operations in gpu dialect for cooperative matrix.

When there are multiple gpu subgroup matrix loads with overlapped regions, part of the data will be loaded repeatedly. For example, we have the following matrix and we want to load 4x4 subgroup matrices starting from positions 00, 10, 20, and 30.

+-----+-----+-----+-----+
| 00  | 01  | 02  | 03  |
+-----+-----+-----+-----+
| 10  | 11  | 12  | 13  |
+-----+-----+-----+-----+
| 20  | 21  | 22  | 23  |
+-----+-----+-----+-----+
| 30  | 31  | 32  | 33  |
+-----+-----+-----+-----+
| 40  | 41  | 42  | 43  |
+-----+-----+-----+-----+
| 50  | 51  | 52  | 53  |
+-----+-----+-----+-----+
| 60  | 61  | 62  | 63  |
+-----+-----+-----+-----+
| 70  | 71  | 72  | 73  |
+-----+-----+-----+-----+

We will get four matrices as follow.

+-----+-----+-----+-----+
| 00  | 01  | 02  | 03  |
+-----+-----+-----+-----+
| 10  | 11  | 12  | 13  |
+-----+-----+-----+-----+
| 20  | 21  | 22  | 23  |
+-----+-----+-----+-----+
| 30  | 31  | 32  | 33  |
+-----+-----+-----+-----+

+-----+-----+-----+-----+
| 10  | 11  | 12  | 13  |
+-----+-----+-----+-----+
| 20  | 21  | 22  | 23  |
+-----+-----+-----+-----+
| 30  | 31  | 32  | 33  |
+-----+-----+-----+-----+
| 40  | 41  | 42  | 43  |
+-----+-----+-----+-----+

+-----+-----+-----+-----+
| 20  | 21  | 22  | 23  |
+-----+-----+-----+-----+
| 30  | 31  | 32  | 33  |
+-----+-----+-----+-----+
| 40  | 41  | 42  | 43  |
+-----+-----+-----+-----+
| 50  | 51  | 52  | 53  |
+-----+-----+-----+-----+

+-----+-----+-----+-----+
| 30  | 31  | 32  | 33  |
+-----+-----+-----+-----+
| 40  | 41  | 42  | 43  |
+-----+-----+-----+-----+
| 50  | 51  | 52  | 53  |
+-----+-----+-----+-----+
| 60  | 61  | 62  | 63  |
+-----+-----+-----+-----+

Apparently, we have several data loaded repeatedly. In SPIR-V, there is an operation, GroupNonUniformRotateKHR, to rotate values in a subgroup under the SPV_KHR_subgroup_rotate extension[4]. We already implemented it in the SPIR-V dialect[5]. We are thinking it should be possible to extract values from the subgroup matrices, rotate values in a subgroup, and insert values to a newly constructed subgroup matrix. In this way, we can avoid to load the values that already reside in GPU registers. For the above example, we can load two disjoint 4x4 matrices starting from positions 00 and 40 first. We have all values in the subgroup registers. After that, we use CompositeConstruct, CompositeExtract, CompositeInsert, and GroupNonUniformRotateKHR to construct 4x4 subgroup matrices starting from 10, 20, and 30. In this way, we can reduce the number of subgroup matrix loads from 4 to 2. However, we cannot find operations in the gpu dialect that can help us to rotate values in a subgroup. That’s why we propose new gpu operations to do it.

Proposals

We have 2 proposals to achieve our goal.

Proposal 1:

We propose to add a new GPU operation, gpu.subgroup_mma_rotate, to permute two loaded MMA matrices. The syntax is

gpu.subgroup_mma_rotate <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>o</mi><mi>p</mi><mi>A</mi><mi mathvariant="normal">‘</mi><mo separator="true">,</mo><mi mathvariant="normal">‘</mi></mrow><annotation encoding="application/x-tex">opA`,` </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">o</span><span class="mord mathnormal">p</span><span class="mord mathnormal">A</span><span class="mord">‘</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">‘</span></span></span></span>opB`,` <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>o</mi><mi>f</mi><mi>f</mi><mi>s</mi><mi>e</mi><mi>t</mi><mi>a</mi><mi>t</mi><mi>t</mi><mi>r</mi><mo>−</mo><mi>d</mi><mi>i</mi><mi>c</mi><mi>t</mi><mi mathvariant="normal">‘</mi><mo>:</mo><mi mathvariant="normal">‘</mi><mi>t</mi><mi>y</mi><mi>p</mi><mi>e</mi><mo stretchy="false">(</mo></mrow><annotation encoding="application/x-tex">offset attr-dict `:` type(</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right:0.10764em;">ff</span><span class="mord mathnormal">se</span><span class="mord mathnormal">t</span><span class="mord mathnormal">a</span><span class="mord mathnormal">tt</span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">d</span><span class="mord mathnormal">i</span><span class="mord mathnormal">c</span><span class="mord mathnormal">t</span><span class="mord">‘</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">:</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord">‘</span><span class="mord mathnormal">t</span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mord mathnormal">p</span><span class="mord mathnormal">e</span><span class="mopen">(</span></span></span></span>opA)`,` type($opB)`,` type($offset) `->` type($res)

This operation takes 2 subgroup matrices with the same type. Use “offset” as the starting position of the first subgroup matrix and append the beginning “offset” of elements in the second subgroup matrix to the end of the result. The result type is the same as the operands. For example, there are 16 elements, TA0 to TA15, in a 4x4 subgroup matrix and TB0 to TB15 in the second matrix. When offset is 1, it will use TA1 to TA15 plus TB0 to construct 4x4 result subgroup matrix.

Use the example above, we assume we already have 4x4 matrices starting from 00, and 40.

MMA0:
+-----+-----+-----+-----+
| 00  | 01  | 02  | 03  |
+-----+-----+-----+-----+
| 10  | 11  | 12  | 13  |
+-----+-----+-----+-----+
| 20  | 21  | 22  | 23  |
+-----+-----+-----+-----+
| 30  | 31  | 32  | 33 
+-----+-----+-----+-----+
 
MMA1:
+-----+-----+-----+-----+
| 40  | 41  | 42  | 43  |
+-----+-----+-----+-----+
| 50  | 51  | 52  | 53  |
+-----+-----+-----+-----+
| 60  | 61  | 62  | 63  |
+-----+-----+-----+-----+
| 70  | 71  | 72  | 73  |
+-----+-----+-----+-----+
 
Arrangement of invocation ID in the subgroup with size 16:
+-----+-----+-----+-----+
|  0  |  1  |  2  |  3  |
+-----+-----+-----+-----+
|  4  |  5  |  6  |  7  |
+-----+-----+-----+-----+
|  8  |  9  | 10  | 11  |
+-----+-----+-----+-----+
| 12  | 13  | 14  | 15  |
+-----+-----+-----+-----+

To construct MMA matrix starting from 10, we can specify the offset to 4.

gpu.subgroup_mma_rotate %MMA0, %MMA1, %c4 : !gpu.mma_matrix<4x4xf32, "AOp">, !gpu.mma_matrix<4x4xf32, "AOp">, i32 -> !gpu.mma_matrix<4x4xf32, "AOp">
 
MMA0:
+-----+-----+-----+-----+
| 00  | 01  | 02  | 03  |
+-----+-----+-----+-----+ -+
| 10  | 11  | 12  | 13  |  |
+-----+-----+-----+-----+  |
| 20  | 21  | 22  | 23  |  |
+-----+-----+-----+-----+  | => +-----+-----+-----+-----+
| 30  | 31  | 32  | 33  |  |    | 10  | 11  | 12  | 13  |
+-----+-----+-----+-----+  |    +-----+-----+-----+-----+
                           |    | 20  | 21  | 22  | 23  |
MMA1:                      |    +-----+-----+-----+-----+
+-----+-----+-----+-----+  |    | 30  | 31  | 32  | 33  |
| 40  | 41  | 42  | 43  |  |    +-----+-----+-----+-----+
+-----+-----+-----+-----+ -+    | 40  | 41  | 42  | 43  |
| 50  | 51  | 52  | 53  |       +-----+-----+-----+-----+
+-----+-----+-----+-----+
| 60  | 61  | 62  | 63  |
+-----+-----+-----+-----+
| 70  | 71  | 72  | 73  |
+-----+-----+-----+-----+

We also propose a pattern to lower to SPIR-V as follow. We already have a prototype to verify its correctness.

%cst_f32 = spirv.Constant 0.000000e+00 : f32
%created_ma = spirv.CompositeConstruct %cst_f32 : (f32) -> !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
%val0 = spirv.CompositeExtract %mma0[0 : i32] : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
%val1 = spirv.CompositeExtract %mma1[0 : i32] : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
%invocation_id = calculation of invocation ID
%comp = spirv.SGreaterThanEqual %invocation_id, %offset : i32
%extract_val = spirv.Select %comp, %val0, %val1 : i1, f32
%rotated_val = spirv.GroupNonUniformRotateKHR <Subgroup>, %extract_val, %offset : f32
%val = spirv.CompositeInsert %rotated_val, %created_ma[0 : i32] : f32 into !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>

Proposal 2:

According to the generated SPIR-V of the proposal 1, we can add several gpu operations to achieve the same goal. What we are missing is to extract value from subgroup matrix, insert value to subgroup matrix and rotate within a subgroup. So, we can add the following three gpu operations.

gpu.subgroup_mma_extract <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>o</mi><mi>p</mi><mi>A</mi><mi mathvariant="normal">‘</mi><mo stretchy="false">[</mo><mi mathvariant="normal">‘</mi><mi>i</mi><mi>n</mi><mi>t</mi><mi>e</mi><mi>g</mi><mi>e</mi><mi>r</mi><mo>−</mo><mi>l</mi><mi>i</mi><mi>t</mi><mi>e</mi><mi>r</mi><mi>a</mi><mi>l</mi><msup><mo stretchy="false">(</mo><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><msup><mo separator="true">,</mo><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mi>i</mi><mi>n</mi><mi>t</mi><mi>e</mi><mi>g</mi><mi>e</mi><mi>r</mi><mo>−</mo><mi>l</mi><mi>i</mi><mi>t</mi><mi>e</mi><mi>r</mi><mi>a</mi><mi>l</mi><mo stretchy="false">)</mo><mo>∗</mo><mi mathvariant="normal">‘</mi><mo stretchy="false">]</mo><mi mathvariant="normal">‘</mi><mi>a</mi><mi>t</mi><mi>t</mi><mi>r</mi><mo>−</mo><mi>d</mi><mi>i</mi><mi>c</mi><mi>t</mi><mi mathvariant="normal">‘</mi><mo>:</mo><mi mathvariant="normal">‘</mi><mi>t</mi><mi>y</mi><mi>p</mi><mi>e</mi><mo stretchy="false">(</mo></mrow><annotation encoding="application/x-tex">opA`[` integer-literal (&#x27;,&#x27; integer-literal)* `]` attr-dict `:` type(</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">o</span><span class="mord mathnormal">p</span><span class="mord mathnormal">A</span><span class="mord">‘</span><span class="mopen">[</span><span class="mord">‘</span><span class="mord mathnormal">in</span><span class="mord mathnormal">t</span><span class="mord mathnormal">e</span><span class="mord mathnormal" style="margin-right:0.03588em;">g</span><span class="mord mathnormal" style="margin-right:0.02778em;">er</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1.0019em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mord mathnormal">i</span><span class="mord mathnormal">t</span><span class="mord mathnormal" style="margin-right:0.02778em;">er</span><span class="mord mathnormal">a</span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mopen"><span class="mopen">(</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7519em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mpunct"><span class="mpunct">,</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7519em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal">in</span><span class="mord mathnormal">t</span><span class="mord mathnormal">e</span><span class="mord mathnormal" style="margin-right:0.03588em;">g</span><span class="mord mathnormal" style="margin-right:0.02778em;">er</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mord mathnormal">i</span><span class="mord mathnormal">t</span><span class="mord mathnormal" style="margin-right:0.02778em;">er</span><span class="mord mathnormal">a</span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord">‘</span><span class="mclose">]</span><span class="mord">‘</span><span class="mord mathnormal">a</span><span class="mord mathnormal">tt</span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">d</span><span class="mord mathnormal">i</span><span class="mord mathnormal">c</span><span class="mord mathnormal">t</span><span class="mord">‘</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">:</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord">‘</span><span class="mord mathnormal">t</span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mord mathnormal">p</span><span class="mord mathnormal">e</span><span class="mopen">(</span></span></span></span>opA) `->` type($res)
gpu.subgroup_mma_insert <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>v</mi><mi>a</mi><mi>l</mi><mi mathvariant="normal">‘</mi><mo separator="true">,</mo><mi mathvariant="normal">‘</mi></mrow><annotation encoding="application/x-tex">val`,` </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">v</span><span class="mord mathnormal">a</span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mord">‘</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">‘</span></span></span></span>opA`[` integer-literal (',' integer-literal)* `]` attr-dict `:` type($val)`,` type($opA) `->` type($res)
gpu.rotate <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>v</mi><mi>a</mi><mi>l</mi><mi mathvariant="normal">‘</mi><mo separator="true">,</mo><mi mathvariant="normal">‘</mi></mrow><annotation encoding="application/x-tex">val`,` </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">v</span><span class="mord mathnormal">a</span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mord">‘</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">‘</span></span></span></span>offset`,` <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>w</mi><mi>i</mi><mi>d</mi><mi>t</mi><mi>h</mi><mi>a</mi><mi>t</mi><mi>t</mi><mi>r</mi><mo>−</mo><mi>d</mi><mi>i</mi><mi>c</mi><mi>t</mi><mi mathvariant="normal">‘</mi><mo>:</mo><mi mathvariant="normal">‘</mi><mi>t</mi><mi>y</mi><mi>p</mi><mi>e</mi><mo stretchy="false">(</mo></mrow><annotation encoding="application/x-tex">width attr-dict `:` type(</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7778em;vertical-align:-0.0833em;"></span><span class="mord mathnormal" style="margin-right:0.02691em;">w</span><span class="mord mathnormal">i</span><span class="mord mathnormal">d</span><span class="mord mathnormal">t</span><span class="mord mathnormal">ha</span><span class="mord mathnormal">tt</span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">d</span><span class="mord mathnormal">i</span><span class="mord mathnormal">c</span><span class="mord mathnormal">t</span><span class="mord">‘</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">:</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord">‘</span><span class="mord mathnormal">t</span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mord mathnormal">p</span><span class="mord mathnormal">e</span><span class="mopen">(</span></span></span></span>val)`,` type($offset)`,` type($width) `->` type(res)

For lowering to SPIR-V, gpu.subgroup_mma_extract will be lowered to spirv.CompositeExtract, gpu.subgroup_mma_insert to spirv.CompositeInsert, and gpu.rotate to spirv.GroupNonUniformRotateKHR.

Reference

  1. SPV_KHR_cooperative_matrix
  2. Machine Learning Acceleration in Vulkan with Cooperative Matrices | NVIDIA Technical Blog
  3. https://www.phoronix.com/news/Intel-Xe2-Coop-Matrix-Enable
  4. SPV_KHR_subgroup_rotate
  5. [mlir][spirv] Add instruction OpGroupNonUniformRotateKHR by Hsiangkai · Pull Request #133428 · llvm/llvm-project · GitHub