MLIR: lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

32#include "llvm/ADT/ArrayRef.h"

33

34using namespace mlir;

39

40#define DEBUG_TYPE "nvgpu-transforms"

41

42

43

44

45

46void ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(

48 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);

49

50

51

53 llvmTypeConverter.addConversion([&](DeviceAsyncTokenType type) -> Type {

54 return llvmTypeConverter.convertType(

55 IntegerType::get(type.getContext(), 32));

56 });

57 llvmTypeConverter.addConversion([&](MBarrierTokenType type) -> Type {

58 return llvmTypeConverter.convertType(

59 IntegerType::get(type.getContext(), 64));

60 });

61 llvmTypeConverter.addConversion([&](WarpgroupAccumulatorType type) -> Type {

62 Type elemType = type.getFragmented().getElementType();

63 int64_t sizeM = type.getFragmented().getDimSize(0);

64 int64_t sizeN = type.getFragmented().getDimSize(1);

65

66 unsigned numMembers;

68 numMembers = sizeN / 2;

69 else if (elemType.isF16())

70 numMembers = sizeN / 4;

71 else

72 llvm_unreachable("unsupported type for warpgroup accumulator");

73

75 for (unsigned i = 0; i < numMembers; i++)

76 innerStructBody.push_back(elemType);

77 auto innerStructType =

78 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);

79

81 for (int i = 0; i < sizeM; i += kWgmmaSizeM)

82 structBody.push_back(innerStructType);

83

84 auto convertedType =

85 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);

86 return llvmTypeConverter.convertType(convertedType);

87 });

88 llvmTypeConverter.addConversion([&](MBarrierGroupType type) -> Type {

89 return llvmTypeConverter.convertType(

91 });

92 llvmTypeConverter.addConversion(

93 [&](WarpgroupMatrixDescriptorType type) -> Type {

94 return llvmTypeConverter.convertType(

95 IntegerType::get(type.getContext(), 64));

96 });

97 llvmTypeConverter.addConversion([&](TensorMapDescriptorType type) -> Type {

98 return LLVM::LLVMPointerType::get(type.getContext());

99 });

101}

102

103LogicalResult ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(

104 TypeConverterBuilderOpInterface builder) {

105 if (builder.getTypeConverterType() != "LLVMTypeConverter")

106 return emitOpError("expected LLVMTypeConverter");

108}

109

110

111

112

113

114void CreateAsyncGroupsOp::getEffects(

117 producesHandle(getOperation()->getOpResults(), effects);

119}

120

128}

129

130

131

132

133

134

138

139

141 auto space =

142 dyn_cast_if_presentgpu::AddressSpaceAttr(type.getMemorySpace());

143 return space &&

144 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();

145}

146

147

148

150

151 auto load = dyn_castvector::TransferReadOp(op);

153 return nullptr;

154

155 auto loadType = dyn_cast(load.getBase().getType());

157 return nullptr;

159}

160

161

163

164 auto store = dyn_castvector::TransferWriteOp(op);

165 if (!store || store.getVector() != v)

166 return false;

167

168 auto storeType = dyn_cast(store.getBase().getType());

170}

171

172

173

176 if (!loaded || !loaded.hasOneUse())

177 return false;

178

180}

181

182

183

184

185

186

187

188

189

190

191static LogicalResult

194

196 for (Operation &op : *forOp.getBody()) {

197

198 if (op.getNumRegions() > 0)

199 return failure();

200

201 if (isagpu::BarrierOp(op)) {

202 barriers.insert(&op);

203 continue;

204 }

205

206 if (isa<DeviceAsyncCopyOp, DeviceAsyncCreateGroupOp>(op)) {

207 ops.insert(&op);

208 ops.insert(std::make_move_iterator(barriers.begin()),

209 std::make_move_iterator(barriers.end()));

210 assert(barriers.empty() &&

211 "expected to have moved the barriers into another set");

212 continue;

213 }

214

216 ops.insert(&op);

217 continue;

218 }

219 }

220

222}

223

224

225

226

227

228static void

