MLIR: lib/Dialect/Tosa/IR/TosaOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

29 #include "llvm/ADT/APFloat.h"

30 #include "llvm/ADT/DenseMap.h"

31 #include "llvm/ADT/TypeSwitch.h"

32

33 #include

34

35 using namespace mlir;

37

38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"

40

41

42

43

44

45 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"

46 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"

47 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"

48 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"

49

50 namespace {

51 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"

52

53

54

55

58

59

60

61

62

63

66 return true;

67 }

68

69

72 return (isatosa::IfOp(dest->getParentOp()) ||

73 isatosa::WhileOp(dest->getParentOp()));

74 }

75 };

76

77

79 TosaDialectBytecodeInterface(Dialect *dialect)

81

82

83

84

86 return ::readAttribute(getContext(), reader);

87 }

88

89 LogicalResult writeAttribute(Attribute attr,

91 return ::writeAttribute(attr, writer);

92 }

93

94

95

96

98 return ::readType(getContext(), reader);

99 }

100

101 LogicalResult writeType(Type type,

103 return ::writeType(type, writer);

104 }

105

107

108 }

109

110 std::unique_ptr

112

113 reader.emitError("Dialect does not support versioning");

114 return nullptr;

115 }

116

117 LogicalResult upgradeFromVersion(Operation *topLevelOp,

119 return success();

120 }

121 };

122

123 }

124

125

126

127

128

129

131 return {&getBodyGraph()};

132 }

133

134

135

136

137

139 return to_vector(llvm::map_range(shape, [](int64_t dim) {

140 return dim == -1 ? ShapedType::kDynamic : dim;

141 }));

142 }

143

144

146 Type elementType = variableOp.getType();

148 auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));

150 }

151

152

153

154

155

156 void TosaDialect::initialize() {

157 addTypes<

158 #define GET_TYPEDEF_LIST

159 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"

160 >();

161 addOperations<

162 #define GET_OP_LIST

163 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"

164 >();

165 addAttributes<

166 #define GET_ATTRDEF_LIST

167 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"

168 >();

169 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();

170 declarePromisedInterfaces<

171 mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,

172 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,

173 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,

174 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,

175 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,

176 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,

177 GreaterEqualOp, MatMulOp>();

178 }

179

182

183

184 if (llvm::isa(type) && llvm::isa(value)) {

185 return builder.createtosa::ConstShapeOp(

186 loc, type, llvm::cast(value));

187 }

188 if (llvm::isa(value))

189 return builder.createtosa::ConstOp(loc, type,

190 llvm::cast(value));

191 return nullptr;

192 }

193

194

195

196

197

198 namespace {

199

200 ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,

202 TypeAttr &typeAttr) {

203 if (auto shapedType = dyn_cast(parsedType)) {

204 if (!shapedType.hasRank())

206 << "expected ranked type";

207

208 auto elementType = shapedType.getElementType();

213 return success();

214 }

216 << "expected shaped type";

217 }

218

219 }

220

221

222

223

224

225

226

231 if (failed(parser.parseAttribute(initialValueAttr))) {

233 << "expected attribute";

234 }

235 if (auto typedAttr = dyn_cast(initialValueAttr)) {

236 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,

237 typeAttr);

238 }

240 << "expected Typed attr";

241 }

242

243 initialValueAttr = nullptr;

244 Type parsedType;

247 << "expected type after colon";

248 }

249 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);

250 }

251

254 TypeAttr typeAttr, Attribute initialValueAttr) {

255 bool needsSpace = false;

256 if (!dyn_cast_or_null(initialValueAttr)) {

257 auto shape =

259 Type elementType = typeAttr.getValue();

260 RankedTensorType tensorType =

263 p << ": ";

265 needsSpace = true;

266 }

267 if (initialValueAttr) {

268 if (needsSpace)

269 p << ' ';

270 p << "= ";

272 }

273 }

274

275

276

277

278

279 std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {

280 if (lhs % rhs != 0)

281 return std::nullopt;

282 return lhs / rhs;

283 }

284

287 if (auto quantType = llvm::dyn_castmlir::quant::QuantizedType(srcType))

288 srcType = quantType.getStorageType();

289 return srcType;

290 }

291

294 }

295

297 Value valZp, StringRef name) {

300

301 bool bothInts =

302 mlir::isa(eType) && mlir::isa(eZpType);

303 bool sameBitWidth =

305

306 if (!bothInts || !sameBitWidth) {

308 << "expected " << name << " and " << name

309 << "_zp to both be integer of the same bitwidth, but got " << eType

310 << " vs. " << eZpType;

311 }

312 return success();

313 }

314

315

317 Value src, int32_t val) {

322 const auto padConstAttr{

323 llvm::isa(srcElemType)

328 return builder.createtosa::ConstOp(loc, padConstType, padConstAttr);

329 }

330

331

332

333

334

335 template

337 const auto inputType = llvm::dyn_cast(op.getInput().getType());

338 const auto weightType = llvm::dyn_cast(op.getWeight().getType());

339

340 auto inputEType = inputType.getElementType();

341 auto weightEType = weightType.getElementType();

342 auto biasEType =

343 llvm::cast(op.getBias().getType()).getElementType();

344 auto resultEType =

345 llvm::cast(op.getResult().getType()).getElementType();

346 bool biasIsFloat = llvm::isa(biasEType);

347 bool resultIsFloat = llvm::isa(resultEType);

348

349 if (auto quantType = llvm::dyn_castmlir::quant::QuantizedType(inputEType))

350 inputEType = quantType.getStorageType();

351

352 if (auto quantType = llvm::dyn_castmlir::quant::QuantizedType(weightEType))

353 weightEType = quantType.getStorageType();

354

355 if (auto quantType = llvm::dyn_castmlir::quant::QuantizedType(biasEType))

356 biasEType = quantType.getStorageType();

357

358 if (auto quantType = llvm::dyn_castmlir::quant::QuantizedType(resultEType))

359 resultEType = quantType.getStorageType();

360

361 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {

362

363

364 op.emitOpError(

365 "expect both bias and result to have same element type, got ")

366 << biasEType << " and " << resultEType;

367 return failure();

368 }

369

370 if (isa(inputEType) || isa(inputEType) ||

371 isa(weightEType) || isa(weightEType)) {

372 if (inputEType != weightEType) {

373 op.emitOpError(

374 "expect both input and weight to have same element type, got ")

375 << inputEType << " and " << weightEType;

376 return failure();

377 }

378 }

379

380 bool inputIsFloat = llvm::isa(inputEType);

381 bool weightIsFloat = llvm::isa(weightEType);

382

383

384 if (inputIsFloat != weightIsFloat) {

385 op.emitOpError(

386 "expect both input and weight to be float or not together, got ")

387 << inputEType << " and " << weightEType;

388 return failure();

389 }

390

392 if (inputEType != inputZpEType) {

393 return op.emitOpError("expect both input and its zero point are the same "

394 "element type, got ")

395 << inputEType << " and " << inputZpEType;

396 }

397

399 if (weightEType != weightZpEType) {

400 return op.emitOpError("expect both weight and its zero point are the same "

401 "element type, got ")

402 << weightEType << " and " << weightZpEType;

403 }

404

405 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();

406 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())

407 return failure();

408

409 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();

410 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())

411 return failure();

412

413 return success();

414 }

415

417

418 auto attrType = llvm::dyn_cast(getValuesAttr().getType());

419 auto outputType = llvm::dyn_cast(getOutput().getType());

420

421 if (!attrType || !outputType) {

422 emitOpError("expected tensors for attr/result type");

423 return failure();

424 }

425

426 if (auto result = llvm::dyn_castmlir::quant::QuantizedType(

427 outputType.getElementType())) {

428 if (result.getStorageType() == attrType.getElementType())

429 return success();

430 }

431

432 if (attrType.getElementType() != outputType.getElementType()) {

433 emitOpError("expected same attr/result element types");

434 return failure();

435 }

436

437 return success();

438 }

439

440 template

442 auto inputEType =

443 llvm::cast(op.getInput().getType()).getElementType();

444

445 if (auto quantType = llvm::dyn_castmlir::quant::QuantizedType(inputEType))

446 inputEType = quantType.getStorageType();

447

448 auto accType = op.getAccType();

449 if (inputEType.isInteger(8) && !accType.isInteger(32))

450 return op.emitOpError("accumulator type for i8 tensor is not i32");

451

452 if (inputEType.isInteger(16) && !accType.isInteger(48))

453 return op.emitOpError("accumulator type for i16 tensor is not i48");

454

455 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())

456 return op.emitOpError("accumulator type for f8 tensor is not f16");

457

458 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))

459 return op.emitOpError("accumulator type for f16 tensor is not f16/f32");

460

461 if (inputEType.isBF16() && !accType.isF32())

462 return op.emitOpError("accumulator type for bf16 tensor is not f32");

463

464 if (inputEType.isF32() && !accType.isF32())

465 return op.emitOpError("accumulator type for f32 tensor is not f32");

466

467 auto resultEType =

468 llvm::cast(op.getResult().getType()).getElementType();

469

470 if (auto quantType = llvm::dyn_castmlir::quant::QuantizedType(resultEType))

471 resultEType = quantType.getStorageType();

472

473 return success();

474 }

475

476

477

478

479

480

481 template

484 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))

485 return op.emitOpError("expect all padding values to be >= 0, got ")

486 << padding;

487

489 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))

490 return op.emitOpError("expect all stride values to be >= 1, got ")

491 << strides;

492

494 if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))

495 return op.emitOpError("expect all dilation values to be >= 1, got ")

496 << dilations;

497

498 const RankedTensorType outputType =

499 llvm::dyn_cast(op.getOutput().getType());

500 if (!outputType)

501

502 return success();

503

504 const RankedTensorType inputType =

505 llvm::dyn_cast(op.getInput().getType());

506 const RankedTensorType weightType =

507 llvm::dyn_cast(op.getWeight().getType());

508

509 if (inputType && weightType) {

510 const auto verifyOutputSize =

511 [&op](const int64_t inputSize, const int64_t kernelSize,

512 const int64_t outputSize, const int64_t padBefore,

513 const int64_t padAfter, const int64_t stride,

514 const int64_t dilation, const llvm::StringRef dimName,

515 const llvm::StringRef dimAxis,

516 const llvm::StringRef padBeforeName,

517 const llvm::StringRef padAfterName) -> LogicalResult {

518 if (inputSize == ShapedType::kDynamic ||

519 kernelSize == ShapedType::kDynamic)

520 return success();

521

522

523

524 const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(

525 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,

526 stride);

527 if (!calculatedOutSizeMinusOne.has_value())

528 return op.emitOpError("expected input_")

529 << dimName << " - 1 + pad_" << padBeforeName << " + pad_"

530 << padAfterName << " - (kernel_" << dimName

531 << " - 1) * dilation_" << dimAxis

532 << " to be wholly divisible by stride_" << dimAxis << ", got ("

533 << inputSize << " - 1 + " << padBefore << " + " << padAfter

534 << " - (" << kernelSize << " - 1) * " << dilation << ") / "

535 << stride;

536

537 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;

538 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)

539 return op.emitOpError("calculated output ")

540 << dimName << " did not match expected: "

541 << "calculated=" << calculatedOutSize

542 << ", expected=" << outputSize;

543

544 return success();

545 };

546

547

548 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {

549 if (failed(verifyOutputSize(

550 inputType.getDimSize(1), weightType.getDimSize(1),

551 outputType.getDimSize(1), padding[0], padding[1], strides[0],

552 dilations[0], "height", "y", "top", "bottom")))

553 return failure();

554

555 if (failed(verifyOutputSize(

556 inputType.getDimSize(2), weightType.getDimSize(2),

557 outputType.getDimSize(2), padding[2], padding[3], strides[1],

558 dilations[1], "width", "x", "left", "right")))

559 return failure();

560 }

561

562

563 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {

564 if (failed(verifyOutputSize(

565 inputType.getDimSize(1), weightType.getDimSize(0),

566 outputType.getDimSize(1), padding[0], padding[1], strides[0],

567 dilations[0], "height", "y", "top", "bottom")))

568 return failure();

569

570 if (failed(verifyOutputSize(

571 inputType.getDimSize(2), weightType.getDimSize(1),

572 outputType.getDimSize(2), padding[2], padding[3], strides[1],

573 dilations[1], "width", "x", "left", "right")))

574 return failure();

575 }

576

577

578 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {

579 if (failed(verifyOutputSize(

580 inputType.getDimSize(1), weightType.getDimSize(1),

581 outputType.getDimSize(1), padding[0], padding[1], strides[0],

582 dilations[0], "depth", "d", "front", "back")))

583 return failure();

584

585 if (failed(verifyOutputSize(

586 inputType.getDimSize(2), weightType.getDimSize(2),

587 outputType.getDimSize(2), padding[2], padding[3], strides[1],

588 dilations[1], "height", "y", "top", "bottom")))

589 return failure();

590

591 if (failed(verifyOutputSize(

592 inputType.getDimSize(3), weightType.getDimSize(3),

593 outputType.getDimSize(3), padding[4], padding[5], strides[2],

594 dilations[2], "width", "x", "left", "right")))

595 return failure();

596 }

597 }

598

599 const RankedTensorType biasType =

600 llvm::dyn_cast(op.getBias().getType());

601 if (!biasType)

602

603 return success();

604

605 const int64_t biasChannels = biasType.getDimSize(0);

606 const int64_t outputChannels =

607 outputType.getDimSize(outputType.getRank() - 1);

608 if (biasChannels == ShapedType::kDynamic ||

609 outputChannels == ShapedType::kDynamic)

610

611 return success();

612

613 if (biasChannels != outputChannels && biasChannels != 1)

614 return op.emitOpError(

615 "bias channels expected to be equal to output channels (")

616 << outputChannels << ") or 1, got " << biasChannels;

617

618 return success();

619 }

620

621

623 StringRef name1, Type type2,

624 StringRef name2) {

625 auto shapeType1 = dyn_cast(type1);

626 auto shapeType2 = dyn_cast(type2);

627 if (!shapeType1 || !shapeType2)

628 return failure();

629

630 auto elemType1 = shapeType1.getElementType();

631 auto elemType2 = shapeType2.getElementType();

632 if (elemType1 != elemType2)

634 << "require same element type for " << name1 << " (" << elemType1

635 << ") and " << name2 << " (" << elemType2 << ")";

636

639 << "require same shapes for " << name1 << " (" << type1 << ") and "

640 << name2 << " (" << type2 << ")";

641

642 return success();

643 }

