MLIR: lib/Dialect/MemRef/IR/MemRefOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallBitVector.h"
26
27 using namespace mlir;
29
30
31
35 return arith::ConstantOp::materialize(builder, value, type, loc);
36 }
37
38
39
40
41
42
43
44
46 bool folded = false;
48 auto cast = operand.get().getDefiningOp();
49 if (cast && operand.get() != inner &&
50 !llvm::isa(cast.getOperand().getType())) {
51 operand.set(cast.getOperand());
52 folded = true;
53 }
54 }
55 return success(folded);
56 }
57
58
59
61 if (auto memref = llvm::dyn_cast(type))
63 if (auto memref = llvm::dyn_cast(type))
66 }
67
69 int64_t dim) {
70 auto memrefType = llvm::cast(value.getType());
71 if (memrefType.isDynamicDim(dim))
72 return builder.createOrFoldmemref::DimOp(loc, value, dim);
73
74 return builder.getIndexAttr(memrefType.getDimSize(dim));
75 }
76
79 auto memrefType = llvm::cast(value.getType());
81 for (int64_t i = 0; i < memrefType.getRank(); ++i)
82 result.push_back(getMixedSize(builder, loc, value, i));
83 return result;
84 }
85
86
87
88
89
90
91
92
93
94
95
98 assert(constValues.size() == values.size() &&
99 "incorrect number of const values");
102 if (!ShapedType::isDynamic(cstVal)) {
103
105 continue;
106 }
108
110 }
111 }
112 }
113
114
115
116
117
118 void AllocOp::getAsmResultNames(
120 setNameFn(getResult(), "alloc");
121 }
122
123 void AllocaOp::getAsmResultNames(
125 setNameFn(getResult(), "alloca");
126 }
127
128 template
130 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
131 "applies to only alloc or alloca");
132 auto memRefType = llvm::dyn_cast(op.getResult().getType());
133 if (!memRefType)
134 return op.emitOpError("result must be a memref");
135
136 if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
137 return op.emitOpError("dimension operand count does not equal memref "
138 "dynamic dimension count");
139
140 unsigned numSymbols = 0;
141 if (!memRefType.getLayout().isIdentity())
142 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
143 if (op.getSymbolOperands().size() != numSymbols)
144 return op.emitOpError("symbol operand count does not equal memref symbol "
145 "count: expected ")
146 << numSymbols << ", got " << op.getSymbolOperands().size();
147
148 return success();
149 }
150
152
154
156 return emitOpError(
157 "requires an ancestor op with AutomaticAllocationScope trait");
158
160 }
161
162 namespace {
163
164 template
165 struct SimplifyAllocConst : public OpRewritePattern {
167
168 LogicalResult matchAndRewrite(AllocLikeOp alloc,
170
171
172 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
173 APInt constSizeArg;
174 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
175 return false;
176 return constSizeArg.isNonNegative();
177 }))
178 return failure();
179
180 auto memrefType = alloc.getType();
181
182
183
185 newShapeConstants.reserve(memrefType.getRank());
187
188 unsigned dynamicDimPos = 0;
189 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
190 int64_t dimSize = memrefType.getDimSize(dim);
191
192 if (!ShapedType::isDynamic(dimSize)) {
193 newShapeConstants.push_back(dimSize);
194 continue;
195 }
196 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
197 APInt constSizeArg;
199 constSizeArg.isNonNegative()) {
200
201 newShapeConstants.push_back(constSizeArg.getZExtValue());
202 } else {
203
204 newShapeConstants.push_back(ShapedType::kDynamic);
205 dynamicSizes.push_back(dynamicSize);
206 }
207 dynamicDimPos++;
208 }
209
210
211 MemRefType newMemRefType =
213 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
214
215
216 auto newAlloc = rewriter.create(
217 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
218 alloc.getAlignmentAttr());
219
220 rewriter.replaceOpWithNewOp(alloc, alloc.getType(), newAlloc);
221 return success();
222 }
223 };
224
225
226 template
229
230 LogicalResult matchAndRewrite(T alloc,
232 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
233 if (auto storeOp = dyn_cast(op))
234 return storeOp.getValue() == alloc;
235 return !isa(op);
236 }))
237 return failure();
238
239 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
241
243 return success();
244 }
245 };
246 }
247
248 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
250 results.add<SimplifyAllocConst, SimplifyDeadAlloc>(context);
251 }
252
253 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
255 results.add<SimplifyAllocConst, SimplifyDeadAlloc>(
256 context);
257 }
258
259
260
261
262
264 auto sourceType = llvm::cast(getOperand(0).getType());
265 MemRefType resultType = getType();
266
267
268 if (!sourceType.getLayout().isIdentity())
269 return emitError("unsupported layout for source memref type ")
270 << sourceType;
271
272
273 if (!resultType.getLayout().isIdentity())
274 return emitError("unsupported layout for result memref type ")
275 << resultType;
276
277
278 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
279 return emitError("different memory spaces specified for source memref "
280 "type ")
281 << sourceType << " and result memref type " << resultType;
282
283
284 if (sourceType.getElementType() != resultType.getElementType())
285 return emitError("different element types specified for source memref "
286 "type ")
287 << sourceType << " and result memref type " << resultType;
288
289
290 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
291 return emitError("missing dimension operand for result type ")
292 << resultType;
293 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
294 return emitError("unnecessary dimension operand for result type ")
295 << resultType;
296
297 return success();
298 }
299
300 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
302 results.add<SimplifyDeadAlloc>(context);
303 }
304
305
306
307
308
310 bool printBlockTerminators = false;
311
312 p << ' ';
313 if (!getResults().empty()) {
314 p << " -> (" << getResultTypes() << ")";
315 printBlockTerminators = true;
316 }
317 p << ' ';
319 false,
320 printBlockTerminators);
322 }
323
325
326 result.regions.reserve(1);
328
329
331 return failure();
332
333
334 if (parser.parseRegion(*bodyRegion, {}))
335 return failure();
336 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
338
339
341 return failure();
342
343 return success();
344 }
345
346 void AllocaScopeOp::getSuccessorRegions(
350 return;
351 }
352
354 }
355
356
357
359 MemoryEffectOpInterface interface = dyn_cast(op);
360 if (!interface)
361 return false;
363 if (auto effect =
365 if (isaSideEffects::AutomaticAllocationScopeResource(
366 effect->getResource()))
367 return true;
368 }
369 }
370 return false;
371 }
372
373
374
375
376
378
379
381 return false;
382 MemoryEffectOpInterface interface = dyn_cast(op);
383 if (!interface)
384 return true;
386 if (auto effect =
388 if (isaSideEffects::AutomaticAllocationScopeResource(
389 effect->getResource()))
390 return true;
391 }
392 }
393 return false;
394 }
395
396
397
398
399
404 }
405
406
407
410
413 bool hasPotentialAlloca =
415 if (alloc == op)
422 }).wasInterrupted();
423
424
425
426 if (hasPotentialAlloca) {
427
428
430 return failure();
432 return failure();
433 }
434
435 Block *block = &op.getRegion().front();
440 rewriter.eraseOp(terminator);
441 return success();
442 }
443 };
444
445
446
447
450
453
455 return failure();
456
458
459 if (!lastParentWithoutScope ||
461 return failure();
462
463
464
465
468 return failure();
469
470 while (!lastParentWithoutScope->getParentOp()
472 lastParentWithoutScope = lastParentWithoutScope->getParentOp();
473 if (!lastParentWithoutScope ||
475 return failure();
476 }
477 assert(lastParentWithoutScope->getParentOp()
479
480 Region *containingRegion = nullptr;
481 for (auto &r : lastParentWithoutScope->getRegions()) {
482 if (r.isAncestor(op->getParentRegion())) {
483 assert(containingRegion == nullptr &&
484 "only one region can contain the op");
485 containingRegion = &r;
486 }
487 }
488 assert(containingRegion && "op must be contained in a region");
489
494
495
496
498 return containingRegion->isAncestor(v.getParentRegion());
499 }))
501 toHoist.push_back(alloc);
503 });
504
505 if (toHoist.empty())
506 return failure();
508 for (auto *op : toHoist) {
509 auto *cloned = rewriter.clone(*op);
510 rewriter.replaceOp(op, cloned->getResults());
511 }
512 return success();
513 }
514 };
515
516 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
519 }
520
521
522
523
524
526 if (!llvm::isPowerOf2_32(getAlignment()))
527 return emitOpError("alignment must be power of 2");
528 return success();
529 }
530
531 void AssumeAlignmentOp::getAsmResultNames(
533 setNameFn(getResult(), "assume_align");
534 }
535
536 OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
537 auto source = getMemref().getDefiningOp();
538 if (!source)
539 return {};
540 if (source.getAlignment() != getAlignment())
541 return {};
542 return getMemref();
543 }
544
545
546
547
548
549 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
550 setNameFn(getResult(), "cast");
551 }
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
591 MemRefType sourceType =
592 llvm::dyn_cast(castOp.getSource().getType());
593 MemRefType resultType = llvm::dyn_cast(castOp.getType());
594
595
596 if (!sourceType || !resultType)
597 return false;
598
599
600 if (sourceType.getElementType() != resultType.getElementType())
601 return false;
602
603
604 if (sourceType.getRank() != resultType.getRank())
605 return false;
606
607
608 int64_t sourceOffset, resultOffset;
610 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
611 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
612 return false;
613
614
615 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
616 auto ss = std::get<0>(it), st = std::get<1>(it);
617 if (ss != st)
618 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
619 return false;
620 }
621
622
623 if (sourceOffset != resultOffset)
624 if (ShapedType::isDynamic(sourceOffset) &&
625 !ShapedType::isDynamic(resultOffset))
626 return false;
627
628
629 for (auto it : llvm::zip(sourceStrides, resultStrides)) {
630 auto ss = std::get<0>(it), st = std::get<1>(it);
631 if (ss != st)
632 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
633 return false;
634 }
635
636 return true;
637 }
638
639 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
640 if (inputs.size() != 1 || outputs.size() != 1)
641 return false;
642 Type a = inputs.front(), b = outputs.front();
643 auto aT = llvm::dyn_cast(a);
644 auto bT = llvm::dyn_cast(b);
645
646 auto uaT = llvm::dyn_cast(a);
647 auto ubT = llvm::dyn_cast(b);
648
649 if (aT && bT) {
650 if (aT.getElementType() != bT.getElementType())
651 return false;
652 if (aT.getLayout() != bT.getLayout()) {
653 int64_t aOffset, bOffset;
655 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
656 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
657 aStrides.size() != bStrides.size())
658 return false;
659
660
661
662
663
664 auto checkCompatible = [](int64_t a, int64_t b) {
665 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
666 };
667 if (!checkCompatible(aOffset, bOffset))
668 return false;
669 for (const auto &aStride : enumerate(aStrides))
670 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
671 return false;
672 }
673 if (aT.getMemorySpace() != bT.getMemorySpace())
674 return false;
675
676
677 if (aT.getRank() != bT.getRank())
678 return false;
679
680 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
681 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
682 if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
683 aDim != bDim)
684 return false;
685 }
686 return true;
687 } else {
688 if (!aT && !uaT)
689 return false;
690 if (!bT && !ubT)
691 return false;
692
693 if (uaT && ubT)
694 return false;
695
696 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
697 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
698 if (aEltType != bEltType)
699 return false;
700
701 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
702 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
703 return aMemSpace == bMemSpace;
704 }
705
706 return false;
707 }
708
709 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
711 }
712
713
714
715
716
717 namespace {
718
719
720
723
724 LogicalResult matchAndRewrite(CopyOp copyOp,
726 bool modified = false;
727
728
729 if (auto castOp = copyOp.getSource().getDefiningOp()) {
730 auto fromType = llvm::dyn_cast(castOp.getSource().getType());
731 auto toType = llvm::dyn_cast(castOp.getSource().getType());
732
733 if (fromType && toType) {
734 if (fromType.getShape() == toType.getShape() &&
735 fromType.getElementType() == toType.getElementType()) {
737 copyOp.getSourceMutable().assign(castOp.getSource());
738 });
739 modified = true;
740 }
741 }
742 }
743
744
745 if (auto castOp = copyOp.getTarget().getDefiningOp()) {
746 auto fromType = llvm::dyn_cast(castOp.getSource().getType());
747 auto toType = llvm::dyn_cast(castOp.getSource().getType());
748
749 if (fromType && toType) {
750 if (fromType.getShape() == toType.getShape() &&
751 fromType.getElementType() == toType.getElementType()) {
753 copyOp.getTargetMutable().assign(castOp.getSource());
754 });
755 modified = true;
756 }
757 }
758 }
759
760 return success(modified);
761 }
762 };
763
764
767
768 LogicalResult matchAndRewrite(CopyOp copyOp,
770 if (copyOp.getSource() != copyOp.getTarget())
771 return failure();
772
773 rewriter.eraseOp(copyOp);
774 return success();
775 }
776 };
777
778 struct FoldEmptyCopy final : public OpRewritePattern {
780
782 return type.hasRank() && llvm::is_contained(type.getShape(), 0);
783 }
784
785 LogicalResult matchAndRewrite(CopyOp copyOp,
787 if (isEmptyMemRef(copyOp.getSource().getType()) ||
788 isEmptyMemRef(copyOp.getTarget().getType())) {
789 rewriter.eraseOp(copyOp);
790 return success();
791 }
792
793 return failure();
794 }
795 };
796 }
797
798 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
800 results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
801 }
802
803 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
805
806 bool folded = false;
809 auto castOp = operand.get().getDefiningOpmemref::CastOp();
811 operand.set(castOp.getOperand());
812 folded = true;
813 }
814 }
815 return success(folded);
816 }
817
818
819
820
821
822 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
824
826 }
827
828
829
830
831
832 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
833 setNameFn(getResult(), "dim");
834 }
835
837 int64_t index) {
839 Value indexValue = builder.createarith::ConstantIndexOp(loc, index);
840 build(builder, result, source, indexValue);
841 }
842
843 std::optional<int64_t> DimOp::getConstantIndex() {
845 }
846
851
852 auto rankedSourceType = dyn_cast(getSource().getType());
853 if (!rankedSourceType)
855
858
860 }
861
864 setResultRange(getResult(),
866 }
867
868
869
870
871
873 std::map<int64_t, unsigned> numOccurences;
874 for (auto val : vals)
875 numOccurences[val]++;
876 return numOccurences;
877 }
878
879
880
881
882
883
884
885
886 static FailureOrllvm::SmallBitVector
889 llvm::SmallBitVector unusedDims(originalType.getRank());
890 if (originalType.getRank() == reducedType.getRank())
891 return unusedDims;
892
894 if (auto attr = llvm::dyn_cast_if_present(dim.value()))
895 if (llvm::cast(attr).getInt() == 1)
896 unusedDims.set(dim.index());
897
898
899
900 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
901 originalType.getRank())
902 return unusedDims;
903
905 int64_t originalOffset, candidateOffset;
906 if (failed(
907 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
908 failed(
909 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
910 return failure();
911
912
913
914
915
916
917
918
919
920
921 std::map<int64_t, unsigned> currUnaccountedStrides =
923 std::map<int64_t, unsigned> candidateStridesNumOccurences =
925 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
926 if (!unusedDims.test(dim))
927 continue;
928 int64_t originalStride = originalStrides[dim];
929 if (currUnaccountedStrides[originalStride] >
930 candidateStridesNumOccurences[originalStride]) {
931
932 currUnaccountedStrides[originalStride]--;
933 continue;
934 }
935 if (currUnaccountedStrides[originalStride] ==
936 candidateStridesNumOccurences[originalStride]) {
937
938 unusedDims.reset(dim);
939 continue;
940 }
941 if (currUnaccountedStrides[originalStride] <
942 candidateStridesNumOccurences[originalStride]) {
943
944
945 return failure();
946 }
947 }
948
949 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
950 originalType.getRank())
951 return failure();
952 return unusedDims;
953 }
954
956 MemRefType sourceType = getSourceType();
957 MemRefType resultType = getType();
958 FailureOrllvm::SmallBitVector unusedDims =
960 assert(succeeded(unusedDims) && "unable to find unused dims of subview");
961 return *unusedDims;
962 }
963
964 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
965
966 auto index = llvm::dyn_cast_if_present(adaptor.getIndex());
967 if (!index)
968 return {};
969
970
971 auto memrefType = llvm::dyn_cast(getSource().getType());
972 if (!memrefType)
973 return {};
974
975
976
977 int64_t indexVal = index.getInt();
978 if (indexVal < 0 || indexVal >= memrefType.getRank())
979 return {};
980
981
982 if (!memrefType.isDynamicDim(index.getInt())) {
984 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
985 }
986
987
988 unsigned unsignedIndex = index.getValue().getZExtValue();
989
990
991 Operation *definingOp = getSource().getDefiningOp();
992
993 if (auto alloc = dyn_cast_or_null(definingOp))
994 return *(alloc.getDynamicSizes().begin() +
995 memrefType.getDynamicDimIndex(unsignedIndex));
996
997 if (auto alloca = dyn_cast_or_null(definingOp))
998 return *(alloca.getDynamicSizes().begin() +
999 memrefType.getDynamicDimIndex(unsignedIndex));
1000
1001 if (auto view = dyn_cast_or_null(definingOp))
1002 return *(view.getDynamicSizes().begin() +
1003 memrefType.getDynamicDimIndex(unsignedIndex));
1004
1005 if (auto subview = dyn_cast_or_null(definingOp)) {
1006 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1007 unsigned resultIndex = 0;
1008 unsigned sourceRank = subview.getSourceType().getRank();
1009 unsigned sourceIndex = 0;
1010 for (auto i : llvm::seq(0, sourceRank)) {
1011 if (unusedDims.test(i))
1012 continue;
1013 if (resultIndex == unsignedIndex) {
1014 sourceIndex = i;
1015 break;
1016 }
1017 resultIndex++;
1018 }
1019 assert(subview.isDynamicSize(sourceIndex) &&
1020 "expected dynamic subview size");
1021 return subview.getDynamicSize(sourceIndex);
1022 }
1023
1024 if (auto sizeInterface =
1025 dyn_cast_or_null(definingOp)) {
1026 assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1027 "Expected dynamic subview size");
1028 return sizeInterface.getDynamicSize(unsignedIndex);
1029 }
1030
1031
1033 return getResult();
1034
1035 return {};
1036 }
1037
1038 namespace {
1039
1040
1041 struct DimOfMemRefReshape : public OpRewritePattern {
1043
1044 LogicalResult matchAndRewrite(DimOp dim,
1046 auto reshape = dim.getSource().getDefiningOp();
1047
1048 if (!reshape)
1050 dim, "Dim op is not defined by a reshape op.");
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1062 if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1063 if (reshape->isBeforeInBlock(definingOp)) {
1065 dim,
1066 "dim.getIndex is not defined before reshape in the same block.");
1067 }
1068 }
1069
1070 }
1071 else if (dim->getBlock() != reshape->getBlock() &&
1072 !dim.getIndex().getParentRegion()->isProperAncestor(
1073 reshape->getParentRegion())) {
1074
1075
1076
1078 dim, "dim.getIndex does not dominate reshape.");
1079 }
1080
1081
1082
1084 Location loc = dim.getLoc();
1086 rewriter.create(loc, reshape.getShape(), dim.getIndex());
1087 if (load.getType() != dim.getType())
1088 load = rewriter.createarith::IndexCastOp(loc, dim.getType(), load);
1090 return success();
1091 }
1092 };
1093
1094 }
1095
1096 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1098 results.add(context);
1099 }
1100
1101
1102
1103
1104
1109 Value elementsPerStride) {
1114 result.addOperands({numElements, tagMemRef});
1116 if (stride)
1117 result.addOperands({stride, elementsPerStride});
1118 }
1119
1121 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1122 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1123 << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1124 if (isStrided())
1125 p << ", " << getStride() << ", " << getNumElementsPerStride();
1126
1128 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1129 << ", " << getTagMemRef().getType();
1130 }
1131
1132
1133
1134
1135
1136
1137
1138
1139
1149
1152
1153
1154
1155
1156
1164 return failure();
1165
1166
1168 return failure();
1169
1170 bool isStrided = strideInfo.size() == 2;
1171 if (!strideInfo.empty() && !isStrided) {
1173 "expected two stride related operands");
1174 }
1175
1177 return failure();
1178 if (types.size() != 3)
1179 return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1180
1185
1188
1190 return failure();
1191
1192 if (isStrided) {
1194 return failure();
1195 }
1196
1197 return success();
1198 }
1199
1201 unsigned numOperands = getNumOperands();
1202
1203
1204
1205 if (numOperands < 4)
1206 return emitOpError("expected at least 4 operands");
1207
1208
1209
1210
1211 if (!llvm::isa(getSrcMemRef().getType()))
1212 return emitOpError("expected source to be of memref type");
1213 if (numOperands < getSrcMemRefRank() + 4)
1214 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1215 << " operands";
1216 if (!getSrcIndices().empty() &&
1217 !llvm::all_of(getSrcIndices().getTypes(),
1219 return emitOpError("expected source indices to be of index type");
1220
1221
1222 if (!llvm::isa(getDstMemRef().getType()))
1223 return emitOpError("expected destination to be of memref type");
1224 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1225 if (numOperands < numExpectedOperands)
1226 return emitOpError() << "expected at least " << numExpectedOperands
1227 << " operands";
1228 if (!getDstIndices().empty() &&
1229 !llvm::all_of(getDstIndices().getTypes(),
1231 return emitOpError("expected destination indices to be of index type");
1232
1233
1235 return emitOpError("expected num elements to be of index type");
1236
1237
1238 if (!llvm::isa(getTagMemRef().getType()))
1239 return emitOpError("expected tag to be of memref type");
1240 numExpectedOperands += getTagMemRefRank();
1241 if (numOperands < numExpectedOperands)
1242 return emitOpError() << "expected at least " << numExpectedOperands
1243 << " operands";
1244 if (!getTagIndices().empty() &&
1245 !llvm::all_of(getTagIndices().getTypes(),
1247 return emitOpError("expected tag indices to be of index type");
1248
1249
1250
1251 if (numOperands != numExpectedOperands &&
1252 numOperands != numExpectedOperands + 2)
1253 return emitOpError("incorrect number of operands");
1254
1255
1256 if (isStrided()) {
1258 !getNumElementsPerStride().getType().isIndex())
1259 return emitOpError(
1260 "expected stride and num elements per stride to be of type index");
1261 }
1262
1263 return success();
1264 }
1265
1266 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1268
1270 }
1271
1272
1273
1274
1275
1276 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1278
1280 }
1281
1283
1284 unsigned numTagIndices = getTagIndices().size();
1285 unsigned tagMemRefRank = getTagMemRefRank();
1286 if (numTagIndices != tagMemRefRank)
1287 return emitOpError() << "expected tagIndices to have the same number of "
1288 "elements as the tagMemRef rank, expected "
1289 << tagMemRefRank << ", but got " << numTagIndices;
1290 return success();
1291 }
1292
1293
1294
1295
1296
1297 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1299 setNameFn(getResult(), "intptr");
1300 }
1301
1302
1303
1304
1305
1306
1307
1308 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1309 MLIRContext *context, std::optional location,
1310 ExtractStridedMetadataOp::Adaptor adaptor,
1312 auto sourceType = llvm::dyn_cast(adaptor.getSource().getType());
1313 if (!sourceType)
1314 return failure();
1315
1316 unsigned sourceRank = sourceType.getRank();
1318 auto memrefType =
1320 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1321
1322 inferredReturnTypes.push_back(memrefType);
1323
1324 inferredReturnTypes.push_back(indexType);
1325
1326 for (unsigned i = 0; i < sourceRank * 2; ++i)
1327 inferredReturnTypes.push_back(indexType);
1328 return success();
1329 }
1330
1331 void ExtractStridedMetadataOp::getAsmResultNames(
1333 setNameFn(getBaseBuffer(), "base_buffer");
1334 setNameFn(getOffset(), "offset");
1335
1336
1337 if (!getSizes().empty()) {
1338 setNameFn(getSizes().front(), "sizes");
1339 setNameFn(getStrides().front(), "strides");
1340 }
1341 }
1342
1343
1344
1345
1346 template
1348 Container values,
1350 assert(values.size() == maybeConstants.size() &&
1351 " expected values and maybeConstants of the same size");
1352 bool atLeastOneReplacement = false;
1353 for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1354
1355
1356 if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1357 continue;
1358 assert(isa(maybeConstant) &&
1359 "The constified value should be either unchanged (i.e., == result) "
1360 "or a constant");
1361 Value constantVal = rewriter.createarith::ConstantIndexOp(
1362 loc, llvm::cast(cast(maybeConstant)).getInt());
1363 for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1364
1365
1367 atLeastOneReplacement = true;
1368 }
1369 }
1370 return atLeastOneReplacement;
1371 }
1372
1373 LogicalResult
1374 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1377
1380 getConstifiedMixedOffset());
1382 getConstifiedMixedSizes());
1384 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1385
1386 return success(atLeastOneReplacement);
1387 }
1388
1392 return values;
1393 }
1394
1396 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1399 int64_t unused;
1400 LogicalResult status =
1401 getSource().getType().getStridesAndOffset(staticValues, unused);
1402 (void)status;
1403 assert(succeeded(status) && "could not get strides from type");
1405 return values;
1406 }
1407
1408 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1412 int64_t offset;
1413 LogicalResult status =
1414 getSource().getType().getStridesAndOffset(unused, offset);
1415 (void)status;
1416 assert(succeeded(status) && "could not get offset from type");
1417 staticValues.push_back(offset);
1419 return values[0];
1420 }
1421
1422
1423
1424
1425
1431
1432 if (auto memrefType = llvm::dyn_cast(memref.getType())) {
1433 Type elementType = memrefType.getElementType();
1434 result.addTypes(elementType);
1435
1439 }
1440 }
1441
1443 auto &body = getRegion();
1444 if (body.getNumArguments() != 1)
1445 return emitOpError("expected single number of entry block arguments");
1446
1447 if (getResult().getType() != body.getArgument(0).getType())
1448 return emitOpError("expected block argument of the same type result type");
1449
1451 body.walk([&](Operation *nestedOp) {
1455 "body of 'memref.generic_atomic_rmw' should contain "
1456 "only operations with no side effects");
1458 })
1459 .wasInterrupted();
1461 }
1462
1466 Type memrefType;
1468
1475 return failure();
1476
1480 return failure();
1481 result.types.push_back(llvm::cast(memrefType).getElementType());
1482 return success();
1483 }
1484
1486 p << ' ' << getMemref() << "[" << getIndices()
1487 << "] : " << getMemref().getType() << ' ';
1490 }
1491
1492
1493
1494
1495
1497 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1498 Type resultType = getResult().getType();
1499 if (parentType != resultType)
1500 return emitOpError() << "types mismatch between yield op: " << resultType
1501 << " and its parent: " << parentType;
1502 return success();
1503 }
1504
1505
1506
1507
1508
1510 TypeAttr type,
1512 p << type;
1513 if (!op.isExternal()) {
1514 p << " = ";
1515 if (op.isUninitialized())
1516 p << "uninitialized";
1517 else
1519 }
1520 }
1521
1522 static ParseResult
1527 return failure();
1528
1529 auto memrefType = llvm::dyn_cast(type);
1530 if (!memrefType || !memrefType.hasStaticShape())
1532 << "type should be static shaped memref, but got " << type;
1534
1536 return success();
1537
1540 return success();
1541 }
1542
1544 if (parser.parseAttribute(initialValue, tensorType))
1545 return failure();
1546 if (!llvm::isa(initialValue))
1548 << "initial value should be a unit or elements attribute";
1549 return success();
1550 }
1551
1553 auto memrefType = llvm::dyn_cast(getType());
1554 if (!memrefType || !memrefType.hasStaticShape())
1555 return emitOpError("type should be static shaped memref, but got ")
1557
1558
1559
1560 if (getInitialValue().has_value()) {
1561 Attribute initValue = getInitialValue().value();
1562 if (!llvm::isa(initValue) && !llvm::isa(initValue))
1563 return emitOpError("initial value should be a unit or elements "
1564 "attribute, but got ")
1565 << initValue;
1566
1567
1568
1569 if (auto elementsAttr = llvm::dyn_cast(initValue)) {
1570 Type initType = elementsAttr.getType();
1572 if (initType != tensorType)
1573 return emitOpError("initial value expected to be of type ")
1574 << tensorType << ", but was of type " << initType;
1575 }
1576 }
1577
1578 if (std::optional<uint64_t> alignAttr = getAlignment()) {
1579 uint64_t alignment = *alignAttr;
1580
1581 if (!llvm::isPowerOf2_64(alignment))
1582 return emitError() << "alignment attribute value " << alignment
1583 << " is not a power of 2";
1584 }
1585
1586
1587 return success();
1588 }
1589
1590 ElementsAttr GlobalOp::getConstantInitValue() {
1591 auto initVal = getInitialValue();
1592 if (getConstant() && initVal.has_value())
1593 return llvm::cast(initVal.value());
1594 return {};
1595 }
1596
1597
1598
1599
1600
1601 LogicalResult
1603
1604
1605 auto global =
1607 if (!global)
1608 return emitOpError("'")
1609 << getName() << "' does not reference a valid global memref";
1610
1611 Type resultType = getResult().getType();
1612 if (global.getType() != resultType)
1613 return emitOpError("result type ")
1614 << resultType << " does not match type " << global.getType()
1615 << " of the global memref @" << getName();
1616 return success();
1617 }
1618
1619
1620
1621
1622
1625 return emitOpError("incorrect number of indices for load, expected ")
1627 }
1628 return success();
1629 }
1630
1631 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1632
1634 return getResult();
1636 }
1637
1638
1639
1640
1641
1642 void MemorySpaceCastOp::getAsmResultNames(
1644 setNameFn(getResult(), "memspacecast");
1645 }
1646
1647 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1648 if (inputs.size() != 1 || outputs.size() != 1)
1649 return false;
1650 Type a = inputs.front(), b = outputs.front();
1651 auto aT = llvm::dyn_cast(a);
1652 auto bT = llvm::dyn_cast(b);
1653
1654 auto uaT = llvm::dyn_cast(a);
1655 auto ubT = llvm::dyn_cast(b);
1656
1657 if (aT && bT) {
1658 if (aT.getElementType() != bT.getElementType())
1659 return false;
1660 if (aT.getLayout() != bT.getLayout())
1661 return false;
1662 if (aT.getShape() != bT.getShape())
1663 return false;
1664 return true;
1665 }
1666 if (uaT && ubT) {
1667 return uaT.getElementType() == ubT.getElementType();
1668 }
1669 return false;
1670 }
1671
1672 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1673
1674
1675 if (auto parentCast = getSource().getDefiningOp()) {
1676 getSourceMutable().assign(parentCast.getSource());
1677 return getResult();
1678 }
1680 }
1681
1682
1683
1684
1685
1687 p << " " << getMemref() << '[';
1689 p << ']' << ", " << (getIsWrite() ? "write" : "read");
1690 p << ", locality<" << getLocalityHint();
1691 p << ">, " << (getIsDataCache() ? "data" : "instr");
1693 (*this)->getAttrs(),
1694 {"localityHint", "isWrite", "isDataCache"});
1696 }
1697
1701 IntegerAttr localityHint;
1702 MemRefType type;
1703 StringRef readOrWrite, cacheType;
1704
1712 parser.parseAttribute(localityHint, i32Type, "localityHint",
1718 return failure();
1719
1720 if (readOrWrite != "read" && readOrWrite != "write")
1722 "rw specifier has to be 'read' or 'write'");
1723 result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1725
1726 if (cacheType != "data" && cacheType != "instr")
1728 "cache type has to be 'data' or 'instr'");
1729
1730 result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1732
1733 return success();
1734 }
1735
1737 if (getNumOperands() != 1 + getMemRefType().getRank())
1738 return emitOpError("too few indices");
1739
1740 return success();
1741 }
1742
1743 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1745
1747 }
1748
1749
1750
1751
1752
1753 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1754
1755 auto type = getOperand().getType();
1756 auto shapedType = llvm::dyn_cast(type);
1757 if (shapedType && shapedType.hasRank())
1759 return IntegerAttr();
1760 }
1761
1762
1763
1764
1765
1766 void ReinterpretCastOp::getAsmResultNames(
1768 setNameFn(getResult(), "reinterpret_cast");
1769 }
1770
1771
1772
1773
1775 MemRefType resultType, Value source,
1785 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1789 }
1790
1796 auto sourceType = cast(source.getType());
1803 b.getContext(), staticOffsets.front(), staticStrides);
1804 auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1805 stridedLayout, sourceType.getMemorySpace());
1806 build(b, result, resultType, source, offset, sizes, strides, attrs);
1807 }
1808
1810 MemRefType resultType, Value source,
1815 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1817 }));
1819 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1821 }));
1822 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1823 strideValues, attrs);
1824 }
1825
1827 MemRefType resultType, Value source, Value offset,
1831 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1833 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1834 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1835 }
1836
1837
1838
1840
1841 auto srcType = llvm::cast(getSource().getType());
1842 auto resultType = llvm::cast(getType());
1843 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1844 return emitError("different memory spaces specified for source type ")
1845 << srcType << " and result memref type " << resultType;
1846 if (srcType.getElementType() != resultType.getElementType())
1847 return emitError("different element types specified for source type ")
1848 << srcType << " and result memref type " << resultType;
1849
1850
1851 for (auto [idx, resultSize, expectedSize] :
1852 llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1853 if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
1854 return emitError("expected result type with size = ")
1855 << (ShapedType::isDynamic(expectedSize)
1856 ? std::string("dynamic")
1857 : std::to_string(expectedSize))
1858 << " instead of " << resultSize << " in dim = " << idx;
1859 }
1860
1861
1862
1863
1864 int64_t resultOffset;
1866 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1867 return emitError("expected result type to have strided layout but found ")
1868 << resultType;
1869
1870
1871 int64_t expectedOffset = getStaticOffsets().front();
1872 if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
1873 return emitError("expected result type with offset = ")
1874 << (ShapedType::isDynamic(expectedOffset)
1875 ? std::string("dynamic")
1876 : std::to_string(expectedOffset))
1877 << " instead of " << resultOffset;
1878
1879
1880 for (auto [idx, resultStride, expectedStride] :
1882 if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
1883 return emitError("expected result type with stride = ")
1884 << (ShapedType::isDynamic(expectedStride)
1885 ? std::string("dynamic")
1886 : std::to_string(expectedStride))
1887 << " instead of " << resultStride << " in dim = " << idx;
1888 }
1889
1890 return success();
1891 }
1892
1893 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor ) {
1894 Value src = getSource();
1895 auto getPrevSrc = [&]() -> Value {
1896
1897 if (auto prev = src.getDefiningOp())
1898 return prev.getSource();
1899
1900
1902 return prev.getSource();
1903
1904
1905
1907 if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
1908 return prev.getSource();
1909
1910 return nullptr;
1911 };
1912
1913 if (auto prevSrc = getPrevSrc()) {
1914 getSourceMutable().assign(prevSrc);
1915 return getResult();
1916 }
1917
1918
1919 if (!ShapedType::isDynamicShape(getType().getShape()) &&
1920 src.getType() == getType() && getStaticOffsets().front() == 0) {
1921 return src;
1922 }
1923
1924 return nullptr;
1925 }
1926
1930 return values;
1931 }
1932
1936 int64_t unused;
1937 LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
1938 (void)status;
1939 assert(succeeded(status) && "could not get strides from type");
1941 return values;
1942 }
1943
1944 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1946 assert(values.size() == 1 &&
1947 "reinterpret_cast must have one and only one offset");
1949 int64_t offset;
1950 LogicalResult status = getType().getStridesAndOffset(unused, offset);
1951 (void)status;
1952 assert(succeeded(status) && "could not get offset from type");
1953 staticValues.push_back(offset);
1955 return values[0];
1956 }
1957
1958 namespace {
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001 struct ReinterpretCastOpExtractStridedMetadataFolder
2003 public:
2005
2006 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2008 auto extractStridedMetadata =
2009 op.getSource().getDefiningOp();
2010 if (!extractStridedMetadata)
2011 return failure();
2012
2013
2014
2015 auto isReinterpretCastNoop = [&]() -> bool {
2016
2017 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2018 op.getConstifiedMixedStrides()))
2019 return false;
2020
2021
2022 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2023 op.getConstifiedMixedSizes()))
2024 return false;
2025
2026
2027 assert(op.getMixedOffsets().size() == 1 &&
2028 "reinterpret_cast with more than one offset should have been "
2029 "rejected by the verifier");
2030 return extractStridedMetadata.getConstifiedMixedOffset() ==
2031 op.getConstifiedMixedOffset();
2032 };
2033
2034 if (!isReinterpretCastNoop()) {
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2051 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2052 });
2053 return success();
2054 }
2055
2056
2057
2058
2059
2060
2061 Type srcTy = extractStridedMetadata.getSource().getType();
2062 if (srcTy == op.getResult().getType())
2063 rewriter.replaceOp(op, extractStridedMetadata.getSource());
2064 else
2066 extractStridedMetadata.getSource());
2067
2068 return success();
2069 }
2070 };
2071 }
2072
2073 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2075 results.add(context);
2076 }
2077
2078
2079
2080
2081
2082 void CollapseShapeOp::getAsmResultNames(
2084 setNameFn(getResult(), "collapse_shape");
2085 }
2086
2087 void ExpandShapeOp::getAsmResultNames(
2089 setNameFn(getResult(), "expand_shape");
2090 }
2091
2094 reifiedResultShapes = {
2095 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2096 return success();
2097 }
2098
2099
2100
2101
2102
2103
2104 static LogicalResult
2108 bool allowMultipleDynamicDimsPerGroup) {
2109
2110 if (collapsedShape.size() != reassociation.size())
2111 return op->emitOpError("invalid number of reassociation groups: found ")
2112 << reassociation.size() << ", expected " << collapsedShape.size();
2113
2114
2115
2116 int64_t nextDim = 0;
2119 int64_t collapsedDim = it.index();
2120
2121 bool foundDynamic = false;
2122 for (int64_t expandedDim : group) {
2123 if (expandedDim != nextDim++)
2124 return op->emitOpError("reassociation indices must be contiguous");
2125
2126 if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2127 return op->emitOpError("reassociation index ")
2128 << expandedDim << " is out of bounds";
2129
2130
2131 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2132 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2134 "at most one dimension in a reassociation group may be dynamic");
2135 foundDynamic = true;
2136 }
2137 }
2138
2139
2140 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2141 return op->emitOpError("collapsed dim (")
2142 << collapsedDim
2143 << ") must be dynamic if and only if reassociation group is "
2144 "dynamic";
2145
2146
2147
2148 if (!foundDynamic) {
2149 int64_t groupSize = 1;
2150 for (int64_t expandedDim : group)
2151 groupSize *= expandedShape[expandedDim];
2152 if (groupSize != collapsedShape[collapsedDim])
2153 return op->emitOpError("collapsed dim size (")
2154 << collapsedShape[collapsedDim]
2155 << ") must equal reassociation group size (" << groupSize << ")";
2156 }
2157 }
2158
2159 if (collapsedShape.empty()) {
2160
2161 for (int64_t d : expandedShape)
2162 if (d != 1)
2164 "rank 0 memrefs can only be extended/collapsed with/from ones");
2165 } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2166
2167
2168 return op->emitOpError("expanded rank (")
2169 << expandedShape.size()
2170 << ") inconsistent with number of reassociation indices (" << nextDim
2171 << ")";
2172 }
2173
2174 return success();
2175 }
2176
2179 }
2180
2183 getReassociationIndices());
2184 }
2185
2188 }
2189
2192 getReassociationIndices());
2193 }
2194
2195
2196
2197 static FailureOr
2200 int64_t srcOffset;
2202 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2203 return failure();
2204 assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2219 reverseResultStrides.reserve(resultShape.size());
2220 unsigned shapeIndex = resultShape.size() - 1;
2221 for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2223 int64_t currentStrideToExpand = std::get<1>(it);
2224 for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2225 reverseResultStrides.push_back(currentStrideToExpand);
2226 currentStrideToExpand =
2229 .asInteger();
2230 }
2231 }
2232 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2233 resultStrides.resize(resultShape.size(), 1);
2235 }
2236
2237 FailureOr ExpandShapeOp::computeExpandedType(
2240 if (srcType.getLayout().isIdentity()) {
2241
2242
2243 MemRefLayoutAttrInterface layout;
2244 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2245 srcType.getMemorySpace());
2246 }
2247
2248
2249 FailureOr computedLayout =
2251 if (failed(computedLayout))
2252 return failure();
2253 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2254 srcType.getMemorySpace());
2255 }
2256
2257 FailureOr<SmallVector>
2259 MemRefType expandedType,
2262 std::optional<SmallVector> outputShape =
2264 inputShape);
2265 if (!outputShape)
2266 return failure();
2267 return *outputShape;
2268 }
2269
2274 auto [staticOutputShape, dynamicOutputShape] =
2276 build(builder, result, llvm::cast(resultType), src,
2278 dynamicOutputShape, staticOutputShape);
2279 }
2280
2286 MemRefType memrefResultTy = llvm::cast(resultType);
2287 FailureOr<SmallVector> outputShape = inferOutputShape(
2288 builder, result.location, memrefResultTy, reassociation, inputShape);
2289
2290
2291 assert(succeeded(outputShape) && "unable to infer output shape");
2292 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2293 }
2294
2298
2299 auto srcType = llvm::cast(src.getType());
2300 FailureOr resultType =
2301 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2302
2303
2304 assert(succeeded(resultType) && "could not compute layout");
2305 build(builder, result, *resultType, src, reassociation);
2306 }
2307
2312
2313 auto srcType = llvm::cast(src.getType());
2314 FailureOr resultType =
2315 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2316
2317
2318 assert(succeeded(resultType) && "could not compute layout");
2319 build(builder, result, *resultType, src, reassociation, outputShape);
2320 }
2321
2323 MemRefType srcType = getSrcType();
2324 MemRefType resultType = getResultType();
2325
2326 if (srcType.getRank() > resultType.getRank()) {
2327 auto r0 = srcType.getRank();
2328 auto r1 = resultType.getRank();
2329 return emitOpError("has source rank ")
2330 << r0 << " and result rank " << r1 << ". This is not an expansion ("
2331 << r0 << " > " << r1 << ").";
2332 }
2333
2334
2336 resultType.getShape(),
2337 getReassociationIndices(),
2338 true)))
2339 return failure();
2340
2341
2342 FailureOr expectedResultType = ExpandShapeOp::computeExpandedType(
2343 srcType, resultType.getShape(), getReassociationIndices());
2344 if (failed(expectedResultType))
2345 return emitOpError("invalid source layout map");
2346
2347
2348 if (*expectedResultType != resultType)
2349 return emitOpError("expected expanded type to be ")
2350 << *expectedResultType << " but found " << resultType;
2351
2352 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2353 return emitOpError("expected number of static shape bounds to be equal to "
2354 "the output rank (")
2355 << resultType.getRank() << ") but found "
2356 << getStaticOutputShape().size() << " inputs instead";
2357
2358 if ((int64_t)getOutputShape().size() !=
2359 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2360 return emitOpError("mismatch in dynamic dims in output_shape and "
2361 "static_output_shape: static_output_shape has ")
2362 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2363 << " dynamic dims while output_shape has " << getOutputShape().size()
2364 << " values";
2365
2366
2367 DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2370 if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
2371 return emitOpError("invalid output shape provided at pos ") << pos;
2372 }
2373 }
2374
2375 return success();
2376 }
2377
2378 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2380 results.add<
2383 }
2384
2385
2386
2387
2388
2389
2390
2391
2392 static FailureOr
2395 bool strict = false) {
2396 int64_t srcOffset;
2398 auto srcShape = srcType.getShape();
2399 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2400 return failure();
2401
2402
2403
2404
2405
2406
2408 resultStrides.reserve(reassociation.size());
2411 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2412 ref = ref.drop_back();
2413 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2414 resultStrides.push_back(srcStrides[ref.back()]);
2415 } else {
2416
2417
2418
2419
2420 resultStrides.push_back(ShapedType::kDynamic);
2421 }
2422 }
2423
2424
2425 unsigned resultStrideIndex = resultStrides.size() - 1;
2427 auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2429 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2431
2432
2433
2434
2435
2436
2437
2438
2439
2441 if (strict && (stride.saturated || srcStride.saturated))
2442 return failure();
2443
2444
2445
2446 if (srcShape[idx - 1] == 1)
2447 continue;
2448
2449 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2450 return failure();
2451 }
2452 }
2454 }
2455
2456 bool CollapseShapeOp::isGuaranteedCollapsible(
2458
2459 if (srcType.getLayout().isIdentity())
2460 return true;
2461
2463 true));
2464 }
2465
2466 MemRefType CollapseShapeOp::computeCollapsedType(
2469 resultShape.reserve(reassociation.size());
2472 for (int64_t srcDim : group)
2473 groupSize =
2475 resultShape.push_back(groupSize.asInteger());
2476 }
2477
2478 if (srcType.getLayout().isIdentity()) {
2479
2480
2481 MemRefLayoutAttrInterface layout;
2482 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2483 srcType.getMemorySpace());
2484 }
2485
2486
2487
2488
2489 FailureOr computedLayout =
2491 assert(succeeded(computedLayout) &&
2492 "invalid source layout map or collapsing non-contiguous dims");
2493 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2494 srcType.getMemorySpace());
2495 }
2496
2500 auto srcType = llvm::cast(src.getType());
2501 MemRefType resultType =
2502 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2505 build(b, result, resultType, src, attrs);
2506 }
2507
2509 MemRefType srcType = getSrcType();
2510 MemRefType resultType = getResultType();
2511
2512 if (srcType.getRank() < resultType.getRank()) {
2513 auto r0 = srcType.getRank();
2514 auto r1 = resultType.getRank();
2515 return emitOpError("has source rank ")
2516 << r0 << " and result rank " << r1 << ". This is not a collapse ("
2517 << r0 << " < " << r1 << ").";
2518 }
2519
2520
2522 srcType.getShape(), getReassociationIndices(),
2523 true)))
2524 return failure();
2525
2526
2527 MemRefType expectedResultType;
2528 if (srcType.getLayout().isIdentity()) {
2529
2530
2531 MemRefLayoutAttrInterface layout;
2532 expectedResultType =
2533 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2534 srcType.getMemorySpace());
2535 } else {
2536
2537
2538
2539 FailureOr computedLayout =
2541 if (failed(computedLayout))
2542 return emitOpError(
2543 "invalid source layout map or collapsing non-contiguous dims");
2544 expectedResultType =
2545 MemRefType::get(resultType.getShape(), srcType.getElementType(),
2546 *computedLayout, srcType.getMemorySpace());
2547 }
2548
2549 if (expectedResultType != resultType)
2550 return emitOpError("expected collapsed type to be ")
2551 << expectedResultType << " but found " << resultType;
2552
2553 return success();
2554 }
2555
2558 public:
2560
2563 auto cast = op.getOperand().getDefiningOp();
2564 if (!cast)
2565 return failure();
2566
2568 return failure();
2569
2570 Type newResultType = CollapseShapeOp::computeCollapsedType(
2571 llvm::cast(cast.getOperand().getType()),
2572 op.getReassociationIndices());
2573
2574 if (newResultType == op.getResultType()) {
2576 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2577 } else {
2578 Value newOp = rewriter.create(
2579 op->getLoc(), cast.getSource(), op.getReassociationIndices());
2581 }
2582 return success();
2583 }
2584 };
2585
2586 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2588 results.add<
2591 memref::DimOp, MemRefType>,
2593 }
2594
2595 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2596 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2597 adaptor.getOperands());
2598 }
2599
2600 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2601 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2602 adaptor.getOperands());
2603 }
2604
2605
2606
2607
2608
2609 void ReshapeOp::getAsmResultNames(
2611 setNameFn(getResult(), "reshape");
2612 }
2613
2615 Type operandType = getSource().getType();
2616 Type resultType = getResult().getType();
2617
2618 Type operandElementType =
2619 llvm::cast(operandType).getElementType();
2620 Type resultElementType = llvm::cast(resultType).getElementType();
2621 if (operandElementType != resultElementType)
2622 return emitOpError("element types of source and destination memref "
2623 "types should be the same");
2624
2625 if (auto operandMemRefType = llvm::dyn_cast(operandType))
2626 if (!operandMemRefType.getLayout().isIdentity())
2627 return emitOpError("source memref type should have identity affine map");
2628
2629 int64_t shapeSize =
2630 llvm::cast(getShape().getType()).getDimSize(0);
2631 auto resultMemRefType = llvm::dyn_cast(resultType);
2632 if (resultMemRefType) {
2633 if (!resultMemRefType.getLayout().isIdentity())
2634 return emitOpError("result memref type should have identity affine map");
2635 if (shapeSize == ShapedType::kDynamic)
2636 return emitOpError("cannot use shape operand with dynamic length to "
2637 "reshape to statically-ranked memref type");
2638 if (shapeSize != resultMemRefType.getRank())
2639 return emitOpError(
2640 "length of shape operand differs from the result's memref rank");
2641 }
2642 return success();
2643 }
2644
2645
2646
2647
2648
2650 if (getNumOperands() != 2 + getMemRefType().getRank())
2651 return emitOpError("store index operand count not equal to memref rank");
2652
2653 return success();
2654 }
2655
2656 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2658
2660 }
2661
2662
2663
2664
2665
2666 void SubViewOp::getAsmResultNames(
2668 setNameFn(getResult(), "subview");
2669 }
2670
2671
2672
2673
2674 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2678 unsigned rank = sourceMemRefType.getRank();
2679 (void)rank;
2680 assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2681 assert(staticSizes.size() == rank && "staticSizes length mismatch");
2682 assert(staticStrides.size() == rank && "staticStrides length mismatch");
2683
2684
2685 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2686
2687
2688
2689 int64_t targetOffset = sourceOffset;
2690 for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2691 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2695 .asInteger();
2696 }
2697
2698
2699
2701 targetStrides.reserve(staticOffsets.size());
2702 for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2703 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2706 .asInteger());
2707 }
2708
2709
2710 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2712 targetOffset, targetStrides),
2713 sourceMemRefType.getMemorySpace());
2714 }
2715
2716 MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2726 return {};
2728 return {};
2730 return {};
2731 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2732 staticSizes, staticStrides);
2733 }
2734
2735 MemRefType SubViewOp::inferRankReducedResultType(
2736 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2739 MemRefType inferredType =
2740 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2741 assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2742 "expected ");
2743 if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2744 return inferredType;
2745
2746
2747 std::optional<llvm::SmallDenseSet> dimsToProject =
2749 assert(dimsToProject.has_value() && "invalid rank reduction");
2750
2751
2752 auto inferredLayout = llvm::cast(inferredType.getLayout());
2754 rankReducedStrides.reserve(resultShape.size());
2755 for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2756 if (!dimsToProject->contains(idx))
2757 rankReducedStrides.push_back(value);
2758 }
2759 return MemRefType::get(resultShape, inferredType.getElementType(),
2761 inferredLayout.getOffset(),
2762 rankReducedStrides),
2763 inferredType.getMemorySpace());
2764 }
2765
2766 MemRefType SubViewOp::inferRankReducedResultType(
2767 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2775 return SubViewOp::inferRankReducedResultType(
2776 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2777 staticStrides);
2778 }
2779
2780
2781
2783 MemRefType resultType, Value source,
2793 auto sourceMemRefType = llvm::cast(source.getType());
2794
2795 if (!resultType) {
2796 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2797 staticSizes, staticStrides);
2798 }
2800 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2804 }
2805
2806
2807
2813 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2814 }
2815
2816
2822 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2824 }));
2826 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2828 }));
2830 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2832 }));
2833 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2834 }
2835
2836
2837
2839 MemRefType resultType, Value source,
2844 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2846 }));
2848 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2850 }));
2852 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2854 }));
2855 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2856 attrs);
2857 }
2858
2859
2860
2862 MemRefType resultType, Value source, ValueRange offsets,
2866 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2868 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2870 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2871 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2872 }
2873
2874
2878 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2879 }
2880
2881
2882 Value SubViewOp::getViewSource() { return getSource(); }
2883
2884
2885
2887 int64_t t1Offset, t2Offset;
2889 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2890 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2891 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2892 }
2893
2894
2895
2896
2898 const llvm::SmallBitVector &droppedDims) {
2899 assert(size_t(t1.getRank()) == droppedDims.size() &&
2900 "incorrect number of bits");
2901 assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2902 "incorrect number of dropped dims");
2903 int64_t t1Offset, t2Offset;
2905 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2906 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2907 if (failed(res1) || failed(res2))
2908 return false;
2909 for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2910 if (droppedDims[i])
2911 continue;
2912 if (t1Strides[i] != t2Strides[j])
2913 return false;
2914 ++j;
2915 }
2916 return true;
2917 }
2918
2921 auto memrefType = llvm::cast(expectedType);
2922 switch (result) {
2924 return success();
2926 return op->emitError("expected result rank to be smaller or equal to ")
2927 << "the source rank. ";
2929 return op->emitError("expected result type to be ")
2930 << expectedType
2931 << " or a rank-reduced version. (mismatch of result sizes) ";
2933 return op->emitError("expected result element type to be ")
2934 << memrefType.getElementType();
2936 return op->emitError("expected result and source memory spaces to match.");
2938 return op->emitError("expected result type to be ")
2939 << expectedType
2940 << " or a rank-reduced version. (mismatch of result layout) ";
2941 }
2942 llvm_unreachable("unexpected subview verification result");
2943 }
2944
2945
2947 MemRefType baseType = getSourceType();
2948 MemRefType subViewType = getType();
2952
2953
2954 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2955 return emitError("different memory spaces specified for base memref "
2956 "type ")
2957 << baseType << " and subview memref type " << subViewType;
2958
2959
2960 if (!baseType.isStrided())
2961 return emitError("base type ") << baseType << " is not strided";
2962
2963
2964
2965 MemRefType expectedType = SubViewOp::inferResultType(
2966 baseType, staticOffsets, staticSizes, staticStrides);
2967
2968
2969
2971 expectedType, subViewType);
2974
2975
2976 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
2978 *this, expectedType);
2979
2980
2983 *this, expectedType);
2984
2985
2986
2987
2988
2991 if (failed(unusedDims))
2993 *this, expectedType);
2994
2995
2998 *this, expectedType);
2999
3000
3001
3004 staticStrides, true);
3005 if (!boundsResult.isValid)
3006 return getOperation()->emitError(boundsResult.errorMessage);
3007
3008 return success();
3009 }
3010
3012 return os << "range " << range.offset << ":" << range.size << ":"
3014 }
3015
3016
3017
3018
3021 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3022 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3023 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3025 unsigned rank = ranks[0];
3026 res.reserve(rank);
3027 for (unsigned idx = 0; idx < rank; ++idx) {
3029 op.isDynamicOffset(idx)
3030 ? op.getDynamicOffset(idx)
3033 op.isDynamicSize(idx)
3034 ? op.getDynamicSize(idx)
3037 op.isDynamicStride(idx)
3038 ? op.getDynamicStride(idx)
3040 res.emplace_back(Range{offset, size, stride});
3041 }
3042 return res;
3043 }
3044
3045
3046
3047
3048
3049
3050
3051
3053 MemRefType currentResultType, MemRefType currentSourceType,
3056 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3057 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3059 currentSourceType, currentResultType, mixedSizes);
3060 if (failed(unusedDims))
3061 return nullptr;
3062
3063 auto layout = llvm::cast(nonRankReducedType.getLayout());
3065 unsigned numDimsAfterReduction =
3066 nonRankReducedType.getRank() - unusedDims->count();
3067 shape.reserve(numDimsAfterReduction);
3068 strides.reserve(numDimsAfterReduction);
3069 for (const auto &[idx, size, stride] :
3070 llvm::zip(llvm::seq(0, nonRankReducedType.getRank()),
3071 nonRankReducedType.getShape(), layout.getStrides())) {
3072 if (unusedDims->test(idx))
3073 continue;
3074 shape.push_back(size);
3075 strides.push_back(stride);
3076 }
3077
3078 return MemRefType::get(shape, nonRankReducedType.getElementType(),
3080 layout.getOffset(), strides),
3081 nonRankReducedType.getMemorySpace());
3082 }
3083
3086 auto memrefType = llvm::cast(memref.getType());
3087 unsigned rank = memrefType.getRank();
3091 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3092 targetShape, memrefType, offsets, sizes, strides);
3093 return b.createOrFoldmemref::SubViewOp(loc, targetType, memref, offsets,
3094 sizes, strides);
3095 }
3096
3097 FailureOr SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3100 auto sourceMemrefType = llvm::dyn_cast(value.getType());
3101 assert(sourceMemrefType && "not a ranked memref type");
3102 auto sourceShape = sourceMemrefType.getShape();
3103 if (sourceShape.equals(desiredShape))
3104 return value;
3105 auto maybeRankReductionMask =
3107 if (!maybeRankReductionMask)
3108 return failure();
3110 }
3111
3112
3113
3114
3115
3117 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3118 return false;
3119
3120 auto mixedOffsets = subViewOp.getMixedOffsets();
3121 auto mixedSizes = subViewOp.getMixedSizes();
3122 auto mixedStrides = subViewOp.getMixedStrides();
3123
3124
3125 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3127 return !intValue || intValue.value() != 0;
3128 }))
3129 return false;
3130
3131
3132 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3134 return !intValue || intValue.value() != 1;
3135 }))
3136 return false;
3137
3138
3139 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3142 if (!intValue || *intValue != sourceShape[size.index()])
3143 return false;
3144 }
3145
3146 return true;
3147 }
3148
3149 namespace {
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166 class SubViewOpMemRefCastFolder final : public OpRewritePattern {
3167 public:
3169
3170 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3172
3173
3174 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3175 return matchPattern(operand, matchConstantIndex());
3176 }))
3177 return failure();
3178
3179 auto castOp = subViewOp.getSource().getDefiningOp();
3180 if (!castOp)
3181 return failure();
3182
3184 return failure();
3185
3186
3187
3188
3189
3191 subViewOp.getType(), subViewOp.getSourceType(),
3192 llvm::cast(castOp.getSource().getType()),
3193 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3194 subViewOp.getMixedStrides());
3195 if (!resultType)
3196 return failure();
3197
3198 Value newSubView = rewriter.create(
3199 subViewOp.getLoc(), resultType, castOp.getSource(),
3200 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3201 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3202 subViewOp.getStaticStrides());
3204 newSubView);
3205 return success();
3206 }
3207 };
3208
3209
3210
3211 class TrivialSubViewOpFolder final : public OpRewritePattern {
3212 public:
3214
3215 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3218 return failure();
3219 if (subViewOp.getSourceType() == subViewOp.getType()) {
3220 rewriter.replaceOp(subViewOp, subViewOp.getSource());
3221 return success();
3222 }
3224 subViewOp.getSource());
3225 return success();
3226 }
3227 };
3228 }
3229
3230
3235
3236 MemRefType resTy = SubViewOp::inferResultType(
3237 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3238 if (!resTy)
3239 return {};
3240 MemRefType nonReducedType = resTy;
3241
3242
3243 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3244 if (droppedDims.none())
3245 return nonReducedType;
3246
3247
3248 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3249
3250
3253 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3254 if (droppedDims.test(i))
3255 continue;
3256 targetStrides.push_back(nonReducedStrides[i]);
3257 targetShape.push_back(nonReducedType.getDimSize(i));
3258 }
3259
3260 return MemRefType::get(targetShape, nonReducedType.getElementType(),
3262 offset, targetStrides),
3263 nonReducedType.getMemorySpace());
3264 }
3265 };
3266
3267
3271 }
3272 };
3273
3274 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3276 results
3279 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3280 }
3281
3282 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3283 MemRefType sourceMemrefType = getSource().getType();
3284 MemRefType resultMemrefType = getResult().getType();
3285 auto resultLayout =
3286 dyn_cast_if_present(resultMemrefType.getLayout());
3287
3288 if (resultMemrefType == sourceMemrefType &&
3289 resultMemrefType.hasStaticShape() &&
3290 (!resultLayout || resultLayout.hasStaticLayout())) {
3291 return getViewSource();
3292 }
3293
3294
3295
3296
3297 if (auto srcSubview = getViewSource().getDefiningOp()) {
3298 auto srcSizes = srcSubview.getMixedSizes();
3300 auto offsets = getMixedOffsets();
3301 bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
3302 auto strides = getMixedStrides();
3303 bool allStridesOne = llvm::all_of(strides, isOneInteger);
3304 bool allSizesSame = llvm::equal(sizes, srcSizes);
3305 if (allOffsetsZero && allStridesOne && allSizesSame &&
3306 resultMemrefType == sourceMemrefType)
3307 return getViewSource();
3308 }
3309
3310 return {};
3311 }
3312
3313
3314
3315
3316
3317 void TransposeOp::getAsmResultNames(
3319 setNameFn(getResult(), "transpose");
3320 }
3321
3322
3325 auto originalSizes = memRefType.getShape();
3326 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3327 assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3328
3329
3330 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3331 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3332
3337 }
3338
3340 AffineMapAttr permutation,
3342 auto permutationMap = permutation.getValue();
3343 assert(permutationMap);
3344
3345 auto memRefType = llvm::cast(in.getType());
3346
3348
3349 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3350 build(b, result, resultType, in, attrs);
3351 }
3352
3353
3355 p << " " << getIn() << " " << getPermutation();
3357 p << " : " << getIn().getType() << " to " << getType();
3358 }
3359
3363 MemRefType srcType, dstType;
3370 return failure();
3371
3372 result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3374 return success();
3375 }
3376
3379 return emitOpError("expected a permutation map");
3380 if (getPermutation().getNumDims() != getIn().getType().getRank())
3381 return emitOpError("expected a permutation map of same rank as the input");
3382
3383 auto srcType = llvm::cast(getIn().getType());
3384 auto resultType = llvm::cast(getType());
3386 .canonicalizeStridedLayout();
3387
3388 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3389 return emitOpError("result type ")
3390 << resultType
3391 << " is not equivalent to the canonical transposed input type "
3392 << canonicalResultType;
3393 return success();
3394 }
3395
3396 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3397
3398
3399 if (getPermutation().isIdentity() && getType() == getIn().getType())
3400 return getIn();
3401
3402
3403 if (auto otherTransposeOp = getIn().getDefiningOpmemref::TransposeOp()) {
3405 getPermutation().compose(otherTransposeOp.getPermutation());
3406 getInMutable().assign(otherTransposeOp.getIn());
3407 setPermutation(composedPermutation);
3408 return getResult();
3409 }
3410 return {};
3411 }
3412
3413
3414
3415
3416
3417 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3418 setNameFn(getResult(), "view");
3419 }
3420
3422 auto baseType = llvm::cast(getOperand(0).getType());
3423 auto viewType = getType();
3424
3425
3426 if (!baseType.getLayout().isIdentity())
3427 return emitError("unsupported map for base memref type ") << baseType;
3428
3429
3430 if (!viewType.getLayout().isIdentity())
3431 return emitError("unsupported map for result memref type ") << viewType;
3432
3433
3434 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3435 return emitError("different memory spaces specified for base memref "
3436 "type ")
3437 << baseType << " and view memref type " << viewType;
3438
3439
3440 unsigned numDynamicDims = viewType.getNumDynamicDims();
3441 if (getSizes().size() != numDynamicDims)
3442 return emitError("incorrect number of size operands for type ") << viewType;
3443
3444 return success();
3445 }
3446
3447 Value ViewOp::getViewSource() { return getSource(); }
3448
3449 namespace {
3450
3451 struct ViewOpShapeFolder : public OpRewritePattern {
3453
3454 LogicalResult matchAndRewrite(ViewOp viewOp,
3456
3457 if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3458 return matchPattern(operand, matchConstantIndex());
3459 }))
3460 return failure();
3461
3462
3463 auto memrefType = viewOp.getType();
3464
3465
3466 int64_t oldOffset;
3468 if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3469 return failure();
3470 assert(oldOffset == 0 && "Expected 0 offset");
3471
3473
3474
3475
3476
3478 newShapeConstants.reserve(memrefType.getRank());
3479
3480 unsigned dynamicDimPos = 0;
3481 unsigned rank = memrefType.getRank();
3482 for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3483 int64_t dimSize = memrefType.getDimSize(dim);
3484
3485 if (!ShapedType::isDynamic(dimSize)) {
3486 newShapeConstants.push_back(dimSize);
3487 continue;
3488 }
3489 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3490 if (auto constantIndexOp =
3491 dyn_cast_or_nullarith::ConstantIndexOp(defOp)) {
3492
3493 newShapeConstants.push_back(constantIndexOp.value());
3494 } else {
3495
3496 newShapeConstants.push_back(dimSize);
3497 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3498 }
3499 dynamicDimPos++;
3500 }
3501
3502
3503 MemRefType newMemRefType =
3505
3506 if (newMemRefType == memrefType)
3507 return failure();
3508
3509
3510 auto newViewOp = rewriter.create(
3511 viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3512 viewOp.getByteShift(), newOperands);
3513
3514 rewriter.replaceOpWithNewOp(viewOp, viewOp.getType(), newViewOp);
3515 return success();
3516 }
3517 };
3518
3519 struct ViewOpMemrefCastFolder : public OpRewritePattern {
3521
3522 LogicalResult matchAndRewrite(ViewOp viewOp,
3524 Value memrefOperand = viewOp.getOperand(0);
3525 CastOp memrefCastOp = memrefOperand.getDefiningOp();
3526 if (!memrefCastOp)
3527 return failure();
3528 Value allocOperand = memrefCastOp.getOperand();
3529 AllocOp allocOp = allocOperand.getDefiningOp();
3530 if (!allocOp)
3531 return failure();
3532 rewriter.replaceOpWithNewOp(viewOp, viewOp.getType(), allocOperand,
3533 viewOp.getByteShift(),
3534 viewOp.getSizes());
3535 return success();
3536 }
3537 };
3538
3539 }
3540
3541 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3543 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3544 }
3545
3546
3547
3548
3549
3551 if (getMemRefType().getRank() != getNumOperands() - 2)
3552 return emitOpError(
3553 "expects the number of subscripts to be equal to memref rank");
3554 switch (getKind()) {
3555 case arith::AtomicRMWKind::addf:
3556 case arith::AtomicRMWKind::maximumf:
3557 case arith::AtomicRMWKind::minimumf:
3558 case arith::AtomicRMWKind::mulf:
3559 if (!llvm::isa(getValue().getType()))
3560 return emitOpError() << "with kind '"
3561 << arith::stringifyAtomicRMWKind(getKind())
3562 << "' expects a floating-point type";
3563 break;
3564 case arith::AtomicRMWKind::addi:
3565 case arith::AtomicRMWKind::maxs:
3566 case arith::AtomicRMWKind::maxu:
3567 case arith::AtomicRMWKind::mins:
3568 case arith::AtomicRMWKind::minu:
3569 case arith::AtomicRMWKind::muli:
3570 case arith::AtomicRMWKind::ori:
3571 case arith::AtomicRMWKind::andi:
3572 if (!llvm::isa(getValue().getType()))
3573 return emitOpError() << "with kind '"
3574 << arith::stringifyAtomicRMWKind(getKind())
3575 << "' expects an integer type";
3576 break;
3577 default:
3578 break;
3579 }
3580 return success();
3581 }
3582
3583 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3584
3586 return getResult();
3588 }
3589
3590
3591
3592
3593
3594 #define GET_OP_CLASSES
3595 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static bool hasSideEffects(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isPermutation(std::vector< PermutationTy > permutation)
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, Operation *op, Type expectedType)
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
bool mightHaveTerminator()
Check whether this block might have a terminator.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setShape(ArrayRef< int64_t > newShape)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
type_range getType() const
Operation is the basic unit of execution within MLIR.
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
The following effect indicates that the operation allocates from some resource.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
static SaturatedInteger wrap(int64_t v)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.