MLIR: lib/Dialect/Affine/IR/AffineOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 #include "llvm/ADT/SmallVectorExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/MathExtras.h"
30 #include
31 #include
32
33 using namespace mlir;
35
36 using llvm::divideCeilSigned;
37 using llvm::divideFloorSigned;
38 using llvm::mod;
39
40 #define DEBUG_TYPE "affine-ops"
41
42 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
43
44
45
46
47
49 if (auto arg = llvm::dyn_cast(value))
50 return arg.getParentRegion() == region;
52 }
53
54
55
56
57
58
59 static bool
63
64
65
66
68 return true;
69
70
71
72
73 if (llvm::isa(value))
74 return legalityCheck(mapping.lookup(value), dest);
75
76
77
78
79
81 bool isDimLikeOp = isa(value.getDefiningOp());
83 isDimLikeOp;
84 }
85
86
87
88 static bool
92 return llvm::all_of(values, [&](Value v) {
94 });
95 }
96
97
98
99 template
102 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
103 AffineWriteOpInterface>::value,
104 "only ops with affine read/write interface are supported");
105
106 AffineMap map = op.getAffineMap();
109 op.getMapOperands().take_back(map.getNumSymbols());
111 dimOperands, src, dest, mapping,
113 return false;
115 symbolOperands, src, dest, mapping,
117 return false;
118 return true;
119 }
120
121
122
123
124
125 template <>
129
132 op.getMapOperands(), src, dest, mapping,
134
135
137 op.getMapOperands(), src, dest, mapping,
139 }
140
141
142
143
144
145 namespace {
146
147
150
151
152
153
154
155
156
157
158
160 IRMapping &valueMapping) const final {
161
162
164 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
165 return false;
166
167
168
169 if (!llvm::hasSingleElement(*src))
170 return false;
171
172
173
176
177 if (auto iface = dyn_cast(op)) {
178 if (iface.hasNoEffect())
179 continue;
180 }
181
182
183
184 bool remainsValid =
186 .Case<AffineApplyOp, AffineReadOpInterface,
187 AffineWriteOpInterface>([&](auto op) {
189 })
191
192 return false;
193 });
194
195 if (!remainsValid)
196 return false;
197 }
198
199 return true;
200 }
201
202
203
205 IRMapping &valueMapping) const final {
206
207
208
209
212 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
213 }
214
215
216 bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
217 };
218 }
219
220
221
222
223
224 void AffineDialect::initialize() {
226 #define GET_OP_LIST
227 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
228 >();
229 addInterfaces();
230 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
231 AffineMinOp>();
232 }
233
234
235
239 if (auto poison = dyn_castub::PoisonAttr(value))
240 return builder.createub::PoisonOp(loc, type, poison);
241 return arith::ConstantOp::materialize(builder, value, type, loc);
242 }
243
244
245
246
247
249 if (auto arg = llvm::dyn_cast(value)) {
250
251
252
255 }
256
259 }
260
261
262
264 auto *curOp = op;
265 while (auto *parentOp = curOp->getParentOp()) {
268 curOp = parentOp;
269 }
270 return nullptr;
271 }
272
275 while (auto *parentOp = curOp->getParentOp()) {
276 if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
278 curOp = parentOp;
279 }
280 return nullptr;
281 }
282
283
284
285
286
287
289
291 return false;
292
295
296
297
298
300 return true;
301 auto *parentOp = llvm::cast(value).getOwner()->getParentOp();
303 }
304
305
306
307
308
309
310
311
313
315 return false;
316
317
319 return true;
320
322 if (!op) {
323
324
326 }
327
328
329 if (auto applyOp = dyn_cast(op))
330 return applyOp.isValidDim(region);
331
332
333 if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))
334 return llvm::all_of(op->getOperands(),
335 [&](Value arg) { return ::isValidDim(arg, region); });
336
337
338 if (auto dimOp = dyn_cast(op))
340 return false;
341 }
342
343
344
345
346 template
349 MemRefType memRefType = memrefDefOp.getType();
350
351
352 if (index >= memRefType.getRank()) {
353 return false;
354 }
355
356
357 if (!memRefType.isDynamicDim(index))
358 return true;
359
360 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
361 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
362 region);
363 }
364
365
367
369 return true;
370
371
372
373 if (llvm::isa(dimOp.getShapedValue()))
374 return false;
375
376
377
378 std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
379
380
381 if (!index.has_value())
382 return false;
383
384
385 Operation *op = dimOp.getShapedValue().getDefiningOp();
386 while (auto castOp = dyn_castmemref::CastOp(op)) {
387
388 if (isa(castOp.getSource().getType()))
389 return false;
390 op = castOp.getSource().getDefiningOp();
391 if (!op)
392 return false;
393 }
394
395 int64_t i = index.value();
397 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
399 .Default([](Operation *) { return false; });
400 }
401
402
403
404
405
406
407
408
409
411 if (!value)
412 return false;
413
414
416 return false;
417
418
420 return true;
421
424
425 return false;
426 }
427
428
429
430
431
432
433
434
435
436
437
438
439
441
443 return false;
444
445
447 return true;
448
450 if (!defOp) {
451
452
457 return false;
458 }
459
460
463 return true;
464
465
466 if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
467 return affine::isValidSymbol(operand, region);
468 })) {
469 return true;
470 }
471
472
473 if (auto dimOp = dyn_cast(defOp))
475
476
481
482 return false;
483 }
484
485
486
487
490 }
491
492
497 printer << '(' << operands.take_front(numDims) << ')';
498 if (operands.size() > numDims)
499 printer << '[' << operands.drop_front(numDims) << ']';
500 }
501
502
507 return failure();
508
509 numDims = opInfos.size();
510
511
516 }
517
518
519
520
521
522
523 template
524 static LogicalResult
526 unsigned numDims) {
527 unsigned opIt = 0;
528 for (auto operand : operands) {
529 if (opIt++ < numDims) {
531 return op.emitOpError("operand cannot be used as a dimension id");
533 return op.emitOpError("operand cannot be used as a symbol");
534 }
535 }
536 return success();
537 }
538
539
540
541
542
544 return AffineValueMap(getAffineMap(), getOperands(), getResult());
545 }
546
550
551 AffineMapAttr mapAttr;
552 unsigned numDims;
556 return failure();
557 auto map = mapAttr.getValue();
558
559 if (map.getNumDims() != numDims ||
560 numDims + map.getNumSymbols() != result.operands.size()) {
562 "dimension or symbol index mismatch");
563 }
564
565 result.types.append(map.getNumResults(), indexTy);
566 return success();
567 }
568
570 p << " " << getMapAttr();
572 getAffineMap().getNumDims(), p);
574 }
575
577
579
580
582 return emitOpError(
583 "operand count and affine map dimension and symbol count must match");
584
585
587 return emitOpError("mapping must produce one value");
588
589
590
591
593 for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) {
595 return emitError("dimensional operand cannot be used as a symbol");
596 }
597
598 return success();
599 }
600
601
602
604 return llvm::all_of(getOperands(),
606 }
607
608
609
610
612 return llvm::all_of(getOperands(),
614 }
615
616
617
619 return llvm::all_of(getOperands(),
621 }
622
623
624
626 return llvm::all_of(getOperands(), [&](Value operand) {
628 });
629 }
630
631 OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
632 auto map = getAffineMap();
633
634
635 auto expr = map.getResult(0);
636 if (auto dim = dyn_cast(expr))
637 return getOperand(dim.getPosition());
638 if (auto sym = dyn_cast(expr))
639 return getOperand(map.getNumDims() + sym.getPosition());
640
641
643 bool hasPoison = false;
644 auto foldResult =
645 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
646 if (hasPoison)
648 if (failed(foldResult))
649 return {};
650 return result[0];
651 }
652
653
654
656
658
659
660
661
662
663
664
665 auto dimExpr = dyn_cast(e);
666
667 if (!dimExpr)
668 return div;
669
670
671
672
673
674
675 Value operand = operands[dimExpr.getPosition()];
676 int64_t operandDivisor = 1;
677
678
680 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
681 operandDivisor = forOp.getStepAsInt();
682 } else {
683 uint64_t lbLargestKnownDivisor =
684 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
685 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
686 }
687 }
688 return operandDivisor;
689 }
690
691
692
694 int64_t k) {
695 if (auto constExpr = dyn_cast(e)) {
696 int64_t constVal = constExpr.getValue();
697 return constVal >= 0 && constVal < k;
698 }
699 auto dimExpr = dyn_cast(e);
700 if (!dimExpr)
701 return false;
702 Value operand = operands[dimExpr.getPosition()];
703
704
706 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
707 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
708 return true;
709 }
710 }
711
712
713
714
715 return false;
716 }
717
718
719
720
723 auto bin = dyn_cast(e);
725 return false;
726
731 quotientTimesDiv = llhs;
732 rem = rlhs;
733 return true;
734 }
737 quotientTimesDiv = rlhs;
738 rem = llhs;
739 return true;
740 }
741 return false;
742 }
743
744
747 if (forOp && forOp.hasConstantLowerBound())
748 return forOp.getConstantLowerBound();
749 return std::nullopt;
750 }
751
752
755 if (!forOp || !forOp.hasConstantUpperBound())
756 return std::nullopt;
757
758
759
760 if (forOp.hasConstantLowerBound()) {
761 return forOp.getConstantUpperBound() - 1 -
762 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
763 forOp.getStepAsInt();
764 }
765 return forOp.getConstantUpperBound() - 1;
766 }
767
768
769
770
772 unsigned numSymbols,
774
776 constLowerBounds.reserve(operands.size());
777 constUpperBounds.reserve(operands.size());
778 for (Value operand : operands) {
779 constLowerBounds.push_back(getLowerBound(operand));
780 constUpperBounds.push_back(getUpperBound(operand));
781 }
782
783 if (auto constExpr = dyn_cast(expr))
784 return constExpr.getValue();
785
787 constUpperBounds,
788 true);
789 }
790
791
792
793
795 unsigned numSymbols,
797
799 constLowerBounds.reserve(operands.size());
800 constUpperBounds.reserve(operands.size());
801 for (Value operand : operands) {
802 constLowerBounds.push_back(getLowerBound(operand));
803 constUpperBounds.push_back(getUpperBound(operand));
804 }
805
806 std::optional<int64_t> lowerBound;
807 if (auto constExpr = dyn_cast(expr)) {
808 lowerBound = constExpr.getValue();
809 } else {
811 constLowerBounds, constUpperBounds,
812 false);
813 }
814 return lowerBound;
815 }
816
817
819 unsigned numSymbols,
821
822 auto binExpr = dyn_cast(expr);
823 if (!binExpr)
824 return;
825
826
832
833 binExpr = dyn_cast(expr);
837 return;
838 }
839
840
841 lhs = binExpr.getLHS();
842 rhs = binExpr.getRHS();
843 auto rhsConst = dyn_cast(rhs);
844 if (!rhsConst)
845 return;
846
847 int64_t rhsConstVal = rhsConst.getValue();
848
849 if (rhsConstVal <= 0)
850 return;
851
852
854 std::optional<int64_t> lhsLbConst =
855 getLowerBound(lhs, numDims, numSymbols, operands);
856 std::optional<int64_t> lhsUbConst =
857 getUpperBound(lhs, numDims, numSymbols, operands);
858 if (lhsLbConst && lhsUbConst) {
859 int64_t lhsLbConstVal = *lhsLbConst;
860 int64_t lhsUbConstVal = *lhsUbConst;
861
862
864 divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
865 divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
867 divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
868 return;
869 }
870
871
873 divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
874 divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
876 context);
877 return;
878 }
879
881 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
882 expr = lhs;
883 return;
884 }
885 }
886
887
888
889
890
892 int64_t divisor;
893 if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
894 if (rhsConstVal % divisor == 0 &&
896 expr = quotientTimesDiv.floorDiv(rhsConst);
897 } else if (divisor % rhsConstVal == 0 &&
899 expr = rem % rhsConst;
900 }
901 return;
902 }
903
904
905
906
907
913 }
914 }
915
916
917
918
919
920
923 bool isMax) {
924
925 if (operands.empty())
926 return;
927
928
929
931 constLowerBounds.reserve(operands.size());
932 constUpperBounds.reserve(operands.size());
933 for (Value operand : operands) {
934 constLowerBounds.push_back(getLowerBound(operand));
935 constUpperBounds.push_back(getUpperBound(operand));
936 }
937
938
939
940
941
942
947 if (auto constExpr = dyn_cast(e)) {
948 lowerBounds.push_back(constExpr.getValue());
949 upperBounds.push_back(constExpr.getValue());
950 } else {
951 lowerBounds.push_back(
953 constLowerBounds, constUpperBounds,
954 false));
955 upperBounds.push_back(
957 constLowerBounds, constUpperBounds,
958 true));
959 }
960 }
961
962
966 unsigned i = exprEn.index();
967
968 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
970
971
972 if (isMax) {
973 if (!upperBounds[i]) {
974 irredundantExprs.push_back(e);
975 continue;
976 }
977
978
979 if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
980 auto otherLowerBound = en.value();
981 unsigned pos = en.index();
982 if (pos == i || !otherLowerBound)
983 return false;
984 if (*otherLowerBound > *upperBounds[i])
985 return true;
986 if (*otherLowerBound < *upperBounds[i])
987 return false;
988
989
990
991 if (upperBounds[pos] && lowerBounds[i] &&
992 lowerBounds[i] == upperBounds[i] &&
993 otherLowerBound == *upperBounds[pos] && i < pos)
994 return false;
995 return true;
996 }))
997 irredundantExprs.push_back(e);
998 } else {
999 if (!lowerBounds[i]) {
1000 irredundantExprs.push_back(e);
1001 continue;
1002 }
1003
1004 if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
1005 auto otherUpperBound = en.value();
1006 unsigned pos = en.index();
1007 if (pos == i || !otherUpperBound)
1008 return false;
1009 if (*otherUpperBound < *lowerBounds[i])
1010 return true;
1011 if (*otherUpperBound > *lowerBounds[i])
1012 return false;
1013 if (lowerBounds[pos] && upperBounds[i] &&
1014 lowerBounds[i] == upperBounds[i] &&
1015 otherUpperBound == lowerBounds[pos] && i < pos)
1016 return false;
1017 return true;
1018 }))
1019 irredundantExprs.push_back(e);
1020 }
1021 }
1022
1023
1026 }
1027
1028
1029
1030
1031 static void LLVM_ATTRIBUTE_UNUSED
1033 assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1038 operands);
1039 newResults.push_back(expr);
1040 }
1043 }
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1056 unsigned dimOrSymbolPosition,
1060 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1061 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1062 : dimOrSymbolPosition - dims.size();
1063 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1064 if (!v)
1065 return failure();
1066
1067 auto affineApply = v.getDefiningOp();
1068 if (!affineApply)
1069 return failure();
1070
1071
1072
1073 v = nullptr;
1074
1075
1076 AffineMap composeMap = affineApply.getAffineMap();
1077 assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1078 SmallVector composeOperands(affineApply.getMapOperands().begin(),
1079 affineApply.getMapOperands().end());
1080
1081
1091
1092
1093 dims.append(composeDims.begin(), composeDims.end());
1094 syms.append(composeSyms.begin(), composeSyms.end());
1095 *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1096
1097 return success();
1098 }
1099
1100
1101
1102
1108 return;
1109 }
1110
1113 operands->begin() + map->getNumDims());
1115 operands->end());
1116
1117
1118
1119
1120
1121
1122 while (true) {
1124 for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1126 break;
1128 break;
1129 }
1130
1131
1132 operands->clear();
1133
1134
1135
1136 unsigned nDims = 0, nSyms = 0;
1138 dimReplacements.reserve(dims.size());
1139 symReplacements.reserve(syms.size());
1140 for (auto *container : {&dims, &syms}) {
1141 bool isDim = (container == &dims);
1142 auto &repls = isDim ? dimReplacements : symReplacements;
1144 Value v = en.value();
1145 if (!v) {
1148 "map is function of unexpected expr@pos");
1150 continue;
1151 }
1154 operands->push_back(v);
1155 }
1156 }
1158 nSyms);
1159
1160
1163 }
1164
1167 while (llvm::any_of(*operands, [](Value v) {
1168 return isa_and_nonnull(v.getDefiningOp());
1169 })) {
1171 }
1172 }
1173
1174 AffineApplyOp
1180 assert(map);
1181 return b.create(loc, map, valueOperands);
1182 }
1183
1184 AffineApplyOp
1188 b, loc,
1190 .front(),
1191 operands);
1192 }
1193
1194
1195
1198
1199
1200
1203 for (unsigned i : llvm::seq(0, map.getNumResults())) {
1204 SmallVector submapOperands(operands.begin(), operands.end());
1208 unsigned numNewDims = submap.getNumDims();
1210 llvm::append_range(dims,
1211 ArrayRef(submapOperands).take_front(numNewDims));
1212 llvm::append_range(symbols,
1213 ArrayRef(submapOperands).drop_front(numNewDims));
1214 exprs.push_back(submap.getResult(0));
1215 }
1216
1217
1218
1219 operands = llvm::to_vector(llvm::concat(dims, symbols));
1222 }
1223
1228 assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1229
1230
1231
1232
1233
1236
1237
1238 AffineApplyOp applyOp =
1240
1241
1243 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1245
1246
1248 if (failed(applyOp->fold(constOperands, foldResults)) ||
1249 foldResults.empty()) {
1251 listener->notifyOperationInserted(applyOp, {});
1252 return applyOp.getResult();
1253 }
1254
1255 applyOp->erase();
1256 return llvm::getSingleElement(foldResults);
1257 }
1258
1264 b, loc,
1266 .front(),
1267 operands);
1268 }
1269
1274 return llvm::map_to_vector(llvm::seq(0, map.getNumResults()),
1275 [&](unsigned i) {
1276 return makeComposedFoldedAffineApply(
1277 b, loc, map.getSubMap({i}), operands);
1278 });
1279 }
1280
1281 template
1288 }
1289
1290 AffineMinOp
1293 return makeComposedMinMax(b, loc, map, operands);
1294 }
1295
1296 template
1300
1301
1302
1303
1306
1307
1308 auto minMaxOp = makeComposedMinMax(newBuilder, loc, map, operands);
1309
1310
1312 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1314
1315
1317 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1318 foldResults.empty()) {
1320 listener->notifyOperationInserted(minMaxOp, {});
1321 return minMaxOp.getResult();
1322 }
1323
1324 minMaxOp->erase();
1325 return llvm::getSingleElement(foldResults);
1326 }
1327
1332 return makeComposedFoldedMinMax(b, loc, map, operands);
1333 }
1334
1339 return makeComposedFoldedMinMax(b, loc, map, operands);
1340 }
1341
1342
1343
1344 template
1347 if (!mapOrSet || operands->empty())
1348 return;
1349
1350 assert(mapOrSet->getNumInputs() == operands->size() &&
1351 "map/set inputs must match number of operands");
1352
1353 auto *context = mapOrSet->getContext();
1355 resultOperands.reserve(operands->size());
1357 remappedSymbols.reserve(operands->size());
1358 unsigned nextDim = 0;
1359 unsigned nextSym = 0;
1360 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1362 for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1363 if (i < mapOrSet->getNumDims()) {
1365
1367 remappedSymbols.push_back((*operands)[i]);
1368 } else {
1370 resultOperands.push_back((*operands)[i]);
1371 }
1372 } else {
1373 resultOperands.push_back((*operands)[i]);
1374 }
1375 }
1376
1377 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1378 *operands = resultOperands;
1379 *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1380 dimRemapping, {}, nextDim, oldNumSyms + nextSym);
1381
1382 assert(mapOrSet->getNumInputs() == operands->size() &&
1383 "map/set inputs must match number of operands");
1384 }
1385
1386
1387
1388
1389
1390
1391
1392 template
1395 if (!mapOrSet || operands.empty())
1396 return;
1397
1398 unsigned numOperands = operands.size();
1399
1400 assert(mapOrSet.getNumInputs() == numOperands &&
1401 "map/set inputs must match number of operands");
1402
1403 auto *context = mapOrSet.getContext();
1405 resultOperands.reserve(numOperands);
1407 remappedDims.reserve(numOperands);
1409 symOperands.reserve(mapOrSet.getNumSymbols());
1410 unsigned nextSym = 0;
1411 unsigned nextDim = 0;
1412 unsigned oldNumDims = mapOrSet.getNumDims();
1414 resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
1415 for (unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1417
1418 symRemapping[i - oldNumDims] =
1420 remappedDims.push_back(operands[i]);
1421 } else {
1423 symOperands.push_back(operands[i]);
1424 }
1425 }
1426
1427 append_range(resultOperands, remappedDims);
1428 append_range(resultOperands, symOperands);
1429 operands = resultOperands;
1430 mapOrSet = mapOrSet.replaceDimsAndSymbols(
1431 {}, symRemapping, oldNumDims + nextDim, nextSym);
1432
1433 assert(mapOrSet.getNumInputs() == operands.size() &&
1434 "map/set inputs must match number of operands");
1435 }
1436
1437
1438 template
1441 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1442 "Argument must be either of AffineMap or IntegerSet type");
1443
1444 if (!mapOrSet || operands->empty())
1445 return;
1446
1447 assert(mapOrSet->getNumInputs() == operands->size() &&
1448 "map/set inputs must match number of operands");
1449
1450 canonicalizePromotedSymbols(mapOrSet, operands);
1451 legalizeDemotedDims(*mapOrSet, *operands);
1452
1453
1454 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1455 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1456 mapOrSet->walkExprs([&](AffineExpr expr) {
1457 if (auto dimExpr = dyn_cast(expr))
1458 usedDims[dimExpr.getPosition()] = true;
1459 else if (auto symExpr = dyn_cast(expr))
1460 usedSyms[symExpr.getPosition()] = true;
1461 });
1462
1463 auto *context = mapOrSet->getContext();
1464
1466 resultOperands.reserve(operands->size());
1467
1468 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1470 unsigned nextDim = 0;
1471 for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1472 if (usedDims[i]) {
1473
1474 auto it = seenDims.find((*operands)[i]);
1475 if (it == seenDims.end()) {
1477 resultOperands.push_back((*operands)[i]);
1478 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1479 } else {
1480 dimRemapping[i] = it->second;
1481 }
1482 }
1483 }
1484 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1486 unsigned nextSym = 0;
1487 for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1488 if (!usedSyms[i])
1489 continue;
1490
1491
1492
1493 IntegerAttr operandCst;
1494 if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1496 symRemapping[i] =
1498 continue;
1499 }
1500
1501 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1502 if (it == seenSymbols.end()) {
1504 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1505 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1506 symRemapping[i]));
1507 } else {
1508 symRemapping[i] = it->second;
1509 }
1510 }
1511 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1512 nextDim, nextSym);
1513 *operands = resultOperands;
1514 }
1515
1518 canonicalizeMapOrSetAndOperands(map, operands);
1519 }
1520
1523 canonicalizeMapOrSetAndOperands(set, operands);
1524 }
1525
1526 namespace {
1527
1528
1529
1530 template
1531 struct SimplifyAffineOp : public OpRewritePattern {
1533
1534
1535
1536 void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1538
1539 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1541 static_assert(
1542 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1543 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1544 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1545 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1546 "expected");
1547 auto map = affineOp.getAffineMap();
1549 auto oldOperands = affineOp.getMapOperands();
1554 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1555 resultOperands.begin()))
1556 return failure();
1557
1558 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1559 return success();
1560 }
1561 };
1562
1563
1564
1565 template <>
1566 void SimplifyAffineOp::replaceAffineOp(
1569 rewriter.replaceOpWithNewOp(load, load.getMemRef(), map,
1570 mapOperands);
1571 }
1572 template <>
1573 void SimplifyAffineOp::replaceAffineOp(
1577 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1578 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1579 }
1580 template <>
1581 void SimplifyAffineOp::replaceAffineOp(
1585 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1586 }
1587 template <>
1588 void SimplifyAffineOp::replaceAffineOp(
1592 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1593 mapOperands);
1594 }
1595 template <>
1596 void SimplifyAffineOp::replaceAffineOp(
1600 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1601 mapOperands);
1602 }
1603
1604
1605 template
1606 void SimplifyAffineOp::replaceAffineOp(
1610 }
1611 }
1612
1613 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1615 results.add<SimplifyAffineOp>(context);
1616 }
1617
1618
1619
1620
1621
1622
1629 Value stride, Value elementsPerStride) {
1640 if (stride) {
1641 result.addOperands({stride, elementsPerStride});
1642 }
1643 }
1644
1646 p << " " << getSrcMemRef() << '[';
1648 p << "], " << getDstMemRef() << '[';
1650 p << "], " << getTagMemRef() << '[';
1653 if (isStrided()) {
1655 p << ", " << getNumElementsPerStride();
1656 }
1657 p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1658 << getTagMemRefType();
1659 }
1660
1661
1662
1663
1664
1665
1666
1670 AffineMapAttr srcMapAttr;
1673 AffineMapAttr dstMapAttr;
1676 AffineMapAttr tagMapAttr;
1680
1683
1684
1685
1686
1687
1688
1691 getSrcMapAttrStrName(),
1695 getDstMapAttrStrName(),
1699 getTagMapAttrStrName(),
1702 return failure();
1703
1704
1706 return failure();
1707
1708 if (!strideInfo.empty() && strideInfo.size() != 2) {
1710 "expected two stride related operands");
1711 }
1712 bool isStrided = strideInfo.size() == 2;
1713
1715 return failure();
1716
1717 if (types.size() != 3)
1719
1727 return failure();
1728
1729 if (isStrided) {
1731 return failure();
1732 }
1733
1734
1735 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1736 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1737 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1739 "memref operand count not equal to map.numInputs");
1740 return success();
1741 }
1742
1743 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1744 if (!llvm::isa(getOperand(getSrcMemRefOperandIndex()).getType()))
1745 return emitOpError("expected DMA source to be of memref type");
1746 if (!llvm::isa(getOperand(getDstMemRefOperandIndex()).getType()))
1747 return emitOpError("expected DMA destination to be of memref type");
1748 if (!llvm::isa(getOperand(getTagMemRefOperandIndex()).getType()))
1749 return emitOpError("expected DMA tag to be of memref type");
1750
1751 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1752 getDstMap().getNumInputs() +
1753 getTagMap().getNumInputs();
1754 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1755 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1756 return emitOpError("incorrect number of operands");
1757 }
1758
1760 for (auto idx : getSrcIndices()) {
1761 if (!idx.getType().isIndex())
1762 return emitOpError("src index to dma_start must have 'index' type");
1764 return emitOpError(
1765 "src index must be a valid dimension or symbol identifier");
1766 }
1767 for (auto idx : getDstIndices()) {
1768 if (!idx.getType().isIndex())
1769 return emitOpError("dst index to dma_start must have 'index' type");
1771 return emitOpError(
1772 "dst index must be a valid dimension or symbol identifier");
1773 }
1774 for (auto idx : getTagIndices()) {
1775 if (!idx.getType().isIndex())
1776 return emitOpError("tag index to dma_start must have 'index' type");
1778 return emitOpError(
1779 "tag index must be a valid dimension or symbol identifier");
1780 }
1781 return success();
1782 }
1783
1786
1788 }
1789
1790 void AffineDmaStartOp::getEffects(
1792 &effects) {
1799 }
1800
1801
1802
1803
1804
1805
1813 }
1814
1816 p << " " << getTagMemRef() << '[';
1819 p << "], ";
1821 p << " : " << getTagMemRef().getType();
1822 }
1823
1824
1825
1826
1827
1828
1832 AffineMapAttr tagMapAttr;
1837
1838
1841 getTagMapAttrStrName(),
1848 return failure();
1849
1850 if (!llvm::isa(type))
1852 "expected tag to be of memref type");
1853
1854 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1856 "tag memref operand count != to map.numInputs");
1857 return success();
1858 }
1859
1860 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1861 if (!llvm::isa(getOperand(0).getType()))
1862 return emitOpError("expected DMA tag to be of memref type");
1864 for (auto idx : getTagIndices()) {
1865 if (!idx.getType().isIndex())
1866 return emitOpError("index to dma_wait must have 'index' type");
1868 return emitOpError(
1869 "index must be a valid dimension or symbol identifier");
1870 }
1871 return success();
1872 }
1873
1876
1878 }
1879
1880 void AffineDmaWaitOp::getEffects(
1882 &effects) {
1885 }
1886
1887
1888
1889
1890
1891
1892
1896 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1897 assert(((!lbMap && lbOperands.empty()) ||
1898 lbOperands.size() == lbMap.getNumInputs()) &&
1899 "lower bound operand count does not match the affine map");
1900 assert(((!ubMap && ubOperands.empty()) ||
1901 ubOperands.size() == ubMap.getNumInputs()) &&
1902 "upper bound operand count does not match the affine map");
1903 assert(step > 0 && "step has to be a positive integer constant");
1904
1906
1907
1909 getOperandSegmentSizeAttr(),
1911 static_cast<int32_t>(ubOperands.size()),
1912 static_cast<int32_t>(iterArgs.size())}));
1913
1914 for (Value val : iterArgs)
1915 result.addTypes(val.getType());
1916
1917
1920
1921
1925
1926
1930
1932
1933
1936 Value inductionVar =
1938 for (Value val : iterArgs)
1939 bodyBlock->addArgument(val.getType(), val.getLoc());
1940
1941
1942
1943
1944 if (iterArgs.empty() && !bodyBuilder) {
1945 ensureTerminator(*bodyRegion, builder, result.location);
1946 } else if (bodyBuilder) {
1949 bodyBuilder(builder, result.location, inductionVar,
1951 }
1952 }
1953
1955 int64_t ub, int64_t step, ValueRange iterArgs,
1956 BodyBuilderFn bodyBuilder) {
1959 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1960 bodyBuilder);
1961 }
1962
1963 LogicalResult AffineForOp::verifyRegions() {
1964
1965
1966 auto *body = getBody();
1967 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1968 return emitOpError("expected body to have a single index argument for the "
1969 "induction variable");
1970
1971
1972
1973 if (getLowerBoundMap().getNumInputs() > 0)
1975 getLowerBoundMap().getNumDims())))
1976 return failure();
1977
1978 if (getUpperBoundMap().getNumInputs() > 0)
1980 getUpperBoundMap().getNumDims())))
1981 return failure();
1982 if (getLowerBoundMap().getNumResults() < 1)
1983 return emitOpError("expected lower bound map to have at least one result");
1984 if (getUpperBoundMap().getNumResults() < 1)
1985 return emitOpError("expected upper bound map to have at least one result");
1986
1987 unsigned opNumResults = getNumResults();
1988 if (opNumResults == 0)
1989 return success();
1990
1991
1992
1993
1994 if (getNumIterOperands() != opNumResults)
1995 return emitOpError(
1996 "mismatch between the number of loop-carried values and results");
1997 if (getNumRegionIterArgs() != opNumResults)
1998 return emitOpError(
1999 "mismatch between the number of basic block args and results");
2000
2001 return success();
2002 }
2003
2004
2007
2008
2009 bool failedToParsedMinMax =
2011
2013 auto boundAttrStrName =
2014 isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
2015 : AffineForOp::getUpperBoundMapAttrName(result.name);
2016
2017
2020 return failure();
2021
2022 if (!boundOpInfos.empty()) {
2023
2024 if (boundOpInfos.size() > 1)
2026 "expected only one loop bound operand");
2027
2028
2029
2032 return failure();
2033
2034
2035
2036
2039 return success();
2040 }
2041
2042
2044
2048 return failure();
2049
2050
2051 if (auto affineMapAttr = llvm::dyn_cast(boundAttr)) {
2052 unsigned currentNumOperands = result.operands.size();
2053 unsigned numDims;
2055 return failure();
2056
2057 auto map = affineMapAttr.getValue();
2061 "dim operand count and affine map dim count must match");
2062
2063 unsigned numDimAndSymbolOperands =
2064 result.operands.size() - currentNumOperands;
2065 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
2068 "symbol operand count and affine map symbol count must match");
2069
2070
2071
2072 if (map.getNumResults() > 1 && failedToParsedMinMax) {
2073 if (isLower) {
2074 return p.emitError(attrLoc, "lower loop bound affine map with "
2075 "multiple results requires 'max' prefix");
2076 }
2077 return p.emitError(attrLoc, "upper loop bound affine map with multiple "
2078 "results requires 'min' prefix");
2079 }
2080 return success();
2081 }
2082
2083
2084 if (auto integerAttr = llvm::dyn_cast(boundAttr)) {
2087 boundAttrStrName,
2089 return success();
2090 }
2091
2094 "expected valid affine map representation for loop bounds");
2095 }
2096
2098 auto &builder = parser.getBuilder();
2101
2103 return failure();
2104
2105
2106 int64_t numOperands = result.operands.size();
2107 if (parseBound(true, result, parser))
2108 return failure();
2109 int64_t numLbOperands = result.operands.size() - numOperands;
2110 if (parser.parseKeyword("to", " between bounds"))
2111 return failure();
2112 numOperands = result.operands.size();
2113 if (parseBound(false, result, parser))
2114 return failure();
2115 int64_t numUbOperands = result.operands.size() - numOperands;
2116
2117
2120 getStepAttrName(result.name),
2122 } else {
2124 IntegerAttr stepAttr;
2126 getStepAttrName(result.name).data(),
2128 return failure();
2129
2130 if (stepAttr.getValue().isNegative())
2132 stepLoc,
2133 "expected step to be representable as a positive signed integer");
2134 }
2135
2136
2139
2140
2141 regionArgs.push_back(inductionVariable);
2142
2144
2147 return failure();
2148
2149 for (auto argOperandType :
2150 llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2151 Type type = std::get<2>(argOperandType);
2152 std::get<0>(argOperandType).type = type;
2153 if (parser.resolveOperand(std::get<1>(argOperandType), type,
2155 return failure();
2156 }
2157 }
2158
2160 getOperandSegmentSizeAttr(),
2162 static_cast<int32_t>(numUbOperands),
2163 static_cast<int32_t>(operands.size())}));
2164
2165
2167 if (regionArgs.size() != result.types.size() + 1)
2170 "mismatch between the number of loop-carried values and results");
2171 if (parser.parseRegion(*body, regionArgs))
2172 return failure();
2173
2174 AffineForOp::ensureTerminator(*body, builder, result.location);
2175
2176
2178 }
2179
2183 AffineMap map = boundMap.getValue();
2184
2185
2186
2187
2188
2189
2190
2193
2194
2196 if (auto constExpr = dyn_cast(expr)) {
2197 p << constExpr.getValue();
2198 return;
2199 }
2200 }
2201
2202
2203
2205 if (isa(expr)) {
2207 return;
2208 }
2209 }
2210 } else {
2211
2212 p << prefix << ' ';
2213 }
2214
2215
2216 p << boundMap;
2219 }
2220
2221 unsigned AffineForOp::getNumIterOperands() {
2222 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2223 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2224
2226 }
2227
2228 std::optional<MutableArrayRef>
2229 AffineForOp::getYieldedValuesMutable() {
2230 return cast(getBody()->getTerminator()).getOperandsMutable();
2231 }
2232
2234 p << ' ';
2236 true);
2237 p << " = ";
2239 p << " to ";
2241
2242 if (getStepAsInt() != 1)
2243 p << " step " << getStepAsInt();
2244
2245 bool printBlockTerminators = false;
2246 if (getNumIterOperands() > 0) {
2247 p << " iter_args(";
2248 auto regionArgs = getRegionIterArgs();
2249 auto operands = getInits();
2250
2251 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2252 p << std::get<0>(it) << " = " << std::get<1>(it);
2253 });
2254 p << ") -> (" << getResultTypes() << ")";
2255 printBlockTerminators = true;
2256 }
2257
2258 p << ' ';
2259 p.printRegion(getRegion(), false,
2260 printBlockTerminators);
2262 (*this)->getAttrs(),
2263 {getLowerBoundMapAttrName(getOperation()->getName()),
2264 getUpperBoundMapAttrName(getOperation()->getName()),
2265 getStepAttrName(getOperation()->getName()),
2266 getOperandSegmentSizeAttr()});
2267 }
2268
2269
2271 auto foldLowerOrUpperBound = [&forOp](bool lower) {
2272
2273
2275 auto boundOperands =
2276 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2277 for (auto operand : boundOperands) {
2280 operandConstants.push_back(operandCst);
2281 }
2282
2284 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2286 "bound maps should have at least one result");
2288 if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2289 return failure();
2290
2291
2292 assert(!foldedResults.empty() && "bounds should have at least one result");
2293 auto maxOrMin = llvm::cast(foldedResults[0]).getValue();
2294 for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2295 auto foldedResult = llvm::cast(foldedResults[i]).getValue();
2296 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2297 : llvm::APIntOps::smin(maxOrMin, foldedResult);
2298 }
2299 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2300 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2301 return success();
2302 };
2303
2304
2305 bool folded = false;
2306 if (!forOp.hasConstantLowerBound())
2307 folded |= succeeded(foldLowerOrUpperBound(true));
2308
2309
2310 if (!forOp.hasConstantUpperBound())
2311 folded |= succeeded(foldLowerOrUpperBound(false));
2312 return success(folded);
2313 }
2314
2315
2319
2320 auto lbMap = forOp.getLowerBoundMap();
2321 auto ubMap = forOp.getUpperBoundMap();
2322 auto prevLbMap = lbMap;
2323 auto prevUbMap = ubMap;
2324
2330
2334
2335
2336 if (lbMap == prevLbMap && ubMap == prevUbMap)
2337 return failure();
2338
2339 if (lbMap != prevLbMap)
2340 forOp.setLowerBound(lbOperands, lbMap);
2341 if (ubMap != prevUbMap)
2342 forOp.setUpperBound(ubOperands, ubMap);
2343 return success();
2344 }
2345
2346 namespace {
2347
2348 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2349 int64_t step = forOp.getStepAsInt();
2350 if (!forOp.hasConstantBounds() || step <= 0)
2351 return std::nullopt;
2352 int64_t lb = forOp.getConstantLowerBound();
2353 int64_t ub = forOp.getConstantUpperBound();
2354 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2355 }
2356
2357
2358
2359 struct AffineForEmptyLoopFolder : public OpRewritePattern {
2361
2362 LogicalResult matchAndRewrite(AffineForOp forOp,
2364
2365 if (!llvm::hasSingleElement(*forOp.getBody()))
2366 return failure();
2367 if (forOp.getNumResults() == 0)
2368 return success();
2369 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2370 if (tripCount == 0) {
2371
2372
2373 rewriter.replaceOp(forOp, forOp.getInits());
2374 return success();
2375 }
2377 auto yieldOp = cast(forOp.getBody()->getTerminator());
2378 auto iterArgs = forOp.getRegionIterArgs();
2379 bool hasValDefinedOutsideLoop = false;
2380 bool iterArgsNotInOrder = false;
2381 for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2382 Value val = yieldOp.getOperand(i);
2383 auto *iterArgIt = llvm::find(iterArgs, val);
2384
2385
2386 if (val == forOp.getInductionVar())
2387 return failure();
2388 if (iterArgIt == iterArgs.end()) {
2389
2390 assert(forOp.isDefinedOutsideOfLoop(val) &&
2391 "must be defined outside of the loop");
2392 hasValDefinedOutsideLoop = true;
2393 replacements.push_back(val);
2394 } else {
2395 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2396 if (pos != i)
2397 iterArgsNotInOrder = true;
2398 replacements.push_back(forOp.getInits()[pos]);
2399 }
2400 }
2401
2402
2403 if (!tripCount.has_value() &&
2404 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2405 return failure();
2406
2407
2408 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2409 return failure();
2410 rewriter.replaceOp(forOp, replacements);
2411 return success();
2412 }
2413 };
2414 }
2415
2416 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
2418 results.add(context);
2419 }
2420
2422 assert((point.isParent() || point == getRegion()) && "invalid region point");
2423
2424
2425
2426 return getInits();
2427 }
2428
2429 void AffineForOp::getSuccessorRegions(
2431 assert((point.isParent() || point == getRegion()) && "expected loop region");
2432
2433
2434
2435
2436 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2437 if (point.isParent() && tripCount.has_value()) {
2438 if (tripCount.value() > 0) {
2439 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2440 return;
2441 }
2442 if (tripCount.value() == 0) {
2444 return;
2445 }
2446 }
2447
2448
2449
2450 if (!point.isParent() && tripCount == 1) {
2452 return;
2453 }
2454
2455
2456
2457 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2459 }
2460
2461
2463 return getTrivialConstantTripCount(op) == 0;
2464 }
2465
2466 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2471
2472
2473
2474
2475
2476 results.assign(getInits().begin(), getInits().end());
2477 folded = true;
2478 }
2479 return success(folded);
2480 }
2481
2484 }
2485
2488 }
2489
2491 assert(lbOperands.size() == map.getNumInputs());
2492 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2493 getLowerBoundOperandsMutable().assign(lbOperands);
2494 setLowerBoundMap(map);
2495 }
2496
2498 assert(ubOperands.size() == map.getNumInputs());
2499 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2500 getUpperBoundOperandsMutable().assign(ubOperands);
2501 setUpperBoundMap(map);
2502 }
2503
2504 bool AffineForOp::hasConstantLowerBound() {
2505 return getLowerBoundMap().isSingleConstant();
2506 }
2507
2508 bool AffineForOp::hasConstantUpperBound() {
2509 return getUpperBoundMap().isSingleConstant();
2510 }
2511
2512 int64_t AffineForOp::getConstantLowerBound() {
2513 return getLowerBoundMap().getSingleConstantResult();
2514 }
2515
2516 int64_t AffineForOp::getConstantUpperBound() {
2517 return getUpperBoundMap().getSingleConstantResult();
2518 }
2519
2520 void AffineForOp::setConstantLowerBound(int64_t value) {
2522 }
2523
2524 void AffineForOp::setConstantUpperBound(int64_t value) {
2526 }
2527
2528 AffineForOp::operand_range AffineForOp::getControlOperands() {
2531 }
2532
2533 bool AffineForOp::matchingBoundOperandList() {
2534 auto lbMap = getLowerBoundMap();
2535 auto ubMap = getUpperBoundMap();
2538 return false;
2539
2540 unsigned numOperands = lbMap.getNumInputs();
2541 for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2542
2543 if (getOperand(i) != getOperand(numOperands + i))
2544 return false;
2545 }
2546 return true;
2547 }
2548
2550
2551 std::optional<SmallVector> AffineForOp::getLoopInductionVars() {
2553 }
2554
2555 std::optional<SmallVector> AffineForOp::getLoopLowerBounds() {
2556 if (!hasConstantLowerBound())
2557 return std::nullopt;
2560 OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2561 }
2562
2563 std::optional<SmallVector> AffineForOp::getLoopSteps() {
2566 OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
2567 }
2568
2569 std::optional<SmallVector> AffineForOp::getLoopUpperBounds() {
2570 if (!hasConstantUpperBound())
2571 return {};
2574 OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2575 }
2576
2577 FailureOr AffineForOp::replaceWithAdditionalYields(
2579 bool replaceInitOperandUsesInLoop,
2581
2584 auto inits = llvm::to_vector(getInits());
2585 inits.append(newInitOperands.begin(), newInitOperands.end());
2586 AffineForOp newLoop = rewriter.create(
2589
2590
2591 auto yieldOp = cast(getBody()->getTerminator());
2593 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2594 {
2598 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2599 assert(newInitOperands.size() == newYieldedValues.size() &&
2600 "expected as many new yield values as new iter operands");
2602 yieldOp.getOperandsMutable().append(newYieldedValues);
2603 });
2604 }
2605
2606
2607 rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2608 newLoop.getBody()->getArguments().take_front(
2609 getBody()->getNumArguments()));
2610
2611 if (replaceInitOperandUsesInLoop) {
2612
2613
2614 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2619 });
2620 }
2621 }
2622
2623
2624 rewriter.replaceOp(getOperation(),
2625 newLoop->getResults().take_front(getNumResults()));
2626 return cast(newLoop.getOperation());
2627 }
2628
2630
2631
2632
2633
2634
2637 }
2638
2639
2640
2643 }
2644
2647 }
2648
2651 }
2652
2654 auto ivArg = llvm::dyn_cast(val);
2655 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2656 return AffineForOp();
2657 if (auto forOp =
2658 ivArg.getOwner()->getParent()->getParentOfType())
2659
2660 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2661 return AffineForOp();
2662 }
2663
2665 auto ivArg = llvm::dyn_cast(val);
2666 if (!ivArg || !ivArg.getOwner())
2667 return nullptr;
2669 auto parallelOp = dyn_cast_if_present(containingOp);
2670 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2671 return parallelOp;
2672 return nullptr;
2673 }
2674
2675
2676
2679 ivs->reserve(forInsts.size());
2680 for (auto forInst : forInsts)
2681 ivs->push_back(forInst.getInductionVar());
2682 }
2683
2686 ivs.reserve(affineOps.size());
2687 for (Operation *op : affineOps) {
2688
2689 if (auto forOp = dyn_cast(op))
2690 ivs.push_back(forOp.getInductionVar());
2691 else if (auto parallelOp = dyn_cast(op))
2692 for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2693 ivs.push_back(parallelOp.getBody()->getArgument(i));
2694 }
2695 }
2696
2697
2698
2699 template <typename BoundListTy, typename LoopCreatorTy>
2701 OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2704 LoopCreatorTy &&loopCreatorFn) {
2705 assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2706 assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2707
2708
2710 if (lbs.empty()) {
2711 if (bodyBuilderFn)
2712 bodyBuilderFn(builder, loc, ValueRange());
2713 return;
2714 }
2715
2716
2718 ivs.reserve(lbs.size());
2719 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2720
2723 ivs.push_back(iv);
2724
2725 if (i == e - 1 && bodyBuilderFn) {
2727 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2728 }
2729 nestedBuilder.create(nestedLoc);
2730 };
2731
2732
2733
2734 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2736 }
2737 }
2738
2739
2740 static AffineForOp
2742 int64_t ub, int64_t step,
2743 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2744 return builder.create(loc, lb, ub, step,
2745 std::nullopt, bodyBuilderFn);
2746 }
2747
2748
2749 static AffineForOp
2751 int64_t step,
2752 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2755 if (lbConst && ubConst)
2757 ubConst.value(), step, bodyBuilderFn);
2760 std::nullopt, bodyBuilderFn);
2761 }
2762
2769 }
2770
2777 }
2778
2779
2780
2781
2782
2783 namespace {
2784
2785 struct SimplifyDeadElse : public OpRewritePattern {
2787
2788 LogicalResult matchAndRewrite(AffineIfOp ifOp,
2790 if (ifOp.getElseRegion().empty() ||
2791 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2792 return failure();
2793
2795 rewriter.eraseBlock(ifOp.getElseBlock());
2797 return success();
2798 }
2799 };
2800
2801
2802
2803 struct AlwaysTrueOrFalseIf : public OpRewritePattern {
2805
2806 LogicalResult matchAndRewrite(AffineIfOp op,
2808
2809 auto isTriviallyFalse = [](IntegerSet iSet) {
2810 return iSet.isEmptyIntegerSet();
2811 };
2812
2813 auto isTriviallyTrue = [](IntegerSet iSet) {
2814 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2815 iSet.getConstraint(0) == 0);
2816 };
2817
2818 IntegerSet affineIfConditions = op.getIntegerSet();
2819 Block *blockToMove;
2820 if (isTriviallyFalse(affineIfConditions)) {
2821
2822
2823
2824 if (op.getNumResults() == 0 && !op.hasElse()) {
2825
2826
2828 return success();
2829 }
2830 blockToMove = op.getElseBlock();
2831 } else if (isTriviallyTrue(affineIfConditions)) {
2832 blockToMove = op.getThenBlock();
2833 } else {
2834 return failure();
2835 }
2837
2838
2840
2841
2842
2843
2844
2845
2846
2848
2849
2850 rewriter.eraseOp(blockToMoveTerminator);
2851 return success();
2852 }
2853 };
2854 }
2855
2856
2857
2858 void AffineIfOp::getSuccessorRegions(
2860
2861
2863 regions.reserve(2);
2864 regions.push_back(
2865 RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
2866
2867 if (getElseRegion().empty()) {
2868 regions.push_back(getResults());
2869 } else {
2870 regions.push_back(
2871 RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
2872 }
2873 return;
2874 }
2875
2876
2877
2879 }
2880
2882
2883
2884 auto conditionAttr =
2885 (*this)->getAttrOfType(getConditionAttrStrName());
2886 if (!conditionAttr)
2887 return emitOpError("requires an integer set attribute named 'condition'");
2888
2889
2890 IntegerSet condition = conditionAttr.getValue();
2891 if (getNumOperands() != condition.getNumInputs())
2892 return emitOpError("operand count and condition integer set dimension and "
2893 "symbol count must match");
2894
2895
2898 return failure();
2899
2900 return success();
2901 }
2902
2904
2905 IntegerSetAttr conditionAttr;
2906 unsigned numDims;
2908 AffineIfOp::getConditionAttrStrName(),
2911 return failure();
2912
2913
2914 auto set = conditionAttr.getValue();
2915 if (set.getNumDims() != numDims)
2918 "dim operand count and integer set dim count must match");
2919 if (numDims + set.getNumSymbols() != result.operands.size())
2922 "symbol operand count and integer set symbol count must match");
2923
2925 return failure();
2926
2927
2928
2929 result.regions.reserve(2);
2932
2933
2934 if (parser.parseRegion(*thenRegion, {}, {}))
2935 return failure();
2936 AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2938
2939
2941 if (parser.parseRegion(*elseRegion, {}, {}))
2942 return failure();
2943 AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2945 }
2946
2947
2949 return failure();
2950
2951 return success();
2952 }
2953
2955 auto conditionAttr =
2956 (*this)->getAttrOfType(getConditionAttrStrName());
2957 p << " " << conditionAttr;
2959 conditionAttr.getValue().getNumDims(), p);
2961 p << ' ';
2962 p.printRegion(getThenRegion(), false,
2963 getNumResults());
2964
2965
2966 auto &elseRegion = this->getElseRegion();
2967 if (!elseRegion.empty()) {
2968 p << " else ";
2970 false,
2971 getNumResults());
2972 }
2973
2974
2976 getConditionAttrStrName());
2977 }
2978
2979 IntegerSet AffineIfOp::getIntegerSet() {
2980 return (*this)
2981 ->getAttrOfType(getConditionAttrStrName())
2982 .getValue();
2983 }
2984
2985 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2986 (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
2987 }
2988
2990 setIntegerSet(set);
2991 (*this)->setOperands(operands);
2992 }
2993
2996 bool withElseRegion) {
2997 assert(resultTypes.empty() || withElseRegion);
2999
3000 result.addTypes(resultTypes);
3003
3006 if (resultTypes.empty())
3007 AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
3008
3010 if (withElseRegion) {
3012 if (resultTypes.empty())
3013 AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
3014 }
3015 }
3016
3019 AffineIfOp::build(builder, result, {}, set, args,
3020 withElseRegion);
3021 }
3022
3023
3024
3025
3028
3029
3030
3033
3034 if (llvm::none_of(operands,
3036 return;
3037
3041 }
3042
3043
3045 auto set = getIntegerSet();
3049
3050
3051 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3052 return failure();
3053
3054 setConditional(set, operands);
3055 return success();
3056 }
3057
3058 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
3060 results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3061 }
3062
3063
3064
3065
3066
3069 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
3071 if (map)
3073 auto memrefType = llvm::cast(operands[0].getType());
3074 result.types.push_back(memrefType.getElementType());
3075 }
3076
3079 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3082 auto memrefType = llvm::cast(memref.getType());
3084 result.types.push_back(memrefType.getElementType());
3085 }
3086
3089 auto memrefType = llvm::cast(memref.getType());
3090 int64_t rank = memrefType.getRank();
3091
3092
3093 auto map =
3095 build(builder, result, memref, map, indices);
3096 }
3097
3099 auto &builder = parser.getBuilder();
3101
3102 MemRefType type;
3104 AffineMapAttr mapAttr;
3106 return failure(
3109 AffineLoadOp::getMapAttrStrName(),
3116 }
3117
3119 p << " " << getMemRef() << '[';
3120 if (AffineMapAttr mapAttr =
3121 (*this)->getAttrOfType(getMapAttrStrName()))
3123 p << ']';
3125 {getMapAttrStrName()});
3127 }
3128
3129
3130
3131 template
3132 static LogicalResult
3135 MemRefType memrefType, unsigned numIndexOperands) {
3136 AffineMap map = mapAttr.getValue();
3137 if (map.getNumResults() != memrefType.getRank())
3138 return op->emitOpError("affine map num results must equal memref rank");
3140 return op->emitOpError("expects as many subscripts as affine map inputs");
3141
3142 for (auto idx : mapOperands) {
3143 if (!idx.getType().isIndex())
3144 return op->emitOpError("index to load must have 'index' type");
3145 }
3147 return failure();
3148
3149 return success();
3150 }
3151
3154 if (getType() != memrefType.getElementType())
3155 return emitOpError("result type must match element type of memref");
3156
3158 *this, (*this)->getAttrOfType(getMapAttrStrName()),
3159 getMapOperands(), memrefType,
3160 getNumOperands() - 1)))
3161 return failure();
3162
3163 return success();
3164 }
3165
3166 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3168 results.add<SimplifyAffineOp>(context);
3169 }
3170
3171 OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3172
3174 return getResult();
3175
3176
3177 auto getGlobalOp = getMemref().getDefiningOpmemref::GetGlobalOp();
3178 if (!getGlobalOp)
3179 return {};
3180
3181 auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3182 if (!symbolTableOp)
3183 return {};
3184 auto global = dyn_cast_or_nullmemref::GlobalOp(
3186 if (!global)
3187 return {};
3188
3189
3190 auto cstAttr =
3191 llvm::dyn_cast_or_null(global.getConstantInitValue());
3192 if (!cstAttr)
3193 return {};
3194
3195 if (auto splatAttr = llvm::dyn_cast(cstAttr))
3196 return splatAttr.getSplatValue<Attribute>();
3197
3198 if (!getAffineMap().isConstant())
3199 return {};
3200 auto indices = llvm::to_vector<4>(
3201 llvm::map_range(getAffineMap().getConstantResults(),
3202 [](int64_t v) -> uint64_t { return v; }));
3203 return cstAttr.getValues<Attribute>()[indices];
3204 }
3205
3206
3207
3208
3209
3213 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3218 }
3219
3220
3224 auto memrefType = llvm::cast(memref.getType());
3225 int64_t rank = memrefType.getRank();
3226
3227
3228 auto map =
3230 build(builder, result, valueToStore, memref, map, indices);
3231 }
3232
3235
3236 MemRefType type;
3239 AffineMapAttr mapAttr;
3244 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3248 parser.resolveOperand(storeValueInfo, type.getElementType(),
3252 }
3253
3255 p << " " << getValueToStore();
3256 p << ", " << getMemRef() << '[';
3257 if (AffineMapAttr mapAttr =
3258 (*this)->getAttrOfType(getMapAttrStrName()))
3260 p << ']';
3262 {getMapAttrStrName()});
3264 }
3265
3267
3269 if (getValueToStore().getType() != memrefType.getElementType())
3270 return emitOpError(
3271 "value to store must have the same type as memref element type");
3272
3274 *this, (*this)->getAttrOfType(getMapAttrStrName()),
3275 getMapOperands(), memrefType,
3276 getNumOperands() - 2)))
3277 return failure();
3278
3279 return success();
3280 }
3281
3282 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3284 results.add<SimplifyAffineOp>(context);
3285 }
3286
3287 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3289
3291 }
3292
3293
3294
3295
3296
3297 template
3299
3300 if (op.getNumOperands() !=
3301 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3302 return op.emitOpError(
3303 "operand count and affine map dimension and symbol count must match");
3304
3305 if (op.getMap().getNumResults() == 0)
3306 return op.emitOpError("affine map expect at least one result");
3307 return success();
3308 }
3309
3310 template
3312 p << ' ' << op->getAttr(T::getMapAttrStrName());
3313 auto operands = op.getOperands();
3314 unsigned numDims = op.getMap().getNumDims();
3315 p << '(' << operands.take_front(numDims) << ')';
3316
3317 if (operands.size() != numDims)
3318 p << '[' << operands.drop_front(numDims) << ']';
3320 {T::getMapAttrStrName()});
3321 }
3322
3323 template
3326 auto &builder = parser.getBuilder();
3330 AffineMapAttr mapAttr;
3331 return failure(
3332 parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3341 }
3342
3343
3344
3345
3346 template
3348 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3349 "expected affine min or max op");
3350
3351
3352
3353
3355 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3356
3357 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3358 return op.getOperand(0);
3359
3360
3361 if (results.empty()) {
3362
3363 if (foldedMap == op.getMap())
3364 return {};
3366 return op.getResult();
3367 }
3368
3369
3370 auto resultIt = std::is_same<T, AffineMinOp>::value
3371 ? llvm::min_element(results)
3372 : llvm::max_element(results);
3373 if (resultIt == results.end())
3374 return {};
3376 }
3377
3378
3379 template
3382
3385 AffineMap oldMap = affineOp.getAffineMap();
3386
3389
3390
3391 if (!llvm::is_contained(newExprs, expr))
3392 newExprs.push_back(expr);
3393 }
3394
3396 return failure();
3397
3400 rewriter.replaceOpWithNewOp(affineOp, newMap, affineOp.getMapOperands());
3401
3402 return success();
3403 }
3404 };
3405
3406
3407
3408
3409
3410
3411
3412
3413
3414
3415
3416
3417
3418
3419
3420
3421
3422 template
3425
3428 AffineMap oldMap = affineOp.getAffineMap();
3430 affineOp.getMapOperands().take_front(oldMap.getNumDims());
3432 affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3433
3434 auto newDimOperands = llvm::to_vector<8>(dimOperands);
3435 auto newSymOperands = llvm::to_vector<8>(symOperands);
3438
3439
3440
3441
3443 if (auto symExpr = dyn_cast(expr)) {
3444 Value symValue = symOperands[symExpr.getPosition()];
3445 if (auto producerOp = symValue.getDefiningOp()) {
3446 producerOps.push_back(producerOp);
3447 continue;
3448 }
3449 } else if (auto dimExpr = dyn_cast(expr)) {
3450 Value dimValue = dimOperands[dimExpr.getPosition()];
3451 if (auto producerOp = dimValue.getDefiningOp()) {
3452 producerOps.push_back(producerOp);
3453 continue;
3454 }
3455 }
3456
3457
3458
3459 newExprs.push_back(expr);
3460 }
3461
3462 if (producerOps.empty())
3463 return failure();
3464
3465 unsigned numUsedDims = oldMap.getNumDims();
3467
3468
3469 for (T producerOp : producerOps) {
3470 AffineMap producerMap = producerOp.getAffineMap();
3471 unsigned numProducerDims = producerMap.getNumDims();
3472 unsigned numProducerSyms = producerMap.getNumSymbols();
3473
3474
3476 producerOp.getMapOperands().take_front(numProducerDims);
3478 producerOp.getMapOperands().take_back(numProducerSyms);
3479 newDimOperands.append(dimValues.begin(), dimValues.end());
3480 newSymOperands.append(symValues.begin(), symValues.end());
3481
3482
3484 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3485 .shiftSymbols(numProducerSyms, numUsedSyms));
3486 }
3487
3488 numUsedDims += numProducerDims;
3489 numUsedSyms += numProducerSyms;
3490 }
3491
3492 auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3494 auto newOperands =
3495 llvm::to_vector<8>(llvm::concat(newDimOperands, newSymOperands));
3497
3498 return success();
3499 }
3500 };
3501
3502
3503
3504
3505
3506
3507
3508
3509
3513
3514 if (!resultExpr.isPureAffine())
3515 return failure();
3516
3518 auto flattenResult = flattener.walkPostOrder(resultExpr);
3519 if (failed(flattenResult))
3520 return failure();
3521
3522
3525 return failure();
3526
3527 flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3529 }
3530
3531
3532 if (llvm::is_sorted(flattenedExprs))
3533 return failure();
3534
3535
3537 llvm::to_vector(llvm::seq(0, map.getNumResults()));
3538 llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3539 return flattenedExprs[lhs] < flattenedExprs[rhs];
3540 });
3542 for (unsigned idx : resultPermutation)
3543 newExprs.push_back(map.getResult(idx));
3544
3547 return success();
3548 }
3549
3550
3551
3552
3553
3554
3555
3556
3557
3558
3559
3560
3561
3562
3563 template
3566
3569 AffineMap map = affineOp.getAffineMap();
3571 return failure();
3572 rewriter.replaceOpWithNewOp(affineOp, map, affineOp.getMapOperands());
3573 return success();
3574 }
3575 };
3576
3577 template
3580
3583 if (affineOp.getMap().getNumResults() != 1)
3584 return failure();
3585 rewriter.replaceOpWithNewOp(affineOp, affineOp.getMap(),
3586 affineOp.getOperands());
3587 return success();
3588 }
3589 };
3590
3591
3592
3593
3594
3595
3596
3597
3598 OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3599 return foldMinMaxOp(*this, adaptor.getOperands());
3600 }
3601
3608 context);
3609 }
3610
3612
3614 return parseAffineMinMaxOp(parser, result);
3615 }
3616
3618
3619
3620
3621
3622
3623
3624
3625
3626 OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3627 return foldMinMaxOp(*this, adaptor.getOperands());
3628 }
3629
3636 context);
3637 }
3638
3640
3642 return parseAffineMinMaxOp(parser, result);
3643 }
3644
3646
3647
3648
3649
3650
3651
3652
3653
3656 auto &builder = parser.getBuilder();
3658
3659 MemRefType type;
3661 IntegerAttr hintInfo;
3663 StringRef readOrWrite, cacheType;
3664
3665 AffineMapAttr mapAttr;
3669 AffinePrefetchOp::getMapAttrStrName(),
3675 AffinePrefetchOp::getLocalityHintAttrStrName(),
3683 return failure();
3684
3685 if (readOrWrite != "read" && readOrWrite != "write")
3687 "rw specifier has to be 'read' or 'write'");
3688 result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3690
3691 if (cacheType != "data" && cacheType != "instr")
3693 "cache type has to be 'data' or 'instr'");
3694
3695 result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3697
3698 return success();
3699 }
3700
3702 p << " " << getMemref() << '[';
3703 AffineMapAttr mapAttr =
3704 (*this)->getAttrOfType(getMapAttrStrName());
3705 if (mapAttr)
3707 p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3708 << "locality<" << getLocalityHint() << ">, "
3709 << (getIsDataCache() ? "data" : "instr");
3711 (*this)->getAttrs(),
3712 {getMapAttrStrName(), getLocalityHintAttrStrName(),
3713 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3715 }
3716
3718 auto mapAttr = (*this)->getAttrOfType(getMapAttrStrName());
3719 if (mapAttr) {
3720 AffineMap map = mapAttr.getValue();
3722 return emitOpError("affine.prefetch affine map num results must equal"
3723 " memref rank");
3724 if (map.getNumInputs() + 1 != getNumOperands())
3725 return emitOpError("too few operands");
3726 } else {
3727 if (getNumOperands() != 1)
3728 return emitOpError("too few operands");
3729 }
3730
3732 for (auto idx : getMapOperands()) {
3734 return emitOpError(
3735 "index must be a valid dimension or symbol identifier");
3736 }
3737 return success();
3738 }
3739
3740 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3742
3743 results.add<SimplifyAffineOp>(context);
3744 }
3745
3746 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3748
3750 }
3751
3752
3753
3754
3755
3761 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3763 }));
3765 build(builder, result, resultTypes, reductions, lbs, {}, ubs,
3766 {}, steps);
3767 }
3768
3775 assert(llvm::all_of(lbMaps,
3777 return m.getNumDims() == lbMaps[0].getNumDims() &&
3778 m.getNumSymbols() == lbMaps[0].getNumSymbols();
3779 }) &&
3780 "expected all lower bounds maps to have the same number of dimensions "
3781 "and symbols");
3782 assert(llvm::all_of(ubMaps,
3784 return m.getNumDims() == ubMaps[0].getNumDims() &&
3785 m.getNumSymbols() == ubMaps[0].getNumSymbols();
3786 }) &&
3787 "expected all upper bounds maps to have the same number of dimensions "
3788 "and symbols");
3789 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3790 "expected lower bound maps to have as many inputs as lower bound "
3791 "operands");
3792 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3793 "expected upper bound maps to have as many inputs as upper bound "
3794 "operands");
3795
3797 result.addTypes(resultTypes);
3798
3799
3801 for (arith::AtomicRMWKind reduction : reductions)
3802 reductionAttrs.push_back(
3804 result.addAttribute(getReductionsAttrStrName(),
3806
3807
3808
3811 if (maps.empty())
3814 groups.reserve(groups.size() + maps.size());
3815 exprs.reserve(maps.size());
3817 llvm::append_range(exprs, m.getResults());
3818 groups.push_back(m.getNumResults());
3819 }
3820 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3822 };
3823
3824
3826 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3827 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3828 result.addAttribute(getLowerBoundsMapAttrStrName(),
3830 result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3832 result.addAttribute(getUpperBoundsMapAttrStrName(),
3834 result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3839
3840
3841 auto *bodyRegion = result.addRegion();
3843
3844
3845 for (unsigned i = 0, e = steps.size(); i < e; ++i)
3847 if (resultTypes.empty())
3848 ensureTerminator(*bodyRegion, builder, result.location);
3849 }
3850
3852 return {&getRegion()};
3853 }
3854
3855 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3856
3857 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3858 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3859 }
3860
3861 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3862 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3863 }
3864
3865 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3866 auto values = getLowerBoundsGroups().getValues<int32_t>();
3867 unsigned start = 0;
3868 for (unsigned i = 0; i < pos; ++i)
3869 start += values[i];
3870 return getLowerBoundsMap().getSliceMap(start, values[pos]);
3871 }
3872
3873 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3874 auto values = getUpperBoundsGroups().getValues<int32_t>();
3875 unsigned start = 0;
3876 for (unsigned i = 0; i < pos; ++i)
3877 start += values[i];
3878 return getUpperBoundsMap().getSliceMap(start, values[pos]);
3879 }
3880
3881 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3882 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3883 }
3884
3885 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3886 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3887 }
3888
3889 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3890 if (hasMinMaxBounds())
3891 return std::nullopt;
3892
3893
3896 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3897 &rangesValueMap);
3899 for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
3900 auto expr = rangesValueMap.getResult(i);
3901 auto cst = dyn_cast(expr);
3902 if (!cst)
3903 return std::nullopt;
3904 out.push_back(cst.getValue());
3905 }
3906 return out;
3907 }
3908
3909 Block *AffineParallelOp::getBody() { return &getRegion().front(); }
3910
3911 OpBuilder AffineParallelOp::getBodyBuilder() {
3912 return OpBuilder(getBody(), std::prev(getBody()->end()));
3913 }
3914
3915 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
3916 assert(lbOperands.size() == map.getNumInputs() &&
3917 "operands to map must match number of inputs");
3918
3919 auto ubOperands = getUpperBoundsOperands();
3920
3922 newOperands.append(ubOperands.begin(), ubOperands.end());
3923 (*this)->setOperands(newOperands);
3924
3926 }
3927
3928 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
3929 assert(ubOperands.size() == map.getNumInputs() &&
3930 "operands to map must match number of inputs");
3931
3933 newOperands.append(ubOperands.begin(), ubOperands.end());
3934 (*this)->setOperands(newOperands);
3935
3937 }
3938
3940 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3941 }
3942
3943
3945 arith::AtomicRMWKind op) {
3946 switch (op) {
3947 case arith::AtomicRMWKind::addf:
3948 return isa(resultType);
3949 case arith::AtomicRMWKind::addi:
3950 return isa(resultType);
3951 case arith::AtomicRMWKind::assign:
3952 return true;
3953 case arith::AtomicRMWKind::mulf:
3954 return isa(resultType);
3955 case arith::AtomicRMWKind::muli:
3956 return isa(resultType);
3957 case arith::AtomicRMWKind::maximumf:
3958 return isa(resultType);
3959 case arith::AtomicRMWKind::minimumf:
3960 return isa(resultType);
3961 case arith::AtomicRMWKind::maxs: {
3962 auto intType = llvm::dyn_cast(resultType);
3963 return intType && intType.isSigned();
3964 }
3965 case arith::AtomicRMWKind::mins: {
3966 auto intType = llvm::dyn_cast(resultType);
3967 return intType && intType.isSigned();
3968 }
3969 case arith::AtomicRMWKind::maxu: {
3970 auto intType = llvm::dyn_cast(resultType);
3971 return intType && intType.isUnsigned();
3972 }
3973 case arith::AtomicRMWKind::minu: {
3974 auto intType = llvm::dyn_cast(resultType);
3975 return intType && intType.isUnsigned();
3976 }
3977 case arith::AtomicRMWKind::ori:
3978 return isa(resultType);
3979 case arith::AtomicRMWKind::andi:
3980 return isa(resultType);
3981 default:
3982 return false;
3983 }
3984 }
3985
3987 auto numDims = getNumDims();
3988 if (getLowerBoundsGroups().getNumElements() != numDims ||
3989 getUpperBoundsGroups().getNumElements() != numDims ||
3990 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3991 return emitOpError() << "the number of region arguments ("
3992 << getBody()->getNumArguments()
3993 << ") and the number of map groups for lower ("
3994 << getLowerBoundsGroups().getNumElements()
3995 << ") and upper bound ("
3996 << getUpperBoundsGroups().getNumElements()
3997 << "), and the number of steps (" << getSteps().size()
3998 << ") must all match";
3999 }
4000
4001 unsigned expectedNumLBResults = 0;
4002 for (APInt v : getLowerBoundsGroups()) {
4003 unsigned results = v.getZExtValue();
4004 if (results == 0)
4005 return emitOpError()
4006 << "expected lower bound map to have at least one result";
4007 expectedNumLBResults += results;
4008 }
4009 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4010 return emitOpError() << "expected lower bounds map to have "
4011 << expectedNumLBResults << " results";
4012 unsigned expectedNumUBResults = 0;
4013 for (APInt v : getUpperBoundsGroups()) {
4014 unsigned results = v.getZExtValue();
4015 if (results == 0)
4016 return emitOpError()
4017 << "expected upper bound map to have at least one result";
4018 expectedNumUBResults += results;
4019 }
4020 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4021 return emitOpError() << "expected upper bounds map to have "
4022 << expectedNumUBResults << " results";
4023
4024 if (getReductions().size() != getNumResults())
4025 return emitOpError("a reduction must be specified for each output");
4026
4027
4028
4031 auto intAttr = llvm::dyn_cast(attr);
4032 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4033 return emitOpError("invalid reduction attribute");
4034 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4036 return emitOpError("result type cannot match reduction attribute");
4037 }
4038
4039
4040
4042 getLowerBoundsMap().getNumDims())))
4043 return failure();
4044
4046 getUpperBoundsMap().getNumDims())))
4047 return failure();
4048 return success();
4049 }
4050
4051 LogicalResult AffineValueMap::canonicalize() {
4053 auto newMap = getAffineMap();
4055 if (newMap == getAffineMap() && newOperands == operands)
4056 return failure();
4057 reset(newMap, newOperands);
4058 return success();
4059 }
4060
4061
4064 bool lbCanonicalized = succeeded(lb.canonicalize());
4065
4067 bool ubCanonicalized = succeeded(ub.canonicalize());
4068
4069
4070 if (!lbCanonicalized && !ubCanonicalized)
4071 return failure();
4072
4073 if (lbCanonicalized)
4075 if (ubCanonicalized)
4077
4078 return success();
4079 }
4080
4081 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4084 }
4085
4086
4087
4088
4089
4090
4093 StringRef keyword) {
4094 AffineMap map = mapAttr.getValue();
4095 unsigned numDims = map.getNumDims();
4096 ValueRange dimOperands = operands.take_front(numDims);
4097 ValueRange symOperands = operands.drop_front(numDims);
4098 unsigned start = 0;
4099 for (llvm::APInt groupSize : group) {
4100 if (start != 0)
4101 p << ", ";
4102
4103 unsigned size = groupSize.getZExtValue();
4104 if (size == 1) {
4106 ++start;
4107 } else {
4108 p << keyword << '(';
4111 p << ')';
4112 start += size;
4113 }
4114 }
4115 }
4116
4118 p << " (" << getBody()->getArguments() << ") = (";
4119 printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4120 getLowerBoundsOperands(), "max");
4121 p << ") to (";
4122 printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4123 getUpperBoundsOperands(), "min");
4124 p << ')';
4126 bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4127 if (!elideSteps) {
4128 p << " step (";
4129 llvm::interleaveComma(steps, p);
4130 p << ')';
4131 }
4132 if (getNumResults()) {
4133 p << " reduce (";
4134 llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4135 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4136 llvm::cast(attr).getInt());
4137 p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4138 });
4139 p << ") -> (" << getResultTypes() << ")";
4140 }
4141
4142 p << ' ';
4143 p.printRegion(getRegion(), false,
4144 getNumResults());
4146 (*this)->getAttrs(),
4147 {AffineParallelOp::getReductionsAttrStrName(),
4148 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4149 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4150 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4151 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4152 AffineParallelOp::getStepsAttrStrName()});
4153 }
4154
4155
4156
4157
4158
4165 "expected operands to be dim or symbol expression");
4166
4168 for (const auto &list : operands) {
4170 if (parser.resolveOperands(list, indexType, valueOperands))
4171 return failure();
4172 for (Value operand : valueOperands) {
4173 unsigned pos = std::distance(uniqueOperands.begin(),
4174 llvm::find(uniqueOperands, operand));
4175 if (pos == uniqueOperands.size())
4176 uniqueOperands.push_back(operand);
4177 replacements.push_back(
4181 }
4182 }
4183 return success();
4184 }
4185
4186 namespace {
4187 enum class MinMaxKind { Min, Max };
4188 }
4189
4190
4191
4192
4193
4194
4195
4196
4197
4198
4199
4200
4201
4202
4203
4204
4205
4208 MinMaxKind kind) {
4209
4210
4211 const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4212
4213 StringRef mapName = kind == MinMaxKind::Min
4214 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4215 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4216 StringRef groupsName =
4217 kind == MinMaxKind::Min
4218 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4219 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4220
4222 return failure();
4223
4228 return success();
4229 }
4230
4236 auto parseOperands = [&]() {
4238 kind == MinMaxKind::Min ? "min" : "max"))) {
4239 mapOperands.clear();
4240 AffineMapAttr map;
4244 return failure();
4246 llvm::append_range(flatExprs, map.getValue().getResults());
4248 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4250 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4252 flatDimOperands.append(map.getValue().getNumResults(), dims);
4253 flatSymOperands.append(map.getValue().getNumResults(), syms);
4254 numMapsPerGroup.push_back(map.getValue().getNumResults());
4255 } else {
4257 flatSymOperands.emplace_back(),
4258 flatExprs.emplace_back())))
4259 return failure();
4260 numMapsPerGroup.push_back(1);
4261 }
4262 return success();
4263 };
4265 return failure();
4266
4267 unsigned totalNumDims = 0;
4268 unsigned totalNumSyms = 0;
4269 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4270 unsigned numDims = flatDimOperands[i].size();
4271 unsigned numSyms = flatSymOperands[i].size();
4272 flatExprs[i] = flatExprs[i]
4273 .shiftDims(numDims, totalNumDims)
4274 .shiftSymbols(numSyms, totalNumSyms);
4275 totalNumDims += numDims;
4276 totalNumSyms += numSyms;
4277 }
4278
4279
4286 return failure();
4287
4288 result.operands.append(dimOperands.begin(), dimOperands.end());
4289 result.operands.append(symOperands.begin(), symOperands.end());
4290
4292 auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4294 flatMap = flatMap.replaceDimsAndSymbols(
4295 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4296
4299 return success();
4300 }
4301
4302
4303
4304
4305
4306
4309 auto &builder = parser.getBuilder();
4317 return failure();
4318
4319 AffineMapAttr stepsMapAttr;
4324 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4326 } else {
4328 AffineParallelOp::getStepsAttrStrName(),
4329 stepsAttrs,
4331 return failure();
4332
4333
4335 auto stepsMap = stepsMapAttr.getValue();
4336 for (const auto &result : stepsMap.getResults()) {
4337 auto constExpr = dyn_cast(result);
4338 if (!constExpr)
4340 "steps must be constant integers");
4341 steps.push_back(constExpr.getValue());
4342 }
4343 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4345 }
4346
4347
4348
4352 return failure();
4353 auto parseAttributes = [&]() -> ParseResult {
4354
4355
4356
4357 StringAttr attrVal;
4361 attrStorage))
4362 return failure();
4363 std::optionalarith::AtomicRMWKind reduction =
4364 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4365 if (!reduction)
4366 return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4367 reductions.push_back(
4368 builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4369
4370 return success();
4371 };
4373 return failure();
4374 }
4375 result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4377
4378
4380 return failure();
4381
4382
4384 for (auto &iv : ivs)
4385 iv.type = indexType;
4388 return failure();
4389
4390
4391 AffineParallelOp::ensureTerminator(*body, builder, result.location);
4392 return success();
4393 }
4394
4395
4396
4397
4398
4400 auto *parentOp = (*this)->getParentOp();
4401 auto results = parentOp->getResults();
4402 auto operands = getOperands();
4403
4404 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4405 return emitOpError() << "only terminates affine.if/for/parallel regions";
4406 if (parentOp->getNumResults() != getNumOperands())
4407 return emitOpError() << "parent of yield must have same number of "
4408 "results as the yield operands";
4409 for (auto it : llvm::zip(results, operands)) {
4410 if (std::get<0>(it).getType() != std::get<1>(it).getType())
4411 return emitOpError() << "types mismatch between yield op and its parent";
4412 }
4413
4414 return success();
4415 }
4416
4417
4418
4419
4420
4422 VectorType resultType, AffineMap map,
4424 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4426 if (map)
4428 result.types.push_back(resultType);
4429 }
4430
4432 VectorType resultType, Value memref,
4434 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4438 result.types.push_back(resultType);
4439 }
4440
4442 VectorType resultType, Value memref,
4444 auto memrefType = llvm::cast(memref.getType());
4445 int64_t rank = memrefType.getRank();
4446
4447
4448 auto map =
4450 build(builder, result, resultType, memref, map, indices);
4451 }
4452
4453 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4455 results.add<SimplifyAffineOp>(context);
4456 }
4457
4460 auto &builder = parser.getBuilder();
4462
4463 MemRefType memrefType;
4464 VectorType resultType;
4466 AffineMapAttr mapAttr;
4468 return failure(
4471 AffineVectorLoadOp::getMapAttrStrName(),
4479 }
4480
4482 p << " " << getMemRef() << '[';
4483 if (AffineMapAttr mapAttr =
4484 (*this)->getAttrOfType(getMapAttrStrName()))
4486 p << ']';
4488 {getMapAttrStrName()});
4490 }
4491
4492
4494 VectorType vectorType) {
4495
4496 if (memrefType.getElementType() != vectorType.getElementType())
4498 "requires memref and vector types of the same elemental type");
4499 return success();
4500 }
4501
4505 *this, (*this)->getAttrOfType(getMapAttrStrName()),
4506 getMapOperands(), memrefType,
4507 getNumOperands() - 1)))
4508 return failure();
4509
4511 return failure();
4512
4513 return success();
4514 }
4515
4516
4517
4518
4519
4523 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4528 }
4529
4530
4534 auto memrefType = llvm::cast(memref.getType());
4535 int64_t rank = memrefType.getRank();
4536
4537
4538 auto map =
4540 build(builder, result, valueToStore, memref, map, indices);
4541 }
4542 void AffineVectorStoreOp::getCanonicalizationPatterns(
4544 results.add<SimplifyAffineOp>(context);
4545 }
4546
4550
4551 MemRefType memrefType;
4552 VectorType resultType;
4555 AffineMapAttr mapAttr;
4557 return failure(
4561 AffineVectorStoreOp::getMapAttrStrName(),
4569 }
4570
4572 p << " " << getValueToStore();
4573 p << ", " << getMemRef() << '[';
4574 if (AffineMapAttr mapAttr =
4575 (*this)->getAttrOfType(getMapAttrStrName()))
4577 p << ']';
4579 {getMapAttrStrName()});
4580 p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4581 }
4582
4586 *this, (*this)->getAttrOfType(getMapAttrStrName()),
4587 getMapOperands(), memrefType,
4588 getNumOperands() - 2)))
4589 return failure();
4590
4592 return failure();
4593
4594 return success();
4595 }
4596
4597
4598
4599
4600
4601 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4605 bool hasOuterBound) {
4606 SmallVector returnTypes(hasOuterBound ? staticBasis.size()
4607 : staticBasis.size() + 1,
4609 build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4610 staticBasis);
4611 }
4612
4613 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4616 bool hasOuterBound) {
4617 if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
4618 hasOuterBound = false;
4619 basis = basis.drop_front();
4620 }
4624 staticBasis);
4625 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4626 hasOuterBound);
4627 }
4628
4629 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4631 Value linearIndex,
4633 bool hasOuterBound) {
4634 if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
4635 hasOuterBound = false;
4636 basis = basis.drop_front();
4637 }
4641 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4642 hasOuterBound);
4643 }
4644
4645 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4648 bool hasOuterBound) {
4649 build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
4650 }
4651
4654 if (getNumResults() != staticBasis.size() &&
4655 getNumResults() != staticBasis.size() + 1)
4656 return emitOpError("should return an index for each basis element and up "
4657 "to one extra index");
4658
4659 auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4660 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4661 return emitOpError(
4662 "mismatch between dynamic and static basis (kDynamic marker but no "
4663 "corresponding dynamic basis entry) -- this can only happen due to an "
4664 "incorrect fold/rewrite");
4665
4666 if (!llvm::all_of(staticBasis, [](int64_t v) {
4667 return v > 0 || ShapedType::isDynamic(v);
4668 }))
4669 return emitOpError("no basis element may be statically non-positive");
4670
4671 return success();
4672 }
4673
4674
4675
4676
4677
4678 static std::optional<SmallVector<int64_t>>
4682 uint64_t dynamicBasisIndex = 0;
4684 if (basis) {
4685 mutableDynamicBasis.erase(dynamicBasisIndex);
4686 } else {
4687 ++dynamicBasisIndex;
4688 }
4689 }
4690
4691
4692 if (dynamicBasisIndex == dynamicBasis.size())
4693 return std::nullopt;
4694
4698 if (!basisVal)
4699 staticBasis.push_back(ShapedType::kDynamic);
4700 else
4701 staticBasis.push_back(*basisVal);
4702 }
4703
4704 return staticBasis;
4705 }
4706
4707 LogicalResult
4708 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4710 std::optional<SmallVector<int64_t>> maybeStaticBasis =
4712 adaptor.getDynamicBasis());
4713 if (maybeStaticBasis) {
4714 setStaticBasis(*maybeStaticBasis);
4715 return success();
4716 }
4717
4718
4719 if (getNumResults() == 1) {
4720 result.push_back(getLinearIndex());
4721 return success();
4722 }
4723
4724 if (adaptor.getLinearIndex() == nullptr)
4725 return failure();
4726
4727 if (!adaptor.getDynamicBasis().empty())
4728 return failure();
4729
4730 int64_t highPart = cast(adaptor.getLinearIndex()).getInt();
4731 Type attrType = getLinearIndex().getType();
4732
4734 if (hasOuterBound())
4735 staticBasis = staticBasis.drop_front();
4736 for (int64_t modulus : llvm::reverse(staticBasis)) {
4737 result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4738 highPart = llvm::divideFloorSigned(highPart, modulus);
4739 }
4741 std::reverse(result.begin(), result.end());
4742 return success();
4743 }
4744
4747 if (hasOuterBound()) {
4748 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4749 return getMixedValues(getStaticBasis().drop_front(),
4750 getDynamicBasis().drop_front(), builder);
4751
4752 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4753 builder);
4754 }
4755
4756 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4757 }
4758
4761 if (!hasOuterBound())
4763 return ret;
4764 }
4765
4766 namespace {
4767
4768
4769 struct DropUnitExtentBasis
4770 : public OpRewritePatternaffine::AffineDelinearizeIndexOp {
4772
4773 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4775 SmallVector replacements(delinearizeOp->getNumResults(), nullptr);
4776 std::optional zero = std::nullopt;
4777 Location loc = delinearizeOp->getLoc();
4779 if (!zero)
4780 zero = rewriter.createarith::ConstantIndexOp(loc, 0);
4781 return zero.value();
4782 };
4783
4784
4785
4787 for (auto [index, basis] :
4789 std::optional<int64_t> basisVal =
4791 if (basisVal == 1)
4792 replacements[index] = getZero();
4793 else
4794 newBasis.push_back(basis);
4795 }
4796
4797 if (newBasis.size() == delinearizeOp.getNumResults())
4799 "no unit basis elements");
4800
4801 if (!newBasis.empty()) {
4802
4803 auto newDelinearizeOp = rewriter.createaffine::AffineDelinearizeIndexOp(
4804 loc, delinearizeOp.getLinearIndex(), newBasis);
4805 int newIndex = 0;
4806
4807 for (auto &replacement : replacements) {
4808 if (replacement)
4809 continue;
4810 replacement = newDelinearizeOp->getResult(newIndex++);
4811 }
4812 }
4813
4814 rewriter.replaceOp(delinearizeOp, replacements);
4815 return success();
4816 }
4817 };
4818
4819
4820
4821
4822
4823
4824
4825
4826
4827
4828
4829 struct CancelDelinearizeOfLinearizeDisjointExactTail
4830 : public OpRewritePatternaffine::AffineDelinearizeIndexOp {
4832
4833 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4835 auto linearizeOp = delinearizeOp.getLinearIndex()
4836 .getDefiningOpaffine::AffineLinearizeIndexOp();
4837 if (!linearizeOp)
4839 "index doesn't come from linearize");
4840
4841 if (!linearizeOp.getDisjoint())
4843
4844 ValueRange linearizeIns = linearizeOp.getMultiIndex();
4845
4848 size_t numMatches = 0;
4849 for (auto [linSize, delinSize] : llvm::zip(
4850 llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
4851 if (linSize != delinSize)
4852 break;
4853 ++numMatches;
4854 }
4855
4856 if (numMatches == 0)
4858 delinearizeOp, "final basis element doesn't match linearize");
4859
4860
4861 if (numMatches == linearizeBasis.size() &&
4862 numMatches == delinearizeBasis.size() &&
4863 linearizeIns.size() == delinearizeOp.getNumResults()) {
4864 rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
4865 return success();
4866 }
4867
4868 Value newLinearize = rewriter.createaffine::AffineLinearizeIndexOp(
4869 linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
4871 linearizeOp.getDisjoint());
4872 auto newDelinearize = rewriter.createaffine::AffineDelinearizeIndexOp(
4873 delinearizeOp.getLoc(), newLinearize,
4875 delinearizeOp.hasOuterBound());
4877 mergedResults.append(linearizeIns.take_back(numMatches).begin(),
4878 linearizeIns.take_back(numMatches).end());
4879 rewriter.replaceOp(delinearizeOp, mergedResults);
4880 return success();
4881 }
4882 };
4883
4884
4885
4886
4887
4888
4889
4890
4891
4892
4893
4894
4895
4896
4897 struct SplitDelinearizeSpanningLastLinearizeArg final
4900
4901 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4903 auto linearizeOp = delinearizeOp.getLinearIndex()
4904 .getDefiningOpaffine::AffineLinearizeIndexOp();
4905 if (!linearizeOp)
4907 "index doesn't come from linearize");
4908
4909 if (!linearizeOp.getDisjoint())
4911 "linearize isn't disjoint");
4912
4913 int64_t target = linearizeOp.getStaticBasis().back();
4914 if (ShapedType::isDynamic(target))
4916 linearizeOp, "linearize ends with dynamic basis value");
4917
4918 int64_t sizeToSplit = 1;
4919 size_t elemsToSplit = 0;
4921 for (int64_t basisElem : llvm::reverse(basis)) {
4922 if (ShapedType::isDynamic(basisElem))
4924 delinearizeOp, "dynamic basis element while scanning for split");
4925 sizeToSplit *= basisElem;
4926 elemsToSplit += 1;
4927
4928 if (sizeToSplit > target)
4930 "overshot last argument size");
4931 if (sizeToSplit == target)
4932 break;
4933 }
4934
4935 if (sizeToSplit < target)
4937 delinearizeOp, "product of known basis elements doesn't exceed last "
4938 "linearize argument");
4939
4940 if (elemsToSplit < 2)
4942 delinearizeOp,
4943 "need at least two elements to form the basis product");
4944
4945 Value linearizeWithoutBack =
4946 rewriter.createaffine::AffineLinearizeIndexOp(
4947 linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
4948 linearizeOp.getDynamicBasis(),
4949 linearizeOp.getStaticBasis().drop_back(),
4950 linearizeOp.getDisjoint());
4951 auto delinearizeWithoutSplitPart =
4952 rewriter.createaffine::AffineDelinearizeIndexOp(
4953 delinearizeOp.getLoc(), linearizeWithoutBack,
4954 delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
4955 delinearizeOp.hasOuterBound());
4956 auto delinearizeBack = rewriter.createaffine::AffineDelinearizeIndexOp(
4957 delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
4958 basis.take_back(elemsToSplit), true);
4960 llvm::concat(delinearizeWithoutSplitPart.getResults(),
4961 delinearizeBack.getResults()));
4962 rewriter.replaceOp(delinearizeOp, results);
4963
4964 return success();
4965 }
4966 };
4967 }
4968
4969 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
4972 .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
4973 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
4974 context);
4975 }
4976
4977
4978
4979
4980
4981 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4984 bool disjoint) {
4985 if (!basis.empty() && basis.front() == Value())
4986 basis = basis.drop_front();
4990 staticBasis);
4991 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4992 }
4993
4994 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4998 bool disjoint) {
4999 if (!basis.empty() && basis.front() == OpFoldResult())
5000 basis = basis.drop_front();
5004 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5005 }
5006
5007 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5011 build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
5012 }
5013
5015 size_t numIndexes = getMultiIndex().size();
5016 size_t numBasisElems = getStaticBasis().size();
5017 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5018 return emitOpError("should be passed a basis element for each index except "
5019 "possibly the first");
5020
5021 auto dynamicMarkersCount =
5022 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
5023 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5024 return emitOpError(
5025 "mismatch between dynamic and static basis (kDynamic marker but no "
5026 "corresponding dynamic basis entry) -- this can only happen due to an "
5027 "incorrect fold/rewrite");
5028
5029 return success();
5030 }
5031
5032 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5033 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5035 adaptor.getDynamicBasis());
5036 if (maybeStaticBasis) {
5037 setStaticBasis(*maybeStaticBasis);
5038 return getResult();
5039 }
5040
5041 if (getMultiIndex().empty())
5043
5044
5045 if (getMultiIndex().size() == 1)
5046 return getMultiIndex().front();
5047
5048 if (llvm::is_contained(adaptor.getMultiIndex(), nullptr))
5049 return nullptr;
5050
5051 if (!adaptor.getDynamicBasis().empty())
5052 return nullptr;
5053
5054 int64_t result = 0;
5055 int64_t stride = 1;
5056 for (auto [length, indexAttr] :
5057 llvm::zip_first(llvm::reverse(getStaticBasis()),
5058 llvm::reverse(adaptor.getMultiIndex()))) {
5059 result = result + cast(indexAttr).getInt() * stride;
5060 stride = stride * length;
5061 }
5062
5063 if (!hasOuterBound())
5064 result =
5065 result +
5066 cast(adaptor.getMultiIndex().front()).getInt() * stride;
5067
5069 }
5070
5073 if (hasOuterBound()) {
5074 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5075 return getMixedValues(getStaticBasis().drop_front(),
5076 getDynamicBasis().drop_front(), builder);
5077
5078 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5079 builder);
5080 }
5081
5082 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5083 }
5084
5087 if (!hasOuterBound())
5089 return ret;
5090 }
5091
5092 namespace {
5093
5094
5095
5096
5097
5098
5099
5100
5101
5102
5103 struct DropLinearizeUnitComponentsIfDisjointOrZero final
5106
5107 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5109 ValueRange multiIndex = op.getMultiIndex();
5110 size_t numIndices = multiIndex.size();
5112 newIndices.reserve(numIndices);
5114 newBasis.reserve(numIndices);
5115
5116 if (!op.hasOuterBound()) {
5117 newIndices.push_back(multiIndex.front());
5118 multiIndex = multiIndex.drop_front();
5119 }
5120
5122 for (auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5124 if (!basisEntry || *basisEntry != 1) {
5125 newIndices.push_back(index);
5126 newBasis.push_back(basisElem);
5127 continue;
5128 }
5129
5131 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5132 newIndices.push_back(index);
5133 newBasis.push_back(basisElem);
5134 continue;
5135 }
5136 }
5137 if (newIndices.size() == numIndices)
5139 "no unit basis entries to replace");
5140
5141 if (newIndices.size() == 0) {
5143 return success();
5144 }
5146 op, newIndices, newBasis, op.getDisjoint());
5147 return success();
5148 }
5149 };
5150
5153 int64_t nDynamic = 0;
5157 if (!term)
5158 return term;
5160 if (maybeConst) {
5162 } else {
5163 dynamicPart.push_back(cast(term));
5165 }
5166 }
5167 if (auto constant = dyn_cast(result))
5169 return builder.create(loc, result, dynamicPart).getResult();
5170 }
5171
5172
5173
5174
5175
5176
5177
5178
5179
5180
5181
5182
5183
5184
5185
5186
5187
5188
5189
5190
5191
5192
5193
5194
5195
5196
5197
5198
5199 struct CancelLinearizeOfDelinearizePortion final
5202
5203 private:
5204
5205
5206
5207
5210 unsigned linStart = 0;
5211 unsigned delinStart = 0;
5212 unsigned length = 0;
5213 };
5214
5215 public:
5216 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5219
5222
5223 ValueRange multiIndex = linearizeOp.getMultiIndex();
5224 unsigned numLinArgs = multiIndex.size();
5225 unsigned linArgIdx = 0;
5226
5227
5229 while (linArgIdx < numLinArgs) {
5230 auto asResult = dyn_cast(multiIndex[linArgIdx]);
5231 if (!asResult) {
5232 linArgIdx++;
5233 continue;
5234 }
5235
5236 auto delinearizeOp =
5237 dyn_cast(asResult.getOwner());
5238 if (!delinearizeOp) {
5239 linArgIdx++;
5240 continue;
5241 }
5242
5243
5244
5245
5246
5247
5248
5249
5250
5251
5252
5253
5254
5255 unsigned delinArgIdx = asResult.getResultNumber();
5257 OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5258 OpFoldResult firstLinBound = linBasis[linArgIdx];
5259 bool boundsMatch = firstDelinBound == firstLinBound;
5260 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5261 bool knownByDisjoint =
5262 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5263 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5264 linArgIdx++;
5265 continue;
5266 }
5267
5268 unsigned j = 1;
5269 unsigned numDelinOuts = delinearizeOp.getNumResults();
5270 for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
5271 ++j) {
5272 if (multiIndex[linArgIdx + j] !=
5273 delinearizeOp.getResult(delinArgIdx + j))
5274 break;
5275 if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
5276 break;
5277 }
5278
5279
5280
5281 if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5282 linArgIdx++;
5283 continue;
5284 }
5285 matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j});
5286 linArgIdx += j;
5287 }
5288
5289 if (matches.empty())
5291 linearizeOp, "no run of delinearize outputs to deal with");
5292
5293
5294
5295
5297
5299 newIndex.reserve(numLinArgs);
5301 newBasis.reserve(numLinArgs);
5302 unsigned prevMatchEnd = 0;
5303 for (Match m : matches) {
5304 unsigned gap = m.linStart - prevMatchEnd;
5305 llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5306 llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5307
5308 prevMatchEnd = m.linStart + m.length;
5309
5310 PatternRewriter::InsertionGuard g(rewriter);
5312
5314 linBasisRef.slice(m.linStart, m.length);
5315
5316
5318 computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
5319
5320
5321 if (m.length == m.delinearize.getNumResults()) {
5322 newIndex.push_back(m.delinearize.getLinearIndex());
5323 newBasis.push_back(newSize);
5324
5326 continue;
5327 }
5328
5331 newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5332 newDelinBasis.begin() + m.delinStart + m.length);
5333 newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5334 auto newDelinearize = rewriter.create(
5335 m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5336 newDelinBasis);
5337
5338
5339
5340
5341 Value combinedElem = newDelinearize.getResult(m.delinStart);
5342 auto residualDelinearize = rewriter.create(
5343 m.delinearize.getLoc(), combinedElem, basisToMerge);
5344
5345
5346
5347
5348 llvm::append_range(newDelinResults,
5349 newDelinearize.getResults().take_front(m.delinStart));
5350 llvm::append_range(newDelinResults, residualDelinearize.getResults());
5351 llvm::append_range(
5352 newDelinResults,
5353 newDelinearize.getResults().drop_front(m.delinStart + 1));
5354
5355 delinearizeReplacements.push_back(newDelinResults);
5356 newIndex.push_back(combinedElem);
5357 newBasis.push_back(newSize);
5358 }
5359 llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5360 llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5362 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5363
5364 for (auto [m, newResults] :
5365 llvm::zip_equal(matches, delinearizeReplacements)) {
5366 if (newResults.empty())
5367 continue;
5368 rewriter.replaceOp(m.delinearize, newResults);
5369 }
5370
5371 return success();
5372 }
5373 };
5374
5375
5376
5377
5378
5379 struct DropLinearizeLeadingZero final
5382
5383 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5385 Value leadingIdx = op.getMultiIndex().front();
5387 return failure();
5388
5389 if (op.getMultiIndex().size() == 1) {
5390 rewriter.replaceOp(op, leadingIdx);
5391 return success();
5392 }
5393
5396 if (op.hasOuterBound())
5397 newMixedBasis = newMixedBasis.drop_front();
5398
5400 op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5401 return success();
5402 }
5403 };
5404 }
5405
5406 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5408 patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5409 DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5410 }
5411
5412
5413
5414
5415
5416 #define GET_OP_CLASSES
5417 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
static bool isValidAffineIndexOperand(Value value, Region *region)
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
static void legalizeDemotedDims(MapOrSet &mapOrSet, SmallVectorImpl< Value > &operands)
A valid affine dimension may appear as a symbol in affine.apply operations.
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr "ientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult verifyAffineMinMaxOp(T op)
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
MLIRContext * getContext() const
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
unsigned getNumResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getDimIdentityMap()
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
AffineMap getSymbolIdentityMap()
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
An integer set representing a conjunction of one or more affine equalities and inequalities.
unsigned getNumDims() const
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
unsigned getNumInputs() const
ArrayRef< AffineExpr > getConstraints() const
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
unsigned getNumSymbols() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
operand_range::iterator operand_iterator
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
AffineBound represents a lower or upper bound in the for operation.
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
AffineMap getAffineMap() const
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Canonicalize the affine map result expression order of an affine min/max operation.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Remove duplicated expressions in affine min/max ops.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.