MLIR: lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

28 #include "llvm/ADT/APFloat.h"

29 #include "llvm/ADT/APInt.h"

30 #include "llvm/ADT/DenseMap.h"

31 #include "llvm/ADT/TypeSwitch.h"

32

33 #include

34

35 using namespace mlir;

37

38

39

40

41

42

43

44

45

46

48

51 (padConstAttr.size() != 1)) {

52 return false;

53 }

54

55

56 if (auto padConstFpAttr = mlir::dyn_cast(padConstAttr)) {

57 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();

58 return padConstVal == 0.0f;

59 }

60

61

62 if (auto padConstIntAttr =

63 mlir::dyn_cast(padConstAttr)) {

65

67 return false;

68 }

69

70

71 int64_t zpVal = (*zpAttr.begin()).getSExtValue();

72 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();

73 return zpVal == padConstVal;

74 }

75

76

77 return false;

78 }

79

80 namespace {

81 template

82 struct PoolPadFoldAdaptor;

83

84 template <>

85 struct PoolPadFoldAdaptortosa::AvgPool2dOp {

86 using OpTy = tosa::AvgPool2dOp;

87 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {

89 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||

90 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])

91 return false;

92 return true;

93 }

94 static bool checkPadConstCompliance(OpTy op, Value padConst) {

96 }

97 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,

100 op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),

102 op.getAccType());

103 }

104 };

105

106 template <>

107 struct PoolPadFoldAdaptortosa::MaxPool2dOp {

108 using OpTy = tosa::MaxPool2dOp;

109 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {

111 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||

112 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])

113 return false;

114 return true;

115 }

116 static bool checkPadConstCompliance(OpTy, Value padConst) {

117

120 padConstAttr.size() != 1) {

121 return false;

122 }

123

124

125 if (auto padConstFpAttr =

126 mlir::dyn_cast(padConstAttr)) {

127 const APFloat padConstVal = *padConstFpAttr.begin();

128 const APFloat lowestVal =

129 APFloat::getLargest(padConstVal.getSemantics(), true);

130 return padConstVal == lowestVal;

131 } else if (auto padConstIntAttr =

132 mlir::dyn_cast(padConstAttr)) {

133 const APInt padConstVal = *padConstIntAttr.begin();

134 const unsigned int bitWidth = padConstVal.getBitWidth();

135 const APInt lowestVal =

136 padConstIntAttr.getElementType().isUnsignedInteger()

138 : APInt::getSignedMinValue(bitWidth);

139 return padConstVal == lowestVal;

140 }

141

142

143 return false;

144 }

145 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,

148 op, op.getType(), padInput, op.getKernel(), op.getStride(),

150 }

151 };

152

153 template

154 struct ConvPadFoldAdaptor {

155 static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {

156 return true;

157 }

158 static bool checkPadConstCompliance(OpTy op, Value padConst) {

160 }

161 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,

164 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),

165 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),

166 op.getDilationAttr(), op.getAccType(), op.getLocalBound());

167 }

168 };

169

170

171

172

173

174 template <typename OpTy, typename AdaptorTy>

177

178 LogicalResult matchAndRewrite(OpTy tensorOp,

180

181 auto padOp = tensorOp.getInput().template getDefiningOptosa::PadOp();

182 if (!padOp)

184 "Producer must be a tosa::PadOp.");

185

186

187 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();

188 if (tensorOpPad.size() != 4)

190 tensorOp, "Tensor operation padding shall have 4 elements.");

191

192

196 tensorOp,

197 "The `padding` input specified on the tosa::PadOp must be constant.");

198 }

199

200

201 if (padOpPadding.size() != 8)

203 "Pad padding should have 8 elements.");

204 int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();

205 int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();

206 int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();

207 int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();

208 int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();

209 int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();

210 int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();

211 int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();

212

213 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)

215 tensorOp, "Folding padding in N or C dimensions is not supported.");

216

217

218

220 foldedPad[0] = padHBefore + tensorOpPad[0];

221 foldedPad[1] = padHAfter + tensorOpPad[1];

222 foldedPad[2] = padWBefore + tensorOpPad[2];

223 foldedPad[3] = padWAfter + tensorOpPad[3];

224

225

226 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {

228 tensorOp, "Padding size not aligned with kernel restrictions.");

229 }

230

231

232 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {

234 tensorOp,

235 "Padding constant is not aligned with operator zero-point.");

236 }

237

238

239 if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {

241 tensorOp, "Padding size more than the 8K level limit.");

242 }

243

244

245 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),

246 foldedPad);

247

248 return success();

249 }

250 };

251 }

252

253 void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,

255 results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,

256 PoolPadFoldAdaptortosa::AvgPool2dOp>>(

257 context);

258 }

259

260 void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,

262 results.add<

263 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptortosa::Conv2DOp>>(

264 context);

265 }

266

267 void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,

269 results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,

270 ConvPadFoldAdaptortosa::DepthwiseConv2DOp>>(

271 context);

272 }

273

276

279 Value input = op.getInput();

280 Value output = op.getOutput();

281 ShapedType inputType = llvm::cast(input.getType());

282 ShapedType outputType = llvm::cast(output.getType());

283

284 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {

285 return failure();

286 }

287

288

290 if (outputShape[1] != 1 || outputShape[2] != 1) {

291 return failure();

292 }

293