231 unsigned iteration, unsigned depth) {

232

233

234 auto waitOp = dyn_cast(op);

235 if (!waitOp || waitOp.getNumGroups())

236 return;

237

238 int numGroupInFlight = 0;

241 numGroupInFlight = depth - 1;

242 } else {

243

244

246

247

248 numGroupInFlight = depth - 1 - iteration;

249 }

250 waitOp.setNumGroups(numGroupInFlight);

251}

252

253

254

255

256

257

258

259

260

261

262

264 scf::ForOp forOp,

265 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,

269 return visited->getBlock() == forOp.getBody();

270 });

271 options.inclusive = true;

272 for (Operation &op : forOp.getBody()->getOperations()) {

273 if (stage0Ops.contains(&op)) {

275 assert(result.succeeded() && "expected a backward slice");

277 }

278 }

279

280 for (Operation &op : forOp.getBody()->getOperations()) {

281 if (!dependencies.contains(&op) && !isascf::YieldOp(op))

282 opsWithPipelineStages.emplace_back(&op, depth);

283 }

284 for (Operation &op : forOp.getBody()->getOperations()) {

285 if (dependencies.contains(&op))

286 opsWithPipelineStages.emplace_back(&op, 0);

287 }

288}

289

290

291

292

293

296

297

298

300 isa<gpu::BarrierOp, DeviceAsyncCreateGroupOp, DeviceAsyncWaitOp>(op)) {

301 return op;

302 }

303

304

305 auto asyncCopyOp = dyn_cast(op);

306 if (!asyncCopyOp)

307 return nullptr;

308

309

310

311

312

313

314 Location loc = asyncCopyOp->getLoc();

315 Value dstElements = arith::ConstantOp::create(

316 rewriter, loc, asyncCopyOp.getDstElementsAttr());

317 Value originalSrcElement =

318 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;

320 auto srcElements = arith::SelectOp::create(rewriter, loc, predicate,

321 originalSrcElement, c0Index);

322 auto asyncCopyZeroFillOp = DeviceAsyncCopyOp::create(

323 rewriter, loc, DeviceAsyncTokenType::get(asyncCopyOp.getContext()),

324 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),

325 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,

326 UnitAttr());

327 rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);

328 return asyncCopyZeroFillOp;

329}

330

331

332

333

334

335

336static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>

338 bool epiloguePeeling) {

341 return std::make_tuple(

343 scf::ForOp());

344 }

345 if (stage0Ops.empty()) {

346 return std::make_tuple(

348 }

349

351 unsigned maxDepth = depth;

352 auto setAnnotation = [&](Operation *op,

354 unsigned iteration) {

356 };

358 [&](scf::ForOp schedulingFor,

359 std::vector<std::pair<Operation *, unsigned>> &ops) {

360 if (schedulingFor != forOp)

361 return;

363 };

364 options.annotateFn = setAnnotation;

365 if (!epiloguePeeling) {

366 options.peelEpilogue = false;

368 }

369

372 bool modifiedIR;

373 FailureOrscf::ForOp maybePipelined =

374 pipelineForLoop(rewriter, forOp, options, &modifiedIR);

375 if (succeeded(maybePipelined)) {

377 *maybePipelined);

378 }

379 return std::make_tuple(

380 modifiedIR

383 scf::ForOp());

384}

385

390 rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());

391 if (diag.succeeded()) {

394 }

395 if (diag.isDefiniteFailure()) {

397 if (!getPeelEpilogue()) {

398 diag.attachNote(forOp->getLoc()) << "couldn't predicate?";

399 diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();

400 }

402 }

403

404 return std::move(diag);

405}

406

407

408

409

410

411

412

413struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {

416

419

420 void print(llvm::raw_ostream &os) const {

421 os << "- indexing: " << first << ", " << second;

422 }

423};

424

425

426

427

430 : b(b), loc(loc), laneId(laneId) {}

431

433 std::function<SmallVector(MLIRContext *)>;

434

435

436

437 FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);

438

439private:

440 struct MmaSyncInfo {

441 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;

443 vectorShapes;

445 bool tf32Enabled;

446 };

447

448

449

450

451 FailureOr getIndexCalculators(ArrayRef<int64_t> opShape,

453

454

455

456

457

458

459

460

461

462

463

464

465

466

467

468

469

470

474 AffineExpr threadIDInGroup = dim % 4;

476 RowColIndexing{groupID + 8, threadIDInGroup}};

