MLIR: lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

32 #include "llvm/ADT/ArrayRef.h"

33

34 using namespace mlir;

39

40 #define DEBUG_TYPE "nvgpu-transforms"

41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")

42 #define DBGSNL() (llvm::dbgs() << "\n")

43 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")

44

45

46

47

48

49 void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(

51 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);

52

53

54

56 llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned {

57 switch (space) {

58 case gpu::AddressSpace::Global:

59 return static_cast<unsigned>(

61 case gpu::AddressSpace::Workgroup:

62 return static_cast<unsigned>(

64 case gpu::AddressSpace::Private:

65 return 0;

66 }

67 llvm_unreachable("unknown address space enum value");

68 return 0;

69 });

70 llvmTypeConverter.addConversion(

71 [&](nvgpu::DeviceAsyncTokenType type) -> Type {

72 return llvmTypeConverter.convertType(

74 });

75 llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {

76 return llvmTypeConverter.convertType(

78 });

79 llvmTypeConverter.addConversion(

80 [&](nvgpu::WarpgroupAccumulatorType type) -> Type {

81 Type elemType = type.getFragmented().getElementType();

82 int64_t sizeM = type.getFragmented().getDimSize(0);

83 int64_t sizeN = type.getFragmented().getDimSize(1);

84

85 unsigned numMembers;

87 numMembers = sizeN / 2;

88 else if (elemType.isF16())

89 numMembers = sizeN / 4;

90 else

91 llvm_unreachable("unsupported type for warpgroup accumulator");

92

94 for (unsigned i = 0; i < numMembers; i++)

95 innerStructBody.push_back(elemType);

96 auto innerStructType = LLVM::LLVMStructType::getLiteral(

97 type.getContext(), innerStructBody);

98

100 for (int i = 0; i < sizeM; i += kWgmmaSizeM)

101 structBody.push_back(innerStructType);

102

103 auto convertedType =

104 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);

105 return llvmTypeConverter.convertType(convertedType);

106 });

107 llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {

108 return llvmTypeConverter.convertType(

110 });

111 llvmTypeConverter.addConversion(

112 [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {

113 return llvmTypeConverter.convertType(

115 });

116 llvmTypeConverter.addConversion(

117 [&](nvgpu::TensorMapDescriptorType type) -> Type {

119 });

121 }

122

123 LogicalResult

124 transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(

125 transform::TypeConverterBuilderOpInterface builder) {

126 if (builder.getTypeConverterType() != "LLVMTypeConverter")

127 return emitOpError("expected LLVMTypeConverter");

128 return success();

129 }

130

131

132

133

134

135 void transform::CreateAsyncGroupsOp::getEffects(

140 }

141

148 }

149

150

151

152

153

154

157 }

158

159

161 auto space =

162 dyn_cast_if_presentgpu::AddressSpaceAttr(type.getMemorySpace());

163 return space &&

164 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();

165 }

166

167

168

170

171 auto load = dyn_castvector::TransferReadOp(op);

172 if (!load)

173 return nullptr;

174

175 auto loadType = dyn_cast(load.getBase().getType());

177 return nullptr;

178 return load;

179 }

180

181

183

184 auto store = dyn_castvector::TransferWriteOp(op);

185 if (!store || store.getVector() != v)

186 return false;

187

188 auto storeType = dyn_cast(store.getBase().getType());

190 }

191

192

193

196 if (!loaded || !loaded.hasOneUse())

197 return false;

198

200 }

201

202

203

204

205

206

207

208

209

210

211 static LogicalResult

214

216 for (Operation &op : *forOp.getBody()) {

217

218 if (op.getNumRegions() > 0)

219 return failure();

220

221 if (isagpu::BarrierOp(op)) {

222 barriers.insert(&op);

223 continue;

224 }

225

226 if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {

227 ops.insert(&op);

228 ops.insert(std::make_move_iterator(barriers.begin()),

229 std::make_move_iterator(barriers.end()));

230 assert(barriers.empty() &&

231 "expected to have moved the barriers into another set");

232 continue;

233 }

234

236 ops.insert(&op);

237 continue;

238 }

239 }

240

241 return success();

242 }

243

244

245

246

247

248 static void

251 unsigned iteration, unsigned depth) {

252

253

254 auto waitOp = dyn_castnvgpu::DeviceAsyncWaitOp(op);

255 if (!waitOp || waitOp.getNumGroups())

256 return;

257

258 int numGroupInFlight = 0;

261 numGroupInFlight = depth - 1;

262 } else {

263

264

266

267

268 numGroupInFlight = depth - 1 - iteration;

269 }