295 if (inputShape[1] != 1 || inputShape[2] != 1) {

296 return failure();

297 }

298

300 return success();

301 }

302 };

303

304 void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,

307 FoldPadToTensorOp<tosa::MaxPool2dOp,

308 PoolPadFoldAdaptortosa::MaxPool2dOp>>(

309 context);

310 }

311

312

313

314

315

318

321 if (op.getInput1().size() != 1)

322 return failure();

323 if (op.getInput1().front().getType() != op.getType()) {

324 rewriter

326 op.getInput1().front())

327 .getResult();

328 return success();

329 }

330

331 rewriter.replaceOp(op, op.getInput1().front());

332 return success();

333 }

334 };

335

336 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,

339 }

340

341 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {

342 auto notOp = op.getInput1().getDefiningOptosa::LogicalNotOp();

343 if (!notOp)

344 return failure();

346 op.getOperation()->setOperands(

347 {notOp.getInput1(), op.getInput3(), op.getInput2()});

348 });

349 return success();

350 }

351

355

358

359 auto innerTranspose =

360 transposeOp.getInput1().getDefiningOptosa::TransposeOp();

361 if (!innerTranspose)

363 "input must be transpose operation");

364

367 innerTranspose.getPerms();

368

369 if (transposePerms.size() != innerTransposePerms.size())

371 transposeOp,

372 "transpose and inner transpose perms sizes must be equal");

373 if (transposePerms.empty())

375 transposeOp, "transpose perms sizes must be positive");

376

377

379 for (int i = 0, s = transposePerms.size(); i < s; ++i)

380 perms[i] = innerTransposePerms[transposePerms[i]];

381

383 transposeOp, transposeOp.getResult().getType(),

385

386 return success();

387 }

388 };

389

390

393

396 if (op.getInput1().getDefiningOptosa::TransposeOp())

398 op, "Src is from transpose, can compose transposes");

399

400 Value result = op.getResult();

402 if (isa_and_nonnulltosa::TransposeOp(subop))

404 op, "Dest is used by transpose, can compose transposes");

405 }

406

407 auto input = op.getInput1();

408 auto inputTy = llvm::cast(input.getType());

409 if (!inputTy.hasRank())

411

412 int64_t numDynDims = 0;

413 for (int i = 0; i < inputTy.getRank(); ++i)

414 if (inputTy.isDynamicDim(i))

415 numDynDims++;

416

417 if (numDynDims > 1)

418 return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");

419

421

423 nonZeroPerms.reserve(permValues.size());

424 for (auto idx : permValues) {

425 auto sz = inputTy.getDimSize(idx);

426 if (sz != 1)

427 nonZeroPerms.push_back(idx);

428 }

429

430 for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)

431 if (nonZeroPerms[i - 1] > nonZeroPerms[i])

433 "Transpose changes memory layout.");

434

436 newShape.reserve(inputTy.getRank());

437 for (int i = 0, s = inputTy.getRank(); i < s; ++i)

438 newShape.push_back(inputTy.getDimSize(permValues[i]));

439

441 op, op.getType(), op.getInput1(),

443 return success();

444 }

445 };

446

447 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,

450 }

451

454

457 Value input = op.getInput();

458 auto inputType = llvm::dyn_cast(op.getInput().getType());

459 auto inputElementType = inputType.getElementType();

460

461 if (!inputType.hasStaticShape()) {

462 return failure();

463 }

464

465 if (isa(inputElementType)) {

466

467 auto minClamp =

468 llvm::castmlir::FloatAttr(op.getMinValAttr()).getValue();

469 auto maxClamp =

470 llvm::castmlir::FloatAttr(op.getMaxValAttr()).getValue();

471 bool isMin = minClamp.isNegInfinity();

472 bool isMax = maxClamp.isInfinity();

473

474 if (isMin && isMax) {

476 return success();

477 }

478 return failure();

479 }

480

481 if (inputElementType.isUnsignedInteger()) {

482 int64_t minClamp =

483 llvm::castmlir::IntegerAttr(op.getMinValAttr()).getUInt();

484 int64_t maxClamp =

485 llvm::castmlir::IntegerAttr(op.getMaxValAttr()).getUInt();

486

487 int64_t intMin =

488 APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())

489 .getZExtValue();

490 int64_t intMax =

491 APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())

492 .getZExtValue();

493

494 if (minClamp <= intMin && maxClamp >= intMax) {

496 return success();

497 }

498 return failure();

499 }

500

501 if (llvm::isa(inputElementType)) {

502 int64_t minClamp =

503 llvm::castmlir::IntegerAttr(op.getMinValAttr()).getInt();

504 int64_t maxClamp =

505 llvm::castmlir::IntegerAttr(op.getMaxValAttr()).getInt();

506

507 int64_t intMin =

508 APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())

509 .getSExtValue();

510 int64_t intMax =

511 APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())

512 .getSExtValue();

513

514 if (minClamp <= intMin && maxClamp >= intMax) {

516 return success();

517 }

518 return failure();

519 }

520

521 return failure();

522 }

523 };

524

525

526

527

528

529

530

531

532

533

534

535

536

537

538

539

540

541

544

545

546 template

548 ClampRange(const T &start, const T &end) : start(start), end(end) {}

551

552

554 return start < otherRange.end && otherRange.start < end;

555 }

556 };

557

560 Value input = op.getInput();

561

562

