MLIR: lib/Dialect/Tosa/IR/TosaOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32
33 #include
34
35 using namespace mlir;
37
38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
40
41
42
43
44
45 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
49
50 namespace {
51 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
52
53
54
55
58
59
60
61
62
63
66 return true;
67 }
68
69
72 return (isatosa::IfOp(dest->getParentOp()) ||
73 isatosa::WhileOp(dest->getParentOp()));
74 }
75 };
76
77
79 TosaDialectBytecodeInterface(Dialect *dialect)
81
82
83
84
86 return ::readAttribute(getContext(), reader);
87 }
88
89 LogicalResult writeAttribute(Attribute attr,
91 return ::writeAttribute(attr, writer);
92 }
93
94
95
96
98 return ::readType(getContext(), reader);
99 }
100
101 LogicalResult writeType(Type type,
103 return ::writeType(type, writer);
104 }
105
107
108 }
109
110 std::unique_ptr
112
113 reader.emitError("Dialect does not support versioning");
114 return nullptr;
115 }
116
117 LogicalResult upgradeFromVersion(Operation *topLevelOp,
119 return success();
120 }
121 };
122
123 }
124
125
126
127
128
129
131 return {&getBodyGraph()};
132 }
133
134
135
136
137
139 return to_vector(llvm::map_range(shape, [](int64_t dim) {
140 return dim == -1 ? ShapedType::kDynamic : dim;
141 }));
142 }
143
144
146 Type elementType = variableOp.getType();
148 auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
150 }
151
152
153
154
155
156 void TosaDialect::initialize() {
157 addTypes<
158 #define GET_TYPEDEF_LIST
159 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
160 >();
161 addOperations<
162 #define GET_OP_LIST
163 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
164 >();
165 addAttributes<
166 #define GET_ATTRDEF_LIST
167 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
168 >();
169 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
170 declarePromisedInterfaces<
171 mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
172 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
173 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
174 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
175 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
176 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
177 GreaterEqualOp, MatMulOp>();
178 }
179
182
183
184 if (llvm::isa(type) && llvm::isa(value)) {
185 return builder.createtosa::ConstShapeOp(
186 loc, type, llvm::cast(value));
187 }
188 if (llvm::isa(value))
189 return builder.createtosa::ConstOp(loc, type,
190 llvm::cast(value));
191 return nullptr;
192 }
193
194
195
196
197
198 namespace {
199
200 ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
202 TypeAttr &typeAttr) {
203 if (auto shapedType = dyn_cast(parsedType)) {
204 if (!shapedType.hasRank())
206 << "expected ranked type";
207
208 auto elementType = shapedType.getElementType();
213 return success();
214 }
216 << "expected shaped type";
217 }
218
219 }
220
221
222
223
224
225
226
231 if (failed(parser.parseAttribute(initialValueAttr))) {
233 << "expected attribute";
234 }
235 if (auto typedAttr = dyn_cast(initialValueAttr)) {
236 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
237 typeAttr);
238 }
240 << "expected Typed attr";
241 }
242
243 initialValueAttr = nullptr;
244 Type parsedType;
247 << "expected type after colon";
248 }
249 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
250 }
251
254 TypeAttr typeAttr, Attribute initialValueAttr) {
255 bool needsSpace = false;
256 if (!dyn_cast_or_null(initialValueAttr)) {
257 auto shape =
259 Type elementType = typeAttr.getValue();
260 RankedTensorType tensorType =
263 p << ": ";
265 needsSpace = true;
266 }
267 if (initialValueAttr) {
268 if (needsSpace)
269 p << ' ';
270 p << "= ";
272 }
273 }
274
275
276
277
278
279 std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
280 if (lhs % rhs != 0)
281 return std::nullopt;
282 return lhs / rhs;
283 }
284
287 if (auto quantType = llvm::dyn_castmlir::quant::QuantizedType(srcType))
288 srcType = quantType.getStorageType();
289 return srcType;
290 }
291
294 }
295
297 Value valZp, StringRef name) {
300
301 bool bothInts =
302 mlir::isa(eType) && mlir::isa(eZpType);
303 bool sameBitWidth =
305
306 if (!bothInts || !sameBitWidth) {
308 << "expected " << name << " and " << name
309 << "_zp to both be integer of the same bitwidth, but got " << eType
310 << " vs. " << eZpType;
311 }
312 return success();
313 }
314
315
317 Value src, int32_t val) {
322 const auto padConstAttr{
323 llvm::isa(srcElemType)
328 return builder.createtosa::ConstOp(loc, padConstType, padConstAttr);
329 }
330
331
332
333
334
335 template
337 const auto inputType = llvm::dyn_cast(op.getInput().getType());
338 const auto weightType = llvm::dyn_cast(op.getWeight().getType());
339
340 auto inputEType = inputType.getElementType();
341 auto weightEType = weightType.getElementType();
342 auto biasEType =
343 llvm::cast(op.getBias().getType()).getElementType();
344 auto resultEType =
345 llvm::cast(op.getResult().getType()).getElementType();
346 bool biasIsFloat = llvm::isa(biasEType);
347 bool resultIsFloat = llvm::isa(resultEType);
348
349 if (auto quantType = llvm::dyn_castmlir::quant::QuantizedType(inputEType))
350 inputEType = quantType.getStorageType();
351
352 if (auto quantType = llvm::dyn_castmlir::quant::QuantizedType(weightEType))
353 weightEType = quantType.getStorageType();
354
355 if (auto quantType = llvm::dyn_castmlir::quant::QuantizedType(biasEType))
356 biasEType = quantType.getStorageType();
357
358 if (auto quantType = llvm::dyn_castmlir::quant::QuantizedType(resultEType))
359 resultEType = quantType.getStorageType();
360
361 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
362
363
364 op.emitOpError(
365 "expect both bias and result to have same element type, got ")
366 << biasEType << " and " << resultEType;
367 return failure();
368 }
369
370 if (isa(inputEType) || isa(inputEType) ||
371 isa(weightEType) || isa(weightEType)) {
372 if (inputEType != weightEType) {
373 op.emitOpError(
374 "expect both input and weight to have same element type, got ")
375 << inputEType << " and " << weightEType;
376 return failure();
377 }
378 }
379
380 bool inputIsFloat = llvm::isa(inputEType);
381 bool weightIsFloat = llvm::isa(weightEType);
382
383
384 if (inputIsFloat != weightIsFloat) {
385 op.emitOpError(
386 "expect both input and weight to be float or not together, got ")
387 << inputEType << " and " << weightEType;
388 return failure();
389 }
390
392 if (inputEType != inputZpEType) {
393 return op.emitOpError("expect both input and its zero point are the same "
394 "element type, got ")
395 << inputEType << " and " << inputZpEType;
396 }
397
399 if (weightEType != weightZpEType) {
400 return op.emitOpError("expect both weight and its zero point are the same "
401 "element type, got ")
402 << weightEType << " and " << weightZpEType;
403 }
404
405 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
406 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
407 return failure();
408
409 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
410 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
411 return failure();
412
413 return success();
414 }
415
417
418 auto attrType = llvm::dyn_cast(getValuesAttr().getType());
419 auto outputType = llvm::dyn_cast(getOutput().getType());
420
421 if (!attrType || !outputType) {
422 emitOpError("expected tensors for attr/result type");
423 return failure();
424 }
425
426 if (auto result = llvm::dyn_castmlir::quant::QuantizedType(
427 outputType.getElementType())) {
428 if (result.getStorageType() == attrType.getElementType())
429 return success();
430 }
431
432 if (attrType.getElementType() != outputType.getElementType()) {
433 emitOpError("expected same attr/result element types");
434 return failure();
435 }
436
437 return success();
438 }
439
440 template
442 auto inputEType =
443 llvm::cast(op.getInput().getType()).getElementType();
444
445 if (auto quantType = llvm::dyn_castmlir::quant::QuantizedType(inputEType))
446 inputEType = quantType.getStorageType();
447
448 auto accType = op.getAccType();
449 if (inputEType.isInteger(8) && !accType.isInteger(32))
450 return op.emitOpError("accumulator type for i8 tensor is not i32");
451
452 if (inputEType.isInteger(16) && !accType.isInteger(48))
453 return op.emitOpError("accumulator type for i16 tensor is not i48");
454
455 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
456 return op.emitOpError("accumulator type for f8 tensor is not f16");
457
458 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
459 return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
460
461 if (inputEType.isBF16() && !accType.isF32())
462 return op.emitOpError("accumulator type for bf16 tensor is not f32");
463
464 if (inputEType.isF32() && !accType.isF32())
465 return op.emitOpError("accumulator type for f32 tensor is not f32");
466
467 auto resultEType =
468 llvm::cast(op.getResult().getType()).getElementType();
469
470 if (auto quantType = llvm::dyn_castmlir::quant::QuantizedType(resultEType))
471 resultEType = quantType.getStorageType();
472
473 return success();
474 }
475
476
477
478
479
480
481 template
484 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
485 return op.emitOpError("expect all padding values to be >= 0, got ")
486 << padding;
487
489 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
490 return op.emitOpError("expect all stride values to be >= 1, got ")
491 << strides;
492
494 if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
495 return op.emitOpError("expect all dilation values to be >= 1, got ")
496 << dilations;
497
498 const RankedTensorType outputType =
499 llvm::dyn_cast(op.getOutput().getType());
500 if (!outputType)
501
502 return success();
503
504 const RankedTensorType inputType =
505 llvm::dyn_cast(op.getInput().getType());
506 const RankedTensorType weightType =
507 llvm::dyn_cast(op.getWeight().getType());
508
509 if (inputType && weightType) {
510 const auto verifyOutputSize =
511 [&op](const int64_t inputSize, const int64_t kernelSize,
512 const int64_t outputSize, const int64_t padBefore,
513 const int64_t padAfter, const int64_t stride,
514 const int64_t dilation, const llvm::StringRef dimName,
515 const llvm::StringRef dimAxis,
516 const llvm::StringRef padBeforeName,
517 const llvm::StringRef padAfterName) -> LogicalResult {
518 if (inputSize == ShapedType::kDynamic ||
519 kernelSize == ShapedType::kDynamic)
520 return success();
521
522
523
524 const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
525 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
526 stride);
527 if (!calculatedOutSizeMinusOne.has_value())
528 return op.emitOpError("expected input_")
529 << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
530 << padAfterName << " - (kernel_" << dimName
531 << " - 1) * dilation_" << dimAxis
532 << " to be wholly divisible by stride_" << dimAxis << ", got ("
533 << inputSize << " - 1 + " << padBefore << " + " << padAfter
534 << " - (" << kernelSize << " - 1) * " << dilation << ") / "
535 << stride;
536
537 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
538 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
539 return op.emitOpError("calculated output ")
540 << dimName << " did not match expected: "
541 << "calculated=" << calculatedOutSize
542 << ", expected=" << outputSize;
543
544 return success();
545 };
546
547
548 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
549 if (failed(verifyOutputSize(
550 inputType.getDimSize(1), weightType.getDimSize(1),
551 outputType.getDimSize(1), padding[0], padding[1], strides[0],
552 dilations[0], "height", "y", "top", "bottom")))
553 return failure();
554
555 if (failed(verifyOutputSize(
556 inputType.getDimSize(2), weightType.getDimSize(2),
557 outputType.getDimSize(2), padding[2], padding[3], strides[1],
558 dilations[1], "width", "x", "left", "right")))
559 return failure();
560 }
561
562
563 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
564 if (failed(verifyOutputSize(
565 inputType.getDimSize(1), weightType.getDimSize(0),
566 outputType.getDimSize(1), padding[0], padding[1], strides[0],
567 dilations[0], "height", "y", "top", "bottom")))
568 return failure();
569
570 if (failed(verifyOutputSize(
571 inputType.getDimSize(2), weightType.getDimSize(1),
572 outputType.getDimSize(2), padding[2], padding[3], strides[1],
573 dilations[1], "width", "x", "left", "right")))
574 return failure();
575 }
576
577
578 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
579 if (failed(verifyOutputSize(
580 inputType.getDimSize(1), weightType.getDimSize(1),
581 outputType.getDimSize(1), padding[0], padding[1], strides[0],
582 dilations[0], "depth", "d", "front", "back")))
583 return failure();
584
585 if (failed(verifyOutputSize(
586 inputType.getDimSize(2), weightType.getDimSize(2),
587 outputType.getDimSize(2), padding[2], padding[3], strides[1],
588 dilations[1], "height", "y", "top", "bottom")))
589 return failure();
590
591 if (failed(verifyOutputSize(
592 inputType.getDimSize(3), weightType.getDimSize(3),
593 outputType.getDimSize(3), padding[4], padding[5], strides[2],
594 dilations[2], "width", "x", "left", "right")))
595 return failure();
596 }
597 }
598
599 const RankedTensorType biasType =
600 llvm::dyn_cast(op.getBias().getType());
601 if (!biasType)
602
603 return success();
604
605 const int64_t biasChannels = biasType.getDimSize(0);
606 const int64_t outputChannels =
607 outputType.getDimSize(outputType.getRank() - 1);
608 if (biasChannels == ShapedType::kDynamic ||
609 outputChannels == ShapedType::kDynamic)
610
611 return success();
612
613 if (biasChannels != outputChannels && biasChannels != 1)
614 return op.emitOpError(
615 "bias channels expected to be equal to output channels (")
616 << outputChannels << ") or 1, got " << biasChannels;
617
618 return success();
619 }
620
621
623 StringRef name1, Type type2,
624 StringRef name2) {
625 auto shapeType1 = dyn_cast(type1);
626 auto shapeType2 = dyn_cast(type2);
627 if (!shapeType1 || !shapeType2)
628 return failure();
629
630 auto elemType1 = shapeType1.getElementType();
631 auto elemType2 = shapeType2.getElementType();
632 if (elemType1 != elemType2)
634 << "require same element type for " << name1 << " (" << elemType1
635 << ") and " << name2 << " (" << elemType2 << ")";
636
639 << "require same shapes for " << name1 << " (" << type1 << ") and "
640 << name2 << " (" << type2 << ")";
641
642 return success();
643 }
644
645
647 StringRef name1,
649 StringRef name2) {
650 if (list1.size() != list2.size())
652 << "require same number of values in " << name1 << " ("
653 << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
654
655 for (auto [type1, type2] :
658 return failure();
659 }
660
661 return success();
662 }
663
667 return success();
668
669 return shapeAdaptor.getNumElements() == 1 ? success() : failure();
670 }
671
672
673
675 StringRef symName) {
677 tosa::VariableOp varOp = nullptr;
678
679
680
681
682
683
684
685 module.walk([&](Operation *tempOp) {
686
687 if (tempOp == op) {
689 }
690
691 if (auto tosaOp = dyn_casttosa::VariableOp(tempOp)) {
692 if (symName == tosaOp.getName()) {
693 varOp = tosaOp;
695 }
696 }
697
699 });
700
701 if (varOp)
702 return varOp;
703
704 return failure();
705 }
706
707 template
709 StringRef symName = op.getName();
710 FailureOrtosa::VariableOp varOp = findVariableDecl(op, symName);
711 if (failed(varOp))
712 return op->emitOpError("'")
713 << symName << "' has not been declared by 'tosa.variable'";
714
715
718 "the input tensor")
719 .failed())
720 return failure();
721
722 return success();
723 }
724
725
726 template
728 auto inputType = llvm::dyn_cast(inType);
729 auto outputType = llvm::dyn_cast(outType);
730 if (!inputType) {
731 op.emitOpError("expect shaped tensor for input, got ") << inType;
732 return failure();
733 }
734 if (!outputType) {
735 op.emitOpError("expect shaped tensor for output, got ") << outType;
736 return failure();
737 }
738 auto inputElementType = inputType.getElementType();
739 auto outputElementType = outputType.getElementType();
740 auto inputQuantType =
741 llvm::dyn_castmlir::quant::UniformQuantizedType(inputElementType);
742 auto outputQuantType =
743 llvm::dyn_castmlir::quant::UniformQuantizedType(outputElementType);
744 if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
745 (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
746 inputElementType != outputElementType) {
747
748
749
750
751 op.emitOpError("expect input and output to have same element type, got ")
752 << inputElementType << " and " << outputElementType;
753 return failure();
754 }
755 return success();
756 }
757
759 const ShapedType resultType = llvm::cast(getType());
760
761
762 if (const auto resultETy = resultType.getElementType();
763 !resultETy.isIntOrIndex())
764 return emitOpError("result tensor is not of integer type");
765
766 const auto inputType = llvm::cast(getInput().getType());
767 if (!inputType.hasRank())
768 return success();
769
770
771 const int64_t axis = getAxisAttr().getInt();
772 if (((axis < 0) || axis >= inputType.getRank()))
773 return emitOpError("specified axis is outside the rank of the tensor");
774
775 if (!resultType.hasRank())
776 return success();
777
781 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
783 return emitOpError("expected output shape '")
784 << expectedOutputShape << "', got '" << outputShape << "'";
785
786 return success();
787 }
788
789 template
792 if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
793 return op.emitOpError("expect all kernel values to be >= 1, got ")
794 << kernel;
795
797 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
798 return op.emitOpError("expect all stride values to be >= 1, got ")
799 << strides;
800
802 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
803 return op.emitOpError("expect all padding values to be >= 0, got ")
804 << padding;
805
806
807 const int64_t kernelX = kernel[1];
808 const int64_t padLeft = padding[2];
809 const int64_t padRight = padding[3];
810 if (padRight >= kernelX || padLeft >= kernelX)
811 return op.emitOpError("expected left/right padding to be less than the "
812 "width of the kernel, got pad_left=")
813 << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
814
815 const int64_t kernelY = kernel[0];
816 const int64_t padTop = padding[0];
817 const int64_t padBottom = padding[1];
818 if (padTop >= kernelY || padBottom >= kernelY)
819 return op.emitOpError("expected top/bottom padding to be less than the "
820 "height of the kernel, got pad_top=")
821 << padTop << ", pad_bottom=" << padBottom
822 << ", kernel_y=" << kernelY;
823
824 const auto inputType =
825 llvm::dyn_cast(op.getInput().getType());
826 const auto outputType =
827 llvm::dyn_cast(op.getResult().getType());
828 if (!inputType || !outputType)
829 return success();
830
831 const auto verifyOutputSize =
832 [&op](const int64_t inputSize, const int64_t outputSize,
833 const int64_t kernelSize, const int64_t strideSize,
834 const int64_t padBefore, const int64_t padAfter,
835 const llvm::StringRef dimName, const llvm::StringRef dimAxis,
836 const llvm::StringRef padBeforeName,
837 const llvm::StringRef padAfterName) -> LogicalResult {
838 if (ShapedType::isDynamic(inputSize))
839 return success();
840
841 const std::optional<int64_t> calculatedOutSizeMinusOne =
842 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
843 if (!calculatedOutSizeMinusOne.has_value())
844 return op.emitOpError("expected input_")
845 << dimName << " + pad_" << padBeforeName << " + pad_"
846 << padAfterName << " - kernel_" << dimAxis
847 << " to be wholly divisible by stride_" << dimAxis << ", got ("
848 << inputSize << " + " << padBefore << " + " << padAfter << " - "
849 << kernelSize << ") / " << strideSize;
850
851 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
852 if (!ShapedType::isDynamic(outputSize) && calculatedOutSize != outputSize)
853 return op.emitOpError("calculated output ")
854 << dimName << " did not match expected: "
855 << "calculated=" << calculatedOutSize
856 << ", expected=" << outputSize;
857
858 return success();
859 };
860
861 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
862 kernel[0], strides[0], padding[0], padding[1],
863 "height", "y", "top", "bottom")))
864 return failure();
865
866 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
867 kernel[1], strides[1], padding[2], padding[3],
868 "width", "x", "left", "right")))
869 return failure();
870
871 return success();
872 }
873
876 return failure();
877
882
883 auto accType = getAccType();
884 if (llvm::isa(inputETy) && !accType.isInteger(32))
885 return emitOpError("accumulator type for integer tensor is not i32");
886
887 if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
888 return emitOpError("accumulator type for f16 tensor is not f16/f32");
889
890 if (inputETy.isBF16() && !accType.isF32())
891 return emitOpError("accumulator type for bf16 tensor is not f32");
892
893 if (inputETy.isF32() && !accType.isF32())
894 return emitOpError("accumulator type for f32 tensor is not f32");
895
896 if (inputETy != inputZpETy)
897 return emitOpError("expect both input and its zero point are the same "
898 "element type, got ")
899 << inputETy << " and " << inputZpETy;
900
901 if (resultETy != outputZpETy)
902 return emitOpError("expect both output and its zero point are the same "
903 "element type, got ")
904 << resultETy << " and " << outputZpETy;
905
906 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
907 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
908 return failure();
909
910 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
911 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
912 return failure();
913
914 return success();
915 }
916
919 llvm::cast(getInput().getType()).getElementType();
920 if (auto quantType =
921 llvm::dyn_castmlir::quant::UniformQuantizedType(inputETy)) {
922 inputETy = quantType.getStorageType();
923 }
925 llvm::cast(getOutput().getType()).getElementType();
926 if (auto quantType =
927 llvm::dyn_castmlir::quant::UniformQuantizedType(outputETy)) {
928 outputETy = quantType.getStorageType();
929 }
930 if (inputETy != outputETy)
931 return emitOpError("input/output element types are incompatible.");
932
933 auto maxValAttr = getMaxValAttr();
934 auto minValAttr = getMinValAttr();
935
937
938 if (inputETy.isInteger(dataTypeBitWidth)) {
939
940
941
942 auto intMaxValAttr = mlir::dyn_castmlir::IntegerAttr(maxValAttr);
943 auto intMinValAttr = mlir::dyn_castmlir::IntegerAttr(minValAttr);
944 if (!intMaxValAttr || !intMinValAttr ||
945 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
946 (intMaxValAttr.getType() != inputETy))
947 return emitOpError("min/max attributes types are incompatible with "
948 "input/output element types.");
949
950 const bool isUnsigned = cast(inputETy).isUnsigned();
951 const APInt minVal = intMinValAttr.getValue();
952 const APInt maxVal = intMaxValAttr.getValue();
953 if (isUnsigned ? maxVal.ult(minVal) : maxVal.slt(minVal))
954 return emitOpError("expected min_val <= max_val, got min_val=")
955 << minValAttr << ", max_val=" << maxValAttr;
956 } else {
957
958
959
960 auto floatMaxValAttr = mlir::dyn_castmlir::FloatAttr(maxValAttr);
961 auto floatMinValAttr = mlir::dyn_castmlir::FloatAttr(minValAttr);
962 if (!floatMaxValAttr || !floatMinValAttr ||
963 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
964 (floatMaxValAttr.getType() != inputETy))
965 return emitOpError("min/max attributes types are incompatible with "
966 "input/output element types.");
967
968 const APFloat minVal = floatMinValAttr.getValue();
969 const APFloat maxVal = floatMaxValAttr.getValue();
970 if (minVal.isNaN() || maxVal.isNaN())
971 return emitOpError("min/max attributes should not be 'NaN', got min_val=")
972 << minValAttr << ", max_val=" << maxValAttr;
973
974 if (maxVal < minVal)
975 return emitOpError("expected min_val <= max_val, got min_val=")
976 << minValAttr << ", max_val=" << maxValAttr;
977 }
978
979 return success();
980 }
981
982
983
984
985
986
987
988
994 TypeAttr accType) {
996 result.addOperands({input, weight, bias, zps.first, zps.second});
1001 Type finalOutputType = outputType;
1003 if (quantAttr) {
1004 finalOutputType =
1006 }
1007 result.addTypes(finalOutputType);
1008 }
1009
1010
1011
1012 static void
1018 result.addOperands({input, weight, bias, zps.first, zps.second});
1022 Type finalOutputType = outputType;
1024 if (quantAttr) {
1025 finalOutputType =
1027 }
1028 result.addTypes(finalOutputType);
1029 }
1030
1031
1032
1033
1034
1039 result.addOperands({a, b, zps.first, zps.second});
1040
1041 Type finalOutputType{outputType};
1044 auto inputBits = eType.getIntOrFloatBitWidth();
1045
1046 auto outputShapedType = llvm::dyn_cast(outputType);
1047 assert(outputShapedType && "Output must be a shaped type");
1048
1049 IntegerType accElementType;
1050 if (inputBits == 16)
1052 else
1053 accElementType = builder.getI32Type();
1054
1055 finalOutputType = outputShapedType.clone(accElementType);
1056 }
1057 result.addTypes(finalOutputType);
1058 }
1059
1060
1061
1062
1063 static void
1066 DenseArrayAttr kernel, DenseArrayAttr stride,
1067 DenseArrayAttr pad, TypeAttr accType) {
1069 int64_t inputZp{0};
1070 int64_t outputZp{0};
1071
1072 if (auto quantAttr =
1074 inputZp = quantAttr.getInputZp();
1075 outputZp = quantAttr.getOutputZp();
1076 }
1077 const std::optional inputZpOp =
1079 if (!inputZpOp) {
1081 loc,
1082 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1083 }
1084 const std::optional outputZpOp =
1086 if (!outputZpOp) {
1087 (void)emitError(loc, "Failed to create output zero point tensor for "
1088 "quantized AVG_POOL2D op");
1089 }
1090
1091 if (inputZpOp && outputZpOp) {
1092 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1093 } else {
1094
1095
1096
1098 }
1103 result.types.push_back(outputType);
1104 }
1105
1106
1107
1108
1113 int64_t input1Zp{0};
1114 int64_t outputZp{0};
1116 if (quantAttr) {
1117 input1Zp = quantAttr.getInputZp();
1118 outputZp = quantAttr.getOutputZp();
1119 }
1120 const std::optional input1ZpOp =
1122 if (!input1ZpOp) {
1124 loc, "Failed to create input1 zero point for quantized NEGATE op");
1125 }
1126
1127 const std::optional outputZpOp =
1129 if (!outputZpOp) {
1131 loc, "Failed to create output zero point for quantized NEGATE op");
1132 }
1133
1134 if (input1ZpOp && outputZpOp) {
1135 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1136 } else {
1137
1138
1139
1141 }
1142
1143 result.types.push_back(outputType);
1144 }
1145
1146
1147
1148
1151 Value paddings) {
1153 int32_t zp{0};
1155 if (quantAttr) {
1156 zp = static_cast<int32_t>(quantAttr.getInputZp());
1157 }
1159 result.addOperands({input, paddings, padConstOp});
1160 result.types.push_back(outputType);
1161 }
1162
1164 StringRef name, Type variableType,
1168
1169 auto shapedType = dyn_cast(variableType);
1170 if (!shapedType) {
1171 (void)emitError(loc, "variable type must be a shaped type");
1172 return;
1173 }
1174 if (!shapedType.hasRank()) {
1175 (void)emitError(loc, "variable type must be a ranked type");
1176 return;
1177 }
1178
1179 auto elementType = shapedType.getElementType();
1180 auto elementTypeAttr = TypeAttr::get(elementType);
1183
1185 result.addAttribute("var_shape", varShapeAttr);
1186 result.addAttribute("type", elementTypeAttr);
1187 result.addAttribute("initial_value", initialValue);
1188 }
1189
1190
1191
1192
1193
1196 int64_t outRank = 0;
1197 for (int i = 0, e = operands.size(); i != e; ++i) {
1198 auto shape = operands.getShape(i);
1199 if (!shape.hasRank()) {
1200
1201
1202 return failure();
1203 }
1204 outRank = std::max<int64_t>(outRank, shape.getRank());
1205 }
1206
1207 outShape.resize(outRank, 1);
1208
1209 for (int i = 0, e = operands.size(); i != e; ++i) {
1210 auto shape = operands.getShape(i);
1211 auto rankDiff = outShape.size() - shape.getRank();
1212
1213 for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1214 auto dim1 = outShape[i + rankDiff];
1215 auto dim2 = shape.getDimSize(i);
1216 auto resolvedDim = dim1;
1217
1218 if (dim1 == 1) {
1219 resolvedDim = dim2;
1220 } else if (dim2 == 1) {
1221 resolvedDim = dim1;
1222 } else if (dim1 != dim2) {
1223 return failure();
1224 }
1225 outShape[i + rankDiff] = resolvedDim;
1226 }
1227 }
1228
1229 return success();
1230 }
1231
1232 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1233 MLIRContext *context, ::std::optional location,
1234 ArgMaxOp::Adaptor adaptor,
1236 ShapeAdaptor inputShape(adaptor.getInput().getType());
1237 IntegerAttr axis = adaptor.getProperties().axis;
1238 int32_t axisVal = axis.getValue().getSExtValue();
1239
1240 if (!inputShape.hasRank()) {
1242 return success();
1243 }
1244
1246 outShape.reserve(inputShape.getRank() - 1);
1247 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1248 if (i == axisVal)
1249 continue;
1250 outShape.push_back(inputShape.getDimSize(i));
1251 }
1252
1254 return success();
1255 }
1256
1257 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1258 MLIRContext *context, ::std::optional location,
1259 RFFT2dOp::Adaptor adaptor,
1261 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1262
1263 if (!inputShape.hasRank())
1264 return failure();
1265
1267 outputShape.resize(3, ShapedType::kDynamic);
1268 outputShape[0] = inputShape.getDimSize(0);
1269 outputShape[1] = inputShape.getDimSize(1);
1270 int64_t inWidth = inputShape.getDimSize(2);
1271
1272
1273
1274 if (inWidth != ShapedType::kDynamic)
1275 outputShape[2] = inWidth / 2 + 1;
1276
1279
1280 return success();
1281 }
1282
1284 const llvm::StringRef dimName) {
1285 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1286 if (!isPowerOfTwo)
1288 << dimName << " to be a power of two, got " << dimSize;
1289
1290 return success();
1291 }
1292
1294 const auto outputTypes = getResultTypes();
1296 return emitOpError("expected output shapes to match, got ") << outputTypes;
1297
1298 const auto inputType =
1299 llvm::dyn_cast(getInputReal().getType());
1300 if (!inputType)
1301 return success();
1302
1303 const int64_t height = inputType.getDimSize(1);
1304 if (!ShapedType::isDynamic(height) &&
1306 return failure();
1307
1308 const int64_t width = inputType.getDimSize(2);
1309 if (!ShapedType::isDynamic(width) &&
1311 return failure();
1312
1313 const auto outputType = llvm::dyn_cast(outputTypes[0]);
1314 if (!outputType)
1315 return success();
1316
1317
1319 outputType.getShape().drop_back())))
1320 return emitOpError("expected batch and height dimensions of input/output "
1321 "to match, got input=")
1322 << inputType << " output=" << outputType;
1323
1324
1325 const int64_t outputWidth = outputType.getDimSize(2);
1326 if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
1327 (outputWidth != (width / 2) + 1))
1328 return emitOpError(
1329 "expected output width to be equal to input_width / 2 + 1, got ")
1330 << outputWidth;
1331
1332 return success();
1333 }
1334
1335 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1336 MLIRContext *context, ::std::optional location,
1337 FFT2dOp::Adaptor adaptor,
1339 inferredReturnShapes.push_back(
1341 inferredReturnShapes.push_back(
1343 return success();
1344 }
1345
1347 const auto inputRealType =
1348 llvm::dyn_cast(getInputReal().getType());
1349 const auto inputImagType =
1350 llvm::dyn_cast(getInputImag().getType());
1351 if (!inputRealType || !inputImagType)
1352 return success();
1353
1354 const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1355 return ShapedType::isDynamic(a) ? a : b;
1356 };
1357
1358 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1359 inputImagType.getDimSize(1));
1360 if (!ShapedType::isDynamic(height) &&
1362 return failure();
1363
1364 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1365 inputImagType.getDimSize(2));
1366 if (!ShapedType::isDynamic(width) &&
1368 return failure();
1369
1370 return success();
1371 }
1372
1373 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1374 MLIRContext *context, ::std::optional location,
1375 ConcatOp::Adaptor adaptor,
1377
1378 const Properties &prop = adaptor.getProperties();
1379 int32_t axis = prop.axis.getValue().getSExtValue();
1381 bool hasRankedInput = false;
1382 for (auto operand : adaptor.getOperands()) {
1383 ShapeAdaptor operandShape(operand.getType());
1384 if (!operandShape.hasRank())
1385 continue;
1386
1387
1388 if (!hasRankedInput)
1389 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1390
1391
1392 for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1393 if (i == axis || operandShape.isDynamicDim(i))
1394 continue;
1395 if (outputShape[i] == ShapedType::kDynamic)
1396 outputShape[i] = operandShape.getDimSize(i);
1397 if (outputShape[i] != operandShape.getDimSize(i))
1399 "Cannot concat tensors with different sizes"
1400 " on the non-axis dimension ",
1401 i);
1402 }
1403
1404 hasRankedInput = true;
1405 }
1406
1407 if (adaptor.getInput1().empty())
1408 return failure();
1409
1410 Type inputType =
1411 llvm::cast(adaptor.getInput1().getType()[0]).getElementType();
1412 if (!hasRankedInput) {
1414 return success();
1415 }
1416
1417
1418 int64_t concatDimSize = 0;
1419 for (auto operand : adaptor.getOperands()) {
1420 ShapeAdaptor operandShape(operand.getType());
1421
1422
1423
1424 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1425 concatDimSize = ShapedType::kDynamic;
1426 break;
1427 }
1428
1429 concatDimSize += operandShape.getDimSize(axis);
1430 }
1431
1432 outputShape[axis] = concatDimSize;
1433
1434 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1435 return success();
1436 }
1437
1439
1440 auto outType = getOutput().getType();
1442
1443
1444 if (inputList.empty())
1445 return emitOpError("expect at least one input");
1446
1447 if (!llvm::all_of(inputList, [&](auto input) {
1449 *this, input.getType(), outType));
1450 })) {
1451 return failure();
1452 }
1453
1454 const int32_t axis = getAxis();
1455 ShapeAdaptor firstRankedInputShape = nullptr;
1456 for (const auto &input : inputList) {
1457 const Type inputType = input.getType();
1459 if (currShape.hasRank()) {
1460 firstRankedInputShape = currShape;
1461
1462 if (axis < 0 || axis >= firstRankedInputShape.getRank())
1463 return emitOpError("expect axis to be within range 0 < axis < "
1464 "rank(input1[firstRankedTensorIdx]), got ")
1465 << axis;
1466 break;
1467 }
1468 }
1469
1470 const auto allOperandsHasRank = [](const Value input) {
1472 };
1473 if (llvm::all_of(inputList, allOperandsHasRank)) {
1474 const int64_t firstInputRank = firstRankedInputShape.getRank();
1475
1476 for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1477 const ShapeAdaptor inputShape(input.getType());
1478 const int64_t inputRank = inputShape.getRank();
1479 const size_t operandNum = index + 1;
1480
1481
1482 if (inputRank != firstInputRank)
1483 return emitOpError(
1484 "expect all operands to have the same rank, but got ")
1485 << firstInputRank << " vs " << inputRank << " on operands 0 and "
1486 << operandNum;
1487
1488
1489 for (int i = 0; i < inputRank; i++) {
1490 const int64_t inputDim = inputShape.getDimSize(i);
1491 const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1492 if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1493 inputShape.isDynamicDim(i))
1494 continue;
1495 if (inputDim != firstInputDim)
1496 return emitOpError("expect all operand shapes to have the same sizes "
1497 "on non-axis dimensions, but got ")
1498 << inputDim << " vs " << firstInputDim << " at index " << i
1499 << " on operands 0 and " << operandNum;
1500 }
1501 }
1502
1503
1504 int64_t axisSum = 0;
1505 for (const auto &input : inputList) {
1506 const ShapeAdaptor inputShape(input.getType());
1507 if (inputShape.isDynamicDim(axis)) {
1508
1509 axisSum = -1;
1510 break;
1511 }
1512 axisSum += inputShape.getDimSize(axis);
1513 }
1515 if (axisSum >= 0 && outputShape.hasRank() &&
1516 !outputShape.isDynamicDim(axis) &&
1517 axisSum != outputShape.getDimSize(axis))
1518 return emitOpError("requires sum of axis dimensions of input1 "
1519 "equal to output axis dimension, got ")
1520 << axisSum << " and " << outputShape.getDimSize(axis);
1521 }
1522
1523 return success();
1524 }
1525
1526 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1527 MLIRContext *context, ::std::optional location,
1532
1536 return success();
1537 }
1538
1539 inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1540 return success();
1541 }
1542
1543 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1544 if (l.size() != r.size() || l.size() != 1)
1545 return false;
1547 }
1548
1549 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1550 MLIRContext *context, ::std::optional location,
1551 MatMulOp::Adaptor adaptor,
1553 ShapeAdaptor lhsShape(adaptor.getA().getType());
1554 ShapeAdaptor rhsShape(adaptor.getB().getType());
1555
1556
1558 outShape.resize(3, ShapedType::kDynamic);
1559
1560 if (lhsShape.hasRank()) {
1561 outShape[0] = lhsShape.getDimSize(0);
1562 outShape[1] = lhsShape.getDimSize(1);
1563 }
1564
1565 if (rhsShape.hasRank()) {
1566 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1567 : outShape[0];
1568 outShape[2] = rhsShape.getDimSize(2);
1569 }
1570
1572 return success();
1573 }
1574
1576 auto aType = llvm::dyn_cast(getA().getType());
1577 auto bType = llvm::dyn_cast(getB().getType());
1578
1579
1580 if (!aType)
1581 return emitOpError("expect a shaped tensor for input a, got ")
1582 << getA().getType();
1583
1584 if (!bType)
1585 return emitOpError("expect a shaped tensor for input b, got ")
1586 << getB().getType();
1587
1588 auto aElementType = aType.getElementType();
1589 auto bElementType = bType.getElementType();
1590
1591 auto aQuantizedEType =
1592 llvm::dyn_castquant::UniformQuantizedType(aElementType);
1593 auto bQuantizedEType =
1594 llvm::dyn_castquant::UniformQuantizedType(bElementType);
1595
1596 if (aQuantizedEType || bQuantizedEType) {
1597 if (!aQuantizedEType || !bQuantizedEType) {
1598 return emitOpError("expect operands to be both quantized or both not "
1599 "quantized, got ")
1600 << aElementType << " and " << bElementType;
1601 }
1602
1603 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1604 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1605 if (aQuantWidth != bQuantWidth) {
1606 return emitOpError("expect quantized operands to have same widths, got ")
1607 << aQuantWidth << " and " << bQuantWidth;
1608 }
1609 } else {
1610
1611 if (aElementType != bElementType) {
1612 return emitOpError("expect same element type for inputs a and b, got ")
1613 << aElementType << " and " << bElementType;
1614 }
1615 }
1616
1617
1620 if (aEType != aZpEType) {
1621 return emitOpError("expect input a and a_zp have the same "
1622 "element type, got ")
1623 << aEType << " and " << aZpEType;
1624 }
1625
1628 if (bEType != bZpEType) {
1629 return emitOpError("expect input b and b_zp have the same "
1630 "element type, got ")
1631 << bEType << " and " << bZpEType;
1632 }
1633
1634 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1635 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1636 return failure();
1637
1638 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1639 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1640 return failure();
1641
1642 return success();
1643 }
1644
1645 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1646 MLIRContext *context, ::std::optional location,
1647 PadOp::Adaptor adaptor,
1649 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1650 auto paddingRank =
1651 casttosa::shapeType(adaptor.getPadding().getType()).getRank();
1653
1654
1655
1656 if (!inputShape.hasRank()) {
1657 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1659 return success();
1660 }
1661
1663
1665 paddingValues)) {
1666 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1668 return success();
1669 }
1670
1671 outputShape.reserve(inputShape.getRank());
1672 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1673 if (inputShape.isDynamicDim(i)) {
1674 outputShape.push_back(ShapedType::kDynamic);
1675 continue;
1676 }
1677 auto padFront = paddingValues[i * 2];
1678 auto padBack = paddingValues[i * 2 + 1];
1679 if (padFront < 0 || padBack < 0) {
1680
1681 outputShape.push_back(ShapedType::kDynamic);
1682 continue;
1683 }
1684
1685 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1686 }
1687
1689 return success();
1690 }
1691
1694 getOutput().getType())
1695 .failed()) {
1696 return failure();
1697 }
1698
1699 if (auto padConst = getPadConst()) {
1701 getOutput().getType())
1702 .failed()) {
1703 return failure();
1704 }
1705 }
1706
1707 RankedTensorType inputType =
1708 llvm::dyn_cast(getInput1().getType());
1709 RankedTensorType outputType =
1710 llvm::dyn_cast(getOutput().getType());
1711 if (!inputType || !outputType)
1712 return success();
1713
1714 auto inputRank = inputType.getRank();
1715 auto outputRank = outputType.getRank();
1716 if (inputRank != outputRank)
1717 return emitOpError() << "expect same input and output tensor rank, but got "
1718 << "inputRank: " << inputRank
1719 << ", outputRank: " << outputRank;
1720
1723 return failure();
1724 }
1725
1726 auto paddingValues = paddingAttr.getValues();
1727 if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
1728 return emitOpError() << "padding tensor must have " << inputRank
1729 << " * 2 = " << inputRank * 2 << " elements, but got "
1730 << paddingValues.size();
1731
1732 auto inputShape = inputType.getShape();
1733 auto outputShape = outputType.getShape();
1734
1735 for (int64_t i = 0; i < inputRank; ++i) {
1736 int64_t padStart = paddingValues[i * 2].getSExtValue();
1737 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1738
1739 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1740 return emitOpError()
1741 << "invalid padding values at dimension " << i
1742 << ": values must be non-negative or -1 for dynamic padding, got ["
1743 << padStart << ", " << padEnd << "]";
1744 }
1745
1746
1747 if (inputShape[i] == ShapedType::kDynamic ||
1748 outputShape[i] == ShapedType::kDynamic)
1749 continue;
1750
1751 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1752 return emitOpError() << "mismatch in output shape at dimension " << i
1753 << ": expected " << inputShape[i] << " + "
1754 << padStart << " + " << padEnd << " = "
1755 << (inputShape[i] + padStart + padEnd)
1756 << ", but got " << outputShape[i];
1757 }
1758 }
1759
1760 return success();
1761 }
1762
1763 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1764 MLIRContext *context, ::std::optional location,
1765 SliceOp::Adaptor adaptor,
1767
1771
1774 auto rank = casttosa::shapeType(adaptor.getSize().getType()).getRank();
1777 return success();
1778 }
1779
1780
1781
1782 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1783
1785 if (inputShape.hasRank()) {
1786 for (size_t i = 0; i < size.size(); i++) {
1787 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1788 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1789 start[i] < inputShape.getDimSize(i))) {
1790
1791 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1792
1793 if (size[i] > 0) {
1794 outputShape[i] = size[i];
1795 }
1796 } else {
1797
1798 if (size[i] == -1) {
1799 outputShape[i] = inputShape.getDimSize(i) - start[i];
1800 } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1801
1802 outputShape[i] = size[i];
1803 }
1804 }
1805 }
1806 }
1807 } else {
1809 }
1811 return success();
1812 }
1813
1816 getOutput().getType())
1817 .failed())
1818 return failure();
1819
1821 if (inputShape.hasRank()) {
1822 const auto inputRank = inputShape.getRank();
1824 if (outputShape.hasRank() && inputRank != outputShape.getRank())
1825 return emitOpError(
1826 "expect input1 and output to have the same ranks, got ")
1827 << inputRank << " and " << outputShape.getRank();
1828
1829 const auto startShapeRank =
1830 llvm::casttosa::shapeType(getStart().getType()).getRank();
1831 if (inputRank != startShapeRank)
1832 return emitOpError("length of start is not equal to rank of input shape");
1833
1834 const auto sizeShapeRank =
1835 llvm::casttosa::shapeType(getSize().getType()).getRank();
1836 if (inputRank != sizeShapeRank)
1837 return emitOpError("length of size is not equal to rank of input shape");
1838 }
1839
1840 return success();
1841 }
1842
1843 LogicalResult tosa::MulOp::inferReturnTypeComponents(
1844 MLIRContext *context, ::std::optional location,
1848
1853 } else {
1855 }
1856 return success();
1857 }
1858
1860 const Value output = getOutput();
1862
1863
1864
1865 if (auto resIntType = dyn_cast(resElemType)) {
1866 IntegerType lhsIntType =
1868 IntegerType rhsIntType =
1870 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
1871 return emitOpError("requires the same element type for all operands");
1872
1873
1874
1875
1876 if (lhsIntType.getWidth() > resIntType.getWidth())
1877 return emitOpError("invalid data type size for operands or result");
1878
1879 } else {
1880
1881
1882 for (int i = 0; i < 2; ++i) {
1884 return emitOpError(
1885 "requires the same element type for all operands and results");
1886 }
1887
1888
1889 ElementsAttr shift_elem;
1891 int32_t shift = shift_elem.getValues()[0].getInt();
1892 if (shift != 0) {
1893 return emitOpError() << "require shift to be 0 for float type";
1894 }
1895 }
1896 }
1897
1898
1899
1900
1901 TypeRange operandTypes = getOperandTypes();
1902 ShapedType aType = cast(operandTypes[0]);
1903 ShapedType bType = cast(operandTypes[1]);
1904
1905 const bool aHasRank = aType.hasRank();
1906 const bool bHasRank = bType.hasRank();
1907 if (aHasRank && bHasRank) {
1908 const int64_t aRank = aType.getRank();
1909 const int64_t bRank = bType.getRank();
1910 if (aRank != bRank)
1911 return emitOpError("a and b operands don't have matching ranks, got ")
1912 << aRank << " and " << bRank;
1913
1914
1917 aType.getShape(), bType.getShape(), resultShape))
1918 return emitOpError("a and b operands don't have broadcast-compatible "
1919 "shapes, got ")
1920 << aType << " and " << bType;
1921 }
1922
1923 ShapedType resultType = cast(output.getType());
1924 if (!resultType.hasRank())
1925 return success();
1926
1927 const int64_t resultRank = resultType.getRank();
1928 if (aHasRank && resultRank != aType.getRank())
1929 return emitOpError("result type has different rank than a, got ")
1930 << resultRank << " vs " << aType.getRank();
1931 if (bHasRank && resultRank != bType.getRank())
1932 return emitOpError("result type has different rank than b, got ")
1933 << resultRank << " vs " << bType.getRank();
1934
1935 return success();
1936 }
1937
1938 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1939 MLIRContext *context, ::std::optional location,
1940 TableOp::Adaptor adaptor,
1942 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1943
1944 if (!inputShape.hasRank()) {
1946 return success();
1947 }
1948
1949 inferredReturnShapes.resize(1);
1950 inputShape.getDims(inferredReturnShapes[0]);
1951 return success();
1952 }
1953
1955 TensorType inputType = getInput1().getType();
1956 TensorType outputType = getOutput().getType();
1957
1959 inputType.getRank() != outputType.getRank())
1960 return emitOpError()
1961 << "expected input tensor rank to equal result tensor rank";
1962
1963 auto inputDims = inputType.getShape();
1964 auto outputDims = outputType.getShape();
1965 for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
1966 int64_t dim = it.index();
1967 auto [inputDim, outputDim] = it.value();
1968 if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
1969 return emitOpError() << "dim(result, " << dim << ") = " << outputDim
1970 << " doesn't match dim(input, " << dim
1971 << ") = " << inputDim;
1972 }
1973 }
1974 return success();
1975 }
1976
1977 LogicalResult
1979
1982 return failure();
1983 multiples = llvm::to_vector(
1984 llvm::map_range(multiplesAttr.getValues(),
1985 [](const APInt &val) { return val.getSExtValue(); }));
1986 return success();
1987 }
1988
1989 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1990 MLIRContext *context, ::std::optional location,
1991 TileOp::Adaptor adaptor,
1996 multiples)) {
1997 auto rank =
1998 casttosa::shapeType(adaptor.getMultiples().getType()).getRank();
2001 return success();
2002 } else {
2004 }
2005
2006 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2008 if (!inputShape.hasRank()) {
2009 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2010 inferredReturnShapes.push_back(
2012 return success();
2013 } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2014 return failure();
2015
2016
2017 outputShape.reserve(multiples.size());
2018 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2019 if (multiples[i] == ShapedType::kDynamic) {
2020 outputShape.push_back(ShapedType::kDynamic);
2021 } else {
2022 int64_t dim = inputShape.getDimSize(i);
2023 if (dim != ShapedType::kDynamic)
2024 dim *= multiples[i];
2025 outputShape.push_back(dim);
2026 }
2027 }
2028
2029 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2030 return success();
2031 }
2032
2035 getOutput().getType())
2036 .failed()) {
2037 return failure();
2038 }
2039 ShapedType inputType = llvm::cast(getInput1().getType());
2040 ShapedType outputType = llvm::cast(getType());
2041
2042 shapeType multiplesType =
2043 llvm::casttosa::shapeType(getMultiples().getType());
2044
2045 auto multiplesRank = multiplesType.getRank();
2046
2047 if (inputType.hasRank()) {
2048 if (inputType.getRank() != multiplesRank)
2049 return emitOpError("expect 'multiples' to have rank ")
2050 << inputType.getRank() << " but got " << multiplesRank << ".";
2051 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2052 return emitOpError("expect same input and output tensor rank.");
2053 } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2054 return emitOpError("expect 'multiples' array to have length ")
2055 << outputType.getRank() << " but got " << multiplesRank << ".";
2056
2058 if (getConstantMultiples(multiples).succeeded() &&
2059 llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2060 return emitOpError(
2061 "expect element of 'multiples' to be positive integer or -1.");
2062
2063 return success();
2064 }
2065
2066 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2067 if (l.size() != r.size() || l.size() != 1)
2068 return false;
2070 }
2071
2072 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2073 MLIRContext *context, ::std::optional location,
2074 ReshapeOp::Adaptor adaptor,
2076 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2080 newShapeValue)) {
2081 auto rank = casttosa::shapeType(adaptor.getShape().getType()).getRank();
2084 return success();
2085 } else {
2087 }
2088
2089
2090
2091 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2092 inferredReturnShapes.push_back(
2094 return success();
2095 }
2096
2097
2098
2099
2100 int64_t numElements = inputShape.getNumElements();
2101 int64_t staticMul = 1;
2102 for (auto val : newShapeValue) {
2103 if (!ShapedType::isDynamic(val)) {
2104 staticMul *= val;
2105 }
2106 }
2107
2108
2109 for (auto &val : newShapeValue) {
2110 if (ShapedType::isDynamic(val))
2111 val = numElements / staticMul;
2112 }
2113
2114 inferredReturnShapes.push_back(
2116 return success();
2117 }
2118
2121 getOutput().getType())
2122 .failed()) {
2123 return failure();
2124 }
2125 TensorType inputType = getInput1().getType();
2126
2129
2130 return mlir::success();
2131 }
2132
2133 int missingDims = llvm::count(shapeValues, -1);
2134 if (missingDims > 1)
2135 return emitOpError() << "expected at most one target dimension to be -1";
2136
2137 const auto outputType = dyn_cast(getType());
2138 if (!outputType)
2139 return success();
2140
2141 if ((int64_t)shapeValues.size() != outputType.getRank())
2142 return emitOpError() << "new shape does not match result rank";
2143
2144 for (auto [newShapeDim, outputShapeDim] :
2145 zip(shapeValues, outputType.getShape())) {
2146 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2147 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2148 return emitOpError() << "new shape is inconsistent with result shape";
2149
2150 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2151 return emitOpError() << "new shape has invalid tensor dimension size "
2152 << newShapeDim;
2153 }
2154
2155 if (inputType.hasStaticShape()) {
2156 int64_t inputElementsNum = inputType.getNumElements();
2157 if (outputType.hasStaticShape()) {
2158 int64_t outputElementsNum = outputType.getNumElements();
2159 if (inputElementsNum != outputElementsNum) {
2160 return emitOpError() << "cannot reshape " << inputElementsNum
2161 << " elements into " << outputElementsNum;
2162 }
2163 }
2164
2165 int64_t newShapeElementsNum = std::accumulate(
2166 shapeValues.begin(), shapeValues.end(), 1LL,
2167 [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
2168 bool isStaticNewShape =
2169 llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2170 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2171 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2172 return emitOpError() << "cannot reshape " << inputElementsNum
2173 << " elements into " << newShapeElementsNum;
2174 }
2175 }
2176
2177 return mlir::success();
2178 }
2179
2180
2181
2182
2184 ElementsAttr zpAttr;
2186 return failure();
2187 }
2188
2189 Type zpElemType = zpAttr.getElementType();
2190
2191 if (llvm::isa(zpElemType)) {
2192 if (zpAttr.getValues()[0].isZero()) {
2193 return 0;
2194 }
2195
2196 return -1;
2197 }
2198
2199 if (llvm::isa(zpElemType)) {
2200 if (signExtend)
2201 return zpAttr.getValues()[0].getSExtValue();
2202 else
2203 return zpAttr.getValues()[0].getZExtValue();
2204 }
2205
2206
2207 return -1;
2208 }
2209
2210 template
2212 const std::string &operand) {
2214
2215 if (!zpElemType.isInteger(8) && zp != 0) {
2216
2217 std::string lower = operand;
2218 std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2219 return op.emitOpError()
2220 << lower << " zero point must be zero for non-int8 integer types";
2221 }
2222
2223 return success();
2224 }
2225
2227 const int64_t &zp,
2228 const std::string &operand) {
2229 bool isInputZp = (operand == "Input");
2230
2231 bool tensorUnsigned =
2232 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2233 StringRef tensorName = isInputZp ? "input" : "output";
2234
2236
2237 if (zp != 0) {
2239 !(zpElemType.isInteger(16) && tensorUnsigned)) {
2240 return op.emitOpError()
2241 << "expect " << tensorName << "_zp of 0, got " << zp;
2242 }
2243 if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
2244 return op.emitOpError() << "expect " << tensorName
2245 << "_zp of 0 or 32768 for unsigned int16 "
2246 << tensorName << ", got " << zp;
2247 }
2248 }
2249
2250 return success();
2251 }
2252
2253 #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2254 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2255 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2256 } \
2257 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2258 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2259 }
2260
2277 #undef ZERO_POINT_HELPER
2278
2279 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2280 MLIRContext *context, ::std::optional location,
2281 TransposeOp::Adaptor adaptor,
2283 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2284
2285
2286
2287 if (!inputShape.hasRank()) {
2289 return success();
2290 }
2291
2292 const auto inputRank = inputShape.getRank();
2293
2294
2295
2296 if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2297 return failure();
2298 }
2299
2301
2302 if (inputRank == 0) {
2304 return success();
2305 }
2306
2307
2308 bool allTheSame = true;
2309 for (int i = 1, s = inputRank; i < s; i++) {
2311 allTheSame = false;
2312 break;
2313 }
2314 }
2315
2316
2317
2318 if (allTheSame) {
2319 outputShape.resize(inputRank, inputShape.getDimSize(0));
2321 return success();
2322 }
2323
2324 outputShape.resize(inputRank, ShapedType::kDynamic);
2325
2326
2327 if (llvm::any_of(adaptor.getPerms(),
2328 [inputRank](const auto i) { return i >= inputRank; }))
2329 return failure();
2330
2331 outputShape.reserve(inputRank);
2332 for (int i = 0, s = inputRank; i < s; i++) {
2333 outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
2334 }
2335
2337 return success();
2338 }
2339
2342 getOutput().getType())
2343 .failed()) {
2344 return failure();
2345 }
2346
2349
2351
2352 if (inputShape.hasRank() &&
2353 constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2354 return emitOpError() << "expected perms attribute to have size "
2355 << inputShape.getRank()
2356 << " (input rank) but got size "
2357 << constantPerms.size();
2358
2359 if (inputShape.hasRank() && outputShape.hasRank() &&
2360 inputShape.getRank() != outputShape.getRank())
2361 return emitOpError()
2362 << "expected input tensor rank to equal result tensor rank";
2363
2364 if (outputShape.hasRank() &&
2365 constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2366 return emitOpError() << "expected perms attribute to have size "
2367 << outputShape.getRank()
2368 << " (output rank) but got size "
2369 << constantPerms.size();
2370
2371 if (!llvm::all_of(constantPerms,
2372 [&constantPerms](int32_t s) {
2373 return s >= 0 &&
2374 static_cast<size_t>(s) < constantPerms.size();
2375 }) ||
2377 constantPerms, [](int32_t v) -> int64_t { return v; }))))
2378 return emitOpError() << "expected valid permutation indices";
2379
2380
2381 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2382 inputShape.getNumElements() != outputShape.getNumElements())
2383 return emitOpError() << "expected input1 and output to have same numbers "
2384 "of elements, got "
2385 << inputShape.getNumElements() << " and "
2386 << outputShape.getNumElements();
2387
2388
2389
2390 if (inputShape.hasRank() && outputShape.hasRank()) {
2391 for (auto i = 0; i < outputShape.getRank(); i++) {
2392 if (inputShape.isDynamicDim(constantPerms[i]) ||
2393 outputShape.isDynamicDim(i))
2394 continue;
2395
2396 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2397 return emitOpError()
2398 << "expected output tensor dim " << i << " to match "
2399 << "input dim " << constantPerms[i] << " with value of "
2400 << inputShape.getDimSize(constantPerms[i]);
2401 }
2402 }
2403
2404 return success();
2405 }
2406
2409
2411
2412 Value input = getInput1();
2413 auto inputType = cast(input.getType());
2414
2416 for (auto dim : transposePerms) {
2417 int32_t dimInInput = transposePerms[dim];
2418 if (inputType.isDynamicDim(dimInInput))
2419 returnedDims[dim] =
2420 builder.createtensor::DimOp(getLoc(), input, dimInInput)
2421 .getResult();
2422 else
2423 returnedDims[dim] =
2424 builder.getIndexAttr(inputType.getDimSize(dimInInput));
2425 }
2426
2427 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2428 return success();
2429 }
2430
2431 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2432 MLIRContext *context, ::std::optional location,
2433 GatherOp::Adaptor adaptor,
2436 outputShape.resize(3, ShapedType::kDynamic);
2437
2438 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2439 if (valuesShape.hasRank()) {
2440 outputShape[0] = valuesShape.getDimSize(0);
2441 outputShape[2] = valuesShape.getDimSize(2);
2442 }
2443
2444 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2445 if (indicesShape.hasRank()) {
2446 if (outputShape[0] == ShapedType::kDynamic)
2447 outputShape[0] = indicesShape.getDimSize(0);
2448 if (outputShape[1] == ShapedType::kDynamic)
2449 outputShape[1] = indicesShape.getDimSize(1);
2450 }
2451
2453 return success();
2454 }
2455
2458 getOutput().getType())
2459 .failed()) {
2460 return failure();
2461 }
2462
2466
2467 int64_t N = ShapedType::kDynamic;
2468 int64_t W = ShapedType::kDynamic;
2469 int64_t C = ShapedType::kDynamic;
2470
2471 if (valuesShape.hasRank()) {
2472 N = valuesShape.getDimSize(0);
2473 C = valuesShape.getDimSize(2);
2474 }
2475 if (indicesShape.hasRank()) {
2476 const int64_t indicesN = indicesShape.getDimSize(0);
2477 W = indicesShape.getDimSize(1);
2478 if (N == ShapedType::kDynamic)
2479 N = indicesN;
2480 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2481 return emitOpError() << "requires indices dimension 0 to have size " << N
2482 << ", got " << indicesN;
2483 }
2484 if (outputShape.hasRank()) {
2485 const int64_t outputN = outputShape.getDimSize(0);
2486 const int64_t outputW = outputShape.getDimSize(1);
2487 const int64_t outputC = outputShape.getDimSize(2);
2488 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2489 N != outputN)
2490 return emitOpError() << "requires output dimension 0 to have size " << N
2491 << ", got " << outputN;
2492
2493 if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2494 W != outputW)
2495 return emitOpError() << "requires output dimension 1 to have size " << W
2496 << ", got " << outputW;
2497 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2498 C != outputC)
2499 return emitOpError() << "requires output dimension 2 to have size " << C
2500 << ", got " << outputC;
2501 }
2502 return success();
2503 }
2504
2505 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2506 MLIRContext *context, ::std::optional location,
2507 ResizeOp::Adaptor adaptor,
2510 outputShape.resize(4, ShapedType::kDynamic);
2511
2512 ShapeAdaptor inputShape(adaptor.getInput().getType());
2513 if (!inputShape.hasRank())
2514 return failure();
2515
2516 outputShape[0] = inputShape.getDimSize(0);
2517 outputShape[3] = inputShape.getDimSize(3);
2518 int64_t inputHeight = inputShape.getDimSize(1);
2519 int64_t inputWidth = inputShape.getDimSize(2);
2520
2521 if ((inputHeight == ShapedType::kDynamic) ||
2522 (inputWidth == ShapedType::kDynamic))
2523 return failure();
2524
2527 scaleInt) ||
2529 offsetInt) ||
2531 borderInt)) {
2532 return failure();
2533 }
2534
2535
2536 outputShape[1] =
2537 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2538 scaleInt[1]) +
2539 1;
2540
2541 outputShape[2] =
2542 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2543 scaleInt[3]) +
2544 1;
2545
2547 return success();
2548 }
2549
2551 const Value input = getInput();
2552 const Value output = getOutput();
2553 const RankedTensorType inputType =
2554 llvm::dyn_cast(input.getType());
2555 const RankedTensorType outputType =
2556 llvm::dyn_cast(output.getType());
2557
2564
2565 return success();
2566 }
2567
2568 if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
2569 return emitOpError("expect all scale values to be > 0, got ")
2570 << scaleValues;
2571
2572 const int64_t scaleYN = scaleValues[0];
2573 const int64_t scaleYD = scaleValues[1];
2574 const int64_t scaleXN = scaleValues[2];
2575 const int64_t scaleXD = scaleValues[3];
2576
2577 const int64_t offsetY = offsetValues[0];
2578 const int64_t offsetX = offsetValues[1];
2579
2580 const int64_t borderY = borderValues[0];
2581 const int64_t borderX = borderValues[1];
2582
2583 if (!inputType)
2584 return success();
2585 if (!outputType)
2586 return success();
2587
2588 const int64_t oh = outputType.getDimSize(1);
2589 const int64_t ow = outputType.getDimSize(2);
2590 const int64_t ih = inputType.getDimSize(1);
2591 const int64_t iw = inputType.getDimSize(2);
2592
2593
2594
2595
2596
2597 if (ih != ShapedType::kDynamic && ih != 1) {
2598 const std::optional<int64_t> calculatedOutHeightMinusOne =
2599 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2600 if (!calculatedOutHeightMinusOne.has_value())
2601 return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
2602 "border_y ")
2603 << "to be wholly divisible by scale_y_d, got ((" << ih
2604 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
2605 << ") / " << scaleYD;
2606 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2607 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2608 return emitOpError("calculated output height did not match expected: ")
2609 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
2610 }
2611
2612
2613
2614
2615
2616 if (iw != ShapedType::kDynamic && iw != 1) {
2617 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2618 const std::optional<int64_t> calculatedOutWidthMinusOne =
2619 idivCheck(scaledInWidth, scaleXD);
2620 if (!calculatedOutWidthMinusOne.has_value())
2621 return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
2622 "border_x ")
2623 << "to be wholly divisible by scale_x_d, got ((" << iw
2624 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
2625 << ") / " << scaleXD;
2626 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2627 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2628 return emitOpError("calculated output width did not match expected: ")
2629 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
2630 }
2631
2632 return success();
2633 }
2634
2635 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2636 MLIRContext *context, ::std::optional location,
2637 ScatterOp::Adaptor adaptor,
2640 outputShape.resize(3, ShapedType::kDynamic);
2641
2642 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2643 if (valuesInShape.hasRank()) {
2644 outputShape[0] = valuesInShape.getDimSize(0);
2645 outputShape[1] = valuesInShape.getDimSize(1);
2646 outputShape[2] = valuesInShape.getDimSize(2);
2647 }
2648
2649 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2650 if (indicesShape.hasRank()) {
2651 if (outputShape[0] == ShapedType::kDynamic)
2652 outputShape[0] = indicesShape.getDimSize(0);
2653 }
2654
2655 ShapeAdaptor inputShape(adaptor.getInput().getType());
2656 if (inputShape.hasRank()) {
2657 if (outputShape[0] == ShapedType::kDynamic)
2658 outputShape[0] = inputShape.getDimSize(0);
2659 if (outputShape[2] == ShapedType::kDynamic)
2660 outputShape[2] = inputShape.getDimSize(2);
2661 }
2662
2664 return success();
2665 }
2666
2669 getValuesOut().getType())
2670 .failed() ||
2672 getValuesOut().getType())
2673 .failed()) {
2674 return failure();
2675 }
2676
2681
2682 int64_t N = ShapedType::kDynamic;
2683 int64_t K = ShapedType::kDynamic;
2684 int64_t W = ShapedType::kDynamic;
2685 int64_t C = ShapedType::kDynamic;
2686 if (valuesInShape.hasRank()) {
2687 N = valuesInShape.getDimSize(0);
2688 K = valuesInShape.getDimSize(1);
2689 C = valuesInShape.getDimSize(2);
2690 }
2691 if (indicesShape.hasRank()) {
2692 const int64_t indicesN = indicesShape.getDimSize(0);
2693 W = indicesShape.getDimSize(1);
2694 if (N == ShapedType::kDynamic)
2695 N = indicesN;
2696 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2697 return emitOpError() << "requires indices dimension 0 to have size " << N
2698 << ", got " << indicesN;
2699 }
2700 if (inputShape.hasRank()) {
2701 const int64_t inputN = inputShape.getDimSize(0);
2702 const int64_t inputW = inputShape.getDimSize(1);
2703 const int64_t inputC = inputShape.getDimSize(2);
2704 if (N == ShapedType::kDynamic)
2705 N = inputN;
2706 else if (inputN != ShapedType::kDynamic && N != inputN)
2707 return emitOpError() << "requires input dimension 0 to have size " << N
2708 << ", got " << inputN;
2709 if (W == ShapedType::kDynamic)
2710 W = inputW;
2711 else if (inputW != ShapedType::kDynamic && W != inputW)
2712 return emitOpError() << "requires input dimension 1 to have size " << W
2713 << ", got " << inputW;
2714
2715 if (C == ShapedType::kDynamic)
2716 C = inputC;
2717 else if (inputC != ShapedType::kDynamic && C != inputC)
2718 return emitOpError() << "requires input dimension 2 to have size " << C
2719 << ", got " << inputC;
2720 }
2721 if (outputShape.hasRank()) {
2722 const int64_t outputN = outputShape.getDimSize(0);
2723 const int64_t outputK = outputShape.getDimSize(1);
2724 const int64_t outputC = outputShape.getDimSize(2);
2725 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2726 N != outputN)
2727 return emitOpError() << "requires values_out dimension 0 to have size "
2728 << N << ", got " << outputN;
2729 if (K == ShapedType::kDynamic)
2730 K = outputK;
2731 else if (outputK != ShapedType::kDynamic && K != outputK)
2732 return emitOpError() << "requires values_out dimension 1 to have size "
2733 << K << ", got " << outputK;
2734 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2735 C != outputC)
2736 return emitOpError() << "requires values_out dimension 2 to have size "
2737 << C << ", got " << outputC;
2738 }
2739 if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2740 return emitOpError() << "requires dimensions K >= W, got K=" << K
2741 << " and W=" << W;
2742
2743 return success();
2744 }
2745
2747 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
2749 int64_t axisVal = axis.getValue().getSExtValue();
2750 if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
2752 return success();
2753 }
2754
2756 operandShape.getDims(outputShape);
2757 outputShape[axisVal] = 1;
2758 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2759 return success();
2760 }
2761
2762 #define COMPATIBLE_RETURN_TYPES(OP) \
2763 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2764 if (l.size() != r.size() || l.size() != 1) \
2765 return false; \
2766 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2767 return false; \
2768 return succeeded(verifyCompatibleShape(l[0], r[0])); \
2769 }
2770
2771 #define REDUCE_SHAPE_INFER(OP) \
2772 LogicalResult OP::inferReturnTypeComponents( \
2773 MLIRContext *context, ::std::optional location, \
2774 OP::Adaptor adaptor, \
2775 SmallVectorImpl &inferredReturnShapes) { \
2776 Type inputType = \
2777 llvm::cast(adaptor.getInput().getType()).getElementType(); \
2778 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2779 const Properties &prop = adaptor.getProperties(); \
2780 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2781 inferredReturnShapes); \
2782 } \
2783 COMPATIBLE_RETURN_TYPES(OP)
2784
2791 #undef REDUCE_SHAPE_INFER
2793 #undef COMPATIBLE_RETURN_TYPES
2794
2795 template
2797
2798 TensorType inputType = op.getInput().getType();
2799 TensorType outputType = op.getOutput().getType();
2800 int32_t reduceAxis = op.getAxis();
2801
2802 if (reduceAxis < 0) {
2803 op.emitOpError("reduce axis must not be negative");
2804 return failure();
2805 }
2806 if (inputType.hasRank()) {
2807 int64_t inputRank = inputType.getRank();
2808
2809
2810 if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
2811 op.emitOpError("expect input tensor rank (")
2812 << inputRank << ") to be larger than reduce axis (" << reduceAxis
2813 << ")";
2814 return failure();
2815 }
2816 }
2817 if (outputType.hasRank()) {
2818 int64_t outputRank = outputType.getRank();
2819 if (inputType.hasRank() && outputRank != inputType.getRank()) {
2820 op.emitOpError(
2821 "expect output tensor rank to be equal to input tensor rank");
2822 return failure();
2823 }
2824 if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
2825 op.emitOpError("expect output tensor rank (")
2826 << outputRank << ") to be larger than reduce axis (" << reduceAxis
2827 << ")";
2828 return failure();
2829 }
2830
2831
2832 if (outputRank != 0) {
2833 auto outputShape = outputType.getShape();
2834 if (!outputType.isDynamicDim(reduceAxis) &&
2835 outputShape[reduceAxis] != 1) {
2836 op.emitOpError("expect reduced dimension size to be 1, got ")
2837 << outputShape[reduceAxis];
2838 return failure();
2839 }
2840 }
2841 }
2842 return success();
2843 }
2844
2851
2858 } else {
2860 }
2861 return success();
2862 }
2863
2864 #define NARY_SHAPE_INFER(OP) \
2865 LogicalResult OP::inferReturnTypeComponents( \
2866 MLIRContext *context, ::std::optional location, \
2867 ValueShapeRange operands, DictionaryAttr attributes, \
2868 OpaqueProperties properties, RegionRange regions, \
2869 SmallVectorImpl &inferredReturnShapes) { \
2870 return NAryInferReturnTypes(operands, inferredReturnShapes); \
2871 }
2872
2910 #undef PRED_SHAPE_INFER
2911
2912 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2913 MLIRContext *context, ::std::optional location,
2914 NegateOp::Adaptor adaptor,
2916 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2918 return success();
2919 }
2920
2922
2923 const Type input1Type = getInput1().getType();
2924 const Type outputType = getOutput().getType();
2926 return failure();
2927
2928
2931 return emitOpError() << "requires the same shape for input1 and output";
2932
2934 const Type input1ZpEType =
2936 if (input1EType != input1ZpEType) {
2937 return emitOpError("expect both input1 and its zero point are the same "
2938 "element type, got ")
2939 << input1EType << " and " << input1ZpEType;
2940 }
2942 const Type outputZpEType =
2944 if (outputEType != outputZpEType) {
2945 return emitOpError("expect both output and its zero point are the same "
2946 "element type, got ")
2947 << outputEType << " and " << outputZpEType;
2948 }
2949
2950 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2951 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2952 return failure();
2953
2954 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2955 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2956 return failure();
2957
2958 return success();
2959 }
2960
2966 outputShape.resize(4, ShapedType::kDynamic);
2967
2968
2969 if (!inputShape) {
2971 return success();
2972 }
2973
2974
2975 outputShape[0] = inputShape.getDimSize(0);
2976 outputShape[3] = inputShape.getDimSize(3);
2977
2978 int64_t height = inputShape.getDimSize(1);
2979 int64_t width = inputShape.getDimSize(2);
2980
2981 if (!ShapedType::isDynamic(height)) {
2982 int64_t padded = height + pad[0] + pad[1] - kernel[0];
2983 outputShape[1] = padded / stride[0] + 1;
2984 }
2985
2986 if (!ShapedType::isDynamic(width)) {
2987 int64_t padded = width + pad[2] + pad[3] - kernel[1];
2988 outputShape[2] = padded / stride[1] + 1;
2989 }
2990
2992 return success();
2993 }
2994
2995 LogicalResult Conv2DOp::inferReturnTypeComponents(
2996 MLIRContext *context, ::std::optional location,
2997 Conv2DOp::Adaptor adaptor,
3000
3001 int64_t inputWidth = ShapedType::kDynamic;
3002 int64_t inputHeight = ShapedType::kDynamic;
3003 int64_t weightWidth = ShapedType::kDynamic;
3004 int64_t weightHeight = ShapedType::kDynamic;
3005
3006
3007
3008 ShapeAdaptor inputShape(adaptor.getInput().getType());
3009 if (inputShape.hasRank()) {
3010 outputShape[0] = inputShape.getDimSize(0);
3011 inputHeight = inputShape.getDimSize(1);
3012 inputWidth = inputShape.getDimSize(2);
3013 }
3014
3015
3016 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3017 if (weightShape.hasRank()) {
3018 outputShape[3] = weightShape.getDimSize(0);
3019 weightHeight = weightShape.getDimSize(1);
3020 weightWidth = weightShape.getDimSize(2);
3021 }
3022
3023
3024 ShapeAdaptor biasShape(adaptor.getBias().getType());
3025 if (biasShape.hasRank()) {
3026 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3027 ? biasShape.getDimSize(0)
3028 : outputShape[3];
3029 }
3030
3034
3035 if (!ShapedType::isDynamic(inputHeight) &&
3036 !ShapedType::isDynamic(weightHeight)) {
3037 int64_t inputSize = inputHeight + padding[0] + padding[1];
3038 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3039 int64_t unstridedResult = inputSize - filterSize + 1;
3040 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3041 }
3042
3043 if (!ShapedType::isDynamic(inputWidth) &&
3044 !ShapedType::isDynamic(weightWidth)) {
3045 int64_t inputSize = inputWidth + padding[2] + padding[3];
3046 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3047 int64_t unstridedResult = inputSize - filterSize + 1;
3048 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3049 }
3050
3052 return success();
3053 }
3054
3058 return failure();
3059 return success();
3060 }
3061
3062 LogicalResult Conv3DOp::inferReturnTypeComponents(
3063 MLIRContext *context, ::std::optional location,
3064 Conv3DOp::Adaptor adaptor,
3067
3068 int64_t inputWidth = ShapedType::kDynamic;
3069 int64_t inputHeight = ShapedType::kDynamic;
3070 int64_t inputDepth = ShapedType::kDynamic;
3071
3072 int64_t weightWidth = ShapedType::kDynamic;
3073 int64_t weightHeight = ShapedType::kDynamic;
3074 int64_t weightDepth = ShapedType::kDynamic;
3075
3076
3077 ShapeAdaptor inputShape(adaptor.getInput().getType());
3078 if (inputShape.hasRank()) {
3079 outputShape[0] = inputShape.getDimSize(0);
3080 inputDepth = inputShape.getDimSize(1);
3081 inputHeight = inputShape.getDimSize(2);
3082 inputWidth = inputShape.getDimSize(3);
3083 }
3084
3085
3086 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3087 if (weightShape.hasRank()) {
3088 outputShape[4] = weightShape.getDimSize(0);
3089 weightDepth = weightShape.getDimSize(1);
3090 weightHeight = weightShape.getDimSize(2);
3091 weightWidth = weightShape.getDimSize(3);
3092 }
3093
3094
3095 ShapeAdaptor biasShape(adaptor.getBias().getType());
3096 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3097 outputShape[4] = biasShape.getDimSize(0);
3098 }
3099
3103
3104 if (!ShapedType::isDynamic(inputDepth) &&
3105 !ShapedType::isDynamic(weightDepth)) {
3106 int32_t inputSize = inputDepth + pad[0] + pad[1];
3107 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3108 int32_t unstridedResult = inputSize - filterSize + 1;
3109 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3110 }
3111
3112 if (!ShapedType::isDynamic(inputHeight) &&
3113 !ShapedType::isDynamic(weightHeight)) {
3114 int32_t inputSize = inputHeight + pad[2] + pad[3];
3115 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3116 int32_t unstridedResult = inputSize - filterSize + 1;
3117 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3118 }
3119
3120 if (!ShapedType::isDynamic(inputWidth) &&
3121 !ShapedType::isDynamic(weightWidth)) {
3122 int32_t inputSize = inputWidth + pad[4] + pad[5];
3123 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3124 int32_t unstridedResult = inputSize - filterSize + 1;
3125 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3126 }
3127
3129 return success();
3130 }
3131
3135 return failure();
3136 return success();
3137 }
3138
3139 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3140 MLIRContext *context, ::std::optional location,
3141 AvgPool2dOp::Adaptor adaptor,
3143 ShapeAdaptor inputShape(adaptor.getInput().getType());
3144 const Properties &prop = adaptor.getProperties();
3146 inferredReturnShapes);
3147 }
3148
3149 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3150 MLIRContext *context, ::std::optional location,
3151 MaxPool2dOp::Adaptor adaptor,
3153 ShapeAdaptor inputShape(adaptor.getInput().getType());
3154 const Properties &prop = adaptor.getProperties();
3156 inferredReturnShapes);
3157 }
3158
3161 getOutput().getType())))
3162 return failure();
3163
3165 return failure();
3166
3167 return success();
3168 }
3169
3170 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3171 MLIRContext *context, ::std::optional location,
3172 DepthwiseConv2DOp::Adaptor adaptor,
3175
3176 int64_t inputWidth = ShapedType::kDynamic;
3177 int64_t inputHeight = ShapedType::kDynamic;
3178 int64_t inputChannels = ShapedType::kDynamic;
3179
3180 int64_t weightWidth = ShapedType::kDynamic;
3181 int64_t weightHeight = ShapedType::kDynamic;
3182 int64_t depthChannels = ShapedType::kDynamic;
3183
3184
3185 ShapeAdaptor inputShape(adaptor.getInput().getType());
3186 if (inputShape.hasRank()) {
3187 outputShape[0] = inputShape.getDimSize(0);
3188 inputHeight = inputShape.getDimSize(1);
3189 inputWidth = inputShape.getDimSize(2);
3190 inputChannels = inputShape.getDimSize(3);
3191 }
3192
3193
3194 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3195 if (weightShape.hasRank()) {
3196 weightHeight = weightShape.getDimSize(0);
3197 weightWidth = weightShape.getDimSize(1);
3198 inputChannels = ShapedType::isDynamic(inputChannels)
3199 ? weightShape.getDimSize(2)
3200 : inputChannels;
3201 depthChannels = weightShape.getDimSize(3);
3202 }
3203
3204
3205
3206 if (!ShapedType::isDynamic(inputChannels) &&
3207 !ShapedType::isDynamic(depthChannels)) {
3208 outputShape[3] = inputChannels * depthChannels;
3209 }
3210
3211
3212 ShapeAdaptor biasShape(adaptor.getBias().getType());
3213 if (biasShape.hasRank()) {
3214 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3215 ? biasShape.getDimSize(0)
3216 : outputShape[3];
3217 }
3218
3222
3223 if (!ShapedType::isDynamic(inputHeight) &&
3224 !ShapedType::isDynamic(weightHeight)) {
3225 int64_t inputSize = inputHeight + padding[0] + padding[1];
3226 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3227 int64_t unstridedResult = inputSize - filterSize + 1;
3228 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3229 }
3230
3231 if (!ShapedType::isDynamic(inputWidth) &&
3232 !ShapedType::isDynamic(weightWidth)) {
3233 int64_t inputSize = inputWidth + padding[2] + padding[3];
3234 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3235 int64_t unstridedResult = inputSize - filterSize + 1;
3236 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3237 }
3238
3240 return success();
3241 }
3242
3246 return failure();
3247 return success();
3248 }
3249
3250 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3251 MLIRContext *context, ::std::optional location,
3252 TransposeConv2DOp::Adaptor adaptor,
3255
3256 int64_t inputWidth = ShapedType::kDynamic;
3257 int64_t inputHeight = ShapedType::kDynamic;
3258 int64_t weightWidth = ShapedType::kDynamic;
3259 int64_t weightHeight = ShapedType::kDynamic;
3260
3261
3262 ShapeAdaptor inputShape(adaptor.getInput().getType());
3263 if (inputShape.hasRank()) {
3264 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3265 ? inputShape.getDimSize(0)
3266 : outputShape[0];
3267 inputHeight = inputShape.getDimSize(1);
3268 inputWidth = inputShape.getDimSize(2);
3269 }
3270
3271
3272 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3273 if (weightShape.hasRank()) {
3274 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3275 ? weightShape.getDimSize(0)
3276 : outputShape[3];
3277 weightHeight = weightShape.getDimSize(1);
3278 weightWidth = weightShape.getDimSize(2);
3279 }
3280
3281
3282 ShapeAdaptor biasShape(adaptor.getInput().getType());
3283 if (biasShape.hasRank()) {
3284 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3285 ? biasShape.getDimSize(0)
3286 : outputShape[3];
3287 }
3288
3291
3292 if (!ShapedType::isDynamic(inputHeight) &&
3293 !ShapedType::isDynamic(weightHeight)) {
3294 int64_t calculateSize =
3295 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3296 outputShape[1] =
3297 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3298 }
3299
3300 if (!ShapedType::isDynamic(inputWidth) &&
3301 !ShapedType::isDynamic(weightWidth)) {
3302 int64_t calculateSize =
3303 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3304 outputShape[2] =
3305 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3306 }
3307
3309 return success();
3310 }
3311
3314 return failure();
3315
3317 const int64_t strideY = strides[0];
3318 const int64_t strideX = strides[1];
3319
3320 if (strideY < 1 || strideX < 1)
3321 return emitOpError("expect all stride values to be >= 1, got [")
3322 << strides << "]";
3323
3324 const auto checkPadAgainstKernelDim =
3325 [this](int64_t pad_value, int64_t kernel_dim_size,
3326 llvm::StringRef pad_name,
3327 llvm::StringRef kernel_dim_name) -> LogicalResult {
3328 if (pad_value <= -kernel_dim_size)
3329 return emitOpError("expected ")
3330 << pad_name << " > -" << kernel_dim_name
3331 << ", but got: " << pad_name << "=" << pad_value << " and "
3332 << kernel_dim_name << "=" << kernel_dim_size;
3333 return success();
3334 };
3335
3337 const int64_t outPadTop = padding[0];
3338 const int64_t outPadBottom = padding[1];
3339 const int64_t outPadLeft = padding[2];
3340 const int64_t outPadRight = padding[3];
3341
3342 const auto weightType =
3343 llvm::dyn_cast(getWeight().getType());
3344
3345 if (weightType) {
3346 const int64_t kernelHeight = weightType.getDimSize(1);
3347 if (!ShapedType::isDynamic(kernelHeight)) {
3348 if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3349 "out_pad_top", "KH")))
3350 return failure();
3351
3352 if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3353 "out_pad_bottom", "KH")))
3354 return failure();
3355 }
3356
3357 const int64_t kernelWidth = weightType.getDimSize(2);
3358 if (!ShapedType::isDynamic(kernelWidth)) {
3359 if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3360 "out_pad_left", "KW")))
3361 return failure();
3362
3363 if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3364 "out_pad_right", "KW")))
3365 return failure();
3366 }
3367 }
3368
3369
3370 const auto outputType =
3371 llvm::dyn_cast(getOutput().getType());
3372 if (!outputType)
3373 return success();
3374
3375 const auto inputType = llvm::dyn_cast(getInput().getType());
3376 if (inputType && weightType) {
3377 const int64_t inputHeight = inputType.getDimSize(1);
3378 const int64_t kernelHeight = weightType.getDimSize(1);
3379 const int64_t outputHeight = outputType.getDimSize(1);
3380
3381 if (!ShapedType::isDynamic(inputHeight) &&
3382 !ShapedType::isDynamic(outputHeight)) {
3383 if (outputHeight !=
3384 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3385 return emitOpError(
3386 "dimension mismatch: expected OH == (IH - 1) * stride_y "
3387 "+ out_pad_top + out_pad_bottom + KH, but got ")
3388 << outputHeight << " != (" << inputHeight << " - 1) * "
3389 << strideY << " + " << outPadTop << " + " << outPadBottom
3390 << " + " << kernelHeight;
3391 }
3392
3393 const int64_t inputWidth = inputType.getDimSize(2);
3394 const int64_t kernelWidth = weightType.getDimSize(2);
3395 const int64_t outputWidth = outputType.getDimSize(2);
3396
3397 if (!ShapedType::isDynamic(inputWidth) &&
3398 !ShapedType::isDynamic(outputWidth)) {
3399 if (outputWidth !=
3400 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3401 return emitOpError(
3402 "dimension mismatch: expected OW == (IW - 1) * stride_x "
3403 "+ out_pad_left + out_pad_right + KW, but got ")
3404 << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3405 << " + " << outPadLeft << " + " << outPadRight << " + "
3406 << kernelWidth;
3407 }
3408 }
3409
3410 const auto biasType = llvm::dyn_cast(getBias().getType());
3411
3412 if (!biasType)
3413 return success();
3414
3415 const int64_t biasChannels = biasType.getDimSize(0);
3416
3417
3418 if (biasChannels == ShapedType::kDynamic)
3419 return success();
3420
3421 const int64_t outputChannels = outputType.getDimSize(3);
3422 if (biasChannels != outputChannels && biasChannels != 1)
3423 return emitOpError(
3424 "bias channels expected to be equal to output channels (")
3425 << outputChannels << ") or 1, got " << biasChannels;
3426
3427 return success();
3428 }
3429
3431 auto inputType = llvm::dyn_cast(getInput().getType());
3432 if (!inputType) {
3433 emitOpError("expect shaped tensor for input, got ") << getInput().getType();
3434 return failure();
3435 }
3436
3437 auto inputElementType =
3439 if (!mlir::isa(inputElementType)) {
3440 emitOpError("expect input to have integer element type, got ")
3441 << inputElementType;
3442 return failure();
3443 }
3444
3445 auto outputType = llvm::dyn_cast(getOutput().getType());
3446 if (!outputType) {
3447 emitOpError("expect shaped tensor for output, got ")
3448 << getOutput().getType();
3449 return failure();
3450 }
3451
3452 auto outputElementType =
3454 if (!mlir::isa(outputElementType)) {
3455 emitOpError("expect output to have integer element type, got ")
3456 << outputElementType;
3457 return failure();
3458 }
3459
3461 .failed())
3462 return failure();
3463
3465 .failed())
3466 return failure();
3467
3468 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3469 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3470 return failure();
3471
3472 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3473 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3474 return failure();
3475
3476 auto multiplierType = llvm::dyn_cast(getMultiplier().getType());
3477 if (!multiplierType) {
3478 emitOpError("expect shaped tensor for multiplier, got ")
3479 << getMultiplier().getType();
3480 return failure();
3481 }
3482
3483 auto shiftType = llvm::dyn_cast(getShift().getType());
3484 if (!shiftType) {
3485 emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
3486 return failure();
3487 }
3488
3489
3490 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3491 emitOpError("expect i32 element type for multiplier for scale32=true, got ")
3492 << multiplierType.getElementType();
3493 return failure();
3494 }
3495
3496
3497 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3498 emitOpError(
3499 "expect i16 element type for multiplier for scale32=false, got ")
3500 << multiplierType.getElementType();
3501 return failure();
3502 }
3503
3504 if (!inputType.hasRank())
3505 return success();
3506
3507
3508
3509
3510 int64_t numChannels = 1;
3511 if (getPerChannel()) {
3512 if (inputType.getRank() < 1) {
3513 emitOpError("requires input to be at least rank 1 when per_channel is "
3514 "true, but got rank ")
3515 << inputType.getRank();
3516 return failure();
3517 }
3518 numChannels = inputType.getDimSize(inputType.getRank() - 1);
3519 }
3520
3521 if (!multiplierType.hasRank())
3522 return success();
3523
3525
3526 if (multiplierShape[0] != ShapedType::kDynamic &&
3527 multiplierShape[0] != numChannels) {
3528 emitOpError("expect shape of { ")
3529 << numChannels << " } for multiplier input, got { "
3530 << multiplierShape[0] << " }";
3531 return failure();
3532 }
3533
3534 if (!shiftType.hasRank())
3535 return success();
3536
3538
3539 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3540 emitOpError("expect shape of { ")
3541 << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
3542 return failure();
3543 }
3544
3545 return success();
3546 }
3547
3548 LogicalResult RescaleOp::inferReturnTypeComponents(
3549 MLIRContext *context, ::std::optional location,
3550 RescaleOp::Adaptor adaptor,
3552 ShapeAdaptor inputShape(adaptor.getInput().getType());
3554 return success();
3555 }
3556
3557 LogicalResult IfOp::inferReturnTypeComponents(
3558 MLIRContext *context, ::std::optional location,
3559 IfOp::Adaptor adaptor,
3562 for (Region *region : adaptor.getRegions()) {
3563 for (auto &block : *region)
3564 if (auto returnOp = dyn_casttosa::YieldOp(block.getTerminator()))
3565 yieldOps.push_back(returnOp);
3566 }
3567
3568 if (yieldOps.empty())
3569 return failure();
3570
3571
3573 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3574 for (auto operand : yieldOps.front().getOperands()) {
3575 resultKnowledge.push_back(
3577 }
3578
3579 for (auto yieldOp : yieldOps) {
3580 if (resultKnowledge.size() != yieldOp.getNumOperands())
3581 return failure();
3582
3583 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3584 int32_t index = it.index();
3586 resultKnowledge[index],
3588 if (!meet)
3589 continue;
3590 resultKnowledge[index] = meet;
3591 }
3592 }
3593
3594 for (const ValueKnowledge &result : resultKnowledge) {
3595 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3596 }
3597
3598 return success();
3599 }
3600
3601 LogicalResult WhileOp::inferReturnTypeComponents(
3602 MLIRContext *context, ::std::optional location,
3603 WhileOp::Adaptor adaptor,
3606 for (auto &block : adaptor.getBodyGraph())
3607 if (auto returnOp = dyn_casttosa::YieldOp(block.getTerminator()))
3608 yieldOps.push_back(returnOp);
3609
3610
3611
3612 if (yieldOps.empty())
3613 return failure();
3614
3615
3617 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3618 for (auto operand : yieldOps.front().getOperands()) {
3619 resultKnowledge.push_back(
3621 }
3622
3623 for (auto yieldOp : yieldOps) {
3624 if (resultKnowledge.size() != yieldOp.getNumOperands())
3625 return failure();
3626
3627 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3628 int32_t index = it.index();
3630 resultKnowledge[index],
3632 resultKnowledge[index] = meet;
3633 }
3634 }
3635 }
3636
3637 for (const ValueKnowledge &result : resultKnowledge) {
3638 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3639 }
3640
3641 return success();
3642 }
3643
3644 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3645 if (auto vt = llvm::dyn_cast(getType()))
3646 return llvm::to_vector<4>(vt.getShape());
3647 return std::nullopt;
3648 }
3649
3650
3652
3653 result.regions.reserve(2);
3656
3657 auto &builder = parser.getBuilder();
3659
3663 return failure();
3664
3666 return failure();
3667
3668 if (parser.parseRegion(*thenRegion, {}, {}))
3669 return failure();
3670
3671
3673 if (parser.parseRegion(*elseRegion, {}, {}))
3674 return failure();
3675 }
3676
3677
3679 return failure();
3680 return success();
3681 }
3682
3684 bool printBlockTerminators = false;
3685
3686 p << " " << getCondition();
3687 if (!getResults().empty()) {
3688 p << " -> (" << getResultTypes() << ")";
3689
3690 printBlockTerminators = true;
3691 }
3692 p << ' ';
3694 false,
3695 printBlockTerminators);
3696
3697
3698 auto &elseRegion = getElseGraph();
3699 if (!elseRegion.empty()) {
3700 p << " else ";
3702 false,
3703 printBlockTerminators);
3704 }
3705
3707 }
3708
3711 "'then_graph' arguments", getInputList(),
3712 "'input_list'")
3713 .failed())
3714 return failure();
3715
3717 "'else_graph' arguments", getInputList(),
3718 "'input_list'")
3719 .failed())
3720 return failure();
3721
3722 auto thenYield = casttosa::YieldOp(getThenGraph().front().getTerminator());
3724 "'then_graph' results", getOutputList(),
3725 "'output_list'")
3726 .failed())
3727 return failure();
3728
3729 auto elseYield = casttosa::YieldOp(getElseGraph().front().getTerminator());
3731 "'else_graph' results", getOutputList(),
3732 "'output_list'")
3733 .failed())
3734 return failure();
3735
3736 auto condType = getCondition().getType();
3738 return emitOpError() << "'condition' must be a size 1 tensor, got "
3739 << condType;
3740
3741 return success();
3742 }
3743
3746 getOutputList(), "'output_list'")
3747 .failed())
3748 return failure();
3749
3751 "'cond_graph' arguments", getInputList(),
3752 "'input_list'")
3753 .failed())
3754 return failure();
3755
3757 "'body_graph' arguments", getInputList(),
3758 "'input_list'")
3759 .failed())
3760 return failure();
3761
3762 auto bodyYield = casttosa::YieldOp(getBodyGraph().front().getTerminator());
3764 "'body_graph' results", getInputList(),
3765 "'input_list'")
3766 .failed())
3767 return failure();
3768
3769
3770
3771 auto condYield = casttosa::YieldOp(getCondGraph().front().getTerminator());
3772 if (condYield.getInputs().size() != 1)
3773 return emitOpError() << "require 'cond_graph' only have one result";
3774
3775 auto condOutType = condYield.getInputs()[0].getType();
3777 return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
3778 << condOutType;
3779
3781 return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
3782 << condOutType;
3783
3784 return success();
3785 }
3786
3789 getOutput().getType())
3790 .failed())
3791 return failure();
3792 TensorType inputType = getInput1().getType();
3793 TensorType outputType = getOutput().getType();
3794 int32_t reverseAxis = getAxis();
3795
3796 if (reverseAxis < 0)
3797 return emitOpError("expected non-negative reverse axis");
3798 if (inputType.hasRank()) {
3799 int64_t inputRank = inputType.getRank();
3800
3801
3802 if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
3803 return emitOpError("expect input tensor rank (")
3804 << inputRank << ") to be larger than reverse axis (" << reverseAxis
3805 << ")";
3806 }
3807 if (outputType.hasRank()) {
3808 int64_t outputRank = outputType.getRank();
3809 if (inputType.hasRank() && outputRank != inputType.getRank())
3810 return emitOpError(
3811 "expect output tensor rank to be equal to input tensor rank");
3812 if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
3813 return emitOpError("expect output tensor rank (")
3814 << outputRank << ") to be larger than reverse axis ("
3815 << reverseAxis << ")";
3816 }
3817 return success();
3818 }
3819
3821
3823 getOutput().getType())
3824 .failed() ||
3826 getOutput().getType())
3827 .failed()) {
3828 return failure();
3829 }
3830
3831 auto predicateType = llvm::dyn_cast(getInput1().getType());
3832 if (!predicateType) {
3833 return emitOpError("expect shaped tensor for input1, got ")
3834 << getInput1().getType();
3835 }
3836 auto predicateElementType = predicateType.getElementType();
3837 if (!predicateElementType.isInteger(1)) {
3838 return emitOpError("expect element type of bool for input1, got ")
3839 << predicateElementType;
3840 }
3841
3842 return success();
3843 }
3844
3846 StringRef symName = getName();
3847 FailureOrtosa::VariableOp varOp = findVariableDecl(*this, symName);
3848 if (succeeded(varOp))
3849 return emitOpError("illegal to have multiple declaration of '")
3850 << symName << "'";
3851
3852 return success();
3853 }
3854
3857 .failed())
3858 return failure();
3859
3860 return success();
3861 }
3862
3865 .failed())
3866 return failure();
3867
3868 return success();
3869 }
3870
3871
3877
3880 if (listResult.has_value() && failed(listResult.value()))
3881 return failure();
3882
3883 FunctionType functionType;
3886 return failure();
3887
3888 result.addTypes(functionType.getResults());
3889
3890 if (functionType.getNumInputs() != operands.size()) {
3891 return parser.emitError(typeLoc)
3892 << "expected as many input types as operands "
3893 << "(expected " << operands.size() << " got "
3894 << functionType.getNumInputs() << ")";
3895 }
3896
3897
3898 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3901 return failure();
3902
3903
3904 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3905 regionArgs[i].type = functionType.getInput(i);
3906
3907 return failure(parser.parseRegion(*cond, regionArgs) ||
3910 }
3911
3915 StringRef prefix = "") {
3916 assert(blocksArgs.size() == initializers.size() &&
3917 "expected same length of arguments and initializers");
3918 if (initializers.empty())
3919 return;
3920
3921 parser << prefix << '(';
3922 llvm::interleaveComma(
3923 llvm::zip(blocksArgs, initializers), parser,
3924 [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
3925 parser << ")";
3926 }
3927
3930 getInputList(), " ");
3931 parser << " : ";
3933 getResults().getTypes());
3934 parser << ' ';
3935 parser.printRegion(getCondGraph(), false);
3936 parser << " do ";
3939 }
3940
3941
3944 Type srcElemType,
3945 int64_t zp) {
3948 if (llvm::isa(srcElemType)) {
3950 zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
3951 return builder.createtosa::ConstOp(loc, zpType, zpAttr);
3952 }
3953 if (llvm::isa(srcElemType)) {
3954 auto zpAttr =
3956 return builder.createtosa::ConstOp(loc, zpType, zpAttr);
3957 }
3958 llvm::errs() << "zero point is not allowed for unsupported data types\n";
3959 return std::nullopt;
3960 }
3961
3962
3963
3964
3965
3967 return mlir::isatosa::shapeType(t);
3968 }
3969
3970 LogicalResult
3972 int rank) {
3973 if (rank < 0)
3974 return emitError() << "invalid rank (must be >= 0): " << rank;
3975 return success();
3976 }
3977
3980 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
3981 Operation *definingOp = v.getDefiningOp();
3983 return op->emitOpError("shape operand is not compile time resolvable");
3984 }
3985 }
3986 }
3987 return success();
3988 }
3989
3992 if (!mlir::isamlir::tosa::shapeType(type)) {
3993 return op->emitOpError("must have operands with tosa shape type");
3994 }
3995 }
3997 if (!mlir::isamlir::tosa::shapeType(type)) {
3998 return op->emitOpError("must have result with tosa shape type");
3999 }
4000 }
4001 return success();
4002 }
4003
4004 LogicalResult
4008 return failure();
4009
4010
4011 auto getRank = [](const Type type) {
4012 return mlir::castmlir::tosa::shapeType(type).getRank();
4013 };
4016
4018 for (auto type : operandTypes) {
4019 if (getRank(type) != rank) {
4020 return op->emitOpError("operands don't have matching ranks");
4021 }
4022 }
4023 for (auto type : resultTypes) {
4024 if (getRank(type) != rank) {
4025 return op->emitOpError("result shape has different rank than operands");
4026 }
4027 }
4028 return success();
4029 }
4030
4031
4032
4033
4034
4036
4037 auto valuesRank = getValues().getType().getRank();
4038 if (valuesRank != 1)
4039 return emitOpError("expect elements in attribute values with rank 1");
4040
4041 auto count = getValues().getNumElements();
4042 auto rank = (casttosa::shapeType(getResult().getType())).getRank();
4043 if (!(count == rank || (count == 1 && rank == 0))) {
4044 return emitOpError("expect number of elements in attribute values (")
4045 << count << ") to be equal to the rank (" << rank
4046 << ") for the result shape type";
4047 }
4048 return success();
4049 }
4050
4051
4052
4053
4054
4055 #define GET_ATTRDEF_CLASSES
4056 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4057
4058
4059
4060
4061 #define GET_TYPEDEF_CLASSES
4062 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4063
4064
4065
4066
4067
4068 #define GET_OP_CLASSES
4069 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static FailureOr< tosa::VariableOp > findVariableDecl(Operation *op, StringRef symName)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static LogicalResult verifyConvOpErrorIf(T op)
static LogicalResult verifyConvOpModes(T op)
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Type getStorageElementTypeOrSelf(Type type)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
static LogicalResult verifyPoolingOp(T op)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class indicates that op operates on tosa shape types.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperator(Operation *op)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)