MLIR: lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

21

22 #include "../LLVMCommon/MemRefDescriptor.h"

23

24 #include "llvm/ADT/STLExtras.h"

25 #include "llvm/ADT/TypeSwitch.h"

26 #include "llvm/Support/Casting.h"

27 #include "llvm/Support/ErrorHandling.h"

28 #include

29

30 namespace mlir {

31 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS

32 #include "mlir/Conversion/Passes.h.inc"

33 }

34

35 using namespace mlir;

37

38

43

44

47 IntegerType i32 = rewriter.getI32Type();

48

49 auto valTy = cast(val.getType());

50 if (i32 == valTy)

51 return val;

52 return valTy.getWidth() > 32

53 ? Value(rewriter.createLLVM::TruncOp(loc, i32, val))

54 : Value(rewriter.createLLVM::ZExtOp(loc, i32, val));

55 }

56

58 Location loc, int32_t value) {

60 return rewriter.createLLVM::ConstantOp(loc, i32, value);

61 }

62

64 bool value) {

66 return rewriter.createLLVM::ConstantOp(loc, llvmI1, value);

67 }

68

69

73 IntegerType i32 = rewriter.getI32Type();

75 for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {

76 if (stride != 1) {

77 Value strideValue =

78 ShapedType::isDynamic(stride)

80 memRefDescriptor.stride(rewriter, loc, i))

81 : rewriter.createLLVM::ConstantOp(loc, i32, stride);

82 increment = rewriter.createLLVM::MulOp(loc, increment, strideValue);

83 }

84 index =

85 index ? rewriter.createLLVM::AddOp(loc, index, increment) : increment;

86 }

88 }

89

90

91

92

94 MemRefType memrefType,

97 uint32_t elementByteWidth) {

98 if (memrefType.hasStaticShape() &&

99 !llvm::any_of(strides, ShapedType::isDynamic)) {

100 int64_t size = memrefType.getRank() == 0 ? 1 : 0;

102 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)

103 size = std::max(shape[i] * strides[i], size);

104 size = size * elementByteWidth;

106 "the memref buffer is too large");

107 return createI32Constant(rewriter, loc, static_cast<int32_t>(size));

108 }

110 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {

111 Value size = memrefDescriptor.size(rewriter, loc, i);

112 Value stride = memrefDescriptor.stride(rewriter, loc, i);

113 Value maxThisDim = rewriter.createLLVM::MulOp(loc, size, stride);

114 maxIndex = maxIndex

115 ? rewriter.createLLVM::UMaxOp(loc, maxIndex, maxThisDim)

116 : maxThisDim;

117 }

120 return rewriter.createLLVM::MulOp(loc, maxIndexI32, byteWidthConst);

121 }

122

125 bool boundsCheck, amdgpu::Chipset chipset,

126 Value cacheSwizzleStride = nullptr,

127 unsigned addressSpace = 8) {

128

129

130

133 if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {

134 Value cacheStrideZext =

135 rewriter.createLLVM::ZExtOp(loc, i16, cacheSwizzleStride);

136 Value swizzleBit = rewriter.createLLVM::ConstantOp(

138 stride = rewriter.createLLVM::OrOp(loc, cacheStrideZext, swizzleBit,

139 true);

140 } else {

141 stride = rewriter.createLLVM::ConstantOp(loc, i16,

143 }

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159 uint32_t flags = (7 << 12) | (4 << 15);

160 if (chipset.majorVersion >= 10) {

161 flags |= (1 << 24);

162 uint32_t oob = boundsCheck ? 3 : 2;

163 flags |= (oob << 28);

164 }

166 Type rsrcType =

169 loc, rsrcType, basePointer, stride, numRecords, flagsConst);

170 return resource;

171 }

172

173 namespace {

174 struct FatRawBufferCastLowering

178 chipset(chipset) {}

179

181

182 LogicalResult

183 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,

186 Value memRef = adaptor.getSource();

187 Value unconvertedMemref = op.getSource();

188 MemRefType memrefType = cast(unconvertedMemref.getType());

190

192 int64_t elementByteWidth =

194

195 int64_t unusedOffset = 0;

197 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))

198 return op.emitOpError("Can't lower non-stride-offset memrefs");

199

200 Value numRecords = adaptor.getValidBytes();

201 if (!numRecords)

202 numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,

203 strideVals, elementByteWidth);

204

205 Value basePointer =

206 adaptor.getResetOffset()

207 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),

208 memrefType)

209 : descriptor.alignedPtr(rewriter, loc);

210

211 Value offset = adaptor.getResetOffset()

212 ? rewriter.createLLVM::ConstantOp(

213 loc, getIndexType(), rewriter.getIndexAttr(0))

214 : descriptor.offset(rewriter, loc);

215

216 bool hasSizes = memrefType.getRank() > 0;

217

218

219 Value sizes = hasSizes ? rewriter.createLLVM::ExtractValueOp(

222 Value strides = hasSizes

223 ? rewriter.createLLVM::ExtractValueOp(

226

228 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),

229 chipset, adaptor.getCacheSwizzleStride(), 7);

230

232 rewriter, loc,

233 getTypeConverter()->convertType(op.getResult().getType()));

