MLIR: lib/Conversion/TosaToLinalg/TosaToLinalg.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
33
34 #include
35 #include <type_traits>
36
37 using namespace mlir;
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62 template
66
68 return result;
69
70 auto nanMode = op.getNanMode();
71 if (nanMode == "PROPAGATE")
72 return result;
73
74
75 Value lhsIsNaN = rewriter.createarith::CmpFOp(
76 op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
77 Value rhsIsNaN = rewriter.createarith::CmpFOp(
78 op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
79 Value rhsOrResult =
80 rewriter.createarith::SelectOp(op.getLoc(), lhsIsNaN, rhs, result);
81 return rewriter.createarith::SelectOp(op.getLoc(), rhsIsNaN, lhs,
82 rhsOrResult);
83 }
84
89 auto elementTy =
91
92
93 if (isatosa::AbsOp(op) && isa(elementTy))
94 return rewriter.createmath::AbsFOp(loc, resultTypes, args);
95
96 if (isatosa::AbsOp(op) && isa(elementTy)) {
97 auto zero = rewriter.createarith::ConstantOp(
99 auto neg = rewriter.createarith::SubIOp(loc, zero, args[0]);
100 return rewriter.createarith::MaxSIOp(loc, args[0], neg);
101 }
102
103
104 if (isatosa::AddOp(op) && isa(elementTy))
105 return rewriter.createarith::AddFOp(loc, resultTypes, args);
106
107 if (isatosa::AddOp(op) && isa(elementTy))
108 return rewriter.createarith::AddIOp(loc, resultTypes, args);
109
110
111 if (isatosa::SubOp(op) && isa(elementTy))
112 return rewriter.createarith::SubFOp(loc, resultTypes, args);
113
114 if (isatosa::SubOp(op) && isa(elementTy))
115 return rewriter.createarith::SubIOp(loc, resultTypes, args);
116
117
118 if (isatosa::IntDivOp(op) && isa(elementTy))
119 return rewriter.createarith::DivSIOp(loc, resultTypes, args);
120
121
122 if (isatosa::ReciprocalOp(op) && isa(elementTy)) {
123 auto one =
125 return rewriter.createarith::DivFOp(loc, resultTypes, one, args[0]);
126 }
127
128
129 if (isatosa::MulOp(op)) {
130 auto shiftVal = casttosa::MulOp(op).getShift();
133 (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
134 return nullptr;
135 }
136
137 int32_t shift = shiftElem.getValues()[0].getInt();
138
139 if (isa(elementTy)) {
140 if (shift != 0) {
142 "Cannot have shift value for float");
143 return nullptr;
144 }
145 return rewriter.createarith::MulFOp(loc, resultTypes, args[0], args[1]);
146 }
147
148 if (isa(elementTy)) {
149 Value a = args[0];
150 Value b = args[1];
151
152 if (shift > 0) {
153 auto shiftConst =
154 rewriter.createarith::ConstantIntOp(loc, shift, 8);
156 a = rewriter.createarith::ExtSIOp(loc, rewriter.getI32Type(), a);
157
159 b = rewriter.createarith::ExtSIOp(loc, rewriter.getI32Type(), b);
160
161 auto result = rewriter.createtosa::ApplyScaleOp(
162 loc, rewriter.getI32Type(), a, b, shiftConst,
164
165 if (elementTy.isInteger(32))
166 return result;
167
168 return rewriter.createarith::TruncIOp(loc, elementTy, result);
169 }
170
173 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
174
175 if (aWidth < cWidth)
176 a = rewriter.createarith::ExtSIOp(loc, resultTypes[0], a);
177 if (bWidth < cWidth)
178 b = rewriter.createarith::ExtSIOp(loc, resultTypes[0], b);
179
180 return rewriter.createarith::MulIOp(loc, resultTypes, a, b);
181 }
182 }
183
184
185 if (isatosa::NegateOp(op)) {
186 auto negate = casttosa::NegateOp(op);
187
188 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
189 if (failed(maybeInZp)) {
191 op, "input1 zero point cannot be statically determined");
192 return nullptr;
193 }
194
195 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
196 if (failed(maybeOutZp)) {
198 op, "output zero point cannot be statically determined");
199 return nullptr;
200 }
201
202 int64_t inZp = *maybeInZp;
203 int64_t outZp = *maybeOutZp;
204
205 if (isa(elementTy))
206 return rewriter.createarith::NegFOp(loc, resultTypes, args[0]);
207
208 if (isa(elementTy)) {
209 if (!inZp && !outZp) {
210 auto constant = rewriter.createarith::ConstantOp(
212 return rewriter.createarith::SubIOp(loc, resultTypes, constant,
213 args[0]);
214 }
215
216
217 const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
218 const int64_t zpAdd = inZp + outZp;
219 const int64_t maxValue =
220 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
222
223
224
225
226 int intermediateBitWidth = 64;
227 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
228 intermediateBitWidth = 16;
229 } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
230 intermediateBitWidth = 32;
231 } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
232 intermediateBitWidth = 48;
233 }
234
235 Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
236 Value zpAddValue = rewriter.createarith::ConstantOp(
237 loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
238
239
240
241 auto ext =
242 rewriter.createarith::ExtSIOp(loc, intermediateType, args[0]);
243 auto sub = rewriter.createarith::SubIOp(loc, zpAddValue, ext);
244
245
247 loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
248 intermediateType);
250 loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
251 intermediateType);
253
254
255 return rewriter.createarith::TruncIOp(loc, elementTy, clamp);
256 }
257 }
258
259
260 if (isatosa::BitwiseAndOp(op) && isa(elementTy))
261 return rewriter.createarith::AndIOp(loc, resultTypes, args);
262
263
264 if (isatosa::BitwiseOrOp(op) && isa(elementTy))
265 return rewriter.createarith::OrIOp(loc, resultTypes, args);
266
267
268 if (isatosa::BitwiseNotOp(op) && isa(elementTy)) {
270 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
271 auto allOnes = rewriter.createarith::ConstantOp(loc, allOnesAttr);
272 return rewriter.createarith::XOrIOp(loc, resultTypes, args[0], allOnes);
273 }
274
275
276 if (isatosa::BitwiseXorOp(op) && isa(elementTy))
277 return rewriter.createarith::XOrIOp(loc, resultTypes, args);
278
279
280 if (isatosa::LogicalLeftShiftOp(op) && isa(elementTy))
281 return rewriter.createarith::ShLIOp(loc, resultTypes, args);
282
283
284 if (isatosa::LogicalRightShiftOp(op) && isa(elementTy))
285 return rewriter.createarith::ShRUIOp(loc, resultTypes, args);
286
287
288 if (isatosa::ArithmeticRightShiftOp(op) && isa(elementTy)) {
289 auto result = rewriter.createarith::ShRSIOp(loc, resultTypes, args);
290 auto round = cast(op->getAttr("round")).getValue();
292 return result;
293 }
294
296 auto one =
298 auto zero =
300 auto i1one =
302
303
304 auto shiftValueGreaterThanZero = rewriter.createarith::CmpIOp(
305 loc, arith::CmpIPredicate::sgt, args[1], zero);
306
307
308 auto subtract =
309 rewriter.createarith::SubIOp(loc, resultTypes, args[1], one);
310 auto shifted =
311 rewriter.createarith::ShRSIOp(loc, resultTypes, args[0], subtract)
312 ->getResults();
313 auto truncated =
314 rewriter.createarith::TruncIOp(loc, i1Ty, shifted, std::nullopt);
315 auto isInputOdd =
316 rewriter.createarith::AndIOp(loc, i1Ty, truncated, i1one);
317
318 auto shouldRound = rewriter.createarith::AndIOp(
319 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
320 auto extended =
321 rewriter.createarith::ExtUIOp(loc, resultTypes, shouldRound);
322 return rewriter.createarith::AddIOp(loc, resultTypes, result, extended);
323 }
324
325
326 if (isatosa::ClzOp(op) && isa(elementTy)) {
327 return rewriter.createmath::CountLeadingZerosOp(loc, elementTy, args[0]);
328 }
329
330
331 if (isatosa::LogicalAndOp(op) && elementTy.isInteger(1))
332 return rewriter.createarith::AndIOp(loc, resultTypes, args);
333
334
335 if (isatosa::LogicalNotOp(op) && elementTy.isInteger(1)) {
336 auto one = rewriter.createarith::ConstantOp(
338 return rewriter.createarith::XOrIOp(loc, resultTypes, args[0], one);
339 }
340
341
342 if (isatosa::LogicalOrOp(op) && elementTy.isInteger(1))
343 return rewriter.createarith::OrIOp(loc, resultTypes, args);
344
345
346 if (isatosa::LogicalXorOp(op) && elementTy.isInteger(1))
347 return rewriter.createarith::XOrIOp(loc, resultTypes, args);
348
349
350 if (isatosa::PowOp(op) && isa(elementTy))
351 return rewriter.createmlir::math::PowFOp(loc, resultTypes, args);
352
353
354 if (isatosa::RsqrtOp(op) && isa(elementTy))
355 return rewriter.createmlir::math::RsqrtOp(loc, resultTypes, args);
356
357
358 if (isatosa::LogOp(op) && isa(elementTy))
359 return rewriter.createmlir::math::LogOp(loc, resultTypes, args);
360
361
362 if (isatosa::ExpOp(op) && isa(elementTy))
363 return rewriter.createmlir::math::ExpOp(loc, resultTypes, args);
364
365
366 if (isatosa::SinOp(op) && isa(elementTy))
367 return rewriter.createmlir::math::SinOp(loc, resultTypes, args);
368
369
370 if (isatosa::CosOp(op) && isa(elementTy))
371 return rewriter.createmlir::math::CosOp(loc, resultTypes, args);
372
373
374 if (isatosa::TanhOp(op) && isa(elementTy))
375 return rewriter.createmlir::math::TanhOp(loc, resultTypes, args);
376
377
378 if (isatosa::ErfOp(op) && llvm::isa(elementTy))
379 return rewriter.createmlir::math::ErfOp(loc, resultTypes, args);
380
381
382 if (isatosa::GreaterOp(op) && isa(elementTy))
383 return rewriter.createarith::CmpFOp(loc, arith::CmpFPredicate::OGT,
384 args[0], args[1]);
385
386 if (isatosa::GreaterOp(op) && elementTy.isSignlessInteger())
387 return rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::sgt,
388 args[0], args[1]);
389
390
391 if (isatosa::GreaterEqualOp(op) && isa(elementTy))
392 return rewriter.createarith::CmpFOp(loc, arith::CmpFPredicate::OGE,
393 args[0], args[1]);
394
395 if (isatosa::GreaterEqualOp(op) && elementTy.isSignlessInteger())
396 return rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::sge,
397 args[0], args[1]);
398
399
400 if (isatosa::EqualOp(op) && isa(elementTy))
401 return rewriter.createarith::CmpFOp(loc, arith::CmpFPredicate::OEQ,
402 args[0], args[1]);
403
404 if (isatosa::EqualOp(op) && elementTy.isSignlessInteger())
405 return rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::eq,
406 args[0], args[1]);
407
408
409 if (isatosa::SelectOp(op)) {
410 elementTy = cast(op->getOperand(1).getType()).getElementType();
411 if (isa(elementTy) || isa(elementTy))
412 return rewriter.createarith::SelectOp(loc, args[0], args[1], args[2]);
413 }
414
415
416 if (isatosa::MaximumOp(op) && isa(elementTy)) {
417 auto max = rewriter.createarith::MaximumFOp(loc, args[0], args[1]);
419 rewriter, args[0], args[1], max);
420 }
421
422 if (isatosa::MaximumOp(op) && elementTy.isSignlessInteger()) {
423 return rewriter.createarith::MaxSIOp(loc, args[0], args[1]);
424 }
425
426
427 if (isatosa::MinimumOp(op) && isa(elementTy)) {
428 auto min = rewriter.createarith::MinimumFOp(loc, args[0], args[1]);
430 rewriter, args[0], args[1], min);
431 }
432
433 if (isatosa::MinimumOp(op) && elementTy.isSignlessInteger()) {
434 return rewriter.createarith::MinSIOp(loc, args[0], args[1]);
435 }
436
437
438 if (isatosa::CeilOp(op) && isa(elementTy))
439 return rewriter.createmath::CeilOp(loc, resultTypes, args);
440
441
442 if (isatosa::FloorOp(op) && isa(elementTy))
443 return rewriter.createmath::FloorOp(loc, resultTypes, args);
444
445
446 if (isatosa::ClampOp(op) && isa(elementTy)) {
447 bool losesInfo = false;
448 APFloat minApf = cast(op->getAttr("min_val")).getValue();
449 APFloat maxApf = cast(op->getAttr("max_val")).getValue();
450 minApf.convert(cast(elementTy).getFloatSemantics(),
451 APFloat::rmNearestTiesToEven, &losesInfo);
452 maxApf.convert(cast(elementTy).getFloatSemantics(),
453 APFloat::rmNearestTiesToEven, &losesInfo);
454 auto min = rewriter.createarith::ConstantOp(
455 loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
456 auto max = rewriter.createarith::ConstantOp(
457 loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
459
460 auto clampOp = llvm::casttosa::ClampOp(op);
461 const auto nanMode = clampOp.getNanMode();
462
463
464 if (!isa(elementTy))
465 return result;
466
467
468
469 if (nanMode == "PROPAGATE")
470 return result;
471
472
473
474
475
476
477
478
479
480
481
482
483 Value isNaN = rewriter.createarith::CmpFOp(
484 op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
485
486
487 return rewriter.createarith::SelectOp(op->getLoc(), isNaN, min, result);
488 }
489
490 if (isatosa::ClampOp(op) && isa(elementTy)) {
491 auto intTy = cast(elementTy);
492 int64_t min =
493 cast(op->getAttr("min_val")).getValue().getSExtValue();
494 int64_t max =
495 cast(op->getAttr("max_val")).getValue().getSExtValue();
496
499 if (intTy.isUnsignedInteger()) {
500 minRepresentable = 0;
501 if (intTy.getIntOrFloatBitWidth() <= 63) {
502 maxRepresentable =
503 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
504 .getZExtValue();
505 }
506 } else if (intTy.getIntOrFloatBitWidth() <= 64) {
507
508 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
509 .getSExtValue();
510 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
511 .getSExtValue();
512 }
513
514
519
520 auto minVal = rewriter.createarith::ConstantIntOp(
521 loc, min, intTy.getIntOrFloatBitWidth());
522 auto maxVal = rewriter.createarith::ConstantIntOp(
523 loc, max, intTy.getIntOrFloatBitWidth());
524 return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
525 intTy.isUnsignedInteger());
526 }
527
528
529 if (isatosa::SigmoidOp(op) && isa(elementTy)) {
530 auto one =
532 auto negate = rewriter.createarith::NegFOp(loc, resultTypes, args[0]);
533 auto exp = rewriter.createmlir::math::ExpOp(loc, resultTypes, negate);
534 auto added = rewriter.createarith::AddFOp(loc, resultTypes, exp, one);
535 return rewriter.createarith::DivFOp(loc, resultTypes, one, added);
536 }
537
538
539 if (isatosa::CastOp(op)) {
540 Type srcTy = elementTy;
541 Type dstTy = resultTypes.front();
544 return nullptr;
545 }
546
547 bool bitExtend =
549
550 if (srcTy == dstTy)
551 return args.front();
552
553 if (isa(srcTy) && isa(dstTy) && bitExtend)
554 return rewriter.createarith::ExtFOp(loc, resultTypes, args,
555 std::nullopt);
556
557 if (isa(srcTy) && isa(dstTy) && !bitExtend)
558 return rewriter.createarith::TruncFOp(loc, resultTypes, args,
559 std::nullopt);
560
561
562 if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
563 return rewriter.createarith::UIToFPOp(loc, resultTypes, args,
564 std::nullopt);
565
566 if (srcTy.isInteger(1) && isa(dstTy) && bitExtend)
567 return rewriter.createarith::ExtUIOp(loc, resultTypes, args,
568 std::nullopt);
569
570
571
573 auto unrealizedCast =
574 rewriter
575 .create(
577 args[0])
579 return rewriter.createarith::UIToFPOp(loc, resultTypes[0],
580 unrealizedCast);
581 }
582
583
584 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
585 return rewriter.createarith::SIToFPOp(loc, resultTypes, args,
586 std::nullopt);
587
588
589 if (isa(srcTy) && dstTy.isInteger(1)) {
590 Value zero = rewriter.createarith::ConstantOp(
592 return rewriter.createarith::CmpFOp(loc, arith::CmpFPredicate::UNE,
593 args.front(), zero);
594 }
595
596 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
597 auto rounded = rewriter.createmath::RoundEvenOp(loc, args[0]);
598
599 const auto &fltSemantics = cast(srcTy).getFloatSemantics();
600
601
603 APFloat::semanticsMaxExponent(fltSemantics)) {
604
605
606 auto conv = rewriter.createarith::FPToSIOp(loc, dstTy, rounded);
607 auto posInf = rewriter.createarith::ConstantOp(
609 APFloat::getInf(fltSemantics)));
610 auto negInf = rewriter.createarith::ConstantOp(
613 APFloat::getInf(fltSemantics, true)));
614 auto overflow = rewriter.createarith::CmpFOp(
615 loc, arith::CmpFPredicate::UEQ, rounded, posInf);
616 auto underflow = rewriter.createarith::CmpFOp(
617 loc, arith::CmpFPredicate::UEQ, rounded, negInf);
618 auto intMin = rewriter.createarith::ConstantOp(
621 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
622 auto intMax = rewriter.createarith::ConstantOp(
625 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
626 auto maxClamped =
627 rewriter.createarith::SelectOp(loc, overflow, intMax, conv);
628 return rewriter.createarith::SelectOp(loc, underflow, intMin,
629 maxClamped);
630 }
631
632 auto intMinFP = rewriter.createarith::ConstantOp(
636 .getSExtValue()));
637
638
639 if (cast(srcTy).getFPMantissaWidth() >=
641
642
643
644
645 auto intMaxFP = rewriter.createarith::ConstantOp(
649 .getSExtValue()));
650
652 clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
653 return rewriter.createarith::FPToSIOp(loc, dstTy, clamped);
654 }
655
656
657
658
659
660 auto intMaxPlusOneFP = rewriter.createarith::ConstantOp(
663 static_cast<double>(
665 .getSExtValue()) +
666 1.0f));
667
668 auto intMax = rewriter.createarith::ConstantOp(
672 auto minClampedFP =
673 rewriter.createarith::MaximumFOp(loc, rounded, intMinFP);
674 auto minClamped =
675 rewriter.createarith::FPToSIOp(loc, dstTy, minClampedFP);
676 auto overflow = rewriter.createarith::CmpFOp(
677 loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
678 return rewriter.createarith::SelectOp(loc, overflow, intMax,
679 minClamped);
680 }
681
682
683
684 if (isa(srcTy) && dstTy.isInteger(1)) {
685 Value zero = rewriter.createarith::ConstantIntOp(
687 return rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::ne,
688 args.front(), zero);
689 }
690
691 if (isa(srcTy) && isa(dstTy) && bitExtend)
692 return rewriter.createarith::ExtSIOp(loc, resultTypes, args,
693 std::nullopt);
694
695 if (isa(srcTy) && isa(dstTy) && !bitExtend) {
696 return rewriter.createarith::TruncIOp(loc, dstTy, args[0]);
697 }
698 }
699
701 op, "unhandled op for linalg body calculation for elementwise op");
702 return nullptr;
703 }
704
706
707
708
709
711 IndexPool &indexPool, int64_t index) {
712 auto [it, inserted] = indexPool.try_emplace(index);
713 if (inserted)
714 it->second =
715 rewriter.createarith::ConstantOp(loc, rewriter.getIndexAttr(index));
716 return it->second;
717 }
718
721 auto indexValue = createIndex(rewriter, loc, indexPool, index);
722 return rewriter.createtensor::DimOp(loc, tensor, indexValue).getResult();
723 }
724
727 int64_t index) {
728 auto shapedType = dyn_cast(tensor.getType());
729 assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
730 assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
731 if (shapedType.isDynamicDim(index))
732 return getTensorDim(rewriter, loc, indexPool, tensor, index);
733 return rewriter.getIndexAttr(shapedType.getDimSize(index));
734 }
735
737 auto isRanked = [](Value value) {
738 return isa(value.getType());
739 };
740 return llvm::all_of(operation->getOperands(), isRanked) &&
741 llvm::all_of(operation->getResults(), isRanked);
742 }
743
744
745
746
747
748
749
750
751
752
753
754 static std::pair<OpFoldResult, Value>
757
758
759
760 for (auto operand : operands) {
761 auto size = cast(operand.getType()).getDimSize(dim);
762 if (!ShapedType::isDynamic(size) && size > 1)
763 return {rewriter.getIndexAttr(size), operand};
764 }
765
766
767 auto operandsWithDynamicDim =
768 llvm::filter_to_vector(operands, [&](Value operand) {
769 return cast(operand.getType()).isDynamicDim(dim);
770 });
771
772
773 if (operandsWithDynamicDim.empty())
774 return {rewriter.getIndexAttr(1), operands.front()};
775
776
777
778
779 auto targetSize =
780 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
781 if (operandsWithDynamicDim.size() == 1)
782 return {targetSize, operandsWithDynamicDim[0]};
783
784
785 for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
786 auto nextSize =
787 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
788 targetSize = rewriter.createarith::MaxUIOp(loc, targetSize, nextSize);
789 }
790 return {targetSize, nullptr};
791 }
792
793
794
798 assert(!operands.empty());
799 auto rank = cast(operands.front().getType()).getRank();
802 for (auto dim : llvm::seq<int64_t>(0, rank)) {
803 auto [targetSize, masterOperand] =
805 targetShape.push_back(targetSize);
806 masterOperands.push_back(masterOperand);
807 }
808 return {targetShape, masterOperands};
809 }
810
814 Value masterOperand) {
815
816 auto rankedTensorType = cast(operand.getType());
817 if (!rankedTensorType.isDynamicDim(dim))
818 return operand;
819
820
821
822
823
824 if (operand == masterOperand)
825 return operand;
826
827
828 auto rank = rankedTensorType.getRank();
830 for (auto index : llvm::seq<int64_t>(0, rank)) {
833 affineExprs.push_back(affineExpr);
834 }
835 auto broadcastAffineMap =
839
840
841 auto one = createIndex(rewriter, loc, indexPool, 1);
842 auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
843 auto broadcastNecessary = rewriter.createarith::CmpIOp(
844 loc, arith::CmpIPredicate::eq, runtimeSize, one);
845
846
848
849
851
852
854 for (auto index : llvm::seq<int64_t>(0, rank)) {
855 auto size = index == dim ? targetSize
857 operand, index);
858 outputTensorShape.push_back(size);
859 }
860 Value outputTensor = opBuilder.createtensor::EmptyOp(
861 loc, outputTensorShape, rankedTensorType.getElementType());
862
863
864 auto resultTensor =
865 opBuilder
866 .createlinalg::GenericOp(
867 loc, outputTensor.getType(), operand, outputTensor, affineMaps,
870
871 opBuilder.createlinalg::YieldOp(loc, blockArgs.front());
872 })
873 .getResult(0);
874
875
876 auto castResultTensor = rewriter.createOrFoldtensor::CastOp(
877 loc, operand.getType(), resultTensor);
878
879
880 opBuilder.createscf::YieldOp(loc, castResultTensor);
881 };
882
883
885 opBuilder.createscf::YieldOp(loc, operand);
886 };
887
888
889 auto ifOp = rewriter.createscf::IfOp(loc, broadcastNecessary,
890 emitThenRegion, emitElseRegion);
892 }
893
898 int64_t rank = cast(operand.getType()).getRank();
899 assert((int64_t)targetShape.size() == rank);
900 assert((int64_t)masterOperands.size() == rank);
901 for (auto index : llvm::seq<int64_t>(0, rank))
902 operand =
904 targetShape[index], masterOperands[index]);
905 return operand;
906 }
907
913
914 if (operands.size() == 1)
915 return operands;
916
917
918 return llvm::map_to_vector(operands, [&](Value operand) {
920 targetShape, masterOperands);
921 });
922 }
923
924 static LogicalResult
929
930 auto resultType = cast_or_null(
932 if (!resultType) {
933 return rewriter.notifyMatchFailure(operation, "failed to convert type");
934 }
935 Value outputTensor = rewriter.createtensor::EmptyOp(
936 loc, targetShape, resultType.getElementType());
937
938
939
940
941 auto rank = resultType.getRank();
942 auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
943 auto shape = cast(operand.getType()).getShape();
946
947
948
949 bool requiresBroadcast =
950 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
951 auto affineExpr = requiresBroadcast
954 affineExprs.push_back(affineExpr);
955 }
957 });
959
960
961 bool encounteredError = false;
962 auto linalgOp = rewriter.createlinalg::GenericOp(
963 loc, outputTensor.getType(), operands, outputTensor, affineMaps,
967 operation, blockArgs.take_front(operation->getNumOperands()),
968 {resultType.getElementType()}, rewriter);
969 if (!opResult) {
970 encounteredError = true;
971 return;
972 }
973 opBuilder.createlinalg::YieldOp(loc, opResult);
974 });
975 if (encounteredError)
977 operation, "unable to create linalg.generic body for elementwise op");
978
979
980 auto castResult = rewriter.createOrFoldtensor::CastOp(
981 loc, resultType, linalgOp->getResult(0));
982 rewriter.replaceOp(operation, castResult);
983 return success();
984 }
985
988
989 if (isatosa::MulOp(operation))
990 return operands.take_front(2);
991
992 if (isatosa::NegateOp(operation))
993 return operands.take_front(1);
994 return operands;
995 }
996
997 static LogicalResult
1001
1002
1003 assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
1005 "elementwise op expects at least 1 operand");
1008 "Unranked tensors not supported");
1009
1010
1012 auto loc = operation->getLoc();
1014 auto [targetShape, masterOperands] =
1016 auto broadcastOperands =
1018 targetShape, masterOperands);
1020 targetShape, converter);
1021 }
1022
1023
1024
1027 if (isatosa::ReduceSumOp(op) && isa(elementTy))
1028 return rewriter.getFloatAttr(elementTy, 0.0);
1029
1030 if (isatosa::ReduceSumOp(op) && isa(elementTy))
1032
1033 if (isatosa::ReduceProductOp(op) && isa(elementTy))
1034 return rewriter.getFloatAttr(elementTy, 1.0);
1035
1036 if (isatosa::ReduceProductOp(op) && isa(elementTy))
1038
1039 if (isatosa::ReduceMinOp(op) && isa(elementTy))
1041 elementTy, APFloat::getLargest(
1042 cast(elementTy).getFloatSemantics(), false));
1043
1044 if (isatosa::ReduceMinOp(op) && isa(elementTy))
1047
1048 if (isatosa::ReduceMaxOp(op) && isa(elementTy))
1050 elementTy, APFloat::getLargest(
1051 cast(elementTy).getFloatSemantics(), true));
1052
1053 if (isatosa::ReduceMaxOp(op) && isa(elementTy))
1056
1057 if (isatosa::ReduceAllOp(op) && elementTy.isInteger(1))
1058 return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
1059
1060 if (isatosa::ReduceAnyOp(op) && elementTy.isInteger(1))
1062
1063 if (isatosa::ArgMaxOp(op) && isa(elementTy))
1065 elementTy, APFloat::getLargest(
1066 cast(elementTy).getFloatSemantics(), true));
1067
1068 if (isatosa::ArgMaxOp(op) && isa(elementTy))
1071
1072 return {};
1073 }
1074
1075
1076
1079 Type elementTy,
1082 if (isatosa::ReduceSumOp(op) && isa(elementTy)) {
1083 return rewriter.createarith::AddFOp(loc, args);
1084 }
1085
1086 if (isatosa::ReduceSumOp(op) && isa(elementTy)) {
1087 return rewriter.createarith::AddIOp(loc, args);
1088 }
1089
1090 if (isatosa::ReduceProductOp(op) && isa(elementTy)) {
1091 return rewriter.createarith::MulFOp(loc, args);
1092 }
1093
1094 if (isatosa::ReduceProductOp(op) && isa(elementTy)) {
1095 return rewriter.createarith::MulIOp(loc, args);
1096 }
1097
1098 if (isatosa::ReduceMinOp(op) && isa(elementTy)) {
1099 return rewriter.createarith::MinimumFOp(loc, args[0], args[1]);
1100 }
1101
1102 if (isatosa::ReduceMinOp(op) && isa(elementTy)) {
1103 return rewriter.createarith::MinSIOp(loc, args[0], args[1]);
1104 }
1105
1106 if (isatosa::ReduceMaxOp(op) && isa(elementTy)) {
1107 return rewriter.createarith::MaximumFOp(loc, args[0], args[1]);
1108 }
1109
1110 if (isatosa::ReduceMaxOp(op) && isa(elementTy)) {
1111 return rewriter.createarith::MaxSIOp(loc, args[0], args[1]);
1112 }
1113
1114 if (isatosa::ReduceAllOp(op) && elementTy.isInteger(1))
1115 return rewriter.createarith::AndIOp(loc, args);
1116
1117 if (isatosa::ReduceAnyOp(op) && elementTy.isInteger(1))
1118 return rewriter.createarith::OrIOp(loc, args);
1119
1120 return {};
1121 }
1122
1123
1124
1125
1126 template
1129 auto loc = op->getLoc();
1130 auto inputTy = dyn_cast(op->getOperand(0).getType());
1131 auto resultTy = dyn_cast(op->getResult(0).getType());
1132 if (!inputTy || !resultTy)
1133 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
1134
1135 auto elementTy = resultTy.getElementType();
1136 Value input = op->getOperand(0);
1137
1140 for (unsigned i = 0; i < inputTy.getRank(); i++) {
1141 if (axis != i) {
1142 reduceShape.push_back(inputTy.getDimSize(i));
1143 if (inputTy.isDynamicDim(i))
1144 dynDims.push_back(rewriter.createtensor::DimOp(loc, input, i));
1145 }
1146 }
1147
1149 inputs.push_back(input);
1150
1151
1152 auto emptyTensor =
1153 rewriter
1154 .createtensor::EmptyOp(loc, reduceShape, resultTy.getElementType(),
1155 dynDims)
1156 .getResult();
1157
1159 if (!fillValueAttr)
1161 op, "No initial value found for reduction operation");
1162
1163 auto fillValue = rewriter.createarith::ConstantOp(loc, fillValueAttr);
1164 auto filledTensor = rewriter
1167 .result();
1168 outputs.push_back(filledTensor);
1169
1170 bool isNanIgnoreMode = false;
1171 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1172 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1173
1174 if (isa(elementTy) && op.getNanMode() == "IGNORE") {
1175 isNanIgnoreMode = true;
1176
1177
1178
1179
1180 auto trueAttr = rewriter.getBoolAttr(true);
1181 auto trueValue = rewriter.createarith::ConstantOp(loc, trueAttr);
1182 auto emptyBoolTensor =
1183 rewriter
1184 .createtensor::EmptyOp(loc, reduceShape, trueValue.getType(),
1185 dynDims)
1186 .getResult();
1187 auto allResultsNaNTensor =
1188 rewriter
1191 .result();
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201 inputs.push_back(input);
1202 outputs.push_back(allResultsNaNTensor);
1203 }
1204 }
1205
1206 bool didEncounterError = false;
1207 linalg::LinalgOp linalgOp = rewriter.createlinalg::ReduceOp(
1208 loc, inputs, outputs, axis,
1210 std::array<Value, 2> binaryArgs{
1211 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1213 op, binaryArgs, elementTy, rewriter);
1214 if (result)
1215 didEncounterError = true;
1216
1218 if (isNanIgnoreMode) {
1219 auto inputValue = blockArgs[0];
1220 auto initialValue = blockArgs[2];
1221 auto oldAllResultsNanFlagValue = blockArgs[3];
1222
1223
1224 Value isNaN = nestedBuilder.createarith::CmpFOp(
1225 op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
1226
1227 auto selectOp = nestedBuilder.createarith::SelectOp(
1228 op->getLoc(), isNaN, initialValue, result);
1229
1230
1231 auto newAllResultsNanFlagValue = nestedBuilder.createarith::AndIOp(
1232 op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1233 resultsToYield.push_back(selectOp);
1234 resultsToYield.push_back(newAllResultsNanFlagValue);
1235 } else {
1236 resultsToYield.push_back(result);
1237 }
1238 nestedBuilder.createlinalg::YieldOp(loc, resultsToYield);
1239 });
1240
1241 if (!didEncounterError)
1243 op, "unable to create linalg.generic body for reduce op");
1244
1245 if (isNanIgnoreMode) {
1246
1247
1248
1249
1250
1251
1253 elementTy,
1254 APFloat::getNaN(cast(elementTy).getFloatSemantics(), false));
1255 auto nanValue = rewriter.createarith::ConstantOp(loc, nanValueAttr);
1256 auto emptyNanTensor =
1257 rewriter
1258 .createtensor::EmptyOp(loc, reduceShape,
1259 resultTy.getElementType(), dynDims)
1260 .getResult();
1261 auto nanFilledTensor =
1262 rewriter
1265 .result();
1266
1267
1268
1269 auto finalEmptyTensor =
1270 rewriter
1271 .createtensor::EmptyOp(loc, reduceShape,
1272 resultTy.getElementType(), dynDims)
1273 .getResult();
1274
1275
1276
1278 ins.push_back(linalgOp->getOpResult(1));
1279 ins.push_back(nanFilledTensor);
1280 ins.push_back(linalgOp->getResult(0));
1281 outs.push_back(finalEmptyTensor);
1282 auto linalgSelect =
1283 rewriter.createlinalg::SelectOp(op->getLoc(), ins, outs);
1284 linalgOp = linalgSelect;
1285 }
1286
1288 uint64_t expandInputRank =
1289 cast(linalgOp->getResults()[0].getType()).getRank();
1290 reassociationMap.resize(expandInputRank);
1291
1292 for (uint64_t i = 0; i < expandInputRank; i++) {
1293 int32_t dimToPush = i > axis ? i + 1 : i;
1294 reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1295 }
1296
1297 if (expandInputRank != 0) {
1298 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1299 reassociationMap[expandedDim].push_back(
1301 }
1302
1303
1304
1305
1306
1308 op, resultTy, linalgOp->getResults()[0], reassociationMap);
1309 return success();
1310 }
1311
1312 namespace {
1313
1314 template
1316 public:
1319
1320 LogicalResult
1321 matchAndRewrite(SrcOp op, OpAdaptor operands,
1324 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1325 }
1326 };
1327
1328 class RescaleConverter : public OpRewritePatterntosa::RescaleOp {
1329 public:
1331
1332 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1334 auto loc = op.getLoc();
1335 auto input = op.getInput();
1336 auto inputTy = cast(op.getInput().getType());
1337 auto outputTy = cast(op.getOutput().getType());
1338 unsigned rank = inputTy.getRank();
1339
1340
1341 if (op.getRoundingMode() == "INEXACT_ROUND")
1342 return rewriter.notifyMatchFailure(
1343 op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1344 "currently supported");
1345 if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
1346 return rewriter.notifyMatchFailure(
1347 op, "tosa.rescale requires scale32 for double_round to be true");
1348
1349 if (!isa(inputTy.getElementType()))
1350 return rewriter.notifyMatchFailure(op, "only support integer type");
1351
1353 for (int i = 0; i < outputTy.getRank(); i++) {
1354 if (outputTy.isDynamicDim(i)) {
1355 dynDims.push_back(rewriter.createtensor::DimOp(loc, input, i));
1356 }
1357 }
1358
1359
1362 return rewriter.notifyMatchFailure(
1363 op, "tosa.rescale requires constant shift input values");
1364
1367 return rewriter.notifyMatchFailure(
1368 op, "tosa.rescale requires constant multiplier input values");
1369
1371 llvm::to_vector(shiftElems.getValues<int8_t>());
1372
1374 llvm::map_range(multiplierElems.getValues(),
1375 [](IntegerAttr attr) -> int32_t {
1376 return static_cast<int32_t>(attr.getInt());
1377 }));
1378
1379
1380 for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1381 if (shiftValues[i] > 63) {
1382 shiftValues[i] = 0;
1383 multiplierValues[i] = 0;
1384 }
1385 }
1386
1387
1388
1389
1390 bool doubleRound =
1391 op.getRoundingMode() == "DOUBLE_ROUND" &&
1392 llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1393 StringAttr roundingMode = doubleRound
1394 ? rewriter.getStringAttr("DOUBLE_ROUND")
1395 : rewriter.getStringAttr("SINGLE_ROUND");
1396
1398 rewriter.getMultiDimIdentityMap(rank)};
1400
1401
1402
1403 Value multiplierConstant;
1404 int64_t multiplierArg = 0;
1405 if (multiplierValues.size() == 1) {
1406 multiplierConstant = rewriter.createarith::ConstantOp(
1407 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1408 } else {
1410 rewriter.getAffineDimExpr(rank - 1)};
1411 auto multiplierType =
1413 rewriter.getI32Type());
1414 genericInputs.push_back(rewriter.createarith::ConstantOp(
1416
1417 indexingMaps.push_back(AffineMap::get(rank,
1418 0, multiplierExprs,
1419 rewriter.getContext()));
1420
1421 multiplierArg = indexingMaps.size() - 1;
1422 }
1423
1424
1425
1426 Value shiftConstant;
1427 int64_t shiftArg = 0;
1428 if (shiftValues.size() == 1) {
1429 shiftConstant = rewriter.createarith::ConstantOp(
1430 loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1431 } else {
1433 rewriter.getAffineDimExpr(rank - 1)};
1434 auto shiftType =
1436 rewriter.getIntegerType(8));
1437 genericInputs.push_back(rewriter.createarith::ConstantOp(
1439 indexingMaps.push_back(AffineMap::get(rank,
1440 0, shiftExprs,
1441 rewriter.getContext()));
1442 shiftArg = indexingMaps.size() - 1;
1443 }
1444
1445
1446 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1447
1448
1449 Value emptyTensor = rewriter.createtensor::EmptyOp(
1450 loc, outputTy.getShape(), outputTy.getElementType(),
1452
1453 auto linalgOp = rewriter.createlinalg::GenericOp(
1454 loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
1458 Value value = blockArgs[0];
1460
1461 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1462 if (failed(maybeIZp)) {
1463 (void)rewriter.notifyMatchFailure(
1464 op, "input zero point cannot be statically determined");
1465 return;
1466 }
1467
1469
1470 const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
1471 auto inputZp = nestedBuilder.createarith::ConstantOp(
1472 loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
1473 *maybeIZp));
1474
1475 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1476 if (failed(maybeOZp)) {
1477 (void)rewriter.notifyMatchFailure(
1478 op, "output zero point cannot be statically determined");
1479 return;
1480 };
1481
1482 IntegerType outIntType =
1483 cast(blockArgs.back().getType());
1484 unsigned outBitWidth = outIntType.getWidth();
1485 const int32_t outAttrBitwidth = 32;
1486 assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
1487 auto outputZp = nestedBuilder.createarith::ConstantOp(
1488 loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
1489 *maybeOZp));
1490
1491 Value multiplier = multiplierConstant ? multiplierConstant
1492 : blockArgs[multiplierArg];
1493 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1494
1496 value = nestedBuilder
1497 .create(
1498 nestedLoc,
1499 nestedBuilder.getIntegerType(
1501 value)
1502 .getResult(0);
1503 }
1505 if (op.getInputUnsigned()) {
1506 value = nestedBuilder.createarith::ExtUIOp(
1507 nestedLoc, nestedBuilder.getI32Type(), value);
1508 } else {
1509 value = nestedBuilder.createarith::ExtSIOp(
1510 nestedLoc, nestedBuilder.getI32Type(), value);
1511 }
1512 }
1513
1514 value =
1515 nestedBuilder.createarith::SubIOp(nestedLoc, value, inputZp);
1516
1517 value = nestedBuilder.createtosa::ApplyScaleOp(
1518 loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1519 roundingMode);
1520
1521
1522 value =
1523 nestedBuilder.createarith::AddIOp(nestedLoc, value, outputZp);
1524
1525
1526 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1527 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1528
1529
1530 if (op.getOutputUnsigned()) {
1531 intMin = 0;
1532 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1533 }
1534
1535 auto intMinVal = nestedBuilder.createarith::ConstantOp(
1536 loc, nestedBuilder.getI32IntegerAttr(intMin));
1537 auto intMaxVal = nestedBuilder.createarith::ConstantOp(
1538 loc, nestedBuilder.getI32IntegerAttr(intMax));
1539
1540 value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1541 nestedBuilder, false);
1542
1543 if (outIntType.getWidth() < 32) {
1544 value = nestedBuilder.createarith::TruncIOp(
1545 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1546 value);
1547 }
1548
1549 if (outIntType.isUnsignedInteger()) {
1550 value = nestedBuilder
1551 .create(nestedLoc,
1552 outIntType, value)
1553 .getResult(0);
1554 }
1555 nestedBuilder.createlinalg::YieldOp(loc, value);
1556 });
1557
1558 rewriter.replaceOp(op, linalgOp->getResults());
1559 return success();
1560 }
1561 };
1562
1563
1564
1565
1566 class ResizeUnaryConverter : public OpRewritePatterntosa::ResizeOp {
1567 public:
1569
1570 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1574 auto input = op.getInput();
1575 auto inputTy = cast(input.getType());
1576 auto resultTy = cast(op.getType());
1577 const bool isBilinear = op.getMode() == "BILINEAR";
1578
1579 auto inputH = inputTy.getDimSize(1);
1580 auto inputW = inputTy.getDimSize(2);
1581 auto outputH = resultTy.getDimSize(1);
1582 auto outputW = resultTy.getDimSize(2);
1583
1584 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1585 return rewriter.notifyMatchFailure(
1586 op, "tosa.resize is not a pure 1x1->1x1 image operation");
1587
1588
1589 if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1590 return rewriter.notifyMatchFailure(
1591 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1592
1593 if (inputTy == resultTy) {
1594 rewriter.replaceOp(op, input);
1595 return success();
1596 }
1597
1600 return failure();
1601 }
1602
1603
1605 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1606 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1607 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1608 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1609
1610 auto collapseTy =
1612 inputTy.getElementType());
1613 Value collapse = builder.createtensor::CollapseShapeOp(collapseTy, input,
1614 reassociationMap);
1615
1616
1618 if (inputTy.isDynamicDim(0))
1619 outputDynSize.push_back(builder.createtensor::DimOp(input, 0));
1620 if (inputTy.isDynamicDim(3))
1621 outputDynSize.push_back(builder.createtensor::DimOp(input, 3));
1622
1623
1624 auto genericTy = collapseTy.clone(resultTy.getElementType());
1625 Value empty = builder.createtensor::EmptyOp(
1626 genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1627 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1629 utils::IteratorType::parallel);
1630
1631 auto generic = builder.createlinalg::GenericOp(
1635 Value value = args[0];
1636
1637 if (inputTy.getElementType() != resultTy.getElementType()) {
1638 value =
1639 b.createarith::ExtSIOp(loc, resultTy.getElementType(), value);
1640
1641 if (isBilinear && scale[0] != 0) {
1642 Value scaleY = b.createarith::ConstantOp(
1643 loc, b.getI32IntegerAttr(scale[0]));
1644 value = b.createarith::MulIOp(loc, value, scaleY);
1645 }
1646
1647 if (isBilinear && scale[2] != 0) {
1648 Value scaleX = b.createarith::ConstantOp(
1649 loc, b.getI32IntegerAttr(scale[2]));
1650 value = b.createarith::MulIOp(loc, value, scaleX);
1651 }
1652 }
1653
1654 b.createlinalg::YieldOp(loc, value);
1655 });
1656
1657 rewriter.replaceOpWithNewOptensor::ExpandShapeOp(
1658 op, resultTy, generic.getResults()[0], reassociationMap);
1659 return success();
1660 }
1661 };
1662
1663
1664
1665
1667 public:
1669
1670 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1674 auto input = op.getInput();
1675 auto inputTy = dyn_cast(input.getType());
1676 auto resultTy = dyn_cast(op.getType());
1677
1678 if (!inputTy || !resultTy)
1679 return rewriter.notifyMatchFailure(op,
1680 "requires ranked input/output types");
1681
1682 auto batch = inputTy.getDimSize(0);
1683 auto channels = inputTy.getDimSize(3);
1684 auto inputH = inputTy.getDimSize(1);
1685 auto inputW = inputTy.getDimSize(2);
1686 auto outputH = resultTy.getDimSize(1);
1687 auto outputW = resultTy.getDimSize(2);
1688
1689 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1690 return rewriter.notifyMatchFailure(
1691 op, "tosa.resize has no broadcasting behavior");
1692
1693
1694
1696 resizeShape.push_back(batch);
1697 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1698 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1699 resizeShape.push_back(channels);
1700
1701 auto resizeTy = resultTy.clone(resizeShape);
1702 auto resize = builder.createtosa::ResizeOp(resizeTy, input, op.getScale(),
1703 op.getOffset(), op.getBorder(),
1704 op.getMode());
1705
1706
1708 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1709 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1710 if (inputH != 1)
1711 reassociationMap.push_back({});
1712 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1713 if (inputW != 1)
1714 reassociationMap.push_back({});
1715 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1716
1718 if (inputH != 1)
1719 collapseShape.push_back(outputH);
1720 if (inputW != 1)
1721 collapseShape.push_back(outputW);
1722 collapseShape.push_back(channels);
1723
1724 auto collapseTy = resultTy.clone(collapseShape);
1725 Value collapse = builder.createtensor::CollapseShapeOp(collapseTy, resize,
1726 reassociationMap);
1727
1728
1730 if (inputTy.isDynamicDim(0))
1731 outputDynSize.push_back(builder.createtensor::DimOp(input, 0));
1732 if (inputTy.isDynamicDim(3))
1733 outputDynSize.push_back(builder.createtensor::DimOp(input, 3));
1734
1736 utils::IteratorType::parallel);
1737 Value empty = builder.createtensor::EmptyOp(
1738 resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1739
1741 if (inputH != 1)
1742 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1743 if (inputW != 1)
1744 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1745 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1746
1747 auto inputMap = AffineMap::get(resultTy.getRank(), 0,
1748 inputExprs, rewriter.getContext());
1749
1750 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1751 rewriter.replaceOpWithNewOplinalg::GenericOp(
1755 Value value = args[0];
1756 b.createlinalg::YieldOp(loc, value);
1757 });
1758
1759 return success();
1760 }
1761 };
1762
1763 class GenericResizeConverter : public OpRewritePatterntosa::ResizeOp {
1764 public:
1766
1767 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1771 auto input = op.getInput();
1772 auto inputTy = cast(input.getType());
1773 auto resultTy = cast(op.getType());
1774 auto resultETy = resultTy.getElementType();
1775
1776 bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1777 auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1778
1779 auto imageH = inputTy.getShape()[1];
1780 auto imageW = inputTy.getShape()[2];
1781
1782 auto dynamicDimsOr =
1784 if (!dynamicDimsOr.has_value())
1785 return rewriter.notifyMatchFailure(
1786 op, "unable to get dynamic dimensions of tosa.resize");
1787
1788 if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1789 return rewriter.notifyMatchFailure(
1790 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1791
1793 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1794 auto emptyTensor = b.createtensor::EmptyOp(resultTy.getShape(), resultETy,
1795 *dynamicDimsOr);
1796 auto genericOp = b.createlinalg::GenericOp(
1799 Value resize = genericOp.getResult(0);
1800
1801 {
1803 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1805 Value batch = b.createlinalg::IndexOp(0);
1806 Value y = b.createlinalg::IndexOp(1);
1807 Value x = b.createlinalg::IndexOp(2);
1808 Value channel = b.createlinalg::IndexOp(3);
1809
1811 b.createarith::ConstantOp(b.getZeroAttr(b.getI32Type()));
1812 Value zeroFp = b.createarith::ConstantOp(b.getZeroAttr(floatTy));
1813 Value hMax = b.createarith::ConstantOp(b.getI32IntegerAttr(imageH - 1));
1814 Value wMax = b.createarith::ConstantOp(b.getI32IntegerAttr(imageW - 1));
1815
1816 Value inY = b.createarith::IndexCastOp(b.getI32Type(), y);
1817 Value inX = b.createarith::IndexCastOp(b.getI32Type(), x);
1818
1823 return rewriter.notifyMatchFailure(
1824 op, "tosa.resize scale/offset/border should have compile time "
1825 "constant values.");
1826 }
1827
1828 Value yScaleN, yScaleD, xScaleN, xScaleD;
1829 yScaleN = b.createarith::ConstantOp(b.getI32IntegerAttr(scale[0]));
1830 yScaleD = b.createarith::ConstantOp(b.getI32IntegerAttr(scale[1]));
1831 xScaleN = b.createarith::ConstantOp(b.getI32IntegerAttr(scale[2]));
1832 xScaleD = b.createarith::ConstantOp(b.getI32IntegerAttr(scale[3]));
1833
1834 Value yOffset, xOffset, yBorder, xBorder;
1835 yOffset = b.createarith::ConstantOp(b.getI32IntegerAttr(offset[0]));
1836 xOffset = b.createarith::ConstantOp(b.getI32IntegerAttr(offset[1]));
1837 yBorder = b.createarith::ConstantOp(b.getI32IntegerAttr(border[0]));
1838 xBorder = b.createarith::ConstantOp(b.getI32IntegerAttr(border[1]));
1839
1840
1841 auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
1844 if (size == 1) {
1845 index = zeroI32;
1846 delta = zeroFp;
1847 return;
1848 }
1849
1850
1851 Value val = b.createarith::MulIOp(in, scaleD);
1852 val = b.createarith::AddIOp(val, offset);
1853 index = b.createarith::FloorDivSIOp(val, scaleN);
1854
1855
1856
1857 Value r = b.createarith::RemSIOp(val, scaleN);
1858 Value rFp = b.createarith::SIToFPOp(floatTy, r);
1859 Value scaleNfp = b.createarith::UIToFPOp(floatTy, scaleN);
1860 delta = b.createarith::DivFOp(rFp, scaleNfp);
1861 };
1862
1863
1864 auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
1867 if (size == 1) {
1868 index = zeroI32;
1869 delta = zeroI32;
1870 return;
1871 }
1872
1873
1874
1875 Value val = b.createarith::MulIOp(in, scaleD);
1876 val = b.createarith::AddIOp(val, offset);
1877 index = b.createarith::DivSIOp(val, scaleN);
1878 delta = b.createarith::MulIOp(index, scaleN);
1879 delta = b.createarith::SubIOp(val, delta);
1880 };
1881
1882 Value ix, iy, dx, dy;
1883 if (floatingPointMode) {
1884 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1885 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1886 } else {
1887 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1888 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1889 }
1890
1891 if (op.getMode() == "NEAREST_NEIGHBOR") {
1892 auto one = b.createarith::ConstantOp(b.getI32IntegerAttr(1));
1893
1894 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
1897 if (size == 1) {
1898 return b.createarith::ConstantIndexOp(0);
1899 }
1900
1902 if (floatingPointMode) {
1903 auto h = b.createarith::ConstantOp(b.getFloatAttr(floatTy, 0.5f));
1904 pred = b.createarith::CmpFOp(arith::CmpFPredicate::OGE, dval, h);
1905 } else {
1906 Value dvalDouble = b.createarith::ShLIOp(dval, one);
1907 pred = b.createarith::CmpIOp(arith::CmpIPredicate::sge,
1908 dvalDouble, scale);
1909 }
1910
1911 auto offset = b.createarith::SelectOp(pred, one, zeroI32);
1912 val = b.createarith::AddIOp(val, offset);
1913 val = clampIntHelper(loc, val, zeroI32, max, b, false);
1914 return b.createarith::IndexCastOp(b.getIndexType(), val);
1915 };
1916
1917 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1918 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1919
1920 Value result = b.createtensor::ExtractOp(
1921 input, ValueRange{batch, iy, ix, channel});
1922
1923 b.createlinalg::YieldOp(result);
1924 } else {
1925
1926 assert(op.getMode() == "BILINEAR");
1927
1928 auto oneVal = b.createarith::ConstantOp(b.getI32IntegerAttr(1));
1929
1930 auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
1932 val0 = in;
1933 val1 = b.createarith::AddIOp(val0, oneVal);
1934 val0 =
1935 clampIntHelper(loc, val0, zeroI32, max, b, false);
1936 val1 =
1937 clampIntHelper(loc, val1, zeroI32, max, b, false);
1938 val0 = b.createarith::IndexCastOp(b.getIndexType(), val0);
1939 val1 = b.createarith::IndexCastOp(b.getIndexType(), val1);
1940 };
1941
1942
1943
1944
1945
1946
1947 Value x0, x1, y0, y1;
1948 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1949 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1950
1951 Value y0x0 = b.createtensor::ExtractOp(
1952 input, ValueRange{batch, y0, x0, channel});
1953 Value y0x1 = b.createtensor::ExtractOp(
1954 input, ValueRange{batch, y0, x1, channel});
1955 Value y1x0 = b.createtensor::ExtractOp(
1956 input, ValueRange{batch, y1, x0, channel});
1957 Value y1x1 = b.createtensor::ExtractOp(
1958 input, ValueRange{batch, y1, x1, channel});
1959
1960 if (floatingPointMode) {
1961 auto oneVal =
1962 b.createarith::ConstantOp(b.getFloatAttr(floatTy, 1.0f));
1963 auto interpolate = [&](Value val0, Value val1, Value delta,
1964 int inputSize,
1966 if (inputSize == 1)
1967 return val0;
1968 Value oneMinusDelta = b.createarith::SubFOp(oneVal, delta);
1969 Value mul0 = b.createarith::MulFOp(val0, oneMinusDelta);
1970 Value mul1 = b.createarith::MulFOp(val1, delta);
1971 return b.createarith::AddFOp(mul0, mul1);
1972 };
1973
1974
1975
1976
1977 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1978
1979
1980
1981
1982 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1983
1984
1985
1986 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1987 b.createlinalg::YieldOp(result);
1988 } else {
1989
1990 y0x0 = b.createarith::ExtSIOp(resultETy, y0x0);
1991 y0x1 = b.createarith::ExtSIOp(resultETy, y0x1);
1992 y1x0 = b.createarith::ExtSIOp(resultETy, y1x0);
1993 y1x1 = b.createarith::ExtSIOp(resultETy, y1x1);
1994
1996 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1997 dx = b.createarith::ExtSIOp(resultETy, dx);
1998 dy = b.createarith::ExtSIOp(resultETy, dy);
1999 }
2000
2001 Value yScaleNExt = yScaleN;
2002 Value xScaleNExt = xScaleN;
2003
2004 const int64_t scaleBitwidth =
2006 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2007 yScaleNExt = b.createarith::ExtSIOp(resultETy, yScaleN);
2008 xScaleNExt = b.createarith::ExtSIOp(resultETy, xScaleN);
2009 }
2010
2011 auto interpolate = [](Value val0, Value val1, Value weight1,
2012 Value scale, int inputSize,
2014 if (inputSize == 1)
2015 return b.createarith::MulIOp(val0, scale);
2016 Value weight0 = b.createarith::SubIOp(scale, weight1);
2017 Value mul0 = b.createarith::MulIOp(val0, weight0);
2018 Value mul1 = b.createarith::MulIOp(val1, weight1);
2019 return b.createarith::AddIOp(mul0, mul1);
2020 };
2021
2022 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2023 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2025 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2026 b.createlinalg::YieldOp(result);
2027 }
2028 }
2029 }
2030
2031 rewriter.replaceOp(op, resize);
2032 return success();
2033 }
2034 };
2035
2036
2037
2038
2039 template
2041 public:
2043
2044 LogicalResult matchAndRewrite(SrcOp op,
2046 rewriter.replaceOp(op, op.getOperation()->getOperands());
2047 return success();
2048 }
2049 };
2050
2051 template
2053 public:
2055
2056 LogicalResult matchAndRewrite(SrcOp reduceOp,
2059 }
2060 };
2061
2062 class ReverseConverter : public OpRewritePatterntosa::ReverseOp {
2063 public:
2065
2066 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2068 auto loc = op.getLoc();
2069 Value input = op.getInput1();
2070 auto inputTy = cast(input.getType());
2071 auto resultTy = cast(op.getType());
2072 auto axis = op.getAxis();
2073
2075 for (int i = 0; i < inputTy.getRank(); i++) {
2076 if (inputTy.isDynamicDim(i)) {
2077 dynDims.push_back(rewriter.createtensor::DimOp(loc, input, i));
2078 }
2079 }
2080
2081 Value axisDimSize = rewriter.createtensor::DimOp(loc, input, axis);
2082
2083
2084 auto emptyTensor = rewriter
2085 .createtensor::EmptyOp(loc, inputTy.getShape(),
2086 inputTy.getElementType(),
2088 .getResult();
2090 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2091
2092 rewriter.replaceOpWithNewOplinalg::GenericOp(
2097 for (unsigned int i = 0; i < inputTy.getRank(); i++) {
2099 rewriter.createlinalg::IndexOp(nestedLoc, i).getResult();
2100 if (i == axis) {
2101 auto one = rewriter.createarith::ConstantIndexOp(nestedLoc, 1);
2102 auto sizeMinusOne =
2103 rewriter.createarith::SubIOp(nestedLoc, axisDimSize, one);
2104 index = rewriter.createarith::SubIOp(nestedLoc, sizeMinusOne,
2105 index);
2106 }
2107
2108 indices.push_back(index);
2109 }
2110
2111 auto extract = nestedBuilder.createtensor::ExtractOp(
2112 nestedLoc, input, indices);
2113 nestedBuilder.createlinalg::YieldOp(op.getLoc(),
2114 extract.getResult());
2115 });
2116 return success();
2117 }
2118 };
2119
2120
2121
2122
2123
2126
2127 LogicalResult
2128 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2130 auto loc = op.getLoc();
2131 auto input = op.getInput1();
2132 auto inputTy = cast(input.getType());
2133 auto inputShape = inputTy.getShape();
2134 auto resultTy = cast(op.getType());
2135 auto elementTy = inputTy.getElementType();
2136 int64_t rank = inputTy.getRank();
2137
2139 if (failed(op.getConstantMultiples(multiples)))
2140 return failure();
2141
2142
2144 for (int i = 0; i < rank; i++) {
2145 int64_t dim = multiples[i];
2146 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2147 genericShape.push_back(inputShape[i]);
2148 }
2149
2151 for (int i = 0; i < inputTy.getRank(); i++) {
2152 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2153 dynDims.push_back(rewriter.createtensor::DimOp(loc, input, i));
2154 }
2155 }
2156
2157 auto emptyTensor = rewriter.createtensor::EmptyOp(
2158 op.getLoc(), genericShape, elementTy, dynDims);
2159
2160
2162 dimExprs.reserve(rank);
2163 for (unsigned i = 0; i < rank; ++i)
2165
2166 auto readAffineMap =
2167 AffineMap::get(rank * 2, 0, dimExprs,
2169
2172
2173 auto genericOp = rewriter.createlinalg::GenericOp(
2175 ValueRange{emptyTensor}, affineMaps,
2178 nestedBuilder.createlinalg::YieldOp(op.getLoc(), *args.begin());
2179 });
2180
2184 op, resultTy, genericOp.getResult(0), shapeValue);
2185 return success();
2186 }
2187 };
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202 class ArgMaxConverter : public OpRewritePatterntosa::ArgMaxOp {
2203 public:
2205
2206 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2208 auto loc = argmaxOp.getLoc();
2209 Value input = argmaxOp.getInput();
2210 auto inputTy = cast(input.getType());
2211 auto resultTy = cast(argmaxOp.getOutput().getType());
2212 auto inElementTy = inputTy.getElementType();
2213 auto outElementTy = resultTy.getElementType();
2214 int axis = argmaxOp.getAxis();
2216
2217 if (!isa(outElementTy))
2219 argmaxOp,
2220 "tosa.arg_max to linalg.* requires integer-like result type");
2221
2223 for (int i = 0; i < inputTy.getRank(); i++) {
2224 if (inputTy.isDynamicDim(i) && i != axis) {
2225 dynDims.push_back(rewriter.createtensor::DimOp(loc, input, i));
2226 }
2227 }
2228
2229
2230 auto emptyTensorIdx = rewriter
2231 .createtensor::EmptyOp(loc, resultTy.getShape(),
2232 outElementTy, dynDims)
2233 .getResult();
2234 auto fillValueIdx = rewriter.createarith::ConstantOp(
2236 auto filledTensorIdx =
2237 rewriter
2240 .result();
2241
2242
2243 auto emptyTensorMax = rewriter
2244 .createtensor::EmptyOp(loc, resultTy.getShape(),
2245 inElementTy, dynDims)
2246 .getResult();
2247 auto fillValueMaxAttr =
2249
2250 if (!fillValueMaxAttr)
2252 argmaxOp, "unsupported tosa.argmax element type");
2253
2254 auto fillValueMax =
2255 rewriter.createarith::ConstantOp(loc, fillValueMaxAttr);
2256 auto filledTensorMax =
2257 rewriter
2260 .result();
2261
2262
2263
2265 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2266 iteratorTypes[axis] = utils::IteratorType::reduction;
2267
2270 for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2272 if (axis != i)
2274 }
2275
2276 bool didEncounterError = false;
2279 auto linalgOp = rewriter.createlinalg::GenericOp(
2280 loc, ArrayRef({resultTy, resultMaxTy}), input,
2281 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2284 auto newValue = blockArgs[0];
2285 auto oldIndex = blockArgs[1];
2286 auto oldValue = blockArgs[2];
2287
2288 Value newIndex = rewriter.createarith::IndexCastOp(
2289 nestedLoc, oldIndex.getType(),
2290 rewriter.createlinalg::IndexOp(loc, axis));
2291
2292 Value predicate;
2293 if (isa(inElementTy)) {
2294 if (argmaxOp.getNanMode() == "IGNORE") {
2295
2296
2297 predicate = rewriter.createarith::CmpFOp(
2298 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2299 } else {
2300
2301
2302
2303 Value gt = rewriter.createarith::CmpFOp(
2304 nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue);
2305 Value oldNonNaN = rewriter.createarith::CmpFOp(
2306 nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue);
2307 predicate = rewriter.createarith::AndIOp(
2308 nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2309 }
2310 } else if (isa(inElementTy)) {
2311 predicate = rewriter.createarith::CmpIOp(
2312 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2313 } else {
2314 didEncounterError = true;
2315 return;
2316 }
2317
2318 auto resultMax = rewriter.createarith::SelectOp(
2319 nestedLoc, predicate, newValue, oldValue);
2320 auto resultIndex = rewriter.createarith::SelectOp(
2321 nestedLoc, predicate, newIndex, oldIndex);
2322 nestedBuilder.createlinalg::YieldOp(
2323 nestedLoc, ValueRange({resultIndex, resultMax}));
2324 });
2325
2326 if (didEncounterError)
2328 argmaxOp, "unsupported tosa.argmax element type");
2329
2330 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2331 return success();
2332 }
2333 };
2334
2336 public:
2338 LogicalResult
2339 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2341 auto input = adaptor.getOperands()[0];
2342 auto indices = adaptor.getOperands()[1];
2343
2344 auto valuesTy = dyn_cast(op.getValues().getType());
2345 auto resultTy = dyn_cast(op.getType());
2346 if (!valuesTy || !resultTy)
2347 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2348
2349 auto dynamicDims = inferDynamicDimsForGather(
2350 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2351
2352 auto resultElementTy = resultTy.getElementType();
2353
2354 auto loc = op.getLoc();
2355 auto emptyTensor =
2356 rewriter
2357 .createtensor::EmptyOp(loc, resultTy.getShape(), resultElementTy,
2358 dynamicDims)
2359 .getResult();
2360
2363 resultTy.getRank(), 0,
2364 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2367
2368 auto genericOp = rewriter.createlinalg::GenericOp(
2370 ValueRange{emptyTensor}, affineMaps,
2373 auto indexValue = args[0];
2374 auto index0 = rewriter.createlinalg::IndexOp(loc, 0);
2375 Value index1 = rewriter.createarith::IndexCastOp(
2377 auto index2 = rewriter.createlinalg::IndexOp(loc, 2);
2378 Value extract = rewriter.createtensor::ExtractOp(
2379 loc, input, ValueRange{index0, index1, index2});
2380 rewriter.createlinalg::YieldOp(loc, extract);
2381 });
2382 rewriter.replaceOp(op, genericOp.getResult(0));
2383 return success();
2384 }
2385
2389 Value indices) {
2391
2392 auto addDynamicDimension = [&](Value source, int64_t dim) {
2394 if (auto dimValue = llvm::dyn_cast_if_present(sz))
2395 results.push_back(dimValue);
2396 };
2397
2398 addDynamicDimension(values, 0);
2399 addDynamicDimension(indices, 1);
2400 addDynamicDimension(values, 2);
2401 return results;
2402 }
2403 };
2404
2405
2406
2407
2408 class TableConverter : public OpRewritePatterntosa::TableOp {
2409 public:
2411
2412 LogicalResult matchAndRewrite(tosa::TableOp op,
2414 auto loc = op.getLoc();
2415 Value input = op.getInput1();
2417 auto inputTy = cast(input.getType());
2418 auto tableTy = cast(table.getType());
2419 auto resultTy = cast(op.getType());
2420
2421 auto inputElementTy = inputTy.getElementType();
2422 auto tableElementTy = tableTy.getElementType();
2423 auto resultElementTy = resultTy.getElementType();
2424
2426 for (int i = 0; i < resultTy.getRank(); ++i) {
2427 if (inputTy.isDynamicDim(i)) {
2428 dynDims.push_back(
2429 rewriter.createtensor::DimOp(loc, op.getOperand(0), i));
2430 }
2431 }
2432
2433 auto emptyTensor = rewriter
2434 .createtensor::EmptyOp(loc, resultTy.getShape(),
2435 resultElementTy, dynDims)
2436 .getResult();
2437
2441
2442 auto genericOp = rewriter.createlinalg::GenericOp(
2445 rewriter.replaceOp(op, genericOp.getResult(0));
2446
2447 {
2450 &genericOp.getRegion(), genericOp.getRegion().end(),
2451 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2452
2453 auto inputValue = block->getArgument(0);
2455 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2456 resultElementTy.isInteger(8)) {
2457 Value index = rewriter.createarith::IndexCastOp(
2459 Value offset = rewriter.createarith::ConstantIndexOp(loc, 128);
2460 index = rewriter.createarith::AddIOp(loc, rewriter.getIndexType(),
2461 index, offset);
2464 rewriter.createlinalg::YieldOp(loc, extract);
2465 return success();
2466 }
2467
2468 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2469 resultElementTy.isInteger(32)) {
2470 Value extend = rewriter.createarith::ExtSIOp(
2471 loc, rewriter.getI32Type(), inputValue);
2472
2473 auto offset = rewriter.createarith::ConstantOp(
2475 auto seven = rewriter.createarith::ConstantOp(
2477 auto one = rewriter.createarith::ConstantOp(
2479 auto b1111111 = rewriter.createarith::ConstantOp(
2481
2482
2483
2484
2485
2486 auto extendAdd = rewriter.createarith::AddIOp(loc, extend, offset);
2487 Value index = rewriter.createarith::ShRUIOp(loc, extendAdd, seven);
2488 Value fraction =
2489 rewriter.createarith::AndIOp(loc, extendAdd, b1111111);
2490
2491
2492
2493
2494 Value indexPlusOne = rewriter.createarith::AddIOp(loc, index, one);
2495
2496 index = rewriter.createarith::IndexCastOp(
2498 indexPlusOne = rewriter.createarith::IndexCastOp(
2499 loc, rewriter.getIndexType(), indexPlusOne);
2500
2503 Value next = rewriter.createtensor::ExtractOp(
2505
2506 base =
2507 rewriter.createarith::ExtSIOp(loc, rewriter.getI32Type(), base);
2508 next =
2509 rewriter.createarith::ExtSIOp(loc, rewriter.getI32Type(), next);
2510
2511
2512
2513 Value baseScaled = rewriter.createarith::ShLIOp(loc, base, seven);
2514 Value diff = rewriter.createarith::SubIOp(loc, next, base);
2515 Value diffScaled = rewriter.createarith::MulIOp(loc, diff, fraction);
2517 rewriter.createarith::AddIOp(loc, baseScaled, diffScaled);
2518
2519 rewriter.createlinalg::YieldOp(loc, result);
2520
2521 return success();
2522 }
2523 }
2524
2526 op, "unable to create body for tosa.table op");
2527 }
2528 };
2529
2530 struct RFFT2dConverter final : public OpRewritePattern {
2532
2533 static bool isRankedTensor(Type type) { return isa(type); }
2534
2537 auto one = builder.createarith::ConstantIndexOp(loc, 1);
2538 auto two = builder.createarith::ConstantIndexOp(loc, 2);
2539
2541 auto divBy2 = builder.createOrFoldarith::DivUIOp(loc, value, two);
2542 auto plusOne = builder.createOrFoldarith::AddIOp(loc, divBy2, one);
2544 }
2545
2546 static RankedTensorType
2549
2551
2552
2553
2554 dims[2] = halfPlusOne(builder, loc, dims[2]);
2555
2558
2559 auto elementType = cast(input.getType()).getElementType();
2561 }
2562
2564 RankedTensorType type,
2566 auto emptyTensor =
2567 rewriter.createtensor::EmptyOp(loc, type, dynamicSizes);
2568 auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2569 auto fillValue = rewriter.createarith::ConstantOp(loc, fillValueAttr);
2570 auto filledTensor = rewriter
2573 .result();
2574 return filledTensor;
2575 }
2576
2578 FloatType type, Value value) {
2579 auto integerVal = builder.createarith::IndexCastUIOp(
2580 loc,
2581 type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2583 value);
2584
2585 return builder.createarith::UIToFPOp(loc, type, integerVal);
2586 }
2587
2589 FloatType type, int64_t index) {
2590 auto indexVal = builder.createlinalg::IndexOp(loc, index);
2591 return castIndexToFloat(builder, loc, type, indexVal);
2592 }
2593
2594 template <typename... Args>
2596 Args... args) {
2598 }
2599
2600 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2602 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2603 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2605 "only supports ranked tensors");
2606 }
2607
2608 auto loc = rfft2d.getLoc();
2609 auto input = rfft2d.getInputReal();
2610 auto elementType =
2611 dyn_cast(cast(input.getType()).getElementType());
2612 if (!elementType)
2614 "only supports float element types");
2615
2616
2618 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2619
2620
2622 utils::IteratorType::parallel, utils::IteratorType::parallel,
2623 utils::IteratorType::parallel, utils::IteratorType::reduction,
2624 utils::IteratorType::reduction};
2625
2626
2629 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2630 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2631
2632
2635 affineDimsExpr(rewriter, 0, 1, 2),
2636 affineDimsExpr(rewriter, 0, 1, 2)},
2638
2639
2640 auto dimH = rewriter.createOrFoldtensor::DimOp(loc, input, 1);
2641 auto dimW = rewriter.createOrFoldtensor::DimOp(loc, input, 2);
2642
2643
2644 auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2645 auto twoPi = rewriter.createarith::ConstantOp(loc, twoPiAttr);
2646 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2647 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2648
2650 Value valReal = args[0];
2651 Value sumReal = args[1];
2652 Value sumImag = args[2];
2653
2654
2655 Value oy = builder.createlinalg::IndexOp(loc, 1);
2656 Value ox = builder.createlinalg::IndexOp(loc, 2);
2657 Value iy = builder.createlinalg::IndexOp(loc, 3);
2658 Value ix = builder.createlinalg::IndexOp(loc, 4);
2659
2660
2661
2662
2663 auto iyXoy = builder.createindex::MulOp(loc, iy, oy);
2664 auto ixXox = builder.createindex::MulOp(loc, ix, ox);
2665
2666 auto iyRem = builder.createindex::RemUOp(loc, iyXoy, dimH);
2667 auto ixRem = builder.createindex::RemUOp(loc, ixXox, dimW);
2668
2669 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2670 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2671
2672 auto yComponent = builder.createarith::DivFOp(loc, iyRemFloat, constH);
2673 auto xComponent = builder.createarith::DivFOp(loc, ixRemFloat, constW);
2674 auto sumXY = builder.createarith::AddFOp(loc, yComponent, xComponent);
2675 auto angle = builder.createarith::MulFOp(loc, twoPi, sumXY);
2676
2677
2678
2679 auto cosAngle = builder.createmath::CosOp(loc, angle);
2680 auto sinAngle = builder.createmath::SinOp(loc, angle);
2681 auto realComponent =
2682 builder.createarith::MulFOp(loc, valReal, cosAngle);
2683 auto imagComponent =
2684 builder.createarith::MulFOp(loc, valReal, sinAngle);
2685
2686
2687
2688 auto outReal = builder.createarith::AddFOp(loc, sumReal, realComponent);
2689 auto outImag = builder.createarith::SubFOp(loc, sumImag, imagComponent);
2690
2691 builder.createlinalg::YieldOp(loc, ValueRange{outReal, outImag});
2692 };
2693
2695 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2696 indexingMaps, iteratorTypes, buildBody);
2697
2698 return success();
2699 }
2700 };
2701
2704
2705 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2707 if (!llvm::all_of(fft2d->getOperandTypes(),
2708 RFFT2dConverter::isRankedTensor) ||
2709 !llvm::all_of(fft2d->getResultTypes(),
2710 RFFT2dConverter::isRankedTensor)) {
2711 return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2712 }
2713
2714 Location loc = fft2d.getLoc();
2715 Value input_real = fft2d.getInputReal();
2716 Value input_imag = fft2d.getInputImag();
2717 BoolAttr inverse = fft2d.getInverseAttr();
2718
2719 auto real_el_ty = cast(
2720 cast(input_real.getType()).getElementType());
2721 [[maybe_unused]] auto imag_el_ty = cast(
2722 cast(input_imag.getType()).getElementType());
2723
2724 assert(real_el_ty == imag_el_ty);
2725
2726
2728
2729
2731
2734
2736
2737
2739 utils::IteratorType::parallel, utils::IteratorType::parallel,
2740 utils::IteratorType::parallel, utils::IteratorType::reduction,
2741 utils::IteratorType::reduction};
2742
2743
2746 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2747 dynamicSizes),
2748 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2749 dynamicSizes)};
2750
2751
2753 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2754 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2755 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2756 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2758
2759
2760 auto dimH = rewriter.createOrFoldtensor::DimOp(loc, input_real, 1);
2761 auto dimW = rewriter.createOrFoldtensor::DimOp(loc, input_real, 2);
2762
2763
2764 auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2765 auto twoPi = rewriter.createarith::ConstantOp(loc, twoPiAttr);
2767 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2769 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2770
2772 Value valReal = args[0];
2773 Value valImag = args[1];
2774 Value sumReal = args[2];
2775 Value sumImag = args[3];
2776
2777
2778 Value oy = builder.createlinalg::IndexOp(loc, 1);
2779 Value ox = builder.createlinalg::IndexOp(loc, 2);
2780 Value iy = builder.createlinalg::IndexOp(loc, 3);
2781 Value ix = builder.createlinalg::IndexOp(loc, 4);
2782
2783
2784
2785 auto iyXoy = builder.createindex::MulOp(loc, iy, oy);
2786 auto ixXox = builder.createindex::MulOp(loc, ix, ox);
2787
2788 auto iyRem = builder.createindex::RemUOp(loc, iyXoy, dimH);
2789 auto ixRem = builder.createindex::RemUOp(loc, ixXox, dimW);
2790
2791 auto iyRemFloat =
2792 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2793 auto ixRemFloat =
2794 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2795
2796 auto yComponent = builder.createarith::DivFOp(loc, iyRemFloat, constH);
2797 auto xComponent = builder.createarith::DivFOp(loc, ixRemFloat, constW);
2798
2799 auto sumXY = builder.createarith::AddFOp(loc, yComponent, xComponent);
2800 auto angle = builder.createarith::MulFOp(loc, twoPi, sumXY);
2801
2803 angle = builder.createarith::MulFOp(
2804 loc, angle,
2805 rewriter.createarith::ConstantOp(
2806 loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
2807 }
2808
2809
2810
2811 auto cosAngle = builder.createmath::CosOp(loc, angle);
2812 auto sinAngle = builder.createmath::SinOp(loc, angle);
2813
2814 auto rcos = builder.createarith::MulFOp(loc, valReal, cosAngle);
2815 auto rsin = builder.createarith::MulFOp(loc, valImag, sinAngle);
2816 auto realComponent = builder.createarith::AddFOp(loc, rcos, rsin);
2817
2818 auto icos = builder.createarith::MulFOp(loc, valImag, cosAngle);
2819 auto isin = builder.createarith::MulFOp(loc, valReal, sinAngle);
2820
2821 auto imagComponent = builder.createarith::SubFOp(loc, icos, isin);
2822
2823
2824
2825 auto outReal = builder.createarith::AddFOp(loc, sumReal, realComponent);
2826 auto outImag = builder.createarith::AddFOp(loc, sumImag, imagComponent);
2827
2828 builder.createlinalg::YieldOp(loc, ValueRange{outReal, outImag});
2829 };
2830
2832 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2833 indexingMaps, iteratorTypes, buildBody);
2834
2835 return success();
2836 }
2837 };
2838
2839 }
2840
2843
2844
2845 patterns->add(patterns->getContext(),
2846 100);
2848 200);
2849 patterns->add(patterns->getContext(),
2850 300);
2851
2853
2854 PointwiseConvertertosa::AddOp,
2855 PointwiseConvertertosa::SubOp,
2856 PointwiseConvertertosa::MulOp,
2857 PointwiseConvertertosa::IntDivOp,
2858 PointwiseConvertertosa::NegateOp,
2859 PointwiseConvertertosa::PowOp,
2860 PointwiseConvertertosa::ReciprocalOp,
2861 PointwiseConvertertosa::RsqrtOp,
2862 PointwiseConvertertosa::LogOp,
2863 PointwiseConvertertosa::ExpOp,
2864 PointwiseConvertertosa::AbsOp,
2865 PointwiseConvertertosa::SinOp,
2866 PointwiseConvertertosa::CosOp,
2867 PointwiseConvertertosa::TanhOp,
2868 PointwiseConvertertosa::ErfOp,
2869 PointwiseConvertertosa::BitwiseAndOp,
2870 PointwiseConvertertosa::BitwiseOrOp,
2871 PointwiseConvertertosa::BitwiseNotOp,
2872 PointwiseConvertertosa::BitwiseXorOp,
2873 PointwiseConvertertosa::LogicalAndOp,
2874 PointwiseConvertertosa::LogicalNotOp,
2875 PointwiseConvertertosa::LogicalOrOp,
2876 PointwiseConvertertosa::LogicalXorOp,
2877 PointwiseConvertertosa::CastOp,
2878 PointwiseConvertertosa::LogicalLeftShiftOp,
2879 PointwiseConvertertosa::LogicalRightShiftOp,
2880 PointwiseConvertertosa::ArithmeticRightShiftOp,
2881 PointwiseConvertertosa::ClzOp,
2882 PointwiseConvertertosa::SelectOp,
2883 PointwiseConvertertosa::GreaterOp,
2884 PointwiseConvertertosa::GreaterEqualOp,
2885 PointwiseConvertertosa::EqualOp,
2886 PointwiseConvertertosa::MaximumOp,
2887 PointwiseConvertertosa::MinimumOp,
2888 PointwiseConvertertosa::CeilOp,
2889 PointwiseConvertertosa::FloorOp,
2890 PointwiseConvertertosa::ClampOp,
2891 PointwiseConvertertosa::SigmoidOp
2892 >(converter, patterns->getContext());
2893
2895 IdentityNConvertertosa::IdentityOp,
2896 ReduceConvertertosa::ReduceAllOp,
2897 ReduceConvertertosa::ReduceAnyOp,
2898 ReduceConvertertosa::ReduceMinOp,
2899 ReduceConvertertosa::ReduceMaxOp,
2900 ReduceConvertertosa::ReduceSumOp,
2901 ReduceConvertertosa::ReduceProductOp,
2902 ArgMaxConverter,
2903 GatherConverter,
2904 RescaleConverter,
2905 ReverseConverter,
2906 RFFT2dConverter,
2907 FFT2dConverter,
2908 TableConverter,
2909 TileConverter>(patterns->getContext());
2910
2911 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static bool operandsAndResultsRanked(Operation *operation)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...