270 waitOp.setNumGroups(numGroupInFlight);

271 }

272

273

274

275

276

277

278

279

280

281

282

284 scf::ForOp forOp,

285 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,

289 return visited->getBlock() == forOp.getBody();

290 });

291 options.inclusive = true;

292 for (Operation &op : forOp.getBody()->getOperations()) {

293 if (stage0Ops.contains(&op)) {

295 assert(result.succeeded() && "expected a backward slice");

296 (void)result;

297 }

298 }

299

300 for (Operation &op : forOp.getBody()->getOperations()) {

301 if (!dependencies.contains(&op) && !isascf::YieldOp(op))

302 opsWithPipelineStages.emplace_back(&op, depth);

303 }

304 for (Operation &op : forOp.getBody()->getOperations()) {

305 if (dependencies.contains(&op))

306 opsWithPipelineStages.emplace_back(&op, 0);

307 }

308 }

309

310

311

312

313

316

317

318

320 isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,

321 nvgpu::DeviceAsyncWaitOp>(op)) {

322 return op;

323 }

324

325

326 auto asyncCopyOp = dyn_castnvgpu::DeviceAsyncCopyOp(op);

327 if (!asyncCopyOp)

328 return nullptr;

329

330

331

332

333

334

335 Location loc = asyncCopyOp->getLoc();

336 Value dstElements =

337 rewriter.createarith::ConstantOp(loc, asyncCopyOp.getDstElementsAttr());

338 Value originalSrcElement =

339 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;

340 Value c0Index = rewriter.createarith::ConstantIndexOp(loc, 0);

341 auto srcElements = rewriter.createarith::SelectOp(

342 loc, predicate, originalSrcElement, c0Index);

343 auto asyncCopyZeroFillOp = rewriter.createnvgpu::DeviceAsyncCopyOp(

345 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),

346 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,

347 UnitAttr());

348 rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);

349 return asyncCopyZeroFillOp;

350 }

351

352

353

354

355

356

357 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>

359 bool epiloguePeeling) {

362 return std::make_tuple(

364 scf::ForOp());

365 }

366 if (stage0Ops.empty()) {

367 return std::make_tuple(

369 }

370

372 unsigned maxDepth = depth;

373 auto setAnnotation = [&](Operation *op,

375 unsigned iteration) {

377 };

379 [&](scf::ForOp schedulingFor,

380 std::vector<std::pair<Operation *, unsigned>> &ops) {

381 if (schedulingFor != forOp)

382 return;

384 };

385 options.annotateFn = setAnnotation;

386 if (!epiloguePeeling) {

387 options.peelEpilogue = false;

389 }

390

393 bool modifiedIR;

394 FailureOrscf::ForOp maybePipelined =

396 if (succeeded(maybePipelined)) {

398 *maybePipelined);

399 }

400 return std::make_tuple(

401 modifiedIR

404 scf::ForOp());

405 }

406

411 rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());

412 if (diag.succeeded()) {

415 }

416 if (diag.isDefiniteFailure()) {

418 if (!getPeelEpilogue()) {

419 diag.attachNote(forOp->getLoc()) << "couldn't predicate?";

420 diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();

421 }

423 }

424

425 return std::move(diag);

426 }

427

428

429

430

431

432

433

434 struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {

437

440

441 void print(llvm::raw_ostream &os) const {

442 os << "- indexing: " << first << ", " << second;

443 }

444 };

445

446

447

448

451 : b(b), loc(loc), laneId(laneId) {}

452

454 std::function<SmallVector(MLIRContext *)>;

455

456

457

458 FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);

459

460 private:

461 struct MmaSyncInfo {

462 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;

464 vectorShapes;

466 bool tf32Enabled;

467 };

468

469

470

471

472 FailureOr getIndexCalculators(ArrayRef<int64_t> opShape,

474

475

476

477

478

479

480

481

482

483

484

485

486

487

488

489

490

491

495 AffineExpr threadIDInGroup = dim % 4;

498 }

499

500

501

502

503

504

508 AffineExpr threadIDInGroup = dim % 4;

510 }

511

512

513

514

515

516

517

521 AffineExpr threadIDInGroup = dim % 4;

522 return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},

524 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},

525 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};

526 }

527

528

529

530

531

532

533

534

535

536

537

538

539

543 AffineExpr threadIDInGroup = dim % 4;

544