477 }

478

479

480

481

482

483

484 static SmallVector m16n8k4tf32Rhs(MLIRContext *ctx) {

486 AffineExpr groupID = dim.floorDiv(4);

487 AffineExpr threadIDInGroup = dim % 4;

488 return {RowColIndexing{threadIDInGroup, groupID}};

489 }

490

491

492

493

494

495

496

497 static SmallVector m16n8k4tf32Res(MLIRContext *ctx) {

499 AffineExpr groupID = dim.floorDiv(4);

500 AffineExpr threadIDInGroup = dim % 4;

501 return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},

502 RowColIndexing{groupID, threadIDInGroup * 2 + 1},

503 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},

504 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};

505 }

506

507

508

509

510

511

512

513

514

515

516

517

518

519 static SmallVector m16n8k16f16Lhs(MLIRContext *ctx) {

521 AffineExpr groupID = dim.floorDiv(4);

522 AffineExpr threadIDInGroup = dim % 4;

523

524 return {

525 RowColIndexing{groupID, threadIDInGroup * 2 + 0},

526 RowColIndexing{groupID, threadIDInGroup * 2 + 1},

527 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},

528 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1},

529 RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8},

530 RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8},

531 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8},

532 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8}

533 };

534

535 }

536

537

538

539

540

541

542

543

544

545 static SmallVector m16n8k16f16Rhs(MLIRContext *ctx) {

547 AffineExpr groupID = dim.floorDiv(4);

548 AffineExpr threadIDInGroup = dim % 4;

549

550 return {

551 RowColIndexing{threadIDInGroup * 2 + 0, groupID},

552 RowColIndexing{threadIDInGroup * 2 + 1, groupID},

553 RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID},

554 RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID}

555 };

556

557 }

558

559

560

561

562

563

564

565

566

567 static SmallVector m16n8k16f16Res(MLIRContext *ctx) {

569 AffineExpr groupID = dim.floorDiv(4);

570 AffineExpr threadIDInGroup = dim % 4;

571

572 return {

573 RowColIndexing{groupID, threadIDInGroup * 2 + 0},

574 RowColIndexing{groupID, threadIDInGroup * 2 + 1},

575 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},

576 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}

577 };

578

579 }

580

581

582

583

584

585

586

587

588

589 SmallVector buildMemRefLoads(OpBuilder &b, Location loc,

590 OpFoldResult laneId, Value memref,

592

593

594

595

596

597

598

599 Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,

600 OpFoldResult laneId, Value memref,

603

604

605

606

607 SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,

609 OpFoldResult laneId, Value memref,

611

612

613

614

615

616

617

618 SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(

619 OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,

621

622 OpBuilder &b;

623 Location loc;

624 OpFoldResult laneId;

625};

626

627

628

629

630

631

632

633template <typename ApplyFn, typename ReduceFn>

635 ReduceFn reduceFn) {

636 VectorType vectorType = cast(vector.getType());

637 auto vectorShape = vectorType.getShape();

639 for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {

642 }

643}

644

648 const IndexCalculator &indexFn) {

649 auto aff = [&](AffineExpr e) {

650 return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);

651 };

652 SmallVector res;

653 SmallVector indexings = indexFn(b.getContext());

654 for (auto indexing : indexings) {

657 auto load = memref::LoadOp::create(b, loc, memref, ValueRange{row, col});

658 res.push_back(load);

659 }

660 return res;

661}

662

663Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(

666 auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn);

667

669 auto vt = VectorType::get(vectorShape, elementType);

670 Value res = vector::BroadcastOp::create(b, loc, vt, loads[0]);

672 res,

673

674 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {

675 return loads[linearIdx];

676 },

677

678 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {

679 res = vector::InsertOp::create(b, loc, v, res, indices);

680 });

681

682 return res;

683}

684

687 Value memref, const IndexCalculator &indexFn) {

688 auto aff = [&](AffineExpr e) {

689 return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);

690 };

691 SmallVector<Operation *> res;