234 result = rewriter.createLLVM::InsertValueOp(

236 result = rewriter.createLLVM::InsertValueOp(

238 result = rewriter.createLLVM::InsertValueOp(loc, result, offset,

240 if (hasSizes) {

241 result = rewriter.createLLVM::InsertValueOp(loc, result, sizes,

243 result = rewriter.createLLVM::InsertValueOp(

245 }

247 return success();

248 }

249 };

250

251

252 template <typename GpuOp, typename Intrinsic>

256

258 static constexpr uint32_t maxVectorOpWidth = 128;

259

260 LogicalResult

261 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,

263 Location loc = gpuOp.getLoc();

264 Value memref = adaptor.getMemref();

265 Value unconvertedMemref = gpuOp.getMemref();

266 MemRefType memrefType = cast(unconvertedMemref.getType());

267

269 return gpuOp.emitOpError("raw buffer ops require GCN or higher");

270

271 Value storeData = adaptor.getODSOperands(0)[0];

272 if (storeData == memref)

273 storeData = Value();

274 Type wantedDataType;

275 if (storeData)

276 wantedDataType = storeData.getType();

277 else

278 wantedDataType = gpuOp.getODSResults(0)[0].getType();

279

281

282 if (storeData) {

283 Value maybeCmpData = adaptor.getODSOperands(1)[0];

284 if (maybeCmpData != memref)

285 atomicCmpData = maybeCmpData;

286 }

287

288 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);

289

291

292

294 int64_t elementByteWidth =

297

298

299

300

301

302

303 Type llvmBufferValType = llvmWantedDataType;

304 if (atomicCmpData) {

305 if (auto floatType = dyn_cast(wantedDataType))

306 llvmBufferValType = this->getTypeConverter()->convertType(

308 }

309 if (auto dataVector = dyn_cast(wantedDataType)) {

310 uint32_t vecLen = dataVector.getNumElements();

311 uint32_t elemBits =

313 uint32_t totalBits = elemBits * vecLen;

314 bool usePackedFp16 =

315 isa_and_present(*gpuOp) && vecLen == 2;

316 if (totalBits > maxVectorOpWidth)

317 return gpuOp.emitOpError(

318 "Total width of loads or stores must be no more than " +

319 Twine(maxVectorOpWidth) + " bits, but we call for " +

320 Twine(totalBits) +

321 " bits. This should've been caught in validation");

322 if (!usePackedFp16 && elemBits < 32) {

323 if (totalBits > 32) {

324 if (totalBits % 32 != 0)

325 return gpuOp.emitOpError("Load or store of more than 32-bits that "

326 "doesn't fit into words. Can't happen\n");

327 llvmBufferValType = this->typeConverter->convertType(

329 } else {

330 llvmBufferValType = this->typeConverter->convertType(

332 }

333 }

334 }

335 if (auto vecType = dyn_cast(llvmBufferValType)) {

336

337

338 if (vecType.getNumElements() == 1)

339 llvmBufferValType = vecType.getElementType();

340 }

341

343 if (storeData) {

344 if (llvmBufferValType != llvmWantedDataType) {

345 Value castForStore =

346 rewriter.createLLVM::BitcastOp(loc, llvmBufferValType, storeData);

347 args.push_back(castForStore);

348 } else {

349 args.push_back(storeData);

350 }

351 }

352

353 if (atomicCmpData) {

354 if (llvmBufferValType != llvmWantedDataType) {

355 Value castForCmp = rewriter.createLLVM::BitcastOp(

356 loc, llvmBufferValType, atomicCmpData);

357 args.push_back(castForCmp);

358 } else {

359 args.push_back(atomicCmpData);

360 }

361 }

362

363

364 int64_t offset = 0;

366 if (failed(memrefType.getStridesAndOffset(strides, offset)))

367 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");

368

370

371 Value ptr = memrefDescriptor.bufferPtr(

372 rewriter, loc, *this->getTypeConverter(), memrefType);

374 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);

376 adaptor.getBoundsCheck(), chipset);

377 args.push_back(resource);

378

379

381 adaptor.getIndices(), strides);

382 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();

383 indexOffset && *indexOffset > 0) {

385 voffset =

386 voffset ? rewriter.createLLVM::AddOp(loc, voffset, extraOffsetConst)

387 : extraOffsetConst;

388 }

389 voffset = rewriter.createLLVM::MulOp(loc, voffset, byteWidthConst);

390 args.push_back(voffset);

391

392

393 Value sgprOffset = adaptor.getSgprOffset();

394 if (!sgprOffset)

396 sgprOffset = rewriter.createLLVM::MulOp(loc, sgprOffset, byteWidthConst);

397 args.push_back(sgprOffset);

398

399

400

401

403

405 llvmBufferValType);

406 Operation *lowered = rewriter.create(loc, resultTypes, args,

410 if (llvmBufferValType != llvmWantedDataType) {

411 replacement = rewriter.createLLVM::BitcastOp(loc, llvmWantedDataType,

412 replacement);

413 }

414 rewriter.replaceOp(gpuOp, replacement);

415 } else {

417 }

418 return success();

419 }

420 };

421

425

427

428 LogicalResult

429 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,

431 bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;

432

433 if (requiresInlineAsm) {

435 LLVM::AsmDialect::AD_ATT);

436 const char *asmStr =

437 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";

438 const char *constraints = "";

440 op,

442 asmStr, constraints, true,

444 asmDialectAttr,

445 ArrayAttr());

446 return success();

447 }

449 constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);

450 constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);

451

452

453 constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);

454

455 int32_t ldsOnlyBits;

457 ldsOnlyBits = ldsOnlyBitsGfx11;

459 ldsOnlyBits = ldsOnlyBitsGfx10;

461 ldsOnlyBits = ldsOnlyBitsGfx6789;

462 else

463 return op.emitOpError(

464 "don't know how to lower this for chipset major version")

466

468 rewriter.createROCDL::SWaitcntOp(loc, ldsOnlyBits);

470 } else {

472 rewriter.createROCDL::WaitDscntOp(loc, 0);

473 rewriter.createROCDL::BarrierSignalOp(loc, -1);

475 }

476

477 return success();

478 }

479 };

480

484

486

487 LogicalResult

488 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,

491 (uint32_t)op.getOpts());

492 return success();

493 }

494 };

495

496 }

497

498

499

500

501

502

503

504

505

506

507

508

509

510

514 if (auto vectorType = dyn_cast(inputType)) {

515 if (vectorType.getElementType().isBF16())

516 return rewriter.createLLVM::BitcastOp(

517 loc, vectorType.clone(rewriter.getI16Type()), input);

518 if (vectorType.getElementType().isInteger(8) &&

519 vectorType.getNumElements() <= 8)

520 return rewriter.createLLVM::BitcastOp(

521 loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);

522 if (isa(vectorType.getElementType()) &&

523 vectorType.getElementTypeBitWidth() <= 8) {

525 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),

526 32);