545 return {

546 RowColIndexing{groupID, threadIDInGroup * 2 + 0},

547 RowColIndexing{groupID, threadIDInGroup * 2 + 1},

548 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},

549 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1},

550 RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8},

551 RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8},

552 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8},

553 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8}

554 };

555

556 }

557

558

559

560

561

562

563

564

565

569 AffineExpr threadIDInGroup = dim % 4;

570

571 return {

572 RowColIndexing{threadIDInGroup * 2 + 0, groupID},

573 RowColIndexing{threadIDInGroup * 2 + 1, groupID},

574 RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID},

575 RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID}

576 };

577

578 }

579

580

581

582

583

584

585

586

587

591 AffineExpr threadIDInGroup = dim % 4;

592

593 return {

594 RowColIndexing{groupID, threadIDInGroup * 2 + 0},

595 RowColIndexing{groupID, threadIDInGroup * 2 + 1},

596 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},

597 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}

598 };

599

600 }

601

602

603

604

605

606

607

608

609

612 const IndexCalculator &indexFn);

613

614

615

616

617

618

619

622 IndexCalculator indexFn,

624

625

626

627

631 const IndexCalculator &indexFn);

632

633

634

635

636

637

638

642

646 };

647

648

649

650

651

652

653

654 template <typename ApplyFn, typename ReduceFn>

656 ReduceFn reduceFn) {

657 VectorType vectorType = cast(vector.getType());

658 auto vectorShape = vectorType.getShape();

660 for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {

661 auto indices = delinearize(idx, strides);

662 reduceFn(applyFn(vector, idx, indices), idx, indices);

663 }

664 }

665

669 const IndexCalculator &indexFn) {

672 };

675 for (auto indexing : indexings) {

678 auto load = b.creatememref::LoadOp(loc, memref, ValueRange{row, col});

679 res.push_back(load);

680 }

681 return res;

682 }

683

684 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(

687 auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));

688

691 Value res = b.createvector::SplatOp(loc, vt, loads[0]);

693 res,

694

696 return loads[linearIdx];

697 },

698

700 res = b.createvector::InsertOp(loc, v, res, indices);

701 });

702

703 return res;

704 }

705

708 Value memref, const IndexCalculator &indexFn) {

711 };

713 for (auto [indexing, val] :

714 llvm::zip_equal(indexFn(b.getContext()), toStore)) {

718 b.creatememref::StoreOp(loc, val, memref, ValueRange{row, col});

719 res.push_back(store);

720 }

721 return res;

722 }

723

728 toStore.reserve(32);

730 vectorToStore,

731

733 return b.createvector::ExtractOp(loc, vectorToStore, indices);

734 },

735

737 toStore.push_back(v);

738 });

739 return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));

740 }

741

749 return std::make_tuple(vlhs, vrhs, vres);

750 }

751

752 FailureOrMmaSyncBuilder::MmaSyncInfo

755

759 elementalTypes == TypeRange{f32, f32, f32}) {

760 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,

761 &MmaSyncBuilder::m16n8k4tf32Rhs,

762 &MmaSyncBuilder::m16n8k4tf32Res),

765 true};

766 }

767

768

771 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,

772 &MmaSyncBuilder::m16n8k16f16Rhs,

773 &MmaSyncBuilder::m16n8k16f16Res),

776 false};

777 }

778 return failure();

779 }

780

782 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();

783 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();

784 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();

785 assert(cast(lhsMemRef.getType()).getRank() == 2 &&

786 "expected lhs to be a 2D memref");

787 assert(cast(rhsMemRef.getType()).getRank() == 2 &&

788 "expected rhs to be a 2D memref");

789 assert(cast(resMemRef.getType()).getRank() == 2 &&

790 "expected res to be a 2D memref");

791

792 int64_t m = cast(lhsMemRef.getType()).getShape()[0];

793 int64_t n = cast(rhsMemRef.getType()).getShape()[1];

794 int64_t k = cast(lhsMemRef.getType()).getShape()[1];

798

799 FailureOr maybeInfo =

800 getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});

801 if (failed(maybeInfo))

802 return failure();

803

804 MmaSyncInfo info = *maybeInfo;

805 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;

806 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;

807 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,

808 lhsIndexFn, lhsShape);

809 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,

810 rhsIndexFn, rhsShape);

811 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,

812 resIndexFn, resShape);

813 res = b.createnvgpu::MmaSyncOp(loc, lhs, rhs, res, info.mmaShape,

814 info.tf32Enabled);

815 buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,

816 resShape);

818 }

819

824 bool fail = true;

825