563 auto clampOp = dyn_cast_if_presenttosa::ClampOp(input.getDefiningOp());

564 if (!clampOp)

565 return failure();

566

567

568 const auto opNanMode = op.getNanMode();

569 const auto clampNanMode = clampOp.getNanMode();

570 if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")

571 return failure();

572

573 auto maxValAttr = op.getMaxValAttr();

574 auto minValAttr = op.getMinValAttr();

575 auto clampOpMaxValAttr = clampOp.getMaxValAttr();

576 auto clampOpMinValAttr = clampOp.getMinValAttr();

577

578 auto inputEType = llvm::cast(input.getType()).getElementType();

579 if (auto quantType =

580 llvm::dyn_castmlir::quant::UniformQuantizedType(inputEType)) {

581 inputEType = quantType.getStorageType();

582 }

583

584 Attribute newMinValAttr, newMaxValAttr;

585 if (mlir::isa(inputEType)) {

586 auto floatMaxValAttr = castmlir::FloatAttr(maxValAttr);

587 auto floatMinValAttr = castmlir::FloatAttr(minValAttr);

588 auto clampOpFloatMaxValAttr = castmlir::FloatAttr(clampOpMaxValAttr);

589 auto clampOpFloatMinValAttr = castmlir::FloatAttr(clampOpMinValAttr);

590

591

592 const auto opMinFloat = floatMinValAttr.getValue();

593 const auto opMaxFloat = floatMaxValAttr.getValue();

594 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();

595 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();

598 clampOpMaxFloat);

599 if (!opRangeFloatRange.intersects(clampRangeFloatRange))

600 return failure();

601

602

603 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);

604 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);

605 newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);

606 newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);

607 } else {

608 assert(mlir::isa(inputEType));

609 auto intMaxValAttr = castmlir::IntegerAttr(maxValAttr);

610 auto intMinValAttr = castmlir::IntegerAttr(minValAttr);

611 auto clampOpIntMaxValAttr = castmlir::IntegerAttr(clampOpMaxValAttr);

612 auto clampOpIntMinValAttr = castmlir::IntegerAttr(clampOpMinValAttr);

613

614 if (inputEType.isUnsignedInteger()) {

615

616 const auto opMinInt = intMinValAttr.getUInt();

617 const auto opMaxInt = intMaxValAttr.getUInt();

618 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();

619 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();

622 clampOpMaxInt);

623 if (!opRangeIntRange.intersects(clampRangeIntRange))

624 return failure();

625

626

627 auto newMinVal = std::max(opMinInt, clampOpMinInt);

628 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);

629 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);

630 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);

631 } else {

632

633 const auto opMinInt = intMinValAttr.getInt();

634 const auto opMaxInt = intMaxValAttr.getInt();

635 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();

636 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();

639 clampOpMaxInt);

640 if (!opRangeIntRange.intersects(clampRangeIntRange))

641 return failure();

642

643

644 auto newMinVal = std::max(opMinInt, clampOpMinInt);

645 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);

646 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);

647 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);

648 }

649 }

650

652 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,

653 rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"

654 : opNanMode));

655 return success();

656 }

657 };

658

659 void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,

663 }

664

667

670 Value sliceInput = sliceOp.getInput1();

671 auto concatOp = sliceInput.getDefiningOptosa::ConcatOp();

672 if (!concatOp)

674 sliceOp, "slice input must be concat operation");

675

677 auto concatType = dyn_cast(concatOp.getType());

678 if (!concatType || !concatType.hasStaticShape())

680 sliceOp, "slice input must be a static ranked tensor");

681 int32_t axis = concatOp.getAxis();

682

685

688 sliceOp, "start of slice must be a static ranked shape");

689

692 sliceOp, "size of slice must be a static ranked shape");

693

695 llvm::to_vector(startElems.getValues<int64_t>());

697 llvm::to_vector(sizeElems.getValues<int64_t>());

698

699

700

701

702 std::optional replaceWithSlice;

703 for (auto input : inputs) {

704 auto inputType = dyn_cast(input.getType());

705 if (!inputType || !inputType.hasStaticShape())

707 sliceOp, "concat input must be a static ranked tensor");

708

709 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=

710 inputType.getDimSize(axis)) {

711 auto start_op =

713 auto size_op =

715 replaceWithSlice =

716 rewriter

717 .createtosa::SliceOp(sliceOp.getLoc(), sliceOp.getType(),

718 input, start_op, size_op)

719 .getResult();

720 break;

721 }

722 sliceStarts[axis] -= inputType.getDimSize(axis);

723 }

724

725 if (!replaceWithSlice)

727 sliceOp, "corresponding concat input not found for slice");

728

729 rewriter.replaceOp(sliceOp, replaceWithSlice.value());

730 return success();

731 }

732 };

733

736

739 Value sliceInput = sliceOp.getInput1();

740

741

742 auto padOp = sliceInput.getDefiningOptosa::PadOp();

743 if (!padOp)

745 "slice input must be a pad operation");

746

747

748 if (!padOp->hasOneUse())

750 "pad shall have a single consumer");

751

752

753 auto inputTy = dyn_cast(padOp.getInput1().getType());

754 auto padTy = dyn_cast(padOp.getType());

755 if (!inputTy || !padTy || !inputTy.hasRank())

757 "slice input must be a ranked tensor");

758

759

763 sliceOp,