644

645

647 StringRef name1,

649 StringRef name2) {

650 if (list1.size() != list2.size())

652 << "require same number of values in " << name1 << " ("

653 << list1.size() << ") and " << name2 << " (" << list2.size() << ")";

654

655 for (auto [type1, type2] :

658 return failure();

659 }

660

661 return success();

662 }

663

667 return success();

668

669 return shapeAdaptor.getNumElements() == 1 ? success() : failure();

670 }

671

672

673

675 StringRef symName) {

677 tosa::VariableOp varOp = nullptr;

678

679

680

681

682

683

684

685 module.walk([&](Operation *tempOp) {

686

687 if (tempOp == op) {

689 }

690

691 if (auto tosaOp = dyn_casttosa::VariableOp(tempOp)) {

692 if (symName == tosaOp.getName()) {

693 varOp = tosaOp;

695 }

696 }

697

699 });

700

701 if (varOp)

702 return varOp;

703

704 return failure();

705 }

706

707 template

709 StringRef symName = op.getName();

710 FailureOrtosa::VariableOp varOp = findVariableDecl(op, symName);

711 if (failed(varOp))

712 return op->emitOpError("'")

713 << symName << "' has not been declared by 'tosa.variable'";

714

715

718 "the input tensor")

719 .failed())

720 return failure();

721

722 return success();

723 }

724

725

726 template

728 auto inputType = llvm::dyn_cast(inType);

729 auto outputType = llvm::dyn_cast(outType);

730 if (!inputType) {

731 op.emitOpError("expect shaped tensor for input, got ") << inType;

732 return failure();

733 }

734 if (!outputType) {

735 op.emitOpError("expect shaped tensor for output, got ") << outType;

736 return failure();

737 }

738 auto inputElementType = inputType.getElementType();

739 auto outputElementType = outputType.getElementType();

740 auto inputQuantType =

741 llvm::dyn_castmlir::quant::UniformQuantizedType(inputElementType);

742 auto outputQuantType =

743 llvm::dyn_castmlir::quant::UniformQuantizedType(outputElementType);

744 if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&

745 (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&

746 inputElementType != outputElementType) {

747

748

749

750

751 op.emitOpError("expect input and output to have same element type, got ")

752 << inputElementType << " and " << outputElementType;

753 return failure();

754 }

755 return success();

756 }

757

759 const ShapedType resultType = llvm::cast(getType());

760

761

762 if (const auto resultETy = resultType.getElementType();

763 !resultETy.isIntOrIndex())

764 return emitOpError("result tensor is not of integer type");

765

766 const auto inputType = llvm::cast(getInput().getType());

767 if (!inputType.hasRank())

768 return success();

769

770

771 const int64_t axis = getAxisAttr().getInt();

772 if (((axis < 0) || axis >= inputType.getRank()))

773 return emitOpError("specified axis is outside the rank of the tensor");

774

775 if (!resultType.hasRank())

776 return success();

777

781 expectedOutputShape.erase(expectedOutputShape.begin() + axis);

783 return emitOpError("expected output shape '")

784 << expectedOutputShape << "', got '" << outputShape << "'";

785

786 return success();

787 }

788

789 template

792 if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))

793 return op.emitOpError("expect all kernel values to be >= 1, got ")

794 << kernel;

795

797 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))

798 return op.emitOpError("expect all stride values to be >= 1, got ")

799 << strides;

800

802 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))

803 return op.emitOpError("expect all padding values to be >= 0, got ")

804 << padding;

805

806

807 const int64_t kernelX = kernel[1];

808 const int64_t padLeft = padding[2];

809 const int64_t padRight = padding[3];

810 if (padRight >= kernelX || padLeft >= kernelX)

811 return op.emitOpError("expected left/right padding to be less than the "

812 "width of the kernel, got pad_left=")

813 << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;

814

815 const int64_t kernelY = kernel[0];

816 const int64_t padTop = padding[0];

817 const int64_t padBottom = padding[1];

818 if (padTop >= kernelY || padBottom >= kernelY)

819 return op.emitOpError("expected top/bottom padding to be less than the "

820 "height of the kernel, got pad_top=")

821 << padTop << ", pad_bottom=" << padBottom

822 << ", kernel_y=" << kernelY;

823

824 const auto inputType =

825 llvm::dyn_cast(op.getInput().getType());

826 const auto outputType =

827 llvm::dyn_cast(op.getResult().getType());

828 if (!inputType || !outputType)

829 return success();

830

831 const auto verifyOutputSize =

832 [&op](const int64_t inputSize, const int64_t outputSize,

833 const int64_t kernelSize, const int64_t strideSize,

834 const int64_t padBefore, const int64_t padAfter,

835 const llvm::StringRef dimName, const llvm::StringRef dimAxis,

836 const llvm::StringRef padBeforeName,

837 const llvm::StringRef padAfterName) -> LogicalResult {

838 if (ShapedType::isDynamic(inputSize))

839 return success();

840

841 const std::optional<int64_t> calculatedOutSizeMinusOne =

842 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);

843 if (!calculatedOutSizeMinusOne.has_value())

844 return op.emitOpError("expected input_")

845 << dimName << " + pad_" << padBeforeName << " + pad_"

846 << padAfterName << " - kernel_" << dimAxis

847 << " to be wholly divisible by stride_" << dimAxis << ", got ("

848 << inputSize << " + " << padBefore << " + " << padAfter << " - "

849 << kernelSize << ") / " << strideSize;

850

851 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;

852 if (!ShapedType::isDynamic(outputSize) && calculatedOutSize != outputSize)

853 return op.emitOpError("calculated output ")

854 << dimName << " did not match expected: "

855 << "calculated=" << calculatedOutSize

856 << ", expected=" << outputSize;

857

858 return success();

859 };

860

861 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),

862 kernel[0], strides[0], padding[0], padding[1],

863 "height", "y", "top", "bottom")))

864 return failure();

865

866 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),

867 kernel[1], strides[1], padding[2], padding[3],

868 "width", "x", "left", "right")))

869 return failure();

870

871 return success();

872 }

873

876 return failure();

877

882

883 auto accType = getAccType();

884 if (llvm::isa(inputETy) && !accType.isInteger(32))

885 return emitOpError("accumulator type for integer tensor is not i32");

886

887 if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))

888 return emitOpError("accumulator type for f16 tensor is not f16/f32");

889

890 if (inputETy.isBF16() && !accType.isF32())

891 return emitOpError("accumulator type for bf16 tensor is not f32");

892

893 if (inputETy.isF32() && !accType.isF32())

894 return emitOpError("accumulator type for f32 tensor is not f32");

895

896 if (inputETy != inputZpETy)

897 return emitOpError("expect both input and its zero point are the same "

898 "element type, got ")

899 << inputETy << " and " << inputZpETy;

900

901 if (resultETy != outputZpETy)

902 return emitOpError("expect both output and its zero point are the same "

903 "element type, got ")

904 << resultETy << " and " << outputZpETy;

905

906 FailureOr<int64_t> maybeIZp = getInputZeroPoint();

907 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())

908 return failure();

909

910 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();

911 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())

912 return failure();

913

914 return success();

915 }

916

919 llvm::cast(getInput().getType()).getElementType();

920 if (auto quantType =

921 llvm::dyn_castmlir::quant::UniformQuantizedType(inputETy)) {

922 inputETy = quantType.getStorageType();

923 }

925 llvm::cast(getOutput().getType()).getElementType();

926 if (auto quantType =

927 llvm::dyn_castmlir::quant::UniformQuantizedType(outputETy)) {

928 outputETy = quantType.getStorageType();

929 }

930 if (inputETy != outputETy)

931 return emitOpError("input/output element types are incompatible.");

932

933 auto maxValAttr = getMaxValAttr();

934 auto minValAttr = getMinValAttr();

935

937

938 if (inputETy.isInteger(dataTypeBitWidth)) {

939

940

941

942 auto intMaxValAttr = mlir::dyn_castmlir::IntegerAttr(maxValAttr);

943 auto intMinValAttr = mlir::dyn_castmlir::IntegerAttr(minValAttr);

944 if (!intMaxValAttr || !intMinValAttr ||

945 (intMaxValAttr.getType() != intMinValAttr.getType()) ||

946 (intMaxValAttr.getType() != inputETy))

947 return emitOpError("min/max attributes types are incompatible with "

948 "input/output element types.");

949

950 const bool isUnsigned = cast(inputETy).isUnsigned();

951 const APInt minVal = intMinValAttr.getValue();

952 const APInt maxVal = intMaxValAttr.getValue();

953 if (isUnsigned ? maxVal.ult(minVal) : maxVal.slt(minVal))

954 return emitOpError("expected min_val <= max_val, got min_val=")

955 << minValAttr << ", max_val=" << maxValAttr;

956 } else {

957

958

959

960 auto floatMaxValAttr = mlir::dyn_castmlir::FloatAttr(maxValAttr);

961 auto floatMinValAttr = mlir::dyn_castmlir::FloatAttr(minValAttr);

962 if (!floatMaxValAttr || !floatMinValAttr ||

963 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||

964 (floatMaxValAttr.getType() != inputETy))

965 return emitOpError("min/max attributes types are incompatible with "

966 "input/output element types.");

967

968 const APFloat minVal = floatMinValAttr.getValue();

969 const APFloat maxVal = floatMaxValAttr.getValue();

970 if (minVal.isNaN() || maxVal.isNaN())

971 return emitOpError("min/max attributes should not be 'NaN', got min_val=")

972 << minValAttr << ", max_val=" << maxValAttr;

973

974 if (maxVal < minVal)

975 return emitOpError("expected min_val <= max_val, got min_val=")

976 << minValAttr << ", max_val=" << maxValAttr;

977 }

978

979 return success();

980 }

981

982

983

984

985

986

987

988

994 TypeAttr accType) {

996 result.addOperands({input, weight, bias, zps.first, zps.second});

1001 Type finalOutputType = outputType;

1003 if (quantAttr) {

1004 finalOutputType =

1006 }

1007 result.addTypes(finalOutputType);

1008 }

1009

1010

1011

1012 static void

1018 result.addOperands({input, weight, bias, zps.first, zps.second});

1022 Type finalOutputType = outputType;

1024 if (quantAttr) {

1025 finalOutputType =

1027 }

1028 result.addTypes(finalOutputType);

1029 }

1030

1031

1032

1033

1034

1039 result.addOperands({a, b, zps.first, zps.second});

1040

1041 Type finalOutputType{outputType};

1044 auto inputBits = eType.getIntOrFloatBitWidth();

1045

1046 auto outputShapedType = llvm::dyn_cast(outputType);

1047 assert(outputShapedType && "Output must be a shaped type");

1048

1049 IntegerType accElementType;

1050 if (inputBits == 16)

1052 else

1053 accElementType = builder.getI32Type();

1054

1055 finalOutputType = outputShapedType.clone(accElementType);

1056 }

1057 result.addTypes(finalOutputType);

1058 }

1059

1060

1061

1062

1063 static void

1066 DenseArrayAttr kernel, DenseArrayAttr stride,

1067 DenseArrayAttr pad, TypeAttr accType) {

1069 int64_t inputZp{0};

1070 int64_t outputZp{0};

1071

1072 if (auto quantAttr =

1074 inputZp = quantAttr.getInputZp();

1075 outputZp = quantAttr.getOutputZp();

1076 }

1077 const std::optional inputZpOp =

1079 if (!inputZpOp) {

1081 loc,

1082 "Failed to create input zero point tensor for quantized AVG_POOL2D op");

1083 }

1084 const std::optional outputZpOp =

1086 if (!outputZpOp) {

1087 (void)emitError(loc, "Failed to create output zero point tensor for "

1088 "quantized AVG_POOL2D op");

1089 }

1090

1091 if (inputZpOp && outputZpOp) {

1092 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});

1093 } else {

1094

1095

1096

1098 }

1103 result.types.push_back(outputType);

1104 }

1105

1106

1107

1108

1113 int64_t input1Zp{0};

1114 int64_t outputZp{0};

1116 if (quantAttr) {

1117 input1Zp = quantAttr.getInputZp();

1118 outputZp = quantAttr.getOutputZp();

1119 }

1120 const std::optional input1ZpOp =

1122 if (!input1ZpOp) {

1124 loc, "Failed to create input1 zero point for quantized NEGATE op");

1125 }

1126

1127 const std::optional outputZpOp =

1129 if (!outputZpOp) {

1131 loc, "Failed to create output zero point for quantized NEGATE op");

1132 }

1133

1134 if (input1ZpOp && outputZpOp) {

1135 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});

1136 } else {

1137

1138

1139

1141 }

1142

1143 result.types.push_back(outputType);

1144 }

1145

1146

1147

1148

1151 Value paddings) {

1153 int32_t zp{0};

1155 if (quantAttr) {

1156 zp = static_cast<int32_t>(quantAttr.getInputZp());

1157 }

1159 result.addOperands({input, paddings, padConstOp});

1160 result.types.push_back(outputType);

1161 }

1162

1164 StringRef name, Type variableType,

1168

1169 auto shapedType = dyn_cast(variableType);

1170 if (!shapedType) {

1171 (void)emitError(loc, "variable type must be a shaped type");

1172 return;

1173 }

1174 if (!shapedType.hasRank()) {

1175 (void)emitError(loc, "variable type must be a ranked type");

1176 return;

1177 }

1178

1179 auto elementType = shapedType.getElementType();

1180 auto elementTypeAttr = TypeAttr::get(elementType);

1183

1185 result.addAttribute("var_shape", varShapeAttr);

1186 result.addAttribute("type", elementTypeAttr);

1187 result.addAttribute("initial_value", initialValue);

1188 }

1189

1190

1191

1192

1193

1196 int64_t outRank = 0;

1197 for (int i = 0, e = operands.size(); i != e; ++i) {

1198 auto shape = operands.getShape(i);

1199 if (!shape.hasRank()) {

1200

1201

1202 return failure();

1203 }

1204 outRank = std::max<int64_t>(outRank, shape.getRank());

1205 }