826 if (isa_and_nonnulllinalg::MatmulOp(linalgOp.getOperation())) {

827

828

829 if (linalgOp.hasUserDefinedMaps()) {

830 return emitSilenceableError()

831 << "only matmul ops with non-extended semantics are supported";

832 }

833 Location loc = linalgOp.getLoc();

834

835 Value laneId = rewriter.creategpu::ThreadIdOp(

836 loc, rewriter.getIndexType(), gpu::Dimension::x);

837 if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))

838 fail = false;

839 }

840

841 if (fail) {

843 << "unsupported target op: " << linalgOp;

844 diag.attachNote(linalgOp->getLoc()) << "target op";

846 }

847

848 rewriter.eraseOp(linalgOp);

850 }

851

852

853

854

855

856

857

860 : rewriter(rewriter), loc(loc) {}

861

863 buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);

864

865

866

869 gpu::LaunchOp launchOp);

870

871

872

880

881

882

883

888

890

893 };

894

900 Value zero = rewriter.createarith::ConstantIndexOp(loc, 0);

901 Value tidx = rewriter.creategpu::ThreadIdOp(loc, gpu::Dimension::x);

903 rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::eq, tidx, zero);

904

905 rewriter.createscf::IfOp(

906 loc,

907 cond,

908

911 sizes.reserve(globalDescriptors.size());

912 for (auto [desc, shmem] : llvm::zip_equal(

913 globalDescriptors, sharedMemBuffers)) {

914 OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);

915 sizes.push_back(sz);

916 }

917

918

919 buildBarrierArriveTx(barrier, sizes);

920 rewriter.createscf::YieldOp(loc);

921 },

922

924

925

927 rewriter.createscf::YieldOp(loc);

928 });

929

930 return loadOps;

931 }

932

935 b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());

936

937 }

938

942 Value barrier = rewriter.createnvgpu::MBarrierCreateOp(

943 loc,

945 Value zero = rewriter.createarith::ConstantIndexOp(loc, 0);

946 rewriter.createnvgpu::MBarrierInitOp(

949 rewriter.creategpu::BarrierOp(loc);

950 return cast<TypedValuenvgpu::MBarrierGroupType>(barrier);

951 }

952

955 gpu::LaunchOp launchOp) {

958 Value unrankedMemRef = rewriter.creatememref::CastOp(

959 loc,

961 memref.getType().getMemorySpace()),

962 memref);

967

969 Value desc = rewriter.createnvgpu::TmaCreateDescriptorOp(

970 loc,

975 TensorMapSwizzleKind::SWIZZLE_NONE,

976 TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,

977 TensorMapInterleaveKind::INTERLEAVE_NONE),

978 unrankedMemRef, sizes);

979 return cast<TypedValuenvgpu::TensorMapDescriptorType>(desc);

980 }

981

986 SmallVectorImpl<Operation *> &loadOps) {

988 Value zero = rewriter.createarith::ConstantIndexOp(loc, 0);

989 Operation *loadOp = rewriter.createnvgpu::TmaAsyncLoadOp(

990 loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero,

992 loadOps.push_back(loadOp);

998 (sharedMemref.getType().getElementTypeBitWidth() / 8);

1000 prodExprInBytes, mixedSizes);

1001 return res;

1002 }

1003

1006 ArrayRef mixedSizes) {

1007 assert(!mixedSizes.empty() && "expecte non-empty sizes");

1015 Value zero = rewriter.createarith::ConstantIndexOp(loc, 0);

1016 rewriter.createnvgpu::MBarrierArriveExpectTxOp(loc, barrier, sizeVal, zero,

1018 }

1019

1023 Value parity = rewriter.createLLVM::ConstantOp(loc, i1, 0);

1024

1025

1026

1027 Value ticksBeforeRetry =

1028 rewriter.createarith::ConstantIndexOp(loc, 10000000);

1029 Value zero = rewriter.createarith::ConstantIndexOp(loc, 0);

1030 rewriter.createnvgpu::MBarrierTryWaitParityOp(loc, barrier, parity,

1031 ticksBeforeRetry, zero);

1032 }

1033

1034

1035

1036

1037

1038

1042

1044 };

1045

1048 if (copyOps.empty())

1050

1051 auto launchOp = copyOps.front()->getParentOfTypegpu::LaunchOp();

1052 assert(launchOp && "expected launch op");

1053

1054

1061 rewriter, loc, prod,

1062 ArrayRef{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),

1063 launchOp.getBlockSizeZ()});

1064