764 "`padding` input specified on the tosa::PadOp must be constant.");

765 }

767 llvm::to_vector(paddingElems.getValues<int64_t>());

768

769

773 sliceOp, "start of slice must be a static ranked shape");

775 llvm::to_vector(startElems.getValues<int64_t>());

776

780 sliceOp, "size of slice must be a static ranked shape");

782 llvm::to_vector(sizeElems.getValues<int64_t>());

783

784

785 const int64_t rank = inputTy.getRank();

786 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {

787 const bool isDimDynamic = inputTy.isDynamicDim(i);

788 const bool isDimSliced =

789 (sliceStarts[i] != 0) || (sliceSizes[i] != -1);

790

791 return isDimDynamic && isDimSliced;

792 })) {

794 sliceOp, "axis that are sliced shall be statically known.");

795 }

796

797

801 bool updated = false;

802

803 for (int64_t i = 0; i < rank; ++i) {

804 const int64_t padLo = padPaddings[i * 2];

805 const int64_t padHi = padPaddings[i * 2 + 1];

806 const int64_t sliceStart = sliceStarts[i];

807 const int64_t sliceSize = sliceSizes[i];

808 const int64_t sliceEnd = sliceStart + sliceSize;

809

810

811 if (inputTy.isDynamicDim(i)) {

812 newPadPaddings[i * 2] = padLo;

813 newPadPaddings[i * 2 + 1] = padHi;

814 newSliceStarts[i] = sliceStart;

815 continue;

816 }

817

818

819 const int64_t dimSize = inputTy.getShape()[i];

820 const int64_t dimTotal = padLo + dimSize + padHi;

821

822

823 if (sliceStart < 0 || sliceEnd > dimTotal)

824 return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");

825

826

827 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);

828 newSliceStarts[i] = newSliceStart;

829 updated |= newSliceStart != sliceStart;

830

831

832 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);

833 const int64_t newPadHi =

834 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);

835 newPadPaddings[i * 2] = newPadLo;

836 newPadPaddings[i * 2 + 1] = newPadHi;

837 updated |= (newPadLo != padLo) || (newPadHi != padHi);

838

839

840 newPadShape[i] =

841 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];

842 }

843

844

845 if (!updated)

847 sliceOp, "terminate condition; nothing to rewrite");

848

849

850 auto newPaddingsOp =

852 auto newPadTy =

854 auto newPadOp = rewriter.createtosa::PadOp(

855 padOp.getLoc(), newPadTy, padOp.getInput1(), newPaddingsOp,

856 padOp.getPadConst());

857

858

859 auto newStartOp =

861 rewriter.replaceOpWithNewOptosa::SliceOp(sliceOp, sliceOp.getType(),

862 newPadOp.getResult(), newStartOp,

863 sliceOp.getSize());

864

865 return success();

866 }

867 };

868

869

870

874

877 ShapedType resultType = cast(sliceOp.getType());

878

879 ElementsAttr sizeElems;

882 sliceOp, "size of slice must be a static ranked shape");

883 }

884

886 llvm::to_vector(sizeElems.getValues<int64_t>());

887

888 bool replaceSliceSize{false};

889

890

891

892 for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {

893 if (size == -1 && !resultType.isDynamicDim(index)) {

894 sliceSizes[index] = resultType.getDimSize(index);

895 replaceSliceSize = true;

896 }

897 }

898

899 if (!replaceSliceSize) {

901 sliceOp, "no dimension of size of slice is dynamic that resolves "

902 "to static output shape");

903 }

904

905 auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);

906 auto newSliceOp = rewriter.createtosa::SliceOp(

907 sliceOp.getLoc(), sliceOp.getType(), sliceOp.getInput1(),

908 sliceOp.getStart(), size_op);

909

910 rewriter.replaceOp(sliceOp, newSliceOp.getResult());

911 return success();

912 }

913 };

914

915 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,

919 }

920

921

922

923

924

925 template <typename IntFolder, typename FloatFolder>

927 RankedTensorType returnTy) {

930 auto rETy = llvm::cast(rhs.getType()).getElementType();

931 if (lETy != rETy)

932 return {};

933

934 if (llvm::isa(lETy)) {

937 auto result = IntFolder()(l, r);

939 }

940

941 if (llvm::isa(lETy)) {

944 auto result = FloatFolder()(l, r);

946 }

947 }

948

949 return {};

950 }

951

953 if (llvm::isa(elemType))

955 if (llvm::isa(elemType))

957 return false;

958 }

959

961 if (llvm::isa(elemType))

962 return val && val.isSplat() &&

963 val.getSplatValue().isExactlyValue(1.0);

964 if (llvm::isa(elemType)) {

965 const int64_t shifted = 1LL << shift;

966 return val && val.isSplat() &&

967 val.getSplatValue().getSExtValue() == shifted;

968 }

969 return false;

970 }

971

972 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {

973 auto lhsTy = llvm::dyn_cast(getInput1().getType());

974 auto rhsTy = llvm::dyn_cast(getInput2().getType());

975 auto resultTy = llvm::dyn_cast(getType());

976 if (!lhsTy || !rhsTy || !resultTy)

977 return {};

978

979

980 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||

981 !rhsTy.getElementType().isIntOrIndexOrFloat())

982 return {};

983

984 auto resultETy = resultTy.getElementType();

985 auto lhsAttr =