527 return rewriter.createLLVM::BitcastOp(

529 }

530 }

531 return input;

532 }

533

534

535

536

537

538

539

540

541

542

543

548 if (auto intType = dyn_cast(inputType))

549 return rewriter.createLLVM::ZExtOp(loc, outputType, input);

550 return rewriter.createLLVM::BitcastOp(loc, outputType, input);

551 }

552

553

554

555

556

557

558

559

560

564 bool isUnsigned, Value llvmInput,

568 auto vectorType = dyn_cast(inputType);

569 if (!vectorType) {

570 operands.push_back(llvmInput);

571 return;

572 }

573 Type elemType = vectorType.getElementType();

574

575 if (elemType.isBF16())

576 llvmInput = rewriter.createLLVM::BitcastOp(

577 loc, vectorType.clone(rewriter.getI16Type()), llvmInput);

579 operands.push_back(llvmInput);

580 return;

581 }

582

583

584

585

586 auto mlirInputType = cast(mlirInput.getType());

587 bool isInputInteger = mlirInputType.getElementType().isInteger();

588 if (isInputInteger) {

589

590 bool localIsUnsigned = isUnsigned;

592 localIsUnsigned = true;

594 localIsUnsigned = false;

595 }

597 operands.push_back(sign);

598 }

599

600 int64_t numBits =

603 Type intrinsicInType = numBits <= 32

606 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);

608 loc, llvmIntrinsicInType, llvmInput);

609

610

611

612 if (numBits < 32)

613 castInput = rewriter.createLLVM::ZExtOp(loc, i32, castInput);

614 operands.push_back(castInput);

615 }

616

617

618

619

620

621

622

623

627 Value output, int32_t subwordOffset,

630 auto vectorType = dyn_cast(inputType);

631 Type elemType = vectorType.getElementType();

632 if (elemType.isBF16())

633 output = rewriter.createLLVM::BitcastOp(

634 loc, vectorType.clone(rewriter.getI16Type()), output);

635 operands.push_back(output);

637 operands.push_back(createI1Constant(rewriter, loc, subwordOffset));

638 } else if (elemType.isInteger(32)) {

640 }

641 }

642

643

644

646 return (chipset == kGfx942 && isa(type)) ||

647 (hasOcpFp8(chipset) && isa(type));

648 }

649

650

651

653 return (chipset == kGfx942 && isa(type)) ||

654 (hasOcpFp8(chipset) && isa(type));

655 }

656

657

658

659

662 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),

663 b = mfma.getBlocks();

666

667 if (sourceElem.isF32() && destElem.isF32()) {

668 if (mfma.getReducePrecision() && chipset >= kGfx942) {

669 if (m == 32 && n == 32 && k == 4 && b == 1)

670 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();

671 if (m == 16 && n == 16 && k == 8 && b == 1)

672 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();

673 }

674 if (m == 32 && n == 32 && k == 1 && b == 2)

675 return ROCDL::mfma_f32_32x32x1f32::getOperationName();

676 if (m == 16 && n == 16 && k == 1 && b == 4)

677 return ROCDL::mfma_f32_16x16x1f32::getOperationName();

678 if (m == 4 && n == 4 && k == 1 && b == 16)

679 return ROCDL::mfma_f32_4x4x1f32::getOperationName();

680 if (m == 32 && n == 32 && k == 2 && b == 1)

681 return ROCDL::mfma_f32_32x32x2f32::getOperationName();

682 if (m == 16 && n == 16 && k == 4 && b == 1)

683 return ROCDL::mfma_f32_16x16x4f32::getOperationName();

684 }

685

686 if (sourceElem.isF16() && destElem.isF32()) {

687 if (chipset >= kGfx950) {

688 if (m == 32 && n == 32 && k == 16 && b == 1)

689 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();

690 if (m == 16 && n == 16 && k == 32 && b == 1)

691 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();

692 }

693 if (m == 32 && n == 32 && k == 4 && b == 2)

694 return ROCDL::mfma_f32_32x32x4f16::getOperationName();

695 if (m == 16 && n == 16 && k == 4 && b == 4)

696 return ROCDL::mfma_f32_16x16x4f16::getOperationName();

697 if (m == 4 && n == 4 && k == 4 && b == 16)

698 return ROCDL::mfma_f32_4x4x4f16::getOperationName();

699 if (m == 32 && n == 32 && k == 8 && b == 1)

700 return ROCDL::mfma_f32_32x32x8f16::getOperationName();

701 if (m == 16 && n == 16 && k == 16 && b == 1)

702 return ROCDL::mfma_f32_16x16x16f16::getOperationName();

703 }

704

705 if (sourceElem.isBF16() && destElem.isF32()) {

706 if (chipset >= kGfx950) {

707 if (m == 32 && n == 32 && k == 16 && b == 1)

708 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();

709 if (m == 16 && n == 16 && k == 32 && b == 1)

710 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();

711 }

712 if (chipset >= kGfx90a) {

713 if (m == 32 && n == 32 && k == 4 && b == 2)

714 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();

715 if (m == 16 && n == 16 && k == 4 && b == 4)

716 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();

717 if (m == 4 && n == 4 && k == 4 && b == 16)

718 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();

719 if (m == 32 && n == 32 && k == 8 && b == 1)

720 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();

721 if (m == 16 && n == 16 && k == 16 && b == 1)

722 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();

723 }

724 if (m == 32 && n == 32 && k == 2 && b == 2)

725 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();

726 if (m == 16 && n == 16 && k == 2 && b == 4)

727 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();

728 if (m == 4 && n == 4 && k == 2 && b == 16)

729 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();

730 if (m == 32 && n == 32 && k == 4 && b == 1)

731 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();

732 if (m == 16 && n == 16 && k == 8 && b == 1)

733 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();

734 }

735

737 if (chipset >= kGfx950) {

738 if (m == 32 && n == 32 && k == 32 && b == 1)

739 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();

740 if (m == 16 && n == 16 && k == 64 && b == 1)

741 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();

742 }

743 if (m == 32 && n == 32 && k == 4 && b == 2)