1206

1207 outShape.resize(outRank, 1);

1208

1209 for (int i = 0, e = operands.size(); i != e; ++i) {

1210 auto shape = operands.getShape(i);

1211 auto rankDiff = outShape.size() - shape.getRank();

1212

1213 for (size_t i = 0, e = shape.getRank(); i < e; ++i) {

1214 auto dim1 = outShape[i + rankDiff];

1215 auto dim2 = shape.getDimSize(i);

1216 auto resolvedDim = dim1;

1217

1218 if (dim1 == 1) {

1219 resolvedDim = dim2;

1220 } else if (dim2 == 1) {

1221 resolvedDim = dim1;

1222 } else if (dim1 != dim2) {

1223 return failure();

1224 }

1225 outShape[i + rankDiff] = resolvedDim;

1226 }

1227 }

1228

1229 return success();

1230 }

1231

1232 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(

1233 MLIRContext *context, ::std::optional location,

1234 ArgMaxOp::Adaptor adaptor,

1236 ShapeAdaptor inputShape(adaptor.getInput().getType());

1237 IntegerAttr axis = adaptor.getProperties().axis;

1238 int32_t axisVal = axis.getValue().getSExtValue();

1239

1240 if (!inputShape.hasRank()) {

1242 return success();

1243 }

1244

1246 outShape.reserve(inputShape.getRank() - 1);

1247 for (int i = 0, s = inputShape.getRank(); i < s; i++) {

1248 if (i == axisVal)

1249 continue;

1250 outShape.push_back(inputShape.getDimSize(i));

1251 }

1252

1254 return success();

1255 }

1256

1257 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(

1258 MLIRContext *context, ::std::optional location,

1259 RFFT2dOp::Adaptor adaptor,

1261 ShapeAdaptor inputShape(adaptor.getInputReal().getType());

1262

1263 if (!inputShape.hasRank())

1264 return failure();

1265

1267 outputShape.resize(3, ShapedType::kDynamic);

1268 outputShape[0] = inputShape.getDimSize(0);

1269 outputShape[1] = inputShape.getDimSize(1);

1270 int64_t inWidth = inputShape.getDimSize(2);

1271

1272

1273

1274 if (inWidth != ShapedType::kDynamic)

1275 outputShape[2] = inWidth / 2 + 1;

1276

1279

1280 return success();

1281 }

1282

1284 const llvm::StringRef dimName) {

1285 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;

1286 if (!isPowerOfTwo)

1288 << dimName << " to be a power of two, got " << dimSize;

1289

1290 return success();

1291 }

1292

1294 const auto outputTypes = getResultTypes();

1296 return emitOpError("expected output shapes to match, got ") << outputTypes;

1297

1298 const auto inputType =

1299 llvm::dyn_cast(getInputReal().getType());

1300 if (!inputType)

1301 return success();

1302

1303 const int64_t height = inputType.getDimSize(1);

1304 if (!ShapedType::isDynamic(height) &&

1306 return failure();

1307

1308 const int64_t width = inputType.getDimSize(2);

1309 if (!ShapedType::isDynamic(width) &&

1311 return failure();

1312

1313 const auto outputType = llvm::dyn_cast(outputTypes[0]);

1314 if (!outputType)

1315 return success();

1316

1317

1319 outputType.getShape().drop_back())))

1320 return emitOpError("expected batch and height dimensions of input/output "

1321 "to match, got input=")

1322 << inputType << " output=" << outputType;

1323

1324

1325 const int64_t outputWidth = outputType.getDimSize(2);

1326 if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&

1327 (outputWidth != (width / 2) + 1))

1328 return emitOpError(

1329 "expected output width to be equal to input_width / 2 + 1, got ")

1330 << outputWidth;

1331

1332 return success();

1333 }

1334

1335 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(

1336 MLIRContext *context, ::std::optional location,

1337 FFT2dOp::Adaptor adaptor,

1339 inferredReturnShapes.push_back(

1341 inferredReturnShapes.push_back(

1343 return success();

1344 }

1345

1347 const auto inputRealType =

1348 llvm::dyn_cast(getInputReal().getType());

1349 const auto inputImagType =

1350 llvm::dyn_cast(getInputImag().getType());

1351 if (!inputRealType || !inputImagType)

1352 return success();

1353

1354 const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {

1355 return ShapedType::isDynamic(a) ? a : b;

1356 };

1357

1358 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),

1359 inputImagType.getDimSize(1));

1360 if (!ShapedType::isDynamic(height) &&

1362 return failure();

1363

1364 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),

1365 inputImagType.getDimSize(2));

1366 if (!ShapedType::isDynamic(width) &&

1368 return failure();

1369

1370 return success();

1371 }

1372

1373 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(

1374 MLIRContext *context, ::std::optional location,

1375 ConcatOp::Adaptor adaptor,

1377

1378 const Properties &prop = adaptor.getProperties();

1379 int32_t axis = prop.axis.getValue().getSExtValue();

1381 bool hasRankedInput = false;

1382 for (auto operand : adaptor.getOperands()) {

1383 ShapeAdaptor operandShape(operand.getType());

1384 if (!operandShape.hasRank())

1385 continue;

1386

1387

1388 if (!hasRankedInput)

1389 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);

1390

1391

1392 for (int i = 0, s = operandShape.getRank(); i < s; i++) {

1393 if (i == axis || operandShape.isDynamicDim(i))

1394 continue;

1395 if (outputShape[i] == ShapedType::kDynamic)

1396 outputShape[i] = operandShape.getDimSize(i);

1397 if (outputShape[i] != operandShape.getDimSize(i))

1399 "Cannot concat tensors with different sizes"

1400 " on the non-axis dimension ",

1401 i);

1402 }

1403

1404 hasRankedInput = true;

1405 }

1406

1407 if (adaptor.getInput1().empty())

1408 return failure();

1409

1410 Type inputType =

1411 llvm::cast(adaptor.getInput1().getType()[0]).getElementType();

1412 if (!hasRankedInput) {

1414 return success();

1415 }

1416

1417

1418 int64_t concatDimSize = 0;

1419 for (auto operand : adaptor.getOperands()) {

1420 ShapeAdaptor operandShape(operand.getType());

1421

1422

1423

1424 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {

1425 concatDimSize = ShapedType::kDynamic;

1426 break;

1427 }

1428

1429 concatDimSize += operandShape.getDimSize(axis);

1430 }

1431

1432 outputShape[axis] = concatDimSize;

1433

1434 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));

1435 return success();

1436 }

1437

1439

1440 auto outType = getOutput().getType();

1442

1443

1444 if (inputList.empty())

1445 return emitOpError("expect at least one input");

1446

1447 if (!llvm::all_of(inputList, [&](auto input) {

1449 *this, input.getType(), outType));

1450 })) {

1451 return failure();

1452 }

1453

1454 const int32_t axis = getAxis();

1455 ShapeAdaptor firstRankedInputShape = nullptr;

1456 for (const auto &input : inputList) {

1457 const Type inputType = input.getType();

1459 if (currShape.hasRank()) {

1460 firstRankedInputShape = currShape;

1461

1462 if (axis < 0 || axis >= firstRankedInputShape.getRank())

1463 return emitOpError("expect axis to be within range 0 < axis < "

1464 "rank(input1[firstRankedTensorIdx]), got ")

1465 << axis;

1466 break;

1467 }

1468 }

1469

1470 const auto allOperandsHasRank = [](const Value input) {

1472 };

1473 if (llvm::all_of(inputList, allOperandsHasRank)) {

1474 const int64_t firstInputRank = firstRankedInputShape.getRank();

1475

1476 for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {

1477 const ShapeAdaptor inputShape(input.getType());

1478 const int64_t inputRank = inputShape.getRank();

1479 const size_t operandNum = index + 1;

1480

1481

1482 if (inputRank != firstInputRank)

1483 return emitOpError(

1484 "expect all operands to have the same rank, but got ")

1485 << firstInputRank << " vs " << inputRank << " on operands 0 and "

1486 << operandNum;

1487

1488

1489 for (int i = 0; i < inputRank; i++) {

1490 const int64_t inputDim = inputShape.getDimSize(i);

1491 const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);

1492 if (i == axis || firstRankedInputShape.isDynamicDim(i) ||

1493 inputShape.isDynamicDim(i))

1494 continue;

1495 if (inputDim != firstInputDim)

1496 return emitOpError("expect all operand shapes to have the same sizes "

1497 "on non-axis dimensions, but got ")

1498 << inputDim << " vs " << firstInputDim << " at index " << i

1499 << " on operands 0 and " << operandNum;

1500 }

1501 }

1502

1503

1504 int64_t axisSum = 0;

1505 for (const auto &input : inputList) {

1506 const ShapeAdaptor inputShape(input.getType());

1507 if (inputShape.isDynamicDim(axis)) {

1508

1509 axisSum = -1;

1510 break;

1511 }

1512 axisSum += inputShape.getDimSize(axis);

1513 }

1515 if (axisSum >= 0 && outputShape.hasRank() &&

1516 !outputShape.isDynamicDim(axis) &&

1517 axisSum != outputShape.getDimSize(axis))

1518 return emitOpError("requires sum of axis dimensions of input1 "

1519 "equal to output axis dimension, got ")

1520 << axisSum << " and " << outputShape.getDimSize(axis);

1521 }

1522

1523 return success();

1524 }

1525

1526 LogicalResult tosa::EqualOp::inferReturnTypeComponents(

1527 MLIRContext *context, ::std::optional location,

1532

1536 return success();

1537 }

1538

1539 inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));

1540 return success();

1541 }

1542

1543 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {

1544 if (l.size() != r.size() || l.size() != 1)

1545 return false;

1547 }

1548

1549 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(

1550 MLIRContext *context, ::std::optional location,

1551 MatMulOp::Adaptor adaptor,

1553 ShapeAdaptor lhsShape(adaptor.getA().getType());

1554 ShapeAdaptor rhsShape(adaptor.getB().getType());

1555

1556

1558 outShape.resize(3, ShapedType::kDynamic);

1559

1560 if (lhsShape.hasRank()) {

1561 outShape[0] = lhsShape.getDimSize(0);

1562 outShape[1] = lhsShape.getDimSize(1);

1563 }

1564

1565 if (rhsShape.hasRank()) {

1566 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)

1567 : outShape[0];

1568 outShape[2] = rhsShape.getDimSize(2);

1569 }

1570

1572 return success();

1573 }

1574

1576 auto aType = llvm::dyn_cast(getA().getType());

1577 auto bType = llvm::dyn_cast(getB().getType());

1578

1579

1580 if (!aType)

1581 return emitOpError("expect a shaped tensor for input a, got ")

1582 << getA().getType();

1583

1584 if (!bType)

1585 return emitOpError("expect a shaped tensor for input b, got ")

1586 << getB().getType();

1587

1588 auto aElementType = aType.getElementType();

1589 auto bElementType = bType.getElementType();

1590

1591 auto aQuantizedEType =

1592 llvm::dyn_castquant::UniformQuantizedType(aElementType);

1593 auto bQuantizedEType =

1594 llvm::dyn_castquant::UniformQuantizedType(bElementType);

1595

1596 if (aQuantizedEType || bQuantizedEType) {

1597 if (!aQuantizedEType || !bQuantizedEType) {

1598 return emitOpError("expect operands to be both quantized or both not "

1599 "quantized, got ")

1600 << aElementType << " and " << bElementType;

1601 }

1602

1603 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();

1604 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();

1605 if (aQuantWidth != bQuantWidth) {

1606 return emitOpError("expect quantized operands to have same widths, got ")

1607 << aQuantWidth << " and " << bQuantWidth;

1608 }

1609 } else {

1610

1611 if (aElementType != bElementType) {

1612 return emitOpError("expect same element type for inputs a and b, got ")

1613 << aElementType << " and " << bElementType;

1614 }

1615 }

1616

1617

1620 if (aEType != aZpEType) {

1621 return emitOpError("expect input a and a_zp have the same "

1622 "element type, got ")

1623 << aEType << " and " << aZpEType;

1624 }

1625

1628 if (bEType != bZpEType) {

1629 return emitOpError("expect input b and b_zp have the same "

1630 "element type, got ")

1631 << bEType << " and " << bZpEType;

1632 }

1633

1634 FailureOr<int64_t> maybeAZp = getAZeroPoint();

1635 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())

1636 return failure();

1637

1638 FailureOr<int64_t> maybeBZp = getBZeroPoint();

1639 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())

1640 return failure();

1641

1642 return success();

1643 }

1644

1645 LogicalResult tosa::PadOp::inferReturnTypeComponents(

1646 MLIRContext *context, ::std::optional location,

1647 PadOp::Adaptor adaptor,

1649 ShapeAdaptor inputShape(adaptor.getInput1().getType());

1650 auto paddingRank =

1651 casttosa::shapeType(adaptor.getPadding().getType()).getRank();

1653

1654

1655

1656 if (!inputShape.hasRank()) {

1657 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);

1659 return success();

1660 }

1661

1663

1665 paddingValues)) {

1666 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);

1668 return success();

1669 }

1670

1671 outputShape.reserve(inputShape.getRank());

1672 for (int i = 0, s = inputShape.getRank(); i < s; i++) {

1673 if (inputShape.isDynamicDim(i)) {

1674 outputShape.push_back(ShapedType::kDynamic);

1675 continue;

1676 }

1677 auto padFront = paddingValues[i * 2];

1678 auto padBack = paddingValues[i * 2 + 1];

1679 if (padFront < 0 || padBack < 0) {

1680

1681 outputShape.push_back(ShapedType::kDynamic);

1682 continue;

1683 }

1684

1685 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);

1686 }

1687

1689 return success();

1690 }

1691

1694 getOutput().getType())

1695 .failed()) {

1696 return failure();

1697 }

1698

1699 if (auto padConst = getPadConst()) {

1701 getOutput().getType())

1702 .failed()) {

1703 return failure();

1704 }

1705 }

1706

1707 RankedTensorType inputType =

1708 llvm::dyn_cast(getInput1().getType());

1709 RankedTensorType outputType =

1710 llvm::dyn_cast(getOutput().getType());

1711 if (!inputType || !outputType)

1712 return success();

1713

1714 auto inputRank = inputType.getRank();

1715 auto outputRank = outputType.getRank();