986 llvm::dyn_cast_if_present(adaptor.getInput1());

987 auto rhsAttr =

988 llvm::dyn_cast_if_present(adaptor.getInput2());

989

990 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))

991 return getInput1();

992 if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))

993 return getInput2();

994

995 if (!lhsAttr || !rhsAttr)

996 return {};

997

998 return binaryFolder<std::plus, std::plus>(lhsAttr, rhsAttr,

999 resultTy);

1000 }

1001

1002 OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {

1003 auto inputTy = llvm::dyn_cast(getInput().getType());

1004 auto outputTy = llvm::dyn_cast(getType());

1005 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||

1006 !outputTy.hasStaticShape())

1007 return {};

1008

1009 if (inputTy.getDimSize(getAxis()) == 1)

1011

1012 return {};

1013 }

1014

1015 OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {

1016 auto lhsTy = llvm::dyn_cast(getInput1().getType());

1017 auto rhsTy = llvm::dyn_cast(getInput2().getType());

1018 auto resultTy = llvm::dyn_cast(getType());

1019 if (!lhsTy || !rhsTy || !resultTy)

1020 return {};

1021 if (lhsTy != rhsTy)

1022 return {};

1023

1024

1025 auto resultETy = resultTy.getElementType();

1026 auto lhsAttr =

1027 llvm::dyn_cast_if_present(adaptor.getInput1());

1028 auto rhsAttr =

1029 llvm::dyn_cast_if_present(adaptor.getInput2());

1030 if (lhsAttr && lhsAttr.isSplat()) {

1031 if (llvm::isa(resultETy) &&

1032 lhsAttr.getSplatValue().isZero())

1033 return lhsAttr;

1034 }

1035

1036 if (rhsAttr && rhsAttr.isSplat()) {

1037 if (llvm::isa(resultETy) &&

1038 rhsAttr.getSplatValue().isOne())

1039 return getInput1();

1040 }

1041

1042 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&

1043 llvm::isa(resultETy)) {

1044 APInt l = lhsAttr.getSplatValue();

1045 APInt r = rhsAttr.getSplatValue();

1046 if (!r.isZero()) {

1047 APInt result = l.sdiv(r);

1049 }

1050 }

1051

1052 return {};

1053 }

1054

1055 namespace {

1056

1057

1058 std::optional mulInt(APInt lhs, APInt rhs, int32_t shift,

1059 unsigned bitwidth) {

1060 APInt result = lhs.sext(64) * rhs.sext(64);

1061

1062 if (shift > 0) {

1063 auto round = APInt(64, 1) << (shift - 1);

1064 result += round;

1065 result.ashrInPlace(shift);

1066

1067 if (!(result.getSExtValue() >= INT32_MIN &&

1068 result.getSExtValue() <= INT32_MAX)) {

1069

1070 return std::nullopt;

1071 }

1072 }

1073

1074 return result.trunc(bitwidth);

1075 }

1076

1078 RankedTensorType ty, int32_t shift) {

1080 if (llvm::isa(ty.getElementType())) {

1083

1084 if (shift == 0) {

1086 }

1087

1088 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();

1089 const std::optional result = mulInt(l, r, shift, bitwidth);

1090 if (!result)

1091 return {};

1093 }

1094

1095 if (llvm::isa(ty.getElementType())) {

1098 APFloat result = l * r;

1100 }

1101 }

1102

1103 return {};

1104 }

1105 }

1106

1107 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {

1108 auto lhs = getInput1();

1109 auto rhs = getInput2();

1110 auto lhsTy = llvm::dyn_cast(lhs.getType());

1111 auto rhsTy = llvm::dyn_cast(rhs.getType());

1112 auto resultTy = llvm::dyn_cast(getType());

1113 if (!lhsTy || !rhsTy || !resultTy)

1114 return {};

1115

1116 auto resultETy = resultTy.getElementType();

1117 auto lhsAttr =

1118 llvm::dyn_cast_if_present(adaptor.getInput1());

1119 auto rhsAttr =

1120 llvm::dyn_cast_if_present(adaptor.getInput2());

1121

1122

1123

1124 int32_t shift = 0;

1125 if (resultETy.isInteger(32)) {

1126 ElementsAttr shift_elem;

1127 if (getShift().getImpl()) {

1129

1130 return {};

1131 shift = shift_elem.getValues()[0].getInt();

1132 }

1133 }

1134

1135 if (rhsTy == resultTy) {

1137 return lhsAttr.resizeSplat(resultTy);

1138 if (isSplatOne(resultETy, lhsAttr, shift))

1139 return rhs;

1140 }

1141 if (lhsTy == resultTy) {

1144 if (isSplatOne(resultETy, rhsAttr, shift))

1145 return lhs;

1146 }

1147

1148 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);

1149 }

1150

1151 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {

1152 auto lhsTy = llvm::dyn_cast(getInput1().getType());

1153 auto rhsTy = llvm::dyn_cast(getInput2().getType());

1154 auto resultTy = llvm::dyn_cast(getType());

1155 if (!lhsTy || !rhsTy || !resultTy)

1156 return {};

1157

1158

1159 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||

1160 !rhsTy.getElementType().isIntOrIndexOrFloat())

1161 return {};

1162

1163 auto resultETy = resultTy.getElementType();

1164 auto lhsAttr =

1165 llvm::dyn_cast_if_present(adaptor.getInput1());