744 return ROCDL::mfma_i32_32x32x4i8::getOperationName();

745 if (m == 16 && n == 16 && k == 4 && b == 4)

746 return ROCDL::mfma_i32_16x16x4i8::getOperationName();

747 if (m == 4 && n == 4 && k == 4 && b == 16)

748 return ROCDL::mfma_i32_4x4x4i8::getOperationName();

749 if (m == 32 && n == 32 && k == 8 && b == 1)

750 return ROCDL::mfma_i32_32x32x8i8::getOperationName();

751 if (m == 16 && n == 16 && k == 16 && b == 1)

752 return ROCDL::mfma_i32_16x16x16i8::getOperationName();

753 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)

754 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();

755 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)

756 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();

757 }

758

759 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {

760 if (m == 16 && n == 16 && k == 4 && b == 1)

761 return ROCDL::mfma_f64_16x16x4f64::getOperationName();

762 if (m == 4 && n == 4 && k == 4 && b == 4)

763 return ROCDL::mfma_f64_4x4x4f64::getOperationName();

764 }

765

767

768

769 Type sourceBElem =

770 cast(mfma.getSourceB().getType()).getElementType();

771 if (m == 16 && n == 16 && k == 32 && b == 1) {

773 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();

775 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();

776 }

777 if (m == 32 && n == 32 && k == 16 && b == 1) {

779 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();

781 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();

782 }

783 }

784

786 Type sourceBElem =

787 cast(mfma.getSourceB().getType()).getElementType();

788 if (m == 16 && n == 16 && k == 32 && b == 1) {

790 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();

792 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();

793 }

794 if (m == 32 && n == 32 && k == 16 && b == 1) {

796 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();

798 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();

799 }

800 }

801

802 return std::nullopt;

803 }

804

807 .Case([](Float8E4M3FNType) { return 0u; })

808 .Case([](Float8E5M2Type) { return 1u; })

809 .Case([](Float6E2M3FNType) { return 2u; })

810 .Case([](Float6E3M2FNType) { return 3u; })

811 .Case([](Float4E2M1FNType) { return 4u; })

812 .Default([](Type) { return std::nullopt; });

813 }

814

815

816

817

818

819

820

821

822 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>

824 uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {

828

830 return std::nullopt;

831 if (!isa(destType))

832 return std::nullopt;

833

836 if (!aTypeCode || !bTypeCode)

837 return std::nullopt;

838

839 if (m == 32 && n == 32 && k == 64 && b == 1)

840 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),

841 *aTypeCode, *bTypeCode};

842 if (m == 16 && n == 16 && k == 128 && b == 1)

843 return std::tuple{

844 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,

845 *bTypeCode};

846

847 return std::nullopt;

848 }

849

850 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>

853 mfma.getSourceA().getType(), mfma.getSourceB().getType(),

854 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),

855 mfma.getBlocks(), chipset);

856 }

857

858 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>

861 smfma.getSourceB().getType(),

862 smfma.getDestC().getType(), smfma.getM(),

863 smfma.getN(), smfma.getK(), 1u, chipset);

864 }

865

866

867

868

871 auto sourceVectorType = dyn_cast(wmma.getSourceA().getType());

872 auto sourceBVectorType = dyn_cast(wmma.getSourceB().getType());

873 auto destVectorType = dyn_cast(wmma.getDestC().getType());

874 auto elemSourceType = sourceVectorType.getElementType();

875 auto elemBSourceType = sourceBVectorType.getElementType();

876 auto elemDestType = destVectorType.getElementType();

877

878 if (elemSourceType.isF16() && elemDestType.isF32())

879 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();

880 if (elemSourceType.isBF16() && elemDestType.isF32())

881 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();

882 if (elemSourceType.isF16() && elemDestType.isF16())

883 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();

884 if (elemSourceType.isBF16() && elemDestType.isBF16())

885 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();

886 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))

887 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();

889 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))

890 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();

891 }

893 if (isa(elemSourceType) &&

894 isa(elemBSourceType) && elemDestType.isF32())

895 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();

896 if (isa(elemSourceType) &&

897 isa(elemBSourceType) && elemDestType.isF32())

898 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();

899 if (isa(elemSourceType) &&

900 isa(elemBSourceType) && elemDestType.isF32())

901 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();

902 if (isa(elemSourceType) &&

903 isa(elemBSourceType) && elemDestType.isF32())

904 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();

905 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {

906 bool isWave64 = destVectorType.getNumElements() == 4;

907

908

909 bool has8Inputs = sourceVectorType.getNumElements() == 8;

910 if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))

911 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();

912 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();

913 }

914 }

915 return std::nullopt;

916 }

917

918 namespace {

922

924

925 LogicalResult

926 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,

929 Type outType = typeConverter->convertType(op.getDestD().getType());

930 Type intrinsicOutType = outType;

931 if (auto outVecType = dyn_cast(outType))

932 if (outVecType.getElementType().isBF16())

933 intrinsicOutType = outVecType.clone(rewriter.getI16Type());

934

936 return op->emitOpError("MFMA only supported on gfx908+");

937 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());

938 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {

940 return op.emitOpError("negation unsupported on older than gfx942");

941 getBlgpField |=

942 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);

943 }

944 std::optional maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);

945 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>

947 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())

948 return op.emitOpError("no intrinsic matching MFMA size on given chipset");

949

950 bool isScaled =

951 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();

952 if (isScaled &&

953 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {

954 return op.emitOpError(

955 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "

956 "be scaled as those fields are used for type information");

957 }

958

959 StringRef intrinsicName =

960 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;

962 loweredOp.addTypes(intrinsicOutType);

963 loweredOp.addOperands(

966 adaptor.getDestC()});

967 if (isScaled) {

969 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;

970 loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),

972 zero, zero,

973 zero, zero});

974 } else {

975 loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),

978 };

980 if (outType != intrinsicOutType)

981 lowered = rewriter.createLLVM::BitcastOp(loc, outType, lowered);

983 return success();

984 }

985 };

986

990

992

993 LogicalResult