692 for (auto [indexing, val] :

693 llvm::zip_equal(indexFn(b.getContext()), toStore)) {

696 Operation *store =

697 memref::StoreOp::create(b, loc, val, memref, ValueRange{row, col});

698 res.push_back(store);

699 }

700 return res;

701}

702

706 SmallVector toStore;

707 toStore.reserve(32);

709 vectorToStore,

710

711 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {

712 return vector::ExtractOp::create(b, loc, vectorToStore, indices);

713 },

714

715 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {

716 toStore.push_back(v);

717 });

718 return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn);

719}

720

728 return std::make_tuple(vlhs, vrhs, vres);

729}

730

731FailureOrMmaSyncBuilder::MmaSyncInfo

734

735 Type f16 = b.getF16Type();

736 Type f32 = b.getF32Type();

737 if (opShape == ArrayRef<int64_t>{16, 8, 4} &&

738 elementalTypes == TypeRange{f32, f32, f32}) {

739 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,

740 &MmaSyncBuilder::m16n8k4tf32Rhs,

741 &MmaSyncBuilder::m16n8k4tf32Res),

743 SmallVector<int64_t>{opShape},

744 true};

745 }

746

747

748 if (opShape == ArrayRef<int64_t>{16, 8, 16} &&

749 elementalTypes == TypeRange{f16, f16, f16}) {

750 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,

751 &MmaSyncBuilder::m16n8k16f16Rhs,

752 &MmaSyncBuilder::m16n8k16f16Res),

754 SmallVector<int64_t>{opShape},

755 false};

756 }

757 return failure();

758}

759

761 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();

762 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();

763 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();

764 assert(cast(lhsMemRef.getType()).getRank() == 2 &&

765 "expected lhs to be a 2D memref");

766 assert(cast(rhsMemRef.getType()).getRank() == 2 &&

767 "expected rhs to be a 2D memref");

768 assert(cast(resMemRef.getType()).getRank() == 2 &&

769 "expected res to be a 2D memref");

770

771 int64_t m = cast(lhsMemRef.getType()).getShape()[0];

772 int64_t n = cast(rhsMemRef.getType()).getShape()[1];

773 int64_t k = cast(lhsMemRef.getType()).getShape()[1];

777

778 FailureOr maybeInfo =

779 getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});

780 if (failed(maybeInfo))

781 return failure();

782

783 const MmaSyncInfo &info = *maybeInfo;

784 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;

785 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;

786 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,

787 lhsIndexFn, lhsShape);

788 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,

789 rhsIndexFn, rhsShape);

790 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,

791 resIndexFn, resShape);

792 res =

793 MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape, info.tf32Enabled);

794 buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,

795 resShape);

797}

798

802 bool fail = true;

803

804 if (isa_and_nonnulllinalg::MatmulOp(linalgOp.getOperation())) {

805

806

807 if (linalgOp.hasUserDefinedMaps()) {

808 return emitSilenceableError()

809 << "only matmul ops with non-extended semantics are supported";

810 }

811 Location loc = linalgOp.getLoc();

812

813 Value laneId = gpu::ThreadIdOp::create(

814 rewriter, loc, rewriter.getIndexType(), gpu::Dimension::x);

815 if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))

816 fail = false;

817 }

818

819 if (fail) {

821 << "unsupported target op: " << linalgOp;

822 diag.attachNote(linalgOp->getLoc()) << "target op";

824 }

825

826 rewriter.eraseOp(linalgOp);

828}

829

830

831

832

833

834

835

839

842

843

844

847 gpu::LaunchOp launchOp);

848

849

850

857

858

859

860

865

867

870};

871

878 Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);

879 Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,

880 tidx, zero);

881

883 loc,

884 cond,

885

888 sizes.reserve(globalDescriptors.size());

889 for (auto [desc, shmem] : llvm::zip_equal(

890 globalDescriptors, sharedMemBuffers)) {

892 sizes.push_back(sz);

893 }

894

895

898 },

899

901

902

905 });

906

907 return loadOps;

908}

909

911 return gpu::AddressSpaceAttr::get(

912 b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());

913

914}

915

919 Value barrier = MBarrierCreateOp::create(

921 MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));