1166 auto rhsAttr =

1167 llvm::dyn_cast_if_present(adaptor.getInput2());

1168

1169 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))

1170 return getInput1();

1171

1172 if (!lhsAttr || !rhsAttr)

1173 return {};

1174

1175 return binaryFolder<std::minus, std::minus>(lhsAttr, rhsAttr,

1176 resultTy);

1177 }

1178

1179 namespace {

1180 template

1181 struct ComparisonFold {

1182 ComparisonFold() = default;

1183 APInt operator()(const APInt &l, const APInt &r) {

1184 return APInt(1, Cmp()(l, r));

1185 }

1186

1187 APInt operator()(const APFloat &l, const APFloat &r) {

1188 return APInt(1, Cmp()(l, r));

1189 }

1190 };

1191

1192 struct APIntFoldGreater {

1193 APIntFoldGreater() = default;

1194 APInt operator()(const APInt &l, const APInt &r) {

1195 return APInt(1, l.sgt(r));

1196 }

1197 };

1198

1199 struct APIntFoldGreaterEqual {

1200 APIntFoldGreaterEqual() = default;

1201 APInt operator()(const APInt &l, const APInt &r) {

1202 return APInt(1, l.sge(r));

1203 }

1204 };

1205 }

1206

1207 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {

1208 auto resultTy = llvm::dyn_cast(getType());

1209 auto lhsAttr =

1210 llvm::dyn_cast_if_present(adaptor.getInput1());

1211 auto rhsAttr =

1212 llvm::dyn_cast_if_present(adaptor.getInput2());

1213

1214 if (!lhsAttr || !rhsAttr)

1215 return {};

1216

1217 return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater>>(

1218 lhsAttr, rhsAttr, resultTy);

1219 }

1220

1221 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {

1222 auto resultTy = llvm::dyn_cast(getType());

1223 auto lhsAttr =

1224 llvm::dyn_cast_if_present(adaptor.getInput1());

1225 auto rhsAttr =

1226 llvm::dyn_cast_if_present(adaptor.getInput2());

1227

1228 if (!lhsAttr || !rhsAttr)

1229 return {};

1230

1232 ComparisonFold<std::greater_equal>>(

1233 lhsAttr, rhsAttr, resultTy);

1234 }

1235

1236 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {

1237 auto resultTy = llvm::dyn_cast(getType());

1238 auto lhsAttr =

1239 llvm::dyn_cast_if_present(adaptor.getInput1());

1240 auto rhsAttr =

1241 llvm::dyn_cast_if_present(adaptor.getInput2());

1242 Value lhs = getInput1();

1243 Value rhs = getInput2();

1244 auto lhsTy = llvm::cast(lhs.getType());

1245

1246

1247

1248 if (llvm::isa(lhsTy.getElementType()) && resultTy &&

1249 resultTy.hasStaticShape() && lhs == rhs) {

1251 }

1252

1253 if (!lhsAttr || !rhsAttr)

1254 return {};

1255

1256 return binaryFolder<ComparisonFold<std::equal_to>,

1257 ComparisonFold<std::equal_to>>(lhsAttr, rhsAttr,

1258 resultTy);

1259 }

1260

1261 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {

1263 return getInput();

1264

1265 auto operand = llvm::dyn_cast_if_present(adaptor.getInput());

1266 if (!operand)

1267 return {};

1268

1269 auto inTy = llvm::cast(getInput().getType());

1270 auto outTy = llvm::cast(getType());

1271 auto inETy = inTy.getElementType();

1272 auto outETy = outTy.getElementType();

1273

1274 if (operand.isSplat()) {

1275 if (llvm::isa(inETy) && llvm::isa(outETy)) {

1276 bool overflow;

1277 auto splatVal = operand.getSplatValue();

1278 auto &semantics = llvm::cast(outETy).getFloatSemantics();

1279 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,

1280 &overflow);

1282 }

1283

1284 if (llvm::isa(inETy) && llvm::isa(outETy)) {

1285 auto unsign = llvm::cast(inETy).isUnsignedInteger();

1286 APFloat splatVal(llvm::cast(outETy).getFloatSemantics());

1287 splatVal.convertFromAPInt(operand.getSplatValue(), !unsign,

1288 llvm::RoundingMode::NearestTiesToEven);

1290 }

1291

1292 if (llvm::isa(inETy) && llvm::isa(outETy)) {

1293 auto unsign = llvm::cast(outETy).isUnsignedInteger();

1294 auto intVal = APSInt(

1295 llvm::cast(outETy).getIntOrFloatBitWidth(), unsign);

1296 auto floatVal = operand.getSplatValue();

1297 bool exact;

1298 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,

1299 &exact);

1301 }

1302

1303 if (llvm::isa(inETy) && llvm::isa(outETy)) {

1304 auto unsignIn = llvm::cast(inETy).isUnsignedInteger();

1305 bool trunc =

1306 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();

1307 auto intVal = operand.getSplatValue();

1308 auto bitwidth = outETy.getIntOrFloatBitWidth();

1309

1310 if (trunc) {

1311 intVal = intVal.trunc(bitwidth);

1312 } else if (unsignIn) {

1313 intVal = intVal.zext(bitwidth);

1314 } else {

1315 intVal = intVal.sext(bitwidth);

1316 }

1317

1319 }

1320 }

1321

1322 return {};

1323 }

1324