994 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,

997 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());

998

1000 return op->emitOpError("scaled MFMA only supported on gfx908+");

1001 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>

1003 if (!maybeScaledIntrinsic.has_value())

1004 return op.emitOpError(

1005 "no intrinsic matching scaled MFMA size on given chipset");

1006

1007 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;

1009 loweredOp.addTypes(intrinsicOutType);

1010 loweredOp.addOperands(

1013 adaptor.getDestC()});

1014 Value scalesIdxA =

1016 Value scalesIdxB =

1018 loweredOp.addOperands(

1021 scalesIdxA,

1022

1024 scalesIdxB,

1025

1028 rewriter.replaceOp(op, lowered);

1029 return success();

1030 }

1031 };

1032

1036

1038

1039 LogicalResult

1040 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,

1043 auto outType =

1044 typeConverter->convertType(op.getDestD().getType());

1045 if (!outType)

1047

1049 return op->emitOpError("WMMA only supported on gfx11 and gfx12");

1050

1051

1052

1053 VectorType rawOutType = outType;

1054 if (outType.getElementType().isBF16())

1055 rawOutType = outType.clone(rewriter.getI16Type());

1056

1057 std::optional maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);

1058

1059 if (!maybeIntrinsic.has_value())

1060 return op.emitOpError("no intrinsic matching WMMA on the given chipset");

1061

1062 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)

1063 return op.emitOpError("subwordOffset not supported on gfx12+");

1064

1066 loweredOp.addTypes(rawOutType);

1067

1070 adaptor.getSourceA(), op.getSourceA(), operands);

1072 adaptor.getSourceB(), op.getSourceB(), operands);

1074 op.getSubwordOffset(), op.getClamp(), operands);

1075

1076 loweredOp.addOperands(operands);

1078

1079 Operation *maybeCastBack = lowered;

1080 if (rawOutType != outType)

1081 maybeCastBack =

1082 rewriter.createLLVM::BitcastOp(loc, outType, lowered->getResult(0));

1084

1085 return success();

1086 }

1087 };

1088

1092

1094

1095 LogicalResult

1096 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,

1099 return op.emitOpError("pre-gfx9 and post-gfx10 not supported");

1100

1102

1103 auto srcMemRefType = cast(op.getSrc().getType());

1104 auto dstMemRefType = cast(op.getDst().getType());

1105

1106

1107

1108

1109 Type transferType = op.getTransferType();

1110 size_t loadWidth = [&]() -> size_t {

1111 if (auto transferVectorType = dyn_cast(transferType)) {

1112 return transferVectorType.getNumElements() *

1113 (transferVectorType.getElementTypeBitWidth() / 8);

1114 }

1116 }();

1117

1118

1119 if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)

1120 return op.emitOpError("chipset unsupported element size");

1121

1124 (adaptor.getSrcIndices()));

1127 (adaptor.getDstIndices()));

1128

1132 rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},

1133 ArrayAttr{});

1134

1135 return success();

1136 }

1137 };

1138

1139 namespace {

1140 struct ExtPackedFp8OpLowering final

1144 chipset(chipset) {}

1146

1147 LogicalResult

1148 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,

1150 };

1151

1152 struct PackedTrunc2xFp8OpLowering final

1157 chipset(chipset) {}

1159

1160 LogicalResult

1161 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,

1163 };

1164

1165 struct PackedStochRoundFp8OpLowering final

1167 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,

1170 chipset(chipset) {}

1172

1173 LogicalResult

1174 matchAndRewrite(PackedStochRoundFp8Op op,

1175 PackedStochRoundFp8OpAdaptor adaptor,

1177 };

1178

1179 struct ScaledExtPackedOpLowering final

1183 chipset(chipset) {}

1185

1186 LogicalResult

1187 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,

1189 };

1190

1191 struct PackedScaledTruncOpLowering final

1196 chipset(chipset) {}

1198

1199 LogicalResult

1200 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,

1202 };

1203

1204 }

1205

1206 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(

1207 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,

1212 loc, "Fp8 conversion instructions are not available on target "

1213 "architecture and their emulation is not implemented");

1216 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());

1217 Type f32 = getTypeConverter()->convertType(op.getResult().getType());

1218

1219 Value source = adaptor.getSource();

1220 auto sourceVecType = dyn_cast(op.getSource().getType());

1221 auto resultVecType = dyn_cast(op.getResult().getType());

1223

1224 if (!sourceVecType || sourceVecType.getNumElements() < 4) {

1225 Value longVec = rewriter.createLLVM::UndefOp(loc, v4i8);

1226 if (!sourceVecType) {

1227 longVec = rewriter.createLLVM::InsertElementOp(

1229 } else {

1230 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {

1232 Value elem = rewriter.createLLVM::ExtractElementOp(loc, source, idx);

1233 longVec =

1234 rewriter.createLLVM::InsertElementOp(loc, longVec, elem, idx);

1235 }

1236 }

1237 source = longVec;

1238 }

1239 Value i32Source = rewriter.createLLVM::BitcastOp(loc, i32, source);

1240 if (resultVecType) {

1242 rewriter.replaceOpWithNewOpROCDL::CvtPkF32Bf8Op(op, f32, i32Source,

1243 op.getIndex());

1245 rewriter.replaceOpWithNewOpROCDL::CvtPkF32Fp8Op(op, f32, i32Source,

1246 op.getIndex());

1247 }

1248 } else {

1251 op.getIndex());

1254 op.getIndex());

1255 }

1256 }

1257 return success();

1258 }

1259

1260 LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(

1261 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,

1266 loc, "Scaled fp conversion instructions are not available on target "

1267 "architecture and their emulation is not implemented");

1268 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());

1269

1270 Value source = adaptor.getSource();

1271 Value scale = adaptor.getScale();

1272

1273 VectorType sourceVecType = cast(op.getSource().getType());

1274 Type sourceElemType = sourceVecType.getElementType();

1275 VectorType destVecType = cast(op.getResult().getType());

1276 Type destElemType = destVecType.getElementType();

1277

1278 VectorType packedVecType;

