MLIR: lib/Dialect/Arith/IR/ArithOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9 #include
10 #include
11 #include
12 #include
13
25
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/APSInt.h"
29 #include "llvm/ADT/FloatingPointMode.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
34
35 using namespace mlir;
37
38
39
40
41
42 static IntegerAttr
45 function_ref<APInt(const APInt &, const APInt &)> binFn) {
46 APInt lhsVal = llvm::cast(lhs).getValue();
47 APInt rhsVal = llvm::cast(rhs).getValue();
48 APInt value = binFn(lhsVal, rhsVal);
50 }
51
55 }
56
60 }
61
64 return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies());
65 }
66
67
68 static IntegerOverflowFlagsAttr
70 IntegerOverflowFlagsAttr val2) {
72 val1.getValue() & val2.getValue());
73 }
74
75
77 switch (pred) {
78 case arith::CmpIPredicate::eq:
79 return arith::CmpIPredicate::ne;
80 case arith::CmpIPredicate::ne:
81 return arith::CmpIPredicate::eq;
82 case arith::CmpIPredicate::slt:
83 return arith::CmpIPredicate::sge;
84 case arith::CmpIPredicate::sle:
85 return arith::CmpIPredicate::sgt;
86 case arith::CmpIPredicate::sgt:
87 return arith::CmpIPredicate::sle;
88 case arith::CmpIPredicate::sge:
89 return arith::CmpIPredicate::slt;
90 case arith::CmpIPredicate::ult:
91 return arith::CmpIPredicate::uge;
92 case arith::CmpIPredicate::ule:
93 return arith::CmpIPredicate::ugt;
94 case arith::CmpIPredicate::ugt:
95 return arith::CmpIPredicate::ule;
96 case arith::CmpIPredicate::uge:
97 return arith::CmpIPredicate::ult;
98 }
99 llvm_unreachable("unknown cmpi predicate kind");
100 }
101
102
103
104
105
106
107
108 static llvm::RoundingMode
110 switch (roundingMode) {
111 case RoundingMode::downward:
112 return llvm::RoundingMode::TowardNegative;
113 case RoundingMode::to_nearest_away:
114 return llvm::RoundingMode::NearestTiesToAway;
115 case RoundingMode::to_nearest_even:
116 return llvm::RoundingMode::NearestTiesToEven;
117 case RoundingMode::toward_zero:
118 return llvm::RoundingMode::TowardZero;
119 case RoundingMode:📈
120 return llvm::RoundingMode::TowardPositive;
121 }
122 llvm_unreachable("Unhandled rounding mode");
123 }
124
125 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
128 }
129
134
135 return -1;
136 }
137
140 }
141
143 APInt value;
145 return value;
146
147 return failure();
148 }
149
152 ShapedType shapedType = llvm::dyn_cast_or_null(type);
153 if (!shapedType)
154 return boolAttr;
156 }
157
158
159
160
161
162 namespace {
163 #include "ArithCanonicalization.inc"
164 }
165
166
167
168
169
170
173 if (auto shapedType = llvm::dyn_cast(type))
174 return shapedType.cloneWith(std::nullopt, i1Type);
175 if (llvm::isa(type))
177 return i1Type;
178 }
179
180
181
182
183
184 void arith::ConstantOp::getAsmResultNames(
187 if (auto intCst = llvm::dyn_cast(getValue())) {
188 auto intType = llvm::dyn_cast(type);
189
190
191 if (intType && intType.getWidth() == 1)
192 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
193
194
196 llvm::raw_svector_ostream specialName(specialNameBuffer);
197 specialName << 'c' << intCst.getValue();
198 if (intType)
199 specialName << '_' << type;
200 setNameFn(getResult(), specialName.str());
201 } else {
202 setNameFn(getResult(), "cst");
203 }
204 }
205
206
207
210
211 if (llvm::isa(type) &&
212 !llvm::cast(type).isSignless())
213 return emitOpError("integer return type must be signless");
214
215 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
216 return emitOpError(
217 "value must be an integer, float, or elements attribute");
218 }
219
220
221
222
223 if (isa(type) && !isa(getValue()))
224 return emitOpError(
225 "intializing scalable vectors with elements attribute is not supported"
226 " unless it's a vector splat");
227 return success();
228 }
229
230 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
231
232 auto typedAttr = llvm::dyn_cast(value);
233 if (!typedAttr || typedAttr.getType() != type)
234 return false;
235
236 if (llvm::isa(type) &&
237 !llvm::cast(type).isSignless())
238 return false;
239
240 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
241 }
242
243 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
245 if (isBuildableWith(value, type))
246 return builder.createarith::ConstantOp(loc, cast(value));
247 return nullptr;
248 }
249
250 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
251
253 int64_t value, unsigned width) {
255 arith::ConstantOp::build(builder, result, type,
257 }
258
260 int64_t value, Type type) {
262 "ConstantIntOp can only have signless integer type values");
263 arith::ConstantOp::build(builder, result, type,
265 }
266
268 if (auto constOp = dyn_cast_or_nullarith::ConstantOp(op))
269 return constOp.getType().isSignlessInteger();
270 return false;
271 }
272
274 const APFloat &value, FloatType type) {
275 arith::ConstantOp::build(builder, result, type,
277 }
278
280 if (auto constOp = dyn_cast_or_nullarith::ConstantOp(op))
281 return llvm::isa(constOp.getType());
282 return false;
283 }
284
286 int64_t value) {
287 arith::ConstantOp::build(builder, result, builder.getIndexType(),
289 }
290
292 if (auto constOp = dyn_cast_or_nullarith::ConstantOp(op))
293 return constOp.getType().isIndex();
294 return false;
295 }
296
297
298
299
300
301 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
302
304 return getLhs();
305
306
307 if (auto sub = getLhs().getDefiningOp())
308 if (getRhs() == sub.getRhs())
309 return sub.getLhs();
310
311
312 if (auto sub = getRhs().getDefiningOp())
313 if (getLhs() == sub.getRhs())
314 return sub.getLhs();
315
316 return constFoldBinaryOp(
317 adaptor.getOperands(),
318 [](APInt a, const APInt &b) { return std::move(a) + b; });
319 }
320
323 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
324 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
325 }
326
327
328
329
330
331 std::optional<SmallVector<int64_t, 4>>
332 arith::AddUIExtendedOp::getShapeForUnroll() {
333 if (auto vt = llvm::dyn_cast(getType(0)))
334 return llvm::to_vector<4>(vt.getShape());
335 return std::nullopt;
336 }
337
338
339
341 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
342 }
343
344 LogicalResult
345 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
347 Type overflowTy = getOverflow().getType();
348
351 auto falseValue = builder.getZeroAttr(overflowTy);
352
353 results.push_back(getLhs());
354 results.push_back(falseValue);
355 return success();
356 }
357
358
359
360
361
362 if (Attribute sumAttr = constFoldBinaryOp(
363 adaptor.getOperands(),
364 [](APInt a, const APInt &b) { return std::move(a) + b; })) {
365 Attribute overflowAttr = constFoldBinaryOp(
366 ArrayRef({sumAttr, adaptor.getLhs()}),
369 if (!overflowAttr)
370 return failure();
371
372 results.push_back(sumAttr);
373 results.push_back(overflowAttr);
374 return success();
375 }
376
377 return failure();
378 }
379
380 void arith::AddUIExtendedOp::getCanonicalizationPatterns(
382 patterns.add(context);
383 }
384
385
386
387
388
389 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
390
391 if (getOperand(0) == getOperand(1)) {
392 auto shapedType = dyn_cast(getType());
393
394 if (!shapedType || shapedType.hasStaticShape())
396 }
397
399 return getLhs();
400
401 if (auto add = getLhs().getDefiningOp()) {
402
403 if (getRhs() == add.getRhs())
404 return add.getLhs();
405
406 if (getRhs() == add.getLhs())
407 return add.getRhs();
408 }
409
410 return constFoldBinaryOp(
411 adaptor.getOperands(),
412 [](APInt a, const APInt &b) { return std::move(a) - b; });
413 }
414
417 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
418 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
419 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
420 }
421
422
423
424
425
426 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
427
429 return getRhs();
430
432 return getLhs();
433
434
435
436 return constFoldBinaryOp(
437 adaptor.getOperands(),
438 [](const APInt &a, const APInt &b) { return a * b; });
439 }
440
441 void arith::MulIOp::getAsmResultNames(
443 if (!isa(getType()))
444 return;
445
446
447
448 auto isVscale = [](Operation *op) {
449 return op && op->getName().getStringRef() == "vector.vscale";
450 };
451
452 IntegerAttr baseValue;
453 auto isVscaleExpr = [&](Value a, Value b) {
455 isVscale(b.getDefiningOp());
456 };
457
458 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
459 return;
460
461
463 llvm::raw_svector_ostream specialName(specialNameBuffer);
464 specialName << 'c' << baseValue.getInt() << "_vscale";
465 setNameFn(getResult(), specialName.str());
466 }
467
470 patterns.add(context);
471 }
472
473
474
475
476
477 std::optional<SmallVector<int64_t, 4>>
478 arith::MulSIExtendedOp::getShapeForUnroll() {
479 if (auto vt = llvm::dyn_cast(getType(0)))
480 return llvm::to_vector<4>(vt.getShape());
481 return std::nullopt;
482 }
483
484 LogicalResult
485 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
487
489 Attribute zero = adaptor.getRhs();
490 results.push_back(zero);
491 results.push_back(zero);
492 return success();
493 }
494
495
496 if (Attribute lowAttr = constFoldBinaryOp(
497 adaptor.getOperands(),
498 [](const APInt &a, const APInt &b) { return a * b; })) {
499
500 Attribute highAttr = constFoldBinaryOp(
501 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
502 return llvm::APIntOps::mulhs(a, b);
503 });
504 assert(highAttr && "Unexpected constant-folding failure");
505
506 results.push_back(lowAttr);
507 results.push_back(highAttr);
508 return success();
509 }
510
511 return failure();
512 }
513
514 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
516 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
517 }
518
519
520
521
522
523 std::optional<SmallVector<int64_t, 4>>
524 arith::MulUIExtendedOp::getShapeForUnroll() {
525 if (auto vt = llvm::dyn_cast(getType(0)))
526 return llvm::to_vector<4>(vt.getShape());
527 return std::nullopt;
528 }
529
530 LogicalResult
531 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
533
535 Attribute zero = adaptor.getRhs();
536 results.push_back(zero);
537 results.push_back(zero);
538 return success();
539 }
540
541
545 results.push_back(getLhs());
546 results.push_back(zero);
547 return success();
548 }
549
550
551 if (Attribute lowAttr = constFoldBinaryOp(
552 adaptor.getOperands(),
553 [](const APInt &a, const APInt &b) { return a * b; })) {
554
555 Attribute highAttr = constFoldBinaryOp(
556 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
557 return llvm::APIntOps::mulhu(a, b);
558 });
559 assert(highAttr && "Unexpected constant-folding failure");
560
561 results.push_back(lowAttr);
562 results.push_back(highAttr);
563 return success();
564 }
565
566 return failure();
567 }
568
569 void arith::MulUIExtendedOp::getCanonicalizationPatterns(
571 patterns.add(context);
572 }
573
574
575
576
577
578
580 arith::IntegerOverflowFlags ovfFlags) {
581 auto mul = lhs.getDefiningOpmlir::arith::MulIOp();
582 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
583 return {};
584
585 if (mul.getLhs() == rhs)
586 return mul.getRhs();
587
588 if (mul.getRhs() == rhs)
589 return mul.getLhs();
590
591 return {};
592 }
593
594 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
595
597 return getLhs();
598
599
600 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
601 return val;
602
603
604 bool div0 = false;
605 auto result = constFoldBinaryOp(adaptor.getOperands(),
606 [&](APInt a, const APInt &b) {
607 if (div0 || !b) {
608 div0 = true;
609 return a;
610 }
611 return a.udiv(b);
612 });
613
614 return div0 ? Attribute() : result;
615 }
616
617
619
622
624 }
625
628 }
629
630
631
632
633
634 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
635
637 return getLhs();
638
639
640 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
641 return val;
642
643
644 bool overflowOrDiv0 = false;
645 auto result = constFoldBinaryOp(
646 adaptor.getOperands(), [&](APInt a, const APInt &b) {
647 if (overflowOrDiv0 || !b) {
648 overflowOrDiv0 = true;
649 return a;
650 }
651 return a.sdiv_ov(b, overflowOrDiv0);
652 });
653
654 return overflowOrDiv0 ? Attribute() : result;
655 }
656
657
658
659
661
662
666
668 }
669
672 }
673
674
675
676
677
679 bool &overflow) {
680
681 APInt one(a.getBitWidth(), 1, true);
682 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
683 return val.sadd_ov(one, overflow);
684 }
685
686
687
688
689
690 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
691
693 return getLhs();
694
695 bool overflowOrDiv0 = false;
696 auto result = constFoldBinaryOp(
697 adaptor.getOperands(), [&](APInt a, const APInt &b) {
698 if (overflowOrDiv0 || !b) {
699 overflowOrDiv0 = true;
700 return a;
701 }
702 APInt quotient = a.udiv(b);
703 if (!a.urem(b))
704 return quotient;
705 APInt one(a.getBitWidth(), 1, true);
706 return quotient.uadd_ov(one, overflowOrDiv0);
707 });
708
709 return overflowOrDiv0 ? Attribute() : result;
710 }
711
714 }
715
716
717
718
719
720 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
721
723 return getLhs();
724
725
726
727
728 bool overflowOrDiv0 = false;
729 auto result = constFoldBinaryOp(
730 adaptor.getOperands(), [&](APInt a, const APInt &b) {
731 if (overflowOrDiv0 || !b) {
732 overflowOrDiv0 = true;
733 return a;
734 }
735 if (!a)
736 return a;
737
738 unsigned bits = a.getBitWidth();
740 bool aGtZero = a.sgt(zero);
741 bool bGtZero = b.sgt(zero);
742 if (aGtZero && bGtZero) {
743
745 }
746
747
748
749 bool overflowNegA = false;
750 bool overflowNegB = false;
751 bool overflowDiv = false;
752 bool overflowNegRes = false;
753 if (!aGtZero && !bGtZero) {
754
755 APInt posA = zero.ssub_ov(a, overflowNegA);
756 APInt posB = zero.ssub_ov(b, overflowNegB);
758 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
759 return res;
760 }
761 if (!aGtZero && bGtZero) {
762
763 APInt posA = zero.ssub_ov(a, overflowNegA);
764 APInt div = posA.sdiv_ov(b, overflowDiv);
765 APInt res = zero.ssub_ov(div, overflowNegRes);
766 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
767 return res;
768 }
769
770 APInt posB = zero.ssub_ov(b, overflowNegB);
771 APInt div = a.sdiv_ov(posB, overflowDiv);
772 APInt res = zero.ssub_ov(div, overflowNegRes);
773
774 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
775 return res;
776 });
777
778 return overflowOrDiv0 ? Attribute() : result;
779 }
780
783 }
784
785
786
787
788
789 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
790
792 return getLhs();
793
794
795 bool overflowOrDiv = false;
796 auto result = constFoldBinaryOp(
797 adaptor.getOperands(), [&](APInt a, const APInt &b) {
798 if (b.isZero()) {
799 overflowOrDiv = true;
800 return a;
801 }
802 return a.sfloordiv_ov(b, overflowOrDiv);
803 });
804
805 return overflowOrDiv ? Attribute() : result;
806 }
807
808
809
810
811
812 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
813
816
817
818 bool div0 = false;
819 auto result = constFoldBinaryOp(adaptor.getOperands(),
820 [&](APInt a, const APInt &b) {
821 if (div0 || b.isZero()) {
822 div0 = true;
823 return a;
824 }
825 return a.urem(b);
826 });
827
828 return div0 ? Attribute() : result;
829 }
830
831
832
833
834
835 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
836
839
840
841 bool div0 = false;
842 auto result = constFoldBinaryOp(adaptor.getOperands(),
843 [&](APInt a, const APInt &b) {
844 if (div0 || b.isZero()) {
845 div0 = true;
846 return a;
847 }
848 return a.srem(b);
849 });
850
851 return div0 ? Attribute() : result;
852 }
853
854
855
856
857
858
860 for (bool reversePrev : {false, true}) {
861 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
862 .getDefiningOparith::AndIOp();
863 if (!prev)
864 continue;
865
866 Value other = (reversePrev ? op.getLhs() : op.getRhs());
867 if (other != prev.getLhs() && other != prev.getRhs())
868 continue;
869
870 return prev.getResult();
871 }
872 return {};
873 }
874
875 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
876
878 return getRhs();
879
880 APInt intValue;
882 intValue.isAllOnes())
883 return getLhs();
884
887 intValue.isAllOnes())
889
892 intValue.isAllOnes())
894
895
897 return result;
898
899 return constFoldBinaryOp(
900 adaptor.getOperands(),
901 [](APInt a, const APInt &b) { return std::move(a) & b; });
902 }
903
904
905
906
907
908 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
910
911 if (rhsVal.isZero())
912 return getLhs();
913
914 if (rhsVal.isAllOnes())
915 return adaptor.getRhs();
916 }
917
918 APInt intValue;
919
922 intValue.isAllOnes())
923 return getRhs().getDefiningOp().getRhs();
924
927 intValue.isAllOnes())
928 return getLhs().getDefiningOp().getRhs();
929
930 return constFoldBinaryOp(
931 adaptor.getOperands(),
932 [](APInt a, const APInt &b) { return std::move(a) | b; });
933 }
934
935
936
937
938
939 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
940
942 return getLhs();
943
944 if (getLhs() == getRhs())
946
947
948 if (arith::XOrIOp prev = getLhs().getDefiningOparith::XOrIOp()) {
949 if (prev.getRhs() == getRhs())
950 return prev.getLhs();
951 if (prev.getLhs() == getRhs())
952 return prev.getRhs();
953 }
954
955
956 if (arith::XOrIOp prev = getRhs().getDefiningOparith::XOrIOp()) {
957 if (prev.getRhs() == getLhs())
958 return prev.getLhs();
959 if (prev.getLhs() == getLhs())
960 return prev.getRhs();
961 }
962
963 return constFoldBinaryOp(
964 adaptor.getOperands(),
965 [](APInt a, const APInt &b) { return std::move(a) ^ b; });
966 }
967
970 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
971 }
972
973
974
975
976
977 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
978
979 if (auto op = this->getOperand().getDefiningOparith::NegFOp())
980 return op.getOperand();
981 return constFoldUnaryOp(adaptor.getOperands(),
982 [](const APFloat &a) { return -a; });
983 }
984
985
986
987
988
989 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
990
992 return getLhs();
993
994 return constFoldBinaryOp(
995 adaptor.getOperands(),
996 [](const APFloat &a, const APFloat &b) { return a + b; });
997 }
998
999
1000
1001
1002
1003 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1004
1006 return getLhs();
1007
1008 return constFoldBinaryOp(
1009 adaptor.getOperands(),
1010 [](const APFloat &a, const APFloat &b) { return a - b; });
1011 }
1012
1013
1014
1015
1016
1017 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1018
1019 if (getLhs() == getRhs())
1020 return getRhs();
1021
1022
1024 return getLhs();
1025
1026 return constFoldBinaryOp(
1027 adaptor.getOperands(),
1028 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1029 }
1030
1031
1032
1033
1034
1035 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1036
1037 if (getLhs() == getRhs())
1038 return getRhs();
1039
1040
1042 return getLhs();
1043
1044 return constFoldBinaryOp(adaptor.getOperands(), llvm::maxnum);
1045 }
1046
1047
1048
1049
1050
1051 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1052
1053 if (getLhs() == getRhs())
1054 return getRhs();
1055
1056 if (APInt intValue;
1058
1059 if (intValue.isMaxSignedValue())
1060 return getRhs();
1061
1062 if (intValue.isMinSignedValue())
1063 return getLhs();
1064 }
1065
1066 return constFoldBinaryOp(adaptor.getOperands(),
1067 [](const APInt &a, const APInt &b) {
1068 return llvm::APIntOps::smax(a, b);
1069 });
1070 }
1071
1072
1073
1074
1075
1076 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1077
1078 if (getLhs() == getRhs())
1079 return getRhs();
1080
1081 if (APInt intValue;
1083
1084 if (intValue.isMaxValue())
1085 return getRhs();
1086
1087 if (intValue.isMinValue())
1088 return getLhs();
1089 }
1090
1091 return constFoldBinaryOp(adaptor.getOperands(),
1092 [](const APInt &a, const APInt &b) {
1093 return llvm::APIntOps::umax(a, b);
1094 });
1095 }
1096
1097
1098
1099
1100
1101 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1102
1103 if (getLhs() == getRhs())
1104 return getRhs();
1105
1106
1108 return getLhs();
1109
1110 return constFoldBinaryOp(
1111 adaptor.getOperands(),
1112 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1113 }
1114
1115
1116
1117
1118
1119 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1120
1121 if (getLhs() == getRhs())
1122 return getRhs();
1123
1124
1126 return getLhs();
1127
1128 return constFoldBinaryOp(
1129 adaptor.getOperands(),
1130 [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1131 }
1132
1133
1134
1135
1136
1137 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1138
1139 if (getLhs() == getRhs())
1140 return getRhs();
1141
1142 if (APInt intValue;
1144
1145 if (intValue.isMinSignedValue())
1146 return getRhs();
1147
1148 if (intValue.isMaxSignedValue())
1149 return getLhs();
1150 }
1151
1152 return constFoldBinaryOp(adaptor.getOperands(),
1153 [](const APInt &a, const APInt &b) {
1154 return llvm::APIntOps::smin(a, b);
1155 });
1156 }
1157
1158
1159
1160
1161
1162 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1163
1164 if (getLhs() == getRhs())
1165 return getRhs();
1166
1167 if (APInt intValue;
1169
1170 if (intValue.isMinValue())
1171 return getRhs();
1172
1173 if (intValue.isMaxValue())
1174 return getLhs();
1175 }
1176
1177 return constFoldBinaryOp(adaptor.getOperands(),
1178 [](const APInt &a, const APInt &b) {
1179 return llvm::APIntOps::umin(a, b);
1180 });
1181 }
1182
1183
1184
1185
1186
1187 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1188
1190 return getLhs();
1191
1192 return constFoldBinaryOp(
1193 adaptor.getOperands(),
1194 [](const APFloat &a, const APFloat &b) { return a * b; });
1195 }
1196
1199 patterns.add(context);
1200 }
1201
1202
1203
1204
1205
1206 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1207
1209 return getLhs();
1210
1211 return constFoldBinaryOp(
1212 adaptor.getOperands(),
1213 [](const APFloat &a, const APFloat &b) { return a / b; });
1214 }
1215
1218 patterns.add(context);
1219 }
1220
1221
1222
1223
1224
1225 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1226 return constFoldBinaryOp(adaptor.getOperands(),
1227 [](const APFloat &a, const APFloat &b) {
1228 APFloat result(a);
1229
1230
1231
1232 (void)result.mod(b);
1233 return result;
1234 });
1235 }
1236
1237
1238
1239
1240
1241 template <typename... Types>
1243
1244
1245
1246
1247 template <typename... ShapedTypes, typename... ElementTypes>
1250 if (llvm::isa(type) && !llvm::isa<ShapedTypes...>(type))
1251 return {};
1252
1254 if (!llvm::isa<ElementTypes...>(underlyingType))
1255 return {};
1256
1257 return underlyingType;
1258 }
1259
1260
1261 template <typename... ElementTypes>
1265 }
1266
1267
1268 template <typename... ElementTypes>
1273 }
1274
1275
1277 auto rankedTensorA = dyn_cast(typeA);
1278 auto rankedTensorB = dyn_cast(typeB);
1279 if (!rankedTensorA || !rankedTensorB)
1280 return true;
1281 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1282 }
1283
1285 if (inputs.size() != 1 || outputs.size() != 1)
1286 return false;
1288 return false;
1290 }
1291
1292
1293
1294
1295
1296
1297 template <typename ValType, typename Op>
1301
1302 if (llvm::cast(srcType).getWidth() >=
1303 llvm::cast(dstType).getWidth())
1304 return op.emitError("result type ")
1305 << dstType << " must be wider than operand type " << srcType;
1306
1307 return success();
1308 }
1309
1310
1311 template <typename ValType, typename Op>
1315
1316 if (llvm::cast(srcType).getWidth() <=
1317 llvm::cast(dstType).getWidth())
1318 return op.emitError("result type ")
1319 << dstType << " must be shorter than operand type " << srcType;
1320
1321 return success();
1322 }
1323
1324
1325 template <template <typename> class WidthComparator, typename... ElementTypes>
1328 return false;
1329
1330 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1331 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1332 if (!srcType || !dstType)
1333 return false;
1334
1335 return WidthComparator()(dstType.getIntOrFloatBitWidth(),
1336 srcType.getIntOrFloatBitWidth());
1337 }
1338
1339
1340
1342 APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1343 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1344 bool losesInfo = false;
1345 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1346 if (losesInfo || status != APFloat::opOK)
1347 return failure();
1348
1349 return sourceValue;
1350 }
1351
1352
1353
1354
1355
1356 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1357 if (auto lhs = getIn().getDefiningOp()) {
1358 getInMutable().assign(lhs.getIn());
1359 return getResult();
1360 }
1361
1363 unsigned bitWidth = llvm::cast(resType).getWidth();
1364 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1365 adaptor.getOperands(), getType(),
1366 [bitWidth](const APInt &a, bool &castStatus) {
1367 return a.zext(bitWidth);
1368 });
1369 }
1370
1371 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1372 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1373 }
1374
1376 return verifyExtOp(*this);
1377 }
1378
1379
1380
1381
1382
1383 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1384 if (auto lhs = getIn().getDefiningOp()) {
1385 getInMutable().assign(lhs.getIn());
1386 return getResult();
1387 }
1388
1390 unsigned bitWidth = llvm::cast(resType).getWidth();
1391 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1392 adaptor.getOperands(), getType(),
1393 [bitWidth](const APInt &a, bool &castStatus) {
1394 return a.sext(bitWidth);
1395 });
1396 }
1397
1398 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1399 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1400 }
1401
1404 patterns.add(context);
1405 }
1406
1408 return verifyExtOp(*this);
1409 }
1410
1411
1412
1413
1414
1415
1416
1417 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1418 if (auto truncFOp = getOperand().getDefiningOp()) {
1419 if (truncFOp.getOperand().getType() == getType()) {
1420 arith::FastMathFlags truncFMF =
1421 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1422 bool isTruncContract =
1424 arith::FastMathFlags extFMF =
1425 getFastmath().value_or(arith::FastMathFlags::none);
1426 bool isExtContract =
1428 if (isTruncContract && isExtContract) {
1429 return truncFOp.getOperand();
1430 }
1431 }
1432 }
1433
1435 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1436 return constFoldCastOp<FloatAttr, FloatAttr>(
1437 adaptor.getOperands(), getType(),
1438 [&targetSemantics](const APFloat &a, bool &castStatus) {
1439 FailureOr result = convertFloatValue(a, targetSemantics);
1440 if (failed(result)) {
1441 castStatus = false;
1442 return a;
1443 }
1444 return *result;
1445 });
1446 }
1447
1448 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1449 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1450 }
1451
1453
1454
1455
1456
1457
1458 bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1460 return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1461 }
1462
1464 return verifyExtOp(*this);
1465 }
1466
1467
1468
1469
1470
1471 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1472 if (matchPattern(getOperand(), m_Oparith::ExtUIOp()) ||
1473 matchPattern(getOperand(), m_Oparith::ExtSIOp())) {
1477
1478
1479 if (llvm::cast(srcType).getWidth() >
1480 llvm::cast(dstType).getWidth()) {
1481 setOperand(src);
1482 return getResult();
1483 }
1484
1485
1486
1487 if (srcType == dstType)
1488 return src;
1489 }
1490
1491
1492 if (matchPattern(getOperand(), m_Oparith::TruncIOp())) {
1493 setOperand(getOperand().getDefiningOp()->getOperand(0));
1494 return getResult();
1495 }
1496
1498 unsigned bitWidth = llvm::cast(resType).getWidth();
1499 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1500 adaptor.getOperands(), getType(),
1501 [bitWidth](const APInt &a, bool &castStatus) {
1502 return a.trunc(bitWidth);
1503 });
1504 }
1505
1506 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1507 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1508 }
1509
1512 patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1513 TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1514 context);
1515 }
1516
1518 return verifyTruncateOp(*this);
1519 }
1520
1521
1522
1523
1524
1525
1526
1527 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1529 if (auto extOp = getOperand().getDefiningOparith::ExtFOp()) {
1530 Value src = extOp.getIn();
1532 auto intermediateType =
1534
1535 if (llvm::APFloatBase::isRepresentableBy(
1536 srcType.getFloatSemantics(),
1537 intermediateType.getFloatSemantics())) {
1538
1539 if (srcType.getWidth() > resElemType.getWidth()) {
1540 setOperand(src);
1541 return getResult();
1542 }
1543
1544
1545 if (srcType == resElemType)
1546 return src;
1547 }
1548 }
1549
1550 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1551 return constFoldCastOp<FloatAttr, FloatAttr>(
1552 adaptor.getOperands(), getType(),
1553 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1554 RoundingMode roundingMode =
1555 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1556 llvm::RoundingMode llvmRoundingMode =
1558 FailureOr result =
1560 if (failed(result)) {
1561 castStatus = false;
1562 return a;
1563 }
1564 return *result;
1565 });
1566 }
1567
1570 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1571 }
1572
1573 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1574 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1575 }
1576
1578 return verifyTruncateOp(*this);
1579 }
1580
1581
1582
1583
1584
1585 bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1587 return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1588 }
1589
1591 return verifyTruncateOp(*this);
1592 }
1593
1594
1595
1596
1597
1600 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1601 }
1602
1603
1604
1605
1606
1609 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1610 }
1611
1612
1613
1614
1615
1616 template <typename From, typename To>
1619 return false;
1620
1621 auto srcType = getTypeIfLike(inputs.front());
1622 auto dstType = getTypeIfLike(outputs.back());
1623
1624 return srcType && dstType;
1625 }
1626
1627
1628
1629
1630
1631 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1632 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1633 }
1634
1635 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1637 return constFoldCastOp<IntegerAttr, FloatAttr>(
1638 adaptor.getOperands(), getType(),
1639 [&resEleType](const APInt &a, bool &castStatus) {
1640 FloatType floatTy = llvm::cast(resEleType);
1641 APFloat apf(floatTy.getFloatSemantics(),
1643 apf.convertFromAPInt(a, false,
1644 APFloat::rmNearestTiesToEven);
1645 return apf;
1646 });
1647 }
1648
1649
1650
1651
1652
1653 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1654 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1655 }
1656
1657 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1659 return constFoldCastOp<IntegerAttr, FloatAttr>(
1660 adaptor.getOperands(), getType(),
1661 [&resEleType](const APInt &a, bool &castStatus) {
1662 FloatType floatTy = llvm::cast(resEleType);
1663 APFloat apf(floatTy.getFloatSemantics(),
1665 apf.convertFromAPInt(a, true,
1666 APFloat::rmNearestTiesToEven);
1667 return apf;
1668 });
1669 }
1670
1671
1672
1673
1674
1675 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1676 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1677 }
1678
1679 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1681 unsigned bitWidth = llvm::cast(resType).getWidth();
1682 return constFoldCastOp<FloatAttr, IntegerAttr>(
1683 adaptor.getOperands(), getType(),
1684 [&bitWidth](const APFloat &a, bool &castStatus) {
1685 bool ignored;
1686 APSInt api(bitWidth, true);
1687 castStatus = APFloat::opInvalidOp !=
1688 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1689 return api;
1690 });
1691 }
1692
1693
1694
1695
1696
1697 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1698 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1699 }
1700
1701 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1703 unsigned bitWidth = llvm::cast(resType).getWidth();
1704 return constFoldCastOp<FloatAttr, IntegerAttr>(
1705 adaptor.getOperands(), getType(),
1706 [&bitWidth](const APFloat &a, bool &castStatus) {
1707 bool ignored;
1708 APSInt api(bitWidth, false);
1709 castStatus = APFloat::opInvalidOp !=
1710 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1711 return api;
1712 });
1713 }
1714
1715
1716
1717
1718
1721 return false;
1722
1723 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1724 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1725 if (!srcType || !dstType)
1726 return false;
1727
1730 }
1731
1732 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1735 }
1736
1737 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1738
1739 unsigned resultBitwidth = 64;
1741 resultBitwidth = intTy.getWidth();
1742
1743 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1744 adaptor.getOperands(), getType(),
1745 [resultBitwidth](const APInt &a, bool & ) {
1746 return a.sextOrTrunc(resultBitwidth);
1747 });
1748 }
1749
1750 void arith::IndexCastOp::getCanonicalizationPatterns(
1752 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1753 }
1754
1755
1756
1757
1758
1759 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1762 }
1763
1764 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1765
1766 unsigned resultBitwidth = 64;
1768 resultBitwidth = intTy.getWidth();
1769
1770 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1771 adaptor.getOperands(), getType(),
1772 [resultBitwidth](const APInt &a, bool & ) {
1773 return a.zextOrTrunc(resultBitwidth);
1774 });
1775 }
1776
1777 void arith::IndexCastUIOp::getCanonicalizationPatterns(
1779 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1780 }
1781
1782
1783
1784
1785
1786 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1788 return false;
1789
1790 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1791 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1792 if (!srcType || !dstType)
1793 return false;
1794
1796 }
1797
1798 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1799 auto resType = getType();
1800 auto operand = adaptor.getIn();
1801 if (!operand)
1802 return {};
1803
1804
1805 if (auto denseAttr = llvm::dyn_cast_or_null(operand))
1806 return denseAttr.bitcast(llvm::cast(resType).getElementType());
1807
1808 if (llvm::isa(resType))
1809 return {};
1810
1811
1812 if (llvm::isaub::PoisonAttr(operand))
1814
1815
1816 APInt bits = llvm::isa(operand)
1817 ? llvm::cast(operand).getValue().bitcastToAPInt()
1818 : llvm::cast(operand).getValue();
1820 "trying to fold on broken IR: operands have incompatible types");
1821
1822 if (auto resFloatType = llvm::dyn_cast(resType))
1824 APFloat(resFloatType.getFloatSemantics(), bits));
1826 }
1827
1830 patterns.add(context);
1831 }
1832
1833
1834
1835
1836
1837
1838
1840 const APInt &lhs, const APInt &rhs) {
1841 switch (predicate) {
1842 case arith::CmpIPredicate::eq:
1843 return lhs.eq(rhs);
1844 case arith::CmpIPredicate::ne:
1845 return lhs.ne(rhs);
1846 case arith::CmpIPredicate::slt:
1847 return lhs.slt(rhs);
1848 case arith::CmpIPredicate::sle:
1849 return lhs.sle(rhs);
1850 case arith::CmpIPredicate::sgt:
1851 return lhs.sgt(rhs);
1852 case arith::CmpIPredicate::sge:
1853 return lhs.sge(rhs);
1854 case arith::CmpIPredicate::ult:
1855 return lhs.ult(rhs);
1856 case arith::CmpIPredicate::ule:
1857 return lhs.ule(rhs);
1858 case arith::CmpIPredicate::ugt:
1859 return lhs.ugt(rhs);
1860 case arith::CmpIPredicate::uge:
1861 return lhs.uge(rhs);
1862 }
1863 llvm_unreachable("unknown cmpi predicate kind");
1864 }
1865
1866
1868 switch (predicate) {
1869 case arith::CmpIPredicate::eq:
1870 case arith::CmpIPredicate::sle:
1871 case arith::CmpIPredicate::sge:
1872 case arith::CmpIPredicate::ule:
1873 case arith::CmpIPredicate::uge:
1874 return true;
1875 case arith::CmpIPredicate::ne:
1876 case arith::CmpIPredicate::slt:
1877 case arith::CmpIPredicate::sgt:
1878 case arith::CmpIPredicate::ult:
1879 case arith::CmpIPredicate::ugt:
1880 return false;
1881 }
1882 llvm_unreachable("unknown cmpi predicate kind");
1883 }
1884
1886 if (auto intType = llvm::dyn_cast(t)) {
1887 return intType.getWidth();
1888 }
1889 if (auto vectorIntType = llvm::dyn_cast(t)) {
1890 return llvm::cast(vectorIntType.getElementType()).getWidth();
1891 }
1892 return std::nullopt;
1893 }
1894
1895 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1896
1897 if (getLhs() == getRhs()) {
1900 }
1901
1903 if (auto extOp = getLhs().getDefiningOp()) {
1904
1905 std::optional<int64_t> integerWidth =
1907 if (integerWidth && integerWidth.value() == 1 &&
1908 getPredicate() == arith::CmpIPredicate::ne)
1909 return extOp.getOperand();
1910 }
1911 if (auto extOp = getLhs().getDefiningOp()) {
1912
1913 std::optional<int64_t> integerWidth =
1915 if (integerWidth && integerWidth.value() == 1 &&
1916 getPredicate() == arith::CmpIPredicate::ne)
1917 return extOp.getOperand();
1918 }
1919
1920
1922 getPredicate() == arith::CmpIPredicate::ne)
1923 return getLhs();
1924 }
1925
1927
1929 getPredicate() == arith::CmpIPredicate::eq)
1930 return getLhs();
1931 }
1932
1933
1934 if (adaptor.getLhs() && !adaptor.getRhs()) {
1935
1936 using Pred = CmpIPredicate;
1937 const std::pair<Pred, Pred> invPreds[] = {
1938 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1939 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1940 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1941 {Pred::ne, Pred::ne},
1942 };
1943 Pred origPred = getPredicate();
1944 for (auto pred : invPreds) {
1945 if (origPred == pred.first) {
1946 setPredicate(pred.second);
1947 Value lhs = getLhs();
1948 Value rhs = getRhs();
1949 getLhsMutable().assign(rhs);
1950 getRhsMutable().assign(lhs);
1951 return getResult();
1952 }
1953 }
1954 llvm_unreachable("unknown cmpi predicate kind");
1955 }
1956
1957
1958
1959 if (auto lhs = llvm::dyn_cast_if_present(adaptor.getLhs())) {
1960 return constFoldBinaryOp(
1961 adaptor.getOperands(), getI1SameShape(lhs.getType()),
1962 [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1963 return APInt(1,
1965 });
1966 }
1967
1968 return {};
1969 }
1970
1973 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1974 }
1975
1976
1977
1978
1979
1980
1981
1983 const APFloat &lhs, const APFloat &rhs) {
1984 auto cmpResult = lhs.compare(rhs);
1985 switch (predicate) {
1986 case arith::CmpFPredicate::AlwaysFalse:
1987 return false;
1988 case arith::CmpFPredicate::OEQ:
1989 return cmpResult == APFloat::cmpEqual;
1990 case arith::CmpFPredicate::OGT:
1991 return cmpResult == APFloat::cmpGreaterThan;
1992 case arith::CmpFPredicate::OGE:
1993 return cmpResult == APFloat::cmpGreaterThan ||
1994 cmpResult == APFloat::cmpEqual;
1995 case arith::CmpFPredicate::OLT:
1996 return cmpResult == APFloat::cmpLessThan;
1997 case arith::CmpFPredicate::OLE:
1998 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1999 case arith::CmpFPredicate::ONE:
2000 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2001 case arith::CmpFPredicate::ORD:
2002 return cmpResult != APFloat::cmpUnordered;
2003 case arith::CmpFPredicate::UEQ:
2004 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2005 case arith::CmpFPredicate::UGT:
2006 return cmpResult == APFloat::cmpUnordered ||
2007 cmpResult == APFloat::cmpGreaterThan;
2008 case arith::CmpFPredicate::UGE:
2009 return cmpResult == APFloat::cmpUnordered ||
2010 cmpResult == APFloat::cmpGreaterThan ||
2011 cmpResult == APFloat::cmpEqual;
2012 case arith::CmpFPredicate::ULT:
2013 return cmpResult == APFloat::cmpUnordered ||
2014 cmpResult == APFloat::cmpLessThan;
2015 case arith::CmpFPredicate::ULE:
2016 return cmpResult == APFloat::cmpUnordered ||
2017 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2018 case arith::CmpFPredicate::UNE:
2019 return cmpResult != APFloat::cmpEqual;
2020 case arith::CmpFPredicate::UNO:
2021 return cmpResult == APFloat::cmpUnordered;
2022 case arith::CmpFPredicate::AlwaysTrue:
2023 return true;
2024 }
2025 llvm_unreachable("unknown cmpf predicate kind");
2026 }
2027
2028 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2029 auto lhs = llvm::dyn_cast_if_present(adaptor.getLhs());
2030 auto rhs = llvm::dyn_cast_if_present(adaptor.getRhs());
2031
2032
2033 if (lhs && lhs.getValue().isNaN())
2034 rhs = lhs;
2035 if (rhs && rhs.getValue().isNaN())
2036 lhs = rhs;
2037
2038 if (!lhs || !rhs)
2039 return {};
2040
2041 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2043 }
2044
2046 public:
2048
2050 bool isUnsigned) {
2051 using namespace arith;
2052 switch (pred) {
2053 case CmpFPredicate::UEQ:
2054 case CmpFPredicate::OEQ:
2055 return CmpIPredicate::eq;
2056 case CmpFPredicate::UGT:
2057 case CmpFPredicate::OGT:
2058 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2059 case CmpFPredicate::UGE:
2060 case CmpFPredicate::OGE:
2061 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2062 case CmpFPredicate::ULT:
2063 case CmpFPredicate::OLT:
2064 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2065 case CmpFPredicate::ULE:
2066 case CmpFPredicate::OLE:
2067 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2068 case CmpFPredicate::UNE:
2069 case CmpFPredicate::ONE:
2070 return CmpIPredicate::ne;
2071 default:
2072 llvm_unreachable("Unexpected predicate!");
2073 }
2074 }
2075
2078 FloatAttr flt;
2080 return failure();
2081
2082 const APFloat &rhs = flt.getValue();
2083
2084
2085 if (rhs.isNaN())
2086 return failure();
2087
2088
2089
2090 FloatType floatTy = llvm::cast(op.getRhs().getType());
2091 int mantissaWidth = floatTy.getFPMantissaWidth();
2092 if (mantissaWidth <= 0)
2093 return failure();
2094
2095 bool isUnsigned;
2097
2098 if (auto si = op.getLhs().getDefiningOp()) {
2099 isUnsigned = false;
2100 intVal = si.getIn();
2101 } else if (auto ui = op.getLhs().getDefiningOp()) {
2102 isUnsigned = true;
2103 intVal = ui.getIn();
2104 } else {
2105 return failure();
2106 }
2107
2108
2109
2110 auto intTy = llvm::cast(intVal.getType());
2111 auto intWidth = intTy.getWidth();
2112
2113
2114 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2115
2116
2117
2118
2119 if ((int)intWidth > mantissaWidth) {
2120
2121 int exponent = ilogb(rhs);
2122 if (exponent == APFloat::IEK_Inf) {
2123 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2124 if (maxExponent < (int)valueBits) {
2125
2126 return failure();
2127 }
2128 } else {
2129
2130
2131 if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2132
2133 return failure();
2134 }
2135 }
2136 }
2137
2138
2139 CmpIPredicate pred;
2140 switch (op.getPredicate()) {
2141 case CmpFPredicate::ORD:
2142
2144 1);
2145 return success();
2146 case CmpFPredicate::UNO:
2147
2149 1);
2150 return success();
2151 default:
2153 break;
2154 }
2155
2156 if (!isUnsigned) {
2157
2158
2159 APFloat signedMax(rhs.getSemantics());
2160 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2161 APFloat::rmNearestTiesToEven);
2162 if (signedMax < rhs) {
2163 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2164 pred == CmpIPredicate::sle)
2166 1);
2167 else
2169 1);
2170 return success();
2171 }
2172 } else {
2173
2174
2175 APFloat unsignedMax(rhs.getSemantics());
2176 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2177 APFloat::rmNearestTiesToEven);
2178 if (unsignedMax < rhs) {
2179 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2180 pred == CmpIPredicate::ule)
2182 1);
2183 else
2185 1);
2186 return success();
2187 }
2188 }
2189
2190 if (!isUnsigned) {
2191
2192 APFloat signedMin(rhs.getSemantics());
2193 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2194 APFloat::rmNearestTiesToEven);
2195 if (signedMin > rhs) {
2196 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2197 pred == CmpIPredicate::sge)
2199 1);
2200 else
2202 1);
2203 return success();
2204 }
2205 } else {
2206
2207 APFloat unsignedMin(rhs.getSemantics());
2208 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2209 APFloat::rmNearestTiesToEven);
2210 if (unsignedMin > rhs) {
2211 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2212 pred == CmpIPredicate::uge)
2214 1);
2215 else
2217 1);
2218 return success();
2219 }
2220 }
2221
2222
2223
2224
2225
2226 bool ignored;
2227 APSInt rhsInt(intWidth, isUnsigned);
2228 if (APFloat::opInvalidOp ==
2229 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2230
2231
2232 return failure();
2233 }
2234
2235 if (!rhs.isZero()) {
2236 APFloat apf(floatTy.getFloatSemantics(),
2238 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2239
2240 bool equal = apf == rhs;
2241 if (!equal) {
2242
2243
2244
2245 switch (pred) {
2246 case CmpIPredicate::ne:
2248 1);
2249 return success();
2250 case CmpIPredicate::eq:
2252 1);
2253 return success();
2254 case CmpIPredicate::ule:
2255
2256
2257 if (rhs.isNegative()) {
2259 1);
2260 return success();
2261 }
2262 break;
2263 case CmpIPredicate::sle:
2264
2265
2266 if (rhs.isNegative())
2267 pred = CmpIPredicate::slt;
2268 break;
2269 case CmpIPredicate::ult:
2270
2271
2272 if (rhs.isNegative()) {
2274 1);
2275 return success();
2276 }
2277 pred = CmpIPredicate::ule;
2278 break;
2279 case CmpIPredicate::slt:
2280
2281
2282 if (!rhs.isNegative())
2283 pred = CmpIPredicate::sle;
2284 break;
2285 case CmpIPredicate::ugt:
2286
2287
2288 if (rhs.isNegative()) {
2290 1);
2291 return success();
2292 }
2293 break;
2294 case CmpIPredicate::sgt:
2295
2296
2297 if (rhs.isNegative())
2298 pred = CmpIPredicate::sge;
2299 break;
2300 case CmpIPredicate::uge:
2301
2302
2303 if (rhs.isNegative()) {
2305 1);
2306 return success();
2307 }
2308 pred = CmpIPredicate::ugt;
2309 break;
2310 case CmpIPredicate::sge:
2311
2312
2313 if (!rhs.isNegative())
2314 pred = CmpIPredicate::sgt;
2315 break;
2316 }
2317 }
2318 }
2319
2320
2321
2323 op, pred, intVal,
2324 rewriter.create(
2325 op.getLoc(), intVal.getType(),
2327 return success();
2328 }
2329 };
2330
2334 }
2335
2336
2337
2338
2339
2340
2343
2346
2347 if (!llvm::isa(op.getType()) || op.getType().isInteger(1))
2348 return failure();
2349
2350
2354 op.getCondition());
2355 return success();
2356 }
2357
2358
2362 op, op.getType(),
2363 rewriter.createarith::XOrIOp(
2364 op.getLoc(), op.getCondition(),
2365 rewriter.createarith::ConstantIntOp(
2366 op.getLoc(), 1, op.getCondition().getType())));
2367 return success();
2368 }
2369
2370 return failure();
2371 }
2372 };
2373
2374 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2376 results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2378 }
2379
2380 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2381 Value trueVal = getTrueValue();
2382 Value falseVal = getFalseValue();
2383 if (trueVal == falseVal)
2384 return trueVal;
2385
2386 Value condition = getCondition();
2387
2388
2390 return trueVal;
2391
2392
2394 return falseVal;
2395
2396
2397 if (isa_and_nonnullub::PoisonAttr(adaptor.getTrueValue()))
2398 return falseVal;
2399
2400 if (isa_and_nonnullub::PoisonAttr(adaptor.getFalseValue()))
2401 return trueVal;
2402
2403
2404 if (getType().isSignlessInteger(1) &&
2407 return condition;
2408
2409 if (auto cmp = dyn_cast_or_nullarith::CmpIOp(condition.getDefiningOp())) {
2410 auto pred = cmp.getPredicate();
2411 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2412 auto cmpLhs = cmp.getLhs();
2413 auto cmpRhs = cmp.getRhs();
2414
2415
2416
2417
2418
2419
2420
2421 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2422 (cmpRhs == trueVal && cmpLhs == falseVal))
2423 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2424 }
2425 }
2426
2427
2428
2429 if (auto cond =
2430 llvm::dyn_cast_if_present(adaptor.getCondition())) {
2431 if (auto lhs =
2432 llvm::dyn_cast_if_present(adaptor.getTrueValue())) {
2433 if (auto rhs =
2434 llvm::dyn_cast_if_present(adaptor.getFalseValue())) {
2436 results.reserve(static_cast<size_t>(cond.getNumElements()));
2437 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2438 cond.value_end<BoolAttr>());
2439 auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2441 auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2443
2444 for (auto [condVal, lhsVal, rhsVal] :
2445 llvm::zip_equal(condVals, lhsVals, rhsVals))
2446 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2447
2449 }
2450 }
2451 }
2452
2453 return nullptr;
2454 }
2455
2457 Type conditionType, resultType;
2459 if (parser.parseOperandList(operands, 3) ||
2462 return failure();
2463
2464
2466 conditionType = resultType;
2467 if (parser.parseType(resultType))
2468 return failure();
2469 } else {
2471 }
2472
2473 result.addTypes(resultType);
2475 {conditionType, resultType, resultType},
2477 }
2478
2480 p << " " << getOperands();
2482 p << " : ";
2483 if (ShapedType condType =
2484 llvm::dyn_cast(getCondition().getType()))
2485 p << condType << ", ";
2487 }
2488
2490 Type conditionType = getCondition().getType();
2492 return success();
2493
2494
2495
2497 if (!llvm::isa<TensorType, VectorType>(resultType))
2498 return emitOpError() << "expected condition to be a signless i1, but got "
2499 << conditionType;
2501 if (conditionType != shapedConditionType) {
2502 return emitOpError() << "expected condition type to have the same shape "
2503 "as the result type, expected "
2504 << shapedConditionType << ", but got "
2505 << conditionType;
2506 }
2507 return success();
2508 }
2509
2510
2511
2512
2513 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2514
2516 return getLhs();
2517
2518 bool bounded = false;
2519 auto result = constFoldBinaryOp(
2520 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2521 bounded = b.ult(b.getBitWidth());
2522 return a.shl(b);
2523 });
2524 return bounded ? result : Attribute();
2525 }
2526
2527
2528
2529
2530
2531 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2532
2534 return getLhs();
2535
2536 bool bounded = false;
2537 auto result = constFoldBinaryOp(
2538 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2539 bounded = b.ult(b.getBitWidth());
2540 return a.lshr(b);
2541 });
2542 return bounded ? result : Attribute();
2543 }
2544
2545
2546
2547
2548
2549 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2550
2552 return getLhs();
2553
2554 bool bounded = false;
2555 auto result = constFoldBinaryOp(
2556 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2557 bounded = b.ult(b.getBitWidth());
2558 return a.ashr(b);
2559 });
2560 return bounded ? result : Attribute();
2561 }
2562
2563
2564
2565
2566
2567
2570 bool useOnlyFiniteValue) {
2571 switch (kind) {
2572 case AtomicRMWKind::maximumf: {
2573 const llvm::fltSemantics &semantic =
2574 llvm::cast(resultType).getFloatSemantics();
2575 APFloat identity = useOnlyFiniteValue
2576 ? APFloat::getLargest(semantic, true)
2577 : APFloat::getInf(semantic, true);
2578 return builder.getFloatAttr(resultType, identity);
2579 }
2580 case AtomicRMWKind::maxnumf: {
2581 const llvm::fltSemantics &semantic =
2582 llvm::cast(resultType).getFloatSemantics();
2583 APFloat identity = APFloat::getNaN(semantic, true);
2584 return builder.getFloatAttr(resultType, identity);
2585 }
2586 case AtomicRMWKind::addf:
2587 case AtomicRMWKind::addi:
2588 case AtomicRMWKind::maxu:
2589 case AtomicRMWKind::ori:
2591 case AtomicRMWKind::andi:
2593 resultType,
2594 APInt::getAllOnes(llvm::cast(resultType).getWidth()));
2595 case AtomicRMWKind::maxs:
2597 resultType, APInt::getSignedMinValue(
2598 llvm::cast(resultType).getWidth()));
2599 case AtomicRMWKind::minimumf: {
2600 const llvm::fltSemantics &semantic =
2601 llvm::cast(resultType).getFloatSemantics();
2602 APFloat identity = useOnlyFiniteValue
2603 ? APFloat::getLargest(semantic, false)
2604 : APFloat::getInf(semantic, false);
2605
2606 return builder.getFloatAttr(resultType, identity);
2607 }
2608 case AtomicRMWKind::minnumf: {
2609 const llvm::fltSemantics &semantic =
2610 llvm::cast(resultType).getFloatSemantics();
2611 APFloat identity = APFloat::getNaN(semantic, false);
2612 return builder.getFloatAttr(resultType, identity);
2613 }
2614 case AtomicRMWKind::mins:
2616 resultType, APInt::getSignedMaxValue(
2617 llvm::cast(resultType).getWidth()));
2618 case AtomicRMWKind::minu:
2620 resultType,
2621 APInt::getMaxValue(llvm::cast(resultType).getWidth()));
2622 case AtomicRMWKind::muli:
2624 case AtomicRMWKind::mulf:
2626
2627 default:
2628 (void)emitOptionalError(loc, "Reduction operation type not supported");
2629 break;
2630 }
2631 return nullptr;
2632 }
2633
2634
2636 std::optional maybeKind =
2638
2639 .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2640 .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2641 .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2642 .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2643 .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2644 .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2645
2646 .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2647 .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2648 .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2649 .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2650 .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2651 .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2652 .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2653 .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2654 .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2655 .Default([](Operation *op) { return std::nullopt; });
2656 if (!maybeKind) {
2657 return std::nullopt;
2658 }
2659
2660 bool useOnlyFiniteValue = false;
2661 auto fmfOpInterface = dyn_cast(op);
2662 if (fmfOpInterface) {
2663 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2664 useOnlyFiniteValue =
2665 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2666 }
2667
2668
2671
2673 useOnlyFiniteValue);
2674 }
2675
2676
2679 bool useOnlyFiniteValue) {
2680 auto attr =
2682 return builder.createarith::ConstantOp(loc, attr);
2683 }
2684
2685
2686
2689 switch (op) {
2690 case AtomicRMWKind::addf:
2691 return builder.createarith::AddFOp(loc, lhs, rhs);
2692 case AtomicRMWKind::addi:
2693 return builder.createarith::AddIOp(loc, lhs, rhs);
2694 case AtomicRMWKind::mulf:
2695 return builder.createarith::MulFOp(loc, lhs, rhs);
2696 case AtomicRMWKind::muli:
2697 return builder.createarith::MulIOp(loc, lhs, rhs);
2698 case AtomicRMWKind::maximumf:
2699 return builder.createarith::MaximumFOp(loc, lhs, rhs);
2700 case AtomicRMWKind::minimumf:
2701 return builder.createarith::MinimumFOp(loc, lhs, rhs);
2702 case AtomicRMWKind::maxnumf:
2703 return builder.createarith::MaxNumFOp(loc, lhs, rhs);
2704 case AtomicRMWKind::minnumf:
2705 return builder.createarith::MinNumFOp(loc, lhs, rhs);
2706 case AtomicRMWKind::maxs:
2707 return builder.createarith::MaxSIOp(loc, lhs, rhs);
2708 case AtomicRMWKind::mins:
2709 return builder.createarith::MinSIOp(loc, lhs, rhs);
2710 case AtomicRMWKind::maxu:
2711 return builder.createarith::MaxUIOp(loc, lhs, rhs);
2712 case AtomicRMWKind::minu:
2713 return builder.createarith::MinUIOp(loc, lhs, rhs);
2714 case AtomicRMWKind::ori:
2715 return builder.createarith::OrIOp(loc, lhs, rhs);
2716 case AtomicRMWKind::andi:
2717 return builder.createarith::AndIOp(loc, lhs, rhs);
2718
2719 default:
2720 (void)emitOptionalError(loc, "Reduction operation type not supported");
2721 break;
2722 }
2723 return nullptr;
2724 }
2725
2726
2727
2728
2729
2730 #define GET_OP_CLASSES
2731 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2732
2733
2734
2735
2736
2737 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static Attribute getBoolAttribute(Type type, bool value)
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
static int64_t getScalarOrElementWidth(Type type)
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
static std::optional< int64_t > getIntegerWidth(Type t)
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
static LogicalResult verifyTruncateOp(Op op)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)
Build a constant float op that produces a float of the specified type.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
static bool classof(Operation *op)
Specialization of arith.constant op that returns an integer value.
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
static bool classof(Operation *op)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)