1325 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }

1326

1327 OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }

1328

1329 #define REDUCE_FOLDER(OP) \

1330 OpFoldResult OP::fold(FoldAdaptor adaptor) { \

1331 ShapedType inputTy = llvm::cast(getInput().getType()); \

1332 if (!inputTy.hasRank()) \

1333 return {}; \

1334 if (inputTy != getType()) \

1335 return {}; \

1336 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \

1337 return getInput(); \

1338 return {}; \

1339 }

1340

1347 #undef REDUCE_FOLDER

1348

1349 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {

1350 auto inputTy = llvm::dyn_cast(getInput1().getType());

1351 auto outputTy = llvm::dyn_cast(getType());

1352

1353 if (!inputTy || !outputTy)

1354 return {};

1355

1356

1357

1358

1359 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)

1360 return getInput1();

1361

1362

1363 if (auto reshapeOp = llvm::dyn_cast_if_presenttosa::ReshapeOp(

1364 getInput1().getDefiningOp())) {

1365 getInput1Mutable().assign(reshapeOp.getInput1());

1366 return getResult();

1367 }

1368

1369

1370 if (!inputTy.getElementType().isIntOrIndexOrFloat())

1371 return {};

1372

1373

1374 if (auto operand =

1375 llvm::dyn_cast_if_present(adaptor.getInput1())) {

1376

1377 if (!outputTy.hasStaticShape())

1378 return {};

1379

1380

1381 if (operand.isSplat())

1383 operand.getSplatValue<Attribute>());

1384

1385

1386 if (!getInput1().hasOneUse())

1387 return {};

1388

1391 return {};

1392

1393 return operand.reshape(

1394 llvm::cast(operand.getType()).clone(shapeVec));

1395 }

1396

1397 return {};

1398 }

1399

1400 OpFoldResult PadOp::fold(FoldAdaptor adaptor) {

1401

1402 if (adaptor.getPadding() && getInput1().getType() == getType()) {

1403 auto densePad = llvm::dyn_cast(adaptor.getPadding());

1404 if (densePad && densePad.isSplat() &&

1405 densePad.getSplatValue().isZero()) {

1406 return getInput1();

1407 }

1408 }

1409

1410 return {};

1411 }

1412

1413

1414

1415 OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {

1416 auto scaleAttr =

1417 llvm::dyn_cast_if_present(adaptor.getScale());

1418 auto offsetAttr =

1419 llvm::dyn_cast_if_present(adaptor.getOffset());

1420 auto borderAttr =

1421 llvm::dyn_cast_if_present(adaptor.getBorder());

1422 if (!scaleAttr || !offsetAttr || !borderAttr) {

1423 return {};

1424 }

1425

1429 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {

1430 return {};

1431 }

1432

1433

1434 if (scale[0] != scale[1] || scale[2] != scale[3]) {

1435 return {};

1436 }

1437

1438

1439 if (offset[0] != 0 || offset[1] != 0) {

1440 return {};

1441 }

1442

1443

1444 if (border[0] != 0 || border[1] != 0) {

1445 return {};

1446 }

1447

1448 auto input = getInput();

1449 auto inputTy = llvm::cast(input.getType());

1450 auto resultTy = llvm::cast(getType());

1451 if (inputTy != resultTy)

1452 return {};

1453

1454 return input;

1455 }

1456

1457 OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {

1458 auto operand = getInput1();

1459 auto operandTy = llvm::cast(operand.getType());

1460 auto axis = getAxis();

1461 auto operandAttr =

1462 llvm::dyn_cast_if_present(adaptor.getInput1());

1463 if (operandAttr)

1464 return operandAttr;

1465

1466

1467 if (operandTy.hasRank() &&

1468 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))

1469 return operand;

1470

1471 return {};

1472 }

1473

1474 OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {

1475 auto inputTy = llvm::dyn_cast(getInput1().getType());

1476 auto outputTy = llvm::dyn_cast(getType());

1477

1478 if (!inputTy || !outputTy)

1479 return {};

1480

1481 if (inputTy == outputTy && inputTy.hasStaticShape())

1482 return getInput1();

1483

1484 if (!adaptor.getInput1())

1485 return {};

1486

1487

1488 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||

1489 !outputTy.getElementType().isIntOrIndexOrFloat())

1490 return {};

1491

1492 auto operand = llvm::cast(adaptor.getInput1());

1493 if (operand.isSplat() && outputTy.hasStaticShape()) {

1495 }

1496

1497 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&

1498 outputTy.getNumElements() == 1) {

1501 return {};

1502

1504 llvm::to_vector(startElems.getValues<uint64_t>());

1505 auto value = operand.getValues<Attribute>()[indices];

1507 }

1508

1509 return {};

1510 }

1511

1512 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {

1513 if (getInput2() == getInput3())

1514 return getInput2();

1515

1516 auto predicate =

1517 llvm::dyn_cast_if_present(adaptor.getInput1());

1518 if (!predicate)

1519 return {};

1520

1521 if (!predicate.isSplat())

1522 return {};

1523 return predicate.getSplatValue().getBoolValue() ? getInput2()

1524 : getInput3();

1525 }

1526