1716 if (inputRank != outputRank)

1717 return emitOpError() << "expect same input and output tensor rank, but got "

1718 << "inputRank: " << inputRank

1719 << ", outputRank: " << outputRank;

1720

1723 return failure();

1724 }

1725

1726 auto paddingValues = paddingAttr.getValues();

1727 if (paddingValues.size() != static_cast<size_t>(inputRank * 2))

1728 return emitOpError() << "padding tensor must have " << inputRank

1729 << " * 2 = " << inputRank * 2 << " elements, but got "

1730 << paddingValues.size();

1731

1732 auto inputShape = inputType.getShape();

1733 auto outputShape = outputType.getShape();

1734

1735 for (int64_t i = 0; i < inputRank; ++i) {

1736 int64_t padStart = paddingValues[i * 2].getSExtValue();

1737 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();

1738

1739 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {

1740 return emitOpError()

1741 << "invalid padding values at dimension " << i

1742 << ": values must be non-negative or -1 for dynamic padding, got ["

1743 << padStart << ", " << padEnd << "]";

1744 }

1745

1746

1747 if (inputShape[i] == ShapedType::kDynamic ||

1748 outputShape[i] == ShapedType::kDynamic)

1749 continue;

1750

1751 if (outputShape[i] != inputShape[i] + padStart + padEnd) {

1752 return emitOpError() << "mismatch in output shape at dimension " << i

1753 << ": expected " << inputShape[i] << " + "

1754 << padStart << " + " << padEnd << " = "

1755 << (inputShape[i] + padStart + padEnd)

1756 << ", but got " << outputShape[i];

1757 }

1758 }

1759

1760 return success();

1761 }

1762

1763 LogicalResult tosa::SliceOp::inferReturnTypeComponents(

1764 MLIRContext *context, ::std::optional location,

1765 SliceOp::Adaptor adaptor,

1767

1771

1774 auto rank = casttosa::shapeType(adaptor.getSize().getType()).getRank();

1777 return success();

1778 }

1779

1780

1781

1782 ShapeAdaptor inputShape(adaptor.getInput1().getType());

1783

1785 if (inputShape.hasRank()) {

1786 for (size_t i = 0; i < size.size(); i++) {

1787 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&

1788 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||

1789 start[i] < inputShape.getDimSize(i))) {

1790

1791 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {

1792

1793 if (size[i] > 0) {

1794 outputShape[i] = size[i];

1795 }

1796 } else {

1797

1798 if (size[i] == -1) {

1799 outputShape[i] = inputShape.getDimSize(i) - start[i];

1800 } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {

1801

1802 outputShape[i] = size[i];

1803 }

1804 }

1805 }

1806 }

1807 } else {

1809 }

1811 return success();

1812 }

1813

1816 getOutput().getType())

1817 .failed())

1818 return failure();

1819

1821 if (inputShape.hasRank()) {

1822 const auto inputRank = inputShape.getRank();

1824 if (outputShape.hasRank() && inputRank != outputShape.getRank())

1825 return emitOpError(

1826 "expect input1 and output to have the same ranks, got ")

1827 << inputRank << " and " << outputShape.getRank();

1828

1829 const auto startShapeRank =

1830 llvm::casttosa::shapeType(getStart().getType()).getRank();

1831 if (inputRank != startShapeRank)

1832 return emitOpError("length of start is not equal to rank of input shape");

1833

1834 const auto sizeShapeRank =

1835 llvm::casttosa::shapeType(getSize().getType()).getRank();

1836 if (inputRank != sizeShapeRank)

1837 return emitOpError("length of size is not equal to rank of input shape");

1838 }

1839

1840 return success();

1841 }

1842

1843 LogicalResult tosa::MulOp::inferReturnTypeComponents(

1844 MLIRContext *context, ::std::optional location,

1848

1853 } else {

1855 }

1856 return success();

1857 }

1858

1860 const Value output = getOutput();

1862

1863

1864

1865 if (auto resIntType = dyn_cast(resElemType)) {

1866 IntegerType lhsIntType =

1868 IntegerType rhsIntType =

1870 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)

1871 return emitOpError("requires the same element type for all operands");

1872

1873

1874

1875

1876 if (lhsIntType.getWidth() > resIntType.getWidth())

1877 return emitOpError("invalid data type size for operands or result");

1878

1879 } else {

1880

1881

1882 for (int i = 0; i < 2; ++i) {

1884 return emitOpError(

1885 "requires the same element type for all operands and results");

1886 }

1887

1888

1889 ElementsAttr shift_elem;

1891 int32_t shift = shift_elem.getValues()[0].getInt();

1892 if (shift != 0) {

1893 return emitOpError() << "require shift to be 0 for float type";

1894 }

1895 }

1896 }

1897

1898

1899

1900

1901 TypeRange operandTypes = getOperandTypes();

1902 ShapedType aType = cast(operandTypes[0]);

1903 ShapedType bType = cast(operandTypes[1]);

1904

1905 const bool aHasRank = aType.hasRank();

1906 const bool bHasRank = bType.hasRank();

1907 if (aHasRank && bHasRank) {

1908 const int64_t aRank = aType.getRank();

1909 const int64_t bRank = bType.getRank();

1910 if (aRank != bRank)

1911 return emitOpError("a and b operands don't have matching ranks, got ")

1912 << aRank << " and " << bRank;

1913

1914

1917 aType.getShape(), bType.getShape(), resultShape))

1918 return emitOpError("a and b operands don't have broadcast-compatible "

1919 "shapes, got ")

1920 << aType << " and " << bType;

1921 }

1922

1923 ShapedType resultType = cast(output.getType());

1924 if (!resultType.hasRank())

1925 return success();

1926

1927 const int64_t resultRank = resultType.getRank();

1928 if (aHasRank && resultRank != aType.getRank())

1929 return emitOpError("result type has different rank than a, got ")

1930 << resultRank << " vs " << aType.getRank();

1931 if (bHasRank && resultRank != bType.getRank())

1932 return emitOpError("result type has different rank than b, got ")

1933 << resultRank << " vs " << bType.getRank();

1934

1935 return success();

1936 }

1937

1938 LogicalResult tosa::TableOp::inferReturnTypeComponents(

1939 MLIRContext *context, ::std::optional location,

1940 TableOp::Adaptor adaptor,

1942 ShapeAdaptor inputShape(adaptor.getInput1().getType());

1943

1944 if (!inputShape.hasRank()) {

1946 return success();

1947 }

1948

1949 inferredReturnShapes.resize(1);

1950 inputShape.getDims(inferredReturnShapes[0]);

1951 return success();

1952 }

1953

1955 TensorType inputType = getInput1().getType();

1956 TensorType outputType = getOutput().getType();

1957

1959 inputType.getRank() != outputType.getRank())

1960 return emitOpError()

1961 << "expected input tensor rank to equal result tensor rank";

1962

1963 auto inputDims = inputType.getShape();

1964 auto outputDims = outputType.getShape();

1965 for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {

1966 int64_t dim = it.index();

1967 auto [inputDim, outputDim] = it.value();

1968 if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {

1969 return emitOpError() << "dim(result, " << dim << ") = " << outputDim

1970 << " doesn't match dim(input, " << dim

1971 << ") = " << inputDim;

1972 }

1973 }

1974 return success();

1975 }

1976

1977 LogicalResult

1979

1982 return failure();

1983 multiples = llvm::to_vector(

1984 llvm::map_range(multiplesAttr.getValues(),

1985 [](const APInt &val) { return val.getSExtValue(); }));

1986 return success();

1987 }

1988

1989 LogicalResult tosa::TileOp::inferReturnTypeComponents(

1990 MLIRContext *context, ::std::optional location,

1991 TileOp::Adaptor adaptor,

1996 multiples)) {

1997 auto rank =

1998 casttosa::shapeType(adaptor.getMultiples().getType()).getRank();

2001 return success();

2002 } else {

2004 }

2005

2006 ShapeAdaptor inputShape(adaptor.getInput1().getType());

2008 if (!inputShape.hasRank()) {

2009 outputShape.resize(multiples.size(), ShapedType::kDynamic);

2010 inferredReturnShapes.push_back(

2012 return success();

2013 } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())

2014 return failure();

2015

2016

2017 outputShape.reserve(multiples.size());

2018 for (int i = 0, s = inputShape.getRank(); i < s; i++) {

2019 if (multiples[i] == ShapedType::kDynamic) {

2020 outputShape.push_back(ShapedType::kDynamic);

2021 } else {

2022 int64_t dim = inputShape.getDimSize(i);

2023 if (dim != ShapedType::kDynamic)

2024 dim *= multiples[i];

2025 outputShape.push_back(dim);

2026 }

2027 }

2028

2029 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));

2030 return success();

2031 }

2032

2035 getOutput().getType())

2036 .failed()) {

2037 return failure();

2038 }

2039 ShapedType inputType = llvm::cast(getInput1().getType());

2040 ShapedType outputType = llvm::cast(getType());

2041

2042 shapeType multiplesType =

2043 llvm::casttosa::shapeType(getMultiples().getType());

2044

2045 auto multiplesRank = multiplesType.getRank();

2046

2047 if (inputType.hasRank()) {

2048 if (inputType.getRank() != multiplesRank)

2049 return emitOpError("expect 'multiples' to have rank ")

2050 << inputType.getRank() << " but got " << multiplesRank << ".";

2051 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())

2052 return emitOpError("expect same input and output tensor rank.");

2053 } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)

2054 return emitOpError("expect 'multiples' array to have length ")

2055 << outputType.getRank() << " but got " << multiplesRank << ".";

2056

2058 if (getConstantMultiples(multiples).succeeded() &&

2059 llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))

2060 return emitOpError(

2061 "expect element of 'multiples' to be positive integer or -1.");

2062

2063 return success();

2064 }

2065

2066 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {

2067 if (l.size() != r.size() || l.size() != 1)

2068 return false;

2070 }

2071

2072 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(

2073 MLIRContext *context, ::std::optional location,

2074 ReshapeOp::Adaptor adaptor,

2076 ShapeAdaptor inputShape(adaptor.getInput1().getType());

2080 newShapeValue)) {

2081 auto rank = casttosa::shapeType(adaptor.getShape().getType()).getRank();

2084 return success();

2085 } else {

2087 }

2088

2089

2090

2091 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {

2092 inferredReturnShapes.push_back(

2094 return success();

2095 }

2096

2097

2098

2099

2100 int64_t numElements = inputShape.getNumElements();

2101 int64_t staticMul = 1;

2102 for (auto val : newShapeValue) {

2103 if (!ShapedType::isDynamic(val)) {

2104 staticMul *= val;

2105 }

2106 }

2107

2108

2109 for (auto &val : newShapeValue) {

2110 if (ShapedType::isDynamic(val))

2111 val = numElements / staticMul;

2112 }

2113

2114 inferredReturnShapes.push_back(

2116 return success();

2117 }

2118

2121 getOutput().getType())

2122 .failed()) {

2123 return failure();

2124 }

2125 TensorType inputType = getInput1().getType();

2126

2129

2130 return mlir::success();

2131 }

2132

2133 int missingDims = llvm::count(shapeValues, -1);

2134 if (missingDims > 1)

2135 return emitOpError() << "expected at most one target dimension to be -1";

2136

2137 const auto outputType = dyn_cast(getType());

2138 if (!outputType)

2139 return success();

2140

2141 if ((int64_t)shapeValues.size() != outputType.getRank())

2142 return emitOpError() << "new shape does not match result rank";

2143

2144 for (auto [newShapeDim, outputShapeDim] :

2145 zip(shapeValues, outputType.getShape())) {

2146 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&

2147 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)

2148 return emitOpError() << "new shape is inconsistent with result shape";

2149

2150 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)

2151 return emitOpError() << "new shape has invalid tensor dimension size "

2152 << newShapeDim;

2153 }

2154

2155 if (inputType.hasStaticShape()) {

2156 int64_t inputElementsNum = inputType.getNumElements();

2157 if (outputType.hasStaticShape()) {

2158 int64_t outputElementsNum = outputType.getNumElements();

2159 if (inputElementsNum != outputElementsNum) {

2160 return emitOpError() << "cannot reshape " << inputElementsNum

2161 << " elements into " << outputElementsNum;

2162 }

2163 }

2164

2165 int64_t newShapeElementsNum = std::accumulate(

2166 shapeValues.begin(), shapeValues.end(), 1LL,

2167 [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });

2168 bool isStaticNewShape =

2169 llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });

2170 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||

2171 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {

2172 return emitOpError() << "cannot reshape " << inputElementsNum

2173 << " elements into " << newShapeElementsNum;

2174 }

2175 }

2176

2177 return mlir::success();

2178 }

2179

2180

2181

2182

2184 ElementsAttr zpAttr;

2186 return failure();

2187 }

2188

2189 Type zpElemType = zpAttr.getElementType();

2190

2191 if (llvm::isa(zpElemType)) {

2192 if (zpAttr.getValues()[0].isZero()) {

2193 return 0;

2194 }

2195

2196 return -1;

2197 }

2198

2199 if (llvm::isa(zpElemType)) {

2200 if (signExtend)

2201 return zpAttr.getValues()[0].getSExtValue();

2202 else

2203 return zpAttr.getValues()[0].getZExtValue();

2204 }

2205

2206

2207 return -1;

2208 }

2209

2210 template

2212 const std::string &operand) {

2214

2215 if (!zpElemType.isInteger(8) && zp != 0) {

2216

2217 std::string lower = operand;

2218 std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);

2219 return op.emitOpError()

2220 << lower << " zero point must be zero for non-int8 integer types";

2221 }

2222

2223 return success();

2224 }

2225

2227 const int64_t &zp,

2228 const std::string &operand) {

2229 bool isInputZp = (operand == "Input");

2230

2231 bool tensorUnsigned =

2232 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();

2233 StringRef tensorName = isInputZp ? "input" : "output";

2234

2236

2237 if (zp != 0) {

2239 !(zpElemType.isInteger(16) && tensorUnsigned)) {

2240 return op.emitOpError()

2241 << "expect " << tensorName << "_zp of 0, got " << zp;

2242 }

2243 if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {

2244 return op.emitOpError() << "expect " << tensorName

2245 << "_zp of 0 or 32768 for unsigned int16 "

2246 << tensorName << ", got " << zp;

2247 }

2248 }

