MLIR: lib/Dialect/Arith/Transforms/EmulateWideInt.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
20 #include "llvm/ADT/APFloat.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/MathExtras.h"
24 #include
25
27 #define GEN_PASS_DEF_ARITHEMULATEWIDEINT
28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
29 }
30
31 using namespace mlir;
32
33
34
35
36
37
38
39
40
41 static std::pair<APInt, APInt> getHalves(const APInt &value,
42 unsigned newBitWidth) {
43 APInt low = value.extractBits(newBitWidth, 0);
44 APInt high = value.extractBits(newBitWidth, newBitWidth);
45 return {std::move(low), std::move(high)};
46 }
47
48
49
50
51
52
54 if (type.getShape().size() == 1)
55 return type.getElementType();
56
57 auto newShape = to_vector(type.getShape());
58 newShape.back() = 1;
60 }
61
62
63
64
65
66
69 int64_t lastOffset) {
71 assert(lastOffset < shape.back() && "Offset out of bounds");
72
73
74 if (shape.size() == 1)
75 return rewriter.createvector::ExtractOp(loc, input, lastOffset);
76
78 offsets.back() = lastOffset;
79 auto sizes = llvm::to_vector(shape);
80 sizes.back() = 1;
82
83 return rewriter.createvector::ExtractStridedSliceOp(loc, input, offsets,
84 sizes, strides);
85 }
86
87
88
89 static std::pair<Value, Value>
94 }
95
96
97
100 auto vecTy = dyn_cast(input.getType());
101 if (!vecTy)
102 return input;
103
104
106 assert(shape.size() >= 2 && "Expected vector with at list two dims");
107 assert(shape.back() == 1 && "Expected the last vector dim to be x1");
108
109 auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType());
110 return rewriter.createvector::ShapeCastOp(loc, newVecTy, input);
111 }
112
113
114
117 auto vecTy = dyn_cast(input.getType());
118 if (!vecTy)
119 return input;
120
121
122 auto newShape = llvm::to_vector(vecTy.getShape());
123 newShape.push_back(1);
124 auto newTy = VectorType::get(newShape, vecTy.getElementType());
125 return rewriter.createvector::ShapeCastOp(loc, newTy, input);
126 }
127
128
129
130
133 int64_t lastOffset) {
135 assert(lastOffset < shape.back() && "Offset out of bounds");
136
137
138 if (isa(source.getType()))
139 return rewriter.createvector::InsertOp(loc, source, dest, lastOffset);
140
142 offsets.back() = lastOffset;
144 return rewriter.createvector::InsertStridedSliceOp(loc, source, dest,
145 offsets, strides);
146 }
147
148
149
150
151
152
153
155 Location loc, VectorType resultType,
158 (void)resultShape;
159 assert(!resultShape.empty() && "Result expected to have dimensions");
160 assert(resultShape.back() == static_cast<int64_t>(resultComponents.size()) &&
161 "Wrong number of result components");
162
164 for (auto [i, component] : llvm::enumerate(resultComponents))
165 resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i);
166
167 return resultVec;
168 }
169
170 namespace {
171
172
173
174
177
178 LogicalResult
179 matchAndRewrite(arith::ConstantOp op, OpAdaptor,
181 Type oldType = op.getType();
182 auto newType = getTypeConverter()->convertType(oldType);
183 if (!newType)
185 op, llvm::formatv("unsupported type: {0}", op.getType()));
186
187 unsigned newBitWidth = newType.getElementTypeBitWidth();
188 Attribute oldValue = op.getValueAttr();
189
190 if (auto intAttr = dyn_cast(oldValue)) {
191 auto [low, high] = getHalves(intAttr.getValue(), newBitWidth);
194 return success();
195 }
196
197 if (auto splatAttr = dyn_cast(oldValue)) {
198 auto [low, high] =
199 getHalves(splatAttr.getSplatValue(), newBitWidth);
200 int64_t numSplatElems = splatAttr.getNumElements();
202 values.reserve(numSplatElems * 2);
203 for (int64_t i = 0; i < numSplatElems; ++i) {
204 values.push_back(low);
205 values.push_back(high);
206 }
207
210 return success();
211 }
212
213 if (auto elemsAttr = dyn_cast(oldValue)) {
214 int64_t numElems = elemsAttr.getNumElements();
216 values.reserve(numElems * 2);
217 for (const APInt &origVal : elemsAttr.getValues()) {
218 auto [low, high] = getHalves(origVal, newBitWidth);
219 values.push_back(std::move(low));
220 values.push_back(std::move(high));
221 }
222
225 return success();
226 }
227
229 "unhandled constant attribute");
230 }
231 };
232
233
234
235
236
239
240 LogicalResult
241 matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
244 auto newTy = getTypeConverter()->convertType(op.getType());
245 if (!newTy)
247 loc, llvm::formatv("unsupported type: {0}", op.getType()));
248
250
251 auto [lhsElem0, lhsElem1] =
253 auto [rhsElem0, rhsElem1] =
255
256 auto lowSum =
257 rewriter.createarith::AddUIExtendedOp(loc, lhsElem0, rhsElem0);
258 Value overflowVal =
259 rewriter.createarith::ExtUIOp(loc, newElemTy, lowSum.getOverflow());
260
261 Value high0 = rewriter.createarith::AddIOp(loc, overflowVal, lhsElem1);
262 Value high = rewriter.createarith::AddIOp(loc, high0, rhsElem1);
263
264 Value resultVec =
266 rewriter.replaceOp(op, resultVec);
267 return success();
268 }
269 };
270
271
272
273
274
275
276 template
280
281 LogicalResult
282 matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
285 auto newTy = this->getTypeConverter()->template convertType(
286 op.getType());
287 if (!newTy)
289 loc, llvm::formatv("unsupported type: {0}", op.getType()));
290
291 auto [lhsElem0, lhsElem1] =
293 auto [rhsElem0, rhsElem1] =
295
296 Value resElem0 = rewriter.create(loc, lhsElem0, rhsElem0);
297 Value resElem1 = rewriter.create(loc, lhsElem1, rhsElem1);
298 Value resultVec =
300 rewriter.replaceOp(op, resultVec);
301 return success();
302 }
303 };
304
305
306
307
308
309
310
311 static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) {
312 using P = arith::CmpIPredicate;
313 switch (pred) {
314 case P::sge:
315 return P::uge;
316 case P::sgt:
317 return P::ugt;
318 case P::sle:
319 return P::ule;
320 case P::slt:
321 return P::ult;
322 default:
323 return pred;
324 }
325 }
326
329
330 LogicalResult
331 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
334 auto inputTy =
335 getTypeConverter()->convertType(op.getLhs().getType());
336 if (!inputTy)
338 loc, llvm::formatv("unsupported type: {0}", op.getType()));
339
340 arith::CmpIPredicate highPred = adaptor.getPredicate();
341 arith::CmpIPredicate lowPred = toUnsignedPredicate(highPred);
342
343 auto [lhsElem0, lhsElem1] =
345 auto [rhsElem0, rhsElem1] =
347
349 rewriter.createarith::CmpIOp(loc, lowPred, lhsElem0, rhsElem0);
351 rewriter.createarith::CmpIOp(loc, highPred, lhsElem1, rhsElem1);
352
353 Value cmpResult{};
354 switch (highPred) {
355 case arith::CmpIPredicate::eq: {
356 cmpResult = rewriter.createarith::AndIOp(loc, lowCmp, highCmp);
357 break;
358 }
359 case arith::CmpIPredicate::ne: {
360 cmpResult = rewriter.createarith::OrIOp(loc, lowCmp, highCmp);
361 break;
362 }
363 default: {
364
365 Value highEq = rewriter.createarith::CmpIOp(
366 loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
367 cmpResult =
368 rewriter.createarith::SelectOp(loc, highEq, lowCmp, highCmp);
369 break;
370 }
371 }
372
373 assert(cmpResult && "Unhandled case");
375 return success();
376 }
377 };
378
379
380
381
382
385
386 LogicalResult
387 matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
390 auto newTy = getTypeConverter()->convertType(op.getType());
391 if (!newTy)
393 loc, llvm::formatv("unsupported type: {0}", op.getType()));
394
395 auto [lhsElem0, lhsElem1] =
397 auto [rhsElem0, rhsElem1] =
399
400
401
402
403 auto mulLowLow =
404 rewriter.createarith::MulUIExtendedOp(loc, lhsElem0, rhsElem0);
405 Value mulLowHi = rewriter.createarith::MulIOp(loc, lhsElem0, rhsElem1);
406 Value mulHiLow = rewriter.createarith::MulIOp(loc, lhsElem1, rhsElem0);
407
408 Value resLow = mulLowLow.getLow();
410 rewriter.createarith::AddIOp(loc, mulLowLow.getHigh(), mulLowHi);
411 resHi = rewriter.createarith::AddIOp(loc, resHi, mulHiLow);
412
413 Value resultVec =
415 rewriter.replaceOp(op, resultVec);
416 return success();
417 }
418 };
419
420
421
422
423
426
427 LogicalResult
428 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
431 auto newTy = getTypeConverter()->convertType(op.getType());
432 if (!newTy)
434 loc, llvm::formatv("unsupported type: {0}", op.getType()));
435
437
438
439
440
441 Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
443 loc, newResultComponentTy, newOperand);
444 Value operandZeroCst =
446 Value signBit = rewriter.createarith::CmpIOp(
447 loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
448 Value signValue =
449 rewriter.createarith::ExtSIOp(loc, newResultComponentTy, signBit);
450
451 Value resultVec =
453 rewriter.replaceOp(op, resultVec);
454 return success();
455 }
456 };
457
458
459
460
461
464
465 LogicalResult
466 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
469 auto newTy = getTypeConverter()->convertType(op.getType());
470 if (!newTy)
472 loc, llvm::formatv("unsupported type: {0}", op.getType()));
473
475
476
477
478 Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
480 loc, newResultComponentTy, newOperand);
484 return success();
485 }
486 };
487
488
489
490
491
492 template <typename SourceOp, arith::CmpIPredicate CmpPred>
495
496 LogicalResult
497 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
500
501 Type oldTy = op.getType();
502 auto newTy = dyn_cast_or_null(
503 this->getTypeConverter()->convertType(oldTy));
504 if (!newTy)
506 loc, llvm::formatv("unsupported type: {0}", op.getType()));
507
508
509
511 rewriter.createarith::CmpIOp(loc, CmpPred, op.getLhs(), op.getRhs());
513 op.getRhs());
514 return success();
515 }
516 };
517
518
519
520
521
522 static bool isIndexOrIndexVector(Type type) {
523 if (isa(type))
524 return true;
525
526 if (auto vectorTy = dyn_cast(type))
527 if (isa(vectorTy.getElementType()))
528 return true;
529
530 return false;
531 }
532
533 template
536
537 LogicalResult
538 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
540 Type resultType = op.getType();
541 if (!isIndexOrIndexVector(resultType))
542 return failure();
543
545 Type inType = op.getIn().getType();
546 auto newInTy =
547 this->getTypeConverter()->template convertType(inType);
548 if (!newInTy)
550 loc, llvm::formatv("unsupported type: {0}", inType));
551
552
556 return success();
557 }
558 };
559
560 template <typename CastOp, typename ExtensionOp>
563
564 LogicalResult
565 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
567 Type inType = op.getIn().getType();
568 if (!isIndexOrIndexVector(inType))
569 return failure();
570
572 auto *typeConverter =
573 this->template getTypeConverterarith::WideIntEmulationConverter();
574
575 Type resultType = op.getType();
576 auto newTy = typeConverter->template convertType(resultType);
577 if (!newTy)
579 loc, llvm::formatv("unsupported type: {0}", resultType));
580
581
582 Type narrowTy =
583 rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
584 if (auto vecTy = dyn_cast(resultType))
586
587
588
589 Value underlyingVal =
590 rewriter.create(loc, narrowTy, adaptor.getIn());
591 rewriter.replaceOpWithNewOp(op, resultType, underlyingVal);
592 return success();
593 }
594 };
595
596
597
598
599
602
603 LogicalResult
604 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
607 auto newTy = getTypeConverter()->convertType(op.getType());
608 if (!newTy)
610 loc, llvm::formatv("unsupported type: {0}", op.getType()));
611
612 auto [trueElem0, trueElem1] =
614 auto [falseElem0, falseElem1] =
616 Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition());
617
619 rewriter.createarith::SelectOp(loc, cond, trueElem0, falseElem0);
621 rewriter.createarith::SelectOp(loc, cond, trueElem1, falseElem1);
622 Value resultVec =
624 rewriter.replaceOp(op, resultVec);
625 return success();
626 }
627 };
628
629
630
631
632
635
636 LogicalResult
637 matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,
640
641 Type oldTy = op.getType();
642 auto newTy = getTypeConverter()->convertType(oldTy);
643 if (!newTy)
645 loc, llvm::formatv("unsupported type: {0}", op.getType()));
646
648
649 unsigned newBitWidth = newTy.getElementTypeBitWidth();
650
651 auto [lhsElem0, lhsElem1] =
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
680 Value elemBitWidth =
682
683 Value illegalElemShift = rewriter.createarith::CmpIOp(
684 loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
685
686 Value shiftedElem0 =
687 rewriter.createarith::ShLIOp(loc, lhsElem0, rhsElem0);
688 Value resElem0 = rewriter.createarith::SelectOp(loc, illegalElemShift,
689 zeroCst, shiftedElem0);
690
691 Value cappedShiftAmount = rewriter.createarith::SelectOp(
692 loc, illegalElemShift, elemBitWidth, rhsElem0);
693 Value rightShiftAmount =
694 rewriter.createarith::SubIOp(loc, elemBitWidth, cappedShiftAmount);
695 Value shiftedRight =
696 rewriter.createarith::ShRUIOp(loc, lhsElem0, rightShiftAmount);
697 Value overshotShiftAmount =
698 rewriter.createarith::SubIOp(loc, rhsElem0, elemBitWidth);
699 Value shiftedLeft =
700 rewriter.createarith::ShLIOp(loc, lhsElem0, overshotShiftAmount);
701
702 Value shiftedElem1 =
703 rewriter.createarith::ShLIOp(loc, lhsElem1, rhsElem0);
704 Value resElem1High = rewriter.createarith::SelectOp(
705 loc, illegalElemShift, zeroCst, shiftedElem1);
706 Value resElem1Low = rewriter.createarith::SelectOp(
707 loc, illegalElemShift, shiftedLeft, shiftedRight);
709 rewriter.createarith::OrIOp(loc, resElem1Low, resElem1High);
710
711 Value resultVec =
713 rewriter.replaceOp(op, resultVec);
714 return success();
715 }
716 };
717
718
719
720
721
724
725 LogicalResult
726 matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor,
729
730 Type oldTy = op.getType();
731 auto newTy = getTypeConverter()->convertType(oldTy);
732 if (!newTy)
734 loc, llvm::formatv("unsupported type: {0}", op.getType()));
735
737
738 unsigned newBitWidth = newTy.getElementTypeBitWidth();
739
740 auto [lhsElem0, lhsElem1] =
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
769 Value elemBitWidth =
771
772 Value illegalElemShift = rewriter.createarith::CmpIOp(
773 loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
774
775 Value shiftedElem0 =
776 rewriter.createarith::ShRUIOp(loc, lhsElem0, rhsElem0);
777 Value resElem0Low = rewriter.createarith::SelectOp(loc, illegalElemShift,
778 zeroCst, shiftedElem0);
779 Value shiftedElem1 =
780 rewriter.createarith::ShRUIOp(loc, lhsElem1, rhsElem0);
781 Value resElem1 = rewriter.createarith::SelectOp(loc, illegalElemShift,
782 zeroCst, shiftedElem1);
783
784 Value cappedShiftAmount = rewriter.createarith::SelectOp(
785 loc, illegalElemShift, elemBitWidth, rhsElem0);
786 Value leftShiftAmount =
787 rewriter.createarith::SubIOp(loc, elemBitWidth, cappedShiftAmount);
788 Value shiftedLeft =
789 rewriter.createarith::ShLIOp(loc, lhsElem1, leftShiftAmount);
790 Value overshotShiftAmount =
791 rewriter.createarith::SubIOp(loc, rhsElem0, elemBitWidth);
792 Value shiftedRight =
793 rewriter.createarith::ShRUIOp(loc, lhsElem1, overshotShiftAmount);
794
795 Value resElem0High = rewriter.createarith::SelectOp(
796 loc, illegalElemShift, shiftedRight, shiftedLeft);
798 rewriter.createarith::OrIOp(loc, resElem0Low, resElem0High);
799
800 Value resultVec =
802 rewriter.replaceOp(op, resultVec);
803 return success();
804 }
805 };
806
807
808
809
810
813
814 LogicalResult
815 matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
818
819 Type oldTy = op.getType();
820 auto newTy = getTypeConverter()->convertType(oldTy);
821 if (!newTy)
823 loc, llvm::formatv("unsupported type: {0}", op.getType()));
824
827
829 int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2;
830
831
832
833
835 Value signBit = rewriter.createarith::CmpIOp(
836 loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
838
839
840
841
842 Value allSign = rewriter.createarith::ExtSIOp(loc, oldTy, signBit);
845 Value numNonSignExtBits =
846 rewriter.createarith::SubIOp(loc, maxShift, rhsElem0);
847 numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits);
848 numNonSignExtBits =
849 rewriter.createarith::ExtUIOp(loc, oldTy, numNonSignExtBits);
851 rewriter.createarith::ShLIOp(loc, allSign, numNonSignExtBits);
852
853
855 rewriter.createarith::ShRUIOp(loc, op.getLhs(), op.getRhs());
856 Value shrsi = rewriter.createarith::OrIOp(loc, shrui, signBits);
857
858
859
860 Value isNoop = rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::eq,
861 rhsElem0, elemZero);
863 rewriter.replaceOpWithNewOparith::SelectOp(op, isNoop, op.getLhs(),
864 shrsi);
865
866 return success();
867 }
868 };
869
870
871
872
873
876
877 LogicalResult
878 matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor,
881 auto newTy = getTypeConverter()->convertType(op.getType());
882 if (!newTy)
884 loc, llvm::formatv("unsupported type: {}", op.getType()));
885
887
888 auto [lhsElem0, lhsElem1] =
890 auto [rhsElem0, rhsElem1] =
892
893
894
895 Value low = rewriter.createarith::SubIOp(loc, lhsElem0, rhsElem0);
896
897 Value carry0 = rewriter.createarith::CmpIOp(
898 loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);
899 Value carryVal = rewriter.createarith::ExtUIOp(loc, newElemTy, carry0);
900
901 Value high0 = rewriter.createarith::SubIOp(loc, lhsElem1, carryVal);
902 Value high = rewriter.createarith::SubIOp(loc, high0, rhsElem1);
903
905 rewriter.replaceOp(op, resultVec);
906 return success();
907 }
908 };
909
910
911
912
913
916
917 LogicalResult
918 matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor,
921
922 Value in = op.getIn();
924 auto newTy = getTypeConverter()->convertType(oldTy);
925 if (!newTy)
927 loc, llvm::formatv("unsupported type: {0}", oldTy));
928
930
931
932
933
934
935
936 Value isNeg = rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::slt,
937 in, zeroCst);
938 Value neg = rewriter.createarith::SubIOp(loc, zeroCst, in);
939 Value abs = rewriter.createarith::SelectOp(loc, isNeg, neg, in);
940
941 Value absResult = rewriter.createarith::UIToFPOp(loc, op.getType(), abs);
942 Value negResult = rewriter.createarith::NegFOp(loc, absResult);
944 absResult);
945 return success();
946 }
947 };
948
949
950
951
952
955
956 LogicalResult
957 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
960
961 Type oldTy = op.getIn().getType();
962 auto newTy = getTypeConverter()->convertType(oldTy);
963 if (!newTy)
965 loc, llvm::formatv("unsupported type: {0}", oldTy));
966 unsigned newBitWidth = newTy.getElementTypeBitWidth();
967
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988 Value hiEqZero = rewriter.createarith::CmpIOp(
989 loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
990
991 Type resultTy = op.getType();
993 Value lowFp = rewriter.createarith::UIToFPOp(loc, resultTy, lowInt);
994 Value hiFp = rewriter.createarith::UIToFPOp(loc, resultTy, hiInt);
995
996 int64_t pow2Int = int64_t(1) << newBitWidth;
997 TypedAttr pow2Attr =
998 rewriter.getFloatAttr(resultElemTy, static_cast<double>(pow2Int));
999 if (auto vecTy = dyn_cast(resultTy))
1001
1002 Value pow2Val = rewriter.createarith::ConstantOp(loc, resultTy, pow2Attr);
1003
1004 Value hiVal = rewriter.createarith::MulFOp(loc, hiFp, pow2Val);
1005 Value result = rewriter.createarith::AddFOp(loc, lowFp, hiVal);
1006
1007 rewriter.replaceOpWithNewOparith::SelectOp(op, hiEqZero, lowFp, result);
1008 return success();
1009 }
1010 };
1011
1012
1013
1014
1015
1018
1019 LogicalResult
1020 matchAndRewrite(arith::FPToSIOp op, OpAdaptor adaptor,
1023
1024 Value inFp = adaptor.getIn();
1026
1027 Type intTy = op.getType();
1028
1029 auto newTy = getTypeConverter()->convertType(intTy);
1030 if (!newTy)
1032 loc, llvm::formatv("unsupported type: {}", intTy));
1033
1034
1035
1036
1037
1038
1039 TypedAttr zeroAttr = rewriter.getZeroAttr(fpTy);
1040 Value zeroCst = rewriter.createarith::ConstantOp(loc, zeroAttr);
1042
1043
1044
1045 Value isNeg = rewriter.createarith::CmpFOp(loc, arith::CmpFPredicate::OLT,
1046 inFp, zeroCst);
1047 Value negInFp = rewriter.createarith::NegFOp(loc, inFp);
1048
1049 Value absVal = rewriter.createarith::SelectOp(loc, isNeg, negInFp, inFp);
1050
1051
1052 Value res = rewriter.createarith::FPToUIOp(loc, intTy, absVal);
1053
1054
1055 Value neg = rewriter.createarith::SubIOp(loc, zeroCstInt, res);
1056
1058 return success();
1059 }
1060 };
1061
1062
1063
1064
1065
1068
1069 LogicalResult
1070 matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor,
1073
1074 Value inFp = adaptor.getIn();
1076
1077 Type intTy = op.getType();
1078 auto newTy = getTypeConverter()->convertType(intTy);
1079 if (!newTy)
1081 loc, llvm::formatv("unsupported type: {}", intTy));
1082 unsigned newBitWidth = newTy.getElementTypeBitWidth();
1083
1085 if (auto vecType = dyn_cast(fpTy))
1086 newHalfType = VectorType::get(vecType.getShape(), newHalfType);
1087
1088
1089
1090
1091
1092
1093
1094 const llvm::fltSemantics &fSemantics =
1096
1097 auto powBitwidth = llvm::APFloat(fSemantics);
1098
1099
1100
1101
1102 if (powBitwidth.convertFromAPInt(APInt(newBitWidth * 2, 1).shl(newBitWidth),
1103 false, llvm::RoundingMode::TowardZero) ==
1104 llvm::detail::opStatus::opInexact)
1105 powBitwidth = llvm::APFloat::getInf(fSemantics);
1106
1107 TypedAttr powBitwidthAttr =
1109 if (auto vecType = dyn_cast(fpTy))
1111 Value powBitwidthFloatCst =
1112 rewriter.createarith::ConstantOp(loc, powBitwidthAttr);
1113
1114 Value fpDivPowBitwidth =
1115 rewriter.createarith::DivFOp(loc, inFp, powBitwidthFloatCst);
1117 rewriter.createarith::FPToUIOp(loc, newHalfType, fpDivPowBitwidth);
1118
1119 Value remainder =
1120 rewriter.createarith::RemFOp(loc, inFp, powBitwidthFloatCst);
1122 rewriter.createarith::FPToUIOp(loc, newHalfType, remainder);
1123
1126
1128
1129 rewriter.replaceOp(op, resultVec);
1130 return success();
1131 }
1132 };
1133
1134
1135
1136
1137
1140
1141 LogicalResult
1142 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
1145
1146
1147 if (!getTypeConverter()->isLegal(op.getType()))
1149 loc, llvm::formatv("unsupported truncation result type: {0}",
1150 op.getType()));
1151
1152
1153
1156 Value truncated =
1157 rewriter.createOrFoldarith::TruncIOp(loc, op.getType(), extracted);
1158 rewriter.replaceOp(op, truncated);
1159 return success();
1160 }
1161 };
1162
1163
1164
1165
1166
1167 struct ConvertVectorPrint final : OpConversionPatternvector::PrintOp {
1169
1170 LogicalResult
1171 matchAndRewrite(vector::PrintOp op, OpAdaptor adaptor,
1173 rewriter.replaceOpWithNewOpvector::PrintOp(op, adaptor.getSource());
1174 return success();
1175 }
1176 };
1177
1178
1179
1180
1181
1182 struct EmulateWideIntPass final
1183 : arith::impl::ArithEmulateWideIntBase {
1184 using ArithEmulateWideIntBase::ArithEmulateWideIntBase;
1185
1186 void runOnOperation() override {
1187 if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
1188 signalPassFailure();
1189 return;
1190 }
1191
1194
1195 arith::WideIntEmulationConverter typeConverter(widestIntSupported);
1197 target.addDynamicallyLegalOpfunc::FuncOp([&typeConverter](Operation *op) {
1198 return typeConverter.isLegal(castfunc::FuncOp(op).getFunctionType());
1199 });
1200 auto opLegalCallback = [&typeConverter](Operation *op) {
1201 return typeConverter.isLegal(op);
1202 };
1203 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
1204 target.addDynamicallyLegalOpvector::PrintOp(opLegalCallback);
1205 target.addDynamicallyLegalDialectarith::ArithDialect(opLegalCallback);
1206 target.addLegalDialectvector::VectorDialect();
1207
1210
1211
1212 populateFunctionOpInterfaceTypeConversionPatternfunc::FuncOp(
1216
1218 signalPassFailure();
1219 }
1220 };
1221 }
1222
1223
1224
1225
1226
1228 unsigned widestIntSupportedByTarget)
1229 : maxIntWidth(widestIntSupportedByTarget) {
1230 assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
1231 "Only power-of-two integers with are supported");
1232 assert(widestIntSupportedByTarget >= 2 && "Integer type too narrow");
1233
1234
1235 addConversion([](Type ty) -> std::optional { return ty; });
1236
1237
1238 addConversion([this](IntegerType ty) -> std::optional {
1239 unsigned width = ty.getWidth();
1240 if (width <= maxIntWidth)
1241 return ty;
1242
1243
1244 if (width == 2 * maxIntWidth)
1246
1247 return nullptr;
1248 });
1249
1250
1251 addConversion([this](VectorType ty) -> std::optional {
1252 auto intTy = dyn_cast(ty.getElementType());
1253 if (!intTy)
1254 return ty;
1255
1256 unsigned width = intTy.getWidth();
1257 if (width <= maxIntWidth)
1258 return ty;
1259
1260
1261 if (width == 2 * maxIntWidth) {
1262 auto newShape = to_vector(ty.getShape());
1263 newShape.push_back(2);
1266 }
1267
1268 return nullptr;
1269 });
1270
1271
1272 addConversion([this](FunctionType ty) -> std::optional {
1273
1274
1276 if (failed(convertTypes(ty.getInputs(), inputs)))
1277 return nullptr;
1278
1280 if (failed(convertTypes(ty.getResults(), results)))
1281 return nullptr;
1282
1284 });
1285 }
1286
1290
1292
1293 ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint,
1294
1295 ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI,
1296 ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
1297 ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
1298 ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
1299 ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>, ConvertSubI,
1300
1301 ConvertBitwiseBinaryarith::AndIOp, ConvertBitwiseBinaryarith::OrIOp,
1302 ConvertBitwiseBinaryarith::XOrIOp,
1303
1304 ConvertExtSI, ConvertExtUI, ConvertTruncI,
1305
1306 ConvertIndexCastIntToIndexarith::IndexCastOp,
1307 ConvertIndexCastIntToIndexarith::IndexCastUIOp,
1308 ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
1309 ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
1310 ConvertSIToFP, ConvertUIToFP, ConvertFPToUI, ConvertFPToSI>(
1311 typeConverter, patterns.getContext());
1312 }
static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value source, Value dest, int64_t lastOffset)
Inserts the source vector slice into the dest vector at offset lastOffset in the last dimension.
static std::pair< APInt, APInt > getHalves(const APInt &value, unsigned newBitWidth)
Returns N bottom and N top bits from value, where N = newBitWidth.
static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input)
Performs a vector shape cast to append an x1 dimension.
static std::pair< Value, Value > extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc, Value input)
Extracts two vector slices from the input whose type is vector<...x2T>, with the first element at off...
static Type reduceInnermostDim(VectorType type)
Returns the type with the last (innermost) dimension reduced to x1.
static Value constructResultVector(ConversionPatternRewriter &rewriter, Location loc, VectorType resultType, ValueRange resultComponents)
Constructs a new vector of type resultType by creating a series of insertions of resultComponents,...
static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input)
static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t lastOffset)
Extracts the input vector slice with elements at the last dimension offset by lastOffset.
Attributes are known-constant values of operations.
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Converts integer types that are too wide for the target by splitting them in two halves and thus turn...
WideIntEmulationConverter(unsigned widestIntSupportedByTarget)
void populateArithWideIntEmulationPatterns(const WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Adds patterns to emulate wide Arith and Function ops over integer types into supported ones.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.