MLIR: lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32
33 #include
34
35 using namespace mlir;
37
38
39
40
41
42
43
44
45
46
48
51 (padConstAttr.size() != 1)) {
52 return false;
53 }
54
55
56 if (auto padConstFpAttr = mlir::dyn_cast(padConstAttr)) {
57 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
58 return padConstVal == 0.0f;
59 }
60
61
62 if (auto padConstIntAttr =
63 mlir::dyn_cast(padConstAttr)) {
65
67 return false;
68 }
69
70
71 int64_t zpVal = (*zpAttr.begin()).getSExtValue();
72 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
73 return zpVal == padConstVal;
74 }
75
76
77 return false;
78 }
79
80 namespace {
81 template
82 struct PoolPadFoldAdaptor;
83
84 template <>
85 struct PoolPadFoldAdaptortosa::AvgPool2dOp {
86 using OpTy = tosa::AvgPool2dOp;
87 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
89 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
90 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
91 return false;
92 return true;
93 }
94 static bool checkPadConstCompliance(OpTy op, Value padConst) {
96 }
97 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
100 op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
102 op.getAccType());
103 }
104 };
105
106 template <>
107 struct PoolPadFoldAdaptortosa::MaxPool2dOp {
108 using OpTy = tosa::MaxPool2dOp;
109 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
111 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
112 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
113 return false;
114 return true;
115 }
116 static bool checkPadConstCompliance(OpTy, Value padConst) {
117
120 padConstAttr.size() != 1) {
121 return false;
122 }
123
124
125 if (auto padConstFpAttr =
126 mlir::dyn_cast(padConstAttr)) {
127 const APFloat padConstVal = *padConstFpAttr.begin();
128 const APFloat lowestVal =
129 APFloat::getLargest(padConstVal.getSemantics(), true);
130 return padConstVal == lowestVal;
131 } else if (auto padConstIntAttr =
132 mlir::dyn_cast(padConstAttr)) {
133 const APInt padConstVal = *padConstIntAttr.begin();
134 const unsigned int bitWidth = padConstVal.getBitWidth();
135 const APInt lowestVal =
136 padConstIntAttr.getElementType().isUnsignedInteger()
138 : APInt::getSignedMinValue(bitWidth);
139 return padConstVal == lowestVal;
140 }
141
142
143 return false;
144 }
145 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
148 op, op.getType(), padInput, op.getKernel(), op.getStride(),
150 }
151 };
152
153 template
154 struct ConvPadFoldAdaptor {
155 static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
156 return true;
157 }
158 static bool checkPadConstCompliance(OpTy op, Value padConst) {
160 }
161 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
164 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
165 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
166 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
167 }
168 };
169
170
171
172
173
174 template <typename OpTy, typename AdaptorTy>
177
178 LogicalResult matchAndRewrite(OpTy tensorOp,
180
181 auto padOp = tensorOp.getInput().template getDefiningOptosa::PadOp();
182 if (!padOp)
184 "Producer must be a tosa::PadOp.");
185
186
187 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
188 if (tensorOpPad.size() != 4)
190 tensorOp, "Tensor operation padding shall have 4 elements.");
191
192
196 tensorOp,
197 "The `padding` input specified on the tosa::PadOp must be constant.");
198 }
199
200
201 if (padOpPadding.size() != 8)
203 "Pad padding should have 8 elements.");
204 int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
205 int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
206 int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
207 int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
208 int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
209 int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
210 int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
211 int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
212
213 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
215 tensorOp, "Folding padding in N or C dimensions is not supported.");
216
217
218
220 foldedPad[0] = padHBefore + tensorOpPad[0];
221 foldedPad[1] = padHAfter + tensorOpPad[1];
222 foldedPad[2] = padWBefore + tensorOpPad[2];
223 foldedPad[3] = padWAfter + tensorOpPad[3];
224
225
226 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
228 tensorOp, "Padding size not aligned with kernel restrictions.");
229 }
230
231
232 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
234 tensorOp,
235 "Padding constant is not aligned with operator zero-point.");
236 }
237
238
239 if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
241 tensorOp, "Padding size more than the 8K level limit.");
242 }
243
244
245 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
246 foldedPad);
247
248 return success();
249 }
250 };
251 }
252
253 void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
255 results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
256 PoolPadFoldAdaptortosa::AvgPool2dOp>>(
257 context);
258 }
259
260 void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
262 results.add<
263 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptortosa::Conv2DOp>>(
264 context);
265 }
266
267 void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
269 results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
270 ConvPadFoldAdaptortosa::DepthwiseConv2DOp>>(
271 context);
272 }
273
276
279 Value input = op.getInput();
280 Value output = op.getOutput();
281 ShapedType inputType = llvm::cast(input.getType());
282 ShapedType outputType = llvm::cast(output.getType());
283
284 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
285 return failure();
286 }
287
288
290 if (outputShape[1] != 1 || outputShape[2] != 1) {
291 return failure();
292 }
293
295 if (inputShape[1] != 1 || inputShape[2] != 1) {
296 return failure();
297 }
298
300 return success();
301 }
302 };
303
304 void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
307 FoldPadToTensorOp<tosa::MaxPool2dOp,
308 PoolPadFoldAdaptortosa::MaxPool2dOp>>(
309 context);
310 }
311
312
313
314
315
318
321 if (op.getInput1().size() != 1)
322 return failure();
323 if (op.getInput1().front().getType() != op.getType()) {
324 rewriter
326 op.getInput1().front())
327 .getResult();
328 return success();
329 }
330
331 rewriter.replaceOp(op, op.getInput1().front());
332 return success();
333 }
334 };
335
336 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
339 }
340
341 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
342 auto notOp = op.getInput1().getDefiningOptosa::LogicalNotOp();
343 if (!notOp)
344 return failure();
346 op.getOperation()->setOperands(
347 {notOp.getInput1(), op.getInput3(), op.getInput2()});
348 });
349 return success();
350 }
351
355
358
359 auto innerTranspose =
360 transposeOp.getInput1().getDefiningOptosa::TransposeOp();
361 if (!innerTranspose)
363 "input must be transpose operation");
364
367 innerTranspose.getPerms();
368
369 if (transposePerms.size() != innerTransposePerms.size())
371 transposeOp,
372 "transpose and inner transpose perms sizes must be equal");
373 if (transposePerms.empty())
375 transposeOp, "transpose perms sizes must be positive");
376
377
379 for (int i = 0, s = transposePerms.size(); i < s; ++i)
380 perms[i] = innerTransposePerms[transposePerms[i]];
381
383 transposeOp, transposeOp.getResult().getType(),
385
386 return success();
387 }
388 };
389
390
393
396 if (op.getInput1().getDefiningOptosa::TransposeOp())
398 op, "Src is from transpose, can compose transposes");
399
400 Value result = op.getResult();
402 if (isa_and_nonnulltosa::TransposeOp(subop))
404 op, "Dest is used by transpose, can compose transposes");
405 }
406
407 auto input = op.getInput1();
408 auto inputTy = llvm::cast(input.getType());
409 if (!inputTy.hasRank())
411
412 int64_t numDynDims = 0;
413 for (int i = 0; i < inputTy.getRank(); ++i)
414 if (inputTy.isDynamicDim(i))
415 numDynDims++;
416
417 if (numDynDims > 1)
418 return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
419
421
423 nonZeroPerms.reserve(permValues.size());
424 for (auto idx : permValues) {
425 auto sz = inputTy.getDimSize(idx);
426 if (sz != 1)
427 nonZeroPerms.push_back(idx);
428 }
429
430 for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
431 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
433 "Transpose changes memory layout.");
434
436 newShape.reserve(inputTy.getRank());
437 for (int i = 0, s = inputTy.getRank(); i < s; ++i)
438 newShape.push_back(inputTy.getDimSize(permValues[i]));
439
441 op, op.getType(), op.getInput1(),
443 return success();
444 }
445 };
446
447 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
450 }
451
454
457 Value input = op.getInput();
458 auto inputType = llvm::dyn_cast(op.getInput().getType());
459 auto inputElementType = inputType.getElementType();
460
461 if (!inputType.hasStaticShape()) {
462 return failure();
463 }
464
465 if (isa(inputElementType)) {
466
467 auto minClamp =
468 llvm::castmlir::FloatAttr(op.getMinValAttr()).getValue();
469 auto maxClamp =
470 llvm::castmlir::FloatAttr(op.getMaxValAttr()).getValue();
471 bool isMin = minClamp.isNegInfinity();
472 bool isMax = maxClamp.isInfinity();
473
474 if (isMin && isMax) {
476 return success();
477 }
478 return failure();
479 }
480
481 if (inputElementType.isUnsignedInteger()) {
482 int64_t minClamp =
483 llvm::castmlir::IntegerAttr(op.getMinValAttr()).getUInt();
484 int64_t maxClamp =
485 llvm::castmlir::IntegerAttr(op.getMaxValAttr()).getUInt();
486
487 int64_t intMin =
488 APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
489 .getZExtValue();
490 int64_t intMax =
491 APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
492 .getZExtValue();
493
494 if (minClamp <= intMin && maxClamp >= intMax) {
496 return success();
497 }
498 return failure();
499 }
500
501 if (llvm::isa(inputElementType)) {
502 int64_t minClamp =
503 llvm::castmlir::IntegerAttr(op.getMinValAttr()).getInt();
504 int64_t maxClamp =
505 llvm::castmlir::IntegerAttr(op.getMaxValAttr()).getInt();
506
507 int64_t intMin =
508 APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
509 .getSExtValue();
510 int64_t intMax =
511 APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
512 .getSExtValue();
513
514 if (minClamp <= intMin && maxClamp >= intMax) {
516 return success();
517 }
518 return failure();
519 }
520
521 return failure();
522 }
523 };
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
544
545
546 template
548 ClampRange(const T &start, const T &end) : start(start), end(end) {}
551
552
554 return start < otherRange.end && otherRange.start < end;
555 }
556 };
557
560 Value input = op.getInput();
561
562
563 auto clampOp = dyn_cast_if_presenttosa::ClampOp(input.getDefiningOp());
564 if (!clampOp)
565 return failure();
566
567
568 const auto opNanMode = op.getNanMode();
569 const auto clampNanMode = clampOp.getNanMode();
570 if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
571 return failure();
572
573 auto maxValAttr = op.getMaxValAttr();
574 auto minValAttr = op.getMinValAttr();
575 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
576 auto clampOpMinValAttr = clampOp.getMinValAttr();
577
578 auto inputEType = llvm::cast(input.getType()).getElementType();
579 if (auto quantType =
580 llvm::dyn_castmlir::quant::UniformQuantizedType(inputEType)) {
581 inputEType = quantType.getStorageType();
582 }
583
584 Attribute newMinValAttr, newMaxValAttr;
585 if (mlir::isa(inputEType)) {
586 auto floatMaxValAttr = castmlir::FloatAttr(maxValAttr);
587 auto floatMinValAttr = castmlir::FloatAttr(minValAttr);
588 auto clampOpFloatMaxValAttr = castmlir::FloatAttr(clampOpMaxValAttr);
589 auto clampOpFloatMinValAttr = castmlir::FloatAttr(clampOpMinValAttr);
590
591
592 const auto opMinFloat = floatMinValAttr.getValue();
593 const auto opMaxFloat = floatMaxValAttr.getValue();
594 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
595 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
598 clampOpMaxFloat);
599 if (!opRangeFloatRange.intersects(clampRangeFloatRange))
600 return failure();
601
602
603 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
604 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
605 newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
606 newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
607 } else {
608 assert(mlir::isa(inputEType));
609 auto intMaxValAttr = castmlir::IntegerAttr(maxValAttr);
610 auto intMinValAttr = castmlir::IntegerAttr(minValAttr);
611 auto clampOpIntMaxValAttr = castmlir::IntegerAttr(clampOpMaxValAttr);
612 auto clampOpIntMinValAttr = castmlir::IntegerAttr(clampOpMinValAttr);
613
614 if (inputEType.isUnsignedInteger()) {
615
616 const auto opMinInt = intMinValAttr.getUInt();
617 const auto opMaxInt = intMaxValAttr.getUInt();
618 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
619 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
622 clampOpMaxInt);
623 if (!opRangeIntRange.intersects(clampRangeIntRange))
624 return failure();
625
626
627 auto newMinVal = std::max(opMinInt, clampOpMinInt);
628 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
629 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
630 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
631 } else {
632
633 const auto opMinInt = intMinValAttr.getInt();
634 const auto opMaxInt = intMaxValAttr.getInt();
635 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
636 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
639 clampOpMaxInt);
640 if (!opRangeIntRange.intersects(clampRangeIntRange))
641 return failure();
642
643
644 auto newMinVal = std::max(opMinInt, clampOpMinInt);
645 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
646 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
647 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
648 }
649 }
650
652 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
653 rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
654 : opNanMode));
655 return success();
656 }
657 };
658
659 void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
663 }
664
667
670 Value sliceInput = sliceOp.getInput1();
671 auto concatOp = sliceInput.getDefiningOptosa::ConcatOp();
672 if (!concatOp)
674 sliceOp, "slice input must be concat operation");
675
677 auto concatType = dyn_cast(concatOp.getType());
678 if (!concatType || !concatType.hasStaticShape())
680 sliceOp, "slice input must be a static ranked tensor");
681 int32_t axis = concatOp.getAxis();
682
685
688 sliceOp, "start of slice must be a static ranked shape");
689
692 sliceOp, "size of slice must be a static ranked shape");
693
695 llvm::to_vector(startElems.getValues<int64_t>());
697 llvm::to_vector(sizeElems.getValues<int64_t>());
698
699
700
701
702 std::optional replaceWithSlice;
703 for (auto input : inputs) {
704 auto inputType = dyn_cast(input.getType());
705 if (!inputType || !inputType.hasStaticShape())
707 sliceOp, "concat input must be a static ranked tensor");
708
709 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
710 inputType.getDimSize(axis)) {
711 auto start_op =
713 auto size_op =
715 replaceWithSlice =
716 rewriter
717 .createtosa::SliceOp(sliceOp.getLoc(), sliceOp.getType(),
718 input, start_op, size_op)
719 .getResult();
720 break;
721 }
722 sliceStarts[axis] -= inputType.getDimSize(axis);
723 }
724
725 if (!replaceWithSlice)
727 sliceOp, "corresponding concat input not found for slice");
728
729 rewriter.replaceOp(sliceOp, replaceWithSlice.value());
730 return success();
731 }
732 };
733
736
739 Value sliceInput = sliceOp.getInput1();
740
741
742 auto padOp = sliceInput.getDefiningOptosa::PadOp();
743 if (!padOp)
745 "slice input must be a pad operation");
746
747
748 if (!padOp->hasOneUse())
750 "pad shall have a single consumer");
751
752
753 auto inputTy = dyn_cast(padOp.getInput1().getType());
754 auto padTy = dyn_cast(padOp.getType());
755 if (!inputTy || !padTy || !inputTy.hasRank())
757 "slice input must be a ranked tensor");
758
759
763 sliceOp,
764 "`padding` input specified on the tosa::PadOp must be constant.");
765 }
767 llvm::to_vector(paddingElems.getValues<int64_t>());
768
769
773 sliceOp, "start of slice must be a static ranked shape");
775 llvm::to_vector(startElems.getValues<int64_t>());
776
780 sliceOp, "size of slice must be a static ranked shape");
782 llvm::to_vector(sizeElems.getValues<int64_t>());
783
784
785 const int64_t rank = inputTy.getRank();
786 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
787 const bool isDimDynamic = inputTy.isDynamicDim(i);
788 const bool isDimSliced =
789 (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
790
791 return isDimDynamic && isDimSliced;
792 })) {
794 sliceOp, "axis that are sliced shall be statically known.");
795 }
796
797
801 bool updated = false;
802
803 for (int64_t i = 0; i < rank; ++i) {
804 const int64_t padLo = padPaddings[i * 2];
805 const int64_t padHi = padPaddings[i * 2 + 1];
806 const int64_t sliceStart = sliceStarts[i];
807 const int64_t sliceSize = sliceSizes[i];
808 const int64_t sliceEnd = sliceStart + sliceSize;
809
810
811 if (inputTy.isDynamicDim(i)) {
812 newPadPaddings[i * 2] = padLo;
813 newPadPaddings[i * 2 + 1] = padHi;
814 newSliceStarts[i] = sliceStart;
815 continue;
816 }
817
818
819 const int64_t dimSize = inputTy.getShape()[i];
820 const int64_t dimTotal = padLo + dimSize + padHi;
821
822
823 if (sliceStart < 0 || sliceEnd > dimTotal)
824 return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
825
826
827 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
828 newSliceStarts[i] = newSliceStart;
829 updated |= newSliceStart != sliceStart;
830
831
832 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
833 const int64_t newPadHi =
834 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
835 newPadPaddings[i * 2] = newPadLo;
836 newPadPaddings[i * 2 + 1] = newPadHi;
837 updated |= (newPadLo != padLo) || (newPadHi != padHi);
838
839
840 newPadShape[i] =
841 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
842 }
843
844
845 if (!updated)
847 sliceOp, "terminate condition; nothing to rewrite");
848
849
850 auto newPaddingsOp =
852 auto newPadTy =
854 auto newPadOp = rewriter.createtosa::PadOp(
855 padOp.getLoc(), newPadTy, padOp.getInput1(), newPaddingsOp,
856 padOp.getPadConst());
857
858
859 auto newStartOp =
861 rewriter.replaceOpWithNewOptosa::SliceOp(sliceOp, sliceOp.getType(),
862 newPadOp.getResult(), newStartOp,
863 sliceOp.getSize());
864
865 return success();
866 }
867 };
868
869
870
874
877 ShapedType resultType = cast(sliceOp.getType());
878
879 ElementsAttr sizeElems;
882 sliceOp, "size of slice must be a static ranked shape");
883 }
884
886 llvm::to_vector(sizeElems.getValues<int64_t>());
887
888 bool replaceSliceSize{false};
889
890
891
892 for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
893 if (size == -1 && !resultType.isDynamicDim(index)) {
894 sliceSizes[index] = resultType.getDimSize(index);
895 replaceSliceSize = true;
896 }
897 }
898
899 if (!replaceSliceSize) {
901 sliceOp, "no dimension of size of slice is dynamic that resolves "
902 "to static output shape");
903 }
904
905 auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
906 auto newSliceOp = rewriter.createtosa::SliceOp(
907 sliceOp.getLoc(), sliceOp.getType(), sliceOp.getInput1(),
908 sliceOp.getStart(), size_op);
909
910 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
911 return success();
912 }
913 };
914
915 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
919 }
920
921
922
923
924
925 template <typename IntFolder, typename FloatFolder>
927 RankedTensorType returnTy) {
930 auto rETy = llvm::cast(rhs.getType()).getElementType();
931 if (lETy != rETy)
932 return {};
933
934 if (llvm::isa(lETy)) {
937 auto result = IntFolder()(l, r);
939 }
940
941 if (llvm::isa(lETy)) {
944 auto result = FloatFolder()(l, r);
946 }
947 }
948
949 return {};
950 }
951
953 if (llvm::isa(elemType))
955 if (llvm::isa(elemType))
957 return false;
958 }
959
961 if (llvm::isa(elemType))
962 return val && val.isSplat() &&
963 val.getSplatValue().isExactlyValue(1.0);
964 if (llvm::isa(elemType)) {
965 const int64_t shifted = 1LL << shift;
966 return val && val.isSplat() &&
967 val.getSplatValue().getSExtValue() == shifted;
968 }
969 return false;
970 }
971
972 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
973 auto lhsTy = llvm::dyn_cast(getInput1().getType());
974 auto rhsTy = llvm::dyn_cast(getInput2().getType());
975 auto resultTy = llvm::dyn_cast(getType());
976 if (!lhsTy || !rhsTy || !resultTy)
977 return {};
978
979
980 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
981 !rhsTy.getElementType().isIntOrIndexOrFloat())
982 return {};
983
984 auto resultETy = resultTy.getElementType();
985 auto lhsAttr =
986 llvm::dyn_cast_if_present(adaptor.getInput1());
987 auto rhsAttr =
988 llvm::dyn_cast_if_present(adaptor.getInput2());
989
990 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
991 return getInput1();
992 if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
993 return getInput2();
994
995 if (!lhsAttr || !rhsAttr)
996 return {};
997
998 return binaryFolder<std::plus, std::plus>(lhsAttr, rhsAttr,
999 resultTy);
1000 }
1001
1002 OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1003 auto inputTy = llvm::dyn_cast(getInput().getType());
1004 auto outputTy = llvm::dyn_cast(getType());
1005 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1006 !outputTy.hasStaticShape())
1007 return {};
1008
1009 if (inputTy.getDimSize(getAxis()) == 1)
1011
1012 return {};
1013 }
1014
1015 OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1016 auto lhsTy = llvm::dyn_cast(getInput1().getType());
1017 auto rhsTy = llvm::dyn_cast(getInput2().getType());
1018 auto resultTy = llvm::dyn_cast(getType());
1019 if (!lhsTy || !rhsTy || !resultTy)
1020 return {};
1021 if (lhsTy != rhsTy)
1022 return {};
1023
1024
1025 auto resultETy = resultTy.getElementType();
1026 auto lhsAttr =
1027 llvm::dyn_cast_if_present(adaptor.getInput1());
1028 auto rhsAttr =
1029 llvm::dyn_cast_if_present(adaptor.getInput2());
1030 if (lhsAttr && lhsAttr.isSplat()) {
1031 if (llvm::isa(resultETy) &&
1032 lhsAttr.getSplatValue().isZero())
1033 return lhsAttr;
1034 }
1035
1036 if (rhsAttr && rhsAttr.isSplat()) {
1037 if (llvm::isa(resultETy) &&
1038 rhsAttr.getSplatValue().isOne())
1039 return getInput1();
1040 }
1041
1042 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1043 llvm::isa(resultETy)) {
1044 APInt l = lhsAttr.getSplatValue();
1045 APInt r = rhsAttr.getSplatValue();
1046 if (!r.isZero()) {
1047 APInt result = l.sdiv(r);
1049 }
1050 }
1051
1052 return {};
1053 }
1054
1055 namespace {
1056
1057
1058 std::optional mulInt(APInt lhs, APInt rhs, int32_t shift,
1059 unsigned bitwidth) {
1060 APInt result = lhs.sext(64) * rhs.sext(64);
1061
1062 if (shift > 0) {
1063 auto round = APInt(64, 1) << (shift - 1);
1064 result += round;
1065 result.ashrInPlace(shift);
1066
1067 if (!(result.getSExtValue() >= INT32_MIN &&
1068 result.getSExtValue() <= INT32_MAX)) {
1069
1070 return std::nullopt;
1071 }
1072 }
1073
1074 return result.trunc(bitwidth);
1075 }
1076
1078 RankedTensorType ty, int32_t shift) {
1080 if (llvm::isa(ty.getElementType())) {
1083
1084 if (shift == 0) {
1086 }
1087
1088 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1089 const std::optional result = mulInt(l, r, shift, bitwidth);
1090 if (!result)
1091 return {};
1093 }
1094
1095 if (llvm::isa(ty.getElementType())) {
1098 APFloat result = l * r;
1100 }
1101 }
1102
1103 return {};
1104 }
1105 }
1106
1107 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1108 auto lhs = getInput1();
1109 auto rhs = getInput2();
1110 auto lhsTy = llvm::dyn_cast(lhs.getType());
1111 auto rhsTy = llvm::dyn_cast(rhs.getType());
1112 auto resultTy = llvm::dyn_cast(getType());
1113 if (!lhsTy || !rhsTy || !resultTy)
1114 return {};
1115
1116 auto resultETy = resultTy.getElementType();
1117 auto lhsAttr =
1118 llvm::dyn_cast_if_present(adaptor.getInput1());
1119 auto rhsAttr =
1120 llvm::dyn_cast_if_present(adaptor.getInput2());
1121
1122
1123
1124 int32_t shift = 0;
1125 if (resultETy.isInteger(32)) {
1126 ElementsAttr shift_elem;
1127 if (getShift().getImpl()) {
1129
1130 return {};
1131 shift = shift_elem.getValues()[0].getInt();
1132 }
1133 }
1134
1135 if (rhsTy == resultTy) {
1137 return lhsAttr.resizeSplat(resultTy);
1138 if (isSplatOne(resultETy, lhsAttr, shift))
1139 return rhs;
1140 }
1141 if (lhsTy == resultTy) {
1144 if (isSplatOne(resultETy, rhsAttr, shift))
1145 return lhs;
1146 }
1147
1148 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1149 }
1150
1151 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1152 auto lhsTy = llvm::dyn_cast(getInput1().getType());
1153 auto rhsTy = llvm::dyn_cast(getInput2().getType());
1154 auto resultTy = llvm::dyn_cast(getType());
1155 if (!lhsTy || !rhsTy || !resultTy)
1156 return {};
1157
1158
1159 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1160 !rhsTy.getElementType().isIntOrIndexOrFloat())
1161 return {};
1162
1163 auto resultETy = resultTy.getElementType();
1164 auto lhsAttr =
1165 llvm::dyn_cast_if_present(adaptor.getInput1());
1166 auto rhsAttr =
1167 llvm::dyn_cast_if_present(adaptor.getInput2());
1168
1169 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1170 return getInput1();
1171
1172 if (!lhsAttr || !rhsAttr)
1173 return {};
1174
1175 return binaryFolder<std::minus, std::minus>(lhsAttr, rhsAttr,
1176 resultTy);
1177 }
1178
1179 namespace {
1180 template
1181 struct ComparisonFold {
1182 ComparisonFold() = default;
1183 APInt operator()(const APInt &l, const APInt &r) {
1184 return APInt(1, Cmp()(l, r));
1185 }
1186
1187 APInt operator()(const APFloat &l, const APFloat &r) {
1188 return APInt(1, Cmp()(l, r));
1189 }
1190 };
1191
1192 struct APIntFoldGreater {
1193 APIntFoldGreater() = default;
1194 APInt operator()(const APInt &l, const APInt &r) {
1195 return APInt(1, l.sgt(r));
1196 }
1197 };
1198
1199 struct APIntFoldGreaterEqual {
1200 APIntFoldGreaterEqual() = default;
1201 APInt operator()(const APInt &l, const APInt &r) {
1202 return APInt(1, l.sge(r));
1203 }
1204 };
1205 }
1206
1207 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1208 auto resultTy = llvm::dyn_cast(getType());
1209 auto lhsAttr =
1210 llvm::dyn_cast_if_present(adaptor.getInput1());
1211 auto rhsAttr =
1212 llvm::dyn_cast_if_present(adaptor.getInput2());
1213
1214 if (!lhsAttr || !rhsAttr)
1215 return {};
1216
1217 return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater>>(
1218 lhsAttr, rhsAttr, resultTy);
1219 }
1220
1221 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1222 auto resultTy = llvm::dyn_cast(getType());
1223 auto lhsAttr =
1224 llvm::dyn_cast_if_present(adaptor.getInput1());
1225 auto rhsAttr =
1226 llvm::dyn_cast_if_present(adaptor.getInput2());
1227
1228 if (!lhsAttr || !rhsAttr)
1229 return {};
1230
1232 ComparisonFold<std::greater_equal>>(
1233 lhsAttr, rhsAttr, resultTy);
1234 }
1235
1236 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1237 auto resultTy = llvm::dyn_cast(getType());
1238 auto lhsAttr =
1239 llvm::dyn_cast_if_present(adaptor.getInput1());
1240 auto rhsAttr =
1241 llvm::dyn_cast_if_present(adaptor.getInput2());
1242 Value lhs = getInput1();
1243 Value rhs = getInput2();
1244 auto lhsTy = llvm::cast(lhs.getType());
1245
1246
1247
1248 if (llvm::isa(lhsTy.getElementType()) && resultTy &&
1249 resultTy.hasStaticShape() && lhs == rhs) {
1251 }
1252
1253 if (!lhsAttr || !rhsAttr)
1254 return {};
1255
1256 return binaryFolder<ComparisonFold<std::equal_to>,
1257 ComparisonFold<std::equal_to>>(lhsAttr, rhsAttr,
1258 resultTy);
1259 }
1260
1261 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1263 return getInput();
1264
1265 auto operand = llvm::dyn_cast_if_present(adaptor.getInput());
1266 if (!operand)
1267 return {};
1268
1269 auto inTy = llvm::cast(getInput().getType());
1270 auto outTy = llvm::cast(getType());
1271 auto inETy = inTy.getElementType();
1272 auto outETy = outTy.getElementType();
1273
1274 if (operand.isSplat()) {
1275 if (llvm::isa(inETy) && llvm::isa(outETy)) {
1276 bool overflow;
1277 auto splatVal = operand.getSplatValue();
1278 auto &semantics = llvm::cast(outETy).getFloatSemantics();
1279 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1280 &overflow);
1282 }
1283
1284 if (llvm::isa(inETy) && llvm::isa(outETy)) {
1285 auto unsign = llvm::cast(inETy).isUnsignedInteger();
1286 APFloat splatVal(llvm::cast(outETy).getFloatSemantics());
1287 splatVal.convertFromAPInt(operand.getSplatValue(), !unsign,
1288 llvm::RoundingMode::NearestTiesToEven);
1290 }
1291
1292 if (llvm::isa(inETy) && llvm::isa(outETy)) {
1293 auto unsign = llvm::cast(outETy).isUnsignedInteger();
1294 auto intVal = APSInt(
1295 llvm::cast(outETy).getIntOrFloatBitWidth(), unsign);
1296 auto floatVal = operand.getSplatValue();
1297 bool exact;
1298 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1299 &exact);
1301 }
1302
1303 if (llvm::isa(inETy) && llvm::isa(outETy)) {
1304 auto unsignIn = llvm::cast(inETy).isUnsignedInteger();
1305 bool trunc =
1306 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1307 auto intVal = operand.getSplatValue();
1308 auto bitwidth = outETy.getIntOrFloatBitWidth();
1309
1310 if (trunc) {
1311 intVal = intVal.trunc(bitwidth);
1312 } else if (unsignIn) {
1313 intVal = intVal.zext(bitwidth);
1314 } else {
1315 intVal = intVal.sext(bitwidth);
1316 }
1317
1319 }
1320 }
1321
1322 return {};
1323 }
1324
1325 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1326
1327 OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1328
1329 #define REDUCE_FOLDER(OP) \
1330 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1331 ShapedType inputTy = llvm::cast(getInput().getType()); \
1332 if (!inputTy.hasRank()) \
1333 return {}; \
1334 if (inputTy != getType()) \
1335 return {}; \
1336 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1337 return getInput(); \
1338 return {}; \
1339 }
1340
1347 #undef REDUCE_FOLDER
1348
1349 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1350 auto inputTy = llvm::dyn_cast(getInput1().getType());
1351 auto outputTy = llvm::dyn_cast(getType());
1352
1353 if (!inputTy || !outputTy)
1354 return {};
1355
1356
1357
1358
1359 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1360 return getInput1();
1361
1362
1363 if (auto reshapeOp = llvm::dyn_cast_if_presenttosa::ReshapeOp(
1364 getInput1().getDefiningOp())) {
1365 getInput1Mutable().assign(reshapeOp.getInput1());
1366 return getResult();
1367 }
1368
1369
1370 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1371 return {};
1372
1373
1374 if (auto operand =
1375 llvm::dyn_cast_if_present(adaptor.getInput1())) {
1376
1377 if (!outputTy.hasStaticShape())
1378 return {};
1379
1380
1381 if (operand.isSplat())
1383 operand.getSplatValue<Attribute>());
1384
1385
1386 if (!getInput1().hasOneUse())
1387 return {};
1388
1391 return {};
1392
1393 return operand.reshape(
1394 llvm::cast(operand.getType()).clone(shapeVec));
1395 }
1396
1397 return {};
1398 }
1399
1400 OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1401
1402 if (adaptor.getPadding() && getInput1().getType() == getType()) {
1403 auto densePad = llvm::dyn_cast(adaptor.getPadding());
1404 if (densePad && densePad.isSplat() &&
1405 densePad.getSplatValue().isZero()) {
1406 return getInput1();
1407 }
1408 }
1409
1410 return {};
1411 }
1412
1413
1414
1415 OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1416 auto scaleAttr =
1417 llvm::dyn_cast_if_present(adaptor.getScale());
1418 auto offsetAttr =
1419 llvm::dyn_cast_if_present(adaptor.getOffset());
1420 auto borderAttr =
1421 llvm::dyn_cast_if_present(adaptor.getBorder());
1422 if (!scaleAttr || !offsetAttr || !borderAttr) {
1423 return {};
1424 }
1425
1429 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1430 return {};
1431 }
1432
1433
1434 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1435 return {};
1436 }
1437
1438
1439 if (offset[0] != 0 || offset[1] != 0) {
1440 return {};
1441 }
1442
1443
1444 if (border[0] != 0 || border[1] != 0) {
1445 return {};
1446 }
1447
1448 auto input = getInput();
1449 auto inputTy = llvm::cast(input.getType());
1450 auto resultTy = llvm::cast(getType());
1451 if (inputTy != resultTy)
1452 return {};
1453
1454 return input;
1455 }
1456
1457 OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1458 auto operand = getInput1();
1459 auto operandTy = llvm::cast(operand.getType());
1460 auto axis = getAxis();
1461 auto operandAttr =
1462 llvm::dyn_cast_if_present(adaptor.getInput1());
1463 if (operandAttr)
1464 return operandAttr;
1465
1466
1467 if (operandTy.hasRank() &&
1468 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1469 return operand;
1470
1471 return {};
1472 }
1473
1474 OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1475 auto inputTy = llvm::dyn_cast(getInput1().getType());
1476 auto outputTy = llvm::dyn_cast(getType());
1477
1478 if (!inputTy || !outputTy)
1479 return {};
1480
1481 if (inputTy == outputTy && inputTy.hasStaticShape())
1482 return getInput1();
1483
1484 if (!adaptor.getInput1())
1485 return {};
1486
1487
1488 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1489 !outputTy.getElementType().isIntOrIndexOrFloat())
1490 return {};
1491
1492 auto operand = llvm::cast(adaptor.getInput1());
1493 if (operand.isSplat() && outputTy.hasStaticShape()) {
1495 }
1496
1497 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1498 outputTy.getNumElements() == 1) {
1501 return {};
1502
1504 llvm::to_vector(startElems.getValues<uint64_t>());
1505 auto value = operand.getValues<Attribute>()[indices];
1507 }
1508
1509 return {};
1510 }
1511
1512 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1513 if (getInput2() == getInput3())
1514 return getInput2();
1515
1516 auto predicate =
1517 llvm::dyn_cast_if_present(adaptor.getInput1());
1518 if (!predicate)
1519 return {};
1520
1521 if (!predicate.isSplat())
1522 return {};
1523 return predicate.getSplatValue().getBoolValue() ? getInput2()
1524 : getInput3();
1525 }
1526
1527 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1529 if (auto multiples = llvm::dyn_cast_if_present(
1530 adaptor.getMultiples())) {
1531 if (multiples.isSplat() &&
1532 multiples.getSplatValue().getSExtValue() == 1)
1533 return getInput1();
1534 if (auto int_array_attr =
1535 llvm::dyn_cast(multiples)) {
1536 if (llvm::all_of(int_array_attr.getValues(),
1537 [](APInt v) { return v.getSExtValue() == 1; }))
1538 return getInput1();
1539 }
1540 }
1541 }
1542 return {};
1543 }
1544
1545 OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1546 auto resultTy = llvm::cast(getType());
1547
1548
1549 if (auto input =
1550 llvm::dyn_cast_if_present(adaptor.getInput1())) {
1551 if (input.isSplat() && resultTy.hasStaticShape() &&
1552 input.getType().getElementType() == resultTy.getElementType())
1553 return input.reshape(resultTy);
1554 }
1555
1556
1558
1559 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1560 return {};
1561
1562 return getInput1();
1563 }
1564
1565 OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
1566 auto input = getInput1();
1567
1568 if (auto op = input.getDefiningOptosa::ExpOp()) {
1569 return op.getInput1();
1570 }
1571
1572 return {};
1573 }
1574
1575 OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
1576 auto input = getInput1();
1577
1578 if (auto op = input.getDefiningOptosa::LogOp()) {
1579 return op.getInput1();
1580 }
1581
1582 return {};
1583 }
1584
1585 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1586
1587
1588 auto definingOp = getInput1().getDefiningOptosa::NegateOp();
1589 if (!definingOp) {
1590
1591 return {};
1592 }
1593
1594 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1595 failed(maybeIZp) || *maybeIZp != 0) {
1596
1597 return {};
1598 }
1599 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1600 failed(maybeOZp) || *maybeOZp != 0) {
1601
1602 return {};
1603 }
1604 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1605 failed(maybeIZp) || *maybeIZp != 0) {
1606
1607 return {};
1608 }
1609 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1610 failed(maybeOZp) || *maybeOZp != 0) {
1611
1612 return {};
1613 }
1614
1615 return definingOp.getInput1();
1616 }
1617
1618 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1619 auto input = getInput1();
1620
1621 if (auto op = input.getDefiningOptosa::AbsOp()) {
1622 return input;
1623 }
1624
1625 return {};
1626 }
1627
1628 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1629
1630
1631
1632
1634 concatOperands.reserve(2 * getNumOperands());
1635
1636
1637 bool foundFoldableConcat = false;
1638 for (Value operand : getOperands()) {
1639 concatOperands.emplace_back(operand);
1640
1641 auto producer = dyn_cast_or_null(operand.getDefiningOp());
1642 if (!producer)
1643 continue;
1644
1645
1646 if (getAxis() != producer.getAxis())
1647 continue;
1648
1649
1650 foundFoldableConcat = true;
1651 concatOperands.pop_back();
1652 llvm::append_range(concatOperands, producer->getOperands());
1653 }
1654
1655 if (!foundFoldableConcat)
1656 return {};
1657
1658 getOperation()->setOperands(concatOperands);
1659 return getResult();
1660 }
1661
1662 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1663 auto input = adaptor.getInput1();
1664
1665 auto inputAttr = llvm::dyn_cast_if_present(input);
1666
1667 if (!inputAttr || !inputAttr.isSplat())
1668 return {};
1669
1670 auto shapeType = llvm::cast(getType());
1671 if (auto floatType = llvm::dyn_cast(inputAttr.getElementType())) {
1672 auto floatVal = inputAttr.getSplatValue();
1674 ReciprocalOp::calcOneElement(floatVal));
1675 }
1676
1677 return {};
1678 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
StringAttr getStringAttr(const Twine &bytes)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
auto getValues() const
Return the held element values as a range of the given type.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt round(const Fraction &f)
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...