2249

2250 return success();

2251 }

2252

2253 #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \

2254 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \

2255 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \

2256 } \

2257 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \

2258 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \

2259 }

2260

2277 #undef ZERO_POINT_HELPER

2278

2279 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(

2280 MLIRContext *context, ::std::optional location,

2281 TransposeOp::Adaptor adaptor,

2283 ShapeAdaptor inputShape(adaptor.getInput1().getType());

2284

2285

2286

2287 if (!inputShape.hasRank()) {

2289 return success();

2290 }

2291

2292 const auto inputRank = inputShape.getRank();

2293

2294

2295

2296 if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {

2297 return failure();

2298 }

2299

2301

2302 if (inputRank == 0) {

2304 return success();

2305 }

2306

2307

2308 bool allTheSame = true;

2309 for (int i = 1, s = inputRank; i < s; i++) {

2311 allTheSame = false;

2312 break;

2313 }

2314 }

2315

2316

2317

2318 if (allTheSame) {

2319 outputShape.resize(inputRank, inputShape.getDimSize(0));

2321 return success();

2322 }

2323

2324 outputShape.resize(inputRank, ShapedType::kDynamic);

2325

2326

2327 if (llvm::any_of(adaptor.getPerms(),

2328 [inputRank](const auto i) { return i >= inputRank; }))

2329 return failure();

2330

2331 outputShape.reserve(inputRank);

2332 for (int i = 0, s = inputRank; i < s; i++) {

2333 outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);

2334 }

2335

2337 return success();

2338 }

2339

2342 getOutput().getType())

2343 .failed()) {

2344 return failure();

2345 }

2346

2349

2351

2352 if (inputShape.hasRank() &&

2353 constantPerms.size() != static_cast<size_t>(inputShape.getRank()))

2354 return emitOpError() << "expected perms attribute to have size "

2355 << inputShape.getRank()

2356 << " (input rank) but got size "

2357 << constantPerms.size();

2358

2359 if (inputShape.hasRank() && outputShape.hasRank() &&

2360 inputShape.getRank() != outputShape.getRank())

2361 return emitOpError()

2362 << "expected input tensor rank to equal result tensor rank";

2363

2364 if (outputShape.hasRank() &&

2365 constantPerms.size() != static_cast<size_t>(outputShape.getRank()))

2366 return emitOpError() << "expected perms attribute to have size "

2367 << outputShape.getRank()

2368 << " (output rank) but got size "

2369 << constantPerms.size();

2370

2371 if (!llvm::all_of(constantPerms,

2372 [&constantPerms](int32_t s) {

2373 return s >= 0 &&

2374 static_cast<size_t>(s) < constantPerms.size();

2375 }) ||

2377 constantPerms, [](int32_t v) -> int64_t { return v; }))))

2378 return emitOpError() << "expected valid permutation indices";

2379

2380

2381 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&

2382 inputShape.getNumElements() != outputShape.getNumElements())

2383 return emitOpError() << "expected input1 and output to have same numbers "

2384 "of elements, got "

2385 << inputShape.getNumElements() << " and "

2386 << outputShape.getNumElements();

2387

2388

2389

2390 if (inputShape.hasRank() && outputShape.hasRank()) {

2391 for (auto i = 0; i < outputShape.getRank(); i++) {

2392 if (inputShape.isDynamicDim(constantPerms[i]) ||

2393 outputShape.isDynamicDim(i))

2394 continue;

2395

2396 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))

2397 return emitOpError()

2398 << "expected output tensor dim " << i << " to match "

2399 << "input dim " << constantPerms[i] << " with value of "

2400 << inputShape.getDimSize(constantPerms[i]);

2401 }

2402 }

2403

2404 return success();

2405 }

2406

2409

2411

2412 Value input = getInput1();

2413 auto inputType = cast(input.getType());

2414

2416 for (auto dim : transposePerms) {

2417 int32_t dimInInput = transposePerms[dim];

2418 if (inputType.isDynamicDim(dimInInput))

2419 returnedDims[dim] =

2420 builder.createtensor::DimOp(getLoc(), input, dimInInput)

2421 .getResult();

2422 else

2423 returnedDims[dim] =

2424 builder.getIndexAttr(inputType.getDimSize(dimInInput));

2425 }

2426

2427 reifiedReturnShapes.emplace_back(std::move(returnedDims));

2428 return success();

2429 }

2430

2431 LogicalResult tosa::GatherOp::inferReturnTypeComponents(

2432 MLIRContext *context, ::std::optional location,

2433 GatherOp::Adaptor adaptor,

2436 outputShape.resize(3, ShapedType::kDynamic);

2437

2438 ShapeAdaptor valuesShape(adaptor.getValues().getType());

2439 if (valuesShape.hasRank()) {

2440 outputShape[0] = valuesShape.getDimSize(0);

2441 outputShape[2] = valuesShape.getDimSize(2);

2442 }

2443

2444 ShapeAdaptor indicesShape(adaptor.getIndices().getType());

2445 if (indicesShape.hasRank()) {

2446 if (outputShape[0] == ShapedType::kDynamic)

2447 outputShape[0] = indicesShape.getDimSize(0);

2448 if (outputShape[1] == ShapedType::kDynamic)

2449 outputShape[1] = indicesShape.getDimSize(1);

2450 }

2451

2453 return success();

2454 }

2455

2458 getOutput().getType())

2459 .failed()) {

2460 return failure();

2461 }

2462

2466

2467 int64_t N = ShapedType::kDynamic;

2468 int64_t W = ShapedType::kDynamic;

2469 int64_t C = ShapedType::kDynamic;

2470

2471 if (valuesShape.hasRank()) {

2472 N = valuesShape.getDimSize(0);

2473 C = valuesShape.getDimSize(2);

2474 }

2475 if (indicesShape.hasRank()) {

2476 const int64_t indicesN = indicesShape.getDimSize(0);

2477 W = indicesShape.getDimSize(1);

2478 if (N == ShapedType::kDynamic)

2479 N = indicesN;

2480 else if (indicesN != ShapedType::kDynamic && N != indicesN)

2481 return emitOpError() << "requires indices dimension 0 to have size " << N

2482 << ", got " << indicesN;

2483 }

2484 if (outputShape.hasRank()) {

2485 const int64_t outputN = outputShape.getDimSize(0);

2486 const int64_t outputW = outputShape.getDimSize(1);

2487 const int64_t outputC = outputShape.getDimSize(2);

2488 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&

2489 N != outputN)

2490 return emitOpError() << "requires output dimension 0 to have size " << N

2491 << ", got " << outputN;

2492

2493 if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&

2494 W != outputW)

2495 return emitOpError() << "requires output dimension 1 to have size " << W

2496 << ", got " << outputW;

2497 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&

2498 C != outputC)

2499 return emitOpError() << "requires output dimension 2 to have size " << C

2500 << ", got " << outputC;

2501 }

2502 return success();

2503 }

2504

2505 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(

2506 MLIRContext *context, ::std::optional location,

2507 ResizeOp::Adaptor adaptor,

2510 outputShape.resize(4, ShapedType::kDynamic);

2511

2512 ShapeAdaptor inputShape(adaptor.getInput().getType());

2513 if (!inputShape.hasRank())

2514 return failure();

2515

2516 outputShape[0] = inputShape.getDimSize(0);

2517 outputShape[3] = inputShape.getDimSize(3);

2518 int64_t inputHeight = inputShape.getDimSize(1);

2519 int64_t inputWidth = inputShape.getDimSize(2);

2520

2521 if ((inputHeight == ShapedType::kDynamic) ||

2522 (inputWidth == ShapedType::kDynamic))

2523 return failure();

2524

2527 scaleInt) ||

2529 offsetInt) ||

2531 borderInt)) {

2532 return failure();

2533 }

2534

2535

2536 outputShape[1] =

2537 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /

2538 scaleInt[1]) +

2539 1;

2540

2541 outputShape[2] =

2542 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /

2543 scaleInt[3]) +

2544 1;

2545

2547 return success();

2548 }

2549

2551 const Value input = getInput();

2552 const Value output = getOutput();

2553 const RankedTensorType inputType =

2554 llvm::dyn_cast(input.getType());

2555 const RankedTensorType outputType =

2556 llvm::dyn_cast(output.getType());

2557

2564

2565 return success();

2566 }

2567

2568 if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))

2569 return emitOpError("expect all scale values to be > 0, got ")

2570 << scaleValues;

2571

2572 const int64_t scaleYN = scaleValues[0];

2573 const int64_t scaleYD = scaleValues[1];

2574 const int64_t scaleXN = scaleValues[2];

2575 const int64_t scaleXD = scaleValues[3];

2576

2577 const int64_t offsetY = offsetValues[0];

2578 const int64_t offsetX = offsetValues[1];

2579

2580 const int64_t borderY = borderValues[0];

2581 const int64_t borderX = borderValues[1];

2582

2583 if (!inputType)

2584 return success();

2585 if (!outputType)

2586 return success();

2587

2588 const int64_t oh = outputType.getDimSize(1);

2589 const int64_t ow = outputType.getDimSize(2);

2590 const int64_t ih = inputType.getDimSize(1);

2591 const int64_t iw = inputType.getDimSize(2);

2592

2593

2594

2595

2596

2597 if (ih != ShapedType::kDynamic && ih != 1) {

2598 const std::optional<int64_t> calculatedOutHeightMinusOne =

2599 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);

2600 if (!calculatedOutHeightMinusOne.has_value())

2601 return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "

2602 "border_y ")

2603 << "to be wholly divisible by scale_y_d, got ((" << ih

2604 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY

2605 << ") / " << scaleYD;

2606 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;

2607 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)

2608 return emitOpError("calculated output height did not match expected: ")

2609 << "calculated=" << calculatedOutHeight << ", expected=" << oh;

2610 }

2611

2612

2613

2614

2615

2616 if (iw != ShapedType::kDynamic && iw != 1) {

2617 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;

2618 const std::optional<int64_t> calculatedOutWidthMinusOne =

2619 idivCheck(scaledInWidth, scaleXD);

2620 if (!calculatedOutWidthMinusOne.has_value())

2621 return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "

2622 "border_x ")

2623 << "to be wholly divisible by scale_x_d, got ((" << iw

2624 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX

2625 << ") / " << scaleXD;

2626 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;

2627 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)

2628 return emitOpError("calculated output width did not match expected: ")

2629 << "calculated=" << calculatedOutWidth << ", expected=" << ow;

2630 }

2631

2632 return success();

2633 }

2634

2635 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(

2636 MLIRContext *context, ::std::optional location,

2637 ScatterOp::Adaptor adaptor,

2640 outputShape.resize(3, ShapedType::kDynamic);

2641

2642 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());

2643 if (valuesInShape.hasRank()) {

2644 outputShape[0] = valuesInShape.getDimSize(0);

2645 outputShape[1] = valuesInShape.getDimSize(1);

2646 outputShape[2] = valuesInShape.getDimSize(2);

2647 }

2648

2649 ShapeAdaptor indicesShape(adaptor.getIndices().getType());

2650 if (indicesShape.hasRank()) {

2651 if (outputShape[0] == ShapedType::kDynamic)

2652 outputShape[0] = indicesShape.getDimSize(0);

2653 }

2654

2655 ShapeAdaptor inputShape(adaptor.getInput().getType());

2656 if (inputShape.hasRank()) {

2657 if (outputShape[0] == ShapedType::kDynamic)

2658 outputShape[0] = inputShape.getDimSize(0);

2659 if (outputShape[2] == ShapedType::kDynamic)

2660 outputShape[2] = inputShape.getDimSize(2);

2661 }

2662

2664 return success();

2665 }

2666

2669 getValuesOut().getType())

2670 .failed() ||

2672 getValuesOut().getType())

2673 .failed()) {

2674 return failure();

2675 }

2676

2681

2682 int64_t N = ShapedType::kDynamic;

2683 int64_t K = ShapedType::kDynamic;

2684 int64_t W = ShapedType::kDynamic;

2685 int64_t C = ShapedType::kDynamic;

2686 if (valuesInShape.hasRank()) {

2687 N = valuesInShape.getDimSize(0);

2688 K = valuesInShape.getDimSize(1);

2689 C = valuesInShape.getDimSize(2);

2690 }

2691 if (indicesShape.hasRank()) {

2692 const int64_t indicesN = indicesShape.getDimSize(0);

2693 W = indicesShape.getDimSize(1);

2694 if (N == ShapedType::kDynamic)

2695 N = indicesN;

2696 else if (indicesN != ShapedType::kDynamic && N != indicesN)

2697 return emitOpError() << "requires indices dimension 0 to have size " << N

2698 << ", got " << indicesN;

2699 }

2700 if (inputShape.hasRank()) {

2701 const int64_t inputN = inputShape.getDimSize(0);

2702 const int64_t inputW = inputShape.getDimSize(1);

2703 const int64_t inputC = inputShape.getDimSize(2);

2704 if (N == ShapedType::kDynamic)

2705 N = inputN;

2706 else if (inputN != ShapedType::kDynamic && N != inputN)

2707 return emitOpError() << "requires input dimension 0 to have size " << N

2708 << ", got " << inputN;

2709 if (W == ShapedType::kDynamic)

2710 W = inputW;

2711 else if (inputW != ShapedType::kDynamic && W != inputW)

2712 return emitOpError() << "requires input dimension 1 to have size " << W

2713 << ", got " << inputW;

2714

2715 if (C == ShapedType::kDynamic)

2716 C = inputC;

2717 else if (inputC != ShapedType::kDynamic && C != inputC)

2718 return emitOpError() << "requires input dimension 2 to have size " << C

2719 << ", got " << inputC;

2720 }

2721 if (outputShape.hasRank()) {

2722 const int64_t outputN = outputShape.getDimSize(0);

2723 const int64_t outputK = outputShape.getDimSize(1);

2724 const int64_t outputC = outputShape.getDimSize(2);

2725 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&

2726 N != outputN)

2727 return emitOpError() << "requires values_out dimension 0 to have size "

2728 << N << ", got " << outputN;

2729 if (K == ShapedType::kDynamic)

2730 K = outputK;

2731 else if (outputK != ShapedType::kDynamic && K != outputK)

2732 return emitOpError() << "requires values_out dimension 1 to have size "