1279 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {

1281 packedVecType = cast(getTypeConverter()->convertType(v4i8));

1282 } else if (isa(sourceElemType)) {

1284 packedVecType = cast(getTypeConverter()->convertType(v8i4));

1285 } else {

1286 llvm_unreachable("invalid element type for scaled ext");

1287 }

1288

1289

1290 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {

1291 Value longVec = rewriter.createLLVM::ZeroOp(loc, packedVecType);

1292 if (!sourceVecType) {

1293 longVec = rewriter.createLLVM::InsertElementOp(

1295 } else {

1296 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {

1298 Value elem = rewriter.createLLVM::ExtractElementOp(loc, source, idx);

1299 longVec =

1300 rewriter.createLLVM::InsertElementOp(loc, longVec, elem, idx);

1301 }

1302 }

1303 source = longVec;

1304 }

1305 Value i32Source = rewriter.createLLVM::BitcastOp(loc, i32, source);

1306

1307 if (isa(sourceElemType) && destElemType.isF32())

1309 op, destVecType, i32Source, scale, op.getIndex());

1310 else if (isa(sourceElemType) && destElemType.isF16())

1312 op, destVecType, i32Source, scale, op.getIndex());

1313 else if (isa(sourceElemType) && destElemType.isBF16())

1315 op, destVecType, i32Source, scale, op.getIndex());

1316 else if (isa(sourceElemType) && destElemType.isF32())

1318 op, destVecType, i32Source, scale, op.getIndex());

1319 else if (isa(sourceElemType) && destElemType.isF16())

1321 op, destVecType, i32Source, scale, op.getIndex());

1322 else if (isa(sourceElemType) && destElemType.isBF16())

1324 op, destVecType, i32Source, scale, op.getIndex());

1325 else if (isa(sourceElemType) && destElemType.isF32())

1327 op, destVecType, i32Source, scale, op.getIndex());

1328 else if (isa(sourceElemType) && destElemType.isF16())

1330 op, destVecType, i32Source, scale, op.getIndex());

1331 else if (isa(sourceElemType) && destElemType.isBF16())

1333 op, destVecType, i32Source, scale, op.getIndex());

1334 else

1335 return failure();

1336

1337 return success();

1338 }

1339

1340 LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(

1341 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,

1346 loc, "Scaled fp conversion instructions are not available on target "

1347 "architecture and their emulation is not implemented");

1348 Type v2i16 = getTypeConverter()->convertType(

1350 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());

1351

1352 Type resultType = op.getResult().getType();

1354 VectorType sourceVecType = cast(op.getSource().getType());

1355 Type sourceElemType = sourceVecType.getElementType();

1356

1357 Type intResultType = isa(resultElemType) ? i32 : v2i16;

1358

1359 Value source = adaptor.getSource();

1360 Value scale = adaptor.getScale();

1361 Value existing = adaptor.getExisting();

1362 if (existing)

1363 existing = rewriter.createLLVM::BitcastOp(loc, intResultType, existing);

1364 else

1365 existing = rewriter.createLLVM::ZeroOp(loc, intResultType);

1366

1367 if (sourceVecType.getNumElements() < 2) {

1369 Value elem0 = rewriter.createLLVM::ExtractElementOp(loc, source, c0);

1371 source = rewriter.createLLVM::ZeroOp(loc, v2);

1372 source = rewriter.createLLVM::InsertElementOp(loc, source, elem0, c0);

1373 }

1374

1375 Value sourceA, sourceB;

1376 if (sourceElemType.isF32()) {

1379 sourceA = rewriter.createLLVM::ExtractElementOp(loc, source, c0);

1380 sourceB = rewriter.createLLVM::ExtractElementOp(loc, source, c1);

1381 }

1382

1384 if (sourceElemType.isF32() && isa(resultElemType))

1385 result = rewriter.createROCDL::CvtScaleF32PkBf8F32Op(

1386 loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());

1387 else if (sourceElemType.isF16() && isa(resultElemType))

1388 result = rewriter.createROCDL::CvtScaleF32PkBf8F16Op(

1389 loc, intResultType, existing, source, scale, op.getIndex());

1390 else if (sourceElemType.isBF16() && isa(resultElemType))

1391 result = rewriter.createROCDL::CvtScaleF32PkBf8Bf16Op(

1392 loc, intResultType, existing, source, scale, op.getIndex());

1393 else if (sourceElemType.isF32() && isa(resultElemType))

1394 result = rewriter.createROCDL::CvtScaleF32PkFp8F32Op(

1395 loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());

1396 else if (sourceElemType.isF16() && isa(resultElemType))

1397 result = rewriter.createROCDL::CvtScaleF32PkFp8F16Op(

1398 loc, intResultType, existing, source, scale, op.getIndex());

1399 else if (sourceElemType.isBF16() && isa(resultElemType))

1400 result = rewriter.createROCDL::CvtScaleF32PkFp8Bf16Op(

1401 loc, intResultType, existing, source, scale, op.getIndex());

1402 else if (sourceElemType.isF32() && isa(resultElemType))

1403 result = rewriter.createROCDL::CvtScaleF32PkFp4F32Op(

1404 loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());

1405 else if (sourceElemType.isF16() && isa(resultElemType))

1406 result = rewriter.createROCDL::CvtScaleF32PkFp4F16Op(

1407 loc, intResultType, existing, source, scale, op.getIndex());

1408 else if (sourceElemType.isBF16() && isa(resultElemType))

1409 result = rewriter.createROCDL::CvtScaleF32PkFp4Bf16Op(

1410 loc, intResultType, existing, source, scale, op.getIndex());

1411 else

1412 return failure();

1413

1415 op, getTypeConverter()->convertType(resultType), result);

1416 return success();

1417 }

1418

1419 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(

1420 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,

1425 loc, "Fp8 conversion instructions are not available on target "

1426 "architecture and their emulation is not implemented");

1427 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());

1428

1429 Type resultType = op.getResult().getType();

1431

1432 Value sourceA = adaptor.getSourceA();

1433 Value sourceB = adaptor.getSourceB();