1066 buildAndInitBarrierInSharedMemory(numThreads);

1067

1071 auto copyOp = castlinalg::CopyOp(op);

1072 auto inMemRef =

1073 cast<TypedValue>(copyOp.getDpsInputOperand(0)->get());

1074 assert(inMemRef.getType().getRank() == 2 &&

1075 "expected in to be a 2D memref");

1076

1077

1079 buildGlobalMemRefDescriptor(inMemRef, launchOp);

1080 globalDescs.push_back(globalDesc);

1081

1082

1083 auto shmem =

1084 cast<TypedValue>(copyOp.getDpsInitOperand(0)->get());

1085 shmems.push_back(shmem);

1086 }

1087

1088

1092 buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);

1093

1094

1095 buildTryWaitParity(barrier);

1096

1097

1100

1101 return results;

1102 }

1103

1108 auto payloadOps = state.getPayloadOps(getTarget());

1109 gpu::LaunchOp commonLaunchOp;

1111 if (llvm::any_of(payloadOps, [&](Operation *op) {

1112 if (!commonLaunchOp) {

1114 firstOp = op;

1115 }

1117 commonLaunchOp != op->getParentOfTypegpu::LaunchOp() ||

1118 !isalinalg::CopyOp(op);

1119 if (fail)

1120 failingOp = op;

1121 return fail;

1122 })) {

1124 emitSilenceableError()

1125 << "target ops must be linalg::CopyOp nested under a common "

1126 "gpu.LaunchOp to be rewritten because the tma descriptors need to "

1127 "be created on the host.\nBut got: "

1128 << *firstOp << "\nand " << *failingOp;

1129 return diag;

1130 }

1131

1132

1133 CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));

1134

1136 }

1137

1138

1139

1140

1141

1142 namespace {

1143 class NVGPUTransformDialectExtension

1145 NVGPUTransformDialectExtension> {

1146 public:

1148

1149 NVGPUTransformDialectExtension() {

1150 declareGeneratedDialectarith::ArithDialect();

1151 declareGeneratedDialectaffine::AffineDialect();

1152 declareGeneratedDialectnvgpu::NVGPUDialect();

1153 declareGeneratedDialectNVVM::NVVMDialect();

1154 declareGeneratedDialectvector::VectorDialect();

1155 registerTransformOps<

1156 #define GET_OP_LIST

1157 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"

1158 >();

1159 }

1160 };

1161 }

1162

1163 #define GET_OP_CLASSES

1164 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"

1165

1167 registry.addExtensions();

1168 }

static constexpr int64_t kSharedMemorySpace

static std::string diag(const llvm::Value &value)

constexpr int kWgmmaSizeM

M size of wgmma.mma_async instruction.

static Attribute getSharedAddressSpaceAttribute(OpBuilder &b)

static bool hasDefaultMemorySpace(BaseMemRefType type)

Returns true if the given type has the default memory space.

static LogicalResult collectStage0PipeliningOps(scf::ForOp forOp, llvm::SmallPtrSet< Operation *, 16 > &ops)

Populate ops with the set of operations that belong to the stage 0 of the pipelined version of the gi...

static Operation * replaceOpWithPredicatedOp(RewriterBase &rewriter, Operation *op, Value predicate)

Hook for the loop pipeliner.

static bool isStoreToShared(Operation *op, Value v)

Returns true if the operation is storing the given value into shared memory.

static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, ReduceFn reduceFn)

Helper functions to create customizable load and stores operations.

static std::tuple< DiagnosedSilenceableFailure, scf::ForOp > pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, bool epiloguePeeling)

Applies loop pipelining with the given depth to the given loop so that copies into the shared memory ...

static bool hasSharedMemorySpace(BaseMemRefType type)

Returns true if the given type has the shared (workgroup) memory space.

static bool isLoadFromGlobalStoredToShared(Operation *op)

Returns true if the operation is a load from the default memory space the result of which is only sto...

static std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > makeVectorShapes(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, ArrayRef< int64_t > res)

static void setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, scf::PipeliningOption::PipelinerPart part, unsigned iteration, unsigned depth)

Hook for the loop pipeliner that sets the "num groups in flight" attribute of async wait operations c...

static void getPipelineStages(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned >> &opsWithPipelineStages, unsigned depth, llvm::SmallPtrSetImpl< Operation * > &stage0Ops)

Hook for the loop pipeliner that populates ops with the stage information as follows:

static Value getValueLoadedFromGlobal(Operation *op)

Returns the value produced by a load from the default memory space.