2733 << K << ", got " << outputK;

2734 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&

2735 C != outputC)

2736 return emitOpError() << "requires values_out dimension 2 to have size "

2737 << C << ", got " << outputC;

2738 }

2739 if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))

2740 return emitOpError() << "requires dimensions K >= W, got K=" << K

2741 << " and W=" << W;

2742

2743 return success();

2744 }

2745

2747 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,

2749 int64_t axisVal = axis.getValue().getSExtValue();

2750 if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {

2752 return success();

2753 }

2754

2756 operandShape.getDims(outputShape);

2757 outputShape[axisVal] = 1;

2758 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));

2759 return success();

2760 }

2761

2762 #define COMPATIBLE_RETURN_TYPES(OP) \

2763 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \

2764 if (l.size() != r.size() || l.size() != 1) \

2765 return false; \

2766 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \

2767 return false; \

2768 return succeeded(verifyCompatibleShape(l[0], r[0])); \

2769 }

2770

2771 #define REDUCE_SHAPE_INFER(OP) \

2772 LogicalResult OP::inferReturnTypeComponents( \

2773 MLIRContext *context, ::std::optional location, \

2774 OP::Adaptor adaptor, \

2775 SmallVectorImpl &inferredReturnShapes) { \

2776 Type inputType = \

2777 llvm::cast(adaptor.getInput().getType()).getElementType(); \

2778 ShapeAdaptor inputShape(adaptor.getInput().getType()); \

2779 const Properties &prop = adaptor.getProperties(); \

2780 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \

2781 inferredReturnShapes); \

2782 } \

2783 COMPATIBLE_RETURN_TYPES(OP)

2784

2791 #undef REDUCE_SHAPE_INFER

2793 #undef COMPATIBLE_RETURN_TYPES

2794

2795 template

2797

2798 TensorType inputType = op.getInput().getType();

2799 TensorType outputType = op.getOutput().getType();

2800 int32_t reduceAxis = op.getAxis();

2801

2802 if (reduceAxis < 0) {

2803 op.emitOpError("reduce axis must not be negative");

2804 return failure();

2805 }

2806 if (inputType.hasRank()) {

2807 int64_t inputRank = inputType.getRank();

2808

2809

2810 if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {

2811 op.emitOpError("expect input tensor rank (")

2812 << inputRank << ") to be larger than reduce axis (" << reduceAxis

2813 << ")";

2814 return failure();

2815 }

2816 }

2817 if (outputType.hasRank()) {

2818 int64_t outputRank = outputType.getRank();

2819 if (inputType.hasRank() && outputRank != inputType.getRank()) {

2820 op.emitOpError(

2821 "expect output tensor rank to be equal to input tensor rank");

2822 return failure();

2823 }

2824 if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {

2825 op.emitOpError("expect output tensor rank (")

2826 << outputRank << ") to be larger than reduce axis (" << reduceAxis

2827 << ")";

2828 return failure();

2829 }

2830

2831

2832 if (outputRank != 0) {

2833 auto outputShape = outputType.getShape();

2834 if (!outputType.isDynamicDim(reduceAxis) &&

2835 outputShape[reduceAxis] != 1) {

2836 op.emitOpError("expect reduced dimension size to be 1, got ")

2837 << outputShape[reduceAxis];

2838 return failure();

2839 }

2840 }

2841 }

2842 return success();

2843 }

2844

2851

2858 } else {

2860 }

2861 return success();

2862 }

2863

2864 #define NARY_SHAPE_INFER(OP) \

2865 LogicalResult OP::inferReturnTypeComponents( \

2866 MLIRContext *context, ::std::optional location, \

2867 ValueShapeRange operands, DictionaryAttr attributes, \

2868 OpaqueProperties properties, RegionRange regions, \

2869 SmallVectorImpl &inferredReturnShapes) { \

2870 return NAryInferReturnTypes(operands, inferredReturnShapes); \

2871 }

2872

2910 #undef PRED_SHAPE_INFER

2911

2912 LogicalResult tosa::NegateOp::inferReturnTypeComponents(

2913 MLIRContext *context, ::std::optional location,

2914 NegateOp::Adaptor adaptor,

2916 ShapeAdaptor inputShape(adaptor.getInput1().getType());

2918 return success();

2919 }

2920

2922

2923 const Type input1Type = getInput1().getType();

2924 const Type outputType = getOutput().getType();

2926 return failure();

2927

2928

2931 return emitOpError() << "requires the same shape for input1 and output";

2932

2934 const Type input1ZpEType =

2936 if (input1EType != input1ZpEType) {

2937 return emitOpError("expect both input1 and its zero point are the same "

2938 "element type, got ")

2939 << input1EType << " and " << input1ZpEType;

2940 }

2942 const Type outputZpEType =

2944 if (outputEType != outputZpEType) {

2945 return emitOpError("expect both output and its zero point are the same "

2946 "element type, got ")

2947 << outputEType << " and " << outputZpEType;

2948 }

2949

2950 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();

2951 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())

2952 return failure();

2953

2954 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();

2955 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())

2956 return failure();

2957

2958 return success();

2959 }

2960

2966 outputShape.resize(4, ShapedType::kDynamic);

2967

2968

2969 if (!inputShape) {

2971 return success();

2972 }

2973

2974

2975 outputShape[0] = inputShape.getDimSize(0);

2976 outputShape[3] = inputShape.getDimSize(3);

2977

2978 int64_t height = inputShape.getDimSize(1);

2979 int64_t width = inputShape.getDimSize(2);

2980

2981 if (!ShapedType::isDynamic(height)) {

2982 int64_t padded = height + pad[0] + pad[1] - kernel[0];

2983 outputShape[1] = padded / stride[0] + 1;

2984 }

2985

2986 if (!ShapedType::isDynamic(width)) {

2987 int64_t padded = width + pad[2] + pad[3] - kernel[1];

2988 outputShape[2] = padded / stride[1] + 1;

2989 }

2990

2992 return success();

2993 }

2994

2995 LogicalResult Conv2DOp::inferReturnTypeComponents(

2996 MLIRContext *context, ::std::optional location,

2997 Conv2DOp::Adaptor adaptor,

3000

3001 int64_t inputWidth = ShapedType::kDynamic;

3002 int64_t inputHeight = ShapedType::kDynamic;

3003 int64_t weightWidth = ShapedType::kDynamic;

3004 int64_t weightHeight = ShapedType::kDynamic;

3005

3006

3007

3008 ShapeAdaptor inputShape(adaptor.getInput().getType());

3009 if (inputShape.hasRank()) {

3010 outputShape[0] = inputShape.getDimSize(0);

3011 inputHeight = inputShape.getDimSize(1);

3012 inputWidth = inputShape.getDimSize(2);

3013 }

3014

3015

3016 ShapeAdaptor weightShape(adaptor.getWeight().getType());

3017 if (weightShape.hasRank()) {

3018 outputShape[3] = weightShape.getDimSize(0);

3019 weightHeight = weightShape.getDimSize(1);

3020 weightWidth = weightShape.getDimSize(2);

3021 }

3022

3023

3024 ShapeAdaptor biasShape(adaptor.getBias().getType());

3025 if (biasShape.hasRank()) {

3026 outputShape[3] = ShapedType::isDynamic(outputShape[3])

3027 ? biasShape.getDimSize(0)

3028 : outputShape[3];

3029 }

3030

3034

3035 if (!ShapedType::isDynamic(inputHeight) &&

3036 !ShapedType::isDynamic(weightHeight)) {

3037 int64_t inputSize = inputHeight + padding[0] + padding[1];

3038 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;

3039 int64_t unstridedResult = inputSize - filterSize + 1;

3040 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;

3041 }

3042

3043 if (!ShapedType::isDynamic(inputWidth) &&

3044 !ShapedType::isDynamic(weightWidth)) {

3045 int64_t inputSize = inputWidth + padding[2] + padding[3];

3046 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;

3047 int64_t unstridedResult = inputSize - filterSize + 1;

3048 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;

3049 }

3050

3052 return success();

3053 }

3054

3058 return failure();

3059 return success();

3060 }

3061

3062 LogicalResult Conv3DOp::inferReturnTypeComponents(

3063 MLIRContext *context, ::std::optional location,

3064 Conv3DOp::Adaptor adaptor,

3067

3068 int64_t inputWidth = ShapedType::kDynamic;

3069 int64_t inputHeight = ShapedType::kDynamic;

3070 int64_t inputDepth = ShapedType::kDynamic;

3071

3072 int64_t weightWidth = ShapedType::kDynamic;

3073 int64_t weightHeight = ShapedType::kDynamic;

3074 int64_t weightDepth = ShapedType::kDynamic;

3075

3076

3077 ShapeAdaptor inputShape(adaptor.getInput().getType());

3078 if (inputShape.hasRank()) {

3079 outputShape[0] = inputShape.getDimSize(0);

3080 inputDepth = inputShape.getDimSize(1);

3081 inputHeight = inputShape.getDimSize(2);

3082 inputWidth = inputShape.getDimSize(3);

3083 }

3084

3085

3086 ShapeAdaptor weightShape(adaptor.getWeight().getType());

3087 if (weightShape.hasRank()) {

3088 outputShape[4] = weightShape.getDimSize(0);

3089 weightDepth = weightShape.getDimSize(1);

3090 weightHeight = weightShape.getDimSize(2);

3091 weightWidth = weightShape.getDimSize(3);

3092 }

3093

3094

3095 ShapeAdaptor biasShape(adaptor.getBias().getType());

3096 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {

3097 outputShape[4] = biasShape.getDimSize(0);

3098 }

3099

3103

3104 if (!ShapedType::isDynamic(inputDepth) &&

3105 !ShapedType::isDynamic(weightDepth)) {

3106 int32_t inputSize = inputDepth + pad[0] + pad[1];

3107 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;

3108 int32_t unstridedResult = inputSize - filterSize + 1;

3109 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;

3110 }

3111

3112 if (!ShapedType::isDynamic(inputHeight) &&

3113 !ShapedType::isDynamic(weightHeight)) {

3114 int32_t inputSize = inputHeight + pad[2] + pad[3];

3115 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;

3116 int32_t unstridedResult = inputSize - filterSize + 1;

3117 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;

3118 }

3119

3120 if (!ShapedType::isDynamic(inputWidth) &&

3121 !ShapedType::isDynamic(weightWidth)) {

3122 int32_t inputSize = inputWidth + pad[4] + pad[5];

3123 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;

3124 int32_t unstridedResult = inputSize - filterSize + 1;

3125 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;

3126 }

3127

3129 return success();

3130 }

3131

3135 return failure();

3136 return success();

3137 }

3138

3139 LogicalResult AvgPool2dOp::inferReturnTypeComponents(

3140 MLIRContext *context, ::std::optional location,

3141 AvgPool2dOp::Adaptor adaptor,

3143 ShapeAdaptor inputShape(adaptor.getInput().getType());

3144 const Properties &prop = adaptor.getProperties();

3146 inferredReturnShapes);

3147 }

3148

3149 LogicalResult MaxPool2dOp::inferReturnTypeComponents(

3150 MLIRContext *context, ::std::optional location,

3151 MaxPool2dOp::Adaptor adaptor,

3153 ShapeAdaptor inputShape(adaptor.getInput().getType());

3154 const Properties &prop = adaptor.getProperties();

3156 inferredReturnShapes);

3157 }

3158

3161 getOutput().getType())))

3162 return failure();

3163

3165 return failure();

3166

3167 return success();

3168 }

3169

3170 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(

3171 MLIRContext *context, ::std::optional location,

3172 DepthwiseConv2DOp::Adaptor adaptor,

3175

3176 int64_t inputWidth = ShapedType::kDynamic;

3177 int64_t inputHeight = ShapedType::kDynamic;

3178 int64_t inputChannels = ShapedType::kDynamic;

3179

3180 int64_t weightWidth = ShapedType::kDynamic;

3181 int64_t weightHeight = ShapedType::kDynamic;

3182 int64_t depthChannels = ShapedType::kDynamic;

3183

3184

3185 ShapeAdaptor inputShape(adaptor.getInput().getType());

3186 if (inputShape.hasRank()) {

3187 outputShape[0] = inputShape.getDimSize(0);

3188 inputHeight = inputShape.getDimSize(1);

3189 inputWidth = inputShape.getDimSize(2);

3190 inputChannels = inputShape.getDimSize(3);

3191 }

3192

3193

3194 ShapeAdaptor weightShape(adaptor.getWeight().getType());

3195 if (weightShape.hasRank()) {

3196 weightHeight = weightShape.getDimSize(0);

3197 weightWidth = weightShape.getDimSize(1);

3198 inputChannels = ShapedType::isDynamic(inputChannels)

3199 ? weightShape.getDimSize(2)

3200 : inputChannels;

3201 depthChannels = weightShape.getDimSize(3);

3202 }

3203

3204

3205

3206 if (!ShapedType::isDynamic(inputChannels) &&

3207 !ShapedType::isDynamic(depthChannels)) {

3208 outputShape[3] = inputChannels * depthChannels;

3209 }

3210

3211

3212 ShapeAdaptor biasShape(adaptor.getBias().getType());

3213 if (biasShape.hasRank()) {

3214 outputShape[3] = ShapedType::isDynamic(outputShape[3])

3215 ? biasShape.getDimSize(0)

3216 : outputShape[3];

3217 }

3218

3222

3223 if (!ShapedType::isDynamic(inputHeight) &&

3224 !ShapedType::isDynamic(weightHeight)) {

3225 int64_t inputSize = inputHeight + padding[0] + padding[1];

3226 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;

3227 int64_t unstridedResult = inputSize - filterSize + 1;

3228 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;

3229 }

3230

3231 if (!ShapedType::isDynamic(inputWidth) &&

3232 !ShapedType::isDynamic(weightWidth)) {

3233 int64_t inputSize = inputWidth + padding[2] + padding[3];

3234 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;

3235 int64_t unstridedResult = inputSize - filterSize + 1;

3236 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;

3237 }

3238

3240 return success();

3241 }

3242

3246 return failure();

3247 return success();

3248 }

3249

