MLIR: lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
32 #include "llvm/ADT/ArrayRef.h"
33
34 using namespace mlir;
39
40 #define DEBUG_TYPE "nvgpu-transforms"
41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
42 #define DBGSNL() (llvm::dbgs() << "\n")
43 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
44
45
46
47
48
49 void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
51 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
52
53
54
56 llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned {
57 switch (space) {
58 case gpu::AddressSpace::Global:
59 return static_cast<unsigned>(
61 case gpu::AddressSpace::Workgroup:
62 return static_cast<unsigned>(
64 case gpu::AddressSpace::Private:
65 return 0;
66 }
67 llvm_unreachable("unknown address space enum value");
68 return 0;
69 });
70 llvmTypeConverter.addConversion(
71 [&](nvgpu::DeviceAsyncTokenType type) -> Type {
72 return llvmTypeConverter.convertType(
74 });
75 llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
76 return llvmTypeConverter.convertType(
78 });
79 llvmTypeConverter.addConversion(
80 [&](nvgpu::WarpgroupAccumulatorType type) -> Type {
81 Type elemType = type.getFragmented().getElementType();
82 int64_t sizeM = type.getFragmented().getDimSize(0);
83 int64_t sizeN = type.getFragmented().getDimSize(1);
84
85 unsigned numMembers;
87 numMembers = sizeN / 2;
88 else if (elemType.isF16())
89 numMembers = sizeN / 4;
90 else
91 llvm_unreachable("unsupported type for warpgroup accumulator");
92
94 for (unsigned i = 0; i < numMembers; i++)
95 innerStructBody.push_back(elemType);
96 auto innerStructType = LLVM::LLVMStructType::getLiteral(
97 type.getContext(), innerStructBody);
98
100 for (int i = 0; i < sizeM; i += kWgmmaSizeM)
101 structBody.push_back(innerStructType);
102
103 auto convertedType =
104 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
105 return llvmTypeConverter.convertType(convertedType);
106 });
107 llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
108 return llvmTypeConverter.convertType(
110 });
111 llvmTypeConverter.addConversion(
112 [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
113 return llvmTypeConverter.convertType(
115 });
116 llvmTypeConverter.addConversion(
117 [&](nvgpu::TensorMapDescriptorType type) -> Type {
119 });
121 }
122
123 LogicalResult
124 transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
125 transform::TypeConverterBuilderOpInterface builder) {
126 if (builder.getTypeConverterType() != "LLVMTypeConverter")
127 return emitOpError("expected LLVMTypeConverter");
128 return success();
129 }
130
131
132
133
134
135 void transform::CreateAsyncGroupsOp::getEffects(
140 }
141
148 }
149
150
151
152
153
154
157 }
158
159
161 auto space =
162 dyn_cast_if_presentgpu::AddressSpaceAttr(type.getMemorySpace());
163 return space &&
164 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
165 }
166
167
168
170
171 auto load = dyn_castvector::TransferReadOp(op);
172 if (!load)
173 return nullptr;
174
175 auto loadType = dyn_cast(load.getBase().getType());
177 return nullptr;
178 return load;
179 }
180
181
183
184 auto store = dyn_castvector::TransferWriteOp(op);
185 if (!store || store.getVector() != v)
186 return false;
187
188 auto storeType = dyn_cast(store.getBase().getType());
190 }
191
192
193
196 if (!loaded || !loaded.hasOneUse())
197 return false;
198
200 }
201
202
203
204
205
206
207
208
209
210
211 static LogicalResult
214
216 for (Operation &op : *forOp.getBody()) {
217
218 if (op.getNumRegions() > 0)
219 return failure();
220
221 if (isagpu::BarrierOp(op)) {
222 barriers.insert(&op);
223 continue;
224 }
225
226 if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
227 ops.insert(&op);
228 ops.insert(std::make_move_iterator(barriers.begin()),
229 std::make_move_iterator(barriers.end()));
230 assert(barriers.empty() &&
231 "expected to have moved the barriers into another set");
232 continue;
233 }
234
236 ops.insert(&op);
237 continue;
238 }
239 }
240
241 return success();
242 }
243
244
245
246
247
248 static void
251 unsigned iteration, unsigned depth) {
252
253
254 auto waitOp = dyn_castnvgpu::DeviceAsyncWaitOp(op);
255 if (!waitOp || waitOp.getNumGroups())
256 return;
257
258 int numGroupInFlight = 0;
261 numGroupInFlight = depth - 1;
262 } else {
263
264
266
267
268 numGroupInFlight = depth - 1 - iteration;
269 }
270 waitOp.setNumGroups(numGroupInFlight);
271 }
272
273
274
275
276
277
278
279
280
281
282
284 scf::ForOp forOp,
285 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
289 return visited->getBlock() == forOp.getBody();
290 });
291 options.inclusive = true;
292 for (Operation &op : forOp.getBody()->getOperations()) {
293 if (stage0Ops.contains(&op)) {
295 assert(result.succeeded() && "expected a backward slice");
296 (void)result;
297 }
298 }
299
300 for (Operation &op : forOp.getBody()->getOperations()) {
301 if (!dependencies.contains(&op) && !isascf::YieldOp(op))
302 opsWithPipelineStages.emplace_back(&op, depth);
303 }
304 for (Operation &op : forOp.getBody()->getOperations()) {
305 if (dependencies.contains(&op))
306 opsWithPipelineStages.emplace_back(&op, 0);
307 }
308 }
309
310
311
312
313
316
317
318
320 isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
321 nvgpu::DeviceAsyncWaitOp>(op)) {
322 return op;
323 }
324
325
326 auto asyncCopyOp = dyn_castnvgpu::DeviceAsyncCopyOp(op);
327 if (!asyncCopyOp)
328 return nullptr;
329
330
331
332
333
334
335 Location loc = asyncCopyOp->getLoc();
336 Value dstElements =
337 rewriter.createarith::ConstantOp(loc, asyncCopyOp.getDstElementsAttr());
338 Value originalSrcElement =
339 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
340 Value c0Index = rewriter.createarith::ConstantIndexOp(loc, 0);
341 auto srcElements = rewriter.createarith::SelectOp(
342 loc, predicate, originalSrcElement, c0Index);
343 auto asyncCopyZeroFillOp = rewriter.createnvgpu::DeviceAsyncCopyOp(
345 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
346 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
347 UnitAttr());
348 rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
349 return asyncCopyZeroFillOp;
350 }
351
352
353
354
355
356
357 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
359 bool epiloguePeeling) {
362 return std::make_tuple(
364 scf::ForOp());
365 }
366 if (stage0Ops.empty()) {
367 return std::make_tuple(
369 }
370
372 unsigned maxDepth = depth;
373 auto setAnnotation = [&](Operation *op,
375 unsigned iteration) {
377 };
379 [&](scf::ForOp schedulingFor,
380 std::vector<std::pair<Operation *, unsigned>> &ops) {
381 if (schedulingFor != forOp)
382 return;
384 };
385 options.annotateFn = setAnnotation;
386 if (!epiloguePeeling) {
387 options.peelEpilogue = false;
389 }
390
393 bool modifiedIR;
394 FailureOrscf::ForOp maybePipelined =
396 if (succeeded(maybePipelined)) {
398 *maybePipelined);
399 }
400 return std::make_tuple(
401 modifiedIR
404 scf::ForOp());
405 }
406
411 rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());
412 if (diag.succeeded()) {
415 }
416 if (diag.isDefiniteFailure()) {
418 if (!getPeelEpilogue()) {
419 diag.attachNote(forOp->getLoc()) << "couldn't predicate?";
420 diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();
421 }
423 }
424
425 return std::move(diag);
426 }
427
428
429
430
431
432
433
434 struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
437
440
441 void print(llvm::raw_ostream &os) const {
442 os << "- indexing: " << first << ", " << second;
443 }
444 };
445
446
447
448
451 : b(b), loc(loc), laneId(laneId) {}
452
454 std::function<SmallVector(MLIRContext *)>;
455
456
457
458 FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
459
460 private:
461 struct MmaSyncInfo {
462 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
464 vectorShapes;
466 bool tf32Enabled;
467 };
468
469
470
471
472 FailureOr getIndexCalculators(ArrayRef<int64_t> opShape,
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
495 AffineExpr threadIDInGroup = dim % 4;
498 }
499
500
501
502
503
504
508 AffineExpr threadIDInGroup = dim % 4;
510 }
511
512
513
514
515
516
517
521 AffineExpr threadIDInGroup = dim % 4;
522 return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
524 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
525 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
526 }
527
528
529
530
531
532
533
534
535
536
537
538
539
543 AffineExpr threadIDInGroup = dim % 4;
544
545 return {
546 RowColIndexing{groupID, threadIDInGroup * 2 + 0},
547 RowColIndexing{groupID, threadIDInGroup * 2 + 1},
548 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
549 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1},
550 RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8},
551 RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8},
552 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8},
553 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8}
554 };
555
556 }
557
558
559
560
561
562
563
564
565
569 AffineExpr threadIDInGroup = dim % 4;
570
571 return {
572 RowColIndexing{threadIDInGroup * 2 + 0, groupID},
573 RowColIndexing{threadIDInGroup * 2 + 1, groupID},
574 RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID},
575 RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID}
576 };
577
578 }
579
580
581
582
583
584
585
586
587
591 AffineExpr threadIDInGroup = dim % 4;
592
593 return {
594 RowColIndexing{groupID, threadIDInGroup * 2 + 0},
595 RowColIndexing{groupID, threadIDInGroup * 2 + 1},
596 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
597 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}
598 };
599
600 }
601
602
603
604
605
606
607
608
609
612 const IndexCalculator &indexFn);
613
614
615
616
617
618
619
622 IndexCalculator indexFn,
624
625
626
627
631 const IndexCalculator &indexFn);
632
633
634
635
636
637
638
642
646 };
647
648
649
650
651
652
653
654 template <typename ApplyFn, typename ReduceFn>
656 ReduceFn reduceFn) {
657 VectorType vectorType = cast(vector.getType());
658 auto vectorShape = vectorType.getShape();
660 for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
661 auto indices = delinearize(idx, strides);
662 reduceFn(applyFn(vector, idx, indices), idx, indices);
663 }
664 }
665
669 const IndexCalculator &indexFn) {
672 };
675 for (auto indexing : indexings) {
678 auto load = b.creatememref::LoadOp(loc, memref, ValueRange{row, col});
679 res.push_back(load);
680 }
681 return res;
682 }
683
684 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
687 auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
688
691 Value res = b.createvector::SplatOp(loc, vt, loads[0]);
693 res,
694
696 return loads[linearIdx];
697 },
698
700 res = b.createvector::InsertOp(loc, v, res, indices);
701 });
702
703 return res;
704 }
705
708 Value memref, const IndexCalculator &indexFn) {
711 };
713 for (auto [indexing, val] :
714 llvm::zip_equal(indexFn(b.getContext()), toStore)) {
718 b.creatememref::StoreOp(loc, val, memref, ValueRange{row, col});
719 res.push_back(store);
720 }
721 return res;
722 }
723
728 toStore.reserve(32);
730 vectorToStore,
731
733 return b.createvector::ExtractOp(loc, vectorToStore, indices);
734 },
735
737 toStore.push_back(v);
738 });
739 return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
740 }
741
749 return std::make_tuple(vlhs, vrhs, vres);
750 }
751
752 FailureOrMmaSyncBuilder::MmaSyncInfo
755
759 elementalTypes == TypeRange{f32, f32, f32}) {
760 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
761 &MmaSyncBuilder::m16n8k4tf32Rhs,
762 &MmaSyncBuilder::m16n8k4tf32Res),
765 true};
766 }
767
768
771 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
772 &MmaSyncBuilder::m16n8k16f16Rhs,
773 &MmaSyncBuilder::m16n8k16f16Res),
776 false};
777 }
778 return failure();
779 }
780
782 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
783 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
784 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
785 assert(cast(lhsMemRef.getType()).getRank() == 2 &&
786 "expected lhs to be a 2D memref");
787 assert(cast(rhsMemRef.getType()).getRank() == 2 &&
788 "expected rhs to be a 2D memref");
789 assert(cast(resMemRef.getType()).getRank() == 2 &&
790 "expected res to be a 2D memref");
791
792 int64_t m = cast(lhsMemRef.getType()).getShape()[0];
793 int64_t n = cast(rhsMemRef.getType()).getShape()[1];
794 int64_t k = cast(lhsMemRef.getType()).getShape()[1];
798
799 FailureOr maybeInfo =
800 getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
801 if (failed(maybeInfo))
802 return failure();
803
804 MmaSyncInfo info = *maybeInfo;
805 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
806 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
807 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
808 lhsIndexFn, lhsShape);
809 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
810 rhsIndexFn, rhsShape);
811 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
812 resIndexFn, resShape);
813 res = b.createnvgpu::MmaSyncOp(loc, lhs, rhs, res, info.mmaShape,
814 info.tf32Enabled);
815 buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
816 resShape);
818 }
819
824 bool fail = true;
825
826 if (isa_and_nonnulllinalg::MatmulOp(linalgOp.getOperation())) {
827
828
829 if (linalgOp.hasUserDefinedMaps()) {
830 return emitSilenceableError()
831 << "only matmul ops with non-extended semantics are supported";
832 }
833 Location loc = linalgOp.getLoc();
834
835 Value laneId = rewriter.creategpu::ThreadIdOp(
836 loc, rewriter.getIndexType(), gpu::Dimension::x);
837 if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
838 fail = false;
839 }
840
841 if (fail) {
843 << "unsupported target op: " << linalgOp;
844 diag.attachNote(linalgOp->getLoc()) << "target op";
846 }
847
848 rewriter.eraseOp(linalgOp);
850 }
851
852
853
854
855
856
857
860 : rewriter(rewriter), loc(loc) {}
861
863 buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
864
865
866
869 gpu::LaunchOp launchOp);
870
871
872
880
881
882
883
888
890
893 };
894
900 Value zero = rewriter.createarith::ConstantIndexOp(loc, 0);
901 Value tidx = rewriter.creategpu::ThreadIdOp(loc, gpu::Dimension::x);
903 rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::eq, tidx, zero);
904
906 loc,
907 cond,
908
911 sizes.reserve(globalDescriptors.size());
912 for (auto [desc, shmem] : llvm::zip_equal(
913 globalDescriptors, sharedMemBuffers)) {
914 OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
915 sizes.push_back(sz);
916 }
917
918
919 buildBarrierArriveTx(barrier, sizes);
920 rewriter.createscf::YieldOp(loc);
921 },
922
924
925
927 rewriter.createscf::YieldOp(loc);
928 });
929
930 return loadOps;
931 }
932
935 b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
936
937 }
938
942 Value barrier = rewriter.createnvgpu::MBarrierCreateOp(
943 loc,
945 Value zero = rewriter.createarith::ConstantIndexOp(loc, 0);
946 rewriter.createnvgpu::MBarrierInitOp(
949 rewriter.creategpu::BarrierOp(loc);
950 return cast<TypedValuenvgpu::MBarrierGroupType>(barrier);
951 }
952
955 gpu::LaunchOp launchOp) {
958 Value unrankedMemRef = rewriter.creatememref::CastOp(
959 loc,
961 memref.getType().getMemorySpace()),
962 memref);
967
969 Value desc = rewriter.createnvgpu::TmaCreateDescriptorOp(
970 loc,
975 TensorMapSwizzleKind::SWIZZLE_NONE,
976 TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
977 TensorMapInterleaveKind::INTERLEAVE_NONE),
978 unrankedMemRef, sizes);
979 return cast<TypedValuenvgpu::TensorMapDescriptorType>(desc);
980 }
981
986 SmallVectorImpl<Operation *> &loadOps) {
988 Value zero = rewriter.createarith::ConstantIndexOp(loc, 0);
989 Operation *loadOp = rewriter.createnvgpu::TmaAsyncLoadOp(
990 loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero,
992 loadOps.push_back(loadOp);
998 (sharedMemref.getType().getElementTypeBitWidth() / 8);
1000 prodExprInBytes, mixedSizes);
1001 return res;
1002 }
1003
1006 ArrayRef mixedSizes) {
1007 assert(!mixedSizes.empty() && "expecte non-empty sizes");
1015 Value zero = rewriter.createarith::ConstantIndexOp(loc, 0);
1016 rewriter.createnvgpu::MBarrierArriveExpectTxOp(loc, barrier, sizeVal, zero,
1018 }
1019
1023 Value parity = rewriter.createLLVM::ConstantOp(loc, i1, 0);
1024
1025
1026
1027 Value ticksBeforeRetry =
1028 rewriter.createarith::ConstantIndexOp(loc, 10000000);
1029 Value zero = rewriter.createarith::ConstantIndexOp(loc, 0);
1030 rewriter.createnvgpu::MBarrierTryWaitParityOp(loc, barrier, parity,
1031 ticksBeforeRetry, zero);
1032 }
1033
1034
1035
1036
1037
1038
1042
1044 };
1045
1048 if (copyOps.empty())
1050
1051 auto launchOp = copyOps.front()->getParentOfTypegpu::LaunchOp();
1052 assert(launchOp && "expected launch op");
1053
1054
1061 rewriter, loc, prod,
1062 ArrayRef{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1063 launchOp.getBlockSizeZ()});
1064
1066 buildAndInitBarrierInSharedMemory(numThreads);
1067
1071 auto copyOp = castlinalg::CopyOp(op);
1072 auto inMemRef =
1073 cast<TypedValue>(copyOp.getDpsInputOperand(0)->get());
1074 assert(inMemRef.getType().getRank() == 2 &&
1075 "expected in to be a 2D memref");
1076
1077
1079 buildGlobalMemRefDescriptor(inMemRef, launchOp);
1080 globalDescs.push_back(globalDesc);
1081
1082
1083 auto shmem =
1084 cast<TypedValue>(copyOp.getDpsInitOperand(0)->get());
1085 shmems.push_back(shmem);
1086 }
1087
1088
1092 buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1093
1094
1095 buildTryWaitParity(barrier);
1096
1097
1100
1101 return results;
1102 }
1103
1108 auto payloadOps = state.getPayloadOps(getTarget());
1109 gpu::LaunchOp commonLaunchOp;
1111 if (llvm::any_of(payloadOps, [&](Operation *op) {
1112 if (!commonLaunchOp) {
1114 firstOp = op;
1115 }
1117 commonLaunchOp != op->getParentOfTypegpu::LaunchOp() ||
1118 !isalinalg::CopyOp(op);
1119 if (fail)
1120 failingOp = op;
1121 return fail;
1122 })) {
1124 emitSilenceableError()
1125 << "target ops must be linalg::CopyOp nested under a common "
1126 "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1127 "be created on the host.\nBut got: "
1128 << *firstOp << "\nand " << *failingOp;
1129 return diag;
1130 }
1131
1132
1133 CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
1134
1136 }
1137
1138
1139
1140
1141
1142 namespace {
1143 class NVGPUTransformDialectExtension
1145 NVGPUTransformDialectExtension> {
1146 public:
1148
1149 NVGPUTransformDialectExtension() {
1150 declareGeneratedDialectarith::ArithDialect();
1151 declareGeneratedDialectaffine::AffineDialect();
1152 declareGeneratedDialectnvgpu::NVGPUDialect();
1153 declareGeneratedDialectNVVM::NVVMDialect();
1154 declareGeneratedDialectvector::VectorDialect();
1155 registerTransformOps<
1156 #define GET_OP_LIST
1157 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1158 >();
1159 }
1160 };
1161 }
1162
1163 #define GET_OP_CLASSES
1164 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1165
1167 registry.addExtensions();
1168 }
static constexpr int64_t kSharedMemorySpace
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static Attribute getSharedAddressSpaceAttribute(OpBuilder &b)
static bool hasDefaultMemorySpace(BaseMemRefType type)
Returns true if the given type has the default memory space.
static LogicalResult collectStage0PipeliningOps(scf::ForOp forOp, llvm::SmallPtrSet< Operation *, 16 > &ops)
Populate ops with the set of operations that belong to the stage 0 of the pipelined version of the gi...
static Operation * replaceOpWithPredicatedOp(RewriterBase &rewriter, Operation *op, Value predicate)
Hook for the loop pipeliner.
static bool isStoreToShared(Operation *op, Value v)
Returns true if the operation is storing the given value into shared memory.
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, ReduceFn reduceFn)
Helper functions to create customizable load and stores operations.
static std::tuple< DiagnosedSilenceableFailure, scf::ForOp > pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, bool epiloguePeeling)
Applies loop pipelining with the given depth to the given loop so that copies into the shared memory ...
static bool hasSharedMemorySpace(BaseMemRefType type)
Returns true if the given type has the shared (workgroup) memory space.
static bool isLoadFromGlobalStoredToShared(Operation *op)
Returns true if the operation is a load from the default memory space the result of which is only sto...
static std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > makeVectorShapes(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, ArrayRef< int64_t > res)
static void setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, scf::PipeliningOption::PipelinerPart part, unsigned iteration, unsigned depth)
Hook for the loop pipeliner that sets the "num groups in flight" attribute of async wait operations c...
static void getPipelineStages(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned >> &opsWithPipelineStages, unsigned depth, llvm::SmallPtrSetImpl< Operation * > &stage0Ops)
Hook for the loop pipeliner that populates ops with the stage information as follows:
static Value getValueLoadedFromGlobal(Operation *op)
Returns the value produced by a load from the default memory space.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
MLIRContext * getContext() const
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
@ kGlobalMemorySpace
Global memory space identifier.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry ®istry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
Helper to create the base Hopper-specific operations that are reused in various other places.
OpFoldResult buildTmaAsyncLoad(TypedValue< nvgpu::TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< nvgpu::MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
TypedValue< nvgpu::MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
void buildTryWaitParity(TypedValue< nvgpu::MBarrierGroupType > barrier)
TypedValue< nvgpu::TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< nvgpu::TensorMapDescriptorType >> globalDescriptors, ArrayRef< TypedValue< MemRefType >> sharedMemBuffers, TypedValue< nvgpu::MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
void buildBarrierArriveTx(TypedValue< nvgpu::MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
Options to dictate how loops should be pipelined.