MLIR: lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
32#include "llvm/ADT/ArrayRef.h"
33
34using namespace mlir;
39
40#define DEBUG_TYPE "nvgpu-transforms"
41
42
43
44
45
46void ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
48 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
49
50
51
53 llvmTypeConverter.addConversion([&](DeviceAsyncTokenType type) -> Type {
54 return llvmTypeConverter.convertType(
55 IntegerType::get(type.getContext(), 32));
56 });
57 llvmTypeConverter.addConversion([&](MBarrierTokenType type) -> Type {
58 return llvmTypeConverter.convertType(
59 IntegerType::get(type.getContext(), 64));
60 });
61 llvmTypeConverter.addConversion([&](WarpgroupAccumulatorType type) -> Type {
62 Type elemType = type.getFragmented().getElementType();
63 int64_t sizeM = type.getFragmented().getDimSize(0);
64 int64_t sizeN = type.getFragmented().getDimSize(1);
65
66 unsigned numMembers;
68 numMembers = sizeN / 2;
69 else if (elemType.isF16())
70 numMembers = sizeN / 4;
71 else
72 llvm_unreachable("unsupported type for warpgroup accumulator");
73
75 for (unsigned i = 0; i < numMembers; i++)
76 innerStructBody.push_back(elemType);
77 auto innerStructType =
78 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
79
81 for (int i = 0; i < sizeM; i += kWgmmaSizeM)
82 structBody.push_back(innerStructType);
83
84 auto convertedType =
85 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
86 return llvmTypeConverter.convertType(convertedType);
87 });
88 llvmTypeConverter.addConversion([&](MBarrierGroupType type) -> Type {
89 return llvmTypeConverter.convertType(
91 });
92 llvmTypeConverter.addConversion(
93 [&](WarpgroupMatrixDescriptorType type) -> Type {
94 return llvmTypeConverter.convertType(
95 IntegerType::get(type.getContext(), 64));
96 });
97 llvmTypeConverter.addConversion([&](TensorMapDescriptorType type) -> Type {
98 return LLVM::LLVMPointerType::get(type.getContext());
99 });
101}
102
103LogicalResult ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
104 TypeConverterBuilderOpInterface builder) {
105 if (builder.getTypeConverterType() != "LLVMTypeConverter")
106 return emitOpError("expected LLVMTypeConverter");
108}
109
110
111
112
113
114void CreateAsyncGroupsOp::getEffects(
117 producesHandle(getOperation()->getOpResults(), effects);
119}
120
128}
129
130
131
132
133
134
138
139
141 auto space =
142 dyn_cast_if_presentgpu::AddressSpaceAttr(type.getMemorySpace());
143 return space &&
144 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
145}
146
147
148
150
151 auto load = dyn_castvector::TransferReadOp(op);
153 return nullptr;
154
155 auto loadType = dyn_cast(load.getBase().getType());
157 return nullptr;
159}
160
161
163
164 auto store = dyn_castvector::TransferWriteOp(op);
165 if (!store || store.getVector() != v)
166 return false;
167
168 auto storeType = dyn_cast(store.getBase().getType());
170}
171
172
173
176 if (!loaded || !loaded.hasOneUse())
177 return false;
178
180}
181
182
183
184
185
186
187
188
189
190
191static LogicalResult
194
196 for (Operation &op : *forOp.getBody()) {
197
198 if (op.getNumRegions() > 0)
199 return failure();
200
201 if (isagpu::BarrierOp(op)) {
202 barriers.insert(&op);
203 continue;
204 }
205
206 if (isa<DeviceAsyncCopyOp, DeviceAsyncCreateGroupOp>(op)) {
207 ops.insert(&op);
208 ops.insert(std::make_move_iterator(barriers.begin()),
209 std::make_move_iterator(barriers.end()));
210 assert(barriers.empty() &&
211 "expected to have moved the barriers into another set");
212 continue;
213 }
214
216 ops.insert(&op);
217 continue;
218 }
219 }
220
222}
223
224
225
226
227
228static void
231 unsigned iteration, unsigned depth) {
232
233
234 auto waitOp = dyn_cast(op);
235 if (!waitOp || waitOp.getNumGroups())
236 return;
237
238 int numGroupInFlight = 0;
241 numGroupInFlight = depth - 1;
242 } else {
243
244
246
247
248 numGroupInFlight = depth - 1 - iteration;
249 }
250 waitOp.setNumGroups(numGroupInFlight);
251}
252
253
254
255
256
257
258
259
260
261
262
264 scf::ForOp forOp,
265 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
269 return visited->getBlock() == forOp.getBody();
270 });
271 options.inclusive = true;
272 for (Operation &op : forOp.getBody()->getOperations()) {
273 if (stage0Ops.contains(&op)) {
275 assert(result.succeeded() && "expected a backward slice");
277 }
278 }
279
280 for (Operation &op : forOp.getBody()->getOperations()) {
281 if (!dependencies.contains(&op) && !isascf::YieldOp(op))
282 opsWithPipelineStages.emplace_back(&op, depth);
283 }
284 for (Operation &op : forOp.getBody()->getOperations()) {
285 if (dependencies.contains(&op))
286 opsWithPipelineStages.emplace_back(&op, 0);
287 }
288}
289
290
291
292
293
296
297
298
300 isa<gpu::BarrierOp, DeviceAsyncCreateGroupOp, DeviceAsyncWaitOp>(op)) {
301 return op;
302 }
303
304
305 auto asyncCopyOp = dyn_cast(op);
306 if (!asyncCopyOp)
307 return nullptr;
308
309
310
311
312
313
314 Location loc = asyncCopyOp->getLoc();
315 Value dstElements = arith::ConstantOp::create(
316 rewriter, loc, asyncCopyOp.getDstElementsAttr());
317 Value originalSrcElement =
318 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
320 auto srcElements = arith::SelectOp::create(rewriter, loc, predicate,
321 originalSrcElement, c0Index);
322 auto asyncCopyZeroFillOp = DeviceAsyncCopyOp::create(
323 rewriter, loc, DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
324 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
325 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
326 UnitAttr());
327 rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
328 return asyncCopyZeroFillOp;
329}
330
331
332
333
334
335
336static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
338 bool epiloguePeeling) {
341 return std::make_tuple(
343 scf::ForOp());
344 }
345 if (stage0Ops.empty()) {
346 return std::make_tuple(
348 }
349
351 unsigned maxDepth = depth;
352 auto setAnnotation = [&](Operation *op,
354 unsigned iteration) {
356 };
358 [&](scf::ForOp schedulingFor,
359 std::vector<std::pair<Operation *, unsigned>> &ops) {
360 if (schedulingFor != forOp)
361 return;
363 };
364 options.annotateFn = setAnnotation;
365 if (!epiloguePeeling) {
366 options.peelEpilogue = false;
368 }
369
372 bool modifiedIR;
373 FailureOrscf::ForOp maybePipelined =
374 pipelineForLoop(rewriter, forOp, options, &modifiedIR);
375 if (succeeded(maybePipelined)) {
377 *maybePipelined);
378 }
379 return std::make_tuple(
380 modifiedIR
383 scf::ForOp());
384}
385
390 rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());
391 if (diag.succeeded()) {
394 }
395 if (diag.isDefiniteFailure()) {
397 if (!getPeelEpilogue()) {
398 diag.attachNote(forOp->getLoc()) << "couldn't predicate?";
399 diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();
400 }
402 }
403
404 return std::move(diag);
405}
406
407
408
409
410
411
412
413struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
416
419
420 void print(llvm::raw_ostream &os) const {
421 os << "- indexing: " << first << ", " << second;
422 }
423};
424
425
426
427
430 : b(b), loc(loc), laneId(laneId) {}
431
433 std::function<SmallVector(MLIRContext *)>;
434
435
436
437 FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
438
439private:
440 struct MmaSyncInfo {
441 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
443 vectorShapes;
445 bool tf32Enabled;
446 };
447
448
449
450
451 FailureOr getIndexCalculators(ArrayRef<int64_t> opShape,
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
474 AffineExpr threadIDInGroup = dim % 4;
476 RowColIndexing{groupID + 8, threadIDInGroup}};
477 }
478
479
480
481
482
483
484 static SmallVector m16n8k4tf32Rhs(MLIRContext *ctx) {
486 AffineExpr groupID = dim.floorDiv(4);
487 AffineExpr threadIDInGroup = dim % 4;
488 return {RowColIndexing{threadIDInGroup, groupID}};
489 }
490
491
492
493
494
495
496
497 static SmallVector m16n8k4tf32Res(MLIRContext *ctx) {
499 AffineExpr groupID = dim.floorDiv(4);
500 AffineExpr threadIDInGroup = dim % 4;
501 return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
502 RowColIndexing{groupID, threadIDInGroup * 2 + 1},
503 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
504 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
505 }
506
507
508
509
510
511
512
513
514
515
516
517
518
519 static SmallVector m16n8k16f16Lhs(MLIRContext *ctx) {
521 AffineExpr groupID = dim.floorDiv(4);
522 AffineExpr threadIDInGroup = dim % 4;
523
524 return {
525 RowColIndexing{groupID, threadIDInGroup * 2 + 0},
526 RowColIndexing{groupID, threadIDInGroup * 2 + 1},
527 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
528 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1},
529 RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8},
530 RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8},
531 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8},
532 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8}
533 };
534
535 }
536
537
538
539
540
541
542
543
544
545 static SmallVector m16n8k16f16Rhs(MLIRContext *ctx) {
547 AffineExpr groupID = dim.floorDiv(4);
548 AffineExpr threadIDInGroup = dim % 4;
549
550 return {
551 RowColIndexing{threadIDInGroup * 2 + 0, groupID},
552 RowColIndexing{threadIDInGroup * 2 + 1, groupID},
553 RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID},
554 RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID}
555 };
556
557 }
558
559
560
561
562
563
564
565
566
567 static SmallVector m16n8k16f16Res(MLIRContext *ctx) {
569 AffineExpr groupID = dim.floorDiv(4);
570 AffineExpr threadIDInGroup = dim % 4;
571
572 return {
573 RowColIndexing{groupID, threadIDInGroup * 2 + 0},
574 RowColIndexing{groupID, threadIDInGroup * 2 + 1},
575 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
576 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}
577 };
578
579 }
580
581
582
583
584
585
586
587
588
589 SmallVector buildMemRefLoads(OpBuilder &b, Location loc,
590 OpFoldResult laneId, Value memref,
592
593
594
595
596
597
598
599 Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
600 OpFoldResult laneId, Value memref,
603
604
605
606
607 SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
609 OpFoldResult laneId, Value memref,
611
612
613
614
615
616
617
618 SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
619 OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
621
622 OpBuilder &b;
623 Location loc;
624 OpFoldResult laneId;
625};
626
627
628
629
630
631
632
633template <typename ApplyFn, typename ReduceFn>
635 ReduceFn reduceFn) {
636 VectorType vectorType = cast(vector.getType());
637 auto vectorShape = vectorType.getShape();
639 for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
642 }
643}
644
648 const IndexCalculator &indexFn) {
649 auto aff = [&](AffineExpr e) {
650 return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
651 };
652 SmallVector res;
653 SmallVector indexings = indexFn(b.getContext());
654 for (auto indexing : indexings) {
657 auto load = memref::LoadOp::create(b, loc, memref, ValueRange{row, col});
658 res.push_back(load);
659 }
660 return res;
661}
662
663Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
666 auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn);
667
669 auto vt = VectorType::get(vectorShape, elementType);
670 Value res = vector::BroadcastOp::create(b, loc, vt, loads[0]);
672 res,
673
674 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
675 return loads[linearIdx];
676 },
677
678 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
679 res = vector::InsertOp::create(b, loc, v, res, indices);
680 });
681
682 return res;
683}
684
687 Value memref, const IndexCalculator &indexFn) {
688 auto aff = [&](AffineExpr e) {
689 return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
690 };
691 SmallVector<Operation *> res;
692 for (auto [indexing, val] :
693 llvm::zip_equal(indexFn(b.getContext()), toStore)) {
696 Operation *store =
697 memref::StoreOp::create(b, loc, val, memref, ValueRange{row, col});
698 res.push_back(store);
699 }
700 return res;
701}
702
706 SmallVector toStore;
707 toStore.reserve(32);
709 vectorToStore,
710
711 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
712 return vector::ExtractOp::create(b, loc, vectorToStore, indices);
713 },
714
715 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
716 toStore.push_back(v);
717 });
718 return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn);
719}
720
728 return std::make_tuple(vlhs, vrhs, vres);
729}
730
731FailureOrMmaSyncBuilder::MmaSyncInfo
734
735 Type f16 = b.getF16Type();
736 Type f32 = b.getF32Type();
737 if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
738 elementalTypes == TypeRange{f32, f32, f32}) {
739 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
740 &MmaSyncBuilder::m16n8k4tf32Rhs,
741 &MmaSyncBuilder::m16n8k4tf32Res),
743 SmallVector<int64_t>{opShape},
744 true};
745 }
746
747
748 if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
749 elementalTypes == TypeRange{f16, f16, f16}) {
750 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
751 &MmaSyncBuilder::m16n8k16f16Rhs,
752 &MmaSyncBuilder::m16n8k16f16Res),
754 SmallVector<int64_t>{opShape},
755 false};
756 }
757 return failure();
758}
759
761 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
762 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
763 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
764 assert(cast(lhsMemRef.getType()).getRank() == 2 &&
765 "expected lhs to be a 2D memref");
766 assert(cast(rhsMemRef.getType()).getRank() == 2 &&
767 "expected rhs to be a 2D memref");
768 assert(cast(resMemRef.getType()).getRank() == 2 &&
769 "expected res to be a 2D memref");
770
771 int64_t m = cast(lhsMemRef.getType()).getShape()[0];
772 int64_t n = cast(rhsMemRef.getType()).getShape()[1];
773 int64_t k = cast(lhsMemRef.getType()).getShape()[1];
777
778 FailureOr maybeInfo =
779 getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
780 if (failed(maybeInfo))
781 return failure();
782
783 const MmaSyncInfo &info = *maybeInfo;
784 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
785 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
786 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
787 lhsIndexFn, lhsShape);
788 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
789 rhsIndexFn, rhsShape);
790 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
791 resIndexFn, resShape);
792 res =
793 MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape, info.tf32Enabled);
794 buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
795 resShape);
797}
798
802 bool fail = true;
803
804 if (isa_and_nonnulllinalg::MatmulOp(linalgOp.getOperation())) {
805
806
807 if (linalgOp.hasUserDefinedMaps()) {
808 return emitSilenceableError()
809 << "only matmul ops with non-extended semantics are supported";
810 }
811 Location loc = linalgOp.getLoc();
812
813 Value laneId = gpu::ThreadIdOp::create(
814 rewriter, loc, rewriter.getIndexType(), gpu::Dimension::x);
815 if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
816 fail = false;
817 }
818
819 if (fail) {
821 << "unsupported target op: " << linalgOp;
822 diag.attachNote(linalgOp->getLoc()) << "target op";
824 }
825
826 rewriter.eraseOp(linalgOp);
828}
829
830
831
832
833
834
835
839
842
843
844
847 gpu::LaunchOp launchOp);
848
849
850
857
858
859
860
865
867
870};
871
878 Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);
879 Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
880 tidx, zero);
881
883 loc,
884 cond,
885
888 sizes.reserve(globalDescriptors.size());
889 for (auto [desc, shmem] : llvm::zip_equal(
890 globalDescriptors, sharedMemBuffers)) {
892 sizes.push_back(sz);
893 }
894
895
898 },
899
901
902
905 });
906
907 return loadOps;
908}
909
911 return gpu::AddressSpaceAttr::get(
912 b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
913
914}
915
919 Value barrier = MBarrierCreateOp::create(
921 MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
923 nvgpu::MBarrierInitOp::create(
928 return cast<TypedValue>(barrier);
929}
930
933 gpu::LaunchOp launchOp) {
935 rewriter.setInsertionPoint(launchOp);
936 Value unrankedMemRef = memref::CastOp::create(
938 UnrankedMemRefType::get(memref.getType().getElementType(),
939 memref.getType().getMemorySpace()),
945
947 Value desc = TmaCreateDescriptorOp::create(
949 TensorMapDescriptorType::get(rewriter.getContext(),
952 TensorMapSwizzleKind::SWIZZLE_NONE,
953 TensorMapL2PromoKind::L2PROMO_NONE,
954 TensorMapOOBKind::OOB_ZERO,
955 TensorMapInterleaveKind::INTERLEAVE_NONE),
956 unrankedMemRef, sizes);
957 return cast<TypedValue>(desc);
958}
959
968 TmaAsyncLoadOp::create(rewriter, loc, sharedMemref, barrier, globalDesc,
970 loadOps.push_back(loadOp);
976 (sharedMemref.getType().getElementTypeBitWidth() / 8);
978 prodExprInBytes, mixedSizes);
979 return res;
980}
981
984 assert(!mixedSizes.empty() && "expecte non-empty sizes");
993 nvgpu::MBarrierArriveExpectTxOp::create(rewriter, loc, barrier, sizeVal, zero,
995}
996
999 Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0);
1000
1001
1002
1003 Value ticksBeforeRetry =
1006 nvgpu::MBarrierTryWaitParityOp::create(rewriter, loc, barrier, parity,
1007 ticksBeforeRetry, zero);
1008}
1009
1010
1011
1012
1013
1014
1021
1024 if (copyOps.empty())
1026
1027 auto launchOp = copyOps.front()->getParentOfTypegpu::LaunchOp();
1028 assert(launchOp && "expected launch op");
1029
1030
1032 rewriter.setInsertionPoint(copyOps.front());
1039 launchOp.getBlockSizeZ()});
1040
1043
1047 auto copyOp = castlinalg::CopyOp(op);
1048 auto inMemRef =
1049 cast<TypedValue>(copyOp.getDpsInputOperand(0)->get());
1050 assert(inMemRef.getType().getRank() == 2 &&
1051 "expected in to be a 2D memref");
1052
1053
1056 globalDescs.push_back(globalDesc);
1057
1058
1059 auto shmem =
1060 cast<TypedValue>(copyOp.getDpsInitOperand(0)->get());
1061 shmems.push_back(shmem);
1062 }
1063
1064
1066 rewriter.setInsertionPoint(copyOps.front());
1069
1070
1072
1073
1076
1077 return results;
1078}
1079
1083 auto payloadOps = state.getPayloadOps(getTarget());
1084 gpu::LaunchOp commonLaunchOp;
1086 if (llvm::any_of(payloadOps, [&](Operation *op) {
1087 if (!commonLaunchOp) {
1089 firstOp = op;
1090 }
1092 commonLaunchOp != op->getParentOfTypegpu::LaunchOp() ||
1093 !isalinalg::CopyOp(op);
1094 if (fail)
1095 failingOp = op;
1096 return fail;
1097 })) {
1099 emitSilenceableError()
1100 << "target ops must be linalg::CopyOp nested under a common "
1101 "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1102 "be created on the host.\nBut got: "
1103 << *firstOp << "\nand " << *failingOp;
1104 return diag;
1105 }
1106
1107
1108 CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
1109
1111}
1112
1113
1114
1115
1116
1117namespace {
1118class NVGPUTransformDialectExtension
1120public:
1122
1123 NVGPUTransformDialectExtension() {
1124 declareGeneratedDialectarith::ArithDialect();
1125 declareGeneratedDialectaffine::AffineDialect();
1126 declareGeneratedDialect();
1127 declareGeneratedDialectNVVM::NVVMDialect();
1128 declareGeneratedDialectvector::VectorDialect();
1129 registerTransformOps<
1130#define GET_OP_LIST
1131#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1132 >();
1133 }
1134};
1135}
1136
1137#define GET_OP_CLASSES
1138#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1139
1141 registry.addExtensions();
1142}
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static Attribute getSharedAddressSpaceAttribute(OpBuilder &b)
Definition NVGPUTransformOps.cpp:910
static bool hasDefaultMemorySpace(BaseMemRefType type)
Returns true if the given type has the default memory space.
Definition NVGPUTransformOps.cpp:135
static LogicalResult collectStage0PipeliningOps(scf::ForOp forOp, llvm::SmallPtrSet< Operation *, 16 > &ops)
Populate ops with the set of operations that belong to the stage 0 of the pipelined version of the gi...
Definition NVGPUTransformOps.cpp:192
static std::tuple< DiagnosedSilenceableFailure, scf::ForOp > pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, bool epiloguePeeling)
Applies loop pipelining with the given depth to the given loop so that copies into the shared memory ...
Definition NVGPUTransformOps.cpp:337
static bool isStoreToShared(Operation *op, Value v)
Returns true if the operation is storing the given value into shared memory.
Definition NVGPUTransformOps.cpp:162
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, ReduceFn reduceFn)
Helper functions to create customizable load and stores operations.
Definition NVGPUTransformOps.cpp:634
static bool hasSharedMemorySpace(BaseMemRefType type)
Returns true if the given type has the shared (workgroup) memory space.
Definition NVGPUTransformOps.cpp:140
static bool isLoadFromGlobalStoredToShared(Operation *op)
Returns true if the operation is a load from the default memory space the result of which is only sto...
Definition NVGPUTransformOps.cpp:174
static std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > makeVectorShapes(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, ArrayRef< int64_t > res)
Definition NVGPUTransformOps.cpp:723
static void getPipelineStages(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned > > &opsWithPipelineStages, unsigned depth, llvm::SmallPtrSetImpl< Operation * > &stage0Ops)
Hook for the loop pipeliner that populates ops with the stage information as follows:
Definition NVGPUTransformOps.cpp:263
static void setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, scf::PipeliningOption::PipelinerPart part, unsigned iteration, unsigned depth)
Hook for the loop pipeliner that sets the "num groups in flight" attribute of async wait operations c...
Definition NVGPUTransformOps.cpp:229
static Operation * replaceOpWithPredicatedOp(RewriterBase &rewriter, Operation *op, Value predicate)
Hook for the loop pipeliner.
Definition NVGPUTransformOps.cpp:294
static Value getValueLoadedFromGlobal(Operation *op)
Returns the value produced by a load from the default memory space.
Definition NVGPUTransformOps.cpp:149
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry ®istry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
llvm::SetVector< T, Vector, Set, N > SetVector
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Helper to create the tma operations corresponding to linalg::CopyOp.
Definition NVGPUTransformOps.cpp:1015
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
Definition NVGPUTransformOps.cpp:1022
CopyBuilder(RewriterBase &rewriter, Location loc)
Definition NVGPUTransformOps.cpp:1016
void buildBarrierArriveTx(TypedValue< MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
Definition NVGPUTransformOps.cpp:982
OpFoldResult buildTmaAsyncLoad(TypedValue< TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
Definition NVGPUTransformOps.cpp:961
RewriterBase & rewriter
Definition NVGPUTransformOps.cpp:868
TypedValue< TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
Definition NVGPUTransformOps.cpp:932
void buildTryWaitParity(TypedValue< MBarrierGroupType > barrier)
Definition NVGPUTransformOps.cpp:997
TypedValue< MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
Definition NVGPUTransformOps.cpp:917
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< TensorMapDescriptorType > > globalDescriptors, ArrayRef< TypedValue< MemRefType > > sharedMemBuffers, TypedValue< MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
Definition NVGPUTransformOps.cpp:872
Location loc
Definition NVGPUTransformOps.cpp:869
HopperBuilder(RewriterBase &rewriter, Location loc)
Definition NVGPUTransformOps.cpp:837
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
Definition NVGPUTransformOps.cpp:428
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
Definition NVGPUTransformOps.cpp:429
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
Definition NVGPUTransformOps.cpp:432
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Definition NVGPUTransformOps.cpp:760
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
Definition NVGPUTransformOps.cpp:413
AffineExpr col() const
Definition NVGPUTransformOps.cpp:418
RowColIndexing(AffineExpr row, AffineExpr col)
Definition NVGPUTransformOps.cpp:414
void print(llvm::raw_ostream &os) const
Definition NVGPUTransformOps.cpp:420
AffineExpr row() const
Definition NVGPUTransformOps.cpp:417
Options to dictate how loops should be pipelined.