MLIR: lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/StringSwitch.h"
30
31 #include
32
33 namespace mlir {
34
35
36
37
38
39
40
41
42
44 gpu::SubgroupMmaElementwiseOp op, Type coopType,
46 assert((isaspirv::CooperativeMatrixType(coopType)));
47
48 switch (op.getOpType()) {
49 case gpu::MMAElementwiseOp::ADDF:
51 return true;
54 return true;
55 case gpu::MMAElementwiseOp::SUBF:
57 return true;
60 return true;
61 case gpu::MMAElementwiseOp::DIVF:
63 return true;
64 case gpu::MMAElementwiseOp::DIVS:
66 return true;
67 case gpu::MMAElementwiseOp::DIVU:
69 return true;
70 case gpu::MMAElementwiseOp::NEGATEF:
72 return true;
73 case gpu::MMAElementwiseOp::NEGATES:
75 return true;
76 case gpu::MMAElementwiseOp::EXTF:
77 builder.replaceOpWithNewOpspirv::FConvertOp(op, coopType, operands);
78 return true;
79 default:
80 break;
81 }
82 return false;
83 }
84
86 assert(!operands.empty());
87 if (!llvm::all_equal(
88 llvm::map_range(operands, [](Value v) { return v.getType(); })))
89 return false;
90
91 return isaspirv::CooperativeMatrixType(operands.front().getType());
92 }
93
94 namespace {
95
96
97 struct WmmaConstantOpToSPIRVLowering final
98 : OpConversionPatterngpu::SubgroupMmaConstantMatrixOp {
100
101 LogicalResult
102 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
103 ConversionPatternRewriter &rewriter) const override {
104 Value cst = llvm::getSingleElement(adaptor.getOperands());
105 auto coopType = getTypeConverter()->convertType(op.getType());
106 if (!coopType)
107 return rewriter.notifyMatchFailure(op, "type conversion failed");
108
109 rewriter.replaceOpWithNewOpspirv::CompositeConstructOp(op, coopType, cst);
110 return success();
111 }
112 };
113
114
115
116 struct WmmaExtractOpToSPIRVLowering final
117 : OpConversionPatterngpu::SubgroupMmaExtractThreadLocalOp {
119
120 LogicalResult
121 matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,
122 ConversionPatternRewriter &rewriter) const override {
123 Value matrix = adaptor.getMatrix();
124 auto coopType =
125 getTypeConverter()->convertTypespirv::CooperativeMatrixType(
126 matrix.getType());
127 if (!coopType)
128 return rewriter.notifyMatchFailure(op, "type conversion failed");
129
130 SmallVector<int32_t> intValues;
131 for (Value val : op.getIndices()) {
132 if (auto constOp = val.getDefiningOparith::ConstantIndexOp()) {
133 intValues.push_back(static_cast<int32_t>(constOp.value()));
134 } else {
135 return rewriter.notifyMatchFailure(op, "indices must be constants");
136 }
137 }
138
139 Type elementType = coopType.getElementType();
140 rewriter.replaceOpWithNewOpspirv::CompositeExtractOp(
141 op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
142 return success();
143 }
144 };
145
146
147
148 struct WmmaInsertOpToSPIRVLowering final
149 : OpConversionPatterngpu::SubgroupMmaInsertThreadLocalOp {
151
152 LogicalResult
153 matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,
154 ConversionPatternRewriter &rewriter) const override {
155 Value value = adaptor.getValue();
156 Value matrix = adaptor.getMatrix();
157 auto coopType = getTypeConverter()->convertType(matrix.getType());
158 if (!coopType)
159 return rewriter.notifyMatchFailure(op, "type conversion failed");
160
161 SmallVector<int32_t> intValues;
162 for (Value val : op.getIndices()) {
163 if (auto constOp = val.getDefiningOparith::ConstantIndexOp()) {
164 intValues.push_back(static_cast<int32_t>(constOp.value()));
165 } else {
166 return rewriter.notifyMatchFailure(op, "indices must be constants");
167 }
168 }
169
170 rewriter.replaceOpWithNewOpspirv::CompositeInsertOp(
171 op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
172 return success();
173 }
174 };
175
176
177
178 struct WmmaElementwiseOpToSPIRVDefaultLowering final
179 : OpConversionPatterngpu::SubgroupMmaElementwiseOp {
181
182 LogicalResult
183 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
184 ConversionPatternRewriter &rewriter) const override {
185
187 return rewriter.notifyMatchFailure(op,
188 "not all operands are coop matrices");
189 }
190
191 auto coopType = getTypeConverter()->convertType(op.getType());
192 if (!coopType)
193 return rewriter.notifyMatchFailure(op, "type conversion failed");
194
195 return success(
197 }
198 };
199
200
201
202 struct WmmaElementwiseOpToSPIRVScalarMulLowering final
203 : OpConversionPatterngpu::SubgroupMmaElementwiseOp {
205
206 LogicalResult
207 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
208 ConversionPatternRewriter &rewriter) const override {
209 if (adaptor.getOperands().size() != 2)
210 return failure();
211
212
214 return rewriter.notifyMatchFailure(op,
215 "not all operands are coop matrices");
216 }
217
218 if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
219 return failure();
220
221
222
223 Value lhs = op.getOperands().front();
224 Value rhs = op.getOperands().back();
225 Value splat = nullptr;
226 Value matrix = nullptr;
227 if (lhs.getDefiningOpgpu::SubgroupMmaConstantMatrixOp()) {
228 splat = adaptor.getOperands().front();
229 matrix = adaptor.getOperands().back();
230 } else if (rhs.getDefiningOpgpu::SubgroupMmaConstantMatrixOp()) {
231 matrix = adaptor.getOperands().front();
232 splat = adaptor.getOperands().back();
233 }
234 if (!splat || !matrix)
235 return rewriter.notifyMatchFailure(op, "no splat operand");
236
237
239 auto cc = splat.getDefiningOpspirv::CompositeConstructOp();
240 if (!cc) {
241 return rewriter.notifyMatchFailure(op,
242 "splat is not a composite construct");
243 }
244
245 scalar = llvm::getSingleElement(cc.getConstituents());
246
247 auto coopType = getTypeConverter()->convertType(op.getType());
248 if (!coopType)
249 return rewriter.notifyMatchFailure(op, "type conversion failed");
250 rewriter.replaceOpWithNewOpspirv::MatrixTimesScalarOp(
251 op, coopType, ValueRange{matrix, scalar});
252 return success();
253 }
254 };
255 }
256
257
258
259
260
261 namespace khr {
262 namespace {
263
264
265
266 struct WmmaLoadOpToSPIRVLowering final
269
270 LogicalResult
271 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
273 const auto &typeConverter = *getTypeConverter();
275
276 auto retType = castgpu::MMAMatrixType(op.getRes().getType());
277 MemRefType memrefType = op.getSrcMemref().getType();
278 Value bufferPtr =
280 adaptor.getIndices(), loc, rewriter);
281
282 auto coopType =
284 if (!coopType)
286
287 int64_t stride = op.getLeadDimension().getSExtValue();
288 IntegerType i32Type = rewriter.getI32Type();
289 auto strideValue = rewriter.createspirv::ConstantOp(
291
292 bool isColMajor = op.getTranspose().value_or(false);
293 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
294 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
295
297 op, coopType, bufferPtr, strideValue, layout);
298 return success();
299 }
300 };
301
302
303
304 struct WmmaStoreOpToSPIRVLowering final
307
308 LogicalResult
309 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
311 const auto &typeConverter = *getTypeConverter();
313
314 auto memrefType = cast(op.getDstMemref().getType());
315 Value bufferPtr =
317 adaptor.getIndices(), loc, rewriter);
318
319 int64_t stride = op.getLeadDimension().getSExtValue();
320 IntegerType i32Type = rewriter.getI32Type();
321 auto strideValue = rewriter.createspirv::ConstantOp(
323
324 bool isColMajor = op.getTranspose().value_or(false);
325 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
326 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
327
329 op, bufferPtr, adaptor.getSrc(), strideValue, layout);
330 return success();
331 }
332 };
333
334
335
336 struct WmmaMmaOpToSPIRVLowering final
339
340 LogicalResult
341 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
342 OpAdaptor adaptor,
345 subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
346 adaptor.getOpC());
347 return success();
348 }
349 };
350
351 }
352 }
353 }
354
357 using namespace mlir;
359 patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
360 khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
361 WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
362 WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
363
364 patterns.add(converter, context,
365 2);
366 }
367
373 auto use =
375 .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
376 .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
377 .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
378
380 retTypeShape[1],
381 spirv::Scope::Subgroup, use);
382 });
383 }
This class implements a pattern rewriter for use with ConversionPatterns.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Type getElementType() const
Get elementType of a single element.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
@ Type
An inlay hint that for a type annotation.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands)
Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op when the elementwise op dire...
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV, using the KHR Cooperative Ma...
const FrozenRewritePatternSet & patterns
bool allOperandsHaveSameCoopMatrixType(ValueRange operands)
void populateMMAToSPIRVCoopMatrixTypeConversion(SPIRVTypeConverter &typeConverter)
Adds MMAMatrixType conversions to SPIR-V cooperative matrix KHR type conversion to the type converter...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...