1527 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {

1529 if (auto multiples = llvm::dyn_cast_if_present(

1530 adaptor.getMultiples())) {

1531 if (multiples.isSplat() &&

1532 multiples.getSplatValue().getSExtValue() == 1)

1533 return getInput1();

1534 if (auto int_array_attr =

1535 llvm::dyn_cast(multiples)) {

1536 if (llvm::all_of(int_array_attr.getValues(),

1537 [](APInt v) { return v.getSExtValue() == 1; }))

1538 return getInput1();

1539 }

1540 }

1541 }

1542 return {};

1543 }

1544

1545 OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {

1546 auto resultTy = llvm::cast(getType());

1547

1548

1549 if (auto input =

1550 llvm::dyn_cast_if_present(adaptor.getInput1())) {

1551 if (input.isSplat() && resultTy.hasStaticShape() &&

1552 input.getType().getElementType() == resultTy.getElementType())

1553 return input.reshape(resultTy);

1554 }

1555

1556

1558

1559 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))

1560 return {};

1561

1562 return getInput1();

1563 }

1564

1565 OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {

1566 auto input = getInput1();

1567

1568 if (auto op = input.getDefiningOptosa::ExpOp()) {

1569 return op.getInput1();

1570 }

1571

1572 return {};

1573 }

1574

1575 OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {

1576 auto input = getInput1();

1577

1578 if (auto op = input.getDefiningOptosa::LogOp()) {

1579 return op.getInput1();

1580 }

1581

1582 return {};

1583 }

1584

1585 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {

1586

1587

1588 auto definingOp = getInput1().getDefiningOptosa::NegateOp();

1589 if (!definingOp) {

1590

1591 return {};

1592 }

1593

1594 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();

1595 failed(maybeIZp) || *maybeIZp != 0) {

1596

1597 return {};

1598 }

1599 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();

1600 failed(maybeOZp) || *maybeOZp != 0) {

1601

1602 return {};

1603 }

1604 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();

1605 failed(maybeIZp) || *maybeIZp != 0) {

1606

1607 return {};

1608 }

1609 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();

1610 failed(maybeOZp) || *maybeOZp != 0) {

1611

1612 return {};

1613 }

1614

1615 return definingOp.getInput1();

1616 }

1617

1618 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {

1619 auto input = getInput1();

1620

1621 if (auto op = input.getDefiningOptosa::AbsOp()) {

1622 return input;

1623 }

1624

1625 return {};

1626 }

1627

1628 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {

1629

1630

1631

1632

1634 concatOperands.reserve(2 * getNumOperands());

1635

1636

1637 bool foundFoldableConcat = false;

1638 for (Value operand : getOperands()) {

1639 concatOperands.emplace_back(operand);

1640

1641 auto producer = dyn_cast_or_null(operand.getDefiningOp());

1642 if (!producer)

1643 continue;

1644

1645

1646 if (getAxis() != producer.getAxis())

1647 continue;

1648

1649

1650 foundFoldableConcat = true;

1651 concatOperands.pop_back();

1652 llvm::append_range(concatOperands, producer->getOperands());

1653 }

1654

1655 if (!foundFoldableConcat)

1656 return {};

1657

1658 getOperation()->setOperands(concatOperands);

1659 return getResult();

1660 }

1661

1662 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {

1663 auto input = adaptor.getInput1();

1664

1665 auto inputAttr = llvm::dyn_cast_if_present(input);

1666

1667 if (!inputAttr || !inputAttr.isSplat())

1668 return {};

1669

1670 auto shapeType = llvm::cast(getType());

1671 if (auto floatType = llvm::dyn_cast(inputAttr.getElementType())) {

1672 auto floatVal = inputAttr.getSplatValue();

1674 ReciprocalOp::calcOneElement(floatVal));

1675 }

1676

1677 return {};

1678 }

static Value getZero(OpBuilder &b, Location loc, Type elementType)

Get zero value for an element type.

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)

DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)

#define REDUCE_FOLDER(OP)

bool checkMatchingPadConstAndZp(Value padConst, Value zp)

static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)

static bool isSplatZero(Type elemType, DenseElementsAttr val)

static ArrayRef< int64_t > getShape(Type type)

Returns the shape of the given type.

Attributes are known-constant values of operations.

DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)

IntegerAttr getIntegerAttr(Type type, int64_t value)

DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)

FloatAttr getFloatAttr(Type type, double value)

StringAttr getStringAttr(const Twine &bytes)

An attribute that represents a reference to a dense vector or tensor object.

std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const

Return the splat value for this attribute.

auto getValues() const

Return the held element values as a range of the given type.

DenseElementsAttr resizeSplat(ShapedType newType)

Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...

int64_t size() const

Returns the number of elements held by this attribute.

bool isSplat() const

Returns true if this attribute corresponds to a splat, i.e.

Type getElementType() const

Return the element type of this DenseElementsAttr.

static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)

Constructs a dense elements attribute from an array of element values.

ShapedType getType() const

Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.

An attribute that represents a reference to a dense integer vector or tensor object.

iterator begin() const

Iterator access to the integer element values.

MLIRContext is the top-level object for a collection of MLIR operations.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents a single result from folding an operation.

This class implements the operand iterators for the Operation class.

Operation is the basic unit of execution within MLIR.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)

Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

user_range getUsers() const

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

DynamicAPInt round(const Fraction &f)

SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)

Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)

bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

bool intersects(const ClampRange< T > &otherRange)

ClampRange(const T &start, const T &end)

LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override

LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override

LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override

LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override

LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override

LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override

LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override

LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override

LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...