static llvm::ManagedStatic< PassManagerOptions > options

static std::optional< VectorShape > vectorShape(Type type)

static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)

Rewrite the given regions using the computing analysis.

#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)

Base type for affine expression.

AffineExpr floorDiv(uint64_t v) const

Attributes are known-constant values of operations.

This class provides a shared interface for ranked and unranked memref types.

Attribute getMemorySpace() const

Returns the memory space in which data referred to by this memref resides.

unsigned getMemorySpaceAsInt() const

[deprecated] Returns the memory space in old raw integer representation.

MLIRContext * getContext() const

The result of a transform IR operation application.

static DiagnosedSilenceableFailure success()

Constructs a DiagnosedSilenceableFailure in the success state.

static DiagnosedSilenceableFailure definiteFailure()

Constructs a DiagnosedSilenceableFailure in the failure state.

The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.

void addExtensions()

Add the given extensions to the registry.

Conversion from types to the LLVM IR dialect.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

This is a builder type that keeps local references to arguments.

Builder & setMemorySpace(Attribute newMemorySpace)

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents a single result from folding an operation.

Operation is the basic unit of execution within MLIR.

Block * getBlock()

Returns the operation block that contains this operation.

OpTy getParentOfType()

Return the closest surrounding parent operation that is of type 'OpTy'.

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isInteger() const

Return true if this is an integer type (with the specified width).

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

user_range getUsers() const

bool hasOneUse() const

Returns true if this value has exactly one use.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...

void push_back(Operation *op)

Appends an element to the list.

Base class for extensions of the Transform dialect that supports injecting operations into the Transf...

Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...

This is a special rewriter to be used in transform op implementations, providing additional helper fu...

The state maintained across applications of various ops implementing the TransformOpInterface.

@ kGlobalMemorySpace

Global memory space identifier.

OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...

SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)

Return the dimensions of the given memref value.

MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)

Return the memref type that can be used to represent an mbarrier object.

void registerTransformDialectExtension(DialectRegistry &registry)

void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)

Convert global->shared vector transfers to async device copies.

FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)

Generate a pipelined version of the scf.for loop based on the schedule given as option.

void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

Populates effects with the memory effects indicating the operation on the given handle value:

void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

Populates effects with the memory effects indicating the access to payload IR resource.

Include the generated interface declarations.

OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)

Convert int64_t to integer attributes of index type and return them as OpFoldResult.

LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})

Fills backwardSlice with the computed backward slice (i.e.

std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue

If Ty is mlir::Type this will select Value instead of having a wrapper around it.

SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)

DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})

Emits a silenceable failure with the given message.

SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)

Given the strides together with a linear index in the dimension space, return the vector-space offset...

bool isMemoryEffectFree(Operation *op)

Returns true if the given operation is free of memory effects.

int64_t computeProduct(ArrayRef< int64_t > basis)

Self-explicit.

DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})

Emits a definite failure with the given message.

void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to SymbolExpr at positions: [0 .

void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)

Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)

These free functions allow clients of the API to not use classes in detail.

int64_t computeSum(ArrayRef< int64_t > basis)

Self-explicit.

void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)

Helper to create the tma operations corresponding to linalg::CopyOp.

SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)

CopyBuilder(RewriterBase &rewriter, Location loc)

Helper to create the base Hopper-specific operations that are reused in various other places.

OpFoldResult buildTmaAsyncLoad(TypedValue< nvgpu::TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< nvgpu::MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)

Build a tma load from global memory to shared memory using barrier to synchronize.

TypedValue< nvgpu::MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)

void buildTryWaitParity(TypedValue< nvgpu::MBarrierGroupType > barrier)

TypedValue< nvgpu::TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)

Create tma descriptor op to initiate transfer from global to shared memory.

SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< nvgpu::TensorMapDescriptorType >> globalDescriptors, ArrayRef< TypedValue< MemRefType >> sharedMemBuffers, TypedValue< nvgpu::MBarrierGroupType > barrier)

If threadIdx.x == 0 does TMA request + wait, else just wait.

void buildBarrierArriveTx(TypedValue< nvgpu::MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)

HopperBuilder(RewriterBase &rewriter, Location loc)

Helper struct to provide a simple mapping from matmul operations to the corresponding mma....

std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator

MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)

FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)

Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...

Helper struct to encode a pair of row/column indexings in the form of affine expressions.

RowColIndexing(AffineExpr row, AffineExpr col)

void print(llvm::raw_ostream &os) const

Options to dictate how loops should be pipelined.