3250 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(

3251 MLIRContext *context, ::std::optional location,

3252 TransposeConv2DOp::Adaptor adaptor,

3255

3256 int64_t inputWidth = ShapedType::kDynamic;

3257 int64_t inputHeight = ShapedType::kDynamic;

3258 int64_t weightWidth = ShapedType::kDynamic;

3259 int64_t weightHeight = ShapedType::kDynamic;

3260

3261

3262 ShapeAdaptor inputShape(adaptor.getInput().getType());

3263 if (inputShape.hasRank()) {

3264 outputShape[0] = ShapedType::isDynamic(outputShape[0])

3265 ? inputShape.getDimSize(0)

3266 : outputShape[0];

3267 inputHeight = inputShape.getDimSize(1);

3268 inputWidth = inputShape.getDimSize(2);

3269 }

3270

3271

3272 ShapeAdaptor weightShape(adaptor.getWeight().getType());

3273 if (weightShape.hasRank()) {

3274 outputShape[3] = ShapedType::isDynamic(outputShape[3])

3275 ? weightShape.getDimSize(0)

3276 : outputShape[3];

3277 weightHeight = weightShape.getDimSize(1);

3278 weightWidth = weightShape.getDimSize(2);

3279 }

3280

3281

3282 ShapeAdaptor biasShape(adaptor.getInput().getType());

3283 if (biasShape.hasRank()) {

3284 outputShape[3] = ShapedType::isDynamic(outputShape[3])

3285 ? biasShape.getDimSize(0)

3286 : outputShape[3];

3287 }

3288

3291

3292 if (!ShapedType::isDynamic(inputHeight) &&

3293 !ShapedType::isDynamic(weightHeight)) {

3294 int64_t calculateSize =

3295 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;

3296 outputShape[1] =

3297 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];

3298 }

3299

3300 if (!ShapedType::isDynamic(inputWidth) &&

3301 !ShapedType::isDynamic(weightWidth)) {

3302 int64_t calculateSize =

3303 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;

3304 outputShape[2] =

3305 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];

3306 }

3307

3309 return success();

3310 }

3311

3314 return failure();

3315

3317 const int64_t strideY = strides[0];

3318 const int64_t strideX = strides[1];

3319

3320 if (strideY < 1 || strideX < 1)

3321 return emitOpError("expect all stride values to be >= 1, got [")

3322 << strides << "]";

3323

3324 const auto checkPadAgainstKernelDim =

3325 [this](int64_t pad_value, int64_t kernel_dim_size,

3326 llvm::StringRef pad_name,

3327 llvm::StringRef kernel_dim_name) -> LogicalResult {

3328 if (pad_value <= -kernel_dim_size)

3329 return emitOpError("expected ")

3330 << pad_name << " > -" << kernel_dim_name

3331 << ", but got: " << pad_name << "=" << pad_value << " and "

3332 << kernel_dim_name << "=" << kernel_dim_size;

3333 return success();

3334 };

3335

3337 const int64_t outPadTop = padding[0];

3338 const int64_t outPadBottom = padding[1];

3339 const int64_t outPadLeft = padding[2];

3340 const int64_t outPadRight = padding[3];

3341

3342 const auto weightType =

3343 llvm::dyn_cast(getWeight().getType());

3344

3345 if (weightType) {

3346 const int64_t kernelHeight = weightType.getDimSize(1);

3347 if (!ShapedType::isDynamic(kernelHeight)) {

3348 if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,

3349 "out_pad_top", "KH")))

3350 return failure();

3351

3352 if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,

3353 "out_pad_bottom", "KH")))

3354 return failure();

3355 }

3356

3357 const int64_t kernelWidth = weightType.getDimSize(2);

3358 if (!ShapedType::isDynamic(kernelWidth)) {

3359 if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,

3360 "out_pad_left", "KW")))

3361 return failure();

3362

3363 if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,

3364 "out_pad_right", "KW")))

3365 return failure();

3366 }

3367 }

3368

3369

3370 const auto outputType =

3371 llvm::dyn_cast(getOutput().getType());

3372 if (!outputType)

3373 return success();

3374

3375 const auto inputType = llvm::dyn_cast(getInput().getType());

3376 if (inputType && weightType) {

3377 const int64_t inputHeight = inputType.getDimSize(1);

3378 const int64_t kernelHeight = weightType.getDimSize(1);

3379 const int64_t outputHeight = outputType.getDimSize(1);

3380

3381 if (!ShapedType::isDynamic(inputHeight) &&

3382 !ShapedType::isDynamic(outputHeight)) {

3383 if (outputHeight !=

3384 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)

3385 return emitOpError(

3386 "dimension mismatch: expected OH == (IH - 1) * stride_y "

3387 "+ out_pad_top + out_pad_bottom + KH, but got ")

3388 << outputHeight << " != (" << inputHeight << " - 1) * "

3389 << strideY << " + " << outPadTop << " + " << outPadBottom

3390 << " + " << kernelHeight;

3391 }

3392

3393 const int64_t inputWidth = inputType.getDimSize(2);

3394 const int64_t kernelWidth = weightType.getDimSize(2);

3395 const int64_t outputWidth = outputType.getDimSize(2);

3396

3397 if (!ShapedType::isDynamic(inputWidth) &&

3398 !ShapedType::isDynamic(outputWidth)) {

3399 if (outputWidth !=

3400 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)

3401 return emitOpError(

3402 "dimension mismatch: expected OW == (IW - 1) * stride_x "

3403 "+ out_pad_left + out_pad_right + KW, but got ")

3404 << outputWidth << " != (" << inputWidth << " - 1) * " << strideX

3405 << " + " << outPadLeft << " + " << outPadRight << " + "

3406 << kernelWidth;

3407 }

3408 }

3409

3410 const auto biasType = llvm::dyn_cast(getBias().getType());

3411

3412 if (!biasType)

3413 return success();

3414

3415 const int64_t biasChannels = biasType.getDimSize(0);

3416

3417

3418 if (biasChannels == ShapedType::kDynamic)

3419 return success();

3420

3421 const int64_t outputChannels = outputType.getDimSize(3);

3422 if (biasChannels != outputChannels && biasChannels != 1)

3423 return emitOpError(

3424 "bias channels expected to be equal to output channels (")

3425 << outputChannels << ") or 1, got " << biasChannels;

3426

3427 return success();

3428 }

3429

3431 auto inputType = llvm::dyn_cast(getInput().getType());

3432 if (!inputType) {

3433 emitOpError("expect shaped tensor for input, got ") << getInput().getType();

3434 return failure();

3435 }

3436

3437 auto inputElementType =

3439 if (!mlir::isa(inputElementType)) {

3440 emitOpError("expect input to have integer element type, got ")

3441 << inputElementType;

3442 return failure();

3443 }

3444

3445 auto outputType = llvm::dyn_cast(getOutput().getType());

3446 if (!outputType) {

3447 emitOpError("expect shaped tensor for output, got ")

3448 << getOutput().getType();

3449 return failure();

3450 }

3451

3452 auto outputElementType =

3454 if (!mlir::isa(outputElementType)) {

3455 emitOpError("expect output to have integer element type, got ")

3456 << outputElementType;

3457 return failure();

3458 }

3459

3461 .failed())

3462 return failure();

3463

3465 .failed())

3466 return failure();

3467

3468 FailureOr<int64_t> maybeIZp = getInputZeroPoint();

3469 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())

3470 return failure();

3471

3472 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();

3473 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())

3474 return failure();

3475

3476 auto multiplierType = llvm::dyn_cast(getMultiplier().getType());

3477 if (!multiplierType) {

3478 emitOpError("expect shaped tensor for multiplier, got ")

3479 << getMultiplier().getType();

3480 return failure();

3481 }

3482

3483 auto shiftType = llvm::dyn_cast(getShift().getType());

3484 if (!shiftType) {

3485 emitOpError("expect shaped tensor for shift, got ") << getShift().getType();

3486 return failure();

3487 }

3488

3489

3490 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {

3491 emitOpError("expect i32 element type for multiplier for scale32=true, got ")

3492 << multiplierType.getElementType();

3493 return failure();

3494 }

3495

3496

3497 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {

3498 emitOpError(

3499 "expect i16 element type for multiplier for scale32=false, got ")

3500 << multiplierType.getElementType();

3501 return failure();

3502 }

3503

3504 if (!inputType.hasRank())

3505 return success();

3506

3507

3508

3509

3510 int64_t numChannels = 1;

3511 if (getPerChannel()) {

3512 if (inputType.getRank() < 1) {

3513 emitOpError("requires input to be at least rank 1 when per_channel is "

3514 "true, but got rank ")

3515 << inputType.getRank();

3516 return failure();

3517 }

3518 numChannels = inputType.getDimSize(inputType.getRank() - 1);

3519 }

3520

3521 if (!multiplierType.hasRank())

3522 return success();

3523

3525

3526 if (multiplierShape[0] != ShapedType::kDynamic &&

3527 multiplierShape[0] != numChannels) {

3528 emitOpError("expect shape of { ")

3529 << numChannels << " } for multiplier input, got { "

3530 << multiplierShape[0] << " }";

3531 return failure();

3532 }

3533

3534 if (!shiftType.hasRank())

3535 return success();

3536

3538

3539 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {

3540 emitOpError("expect shape of { ")

3541 << numChannels << " } for shift input, got { " << shiftShape[0] << " }";

3542 return failure();

3543 }

3544

3545 return success();

3546 }

3547

3548 LogicalResult RescaleOp::inferReturnTypeComponents(

3549 MLIRContext *context, ::std::optional location,

3550 RescaleOp::Adaptor adaptor,

3552 ShapeAdaptor inputShape(adaptor.getInput().getType());

3554 return success();

3555 }

3556

3557 LogicalResult IfOp::inferReturnTypeComponents(

3558 MLIRContext *context, ::std::optional location,

3559 IfOp::Adaptor adaptor,

3562 for (Region *region : adaptor.getRegions()) {

3563 for (auto &block : *region)

3564 if (auto returnOp = dyn_casttosa::YieldOp(block.getTerminator()))

3565 yieldOps.push_back(returnOp);

3566 }

3567

3568 if (yieldOps.empty())

3569 return failure();

3570

3571

3573 resultKnowledge.reserve(yieldOps.front().getNumOperands());

3574 for (auto operand : yieldOps.front().getOperands()) {

3575 resultKnowledge.push_back(

3577 }

3578

3579 for (auto yieldOp : yieldOps) {

3580 if (resultKnowledge.size() != yieldOp.getNumOperands())

3581 return failure();

3582

3583 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {

3584 int32_t index = it.index();

3586 resultKnowledge[index],

3588 if (!meet)

3589 continue;

3590 resultKnowledge[index] = meet;

3591 }

3592 }

3593

3594 for (const ValueKnowledge &result : resultKnowledge) {

3595 inferredReturnShapes.push_back(result.getShapedTypeComponents());

3596 }

3597

3598 return success();

3599 }

3600

3601 LogicalResult WhileOp::inferReturnTypeComponents(

3602 MLIRContext *context, ::std::optional location,

3603 WhileOp::Adaptor adaptor,

3606 for (auto &block : adaptor.getBodyGraph())

3607 if (auto returnOp = dyn_casttosa::YieldOp(block.getTerminator()))

3608 yieldOps.push_back(returnOp);

3609

3610

3611

3612 if (yieldOps.empty())

3613 return failure();

3614

3615

3617 resultKnowledge.reserve(yieldOps.front().getNumOperands());

3618 for (auto operand : yieldOps.front().getOperands()) {

3619 resultKnowledge.push_back(

3621 }

3622

3623 for (auto yieldOp : yieldOps) {

3624 if (resultKnowledge.size() != yieldOp.getNumOperands())

3625 return failure();

3626

3627 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {

3628 int32_t index = it.index();

3630 resultKnowledge[index],

3632 resultKnowledge[index] = meet;

3633 }

3634 }

3635 }

3636

3637 for (const ValueKnowledge &result : resultKnowledge) {

3638 inferredReturnShapes.push_back(result.getShapedTypeComponents());

3639 }

3640

3641 return success();

3642 }

3643

3644 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {

3645 if (auto vt = llvm::dyn_cast(getType()))

3646 return llvm::to_vector<4>(vt.getShape());

3647 return std::nullopt;

3648 }

3649

3650

3652

3653 result.regions.reserve(2);

3656

3657 auto &builder = parser.getBuilder();

3659

3663 return failure();

3664

3666 return failure();

3667

3668 if (parser.parseRegion(*thenRegion, {}, {}))

3669 return failure();

3670

3671

3673 if (parser.parseRegion(*elseRegion, {}, {}))

3674 return failure();

3675 }

3676

3677

3679 return failure();

3680 return success();

3681 }

3682

3684 bool printBlockTerminators = false;

3685

3686 p << " " << getCondition();

3687 if (!getResults().empty()) {

3688 p << " -> (" << getResultTypes() << ")";

3689

3690 printBlockTerminators = true;

3691 }

3692 p << ' ';

3694 false,

3695 printBlockTerminators);

3696

3697

3698 auto &elseRegion = getElseGraph();

3699 if (!elseRegion.empty()) {

3700 p << " else ";

3702 false,

3703 printBlockTerminators);

3704 }

3705

3707 }

3708

3711 "'then_graph' arguments", getInputList(),

3712 "'input_list'")

3713 .failed())

3714 return failure();

3715

3717 "'else_graph' arguments", getInputList(),

3718 "'input_list'")

3719 .failed())

3720 return failure();

3721

3722 auto thenYield = casttosa::YieldOp(getThenGraph().front().getTerminator());

3724 "'then_graph' results", getOutputList(),

3725 "'output_list'")

3726 .failed())

3727 return failure();

3728

3729 auto elseYield = casttosa::YieldOp(getElseGraph().front().getTerminator());

3731 "'else_graph' results", getOutputList(),

3732 "'output_list'")

3733 .failed())

3734 return failure();

3735

3736 auto condType = getCondition().getType();

3738 return emitOpError() << "'condition' must be a size 1 tensor, got "

3739 << condType;

3740

3741 return success();

3742 }

3743

3746 getOutputList(), "'output_list'")

3747 .failed())

3748 return failure();

3749

3751 "'cond_graph' arguments", getInputList(),

3752 "'input_list'")

3753 .failed())

3754 return failure();

3755

3757 "'body_graph' arguments", getInputList(),

3758 "'input_list'")

3759 .failed())

3760 return failure();

3761

