MLIR: lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
21
22 #include "../LLVMCommon/MemRefDescriptor.h"
23
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/ErrorHandling.h"
28 #include
29
30 namespace mlir {
31 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
32 #include "mlir/Conversion/Passes.h.inc"
33 }
34
35 using namespace mlir;
37
38
43
44
47 IntegerType i32 = rewriter.getI32Type();
48
49 auto valTy = cast(val.getType());
50 if (i32 == valTy)
51 return val;
52 return valTy.getWidth() > 32
53 ? Value(rewriter.createLLVM::TruncOp(loc, i32, val))
54 : Value(rewriter.createLLVM::ZExtOp(loc, i32, val));
55 }
56
58 Location loc, int32_t value) {
60 return rewriter.createLLVM::ConstantOp(loc, i32, value);
61 }
62
64 bool value) {
66 return rewriter.createLLVM::ConstantOp(loc, llvmI1, value);
67 }
68
69
73 IntegerType i32 = rewriter.getI32Type();
75 for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
76 if (stride != 1) {
77 Value strideValue =
78 ShapedType::isDynamic(stride)
80 memRefDescriptor.stride(rewriter, loc, i))
81 : rewriter.createLLVM::ConstantOp(loc, i32, stride);
82 increment = rewriter.createLLVM::MulOp(loc, increment, strideValue);
83 }
84 index =
85 index ? rewriter.createLLVM::AddOp(loc, index, increment) : increment;
86 }
88 }
89
90
91
92
94 MemRefType memrefType,
97 uint32_t elementByteWidth) {
98 if (memrefType.hasStaticShape() &&
99 !llvm::any_of(strides, ShapedType::isDynamic)) {
100 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
102 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
103 size = std::max(shape[i] * strides[i], size);
104 size = size * elementByteWidth;
106 "the memref buffer is too large");
107 return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
108 }
110 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
111 Value size = memrefDescriptor.size(rewriter, loc, i);
112 Value stride = memrefDescriptor.stride(rewriter, loc, i);
113 Value maxThisDim = rewriter.createLLVM::MulOp(loc, size, stride);
114 maxIndex = maxIndex
115 ? rewriter.createLLVM::UMaxOp(loc, maxIndex, maxThisDim)
116 : maxThisDim;
117 }
120 return rewriter.createLLVM::MulOp(loc, maxIndexI32, byteWidthConst);
121 }
122
125 bool boundsCheck, amdgpu::Chipset chipset,
126 Value cacheSwizzleStride = nullptr,
127 unsigned addressSpace = 8) {
128
129
130
133 if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
134 Value cacheStrideZext =
135 rewriter.createLLVM::ZExtOp(loc, i16, cacheSwizzleStride);
136 Value swizzleBit = rewriter.createLLVM::ConstantOp(
138 stride = rewriter.createLLVM::OrOp(loc, cacheStrideZext, swizzleBit,
139 true);
140 } else {
141 stride = rewriter.createLLVM::ConstantOp(loc, i16,
143 }
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159 uint32_t flags = (7 << 12) | (4 << 15);
160 if (chipset.majorVersion >= 10) {
161 flags |= (1 << 24);
162 uint32_t oob = boundsCheck ? 3 : 2;
163 flags |= (oob << 28);
164 }
166 Type rsrcType =
169 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
170 return resource;
171 }
172
173 namespace {
174 struct FatRawBufferCastLowering
178 chipset(chipset) {}
179
181
182 LogicalResult
183 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
186 Value memRef = adaptor.getSource();
187 Value unconvertedMemref = op.getSource();
188 MemRefType memrefType = cast(unconvertedMemref.getType());
190
192 int64_t elementByteWidth =
194
195 int64_t unusedOffset = 0;
197 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
198 return op.emitOpError("Can't lower non-stride-offset memrefs");
199
200 Value numRecords = adaptor.getValidBytes();
201 if (!numRecords)
202 numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
203 strideVals, elementByteWidth);
204
205 Value basePointer =
206 adaptor.getResetOffset()
207 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
208 memrefType)
209 : descriptor.alignedPtr(rewriter, loc);
210
211 Value offset = adaptor.getResetOffset()
212 ? rewriter.createLLVM::ConstantOp(
213 loc, getIndexType(), rewriter.getIndexAttr(0))
214 : descriptor.offset(rewriter, loc);
215
216 bool hasSizes = memrefType.getRank() > 0;
217
218
219 Value sizes = hasSizes ? rewriter.createLLVM::ExtractValueOp(
222 Value strides = hasSizes
223 ? rewriter.createLLVM::ExtractValueOp(
226
228 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
229 chipset, adaptor.getCacheSwizzleStride(), 7);
230
232 rewriter, loc,
233 getTypeConverter()->convertType(op.getResult().getType()));
234 result = rewriter.createLLVM::InsertValueOp(
236 result = rewriter.createLLVM::InsertValueOp(
238 result = rewriter.createLLVM::InsertValueOp(loc, result, offset,
240 if (hasSizes) {
241 result = rewriter.createLLVM::InsertValueOp(loc, result, sizes,
243 result = rewriter.createLLVM::InsertValueOp(
245 }
247 return success();
248 }
249 };
250
251
252 template <typename GpuOp, typename Intrinsic>
256
258 static constexpr uint32_t maxVectorOpWidth = 128;
259
260 LogicalResult
261 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
263 Location loc = gpuOp.getLoc();
264 Value memref = adaptor.getMemref();
265 Value unconvertedMemref = gpuOp.getMemref();
266 MemRefType memrefType = cast(unconvertedMemref.getType());
267
269 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
270
271 Value storeData = adaptor.getODSOperands(0)[0];
272 if (storeData == memref)
273 storeData = Value();
274 Type wantedDataType;
275 if (storeData)
276 wantedDataType = storeData.getType();
277 else
278 wantedDataType = gpuOp.getODSResults(0)[0].getType();
279
281
282 if (storeData) {
283 Value maybeCmpData = adaptor.getODSOperands(1)[0];
284 if (maybeCmpData != memref)
285 atomicCmpData = maybeCmpData;
286 }
287
288 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
289
291
292
294 int64_t elementByteWidth =
297
298
299
300
301
302
303 Type llvmBufferValType = llvmWantedDataType;
304 if (atomicCmpData) {
305 if (auto floatType = dyn_cast(wantedDataType))
306 llvmBufferValType = this->getTypeConverter()->convertType(
308 }
309 if (auto dataVector = dyn_cast(wantedDataType)) {
310 uint32_t vecLen = dataVector.getNumElements();
311 uint32_t elemBits =
313 uint32_t totalBits = elemBits * vecLen;
314 bool usePackedFp16 =
315 isa_and_present(*gpuOp) && vecLen == 2;
316 if (totalBits > maxVectorOpWidth)
317 return gpuOp.emitOpError(
318 "Total width of loads or stores must be no more than " +
319 Twine(maxVectorOpWidth) + " bits, but we call for " +
320 Twine(totalBits) +
321 " bits. This should've been caught in validation");
322 if (!usePackedFp16 && elemBits < 32) {
323 if (totalBits > 32) {
324 if (totalBits % 32 != 0)
325 return gpuOp.emitOpError("Load or store of more than 32-bits that "
326 "doesn't fit into words. Can't happen\n");
327 llvmBufferValType = this->typeConverter->convertType(
329 } else {
330 llvmBufferValType = this->typeConverter->convertType(
332 }
333 }
334 }
335 if (auto vecType = dyn_cast(llvmBufferValType)) {
336
337
338 if (vecType.getNumElements() == 1)
339 llvmBufferValType = vecType.getElementType();
340 }
341
343 if (storeData) {
344 if (llvmBufferValType != llvmWantedDataType) {
345 Value castForStore =
346 rewriter.createLLVM::BitcastOp(loc, llvmBufferValType, storeData);
347 args.push_back(castForStore);
348 } else {
349 args.push_back(storeData);
350 }
351 }
352
353 if (atomicCmpData) {
354 if (llvmBufferValType != llvmWantedDataType) {
355 Value castForCmp = rewriter.createLLVM::BitcastOp(
356 loc, llvmBufferValType, atomicCmpData);
357 args.push_back(castForCmp);
358 } else {
359 args.push_back(atomicCmpData);
360 }
361 }
362
363
364 int64_t offset = 0;
366 if (failed(memrefType.getStridesAndOffset(strides, offset)))
367 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
368
370
371 Value ptr = memrefDescriptor.bufferPtr(
372 rewriter, loc, *this->getTypeConverter(), memrefType);
374 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
376 adaptor.getBoundsCheck(), chipset);
377 args.push_back(resource);
378
379
381 adaptor.getIndices(), strides);
382 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
383 indexOffset && *indexOffset > 0) {
385 voffset =
386 voffset ? rewriter.createLLVM::AddOp(loc, voffset, extraOffsetConst)
387 : extraOffsetConst;
388 }
389 voffset = rewriter.createLLVM::MulOp(loc, voffset, byteWidthConst);
390 args.push_back(voffset);
391
392
393 Value sgprOffset = adaptor.getSgprOffset();
394 if (!sgprOffset)
396 sgprOffset = rewriter.createLLVM::MulOp(loc, sgprOffset, byteWidthConst);
397 args.push_back(sgprOffset);
398
399
400
401
403
405 llvmBufferValType);
406 Operation *lowered = rewriter.create(loc, resultTypes, args,
410 if (llvmBufferValType != llvmWantedDataType) {
411 replacement = rewriter.createLLVM::BitcastOp(loc, llvmWantedDataType,
412 replacement);
413 }
414 rewriter.replaceOp(gpuOp, replacement);
415 } else {
417 }
418 return success();
419 }
420 };
421
425
427
428 LogicalResult
429 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
431 bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
432
433 if (requiresInlineAsm) {
435 LLVM::AsmDialect::AD_ATT);
436 const char *asmStr =
437 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
438 const char *constraints = "";
440 op,
442 asmStr, constraints, true,
444 asmDialectAttr,
445 ArrayAttr());
446 return success();
447 }
449 constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
450 constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
451
452
453 constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
454
455 int32_t ldsOnlyBits;
457 ldsOnlyBits = ldsOnlyBitsGfx11;
459 ldsOnlyBits = ldsOnlyBitsGfx10;
461 ldsOnlyBits = ldsOnlyBitsGfx6789;
462 else
463 return op.emitOpError(
464 "don't know how to lower this for chipset major version")
466
468 rewriter.createROCDL::SWaitcntOp(loc, ldsOnlyBits);
470 } else {
472 rewriter.createROCDL::WaitDscntOp(loc, 0);
473 rewriter.createROCDL::BarrierSignalOp(loc, -1);
475 }
476
477 return success();
478 }
479 };
480
484
486
487 LogicalResult
488 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
491 (uint32_t)op.getOpts());
492 return success();
493 }
494 };
495
496 }
497
498
499
500
501
502
503
504
505
506
507
508
509
510
514 if (auto vectorType = dyn_cast(inputType)) {
515 if (vectorType.getElementType().isBF16())
516 return rewriter.createLLVM::BitcastOp(
517 loc, vectorType.clone(rewriter.getI16Type()), input);
518 if (vectorType.getElementType().isInteger(8) &&
519 vectorType.getNumElements() <= 8)
520 return rewriter.createLLVM::BitcastOp(
521 loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
522 if (isa(vectorType.getElementType()) &&
523 vectorType.getElementTypeBitWidth() <= 8) {
525 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
526 32);
527 return rewriter.createLLVM::BitcastOp(
529 }
530 }
531 return input;
532 }
533
534
535
536
537
538
539
540
541
542
543
548 if (auto intType = dyn_cast(inputType))
549 return rewriter.createLLVM::ZExtOp(loc, outputType, input);
550 return rewriter.createLLVM::BitcastOp(loc, outputType, input);
551 }
552
553
554
555
556
557
558
559
560
564 bool isUnsigned, Value llvmInput,
568 auto vectorType = dyn_cast(inputType);
569 if (!vectorType) {
570 operands.push_back(llvmInput);
571 return;
572 }
573 Type elemType = vectorType.getElementType();
574
575 if (elemType.isBF16())
576 llvmInput = rewriter.createLLVM::BitcastOp(
577 loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
579 operands.push_back(llvmInput);
580 return;
581 }
582
583
584
585
586 auto mlirInputType = cast(mlirInput.getType());
587 bool isInputInteger = mlirInputType.getElementType().isInteger();
588 if (isInputInteger) {
589
590 bool localIsUnsigned = isUnsigned;
592 localIsUnsigned = true;
594 localIsUnsigned = false;
595 }
597 operands.push_back(sign);
598 }
599
600 int64_t numBits =
603 Type intrinsicInType = numBits <= 32
606 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
608 loc, llvmIntrinsicInType, llvmInput);
609
610
611
612 if (numBits < 32)
613 castInput = rewriter.createLLVM::ZExtOp(loc, i32, castInput);
614 operands.push_back(castInput);
615 }
616
617
618
619
620
621
622
623
627 Value output, int32_t subwordOffset,
630 auto vectorType = dyn_cast(inputType);
631 Type elemType = vectorType.getElementType();
632 if (elemType.isBF16())
633 output = rewriter.createLLVM::BitcastOp(
634 loc, vectorType.clone(rewriter.getI16Type()), output);
635 operands.push_back(output);
637 operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
638 } else if (elemType.isInteger(32)) {
640 }
641 }
642
643
644
646 return (chipset == kGfx942 && isa(type)) ||
647 (hasOcpFp8(chipset) && isa(type));
648 }
649
650
651
653 return (chipset == kGfx942 && isa(type)) ||
654 (hasOcpFp8(chipset) && isa(type));
655 }
656
657
658
659
662 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
663 b = mfma.getBlocks();
666
667 if (sourceElem.isF32() && destElem.isF32()) {
668 if (mfma.getReducePrecision() && chipset >= kGfx942) {
669 if (m == 32 && n == 32 && k == 4 && b == 1)
670 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
671 if (m == 16 && n == 16 && k == 8 && b == 1)
672 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
673 }
674 if (m == 32 && n == 32 && k == 1 && b == 2)
675 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
676 if (m == 16 && n == 16 && k == 1 && b == 4)
677 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
678 if (m == 4 && n == 4 && k == 1 && b == 16)
679 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
680 if (m == 32 && n == 32 && k == 2 && b == 1)
681 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
682 if (m == 16 && n == 16 && k == 4 && b == 1)
683 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
684 }
685
686 if (sourceElem.isF16() && destElem.isF32()) {
687 if (chipset >= kGfx950) {
688 if (m == 32 && n == 32 && k == 16 && b == 1)
689 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
690 if (m == 16 && n == 16 && k == 32 && b == 1)
691 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
692 }
693 if (m == 32 && n == 32 && k == 4 && b == 2)
694 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
695 if (m == 16 && n == 16 && k == 4 && b == 4)
696 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
697 if (m == 4 && n == 4 && k == 4 && b == 16)
698 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
699 if (m == 32 && n == 32 && k == 8 && b == 1)
700 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
701 if (m == 16 && n == 16 && k == 16 && b == 1)
702 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
703 }
704
705 if (sourceElem.isBF16() && destElem.isF32()) {
706 if (chipset >= kGfx950) {
707 if (m == 32 && n == 32 && k == 16 && b == 1)
708 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
709 if (m == 16 && n == 16 && k == 32 && b == 1)
710 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
711 }
712 if (chipset >= kGfx90a) {
713 if (m == 32 && n == 32 && k == 4 && b == 2)
714 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
715 if (m == 16 && n == 16 && k == 4 && b == 4)
716 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
717 if (m == 4 && n == 4 && k == 4 && b == 16)
718 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
719 if (m == 32 && n == 32 && k == 8 && b == 1)
720 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
721 if (m == 16 && n == 16 && k == 16 && b == 1)
722 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
723 }
724 if (m == 32 && n == 32 && k == 2 && b == 2)
725 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
726 if (m == 16 && n == 16 && k == 2 && b == 4)
727 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
728 if (m == 4 && n == 4 && k == 2 && b == 16)
729 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
730 if (m == 32 && n == 32 && k == 4 && b == 1)
731 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
732 if (m == 16 && n == 16 && k == 8 && b == 1)
733 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
734 }
735
737 if (chipset >= kGfx950) {
738 if (m == 32 && n == 32 && k == 32 && b == 1)
739 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
740 if (m == 16 && n == 16 && k == 64 && b == 1)
741 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
742 }
743 if (m == 32 && n == 32 && k == 4 && b == 2)
744 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
745 if (m == 16 && n == 16 && k == 4 && b == 4)
746 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
747 if (m == 4 && n == 4 && k == 4 && b == 16)
748 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
749 if (m == 32 && n == 32 && k == 8 && b == 1)
750 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
751 if (m == 16 && n == 16 && k == 16 && b == 1)
752 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
753 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
754 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
755 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
756 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
757 }
758
759 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
760 if (m == 16 && n == 16 && k == 4 && b == 1)
761 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
762 if (m == 4 && n == 4 && k == 4 && b == 4)
763 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
764 }
765
767
768
769 Type sourceBElem =
770 cast(mfma.getSourceB().getType()).getElementType();
771 if (m == 16 && n == 16 && k == 32 && b == 1) {
773 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
775 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
776 }
777 if (m == 32 && n == 32 && k == 16 && b == 1) {
779 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
781 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
782 }
783 }
784
786 Type sourceBElem =
787 cast(mfma.getSourceB().getType()).getElementType();
788 if (m == 16 && n == 16 && k == 32 && b == 1) {
790 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
792 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
793 }
794 if (m == 32 && n == 32 && k == 16 && b == 1) {
796 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
798 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
799 }
800 }
801
802 return std::nullopt;
803 }
804
807 .Case([](Float8E4M3FNType) { return 0u; })
808 .Case([](Float8E5M2Type) { return 1u; })
809 .Case([](Float6E2M3FNType) { return 2u; })
810 .Case([](Float6E3M2FNType) { return 3u; })
811 .Case([](Float4E2M1FNType) { return 4u; })
812 .Default([](Type) { return std::nullopt; });
813 }
814
815
816
817
818
819
820
821
822 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
824 uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
828
830 return std::nullopt;
831 if (!isa(destType))
832 return std::nullopt;
833
836 if (!aTypeCode || !bTypeCode)
837 return std::nullopt;
838
839 if (m == 32 && n == 32 && k == 64 && b == 1)
840 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
841 *aTypeCode, *bTypeCode};
842 if (m == 16 && n == 16 && k == 128 && b == 1)
843 return std::tuple{
844 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
845 *bTypeCode};
846
847 return std::nullopt;
848 }
849
850 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
853 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
854 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
855 mfma.getBlocks(), chipset);
856 }
857
858 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
861 smfma.getSourceB().getType(),
862 smfma.getDestC().getType(), smfma.getM(),
863 smfma.getN(), smfma.getK(), 1u, chipset);
864 }
865
866
867
868
871 auto sourceVectorType = dyn_cast(wmma.getSourceA().getType());
872 auto sourceBVectorType = dyn_cast(wmma.getSourceB().getType());
873 auto destVectorType = dyn_cast(wmma.getDestC().getType());
874 auto elemSourceType = sourceVectorType.getElementType();
875 auto elemBSourceType = sourceBVectorType.getElementType();
876 auto elemDestType = destVectorType.getElementType();
877
878 if (elemSourceType.isF16() && elemDestType.isF32())
879 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
880 if (elemSourceType.isBF16() && elemDestType.isF32())
881 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
882 if (elemSourceType.isF16() && elemDestType.isF16())
883 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
884 if (elemSourceType.isBF16() && elemDestType.isBF16())
885 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
886 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
887 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
889 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
890 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
891 }
893 if (isa(elemSourceType) &&
894 isa(elemBSourceType) && elemDestType.isF32())
895 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
896 if (isa(elemSourceType) &&
897 isa(elemBSourceType) && elemDestType.isF32())
898 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
899 if (isa(elemSourceType) &&
900 isa(elemBSourceType) && elemDestType.isF32())
901 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
902 if (isa(elemSourceType) &&
903 isa(elemBSourceType) && elemDestType.isF32())
904 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
905 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
906 bool isWave64 = destVectorType.getNumElements() == 4;
907
908
909 bool has8Inputs = sourceVectorType.getNumElements() == 8;
910 if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
911 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
912 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
913 }
914 }
915 return std::nullopt;
916 }
917
918 namespace {
922
924
925 LogicalResult
926 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
929 Type outType = typeConverter->convertType(op.getDestD().getType());
930 Type intrinsicOutType = outType;
931 if (auto outVecType = dyn_cast(outType))
932 if (outVecType.getElementType().isBF16())
933 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
934
936 return op->emitOpError("MFMA only supported on gfx908+");
937 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
938 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
940 return op.emitOpError("negation unsupported on older than gfx942");
941 getBlgpField |=
942 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
943 }
944 std::optional maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
945 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
947 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
948 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
949
950 bool isScaled =
951 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
952 if (isScaled &&
953 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
954 return op.emitOpError(
955 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
956 "be scaled as those fields are used for type information");
957 }
958
959 StringRef intrinsicName =
960 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
962 loweredOp.addTypes(intrinsicOutType);
963 loweredOp.addOperands(
966 adaptor.getDestC()});
967 if (isScaled) {
969 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
970 loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
972 zero, zero,
973 zero, zero});
974 } else {
975 loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
978 };
980 if (outType != intrinsicOutType)
981 lowered = rewriter.createLLVM::BitcastOp(loc, outType, lowered);
983 return success();
984 }
985 };
986
990
992
993 LogicalResult
994 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
997 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
998
1000 return op->emitOpError("scaled MFMA only supported on gfx908+");
1001 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1003 if (!maybeScaledIntrinsic.has_value())
1004 return op.emitOpError(
1005 "no intrinsic matching scaled MFMA size on given chipset");
1006
1007 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1009 loweredOp.addTypes(intrinsicOutType);
1010 loweredOp.addOperands(
1013 adaptor.getDestC()});
1014 Value scalesIdxA =
1016 Value scalesIdxB =
1018 loweredOp.addOperands(
1021 scalesIdxA,
1022
1024 scalesIdxB,
1025
1028 rewriter.replaceOp(op, lowered);
1029 return success();
1030 }
1031 };
1032
1036
1038
1039 LogicalResult
1040 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1043 auto outType =
1044 typeConverter->convertType(op.getDestD().getType());
1045 if (!outType)
1047
1049 return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1050
1051
1052
1053 VectorType rawOutType = outType;
1054 if (outType.getElementType().isBF16())
1055 rawOutType = outType.clone(rewriter.getI16Type());
1056
1057 std::optional maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1058
1059 if (!maybeIntrinsic.has_value())
1060 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1061
1062 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1063 return op.emitOpError("subwordOffset not supported on gfx12+");
1064
1066 loweredOp.addTypes(rawOutType);
1067
1070 adaptor.getSourceA(), op.getSourceA(), operands);
1072 adaptor.getSourceB(), op.getSourceB(), operands);
1074 op.getSubwordOffset(), op.getClamp(), operands);
1075
1076 loweredOp.addOperands(operands);
1078
1079 Operation *maybeCastBack = lowered;
1080 if (rawOutType != outType)
1081 maybeCastBack =
1082 rewriter.createLLVM::BitcastOp(loc, outType, lowered->getResult(0));
1084
1085 return success();
1086 }
1087 };
1088
1092
1094
1095 LogicalResult
1096 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1099 return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
1100
1102
1103 auto srcMemRefType = cast(op.getSrc().getType());
1104 auto dstMemRefType = cast(op.getDst().getType());
1105
1106
1107
1108
1109 Type transferType = op.getTransferType();
1110 size_t loadWidth = [&]() -> size_t {
1111 if (auto transferVectorType = dyn_cast(transferType)) {
1112 return transferVectorType.getNumElements() *
1113 (transferVectorType.getElementTypeBitWidth() / 8);
1114 }
1116 }();
1117
1118
1119 if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
1120 return op.emitOpError("chipset unsupported element size");
1121
1124 (adaptor.getSrcIndices()));
1127 (adaptor.getDstIndices()));
1128
1132 rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
1133 ArrayAttr{});
1134
1135 return success();
1136 }
1137 };
1138
1139 namespace {
1140 struct ExtPackedFp8OpLowering final
1144 chipset(chipset) {}
1146
1147 LogicalResult
1148 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1150 };
1151
1152 struct PackedTrunc2xFp8OpLowering final
1157 chipset(chipset) {}
1159
1160 LogicalResult
1161 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1163 };
1164
1165 struct PackedStochRoundFp8OpLowering final
1167 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
1170 chipset(chipset) {}
1172
1173 LogicalResult
1174 matchAndRewrite(PackedStochRoundFp8Op op,
1175 PackedStochRoundFp8OpAdaptor adaptor,
1177 };
1178
1179 struct ScaledExtPackedOpLowering final
1183 chipset(chipset) {}
1185
1186 LogicalResult
1187 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1189 };
1190
1191 struct PackedScaledTruncOpLowering final
1196 chipset(chipset) {}
1198
1199 LogicalResult
1200 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1202 };
1203
1204 }
1205
1206 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1207 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1212 loc, "Fp8 conversion instructions are not available on target "
1213 "architecture and their emulation is not implemented");
1216 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1217 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1218
1219 Value source = adaptor.getSource();
1220 auto sourceVecType = dyn_cast(op.getSource().getType());
1221 auto resultVecType = dyn_cast(op.getResult().getType());
1223
1224 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1225 Value longVec = rewriter.createLLVM::UndefOp(loc, v4i8);
1226 if (!sourceVecType) {
1227 longVec = rewriter.createLLVM::InsertElementOp(
1229 } else {
1230 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1232 Value elem = rewriter.createLLVM::ExtractElementOp(loc, source, idx);
1233 longVec =
1234 rewriter.createLLVM::InsertElementOp(loc, longVec, elem, idx);
1235 }
1236 }
1237 source = longVec;
1238 }
1239 Value i32Source = rewriter.createLLVM::BitcastOp(loc, i32, source);
1240 if (resultVecType) {
1242 rewriter.replaceOpWithNewOpROCDL::CvtPkF32Bf8Op(op, f32, i32Source,
1243 op.getIndex());
1245 rewriter.replaceOpWithNewOpROCDL::CvtPkF32Fp8Op(op, f32, i32Source,
1246 op.getIndex());
1247 }
1248 } else {
1251 op.getIndex());
1254 op.getIndex());
1255 }
1256 }
1257 return success();
1258 }
1259
1260 LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1261 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1266 loc, "Scaled fp conversion instructions are not available on target "
1267 "architecture and their emulation is not implemented");
1268 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1269
1270 Value source = adaptor.getSource();
1271 Value scale = adaptor.getScale();
1272
1273 VectorType sourceVecType = cast(op.getSource().getType());
1274 Type sourceElemType = sourceVecType.getElementType();
1275 VectorType destVecType = cast(op.getResult().getType());
1276 Type destElemType = destVecType.getElementType();
1277
1278 VectorType packedVecType;
1279 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
1281 packedVecType = cast(getTypeConverter()->convertType(v4i8));
1282 } else if (isa(sourceElemType)) {
1284 packedVecType = cast(getTypeConverter()->convertType(v8i4));
1285 } else {
1286 llvm_unreachable("invalid element type for scaled ext");
1287 }
1288
1289
1290 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1291 Value longVec = rewriter.createLLVM::ZeroOp(loc, packedVecType);
1292 if (!sourceVecType) {
1293 longVec = rewriter.createLLVM::InsertElementOp(
1295 } else {
1296 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1298 Value elem = rewriter.createLLVM::ExtractElementOp(loc, source, idx);
1299 longVec =
1300 rewriter.createLLVM::InsertElementOp(loc, longVec, elem, idx);
1301 }
1302 }
1303 source = longVec;
1304 }
1305 Value i32Source = rewriter.createLLVM::BitcastOp(loc, i32, source);
1306
1307 if (isa(sourceElemType) && destElemType.isF32())
1309 op, destVecType, i32Source, scale, op.getIndex());
1310 else if (isa(sourceElemType) && destElemType.isF16())
1312 op, destVecType, i32Source, scale, op.getIndex());
1313 else if (isa(sourceElemType) && destElemType.isBF16())
1315 op, destVecType, i32Source, scale, op.getIndex());
1316 else if (isa(sourceElemType) && destElemType.isF32())
1318 op, destVecType, i32Source, scale, op.getIndex());
1319 else if (isa(sourceElemType) && destElemType.isF16())
1321 op, destVecType, i32Source, scale, op.getIndex());
1322 else if (isa(sourceElemType) && destElemType.isBF16())
1324 op, destVecType, i32Source, scale, op.getIndex());
1325 else if (isa(sourceElemType) && destElemType.isF32())
1327 op, destVecType, i32Source, scale, op.getIndex());
1328 else if (isa(sourceElemType) && destElemType.isF16())
1330 op, destVecType, i32Source, scale, op.getIndex());
1331 else if (isa(sourceElemType) && destElemType.isBF16())
1333 op, destVecType, i32Source, scale, op.getIndex());
1334 else
1335 return failure();
1336
1337 return success();
1338 }
1339
1340 LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1341 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1346 loc, "Scaled fp conversion instructions are not available on target "
1347 "architecture and their emulation is not implemented");
1348 Type v2i16 = getTypeConverter()->convertType(
1350 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1351
1352 Type resultType = op.getResult().getType();
1354 VectorType sourceVecType = cast(op.getSource().getType());
1355 Type sourceElemType = sourceVecType.getElementType();
1356
1357 Type intResultType = isa(resultElemType) ? i32 : v2i16;
1358
1359 Value source = adaptor.getSource();
1360 Value scale = adaptor.getScale();
1361 Value existing = adaptor.getExisting();
1362 if (existing)
1363 existing = rewriter.createLLVM::BitcastOp(loc, intResultType, existing);
1364 else
1365 existing = rewriter.createLLVM::ZeroOp(loc, intResultType);
1366
1367 if (sourceVecType.getNumElements() < 2) {
1369 Value elem0 = rewriter.createLLVM::ExtractElementOp(loc, source, c0);
1371 source = rewriter.createLLVM::ZeroOp(loc, v2);
1372 source = rewriter.createLLVM::InsertElementOp(loc, source, elem0, c0);
1373 }
1374
1375 Value sourceA, sourceB;
1376 if (sourceElemType.isF32()) {
1379 sourceA = rewriter.createLLVM::ExtractElementOp(loc, source, c0);
1380 sourceB = rewriter.createLLVM::ExtractElementOp(loc, source, c1);
1381 }
1382
1384 if (sourceElemType.isF32() && isa(resultElemType))
1385 result = rewriter.createROCDL::CvtScaleF32PkBf8F32Op(
1386 loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1387 else if (sourceElemType.isF16() && isa(resultElemType))
1388 result = rewriter.createROCDL::CvtScaleF32PkBf8F16Op(
1389 loc, intResultType, existing, source, scale, op.getIndex());
1390 else if (sourceElemType.isBF16() && isa(resultElemType))
1391 result = rewriter.createROCDL::CvtScaleF32PkBf8Bf16Op(
1392 loc, intResultType, existing, source, scale, op.getIndex());
1393 else if (sourceElemType.isF32() && isa(resultElemType))
1394 result = rewriter.createROCDL::CvtScaleF32PkFp8F32Op(
1395 loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1396 else if (sourceElemType.isF16() && isa(resultElemType))
1397 result = rewriter.createROCDL::CvtScaleF32PkFp8F16Op(
1398 loc, intResultType, existing, source, scale, op.getIndex());
1399 else if (sourceElemType.isBF16() && isa(resultElemType))
1400 result = rewriter.createROCDL::CvtScaleF32PkFp8Bf16Op(
1401 loc, intResultType, existing, source, scale, op.getIndex());
1402 else if (sourceElemType.isF32() && isa(resultElemType))
1403 result = rewriter.createROCDL::CvtScaleF32PkFp4F32Op(
1404 loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1405 else if (sourceElemType.isF16() && isa(resultElemType))
1406 result = rewriter.createROCDL::CvtScaleF32PkFp4F16Op(
1407 loc, intResultType, existing, source, scale, op.getIndex());
1408 else if (sourceElemType.isBF16() && isa(resultElemType))
1409 result = rewriter.createROCDL::CvtScaleF32PkFp4Bf16Op(
1410 loc, intResultType, existing, source, scale, op.getIndex());
1411 else
1412 return failure();
1413
1415 op, getTypeConverter()->convertType(resultType), result);
1416 return success();
1417 }
1418
1419 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1420 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1425 loc, "Fp8 conversion instructions are not available on target "
1426 "architecture and their emulation is not implemented");
1427 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1428
1429 Type resultType = op.getResult().getType();
1431
1432 Value sourceA = adaptor.getSourceA();
1433 Value sourceB = adaptor.getSourceB();
1434 if (!sourceB)
1435 sourceB = rewriter.createLLVM::UndefOp(loc, sourceA.getType());
1436 Value existing = adaptor.getExisting();
1437 if (existing)
1438 existing = rewriter.createLLVM::BitcastOp(loc, i32, existing);
1439 else
1440 existing = rewriter.createLLVM::UndefOp(loc, i32);
1441
1444 result = rewriter.createROCDL::CvtPkBf8F32Op(loc, i32, sourceA, sourceB,
1445 existing, op.getWordIndex());
1447 result = rewriter.createROCDL::CvtPkFp8F32Op(loc, i32, sourceA, sourceB,
1448 existing, op.getWordIndex());
1449
1451 op, getTypeConverter()->convertType(resultType), result);
1452 return success();
1453 }
1454
1455 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1456 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1461 loc, "Fp8 conversion instructions are not available on target "
1462 "architecture and their emulation is not implemented");
1463 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1464
1465 Type resultType = op.getResult().getType();
1467
1468 Value source = adaptor.getSource();
1469 Value stoch = adaptor.getStochiasticParam();
1470 Value existing = adaptor.getExisting();
1471 if (existing)
1472 existing = rewriter.createLLVM::BitcastOp(loc, i32, existing);
1473 else
1474 existing = rewriter.createLLVM::UndefOp(loc, i32);
1475
1478 result = rewriter.createROCDL::CvtSrBf8F32Op(
1479 loc, i32, source, stoch, existing, op.getStoreIndex());
1481 result = rewriter.createROCDL::CvtSrFp8F32Op(
1482 loc, i32, source, stoch, existing, op.getStoreIndex());
1483
1485 op, getTypeConverter()->convertType(resultType), result);
1486 return success();
1487 }
1488
1489
1490
1495
1496 LogicalResult
1497 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
1499
1500
1501 Location loc = DppOp.getLoc();
1502 Value src = adaptor.getSrc();
1503 Value old = adaptor.getOld();
1506 Type llvmType = nullptr;
1509 } else if (isa(srcType)) {
1513 } else if (isa(srcType)) {
1517 }
1518 auto llvmSrcIntType = typeConverter->convertType(
1520
1521
1522 auto convertOperand = [&](Value operand, Type operandType) {
1523 if (operandType.getIntOrFloatBitWidth() <= 16) {
1524 if (llvm::isa(operandType)) {
1525 operand =
1526 rewriter.createLLVM::BitcastOp(loc, llvmSrcIntType, operand);
1527 }
1529 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
1530 Value undefVec = rewriter.createLLVM::UndefOp(loc, llvmVecType);
1531 operand = rewriter.createLLVM::InsertElementOp(
1533 operand = rewriter.createLLVM::BitcastOp(loc, llvmType, operand);
1534 }
1535 return operand;
1536 };
1537
1538 src = convertOperand(src, srcType);
1539 old = convertOperand(old, oldType);
1540
1541
1542 enum DppCtrl : unsigned {
1543 ROW_SHL0 = 0x100,
1544 ROW_SHR0 = 0x110,
1545 ROW_ROR0 = 0x120,
1546 WAVE_SHL1 = 0x130,
1547 WAVE_ROL1 = 0x134,
1548 WAVE_SHR1 = 0x138,
1549 WAVE_ROR1 = 0x13C,
1550 ROW_MIRROR = 0x140,
1551 ROW_HALF_MIRROR = 0x141,
1552 BCAST15 = 0x142,
1553 BCAST31 = 0x143,
1554 };
1555
1556 auto kind = DppOp.getKind();
1557 auto permArgument = DppOp.getPermArgument();
1558 uint32_t DppCtrl = 0;
1559
1560 switch (kind) {
1561
1562 case DPPPerm::quad_perm:
1563 if (auto quadPermAttr = cast(*permArgument)) {
1564 int32_t i = 0;
1565 for (auto elem : quadPermAttr.getAsRange()) {
1566 uint32_t num = elem.getInt();
1567 DppCtrl |= num << (i * 2);
1568 i++;
1569 }
1570 }
1571 break;
1572 case DPPPerm::row_shl:
1573 if (auto intAttr = cast(*permArgument)) {
1574 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
1575 }
1576 break;
1577 case DPPPerm::row_shr:
1578 if (auto intAttr = cast(*permArgument)) {
1579 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
1580 }
1581 break;
1582 case DPPPerm::row_ror:
1583 if (auto intAttr = cast(*permArgument)) {
1584 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
1585 }
1586 break;
1587 case DPPPerm::wave_shl:
1588 DppCtrl = DppCtrl::WAVE_SHL1;
1589 break;
1590 case DPPPerm::wave_shr:
1591 DppCtrl = DppCtrl::WAVE_SHR1;
1592 break;
1593 case DPPPerm::wave_rol:
1594 DppCtrl = DppCtrl::WAVE_ROL1;
1595 break;
1596 case DPPPerm::wave_ror:
1597 DppCtrl = DppCtrl::WAVE_ROR1;
1598 break;
1599 case DPPPerm::row_mirror:
1600 DppCtrl = DppCtrl::ROW_MIRROR;
1601 break;
1602 case DPPPerm::row_half_mirror:
1603 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1604 break;
1605 case DPPPerm::row_bcast_15:
1606 DppCtrl = DppCtrl::BCAST15;
1607 break;
1608 case DPPPerm::row_bcast_31:
1609 DppCtrl = DppCtrl::BCAST31;
1610 break;
1611 }
1612
1613
1614
1615 auto rowMask = DppOp->getAttrOfType("row_mask").getInt();
1616 auto bankMask = DppOp->getAttrOfType("bank_mask").getInt();
1617 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
1618
1619
1620 auto dppMovOp = rewriter.createROCDL::DPPUpdateOp(
1621 loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
1622
1623 Value result = dppMovOp.getRes();
1625 result = rewriter.createLLVM::TruncOp(loc, llvmSrcIntType, result);
1626 if (!llvm::isa(srcType)) {
1627 result = rewriter.createLLVM::BitcastOp(loc, srcType, result);
1628 }
1629 }
1630
1631
1632
1634 return success();
1635 }
1636 };
1637
1638 struct AMDGPUSwizzleBitModeLowering
1641
1642 LogicalResult
1643 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
1647 Value src = adaptor.getSrc();
1650 unsigned andMask = op.getAndMask();
1651 unsigned orMask = op.getOrMask();
1652 unsigned xorMask = op.getXorMask();
1653
1654
1655
1656 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
1659 for (Value v : decomposed) {
1661 rewriter.createROCDL::DsSwizzleOp(loc, v.getType(), v, maskValue);
1662 swizzled.emplace_back(res);
1663 }
1664
1667 return success();
1668 }
1669 };
1670
1671 struct ConvertAMDGPUToROCDLPass
1672 : public impl::ConvertAMDGPUToROCDLPassBase {
1673 using Base::Base;
1674
1675 void runOnOperation() override {
1677 FailureOr maybeChipset = Chipset::parse(chipset);
1678 if (failed(maybeChipset)) {
1680 return signalPassFailure();
1681 }
1682
1687 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1688 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1689 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1692 signalPassFailure();
1693 }
1694 };
1695 }
1696
1700 [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
1704 switch (as.getValue()) {
1705 case amdgpu::AddressSpace::FatRawBuffer:
1707 case amdgpu::AddressSpace::BufferRsrc:
1709 case amdgpu::AddressSpace::FatStructuredBuffer:
1711 }
1713 });
1714 }
1715
1721 .add<FatRawBufferCastLowering,
1722 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1723 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1724 RawBufferOpLowering<RawBufferAtomicFaddOp,
1725 ROCDL::RawPtrBufferAtomicFaddOp>,
1726 RawBufferOpLowering<RawBufferAtomicFmaxOp,
1727 ROCDL::RawPtrBufferAtomicFmaxOp>,
1728 RawBufferOpLowering<RawBufferAtomicSmaxOp,
1729 ROCDL::RawPtrBufferAtomicSmaxOp>,
1730 RawBufferOpLowering<RawBufferAtomicUminOp,
1731 ROCDL::RawPtrBufferAtomicUminOp>,
1732 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1733 ROCDL::RawPtrBufferAtomicCmpSwap>,
1734 AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1735 MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
1736 ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
1737 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
1738 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
1739 chipset);
1740 patterns.add(converter);
1741 }
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
constexpr Chipset kGfx908
constexpr Chipset kGfx90a
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, uint32_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
constexpr Chipset kGfx950
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class provides a shared interface for ranked and unranked memref types.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult abort()
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
bool hasOcpFp8(const Chipset &chipset)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.