1434 if (!sourceB)

1435 sourceB = rewriter.createLLVM::UndefOp(loc, sourceA.getType());

1436 Value existing = adaptor.getExisting();

1437 if (existing)

1438 existing = rewriter.createLLVM::BitcastOp(loc, i32, existing);

1439 else

1440 existing = rewriter.createLLVM::UndefOp(loc, i32);

1441

1444 result = rewriter.createROCDL::CvtPkBf8F32Op(loc, i32, sourceA, sourceB,

1445 existing, op.getWordIndex());

1447 result = rewriter.createROCDL::CvtPkFp8F32Op(loc, i32, sourceA, sourceB,

1448 existing, op.getWordIndex());

1449

1451 op, getTypeConverter()->convertType(resultType), result);

1452 return success();

1453 }

1454

1455 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(

1456 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,

1461 loc, "Fp8 conversion instructions are not available on target "

1462 "architecture and their emulation is not implemented");

1463 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());

1464

1465 Type resultType = op.getResult().getType();

1467

1468 Value source = adaptor.getSource();

1469 Value stoch = adaptor.getStochiasticParam();

1470 Value existing = adaptor.getExisting();

1471 if (existing)

1472 existing = rewriter.createLLVM::BitcastOp(loc, i32, existing);

1473 else

1474 existing = rewriter.createLLVM::UndefOp(loc, i32);

1475

1478 result = rewriter.createROCDL::CvtSrBf8F32Op(

1479 loc, i32, source, stoch, existing, op.getStoreIndex());

1481 result = rewriter.createROCDL::CvtSrFp8F32Op(

1482 loc, i32, source, stoch, existing, op.getStoreIndex());

1483

1485 op, getTypeConverter()->convertType(resultType), result);

1486 return success();

1487 }

1488

1489

1490

1495

1496 LogicalResult

1497 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,

1499

1500

1501 Location loc = DppOp.getLoc();

1502 Value src = adaptor.getSrc();

1503 Value old = adaptor.getOld();

1506 Type llvmType = nullptr;

1509 } else if (isa(srcType)) {

1513 } else if (isa(srcType)) {

1517 }

1518 auto llvmSrcIntType = typeConverter->convertType(

1520

1521

1522 auto convertOperand = [&](Value operand, Type operandType) {

1523 if (operandType.getIntOrFloatBitWidth() <= 16) {

1524 if (llvm::isa(operandType)) {

1525 operand =

1526 rewriter.createLLVM::BitcastOp(loc, llvmSrcIntType, operand);

1527 }

1529 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));

1530 Value undefVec = rewriter.createLLVM::UndefOp(loc, llvmVecType);

1531 operand = rewriter.createLLVM::InsertElementOp(

1533 operand = rewriter.createLLVM::BitcastOp(loc, llvmType, operand);

1534 }

1535 return operand;

1536 };

1537

1538 src = convertOperand(src, srcType);

1539 old = convertOperand(old, oldType);

1540

1541

1542 enum DppCtrl : unsigned {

1543 ROW_SHL0 = 0x100,

1544 ROW_SHR0 = 0x110,

1545 ROW_ROR0 = 0x120,

1546 WAVE_SHL1 = 0x130,

1547 WAVE_ROL1 = 0x134,

1548 WAVE_SHR1 = 0x138,

1549 WAVE_ROR1 = 0x13C,

1550 ROW_MIRROR = 0x140,

1551 ROW_HALF_MIRROR = 0x141,

1552 BCAST15 = 0x142,

1553 BCAST31 = 0x143,

1554 };

1555

1556 auto kind = DppOp.getKind();

1557 auto permArgument = DppOp.getPermArgument();

1558 uint32_t DppCtrl = 0;

1559

1560 switch (kind) {

1561

1562 case DPPPerm::quad_perm:

1563 if (auto quadPermAttr = cast(*permArgument)) {

1564 int32_t i = 0;

1565 for (auto elem : quadPermAttr.getAsRange()) {

1566 uint32_t num = elem.getInt();

1567 DppCtrl |= num << (i * 2);

1568 i++;

1569 }

1570 }

1571 break;

1572 case DPPPerm::row_shl:

1573 if (auto intAttr = cast(*permArgument)) {

1574 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;

1575 }

1576 break;

1577 case DPPPerm::row_shr:

1578 if (auto intAttr = cast(*permArgument)) {

1579 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;

1580 }

1581 break;

1582 case DPPPerm::row_ror:

1583 if (auto intAttr = cast(*permArgument)) {

1584 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;

1585 }

1586 break;

1587 case DPPPerm::wave_shl:

1588 DppCtrl = DppCtrl::WAVE_SHL1;

1589 break;

1590 case DPPPerm::wave_shr:

1591 DppCtrl = DppCtrl::WAVE_SHR1;

1592 break;

1593 case DPPPerm::wave_rol:

1594 DppCtrl = DppCtrl::WAVE_ROL1;

1595 break;

1596 case DPPPerm::wave_ror:

1597 DppCtrl = DppCtrl::WAVE_ROR1;

1598 break;

1599 case DPPPerm::row_mirror:

1600 DppCtrl = DppCtrl::ROW_MIRROR;

1601 break;

1602 case DPPPerm::row_half_mirror:

1603 DppCtrl = DppCtrl::ROW_HALF_MIRROR;

1604 break;

1605 case DPPPerm::row_bcast_15:

1606 DppCtrl = DppCtrl::BCAST15;

1607 break;

1608 case DPPPerm::row_bcast_31:

1609 DppCtrl = DppCtrl::BCAST31;

1610 break;

1611 }

1612

1613

1614

1615 auto rowMask = DppOp->getAttrOfType("row_mask").getInt();

1616 auto bankMask = DppOp->getAttrOfType("bank_mask").getInt();

1617 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();

1618

1619

1620 auto dppMovOp = rewriter.createROCDL::DPPUpdateOp(

1621 loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);

1622

1623 Value result = dppMovOp.getRes();

1625 result = rewriter.createLLVM::TruncOp(loc, llvmSrcIntType, result);