3762 auto bodyYield = casttosa::YieldOp(getBodyGraph().front().getTerminator());

3764 "'body_graph' results", getInputList(),

3765 "'input_list'")

3766 .failed())

3767 return failure();

3768

3769

3770

3771 auto condYield = casttosa::YieldOp(getCondGraph().front().getTerminator());

3772 if (condYield.getInputs().size() != 1)

3773 return emitOpError() << "require 'cond_graph' only have one result";

3774

3775 auto condOutType = condYield.getInputs()[0].getType();

3777 return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "

3778 << condOutType;

3779

3781 return emitOpError() << "'cond_graph' result must be a boolean tensor, got "

3782 << condOutType;

3783

3784 return success();

3785 }

3786

3789 getOutput().getType())

3790 .failed())

3791 return failure();

3792 TensorType inputType = getInput1().getType();

3793 TensorType outputType = getOutput().getType();

3794 int32_t reverseAxis = getAxis();

3795

3796 if (reverseAxis < 0)

3797 return emitOpError("expected non-negative reverse axis");

3798 if (inputType.hasRank()) {

3799 int64_t inputRank = inputType.getRank();

3800

3801

3802 if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))

3803 return emitOpError("expect input tensor rank (")

3804 << inputRank << ") to be larger than reverse axis (" << reverseAxis

3805 << ")";

3806 }

3807 if (outputType.hasRank()) {

3808 int64_t outputRank = outputType.getRank();

3809 if (inputType.hasRank() && outputRank != inputType.getRank())

3810 return emitOpError(

3811 "expect output tensor rank to be equal to input tensor rank");

3812 if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))

3813 return emitOpError("expect output tensor rank (")

3814 << outputRank << ") to be larger than reverse axis ("

3815 << reverseAxis << ")";

3816 }

3817 return success();

3818 }

3819

3821

3823 getOutput().getType())

3824 .failed() ||

3826 getOutput().getType())

3827 .failed()) {

3828 return failure();

3829 }

3830

3831 auto predicateType = llvm::dyn_cast(getInput1().getType());

3832 if (!predicateType) {

3833 return emitOpError("expect shaped tensor for input1, got ")

3834 << getInput1().getType();

3835 }

3836 auto predicateElementType = predicateType.getElementType();

3837 if (!predicateElementType.isInteger(1)) {

3838 return emitOpError("expect element type of bool for input1, got ")

3839 << predicateElementType;

3840 }

3841

3842 return success();

3843 }

3844

3846 StringRef symName = getName();

3847 FailureOrtosa::VariableOp varOp = findVariableDecl(*this, symName);

3848 if (succeeded(varOp))

3849 return emitOpError("illegal to have multiple declaration of '")

3850 << symName << "'";

3851

3852 return success();

3853 }

3854

3857 .failed())

3858 return failure();

3859

3860 return success();

3861 }

3862

3865 .failed())

3866 return failure();

3867

3868 return success();

3869 }

3870

3871

3877

3880 if (listResult.has_value() && failed(listResult.value()))

3881 return failure();

3882

3883 FunctionType functionType;

3886 return failure();

3887

3888 result.addTypes(functionType.getResults());

3889

3890 if (functionType.getNumInputs() != operands.size()) {

3891 return parser.emitError(typeLoc)

3892 << "expected as many input types as operands "

3893 << "(expected " << operands.size() << " got "

3894 << functionType.getNumInputs() << ")";

3895 }

3896

3897

3898 if (failed(parser.resolveOperands(operands, functionType.getInputs(),

3901 return failure();

3902

3903

3904 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)

3905 regionArgs[i].type = functionType.getInput(i);

3906

3907 return failure(parser.parseRegion(*cond, regionArgs) ||

3910 }

3911

3915 StringRef prefix = "") {

3916 assert(blocksArgs.size() == initializers.size() &&

3917 "expected same length of arguments and initializers");

3918 if (initializers.empty())

3919 return;

3920

3921 parser << prefix << '(';

3922 llvm::interleaveComma(

3923 llvm::zip(blocksArgs, initializers), parser,

3924 [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });

3925 parser << ")";

3926 }

3927

3930 getInputList(), " ");

3931 parser << " : ";

3933 getResults().getTypes());

3934 parser << ' ';

3935 parser.printRegion(getCondGraph(), false);

3936 parser << " do ";

3939 }

3940

3941

3944 Type srcElemType,

3945 int64_t zp) {

3948 if (llvm::isa(srcElemType)) {

3950 zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));

3951 return builder.createtosa::ConstOp(loc, zpType, zpAttr);

3952 }

3953 if (llvm::isa(srcElemType)) {

3954 auto zpAttr =

3956 return builder.createtosa::ConstOp(loc, zpType, zpAttr);

3957 }

3958 llvm::errs() << "zero point is not allowed for unsupported data types\n";

3959 return std::nullopt;

3960 }

3961

3962

3963

3964

3965

3967 return mlir::isatosa::shapeType(t);

3968 }

3969

3970 LogicalResult

3972 int rank) {

3973 if (rank < 0)

3974 return emitError() << "invalid rank (must be >= 0): " << rank;

3975 return success();

3976 }

3977

3980 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {

3981 Operation *definingOp = v.getDefiningOp();

3983 return op->emitOpError("shape operand is not compile time resolvable");

3984 }

3985 }

3986 }

3987 return success();

3988 }

3989

3992 if (!mlir::isamlir::tosa::shapeType(type)) {

3993 return op->emitOpError("must have operands with tosa shape type");

3994 }

3995 }

3997 if (!mlir::isamlir::tosa::shapeType(type)) {

3998 return op->emitOpError("must have result with tosa shape type");

3999 }

4000 }

4001 return success();

4002 }

4003

4004 LogicalResult

4008 return failure();

4009

4010

4011 auto getRank = [](const Type type) {

4012 return mlir::castmlir::tosa::shapeType(type).getRank();

4013 };

4016

4018 for (auto type : operandTypes) {

4019 if (getRank(type) != rank) {

4020 return op->emitOpError("operands don't have matching ranks");

4021 }

4022 }

4023 for (auto type : resultTypes) {

4024 if (getRank(type) != rank) {

4025 return op->emitOpError("result shape has different rank than operands");

4026 }

4027 }

4028 return success();

4029 }

4030

4031

4032

4033

4034

4036

4037 auto valuesRank = getValues().getType().getRank();

4038 if (valuesRank != 1)

4039 return emitOpError("expect elements in attribute values with rank 1");

4040

4041 auto count = getValues().getNumElements();

4042 auto rank = (casttosa::shapeType(getResult().getType())).getRank();

4043 if (!(count == rank || (count == 1 && rank == 0))) {

4044 return emitOpError("expect number of elements in attribute values (")

4045 << count << ") to be equal to the rank (" << rank

4046 << ") for the result shape type";

4047 }

4048 return success();

4049 }

4050

4051

4052

4053

4054

4055 #define GET_ATTRDEF_CLASSES

4056 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"

4057

4058

4059

4060

4061 #define GET_TYPEDEF_CLASSES

4062 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"

4063

4064

4065

4066

4067

4068 #define GET_OP_CLASSES

4069 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"

static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)

Maps the 2-dim memref shape to the 64-bit stride.

static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)

A utility function used to materialize a constant for a given attribute and type.

static MLIRContext * getContext(OpFoldResult val)

static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)

Utility to check that all of the operations within 'src' can be inlined.

static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)

static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)

The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...

static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)

static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)

static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)

static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)

static FailureOr< tosa::VariableOp > findVariableDecl(Operation *op, StringRef symName)

static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)

#define REDUCE_SHAPE_INFER(OP)

static LogicalResult verifyConvOp(T op)

static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)

static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)

static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)

This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...

static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)

static LogicalResult verifyReduceOp(T op)

#define NARY_SHAPE_INFER(OP)

#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)

static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)

Handles tosa.transpose_conv2d which has outpad and output shape attributes.

static LogicalResult verifyConvOpErrorIf(T op)

static LogicalResult verifyConvOpModes(T op)

std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)

static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)

#define COMPATIBLE_RETURN_TYPES(OP)

static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)

Type getStorageElementTypeOrSelf(Type type)

static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)

This builder is called on single-parameter negate operator to construct input and output zero points ...

static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)

This builder is called on all convolution operators except TransposeConv, which has specialized outpu...

static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)

Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...

static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)

static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)

static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)

static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")

static LogicalResult verifyPoolingOp(T op)

static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)

static ArrayRef< int64_t > getShape(Type type)

Returns the shape of the given type.

virtual Builder & getBuilder() const =0

Return a builder which provides useful access to MLIRContext, global objects like types and attribute...

virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0

Parse a named dictionary into 'result' if it is present.

virtual ParseResult parseOptionalEqual()=0

Parse a = token if present.

virtual ParseResult parseOptionalKeyword(StringRef keyword)=0

Parse the given keyword if present.

MLIRContext * getContext() const

virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0

Emit a diagnostic at the specified location and return failure.

virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0

Parse a named dictionary into 'result' if the attributes keyword is present.

virtual ParseResult parseColonType(Type &result)=0

Parse a colon followed by a type.

virtual SMLoc getCurrentLocation()=0

Get the location of the next token and store it into the argument.

virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0

Parse an optional arrow followed by a type list.

ParseResult parseKeyword(StringRef keyword)

Parse a given keyword.

virtual ParseResult parseAttribute(Attribute &result, Type type={})=0

Parse an arbitrary attribute of a given type and return it in result.

virtual void printAttribute(Attribute attr)

Attributes are known-constant values of operations.

MutableArrayRef< BlockArgument > BlockArgListType

This class is a general helper class for creating context-global objects like types,...

IntegerAttr getIndexAttr(int64_t value)

IntegerAttr getIntegerAttr(Type type, int64_t value)

FloatAttr getFloatAttr(Type type, double value)

IntegerType getIntegerType(unsigned width)

StringAttr getStringAttr(const Twine &bytes)

DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)

An attribute that represents a reference to a dense vector or tensor object.

auto getValues() const

Return the held element values as a range of the given type.

static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)

Constructs a dense elements attribute from an array of element values.

An attribute that represents a reference to a dense integer vector or tensor object.

This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...

virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0

Emit an error to the reader.

This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...

This is the interface that must be implemented by the dialects of operations to be inlined.

DialectInlinerInterface(Dialect *dialect)

This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.

Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...

This is a utility class for mapping one set of IR entities to another.

This class represents a diagnostic that is inflight and set to be reported.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0

virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0

Parses a region.

virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0

Resolve an operand to an SSA value, emitting an error on failure.

ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)

Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...

virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0

Parse a single SSA value operand name along with a result number if allowResultNumber is true.

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0

If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...

virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0

If the specified operation has attributes, print out an attribute dictionary with their values.

void printFunctionalType(Operation *op)

Print the complete type of an operation in functional form.

virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0

Prints a region.

This class helps build Operations.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class indicates that op operates on tosa shape types.

Simple wrapper around a void* in order to express generically how to pass in op properties through AP...

This class implements the operand iterators for the Operation class.

Operation is the basic unit of execution within MLIR.

bool hasTrait()

Returns true if the operation was registered with a particular trait, e.g.

OpTy getParentOfType()

Return the closest surrounding parent operation that is of type 'OpTy'.

operand_type_range getOperandTypes()

result_type_range getResultTypes()

operand_range getOperands()

Returns an iterator on the underlying Value's.

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

This class implements Optional functionality for ParseResult.

ParseResult value() const

Access the internal ParseResult value.

bool has_value() const

Returns true if we contain a valid ParseResult value.

This class provides an abstraction over the different types of ranges over Regions.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...

bool isDynamicDim(int index) const

Returns whether the index'th dimension is dynamic.

int64_t getDimSize(int index) const

Returns the size of the index'th dimension.

int64_t getRank() const

Returns the rank of the shape.

bool hasStaticShape() const

Returns whether the shape is fully static.

int64_t getNumElements() const

Returns the number of elements in the shape.

void getDims(SmallVectorImpl< int64_t > &res) const

Populates the dimensions from shape referenced.

bool hasRank() const

Returns whether the shape has a rank.

ShapedTypeComponents that represents the components of a ShapedType.

Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...

ArrayRef< int64_t > getShape() const

Returns the shape of this tensor type.

bool hasRank() const

Returns if this type is ranked, i.e. it has a known number of dimensions.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isInteger() const

Return true if this is an integer type (with the specified width).

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class provides an abstraction over the different types of ranges over Values.

type_range getTypes() const

Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).

ShapeAdaptor getShape(int index) const

Returns the shape of index'th operand.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

static WalkResult advance()

static WalkResult interrupt()

Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...

LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)

LogicalResult verifyTosaShapeOperator(Operation *op)

LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)

LogicalResult verifyTosaResolvableShapeOperands(Operation *op)

bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)

Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

Operation::operand_range getIndices(Operation *op)

Get the indices that the given load/store operation is operating on.

QueryRef parse(llvm::StringRef line, const QuerySession &qs)

ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)

Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...

RankedTensorType getVariableType(VariableOp variableOp)

Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)

construct ConvOp output type with correct bitwidth based on input/weight width.

ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)

PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)

Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.

std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)

void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)

MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)

Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...

std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)

bool isa_tosa_shape_type(mlir::Type t)

SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)

UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)

Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...

Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)

bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)

Reify the shape of the result of an operation (typically in terms of the shape of its operands).

LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)

Returns success if the given two arrays have the same number of elements and each pair wise entries h...

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)

Overloads of the above emission functions that take an optionally null location.

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)

Returns success if the given two shapes are compatible.

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

bool isPermutationVector(ArrayRef< int64_t > interchange)

Method to check if an interchange vector is a permutation.

This is the representation of an operand reference.

This represents an operation in an abstracted form, suitable for use with the builder APIs.

SmallVector< Value, 4 > operands

void addOperands(ValueRange newOperands)

void addAttribute(StringRef name, Attribute attr)

Add an attribute with the specified name.

void addTypes(ArrayRef< Type > newTypes)

SmallVector< std::unique_ptr< Region >, 1 > regions

Regions that the op will hold.

SmallVector< Type, 4 > types

Types of the results of this operation.

Region * addRegion()

Create a region that should be attached to the operation.

Statically known information for a particular Value.

static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)

static ValueKnowledge getKnowledgeFromType(Type type)