923 nvgpu::MBarrierInitOp::create(

928 return cast<TypedValue>(barrier);

929}

930

933 gpu::LaunchOp launchOp) {

935 rewriter.setInsertionPoint(launchOp);

936 Value unrankedMemRef = memref::CastOp::create(

938 UnrankedMemRefType::get(memref.getType().getElementType(),

939 memref.getType().getMemorySpace()),

945

947 Value desc = TmaCreateDescriptorOp::create(

949 TensorMapDescriptorType::get(rewriter.getContext(),

952 TensorMapSwizzleKind::SWIZZLE_NONE,

953 TensorMapL2PromoKind::L2PROMO_NONE,

954 TensorMapOOBKind::OOB_ZERO,

955 TensorMapInterleaveKind::INTERLEAVE_NONE),

956 unrankedMemRef, sizes);

957 return cast<TypedValue>(desc);

958}

959

968 TmaAsyncLoadOp::create(rewriter, loc, sharedMemref, barrier, globalDesc,

970 loadOps.push_back(loadOp);

976 (sharedMemref.getType().getElementTypeBitWidth() / 8);

978 prodExprInBytes, mixedSizes);

979 return res;

980}

981

984 assert(!mixedSizes.empty() && "expecte non-empty sizes");

993 nvgpu::MBarrierArriveExpectTxOp::create(rewriter, loc, barrier, sizeVal, zero,

995}

996

999 Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0);

1000

1001

1002

1003 Value ticksBeforeRetry =

1006 nvgpu::MBarrierTryWaitParityOp::create(rewriter, loc, barrier, parity,

1007 ticksBeforeRetry, zero);

1008}

1009

1010

1011

1012

1013

1014

1021

1024 if (copyOps.empty())

1026

1027 auto launchOp = copyOps.front()->getParentOfTypegpu::LaunchOp();

1028 assert(launchOp && "expected launch op");

1029

1030

1032 rewriter.setInsertionPoint(copyOps.front());

1039 launchOp.getBlockSizeZ()});

1040

1043

1047 auto copyOp = castlinalg::CopyOp(op);

1048 auto inMemRef =

1049 cast<TypedValue>(copyOp.getDpsInputOperand(0)->get());

1050 assert(inMemRef.getType().getRank() == 2 &&

1051 "expected in to be a 2D memref");

1052

1053

1056 globalDescs.push_back(globalDesc);

1057

1058

1059 auto shmem =

1060 cast<TypedValue>(copyOp.getDpsInitOperand(0)->get());

1061 shmems.push_back(shmem);

1062 }

1063

1064

1066 rewriter.setInsertionPoint(copyOps.front());

1069

1070

1072

1073

1076

1077 return results;

1078}

1079

1083 auto payloadOps = state.getPayloadOps(getTarget());

1084 gpu::LaunchOp commonLaunchOp;

1086 if (llvm::any_of(payloadOps, [&](Operation *op) {

1087 if (!commonLaunchOp) {

1089 firstOp = op;

1090 }

1092 commonLaunchOp != op->getParentOfTypegpu::LaunchOp() ||

1093 !isalinalg::CopyOp(op);

1094 if (fail)

1095 failingOp = op;

1096 return fail;

1097 })) {

1099 emitSilenceableError()

1100 << "target ops must be linalg::CopyOp nested under a common "

1101 "gpu.LaunchOp to be rewritten because the tma descriptors need to "

1102 "be created on the host.\nBut got: "

1103 << *firstOp << "\nand " << *failingOp;

1104 return diag;

1105 }

1106

1107

1108 CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));

1109

1111}

1112

1113

1114

1115

1116

1117namespace {

1118class NVGPUTransformDialectExtension

1120public:

1122

1123 NVGPUTransformDialectExtension() {

1124 declareGeneratedDialectarith::ArithDialect();

1125 declareGeneratedDialectaffine::AffineDialect();

1126 declareGeneratedDialect();

1127 declareGeneratedDialectNVVM::NVVMDialect();

1128 declareGeneratedDialectvector::VectorDialect();

1129 registerTransformOps<

1130#define GET_OP_LIST

1131#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"

1132 >();

1133 }

1134};

1135}

