MLIR: lib/Conversion/VectorToGPU/VectorToGPU.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
14
15 #include <type_traits>
16
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38
39 #define DEBUG_TYPE "vector-to-gpu"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
42
43 namespace mlir {
44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
45 #include "mlir/Conversion/Passes.h.inc"
46 }
47
48 using namespace mlir;
49
50
51
52
53
54
55
56 template
60 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
61 Location loc = xferOp.getLoc();
62 unsigned offsetsIdx = 0;
63 for (auto expr : xferOp.getPermutationMap().getResults()) {
64 if (auto dim = dyn_cast(expr)) {
65 Value prevIdx = indices[dim.getPosition()];
67 dims.push_back(prevIdx);
70 rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
71 continue;
72 }
73 }
74 }
75
76
78 bool useNvGpu) {
80 auto infer = [&](MapList m) {
82 };
85 auto iteratorTypes = contract.getIteratorTypes().getValue();
89 return false;
90
91
92
93 if (!useNvGpu &&
94 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
95 return false;
96 if (useNvGpu &&
97 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
98 return false;
99
100 return true;
101 }
102
103
104
107
109 auto nDim = permutationMap.getNumDims();
111 if (nDim < 2) {
112
114 return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
115 }
116
119
120 return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
121 permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
122 }
123
124
125
127 auto memrefType = dyn_cast(type);
128 if (!memrefType)
129 return false;
130
131 if (memrefType.getRank() < 2)
132 return 0;
133 int64_t offset = 0;
135 if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
136 strides.back() != 1)
137 return std::nullopt;
138 int64_t stride = strides[strides.size() - 2];
139 if (stride == ShapedType::kDynamic)
140 return std::nullopt;
141 return stride;
142 }
143
144
146 if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
147 readOp.getVectorType().getRank() != 2)
148 return false;
150 return false;
151
152
153 if (readOp.getVectorType().getElementType().isInteger(8))
154 if (!readOp->hasOneUse() || (!isaarith::ExtSIOp(*readOp->user_begin()) &&
155 !isaarith::ExtUIOp(*readOp->user_begin())))
156 return false;
157
162 auto broadcastInnerDim =
164 return map.isMinorIdentity() || map == broadcastInnerDim ||
166 }
167
168
169 static bool
171
172 if (writeOp.getTransferRank() == 0)
173 return false;
174
175 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
176 writeOp.getVectorType().getRank() != 2)
177 return false;
179 return false;
180
181 if (!writeOp.getPermutationMap().isMinorIdentity())
182 return false;
183 return true;
184 }
185
186
187
189 auto vecType = dyn_cast(constantOp.getType());
190 if (!vecType || vecType.getRank() != 2)
191 return false;
192 return isa(constantOp.getValue());
193 }
194
195
197 return broadcastOp.getResultVectorType().getRank() == 2;
198 }
199
200
201 template
203 auto transferReadOp =
204 extOp.getOperand().template getDefiningOpvector::TransferReadOp();
205 if (!transferReadOp)
206 return false;
207 return llvm::all_of(extOp->getUsers(), llvm::IsaPredvector::ContractionOp);
208 }
209
211
212
213
214 static std::optionalgpu::MMAElementwiseOp
216 if (isaarith::AddFOp(op))
217 return gpu::MMAElementwiseOp::ADDF;
218 if (isaarith::MulFOp(op))
219 return gpu::MMAElementwiseOp::MULF;
220 if (isaarith::SubFOp(op))
221 return gpu::MMAElementwiseOp::SUBF;
222 if (isaarith::MaximumFOp(op))
223 return gpu::MMAElementwiseOp::MAXF;
224 if (isaarith::MinimumFOp(op))
225 return gpu::MMAElementwiseOp::MINF;
226 if (isaarith::DivFOp(op))
227 return gpu::MMAElementwiseOp::DIVF;
228 if (isaarith::AddIOp(op))
230 if (isaarith::MulIOp(op))
232 if (isaarith::SubIOp(op))
234 if (isaarith::DivSIOp(op))
235 return gpu::MMAElementwiseOp::DIVS;
236 if (isaarith::DivUIOp(op))
237 return gpu::MMAElementwiseOp::DIVU;
238 if (isaarith::NegFOp(op))
239 return gpu::MMAElementwiseOp::NEGATEF;
240 if (isaarith::ExtFOp(op))
241 return gpu::MMAElementwiseOp::EXTF;
242 return std::nullopt;
243 }
244
245
248 }
249
250
251
252 static bool
254
255 FailureOrnvgpu::WarpMatrixInfo warpMatrixInfo =
257 if (failed(warpMatrixInfo))
258 return false;
259
261 if (failed(contractOp))
262 return false;
263
264
265
266
268 return (cast(op->getResult(0).getType()) ==
269 cast((*contractOp).getRhs().getType()));
271 return (cast(op->getResult(0).getType()) ==
272 cast((*contractOp).getAcc().getType()));
273
274 return false;
275 }
276
278 if (isa<scf::ForOp, scf::YieldOp>(op))
279 return true;
280 if (auto transferRead = dyn_castvector::TransferReadOp(op))
283 if (auto transferWrite = dyn_castvector::TransferWriteOp(op))
286 if (auto extractStridedSlice = dyn_castvector::ExtractStridedSliceOp(op))
287 return useNvGpu &&
289 if (auto contract = dyn_castvector::ContractionOp(op))
291 if (auto constant = dyn_castarith::ConstantOp(op))
293 if (auto broadcast = dyn_castvector::BroadcastOp(op))
295 if (auto signedExtend = dyn_castarith::ExtSIOp(op))
296 return integerExtendSupportsMMAMatrixTypearith::ExtSIOp(signedExtend);
297 if (auto unsignedExtend = dyn_castarith::ExtUIOp(op))
298 return integerExtendSupportsMMAMatrixTypearith::ExtUIOp(unsignedExtend);
299 if (auto fpExtend = dyn_castarith::ExtFOp(op))
302 }
303
304
305
306
312 slice.insert(op);
313 unsigned currentIndex = 0;
316 while (currentIndex != slice.size()) {
317 auto *currentOp = (slice)[currentIndex];
318
319 backwardSlice.clear();
320 LogicalResult result =
321 getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
322 assert(result.succeeded() && "expected a backward slice");
323 (void)result;
324 slice.insert_range(backwardSlice);
325
326
327 forwardSlice.clear();
328
329
330
331
332 if (auto forOp = dyn_castscf::ForOp(currentOp)) {
333 for (Value forOpResult : forOp.getResults())
334 getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
335 for (BlockArgument &arg : forOp.getRegionIterArgs())
336 getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
337 } else {
338 getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
339 }
340 slice.insert_range(forwardSlice);
341 ++currentIndex;
342 }
343 return slice;
344 }
345
346
347
349 bool useNvGpu) {
350 auto hasVectorDest = [](Operation *op) {
351 return llvm::any_of(op->getResultTypes(), llvm::IsaPred);
352 };
354 backwardSliceOptions.filter = hasVectorDest;
355
356 auto hasVectorSrc = [](Operation *op) {
357 return llvm::any_of(op->getOperandTypes(), llvm::IsaPred);
358 };
360 forwardSliceOptions.filter = hasVectorSrc;
361
363 op->walk([&](vector::ContractionOp contract) {
364 if (opToConvert.contains(contract.getOperation()))
365 return;
368
369
370
371 if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
373 LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
374 return true;
375 }
376 return false;
377 }))
378 return;
379
380 opToConvert.insert_range(dependentOps);
381 });
382
384 }
385
386 namespace {
387
388
389 struct PrepareContractToGPUMMA
392
393 LogicalResult matchAndRewrite(vector::ContractionOp op,
396 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
397
398
400 auto infer = [&](MapList m) {
402 };
405 static constexpr std::array<int64_t, 2> perm = {1, 0};
406 auto iteratorTypes = op.getIteratorTypes().getValue();
412
413
414
415
416 if (maps == infer({{m, k}, {k, n}, {m, n}}))
417 return rewriter.notifyMatchFailure(op, "contraction already prepared");
418 if (maps == infer({{m, k}, {n, k}, {m, n}})) {
419 rhs = rewriter.createvector::TransposeOp(loc, rhs, perm);
420 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
421 lhs = rewriter.createvector::TransposeOp(loc, lhs, perm);
422 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
423 rhs = rewriter.createvector::TransposeOp(loc, rhs, perm);
424 lhs = rewriter.createvector::TransposeOp(loc, lhs, perm);
425 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
426 std::swap(rhs, lhs);
427 rhs = rewriter.createvector::TransposeOp(loc, rhs, perm);
428 lhs = rewriter.createvector::TransposeOp(loc, lhs, perm);
429 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
430 std::swap(rhs, lhs);
431 rhs = rewriter.createvector::TransposeOp(loc, rhs, perm);
432 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
433 std::swap(lhs, rhs);
434 lhs = rewriter.createvector::TransposeOp(loc, lhs, perm);
435 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
436 std::swap(lhs, rhs);
437 } else {
438
439 return rewriter.notifyMatchFailure(op, "unexpected contraction case");
440 }
442 op, lhs, rhs, res,
444 op.getIteratorTypes());
445 return success();
446 }
447 };
448
449
450
451
452
453 struct CombineTransferReadOpTranspose final
456
457 LogicalResult matchAndRewrite(vector::TransposeOp op,
459
460 Value source = op.getVector();
461 Type resultType = op.getType();
463 if ((extOp = source.getDefiningOparith::ExtSIOp()) ||
464 (extOp = source.getDefiningOparith::ExtUIOp()) ||
465 (extOp = source.getDefiningOparith::ExtFOp())) {
467 resultType =
469 cast(source.getType()).getElementType());
470 }
471
472 auto transferReadOp = source.getDefiningOpvector::TransferReadOp();
473 if (!transferReadOp)
475
476
477 if (transferReadOp.getTransferRank() == 0)
479
480 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
481 return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
482
486 permutationMap.compose(transferReadOp.getPermutationMap());
487
488 auto loc = op.getLoc();
490 rewriter
491 .createvector::TransferReadOp(
492 loc, resultType, transferReadOp.getBase(),
494 transferReadOp.getPadding(), transferReadOp.getMask(),
495 transferReadOp.getInBoundsAttr())
496 .getResult();
497
498
499 if (extOp) {
500 if (isaarith::ExtSIOp(extOp))
501 result = rewriter.createarith::ExtSIOp(loc, op.getType(), result)
502 .getResult();
503 else if (isaarith::ExtUIOp(extOp))
504 result = rewriter.createarith::ExtUIOp(loc, op.getType(), result)
505 .getResult();
506 else
507 result = rewriter.createarith::ExtFOp(loc, op.getType(), result)
508 .getResult();
509 }
510
512 return success();
513 }
514 };
515
516 }
517
518
519
520
521
523
524
529 }
530
532 auto contract = dyn_castvector::ContractionOp(users);
534 continue;
537 return "AOp";
539 return "BOp";
540 }
541 return "COp";
542 }
543
544 static LogicalResult
549
550 assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
552 "expected convertible operation");
553
554 std::optional<int64_t> stride =
556 if (!stride.has_value()) {
557 LLVM_DEBUG(DBGS() << "no stride\n");
559 }
560
563
564
565 if (auto cstExpr = dyn_cast(map.getResult(isTranspose))) {
566 assert(cstExpr.getValue() == 0);
567 stride = 0;
568 }
569
570 Value mappingResult = op.getResult();
571 auto elType = op.getVectorType().getElementType();
573 if (op->hasOneUse()) {
574 auto *user = *op->user_begin();
575
576 if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
578 op.getContext(), cast(elType).getWidth(),
580 : IntegerType::Unsigned);
581 mappingResult = user->getResult(0);
582 }
583 }
586 Value load = rewriter.creategpu::SubgroupMmaLoadMatrixOp(
587 op.getLoc(), type, op.getBase(), op.getIndices(),
589 isTranspose ? rewriter.getUnitAttr() : UnitAttr());
590 valueMapping[mappingResult] = load;
591
592 LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
593 return success();
594 }
595
596 static LogicalResult
601
603 std::optional<int64_t> stride =
605 if (!stride.has_value()) {
606 LLVM_DEBUG(DBGS() << "no stride\n");
608 }
609
610 auto it = valueMapping.find(op.getVector());
611 if (it == valueMapping.end()) {
612 LLVM_DEBUG(DBGS() << "no mapping\n");
614 }
615
616 Value matrix = it->second;
617 auto store = rewriter.creategpu::SubgroupMmaStoreMatrixOp(
618 op.getLoc(), matrix, op.getBase(), op.getIndices(),
619 rewriter.getIndexAttr(*stride), UnitAttr());
620 (void)store;
621
622 LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
623
624 LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
626 return success();
627 }
628
629
630 static VectorType
635 if (auto vecType = dyn_cast(elType))
636 elType = vecType.getElementType();
638 }
639
640
641 static LogicalResult
646
647 FailureOrnvgpu::WarpMatrixInfo warpMatrixInfo =
649 if (failed(warpMatrixInfo)) {
650 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
652 }
653
654 FailureOrnvgpu::FragmentElementInfo regInfo =
656 if (failed(regInfo)) {
657 LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
659 }
660
662 auto dense = dyn_cast(op.getValue());
663 if (!dense) {
664 LLVM_DEBUG(DBGS() << "not a splat\n");
666 }
667
668 Value result = rewriter.createarith::ConstantOp(
669 op.getLoc(), vectorType,
671 valueMapping[op.getResult()] = result;
672 return success();
673 }
674
675
676
677
678
679
680
681
682 static FailureOr isTransposed(vector::TransferReadOp op) {
684
686 LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
687 "is not a 2d operand\n");
688 return failure();
689 }
690
691
694
695
696 auto exprM = dyn_cast(dM);
697 auto exprN = dyn_cast(dN);
698
699 if (!exprM || !exprN) {
700 LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
701 "expressions, then transpose cannot be determined.\n");
702 return failure();
703 }
704
705 return exprM.getPosition() > exprN.getPosition();
706 }
707
708 static LogicalResult
714
715 FailureOrnvgpu::WarpMatrixInfo warpMatrixInfo =
717 if (failed(warpMatrixInfo)) {
718 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
720 }
721
722 FailureOrnvgpu::FragmentElementInfo regInfo =
724 if (failed(regInfo)) {
725 LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
727 }
728
731 LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
733 op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
734 }
735
736 FailureOrnvgpu::LdMatrixParams params =
738
739 if (failed(params)) {
740 LLVM_DEBUG(
742 << "failed to convert vector.transfer_read to ldmatrix. "
743 << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
745 op, "failed to convert vector.transfer_read to ldmatrix; this op "
746 "likely should not be converted to a nvgpu.ldmatrix call.");
747 }
748
749
750 auto laneId = rewriter.creategpu::LaneIdOp(loc, nullptr);
751 FailureOr offsets =
753 if (failed(offsets)) {
754 LLVM_DEBUG(DBGS() << "no offsets\n");
756 }
757
759
761 getXferIndicesvector::TransferReadOp(rewriter, op, *offsets, {laneId},
762 indices);
763
764 nvgpu::LdMatrixOp newOp = rewriter.createnvgpu::LdMatrixOp(
765 loc, vectorType, op.getBase(), indices, *transpose, params->numTiles);
766 valueMapping[op] = newOp->getResult(0);
767 return success();
768 }
769
770 static LogicalResult
775
777 FailureOrnvgpu::WarpMatrixInfo warpMatrixInfo =
779 if (failed(warpMatrixInfo))
781 FailureOrnvgpu::FragmentElementInfo regInfo =
783 if (failed(regInfo)) {
785 op, "Failed to deduce register fragment type during "
786 "conversion to distributed non-ldmatrix compatible load");
787 }
788
789 Value laneId = rewriter.creategpu::LaneIdOp(loc, nullptr);
790
791
792 Type loadedElType = regInfo->registerLLVMType;
794
795 Value fill = rewriter.createarith::ConstantOp(
796 op.getLoc(), vectorType.getElementType(),
797 rewriter.getZeroAttr(vectorType.getElementType()));
799 rewriter.createvector::SplatOp(op.getLoc(), fill, vectorType);
800
801 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
802
803
804
805 if (!isTransposeLoad) {
806 if (!isa(loadedElType)) {
808 }
809
810 for (int i = 0; i < vectorType.getShape()[0]; i++) {
812 rewriter, op.getLoc(), *warpMatrixInfo);
813 if (failed(coords))
815
816 Value logicalValueId = rewriter.createarith::ConstantOp(
818 rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
820 getXferIndicesvector::TransferReadOp(
821 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
822
823 Value el = rewriter.createvector::LoadOp(loc, loadedElType,
824 op.getBase(), newIndices);
825 result = rewriter.createvector::InsertOp(loc, el, result, i);
826 }
827 } else {
828 if (auto vecType = dyn_cast(loadedElType)) {
829 loadedElType = vecType.getElementType();
830 }
831 for (int i = 0; i < vectorType.getShape()[0]; i++) {
832 for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
833 innerIdx++) {
834
835 Value logicalValueId = rewriter.createarith::ConstantOp(
837 rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
839 rewriter, op.getLoc(), *warpMatrixInfo);
840 if (failed(coords))
842
844 getXferIndicesvector::TransferReadOp(
845 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
846 Value el = rewriter.creatememref::LoadOp(op.getLoc(), loadedElType,
847 op.getBase(), newIndices);
848 result = rewriter.createvector::InsertOp(
850 }
851 }
852 }
853
854 valueMapping[op.getResult()] = result;
855 return success();
856 }
857
858
860 auto addressSpace =
861 dyn_cast_or_nullgpu::AddressSpaceAttr(type.getMemorySpace());
862 return addressSpace &&
863 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
864 }
865
866
867
868
869 static LogicalResult
874
875 FailureOrnvgpu::WarpMatrixInfo warpMatrixInfo =
877 if (failed(warpMatrixInfo))
879
880 bool isLdMatrixCompatible =
881 isSharedMemory(cast(op.getBase().getType())) &&
883
884 VectorType vecTy = op.getVectorType();
885 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
886
887
888
889
890 if (!op.getPermutationMap().isMinorIdentity() &&
891 (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
892 vecTy.getDimSize(0) * bitWidth < 128))
893 isLdMatrixCompatible = false;
894
895 if (!isLdMatrixCompatible)
897
899 }
900
901 static LogicalResult
906
908 auto it = valueMapping.find(op.getVector());
909 if (it == valueMapping.end())
911 Value matrix = it->second;
912
913 FailureOrnvgpu::WarpMatrixInfo warpMatrixInfo =
915 if (failed(warpMatrixInfo))
917 FailureOrnvgpu::FragmentElementInfo regInfo =
919 if (failed(regInfo))
921
923 Value laneId = rewriter.creategpu::LaneIdOp(loc, nullptr);
924
925 for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
926 Value logicalValueId = rewriter.createarith::ConstantOp(
928 rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
930 rewriter, op.getLoc(), *warpMatrixInfo);
931 if (failed(coords))
933
937 getXferIndicesvector::TransferWriteOp(
938 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
939 rewriter.createvector::StoreOp(loc, el, op.getBase(), newIndices);
940 }
941
942 LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
944 return success();
945 }
946
949 for (auto attr : arrayAttr)
950 results.push_back(cast(attr).getInt());
951 }
952
953 static LogicalResult
955 vector::ExtractStridedSliceOp op,
959
961
962 FailureOrnvgpu::WarpMatrixInfo warpMatrixInfo =
964 if (failed(warpMatrixInfo))
966
967 FailureOrnvgpu::FragmentElementInfo mmaSyncFragmentInfo =
969 if (failed(mmaSyncFragmentInfo))
971
972
973 auto transferReadOp = op.getVector().getDefiningOpvector::TransferReadOp();
974 if (!transferReadOp)
976
978 if (failed(warpMatrixInfo))
980
981 FailureOrnvgpu::FragmentElementInfo ldFragmentInfo =
983 if (failed(ldFragmentInfo))
985
986 assert(
987 (mmaSyncFragmentInfo->elementsPerRegister ==
988 ldFragmentInfo->elementsPerRegister) &&
989 "Number of elements per register should be same for load and mma.sync");
990
991
992 std::array<int64_t, 2> strides = {1,
993 1};
994 std::array<int64_t, 2> sliceShape = {
995 mmaSyncFragmentInfo->numRegistersPerFragment,
996 mmaSyncFragmentInfo->elementsPerRegister};
997 auto it = valueMapping.find(transferReadOp);
998 if (it == valueMapping.end())
1000 auto sourceVector = it->second;
1001
1002
1005
1008 ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1009
1010
1011
1012
1013 std::array<int64_t, 2> sliceOffset = {0, 0};
1014
1015 if (offsets[0] && offsets[1])
1016 return op->emitError() << "Slicing fragments in 2D is not supported. ";
1017 if (offsets[0])
1018 sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1019 else if (offsets[1])
1020 sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1021
1022 Value newOp = rewriter.createvector::ExtractStridedSliceOp(
1023 loc, sourceVector, sliceOffset, sliceShape, strides);
1024
1025 valueMapping[op] = newOp;
1026 return success();
1027 }
1028
1029 static LogicalResult
1034
1035 auto itA = valueMapping.find(op.getLhs());
1036 auto itB = valueMapping.find(op.getRhs());
1037 auto itC = valueMapping.find(op.getAcc());
1038 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1039 itC == valueMapping.end())
1041 Value opA = itA->second, opB = itB->second, opC = itC->second;
1042 Value matmul = rewriter.creategpu::SubgroupMmaComputeOp(
1043 op.getLoc(), opC.getType(), opA, opB, opC, UnitAttr(),
1044 UnitAttr());
1045 valueMapping[op.getResult()] = matmul;
1046 return success();
1047 }
1048
1049 static LogicalResult
1054
1055 auto itA = valueMapping.find(op.getLhs());
1056 auto itB = valueMapping.find(op.getRhs());
1057 auto itC = valueMapping.find(op.getAcc());
1058 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1059 itC == valueMapping.end())
1061 Value opA = itA->second, opB = itB->second, opC = itC->second;
1062 int64_t m = cast(op.getLhs().getType()).getShape()[0];
1063 int64_t n = cast(op.getRhs().getType()).getShape()[0];
1064 int64_t k = cast(op.getLhs().getType()).getShape()[1];
1065 Value matmul = rewriter.createnvgpu::MmaSyncOp(
1066 op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1067 valueMapping[op.getResult()] = matmul;
1068 return success();
1069 }
1070
1071
1072 static LogicalResult
1077
1079
1080 auto splat =
1081 cast(op.getValue()).getSplatValue();
1082 auto scalarConstant =
1083 rewriter.createarith::ConstantOp(op.getLoc(), splat.getType(), splat);
1085 auto vecType = cast(op.getType());
1087 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1088 auto matrix = rewriter.creategpu::SubgroupMmaConstantMatrixOp(
1089 op.getLoc(), type, scalarConstant);
1090 valueMapping[op.getResult()] = matrix;
1091 return success();
1092 }
1093
1094
1095 static LogicalResult
1100
1102
1104 auto vecType = op.getResultVectorType();
1106 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1107 auto matrix = rewriter.creategpu::SubgroupMmaConstantMatrixOp(
1108 op.getLoc(), type, op.getSource());
1109 valueMapping[op.getResult()] = matrix;
1110 return success();
1111 }
1112
1113
1114
1116 scf::ForOp loop,
1120
1121
1123 auto operands = llvm::to_vector<4>(loop.getInitArgs());
1124 llvm::append_range(operands, newInitArgs);
1125 scf::ForOp newLoop = rewriter.createscf::ForOp(
1126 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1127 operands);
1128 rewriter.eraseBlock(newLoop.getBody());
1129
1130 newLoop.getRegion().getBlocks().splice(
1131 newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1132 for (Value operand : newInitArgs)
1133 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1134
1135 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1136 loop.getNumResults())))
1138
1139 LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
1140 LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
1141 LLVM_DEBUG(DBGS() << "erase: " << loop);
1142
1144 return newLoop;
1145 }
1146
1151
1154 for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1155 auto it = valueMapping.find(operand.value());
1156 if (it == valueMapping.end()) {
1157 LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
1158 continue;
1159 }
1160 argMapping.push_back(std::make_pair(
1161 operand.index(), op.getInitArgs().size() + newOperands.size()));
1162 newOperands.push_back(it->second);
1163 }
1164
1166 Block &loopBody = *newForOp.getBody();
1167 for (auto mapping : argMapping) {
1168 valueMapping[newForOp.getResult(mapping.first)] =
1169 newForOp.getResult(mapping.second);
1170 valueMapping[loopBody.getArgument(mapping.first +
1171 newForOp.getNumInductionVars())] =
1172 loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1173 }
1174
1175 LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
1176 return success();
1177 }
1178
1179 static LogicalResult
1184
1185 auto loop = castscf::ForOp(op->getParentOp());
1186 auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1187 for (const auto &operand : llvm::enumerate(op.getOperands())) {
1188 auto it = valueMapping.find(operand.value());
1189 if (it == valueMapping.end())
1190 continue;
1191
1192
1193 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1194 yieldOperands.push_back(it->second);
1195 }
1196 rewriter.createscf::YieldOp(op.getLoc(), yieldOperands);
1197
1198 LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1200 return success();
1201 }
1202
1203
1204 static LogicalResult
1206 gpu::MMAElementwiseOp opType,
1210
1213 auto it = valueMapping.find(operand);
1214 if (it == valueMapping.end())
1216 matrixOperands.push_back(it->second);
1217 }
1218 auto resultType = castgpu::MMAMatrixType(matrixOperands[0].getType());
1219 if (opType == gpu::MMAElementwiseOp::EXTF) {
1220
1221 auto vectorType = cast(op->getResultTypes()[0]);
1223 vectorType.getElementType(),
1224 resultType.getOperand());
1225 }
1226
1227 Value newOp = rewriter.creategpu::SubgroupMmaElementwiseOp(
1228 op->getLoc(), resultType, matrixOperands, opType);
1229 valueMapping[op->getResult(0)] = newOp;
1230 return success();
1231 }
1232
1234 bool useNvGpu) {
1235 if (!useNvGpu) {
1236 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1238 return;
1239 }
1241 patterns.add(patterns.getContext());
1242 }
1243
1248
1249 auto globalRes = LogicalResult::success();
1251 LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1252
1253 auto res = LogicalResult::success();
1254 if (auto transferRead = dyn_castvector::TransferReadOp(op)) {
1256 } else if (auto transferWrite = dyn_castvector::TransferWriteOp(op)) {
1258 } else if (auto contractOp = dyn_castvector::ContractionOp(op)) {
1260 } else if (auto constantOp = dyn_castarith::ConstantOp(op)) {
1262 } else if (auto broadcastOp = dyn_castvector::BroadcastOp(op)) {
1264 } else if (auto forOp = dyn_castscf::ForOp(op)) {
1265 res = convertForOp(rewriter, forOp, valueMapping);
1266 } else if (auto yieldOp = dyn_castscf::YieldOp(op)) {
1267 res = convertYieldOp(rewriter, yieldOp, valueMapping);
1270 }
1271 if (failed(res))
1272 globalRes = failure();
1273 }
1274 return globalRes;
1275 }
1276
1283 .Case([&](vector::TransferReadOp transferReadOp) {
1285 valueMapping);
1286 })
1287 .Case([&](vector::TransferWriteOp transferWriteOp) {
1289 valueMapping);
1290 })
1291 .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1293 valueMapping);
1294 })
1295 .Case([&](vector::ContractionOp contractionOp) {
1297 valueMapping);
1298 })
1299 .Case([&](scf::ForOp forOp) {
1300 return convertForOp(rewriter, forOp, valueMapping);
1301 })
1302 .Case([&](scf::YieldOp yieldOp) {
1303 return convertYieldOp(rewriter, yieldOp, valueMapping);
1304 })
1305 .Case([&](arith::ConstantOp constOp) {
1307 })
1309 return op->emitError() << "unhandled vector to mma type: " << *op;
1310 })
1311 .failed()) {
1313 << "failed to convert op during vector-to-nvgpu conversion";
1314 }
1315 }
1316 return success();
1317 }
1318
1319 namespace {
1320
1321 struct ConvertVectorToGPUPass
1322 : public impl::ConvertVectorToGPUBase {
1323
1324 explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1325 useNvGpu.setValue(useNvGpu_);
1326 }
1327
1328 void runOnOperation() override {
1332 return signalPassFailure();
1333
1335 if (useNvGpu) {
1336 if (failed(
1338 return signalPassFailure();
1339 return;
1340 }
1342 }
1343 };
1344
1345 }
1346
1348 return std::make_unique(useNvGpu);
1349 }
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info)
Returns the vector type which represents a matrix fragment.
static const char * inferFragType(Operation *op)
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isTransposeMatrixLoadMap(AffineMap permutationMap)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
user_iterator user_begin()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams ¶ms)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This trait tags element-wise ops on vectors or tensors.
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
int64_t elementsPerRegister
int64_t numRegistersPerFragment