MLIR: lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
14
27 #include "llvm/ADT/TypeSwitch.h"
28
29 #include
30 #include
31
32 using namespace mlir;
34
35 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
36
37 void AMDGPUDialect::initialize() {
38 addOperations<
39 #define GET_OP_LIST
40 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
41 >();
42 addAttributes<
43 #define GET_ATTRDEF_LIST
44 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
45 >();
46 }
47
48
49
50
52 if (getExisting() && getExisting().getType() != getResult().getType())
53 return emitOpError("existing values must have same type as result");
54 return success();
55 }
56
58 if (getExisting() && getExisting().getType() != getResult().getType())
59 return emitOpError("existing values must have same type as result");
60 return success();
61 }
62
63
64
65
67 if (getExisting() && getExisting().getType() != getResult().getType())
68 return emitOpError("existing values must have same type as result");
69 return success();
70 }
71
72
73
74
75
76
77
78
79
81 bool resetOffset) {
86 MemRefLayoutAttrInterface layout = source.getLayout();
87 if (resetOffset && !layout.isIdentity()) {
88 auto stridedLayout = dyn_cast(layout);
89 if (!stridedLayout)
90 return failure();
92 }
93 return (MemRefType)(mb);
94 }
95
96 LogicalResult FatRawBufferCastOp::inferReturnTypes(
100 Adaptor adaptor(operands, attributes, properties, regions);
101 auto sourceType =
102 dyn_cast_if_present(adaptor.getSource().getType());
103 if (!sourceType)
104 return failure();
105 FailureOr resultType =
107 if (failed(resultType))
108 return failure();
110 return success();
111 }
112
114 FailureOr expectedResultType =
116 if (failed(expectedResultType))
117 return emitOpError("source type ")
118 << getSource().getType() << " can't have its offset reset";
119 if (getResult().getType() != *expectedResultType)
120 return emitOpError("expected result type to be ")
121 << *expectedResultType << " but got " << getResult().getType();
122 return success();
123 }
124
126 if (!memorySpace)
127 return true;
128 if (auto intMemorySpace = dyn_cast(memorySpace))
129 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
130 if (auto gpuMemorySpace = dyn_castgpu::AddressSpaceAttr(memorySpace))
131 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
132 return false;
133 }
134
136 if (auto intMemorySpace = dyn_cast(memorySpace))
137 return intMemorySpace.getInt() == 3;
138 if (auto gpuMemorySpace = dyn_castgpu::AddressSpaceAttr(memorySpace))
139 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
140 return false;
141 }
142
144 if (auto intMemorySpace = dyn_cast(memorySpace))
145 return intMemorySpace.getInt() == 7;
146 if (auto gpuMemorySpace = dyn_castamdgpu::AddressSpaceAttr(memorySpace))
147 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
148 return false;
149 }
150
151
152
153
154 template
156 MemRefType bufferType = llvm::cast(op.getMemref().getType());
158
159 if (!isGlobal)
160 return op.emitOpError(
161 "Buffer ops must operate on a memref in global memory");
162 if (!bufferType.hasRank())
163 return op.emitOpError(
164 "Cannot meaningfully buffer_store to an unranked memref");
165 if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
166 return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
167 " indices to memref");
168 return success();
169 }
170
172
174
177 }
178
181 }
182
185 }
186
189 }
190
193 }
194
196 APInt cst;
198 return std::nullopt;
200 return cst.getZExtValue();
201 return std::nullopt;
202 }
203
204 template
206 if (!op.getBoundsCheck())
207 return false;
208 MemRefType bufferType = op.getMemref().getType();
209 if (!bufferType.hasStaticShape())
210 return false;
211 int64_t offset;
213 if (failed(bufferType.getStridesAndOffset(strides, offset)))
214 return false;
215 int64_t result = offset + op.getIndexOffset().value_or(0);
216 if (op.getSgprOffset()) {
217 std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
218 if (!sgprOffset)
219 return false;
220 result += *sgprOffset;
221 }
222 if (strides.size() != op.getIndices().size())
223 return false;
224 int64_t indexVal = 0;
225 for (auto pair : llvm::zip(strides, op.getIndices())) {
226 int64_t stride = std::get<0>(pair);
227 Value idx = std::get<1>(pair);
229 if (!idxVal)
230 return false;
231 indexVal += stride * *idxVal;
232 }
233 result += indexVal;
235
236 return false;
237 return result >= bufferType.getNumElements();
238 }
239
240 namespace {
241 template
242 struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern {
244
245 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
247 return failure();
248 Type loadType = op.getResult().getType();
251 return success();
252 }
253 };
254
255 template
256 struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern {
258
259 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
261 return failure();
262
264 return success();
265 }
266 };
267 }
268
269 void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
271 results.add<RemoveStaticallyOobBufferLoads>(context);
272 }
273
274 void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
276 results.add<RemoveStaticallyOobBufferWrites>(context);
277 }
278
279 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
281 results.add<RemoveStaticallyOobBufferWrites>(context);
282 }
283
284 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
286 results.add<RemoveStaticallyOobBufferWrites>(context);
287 }
288
289 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
291 results.add<RemoveStaticallyOobBufferWrites>(context);
292 }
293
294 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
296 results.add<RemoveStaticallyOobBufferWrites>(context);
297 }
298
299 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
301 results.add<RemoveStaticallyOobBufferLoads>(
302 context);
303 }
304
305
306
307
309 Type sourceAType = getSourceA().getType();
310 Type sourceBType = getSourceB().getType();
311 Type destType = getDestC().getType();
312
313 VectorType sourceVectorAType = dyn_cast(sourceAType);
314 VectorType sourceVectorBType = dyn_cast(sourceBType);
315 VectorType destVectorType = dyn_cast(destType);
316
317 Type sourceAElemType = sourceVectorAType.getElementType();
318 Type sourceBElemType = sourceVectorBType.getElementType();
319 Type destElemType = destVectorType.getElementType();
320
321 if (sourceVectorAType.getNumElements() !=
322 sourceVectorBType.getNumElements()) {
323 return emitOpError("source vectors have different lengths: ")
324 << sourceVectorAType << " vs. " << sourceVectorBType;
325 }
326
327 bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
328 bool isSrcFloat =
329 isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
330 sourceAElemType);
331
332 if (isDestFloat && !isSrcFloat) {
333 return emitOpError("Expected float sources with float destination");
334 }
335
336 if (!isDestFloat && isSrcFloat) {
337 return emitOpError("Expected int sources with int destination");
338 }
339
340 if (sourceAElemType != sourceBElemType &&
341 !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
342 isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
343 return emitOpError(
344 "source element types much match (except for fp8) but have ")
345 << sourceAType << " and " << sourceBType;
346 }
347 return success();
348 }
349
350
351
352
354 constexpr uint32_t waveSize = 64;
356
357 Type sourceType = getSourceA().getType();
358 Type destType = getDestC().getType();
359
360 Type sourceElem = sourceType, destElem = destType;
361 uint32_t sourceLen = 1, destLen = 1;
362 if (auto sourceVector = llvm::dyn_cast(sourceType)) {
363 sourceLen = sourceVector.getNumElements();
364 sourceElem = sourceVector.getElementType();
365 }
366 if (auto destVector = llvm::dyn_cast(destType)) {
367 destLen = destVector.getNumElements();
368 destElem = destVector.getElementType();
369 }
370
371 Type sourceBType = getSourceB().getType();
373 int64_t sourceBLen = 1;
374 Type sourceBElem = sourceBType;
375 if (auto sourceBVector = llvm::dyn_cast(sourceBType)) {
376 sourceBLen = sourceBVector.getNumElements();
377 sourceBElem = sourceBVector.getElementType();
378 }
379 if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
381 return emitOpError("expected both source operands to have small-float "
382 "elements if one does");
383 if (sourceLen != sourceBLen)
384 return emitOpError(
385 "expected both small-float source vectors to have the same length");
386 } else {
387 if (sourceType != sourceBType)
388 return emitOpError("expected both non-small-float source operand types "
389 "to match exactly");
390 }
391
393 sourceLen *= 4;
394 sourceElem = b.getI8Type();
395 }
397 sourceLen *= 8;
398 sourceElem = b.getI8Type();
399 }
400
401 int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
402 if (sourceLen != numSourceElems)
403 return emitOpError("expected " + Twine(numSourceElems) +
404 " source values for this operation but got " +
405 Twine(sourceLen));
406
407 int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
408 if (destLen != numDestElems)
409 return emitOpError("expected " + Twine(numDestElems) +
410 " result values for this operation but got " +
411 Twine(destLen));
412
413 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
414 return emitOpError(
415 "double-precision ops do not support permuting lanes of B");
416 if (destElem.isF64() && getCbsz() != 0)
417 return emitOpError(
418 "double-precision ops do not support permuting lanes of A");
419 if (getAbid() >= (1u << getCbsz()))
420 return emitOpError(
421 "block ID for permuting A (abid) must be below 2 ** cbsz");
422
423 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
424 return emitOpError(
425 "negation flags only available for double-precision operations");
426
427 return success();
428 }
429
430
431
432
434 Type srcType = getSrc().getType();
436 return emitOpError("integer and floating point types larger than 64 bits "
437 "are not supported");
438 }
439
440 DPPPerm kind = getKind();
442
443 switch (kind) {
444
445 case DPPPerm::quad_perm: {
446 auto quadPermAttr = dyn_cast_or_null(permArgument);
447 if (!quadPermAttr || quadPermAttr.size() != 4) {
448 return emitOpError("quad_perm attribute must have exactly 4 elements");
449 }
450 for (auto elem : quadPermAttr.getAsRange()) {
451 int32_t num = elem.getInt();
452 if (num < 0 || num > 3) {
453 return emitOpError(
454 "Each element of quad_perm must be in the range [0, 3]");
455 }
456 }
457 } break;
458
459 case DPPPerm::row_shl:
460 case DPPPerm::row_shr:
461 case DPPPerm::row_ror: {
462 if (!permArgument) {
463 return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
464 "' value not specified");
465 }
466 if (auto intAttr = dyn_cast(permArgument)) {
467 uint32_t attrValue = intAttr.getInt();
468 if (attrValue < 1 || attrValue > 15) {
469 return emitOpError("Attribute value must be between 1 and 15");
470 }
471 }
472 } break;
473
474 case DPPPerm::wave_shl:
475 case DPPPerm::wave_shr:
476 case DPPPerm::wave_rol:
477 case DPPPerm::wave_ror:
478 case DPPPerm::row_mirror:
479 case DPPPerm::row_half_mirror:
480 case DPPPerm::row_bcast_15:
481 case DPPPerm::row_bcast_31: {
482 if (permArgument && !isa(permArgument)) {
483 return emitOpError("Expected unit attribute for permArgument, but found "
484 "non-trivial argument");
485 }
486 break;
487 }
488 }
489 return success();
490 }
491
493 MemRefType srcType = cast(getSrc().getType());
494 MemRefType dstType = cast(getDst().getType());
495
496 if (!dstType.areTrailingDimsContiguous(dstType.getRank()))
497 return emitOpError("destination types must be contiguous");
498
499 auto elemType = srcType.getElementType();
500
501 if (elemType != dstType.getElementType())
502 return emitOpError("source and destination element types must match");
503
504
505 auto transferType = getTransferType();
506 size_t transferSize;
507 if (auto vectorTransfer = dyn_cast(transferType)) {
508 transferSize = vectorTransfer.getNumElements() *
509 vectorTransfer.getElementTypeBitWidth();
510 } else {
511 transferSize = transferType.getIntOrFloatBitWidth();
512 }
513 if (transferSize != 8 && transferSize != 16 && transferSize != 32)
514 return emitOpError("Transfering type size must be 8, 16, or 32 bits");
515
518 return emitOpError(
519 "source memory address space must be global or fat raw buffer");
520
522 return emitOpError("destination memory address space must be Workgroup");
523
524 return success();
525 }
526
527 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
528
529 #define GET_ATTRDEF_CLASSES
530 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
531
532 #define GET_OP_CLASSES
533 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setMemorySpace(Attribute newMemorySpace)
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...