1136

1137#define GET_OP_CLASSES

1138#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"

1139

1141 registry.addExtensions();

1142}

p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")

Given a list of lists of parsed operands, populates uniqueOperands with unique operands.

b

Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...

static std::string diag(const llvm::Value &value)

constexpr int kWgmmaSizeM

M size of wgmma.mma_async instruction.

static Attribute getSharedAddressSpaceAttribute(OpBuilder &b)

Definition NVGPUTransformOps.cpp:910

static bool hasDefaultMemorySpace(BaseMemRefType type)

Returns true if the given type has the default memory space.

Definition NVGPUTransformOps.cpp:135

static LogicalResult collectStage0PipeliningOps(scf::ForOp forOp, llvm::SmallPtrSet< Operation *, 16 > &ops)

Populate ops with the set of operations that belong to the stage 0 of the pipelined version of the gi...

Definition NVGPUTransformOps.cpp:192

static std::tuple< DiagnosedSilenceableFailure, scf::ForOp > pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, bool epiloguePeeling)

Applies loop pipelining with the given depth to the given loop so that copies into the shared memory ...

Definition NVGPUTransformOps.cpp:337

static bool isStoreToShared(Operation *op, Value v)

Returns true if the operation is storing the given value into shared memory.

Definition NVGPUTransformOps.cpp:162

static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, ReduceFn reduceFn)

Helper functions to create customizable load and stores operations.

Definition NVGPUTransformOps.cpp:634

static bool hasSharedMemorySpace(BaseMemRefType type)

Returns true if the given type has the shared (workgroup) memory space.

Definition NVGPUTransformOps.cpp:140

static bool isLoadFromGlobalStoredToShared(Operation *op)

Returns true if the operation is a load from the default memory space the result of which is only sto...

Definition NVGPUTransformOps.cpp:174

static std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > makeVectorShapes(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, ArrayRef< int64_t > res)

Definition NVGPUTransformOps.cpp:723

static void getPipelineStages(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned > > &opsWithPipelineStages, unsigned depth, llvm::SmallPtrSetImpl< Operation * > &stage0Ops)

Hook for the loop pipeliner that populates ops with the stage information as follows:

Definition NVGPUTransformOps.cpp:263

static void setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, scf::PipeliningOption::PipelinerPart part, unsigned iteration, unsigned depth)

Hook for the loop pipeliner that sets the "num groups in flight" attribute of async wait operations c...

Definition NVGPUTransformOps.cpp:229

static Operation * replaceOpWithPredicatedOp(RewriterBase &rewriter, Operation *op, Value predicate)

Hook for the loop pipeliner.

Definition NVGPUTransformOps.cpp:294

static Value getValueLoadedFromGlobal(Operation *op)

Returns the value produced by a load from the default memory space.

Definition NVGPUTransformOps.cpp:149

static llvm::ManagedStatic< PassManagerOptions > options

static std::optional< VectorShape > vectorShape(Type type)

#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)

Base type for affine expression.

AffineExpr floorDiv(uint64_t v) const

Attributes are known-constant values of operations.

This class provides a shared interface for ranked and unranked memref types.

Attribute getMemorySpace() const

Returns the memory space in which data referred to by this memref resides.

unsigned getMemorySpaceAsInt() const

[deprecated] Returns the memory space in old raw integer representation.

The result of a transform IR operation application.

static DiagnosedSilenceableFailure success()

Constructs a DiagnosedSilenceableFailure in the success state.

static DiagnosedSilenceableFailure definiteFailure()

Constructs a DiagnosedSilenceableFailure in the failure state.

The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.

void addExtensions()

Add the given extensions to the registry.

Conversion from types to the LLVM IR dialect.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

This is a builder type that keeps local references to arguments.

Builder & setMemorySpace(Attribute newMemorySpace)

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

This class represents a single result from folding an operation.

Operation is the basic unit of execution within MLIR.

Block * getBlock()

Returns the operation block that contains this operation.

OpTy getParentOfType()

Return the closest surrounding parent operation that is of type 'OpTy'.

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isInteger() const

Return true if this is an integer type (with the specified width).

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

user_range getUsers() const

bool hasOneUse() const