1626 if (!llvm::isa(srcType)) {

1627 result = rewriter.createLLVM::BitcastOp(loc, srcType, result);

1628 }

1629 }

1630

1631

1632

1634 return success();

1635 }

1636 };

1637

1638 struct AMDGPUSwizzleBitModeLowering

1641

1642 LogicalResult

1643 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,

1647 Value src = adaptor.getSrc();

1650 unsigned andMask = op.getAndMask();

1651 unsigned orMask = op.getOrMask();

1652 unsigned xorMask = op.getXorMask();

1653

1654

1655

1656 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);

1659 for (Value v : decomposed) {

1661 rewriter.createROCDL::DsSwizzleOp(loc, v.getType(), v, maskValue);

1662 swizzled.emplace_back(res);

1663 }

1664

1667 return success();

1668 }

1669 };

1670

1671 struct ConvertAMDGPUToROCDLPass

1672 : public impl::ConvertAMDGPUToROCDLPassBase {

1673 using Base::Base;

1674

1675 void runOnOperation() override {

1677 FailureOr maybeChipset = Chipset::parse(chipset);

1678 if (failed(maybeChipset)) {

1680 return signalPassFailure();

1681 }

1682

1687 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();

1688 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();

1689 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();

1692 signalPassFailure();

1693 }

1694 };

1695 }

1696

1700 [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)

1704 switch (as.getValue()) {

1705 case amdgpu::AddressSpace::FatRawBuffer:

1707 case amdgpu::AddressSpace::BufferRsrc:

1709 case amdgpu::AddressSpace::FatStructuredBuffer:

1711 }

1713 });

1714 }

1715

1721 .add<FatRawBufferCastLowering,

1722 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,

1723 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,

1724 RawBufferOpLowering<RawBufferAtomicFaddOp,

1725 ROCDL::RawPtrBufferAtomicFaddOp>,

1726 RawBufferOpLowering<RawBufferAtomicFmaxOp,

1727 ROCDL::RawPtrBufferAtomicFmaxOp>,

1728 RawBufferOpLowering<RawBufferAtomicSmaxOp,

1729 ROCDL::RawPtrBufferAtomicSmaxOp>,

1730 RawBufferOpLowering<RawBufferAtomicUminOp,

1731 ROCDL::RawPtrBufferAtomicUminOp>,

1732 RawBufferOpLowering<RawBufferAtomicCmpswapOp,

1733 ROCDL::RawPtrBufferAtomicCmpSwap>,

1734 AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,

1735 MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,

1736 ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,

1737 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,

1738 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,

1739 chipset);

1740 patterns.add(converter);

1741 }

static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)

Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...

constexpr Chipset kGfx942

static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)

Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.

static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)

constexpr Chipset kGfx908

constexpr Chipset kGfx90a

static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)

Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.

static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)

Push the output operand.

static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)

Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...

static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)

static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)

Push an input operand.

static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)

Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...

static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)

Returns the linear index used to access an element in the memref.

static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)

Convert an unsigned number val to i32.

static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)

Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...

static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)

static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, uint32_t elementByteWidth)

Compute the contents of the num_records field for a given memref descriptor - that is,...

static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)

static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)

If there is a scaled MFMA instruction for the input element types aType and bType,...

constexpr Chipset kGfx950

static MLIRContext * getContext(OpFoldResult val)

union mlir::linalg::@1203::ArityGroupAndKind::Kind kind

static constexpr unsigned kSizePosInMemRefDescriptor

static constexpr unsigned kStridePosInMemRefDescriptor

static constexpr unsigned kOffsetPosInMemRefDescriptor

static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor

static constexpr unsigned kAlignedPtrPosInMemRefDescriptor

static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

This class provides a shared interface for ranked and unranked memref types.

Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.

IntegerAttr getIndexAttr(int64_t value)

IntegerAttr getI32IntegerAttr(int32_t value)

IntegerAttr getI16IntegerAttr(int16_t value)

IntegerType getIntegerType(unsigned width)

MLIRContext * getContext() const

This class implements a pattern rewriter for use with ConversionPatterns.

void replaceOp(Operation *op, ValueRange newValues) override

Replace the given operation with the new values.

void eraseOp(Operation *op) override

PatternRewriter hook for erasing a dead operation.

Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...

ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)

The main mechanism for performing data layout queries.

static DataLayout closest(Operation *op)

Returns the layout of the closest parent operation carrying layout info.

llvm::TypeSize getTypeSizeInBits(Type t) const

Returns the size in bits of the given type in the current scope.

Derived class that automatically populates legalization information for different LLVM ops.

Conversion from types to the LLVM IR dialect.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...

Value stride(OpBuilder &builder, Location loc, unsigned pos)

Builds IR extracting the pos-th size from the descriptor.

static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)

Builds IR creating a poison value of the descriptor type.

Value size(OpBuilder &builder, Location loc, unsigned pos)

Builds IR extracting the pos-th size from the descriptor.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

Operation is the basic unit of execution within MLIR.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

result_range getResults()

unsigned getNumResults()

Return the number of results held by this operation.

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

The general result of a type attribute conversion callback, allowing for early termination.

static AttributeConversionResult abort()

LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const

Convert the given type.

void addTypeAttributeConversion(FnT &&callback)

Register a conversion function for attributes within types.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isSignedInteger() const

Return true if this is a signed integer type (with the specified width).

bool isUnsignedInteger() const

Return true if this is an unsigned integer type (with the specified width).

bool isInteger() const

Return true if this is an integer type (with the specified width).

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)

Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...

Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)

Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...

SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)

Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...

bool hasOcpFp8(const Chipset &chipset)

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)

Divides the known min value of the numerator by the denominator and rounds the result up to the next ...

Include the generated interface declarations.

void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)

Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)

Note: This function will also add conversions for the AMDGPU-specific address spaces,...

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())

Below we define several entry points for operation conversion.

This represents an operation in an abstracted form, suitable for use with the builder APIs.

Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.

static FailureOr< Chipset > parse(StringRef name)

Parses the chipset version string and returns the chipset on success, and failure otherwise.