Returns true if this value has exactly one use.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)

A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...

void push_back(Operation *op)

Appends an element to the list.

Base class for extensions of the Transform dialect that supports injecting operations into the Transf...

Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...

This is a special rewriter to be used in transform op implementations, providing additional helper fu...

The state maintained across applications of various ops implementing the TransformOpInterface.

auto getPayloadOps(Value value) const

Returns an iterator that enumerates all ops that the given transform IR value corresponds to.

OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)

Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...

SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)

Return the dimensions of the given memref value.

void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)

Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.

MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)

Return the memref type that can be used to represent an mbarrier object.

void registerTransformDialectExtension(DialectRegistry &registry)

void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)

Convert global->shared vector transfers to async device copies.

void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

Populates effects with the memory effects indicating the operation on the given handle value:

void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

Populates effects with the memory effects indicating the access to payload IR resource.

Include the generated interface declarations.

OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)

Convert int64_t to integer attributes of index type and return them as OpFoldResult.

LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})

Fills backwardSlice with the computed backward slice (i.e.

SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)

DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})

Emits a silenceable failure with the given message.

SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)

Given the strides together with a linear index in the dimension space, return the vector-space offset...

bool isMemoryEffectFree(Operation *op)

Returns true if the given operation is free of memory effects.

int64_t computeProduct(ArrayRef< int64_t > basis)

Self-explicit.

llvm::SetVector< T, Vector, Set, N > SetVector

DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})

Emits a definite failure with the given message.

void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue

If Ty is mlir::Type this will select Value instead of having a wrapper around it.

const FrozenRewritePatternSet & patterns

void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to SymbolExpr at positions: [0 .

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)

These free functions allow clients of the API to not use classes in detail.

int64_t computeSum(ArrayRef< int64_t > basis)

Self-explicit.

void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)

Helper to create the tma operations corresponding to linalg::CopyOp.

Definition NVGPUTransformOps.cpp:1015

SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)

Definition NVGPUTransformOps.cpp:1022

CopyBuilder(RewriterBase &rewriter, Location loc)

Definition NVGPUTransformOps.cpp:1016

void buildBarrierArriveTx(TypedValue< MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)

Definition NVGPUTransformOps.cpp:982

OpFoldResult buildTmaAsyncLoad(TypedValue< TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)

Build a tma load from global memory to shared memory using barrier to synchronize.

Definition NVGPUTransformOps.cpp:961

RewriterBase & rewriter

Definition NVGPUTransformOps.cpp:868

TypedValue< TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)

Create tma descriptor op to initiate transfer from global to shared memory.

Definition NVGPUTransformOps.cpp:932

void buildTryWaitParity(TypedValue< MBarrierGroupType > barrier)

Definition NVGPUTransformOps.cpp:997

TypedValue< MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)

Definition NVGPUTransformOps.cpp:917

SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< TensorMapDescriptorType > > globalDescriptors, ArrayRef< TypedValue< MemRefType > > sharedMemBuffers, TypedValue< MBarrierGroupType > barrier)

If threadIdx.x == 0 does TMA request + wait, else just wait.

Definition NVGPUTransformOps.cpp:872

Location loc

Definition NVGPUTransformOps.cpp:869

HopperBuilder(RewriterBase &rewriter, Location loc)

Definition NVGPUTransformOps.cpp:837

Helper struct to provide a simple mapping from matmul operations to the corresponding mma....

Definition NVGPUTransformOps.cpp:428

MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)

Definition NVGPUTransformOps.cpp:429

std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator

Definition NVGPUTransformOps.cpp:432

FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)

Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...

Definition NVGPUTransformOps.cpp:760

Helper struct to encode a pair of row/column indexings in the form of affine expressions.

Definition NVGPUTransformOps.cpp:413

AffineExpr col() const

Definition NVGPUTransformOps.cpp:418

RowColIndexing(AffineExpr row, AffineExpr col)

Definition NVGPUTransformOps.cpp:414

void print(llvm::raw_ostream &os) const

Definition NVGPUTransformOps.cpp:420

AffineExpr row() const

Definition NVGPUTransformOps.cpp:417

Options to dictate how loops should be pipelined.