MLIR: lib/Conversion/TosaToLinalg/TosaToLinalg.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

31 #include "llvm/ADT/STLExtras.h"

32 #include "llvm/ADT/Sequence.h"

33

34 #include

35 #include <type_traits>

36

37 using namespace mlir;

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62 template

66

68 return result;

69

70 auto nanMode = op.getNanMode();

71 if (nanMode == "PROPAGATE")

72 return result;

73

74

75 Value lhsIsNaN = rewriter.createarith::CmpFOp(

76 op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);

77 Value rhsIsNaN = rewriter.createarith::CmpFOp(

78 op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);

79 Value rhsOrResult =

80 rewriter.createarith::SelectOp(op.getLoc(), lhsIsNaN, rhs, result);

81 return rewriter.createarith::SelectOp(op.getLoc(), rhsIsNaN, lhs,

82 rhsOrResult);

83 }

84

89 auto elementTy =

91

92

93 if (isatosa::AbsOp(op) && isa(elementTy))

94 return rewriter.createmath::AbsFOp(loc, resultTypes, args);

95

96 if (isatosa::AbsOp(op) && isa(elementTy)) {

97 auto zero = rewriter.createarith::ConstantOp(

99 auto neg = rewriter.createarith::SubIOp(loc, zero, args[0]);

100 return rewriter.createarith::MaxSIOp(loc, args[0], neg);

101 }

102

103

104 if (isatosa::AddOp(op) && isa(elementTy))

105 return rewriter.createarith::AddFOp(loc, resultTypes, args);

106

107 if (isatosa::AddOp(op) && isa(elementTy))

108 return rewriter.createarith::AddIOp(loc, resultTypes, args);

109

110

111 if (isatosa::SubOp(op) && isa(elementTy))

112 return rewriter.createarith::SubFOp(loc, resultTypes, args);

113

114 if (isatosa::SubOp(op) && isa(elementTy))

115 return rewriter.createarith::SubIOp(loc, resultTypes, args);

116

117

118 if (isatosa::IntDivOp(op) && isa(elementTy))

119 return rewriter.createarith::DivSIOp(loc, resultTypes, args);

120

121

122 if (isatosa::ReciprocalOp(op) && isa(elementTy)) {

123 auto one =

125 return rewriter.createarith::DivFOp(loc, resultTypes, one, args[0]);

126 }

127

128

129 if (isatosa::MulOp(op)) {

130 auto shiftVal = casttosa::MulOp(op).getShift();

133 (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");

134 return nullptr;

135 }

136

137 int32_t shift = shiftElem.getValues()[0].getInt();

138

139 if (isa(elementTy)) {

140 if (shift != 0) {

142 "Cannot have shift value for float");

143 return nullptr;

144 }

145 return rewriter.createarith::MulFOp(loc, resultTypes, args[0], args[1]);

146 }

147

148 if (isa(elementTy)) {

149 Value a = args[0];

150 Value b = args[1];

151

152 if (shift > 0) {

153 auto shiftConst =

154 rewriter.createarith::ConstantIntOp(loc, shift, 8);

156 a = rewriter.createarith::ExtSIOp(loc, rewriter.getI32Type(), a);

157

159 b = rewriter.createarith::ExtSIOp(loc, rewriter.getI32Type(), b);

160

161 auto result = rewriter.createtosa::ApplyScaleOp(

162 loc, rewriter.getI32Type(), a, b, shiftConst,

164

165 if (elementTy.isInteger(32))

166 return result;

167

168 return rewriter.createarith::TruncIOp(loc, elementTy, result);

169 }

170

173 int cWidth = resultTypes[0].getIntOrFloatBitWidth();

174

175 if (aWidth < cWidth)

176 a = rewriter.createarith::ExtSIOp(loc, resultTypes[0], a);

177 if (bWidth < cWidth)

178 b = rewriter.createarith::ExtSIOp(loc, resultTypes[0], b);

179

180 return rewriter.createarith::MulIOp(loc, resultTypes, a, b);

181 }

182 }

183

184

185 if (isatosa::NegateOp(op)) {

186 auto negate = casttosa::NegateOp(op);

187

188 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();

189 if (failed(maybeInZp)) {

191 op, "input1 zero point cannot be statically determined");

192 return nullptr;

193 }

194

195 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();

196 if (failed(maybeOutZp)) {

198 op, "output zero point cannot be statically determined");

199 return nullptr;

200 }

201

202 int64_t inZp = *maybeInZp;

203 int64_t outZp = *maybeOutZp;

204

205 if (isa(elementTy))

206 return rewriter.createarith::NegFOp(loc, resultTypes, args[0]);

207

208 if (isa(elementTy)) {

209 if (!inZp && !outZp) {

210 auto constant = rewriter.createarith::ConstantOp(

212 return rewriter.createarith::SubIOp(loc, resultTypes, constant,

213 args[0]);

214 }

215

216

217 const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();

218 const int64_t zpAdd = inZp + outZp;

219 const int64_t maxValue =

220 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +

222

223

224

225

226 int intermediateBitWidth = 64;

227 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {

228 intermediateBitWidth = 16;

229 } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {

230 intermediateBitWidth = 32;

231 } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {

232 intermediateBitWidth = 48;

233 }

234

235 Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);

236 Value zpAddValue = rewriter.createarith::ConstantOp(

237 loc, rewriter.getIntegerAttr(intermediateType, zpAdd));

238

239

240

241 auto ext =

242 rewriter.createarith::ExtSIOp(loc, intermediateType, args[0]);

243 auto sub = rewriter.createarith::SubIOp(loc, zpAddValue, ext);

244

245

247 loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),

248 intermediateType);

250 loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),

251 intermediateType);

253

254

255 return rewriter.createarith::TruncIOp(loc, elementTy, clamp);

256 }

257 }

258

259

260 if (isatosa::BitwiseAndOp(op) && isa(elementTy))

261 return rewriter.createarith::AndIOp(loc, resultTypes, args);

262

263

264 if (isatosa::BitwiseOrOp(op) && isa(elementTy))

265 return rewriter.createarith::OrIOp(loc, resultTypes, args);

266

267

268 if (isatosa::BitwiseNotOp(op) && isa(elementTy)) {

270 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));

271 auto allOnes = rewriter.createarith::ConstantOp(loc, allOnesAttr);

272 return rewriter.createarith::XOrIOp(loc, resultTypes, args[0], allOnes);

273 }

274

275

276 if (isatosa::BitwiseXorOp(op) && isa(elementTy))

277 return rewriter.createarith::XOrIOp(loc, resultTypes, args);

278

279

280 if (isatosa::LogicalLeftShiftOp(op) && isa(elementTy))

281 return rewriter.createarith::ShLIOp(loc, resultTypes, args);

282

283

284 if (isatosa::LogicalRightShiftOp(op) && isa(elementTy))

285 return rewriter.createarith::ShRUIOp(loc, resultTypes, args);

286

287

288 if (isatosa::ArithmeticRightShiftOp(op) && isa(elementTy)) {

289 auto result = rewriter.createarith::ShRSIOp(loc, resultTypes, args);

290 auto round = cast(op->getAttr("round")).getValue();

292 return result;

293 }

294

296 auto one =

298 auto zero =

300 auto i1one =

302

303

304 auto shiftValueGreaterThanZero = rewriter.createarith::CmpIOp(

305 loc, arith::CmpIPredicate::sgt, args[1], zero);

306

307

308 auto subtract =

309 rewriter.createarith::SubIOp(loc, resultTypes, args[1], one);

310 auto shifted =

311 rewriter.createarith::ShRSIOp(loc, resultTypes, args[0], subtract)

312 ->getResults();

313 auto truncated =

314 rewriter.createarith::TruncIOp(loc, i1Ty, shifted, std::nullopt);

315 auto isInputOdd =

316 rewriter.createarith::AndIOp(loc, i1Ty, truncated, i1one);

317

318 auto shouldRound = rewriter.createarith::AndIOp(

319 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);

320 auto extended =

321 rewriter.createarith::ExtUIOp(loc, resultTypes, shouldRound);

322 return rewriter.createarith::AddIOp(loc, resultTypes, result, extended);

323 }

324

325

326 if (isatosa::ClzOp(op) && isa(elementTy)) {

327 return rewriter.createmath::CountLeadingZerosOp(loc, elementTy, args[0]);

328 }

329

330

331 if (isatosa::LogicalAndOp(op) && elementTy.isInteger(1))

332 return rewriter.createarith::AndIOp(loc, resultTypes, args);

333

334

335 if (isatosa::LogicalNotOp(op) && elementTy.isInteger(1)) {

336 auto one = rewriter.createarith::ConstantOp(

338 return rewriter.createarith::XOrIOp(loc, resultTypes, args[0], one);

339 }

340

341

342 if (isatosa::LogicalOrOp(op) && elementTy.isInteger(1))

343 return rewriter.createarith::OrIOp(loc, resultTypes, args);

344

345

346 if (isatosa::LogicalXorOp(op) && elementTy.isInteger(1))

347 return rewriter.createarith::XOrIOp(loc, resultTypes, args);

348

349

350 if (isatosa::PowOp(op) && isa(elementTy))

351 return rewriter.createmlir::math::PowFOp(loc, resultTypes, args);

352

353

354 if (isatosa::RsqrtOp(op) && isa(elementTy))

355 return rewriter.createmlir::math::RsqrtOp(loc, resultTypes, args);

356

357

358 if (isatosa::LogOp(op) && isa(elementTy))

359 return rewriter.createmlir::math::LogOp(loc, resultTypes, args);

360

361

362 if (isatosa::ExpOp(op) && isa(elementTy))

363 return rewriter.createmlir::math::ExpOp(loc, resultTypes, args);

364

365

366 if (isatosa::SinOp(op) && isa(elementTy))

367 return rewriter.createmlir::math::SinOp(loc, resultTypes, args);

368

369

370 if (isatosa::CosOp(op) && isa(elementTy))

371 return rewriter.createmlir::math::CosOp(loc, resultTypes, args);

372

373

374 if (isatosa::TanhOp(op) && isa(elementTy))

375 return rewriter.createmlir::math::TanhOp(loc, resultTypes, args);

376

377

378 if (isatosa::ErfOp(op) && llvm::isa(elementTy))

379 return rewriter.createmlir::math::ErfOp(loc, resultTypes, args);

380

381

382 if (isatosa::GreaterOp(op) && isa(elementTy))

383 return rewriter.createarith::CmpFOp(loc, arith::CmpFPredicate::OGT,

384 args[0], args[1]);

385

386 if (isatosa::GreaterOp(op) && elementTy.isSignlessInteger())

387 return rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::sgt,

388 args[0], args[1]);

389

390

391 if (isatosa::GreaterEqualOp(op) && isa(elementTy))

392 return rewriter.createarith::CmpFOp(loc, arith::CmpFPredicate::OGE,

393 args[0], args[1]);

394

395 if (isatosa::GreaterEqualOp(op) && elementTy.isSignlessInteger())

396 return rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::sge,

397 args[0], args[1]);

398

399

400 if (isatosa::EqualOp(op) && isa(elementTy))

401 return rewriter.createarith::CmpFOp(loc, arith::CmpFPredicate::OEQ,

402 args[0], args[1]);

403

404 if (isatosa::EqualOp(op) && elementTy.isSignlessInteger())

405 return rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::eq,

406 args[0], args[1]);

407

408

409 if (isatosa::SelectOp(op)) {

410 elementTy = cast(op->getOperand(1).getType()).getElementType();

411 if (isa(elementTy) || isa(elementTy))

412 return rewriter.createarith::SelectOp(loc, args[0], args[1], args[2]);

413 }

414

415

416 if (isatosa::MaximumOp(op) && isa(elementTy)) {

417 auto max = rewriter.createarith::MaximumFOp(loc, args[0], args[1]);

419 rewriter, args[0], args[1], max);

420 }

421

422 if (isatosa::MaximumOp(op) && elementTy.isSignlessInteger()) {

423 return rewriter.createarith::MaxSIOp(loc, args[0], args[1]);

424 }

425

426

427 if (isatosa::MinimumOp(op) && isa(elementTy)) {

428 auto min = rewriter.createarith::MinimumFOp(loc, args[0], args[1]);

430 rewriter, args[0], args[1], min);

431 }

432

433 if (isatosa::MinimumOp(op) && elementTy.isSignlessInteger()) {

434 return rewriter.createarith::MinSIOp(loc, args[0], args[1]);

435 }

436

437

438 if (isatosa::CeilOp(op) && isa(elementTy))

439 return rewriter.createmath::CeilOp(loc, resultTypes, args);

440

441

442 if (isatosa::FloorOp(op) && isa(elementTy))

443 return rewriter.createmath::FloorOp(loc, resultTypes, args);

444

445

446 if (isatosa::ClampOp(op) && isa(elementTy)) {

447 bool losesInfo = false;

448 APFloat minApf = cast(op->getAttr("min_val")).getValue();

449 APFloat maxApf = cast(op->getAttr("max_val")).getValue();

450 minApf.convert(cast(elementTy).getFloatSemantics(),

451 APFloat::rmNearestTiesToEven, &losesInfo);

452 maxApf.convert(cast(elementTy).getFloatSemantics(),

453 APFloat::rmNearestTiesToEven, &losesInfo);

454 auto min = rewriter.createarith::ConstantOp(

455 loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));

456 auto max = rewriter.createarith::ConstantOp(

457 loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));

459

460 auto clampOp = llvm::casttosa::ClampOp(op);

461 const auto nanMode = clampOp.getNanMode();

462

463

464 if (!isa(elementTy))

465 return result;

466

467

468

469 if (nanMode == "PROPAGATE")

470 return result;

471

472

473

474

475

476

477

478

479

480

481

482

483 Value isNaN = rewriter.createarith::CmpFOp(

484 op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);

485

486

487 return rewriter.createarith::SelectOp(op->getLoc(), isNaN, min, result);

488 }

489

490 if (isatosa::ClampOp(op) && isa(elementTy)) {

491 auto intTy = cast(elementTy);

492 int64_t min =

493 cast(op->getAttr("min_val")).getValue().getSExtValue();

494 int64_t max =

495 cast(op->getAttr("max_val")).getValue().getSExtValue();

496

499 if (intTy.isUnsignedInteger()) {

500 minRepresentable = 0;

501 if (intTy.getIntOrFloatBitWidth() <= 63) {

502 maxRepresentable =

503 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())

504 .getZExtValue();

505 }

506 } else if (intTy.getIntOrFloatBitWidth() <= 64) {

507

508 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())

509 .getSExtValue();

510 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())

511 .getSExtValue();

512 }

513

514

519

520 auto minVal = rewriter.createarith::ConstantIntOp(

521 loc, min, intTy.getIntOrFloatBitWidth());

522 auto maxVal = rewriter.createarith::ConstantIntOp(

523 loc, max, intTy.getIntOrFloatBitWidth());

524 return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,

525 intTy.isUnsignedInteger());

526 }

527

528

529 if (isatosa::SigmoidOp(op) && isa(elementTy)) {

530 auto one =

532 auto negate = rewriter.createarith::NegFOp(loc, resultTypes, args[0]);

533 auto exp = rewriter.createmlir::math::ExpOp(loc, resultTypes, negate);

534 auto added = rewriter.createarith::AddFOp(loc, resultTypes, exp, one);

535 return rewriter.createarith::DivFOp(loc, resultTypes, one, added);

536 }

537

538

539 if (isatosa::CastOp(op)) {

540 Type srcTy = elementTy;

541 Type dstTy = resultTypes.front();

544 return nullptr;

545 }

546

547 bool bitExtend =

549

550 if (srcTy == dstTy)

551 return args.front();

552

553 if (isa(srcTy) && isa(dstTy) && bitExtend)

554 return rewriter.createarith::ExtFOp(loc, resultTypes, args,

555 std::nullopt);

556

557 if (isa(srcTy) && isa(dstTy) && !bitExtend)

558 return rewriter.createarith::TruncFOp(loc, resultTypes, args,

559 std::nullopt);

560

561

562 if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))

563 return rewriter.createarith::UIToFPOp(loc, resultTypes, args,

564 std::nullopt);

565

566 if (srcTy.isInteger(1) && isa(dstTy) && bitExtend)

567 return rewriter.createarith::ExtUIOp(loc, resultTypes, args,

568 std::nullopt);

569

570

571

573 auto unrealizedCast =

574 rewriter

575 .create(

577 args[0])

579 return rewriter.createarith::UIToFPOp(loc, resultTypes[0],

580 unrealizedCast);

581 }

582

583

584 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))

585 return rewriter.createarith::SIToFPOp(loc, resultTypes, args,

586 std::nullopt);

587

588

589 if (isa(srcTy) && dstTy.isInteger(1)) {

590 Value zero = rewriter.createarith::ConstantOp(

592 return rewriter.createarith::CmpFOp(loc, arith::CmpFPredicate::UNE,

593 args.front(), zero);

594 }

595

596 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {

597 auto rounded = rewriter.createmath::RoundEvenOp(loc, args[0]);

598

599 const auto &fltSemantics = cast(srcTy).getFloatSemantics();

600

601

603 APFloat::semanticsMaxExponent(fltSemantics)) {

604

605

606 auto conv = rewriter.createarith::FPToSIOp(loc, dstTy, rounded);

607 auto posInf = rewriter.createarith::ConstantOp(

609 APFloat::getInf(fltSemantics)));

610 auto negInf = rewriter.createarith::ConstantOp(

613 APFloat::getInf(fltSemantics, true)));

614 auto overflow = rewriter.createarith::CmpFOp(

615 loc, arith::CmpFPredicate::UEQ, rounded, posInf);

616 auto underflow = rewriter.createarith::CmpFOp(

617 loc, arith::CmpFPredicate::UEQ, rounded, negInf);

618 auto intMin = rewriter.createarith::ConstantOp(

621 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));

622 auto intMax = rewriter.createarith::ConstantOp(

625 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));

626 auto maxClamped =

627 rewriter.createarith::SelectOp(loc, overflow, intMax, conv);

628 return rewriter.createarith::SelectOp(loc, underflow, intMin,

629 maxClamped);

630 }

631

632 auto intMinFP = rewriter.createarith::ConstantOp(

636 .getSExtValue()));

637

638

639 if (cast(srcTy).getFPMantissaWidth() >=

641

642

643

644

645 auto intMaxFP = rewriter.createarith::ConstantOp(

649 .getSExtValue()));

650

652 clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);

653 return rewriter.createarith::FPToSIOp(loc, dstTy, clamped);

654 }

655

656

657

658

659

660 auto intMaxPlusOneFP = rewriter.createarith::ConstantOp(

663 static_cast<double>(

665 .getSExtValue()) +

666 1.0f));

667

668 auto intMax = rewriter.createarith::ConstantOp(

672 auto minClampedFP =

673 rewriter.createarith::MaximumFOp(loc, rounded, intMinFP);

674 auto minClamped =

675 rewriter.createarith::FPToSIOp(loc, dstTy, minClampedFP);

676 auto overflow = rewriter.createarith::CmpFOp(

677 loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);

678 return rewriter.createarith::SelectOp(loc, overflow, intMax,

679 minClamped);

680 }

681

682

683

684 if (isa(srcTy) && dstTy.isInteger(1)) {

685 Value zero = rewriter.createarith::ConstantIntOp(

687 return rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::ne,

688 args.front(), zero);

689 }

690

691 if (isa(srcTy) && isa(dstTy) && bitExtend)

692 return rewriter.createarith::ExtSIOp(loc, resultTypes, args,

693 std::nullopt);

694

695 if (isa(srcTy) && isa(dstTy) && !bitExtend) {

696 return rewriter.createarith::TruncIOp(loc, dstTy, args[0]);

697 }

698 }

699

701 op, "unhandled op for linalg body calculation for elementwise op");

702 return nullptr;

703 }

704

706

707

708

709

711 IndexPool &indexPool, int64_t index) {

712 auto [it, inserted] = indexPool.try_emplace(index);

713 if (inserted)

714 it->second =

715 rewriter.createarith::ConstantOp(loc, rewriter.getIndexAttr(index));

716 return it->second;

717 }

718

721 auto indexValue = createIndex(rewriter, loc, indexPool, index);

722 return rewriter.createtensor::DimOp(loc, tensor, indexValue).getResult();

723 }

724

727 int64_t index) {

728 auto shapedType = dyn_cast(tensor.getType());

729 assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");

730 assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");

731 if (shapedType.isDynamicDim(index))

732 return getTensorDim(rewriter, loc, indexPool, tensor, index);

733 return rewriter.getIndexAttr(shapedType.getDimSize(index));

734 }

735

737 auto isRanked = [](Value value) {

738 return isa(value.getType());

739 };

740 return llvm::all_of(operation->getOperands(), isRanked) &&

741 llvm::all_of(operation->getResults(), isRanked);

742 }

743

744

745

746

747

748

749

750

751

752

753

754 static std::pair<OpFoldResult, Value>

757

758

759

760 for (auto operand : operands) {

761 auto size = cast(operand.getType()).getDimSize(dim);

762 if (!ShapedType::isDynamic(size) && size > 1)

763 return {rewriter.getIndexAttr(size), operand};

764 }

765

766

767 auto operandsWithDynamicDim =

768 llvm::filter_to_vector(operands, [&](Value operand) {

769 return cast(operand.getType()).isDynamicDim(dim);

770 });

771

772

773 if (operandsWithDynamicDim.empty())

774 return {rewriter.getIndexAttr(1), operands.front()};

775

776

777

778

779 auto targetSize =

780 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);

781 if (operandsWithDynamicDim.size() == 1)

782 return {targetSize, operandsWithDynamicDim[0]};

783

784

785 for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {

786 auto nextSize =

787 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);

788 targetSize = rewriter.createarith::MaxUIOp(loc, targetSize, nextSize);

789 }

790 return {targetSize, nullptr};

791 }

792

793

794

798 assert(!operands.empty());

799 auto rank = cast(operands.front().getType()).getRank();

802 for (auto dim : llvm::seq<int64_t>(0, rank)) {

803 auto [targetSize, masterOperand] =

805 targetShape.push_back(targetSize);

806 masterOperands.push_back(masterOperand);

807 }

808 return {targetShape, masterOperands};

809 }

810

814 Value masterOperand) {

815

816 auto rankedTensorType = cast(operand.getType());

817 if (!rankedTensorType.isDynamicDim(dim))

818 return operand;

819

820

821

822

823

824 if (operand == masterOperand)

825 return operand;

826

827

828 auto rank = rankedTensorType.getRank();

830 for (auto index : llvm::seq<int64_t>(0, rank)) {

833 affineExprs.push_back(affineExpr);

834 }

835 auto broadcastAffineMap =

839

840

841 auto one = createIndex(rewriter, loc, indexPool, 1);

842 auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);

843 auto broadcastNecessary = rewriter.createarith::CmpIOp(

844 loc, arith::CmpIPredicate::eq, runtimeSize, one);

845

846

848

849

851

852

854 for (auto index : llvm::seq<int64_t>(0, rank)) {

855 auto size = index == dim ? targetSize

857 operand, index);

858 outputTensorShape.push_back(size);

859 }

860 Value outputTensor = opBuilder.createtensor::EmptyOp(

861 loc, outputTensorShape, rankedTensorType.getElementType());

862

863

864 auto resultTensor =

865 opBuilder

866 .createlinalg::GenericOp(

867 loc, outputTensor.getType(), operand, outputTensor, affineMaps,

870

871 opBuilder.createlinalg::YieldOp(loc, blockArgs.front());

872 })

873 .getResult(0);

874

875

876 auto castResultTensor = rewriter.createOrFoldtensor::CastOp(

877 loc, operand.getType(), resultTensor);

878

879

880 opBuilder.createscf::YieldOp(loc, castResultTensor);

881 };

882

883

885 opBuilder.createscf::YieldOp(loc, operand);

886 };

887

888

889 auto ifOp = rewriter.createscf::IfOp(loc, broadcastNecessary,

890 emitThenRegion, emitElseRegion);

892 }

893

898 int64_t rank = cast(operand.getType()).getRank();

899 assert((int64_t)targetShape.size() == rank);

900 assert((int64_t)masterOperands.size() == rank);

901 for (auto index : llvm::seq<int64_t>(0, rank))

902 operand =

904 targetShape[index], masterOperands[index]);

905 return operand;

906 }

907

913

914 if (operands.size() == 1)

915 return operands;

916

917

918 return llvm::map_to_vector(operands, [&](Value operand) {

920 targetShape, masterOperands);

921 });

922 }

923

924 static LogicalResult

929

930 auto resultType = cast_or_null(

932 if (!resultType) {

933 return rewriter.notifyMatchFailure(operation, "failed to convert type");

934 }

935 Value outputTensor = rewriter.createtensor::EmptyOp(

936 loc, targetShape, resultType.getElementType());

937

938

939

940

941 auto rank = resultType.getRank();

942 auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {

943 auto shape = cast(operand.getType()).getShape();

946

947

948

949 bool requiresBroadcast =

950 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);

951 auto affineExpr = requiresBroadcast

954 affineExprs.push_back(affineExpr);

955 }

957 });

959

960

961 bool encounteredError = false;

962 auto linalgOp = rewriter.createlinalg::GenericOp(

963 loc, outputTensor.getType(), operands, outputTensor, affineMaps,

967 operation, blockArgs.take_front(operation->getNumOperands()),

968 {resultType.getElementType()}, rewriter);

969 if (!opResult) {

970 encounteredError = true;

971 return;

972 }

973 opBuilder.createlinalg::YieldOp(loc, opResult);

974 });

975 if (encounteredError)

977 operation, "unable to create linalg.generic body for elementwise op");

978

979

980 auto castResult = rewriter.createOrFoldtensor::CastOp(

981 loc, resultType, linalgOp->getResult(0));

982 rewriter.replaceOp(operation, castResult);

983 return success();

984 }

985

988

989 if (isatosa::MulOp(operation))

990 return operands.take_front(2);

991

992 if (isatosa::NegateOp(operation))

993 return operands.take_front(1);

994 return operands;

995 }

996

997 static LogicalResult

1001

1002

1003 assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");

1005 "elementwise op expects at least 1 operand");

1008 "Unranked tensors not supported");

1009

1010

1012 auto loc = operation->getLoc();

1014 auto [targetShape, masterOperands] =

1016 auto broadcastOperands =

1018 targetShape, masterOperands);

1020 targetShape, converter);

1021 }

1022

1023

1024

1027 if (isatosa::ReduceSumOp(op) && isa(elementTy))

1028 return rewriter.getFloatAttr(elementTy, 0.0);

1029

1030 if (isatosa::ReduceSumOp(op) && isa(elementTy))

1032

1033 if (isatosa::ReduceProductOp(op) && isa(elementTy))

1034 return rewriter.getFloatAttr(elementTy, 1.0);

1035

1036 if (isatosa::ReduceProductOp(op) && isa(elementTy))

1038

1039 if (isatosa::ReduceMinOp(op) && isa(elementTy))

1041 elementTy, APFloat::getLargest(

1042 cast(elementTy).getFloatSemantics(), false));

1043

1044 if (isatosa::ReduceMinOp(op) && isa(elementTy))

1047

1048 if (isatosa::ReduceMaxOp(op) && isa(elementTy))

1050 elementTy, APFloat::getLargest(

1051 cast(elementTy).getFloatSemantics(), true));

1052

1053 if (isatosa::ReduceMaxOp(op) && isa(elementTy))

1056

1057 if (isatosa::ReduceAllOp(op) && elementTy.isInteger(1))

1058 return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));

1059

1060 if (isatosa::ReduceAnyOp(op) && elementTy.isInteger(1))

1062

1063 if (isatosa::ArgMaxOp(op) && isa(elementTy))

1065 elementTy, APFloat::getLargest(

1066 cast(elementTy).getFloatSemantics(), true));

1067

1068 if (isatosa::ArgMaxOp(op) && isa(elementTy))

1071

1072 return {};

1073 }

1074

1075

1076

1079 Type elementTy,

1082 if (isatosa::ReduceSumOp(op) && isa(elementTy)) {

1083 return rewriter.createarith::AddFOp(loc, args);

1084 }

1085

1086 if (isatosa::ReduceSumOp(op) && isa(elementTy)) {

1087 return rewriter.createarith::AddIOp(loc, args);

1088 }

1089

1090 if (isatosa::ReduceProductOp(op) && isa(elementTy)) {

1091 return rewriter.createarith::MulFOp(loc, args);

1092 }

1093

1094 if (isatosa::ReduceProductOp(op) && isa(elementTy)) {

1095 return rewriter.createarith::MulIOp(loc, args);

1096 }

1097

1098 if (isatosa::ReduceMinOp(op) && isa(elementTy)) {

1099 return rewriter.createarith::MinimumFOp(loc, args[0], args[1]);

1100 }

1101

1102 if (isatosa::ReduceMinOp(op) && isa(elementTy)) {

1103 return rewriter.createarith::MinSIOp(loc, args[0], args[1]);

1104 }

1105

1106 if (isatosa::ReduceMaxOp(op) && isa(elementTy)) {

1107 return rewriter.createarith::MaximumFOp(loc, args[0], args[1]);

1108 }

1109

1110 if (isatosa::ReduceMaxOp(op) && isa(elementTy)) {

1111 return rewriter.createarith::MaxSIOp(loc, args[0], args[1]);

1112 }

1113

1114 if (isatosa::ReduceAllOp(op) && elementTy.isInteger(1))

1115 return rewriter.createarith::AndIOp(loc, args);

1116

1117 if (isatosa::ReduceAnyOp(op) && elementTy.isInteger(1))

1118 return rewriter.createarith::OrIOp(loc, args);

1119

1120 return {};

1121 }

1122

1123

1124

1125

1126 template

1129 auto loc = op->getLoc();

1130 auto inputTy = dyn_cast(op->getOperand(0).getType());

1131 auto resultTy = dyn_cast(op->getResult(0).getType());

1132 if (!inputTy || !resultTy)

1133 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");

1134

1135 auto elementTy = resultTy.getElementType();

1136 Value input = op->getOperand(0);

1137

1140 for (unsigned i = 0; i < inputTy.getRank(); i++) {

1141 if (axis != i) {

1142 reduceShape.push_back(inputTy.getDimSize(i));

1143 if (inputTy.isDynamicDim(i))

1144 dynDims.push_back(rewriter.createtensor::DimOp(loc, input, i));

1145 }

1146 }

1147

1149 inputs.push_back(input);

1150

1151

1152 auto emptyTensor =

1153 rewriter

1154 .createtensor::EmptyOp(loc, reduceShape, resultTy.getElementType(),

1155 dynDims)

1156 .getResult();

1157

1159 if (!fillValueAttr)

1161 op, "No initial value found for reduction operation");

1162

1163 auto fillValue = rewriter.createarith::ConstantOp(loc, fillValueAttr);

1164 auto filledTensor = rewriter

1167 .result();

1168 outputs.push_back(filledTensor);

1169

1170 bool isNanIgnoreMode = false;

1171 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||

1172 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {

1173

1174 if (isa(elementTy) && op.getNanMode() == "IGNORE") {

1175 isNanIgnoreMode = true;

1176

1177

1178

1179

1180 auto trueAttr = rewriter.getBoolAttr(true);

1181 auto trueValue = rewriter.createarith::ConstantOp(loc, trueAttr);

1182 auto emptyBoolTensor =

1183 rewriter

1184 .createtensor::EmptyOp(loc, reduceShape, trueValue.getType(),

1185 dynDims)

1186 .getResult();

1187 auto allResultsNaNTensor =

1188 rewriter

1191 .result();

1192

1193

1194

1195

1196

1197

1198

1199

1200

1201 inputs.push_back(input);

1202 outputs.push_back(allResultsNaNTensor);

1203 }

1204 }

1205

1206 bool didEncounterError = false;

1207 linalg::LinalgOp linalgOp = rewriter.createlinalg::ReduceOp(

1208 loc, inputs, outputs, axis,

1210 std::array<Value, 2> binaryArgs{

1211 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};

1213 op, binaryArgs, elementTy, rewriter);

1214 if (result)

1215 didEncounterError = true;

1216

1218 if (isNanIgnoreMode) {

1219 auto inputValue = blockArgs[0];

1220 auto initialValue = blockArgs[2];

1221 auto oldAllResultsNanFlagValue = blockArgs[3];

1222

1223

1224 Value isNaN = nestedBuilder.createarith::CmpFOp(

1225 op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);

1226

1227 auto selectOp = nestedBuilder.createarith::SelectOp(

1228 op->getLoc(), isNaN, initialValue, result);

1229

1230

1231 auto newAllResultsNanFlagValue = nestedBuilder.createarith::AndIOp(

1232 op->getLoc(), oldAllResultsNanFlagValue, isNaN);

1233 resultsToYield.push_back(selectOp);

1234 resultsToYield.push_back(newAllResultsNanFlagValue);

1235 } else {

1236 resultsToYield.push_back(result);

1237 }

1238 nestedBuilder.createlinalg::YieldOp(loc, resultsToYield);

1239 });

1240

1241 if (!didEncounterError)

1243 op, "unable to create linalg.generic body for reduce op");

1244

1245 if (isNanIgnoreMode) {

1246

1247

1248

1249

1250

1251

1253 elementTy,

1254 APFloat::getNaN(cast(elementTy).getFloatSemantics(), false));

1255 auto nanValue = rewriter.createarith::ConstantOp(loc, nanValueAttr);

1256 auto emptyNanTensor =

1257 rewriter

1258 .createtensor::EmptyOp(loc, reduceShape,

1259 resultTy.getElementType(), dynDims)

1260 .getResult();

1261 auto nanFilledTensor =

1262 rewriter

1265 .result();

1266

1267

1268

1269 auto finalEmptyTensor =

1270 rewriter

1271 .createtensor::EmptyOp(loc, reduceShape,

1272 resultTy.getElementType(), dynDims)

1273 .getResult();

1274

1275

1276

1278 ins.push_back(linalgOp->getOpResult(1));

1279 ins.push_back(nanFilledTensor);

1280 ins.push_back(linalgOp->getResult(0));

1281 outs.push_back(finalEmptyTensor);

1282 auto linalgSelect =

1283 rewriter.createlinalg::SelectOp(op->getLoc(), ins, outs);

1284 linalgOp = linalgSelect;

1285 }

1286

1288 uint64_t expandInputRank =

1289 cast(linalgOp->getResults()[0].getType()).getRank();

1290 reassociationMap.resize(expandInputRank);

1291

1292 for (uint64_t i = 0; i < expandInputRank; i++) {

1293 int32_t dimToPush = i > axis ? i + 1 : i;

1294 reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));

1295 }

1296

1297 if (expandInputRank != 0) {

1298 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;

1299 reassociationMap[expandedDim].push_back(

1301 }

1302

1303

1304

1305

1306

1308 op, resultTy, linalgOp->getResults()[0], reassociationMap);

1309 return success();

1310 }

1311

1312 namespace {

1313

1314 template

1316 public:

1319

1320 LogicalResult

1321 matchAndRewrite(SrcOp op, OpAdaptor operands,

1324 op, operands.getOperands(), rewriter, *this->getTypeConverter());

1325 }

1326 };

1327

1328 class RescaleConverter : public OpRewritePatterntosa::RescaleOp {

1329 public:

1331

1332 LogicalResult matchAndRewrite(tosa::RescaleOp op,

1334 auto loc = op.getLoc();

1335 auto input = op.getInput();

1336 auto inputTy = cast(op.getInput().getType());

1337 auto outputTy = cast(op.getOutput().getType());

1338 unsigned rank = inputTy.getRank();

1339

1340

1341 if (op.getRoundingMode() == "INEXACT_ROUND")

1342 return rewriter.notifyMatchFailure(

1343 op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "

1344 "currently supported");

1345 if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())

1346 return rewriter.notifyMatchFailure(

1347 op, "tosa.rescale requires scale32 for double_round to be true");

1348

1349 if (!isa(inputTy.getElementType()))

1350 return rewriter.notifyMatchFailure(op, "only support integer type");

1351

1353 for (int i = 0; i < outputTy.getRank(); i++) {

1354 if (outputTy.isDynamicDim(i)) {

1355 dynDims.push_back(rewriter.createtensor::DimOp(loc, input, i));

1356 }

1357 }

1358

1359

1362 return rewriter.notifyMatchFailure(

1363 op, "tosa.rescale requires constant shift input values");

1364

1367 return rewriter.notifyMatchFailure(

1368 op, "tosa.rescale requires constant multiplier input values");

1369

1371 llvm::to_vector(shiftElems.getValues<int8_t>());

1372

1374 llvm::map_range(multiplierElems.getValues(),

1375 [](IntegerAttr attr) -> int32_t {

1376 return static_cast<int32_t>(attr.getInt());

1377 }));

1378

1379

1380 for (int i = 0, s = multiplierValues.size(); i < s; i++) {

1381 if (shiftValues[i] > 63) {

1382 shiftValues[i] = 0;

1383 multiplierValues[i] = 0;

1384 }

1385 }

1386

1387

1388

1389

1390 bool doubleRound =

1391 op.getRoundingMode() == "DOUBLE_ROUND" &&

1392 llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });

1393 StringAttr roundingMode = doubleRound

1394 ? rewriter.getStringAttr("DOUBLE_ROUND")

1395 : rewriter.getStringAttr("SINGLE_ROUND");

1396

1398 rewriter.getMultiDimIdentityMap(rank)};

1400

1401

1402

1403 Value multiplierConstant;

1404 int64_t multiplierArg = 0;

1405 if (multiplierValues.size() == 1) {

1406 multiplierConstant = rewriter.createarith::ConstantOp(

1407 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));

1408 } else {

1410 rewriter.getAffineDimExpr(rank - 1)};

1411 auto multiplierType =

1413 rewriter.getI32Type());

1414 genericInputs.push_back(rewriter.createarith::ConstantOp(

1416

1417 indexingMaps.push_back(AffineMap::get(rank,

1418 0, multiplierExprs,

1419 rewriter.getContext()));

1420

1421 multiplierArg = indexingMaps.size() - 1;

1422 }

1423

1424

1425

1426 Value shiftConstant;

1427 int64_t shiftArg = 0;

1428 if (shiftValues.size() == 1) {

1429 shiftConstant = rewriter.createarith::ConstantOp(

1430 loc, rewriter.getI8IntegerAttr(shiftValues.front()));

1431 } else {

1433 rewriter.getAffineDimExpr(rank - 1)};

1434 auto shiftType =

1436 rewriter.getIntegerType(8));

1437 genericInputs.push_back(rewriter.createarith::ConstantOp(

1439 indexingMaps.push_back(AffineMap::get(rank,

1440 0, shiftExprs,

1441 rewriter.getContext()));

1442 shiftArg = indexingMaps.size() - 1;

1443 }

1444

1445

1446 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));

1447

1448

1449 Value emptyTensor = rewriter.createtensor::EmptyOp(

1450 loc, outputTy.getShape(), outputTy.getElementType(),

1452

1453 auto linalgOp = rewriter.createlinalg::GenericOp(

1454 loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,

1458 Value value = blockArgs[0];

1460

1461 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();

1462 if (failed(maybeIZp)) {

1463 (void)rewriter.notifyMatchFailure(

1464 op, "input zero point cannot be statically determined");

1465 return;

1466 }

1467

1469

1470 const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;

1471 auto inputZp = nestedBuilder.createarith::ConstantOp(

1472 loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),

1473 *maybeIZp));

1474

1475 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();

1476 if (failed(maybeOZp)) {

1477 (void)rewriter.notifyMatchFailure(

1478 op, "output zero point cannot be statically determined");

1479 return;

1480 };

1481

1482 IntegerType outIntType =

1483 cast(blockArgs.back().getType());

1484 unsigned outBitWidth = outIntType.getWidth();

1485 const int32_t outAttrBitwidth = 32;

1486 assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");

1487 auto outputZp = nestedBuilder.createarith::ConstantOp(

1488 loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),

1489 *maybeOZp));

1490

1491 Value multiplier = multiplierConstant ? multiplierConstant

1492 : blockArgs[multiplierArg];

1493 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];

1494

1496 value = nestedBuilder

1497 .create(

1498 nestedLoc,

1499 nestedBuilder.getIntegerType(

1501 value)

1502 .getResult(0);

1503 }

1505 if (op.getInputUnsigned()) {

1506 value = nestedBuilder.createarith::ExtUIOp(

1507 nestedLoc, nestedBuilder.getI32Type(), value);

1508 } else {

1509 value = nestedBuilder.createarith::ExtSIOp(

1510 nestedLoc, nestedBuilder.getI32Type(), value);

1511 }

1512 }

1513

1514 value =

1515 nestedBuilder.createarith::SubIOp(nestedLoc, value, inputZp);

1516

1517 value = nestedBuilder.createtosa::ApplyScaleOp(

1518 loc, nestedBuilder.getI32Type(), value, multiplier, shift,

1519 roundingMode);

1520

1521

1522 value =

1523 nestedBuilder.createarith::AddIOp(nestedLoc, value, outputZp);

1524

1525

1526 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();

1527 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();

1528

1529

1530 if (op.getOutputUnsigned()) {

1531 intMin = 0;

1532 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();

1533 }

1534

1535 auto intMinVal = nestedBuilder.createarith::ConstantOp(

1536 loc, nestedBuilder.getI32IntegerAttr(intMin));

1537 auto intMaxVal = nestedBuilder.createarith::ConstantOp(

1538 loc, nestedBuilder.getI32IntegerAttr(intMax));

1539

1540 value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,

1541 nestedBuilder, false);

1542

1543 if (outIntType.getWidth() < 32) {

1544 value = nestedBuilder.createarith::TruncIOp(

1545 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),

1546 value);

1547 }

1548

1549 if (outIntType.isUnsignedInteger()) {

1550 value = nestedBuilder

1551 .create(nestedLoc,

1552 outIntType, value)

1553 .getResult(0);

1554 }

1555 nestedBuilder.createlinalg::YieldOp(loc, value);

1556 });

1557

1558 rewriter.replaceOp(op, linalgOp->getResults());

1559 return success();

1560 }

1561 };

1562

1563

1564

1565

1566 class ResizeUnaryConverter : public OpRewritePatterntosa::ResizeOp {

1567 public:

1569

1570 LogicalResult matchAndRewrite(tosa::ResizeOp op,

1574 auto input = op.getInput();

1575 auto inputTy = cast(input.getType());

1576 auto resultTy = cast(op.getType());

1577 const bool isBilinear = op.getMode() == "BILINEAR";

1578

1579 auto inputH = inputTy.getDimSize(1);

1580 auto inputW = inputTy.getDimSize(2);

1581 auto outputH = resultTy.getDimSize(1);

1582 auto outputW = resultTy.getDimSize(2);

1583

1584 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)

1585 return rewriter.notifyMatchFailure(

1586 op, "tosa.resize is not a pure 1x1->1x1 image operation");

1587

1588

1589 if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")

1590 return rewriter.notifyMatchFailure(

1591 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");

1592

1593 if (inputTy == resultTy) {

1594 rewriter.replaceOp(op, input);

1595 return success();

1596 }

1597

1600 return failure();

1601 }

1602

1603

1605 reassociationMap[0].push_back(builder.getAffineDimExpr(0));

1606 reassociationMap[1].push_back(builder.getAffineDimExpr(1));

1607 reassociationMap[1].push_back(builder.getAffineDimExpr(2));

1608 reassociationMap[1].push_back(builder.getAffineDimExpr(3));

1609

1610 auto collapseTy =

1612 inputTy.getElementType());

1613 Value collapse = builder.createtensor::CollapseShapeOp(collapseTy, input,

1614 reassociationMap);

1615

1616

1618 if (inputTy.isDynamicDim(0))

1619 outputDynSize.push_back(builder.createtensor::DimOp(input, 0));

1620 if (inputTy.isDynamicDim(3))

1621 outputDynSize.push_back(builder.createtensor::DimOp(input, 3));

1622

1623

1624 auto genericTy = collapseTy.clone(resultTy.getElementType());

1625 Value empty = builder.createtensor::EmptyOp(

1626 genericTy.getShape(), resultTy.getElementType(), outputDynSize);

1627 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());

1629 utils::IteratorType::parallel);

1630

1631 auto generic = builder.createlinalg::GenericOp(

1635 Value value = args[0];

1636

1637 if (inputTy.getElementType() != resultTy.getElementType()) {

1638 value =

1639 b.createarith::ExtSIOp(loc, resultTy.getElementType(), value);

1640

1641 if (isBilinear && scale[0] != 0) {

1642 Value scaleY = b.createarith::ConstantOp(

1643 loc, b.getI32IntegerAttr(scale[0]));

1644 value = b.createarith::MulIOp(loc, value, scaleY);

1645 }

1646

1647 if (isBilinear && scale[2] != 0) {

1648 Value scaleX = b.createarith::ConstantOp(

1649 loc, b.getI32IntegerAttr(scale[2]));

1650 value = b.createarith::MulIOp(loc, value, scaleX);

1651 }

1652 }

1653

1654 b.createlinalg::YieldOp(loc, value);

1655 });

1656

1657 rewriter.replaceOpWithNewOptensor::ExpandShapeOp(

1658 op, resultTy, generic.getResults()[0], reassociationMap);

1659 return success();

1660 }

1661 };

1662

1663

1664

1665

1667 public:

1669

1670 LogicalResult matchAndRewrite(tosa::ResizeOp op,

1674 auto input = op.getInput();

1675 auto inputTy = dyn_cast(input.getType());

1676 auto resultTy = dyn_cast(op.getType());

1677

1678 if (!inputTy || !resultTy)

1679 return rewriter.notifyMatchFailure(op,

1680 "requires ranked input/output types");

1681

1682 auto batch = inputTy.getDimSize(0);

1683 auto channels = inputTy.getDimSize(3);

1684 auto inputH = inputTy.getDimSize(1);

1685 auto inputW = inputTy.getDimSize(2);

1686 auto outputH = resultTy.getDimSize(1);

1687 auto outputW = resultTy.getDimSize(2);

1688

1689 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))

1690 return rewriter.notifyMatchFailure(

1691 op, "tosa.resize has no broadcasting behavior");

1692

1693

1694

1696 resizeShape.push_back(batch);

1697 resizeShape.push_back(inputH == 1 ? 1 : outputH);

1698 resizeShape.push_back(inputW == 1 ? 1 : outputW);

1699 resizeShape.push_back(channels);

1700

1701 auto resizeTy = resultTy.clone(resizeShape);

1702 auto resize = builder.createtosa::ResizeOp(resizeTy, input, op.getScale(),

1703 op.getOffset(), op.getBorder(),

1704 op.getMode());

1705

1706

1708 reassociationMap[0].push_back(builder.getAffineDimExpr(0));

1709 reassociationMap.back().push_back(builder.getAffineDimExpr(1));

1710 if (inputH != 1)

1711 reassociationMap.push_back({});

1712 reassociationMap.back().push_back(builder.getAffineDimExpr(2));

1713 if (inputW != 1)

1714 reassociationMap.push_back({});

1715 reassociationMap.back().push_back(builder.getAffineDimExpr(3));

1716

1718 if (inputH != 1)

1719 collapseShape.push_back(outputH);

1720 if (inputW != 1)

1721 collapseShape.push_back(outputW);

1722 collapseShape.push_back(channels);

1723

1724 auto collapseTy = resultTy.clone(collapseShape);

1725 Value collapse = builder.createtensor::CollapseShapeOp(collapseTy, resize,

1726 reassociationMap);

1727

1728

1730 if (inputTy.isDynamicDim(0))

1731 outputDynSize.push_back(builder.createtensor::DimOp(input, 0));

1732 if (inputTy.isDynamicDim(3))

1733 outputDynSize.push_back(builder.createtensor::DimOp(input, 3));

1734

1736 utils::IteratorType::parallel);

1737 Value empty = builder.createtensor::EmptyOp(

1738 resultTy.getShape(), resultTy.getElementType(), outputDynSize);

1739

1741 if (inputH != 1)

1742 inputExprs.push_back(rewriter.getAffineDimExpr(1));

1743 if (inputW != 1)

1744 inputExprs.push_back(rewriter.getAffineDimExpr(2));

1745 inputExprs.push_back(rewriter.getAffineDimExpr(3));

1746

1747 auto inputMap = AffineMap::get(resultTy.getRank(), 0,

1748 inputExprs, rewriter.getContext());

1749

1750 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());

1751 rewriter.replaceOpWithNewOplinalg::GenericOp(

1755 Value value = args[0];

1756 b.createlinalg::YieldOp(loc, value);

1757 });

1758

1759 return success();

1760 }

1761 };

1762

1763 class GenericResizeConverter : public OpRewritePatterntosa::ResizeOp {

1764 public:

1766

1767 LogicalResult matchAndRewrite(tosa::ResizeOp op,

1771 auto input = op.getInput();

1772 auto inputTy = cast(input.getType());

1773 auto resultTy = cast(op.getType());

1774 auto resultETy = resultTy.getElementType();

1775

1776 bool floatingPointMode = resultETy.isF16() || resultETy.isF32();

1777 auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();

1778

1779 auto imageH = inputTy.getShape()[1];

1780 auto imageW = inputTy.getShape()[2];

1781

1782 auto dynamicDimsOr =

1784 if (!dynamicDimsOr.has_value())

1785 return rewriter.notifyMatchFailure(

1786 op, "unable to get dynamic dimensions of tosa.resize");

1787

1788 if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")

1789 return rewriter.notifyMatchFailure(

1790 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");

1791

1793 rewriter.getMultiDimIdentityMap(resultTy.getRank())};

1794 auto emptyTensor = b.createtensor::EmptyOp(resultTy.getShape(), resultETy,

1795 *dynamicDimsOr);

1796 auto genericOp = b.createlinalg::GenericOp(

1799 Value resize = genericOp.getResult(0);

1800

1801 {

1803 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),

1805 Value batch = b.createlinalg::IndexOp(0);

1806 Value y = b.createlinalg::IndexOp(1);

1807 Value x = b.createlinalg::IndexOp(2);

1808 Value channel = b.createlinalg::IndexOp(3);

1809

1811 b.createarith::ConstantOp(b.getZeroAttr(b.getI32Type()));

1812 Value zeroFp = b.createarith::ConstantOp(b.getZeroAttr(floatTy));

1813 Value hMax = b.createarith::ConstantOp(b.getI32IntegerAttr(imageH - 1));

1814 Value wMax = b.createarith::ConstantOp(b.getI32IntegerAttr(imageW - 1));

1815

1816 Value inY = b.createarith::IndexCastOp(b.getI32Type(), y);

1817 Value inX = b.createarith::IndexCastOp(b.getI32Type(), x);

1818

1823 return rewriter.notifyMatchFailure(

1824 op, "tosa.resize scale/offset/border should have compile time "

1825 "constant values.");

1826 }

1827

1828 Value yScaleN, yScaleD, xScaleN, xScaleD;

1829 yScaleN = b.createarith::ConstantOp(b.getI32IntegerAttr(scale[0]));

1830 yScaleD = b.createarith::ConstantOp(b.getI32IntegerAttr(scale[1]));

1831 xScaleN = b.createarith::ConstantOp(b.getI32IntegerAttr(scale[2]));

1832 xScaleD = b.createarith::ConstantOp(b.getI32IntegerAttr(scale[3]));

1833

1834 Value yOffset, xOffset, yBorder, xBorder;

1835 yOffset = b.createarith::ConstantOp(b.getI32IntegerAttr(offset[0]));

1836 xOffset = b.createarith::ConstantOp(b.getI32IntegerAttr(offset[1]));

1837 yBorder = b.createarith::ConstantOp(b.getI32IntegerAttr(border[0]));

1838 xBorder = b.createarith::ConstantOp(b.getI32IntegerAttr(border[1]));

1839

1840

1841 auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,

1844 if (size == 1) {

1845 index = zeroI32;

1846 delta = zeroFp;

1847 return;

1848 }

1849

1850

1851 Value val = b.createarith::MulIOp(in, scaleD);

1852 val = b.createarith::AddIOp(val, offset);

1853 index = b.createarith::FloorDivSIOp(val, scaleN);

1854

1855

1856

1857 Value r = b.createarith::RemSIOp(val, scaleN);

1858 Value rFp = b.createarith::SIToFPOp(floatTy, r);

1859 Value scaleNfp = b.createarith::UIToFPOp(floatTy, scaleN);

1860 delta = b.createarith::DivFOp(rFp, scaleNfp);

1861 };

1862

1863

1864 auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,

1867 if (size == 1) {

1868 index = zeroI32;

1869 delta = zeroI32;

1870 return;

1871 }

1872

1873

1874

1875 Value val = b.createarith::MulIOp(in, scaleD);

1876 val = b.createarith::AddIOp(val, offset);

1877 index = b.createarith::DivSIOp(val, scaleN);

1878 delta = b.createarith::MulIOp(index, scaleN);

1879 delta = b.createarith::SubIOp(val, delta);

1880 };

1881

1882 Value ix, iy, dx, dy;

1883 if (floatingPointMode) {

1884 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);

1885 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);

1886 } else {

1887 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);

1888 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);

1889 }

1890

1891 if (op.getMode() == "NEAREST_NEIGHBOR") {

1892 auto one = b.createarith::ConstantOp(b.getI32IntegerAttr(1));

1893

1894 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,

1897 if (size == 1) {

1898 return b.createarith::ConstantIndexOp(0);

1899 }

1900

1902 if (floatingPointMode) {

1903 auto h = b.createarith::ConstantOp(b.getFloatAttr(floatTy, 0.5f));

1904 pred = b.createarith::CmpFOp(arith::CmpFPredicate::OGE, dval, h);

1905 } else {

1906 Value dvalDouble = b.createarith::ShLIOp(dval, one);

1907 pred = b.createarith::CmpIOp(arith::CmpIPredicate::sge,

1908 dvalDouble, scale);

1909 }

1910

1911 auto offset = b.createarith::SelectOp(pred, one, zeroI32);

1912 val = b.createarith::AddIOp(val, offset);

1913 val = clampIntHelper(loc, val, zeroI32, max, b, false);

1914 return b.createarith::IndexCastOp(b.getIndexType(), val);

1915 };

1916

1917 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);

1918 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);

1919

1920 Value result = b.createtensor::ExtractOp(

1921 input, ValueRange{batch, iy, ix, channel});

1922

1923 b.createlinalg::YieldOp(result);

1924 } else {

1925

1926 assert(op.getMode() == "BILINEAR");

1927

1928 auto oneVal = b.createarith::ConstantOp(b.getI32IntegerAttr(1));

1929

1930 auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,

1932 val0 = in;

1933 val1 = b.createarith::AddIOp(val0, oneVal);

1934 val0 =

1935 clampIntHelper(loc, val0, zeroI32, max, b, false);

1936 val1 =

1937 clampIntHelper(loc, val1, zeroI32, max, b, false);

1938 val0 = b.createarith::IndexCastOp(b.getIndexType(), val0);

1939 val1 = b.createarith::IndexCastOp(b.getIndexType(), val1);

1940 };

1941

1942

1943

1944

1945

1946

1947 Value x0, x1, y0, y1;

1948 getClampedIdxs(y0, y1, imageH, iy, hMax, b);

1949 getClampedIdxs(x0, x1, imageW, ix, wMax, b);

1950

1951 Value y0x0 = b.createtensor::ExtractOp(

1952 input, ValueRange{batch, y0, x0, channel});

1953 Value y0x1 = b.createtensor::ExtractOp(

1954 input, ValueRange{batch, y0, x1, channel});

1955 Value y1x0 = b.createtensor::ExtractOp(

1956 input, ValueRange{batch, y1, x0, channel});

1957 Value y1x1 = b.createtensor::ExtractOp(

1958 input, ValueRange{batch, y1, x1, channel});

1959

1960 if (floatingPointMode) {

1961 auto oneVal =

1962 b.createarith::ConstantOp(b.getFloatAttr(floatTy, 1.0f));

1963 auto interpolate = [&](Value val0, Value val1, Value delta,

1964 int inputSize,

1966 if (inputSize == 1)

1967 return val0;

1968 Value oneMinusDelta = b.createarith::SubFOp(oneVal, delta);

1969 Value mul0 = b.createarith::MulFOp(val0, oneMinusDelta);

1970 Value mul1 = b.createarith::MulFOp(val1, delta);

1971 return b.createarith::AddFOp(mul0, mul1);

1972 };

1973

1974

1975

1976

1977 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);

1978

1979

1980

1981

1982 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);

1983

1984

1985

1986 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);

1987 b.createlinalg::YieldOp(result);

1988 } else {

1989

1990 y0x0 = b.createarith::ExtSIOp(resultETy, y0x0);

1991 y0x1 = b.createarith::ExtSIOp(resultETy, y0x1);

1992 y1x0 = b.createarith::ExtSIOp(resultETy, y1x0);

1993 y1x1 = b.createarith::ExtSIOp(resultETy, y1x1);

1994

1996 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {

1997 dx = b.createarith::ExtSIOp(resultETy, dx);

1998 dy = b.createarith::ExtSIOp(resultETy, dy);

1999 }

2000

2001 Value yScaleNExt = yScaleN;

2002 Value xScaleNExt = xScaleN;

2003

2004 const int64_t scaleBitwidth =

2006 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {

2007 yScaleNExt = b.createarith::ExtSIOp(resultETy, yScaleN);

2008 xScaleNExt = b.createarith::ExtSIOp(resultETy, xScaleN);

2009 }

2010

2011 auto interpolate = [](Value val0, Value val1, Value weight1,

2012 Value scale, int inputSize,

2014 if (inputSize == 1)

2015 return b.createarith::MulIOp(val0, scale);

2016 Value weight0 = b.createarith::SubIOp(scale, weight1);

2017 Value mul0 = b.createarith::MulIOp(val0, weight0);

2018 Value mul1 = b.createarith::MulIOp(val1, weight1);

2019 return b.createarith::AddIOp(mul0, mul1);

2020 };

2021

2022 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);

2023 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);

2025 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);

2026 b.createlinalg::YieldOp(result);

2027 }

2028 }

2029 }

2030

2031 rewriter.replaceOp(op, resize);

2032 return success();

2033 }

2034 };

2035

2036

2037

2038

2039 template

2041 public:

2043

2044 LogicalResult matchAndRewrite(SrcOp op,

2046 rewriter.replaceOp(op, op.getOperation()->getOperands());

2047 return success();

2048 }

2049 };

2050

2051 template

2053 public:

2055

2056 LogicalResult matchAndRewrite(SrcOp reduceOp,

2059 }

2060 };

2061

2062 class ReverseConverter : public OpRewritePatterntosa::ReverseOp {

2063 public:

2065

2066 LogicalResult matchAndRewrite(tosa::ReverseOp op,

2068 auto loc = op.getLoc();

2069 Value input = op.getInput1();

2070 auto inputTy = cast(input.getType());

2071 auto resultTy = cast(op.getType());

2072 auto axis = op.getAxis();

2073

2075 for (int i = 0; i < inputTy.getRank(); i++) {

2076 if (inputTy.isDynamicDim(i)) {

2077 dynDims.push_back(rewriter.createtensor::DimOp(loc, input, i));

2078 }

2079 }

2080

2081 Value axisDimSize = rewriter.createtensor::DimOp(loc, input, axis);

2082

2083

2084 auto emptyTensor = rewriter

2085 .createtensor::EmptyOp(loc, inputTy.getShape(),

2086 inputTy.getElementType(),

2088 .getResult();

2090 rewriter.getMultiDimIdentityMap(resultTy.getRank())};

2091

2092 rewriter.replaceOpWithNewOplinalg::GenericOp(

2097 for (unsigned int i = 0; i < inputTy.getRank(); i++) {

2099 rewriter.createlinalg::IndexOp(nestedLoc, i).getResult();

2100 if (i == axis) {

2101 auto one = rewriter.createarith::ConstantIndexOp(nestedLoc, 1);

2102 auto sizeMinusOne =

2103 rewriter.createarith::SubIOp(nestedLoc, axisDimSize, one);

2104 index = rewriter.createarith::SubIOp(nestedLoc, sizeMinusOne,

2105 index);

2106 }

2107

2108 indices.push_back(index);

2109 }

2110

2111 auto extract = nestedBuilder.createtensor::ExtractOp(

2112 nestedLoc, input, indices);

2113 nestedBuilder.createlinalg::YieldOp(op.getLoc(),

2114 extract.getResult());

2115 });

2116 return success();

2117 }

2118 };

2119

2120

2121

2122

2123

2126

2127 LogicalResult

2128 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,

2130 auto loc = op.getLoc();

2131 auto input = op.getInput1();

2132 auto inputTy = cast(input.getType());

2133 auto inputShape = inputTy.getShape();

2134 auto resultTy = cast(op.getType());

2135 auto elementTy = inputTy.getElementType();

2136 int64_t rank = inputTy.getRank();

2137

2139 if (failed(op.getConstantMultiples(multiples)))

2140 return failure();

2141

2142

2144 for (int i = 0; i < rank; i++) {

2145 int64_t dim = multiples[i];

2146 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);

2147 genericShape.push_back(inputShape[i]);

2148 }

2149

2151 for (int i = 0; i < inputTy.getRank(); i++) {

2152 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {

2153 dynDims.push_back(rewriter.createtensor::DimOp(loc, input, i));

2154 }

2155 }

2156

2157 auto emptyTensor = rewriter.createtensor::EmptyOp(

2158 op.getLoc(), genericShape, elementTy, dynDims);

2159

2160

2162 dimExprs.reserve(rank);

2163 for (unsigned i = 0; i < rank; ++i)

2165

2166 auto readAffineMap =

2167 AffineMap::get(rank * 2, 0, dimExprs,

2169

2172

2173 auto genericOp = rewriter.createlinalg::GenericOp(

2175 ValueRange{emptyTensor}, affineMaps,

2178 nestedBuilder.createlinalg::YieldOp(op.getLoc(), *args.begin());

2179 });

2180

2184 op, resultTy, genericOp.getResult(0), shapeValue);

2185 return success();

2186 }

2187 };

2188

2189

2190

2191

2192

2193

2194

2195

2196

2197

2198

2199

2200

2201

2202 class ArgMaxConverter : public OpRewritePatterntosa::ArgMaxOp {

2203 public:

2205

2206 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,

2208 auto loc = argmaxOp.getLoc();

2209 Value input = argmaxOp.getInput();

2210 auto inputTy = cast(input.getType());

2211 auto resultTy = cast(argmaxOp.getOutput().getType());

2212 auto inElementTy = inputTy.getElementType();

2213 auto outElementTy = resultTy.getElementType();

2214 int axis = argmaxOp.getAxis();

2216

2217 if (!isa(outElementTy))

2219 argmaxOp,

2220 "tosa.arg_max to linalg.* requires integer-like result type");

2221

2223 for (int i = 0; i < inputTy.getRank(); i++) {

2224 if (inputTy.isDynamicDim(i) && i != axis) {

2225 dynDims.push_back(rewriter.createtensor::DimOp(loc, input, i));

2226 }

2227 }

2228

2229

2230 auto emptyTensorIdx = rewriter

2231 .createtensor::EmptyOp(loc, resultTy.getShape(),

2232 outElementTy, dynDims)

2233 .getResult();

2234 auto fillValueIdx = rewriter.createarith::ConstantOp(

2236 auto filledTensorIdx =

2237 rewriter

2240 .result();

2241

2242

2243 auto emptyTensorMax = rewriter

2244 .createtensor::EmptyOp(loc, resultTy.getShape(),

2245 inElementTy, dynDims)

2246 .getResult();

2247 auto fillValueMaxAttr =

2249

2250 if (!fillValueMaxAttr)

2252 argmaxOp, "unsupported tosa.argmax element type");

2253

2254 auto fillValueMax =

2255 rewriter.createarith::ConstantOp(loc, fillValueMaxAttr);

2256 auto filledTensorMax =

2257 rewriter

2260 .result();

2261

2262

2263

2265 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);

2266 iteratorTypes[axis] = utils::IteratorType::reduction;

2267

2270 for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {

2272 if (axis != i)

2274 }

2275

2276 bool didEncounterError = false;

2279 auto linalgOp = rewriter.createlinalg::GenericOp(

2280 loc, ArrayRef({resultTy, resultMaxTy}), input,

2281 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,

2284 auto newValue = blockArgs[0];

2285 auto oldIndex = blockArgs[1];

2286 auto oldValue = blockArgs[2];

2287

2288 Value newIndex = rewriter.createarith::IndexCastOp(

2289 nestedLoc, oldIndex.getType(),

2290 rewriter.createlinalg::IndexOp(loc, axis));

2291

2292 Value predicate;

2293 if (isa(inElementTy)) {

2294 if (argmaxOp.getNanMode() == "IGNORE") {

2295

2296

2297 predicate = rewriter.createarith::CmpFOp(

2298 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);

2299 } else {

2300

2301

2302

2303 Value gt = rewriter.createarith::CmpFOp(

2304 nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue);

2305 Value oldNonNaN = rewriter.createarith::CmpFOp(

2306 nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue);

2307 predicate = rewriter.createarith::AndIOp(

2308 nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);

2309 }

2310 } else if (isa(inElementTy)) {

2311 predicate = rewriter.createarith::CmpIOp(

2312 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);

2313 } else {

2314 didEncounterError = true;

2315 return;

2316 }

2317

2318 auto resultMax = rewriter.createarith::SelectOp(

2319 nestedLoc, predicate, newValue, oldValue);

2320 auto resultIndex = rewriter.createarith::SelectOp(

2321 nestedLoc, predicate, newIndex, oldIndex);

2322 nestedBuilder.createlinalg::YieldOp(

2323 nestedLoc, ValueRange({resultIndex, resultMax}));

2324 });

2325

2326 if (didEncounterError)

2328 argmaxOp, "unsupported tosa.argmax element type");

2329

2330 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));

2331 return success();

2332 }

2333 };

2334

2336 public:

2338 LogicalResult

2339 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,

2341 auto input = adaptor.getOperands()[0];

2342 auto indices = adaptor.getOperands()[1];

2343

2344 auto valuesTy = dyn_cast(op.getValues().getType());

2345 auto resultTy = dyn_cast(op.getType());

2346 if (!valuesTy || !resultTy)

2347 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");

2348

2349 auto dynamicDims = inferDynamicDimsForGather(

2350 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());

2351

2352 auto resultElementTy = resultTy.getElementType();

2353

2354 auto loc = op.getLoc();

2355 auto emptyTensor =

2356 rewriter

2357 .createtensor::EmptyOp(loc, resultTy.getShape(), resultElementTy,

2358 dynamicDims)

2359 .getResult();

2360

2363 resultTy.getRank(), 0,

2364 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},

2367

2368 auto genericOp = rewriter.createlinalg::GenericOp(

2370 ValueRange{emptyTensor}, affineMaps,

2373 auto indexValue = args[0];

2374 auto index0 = rewriter.createlinalg::IndexOp(loc, 0);

2375 Value index1 = rewriter.createarith::IndexCastOp(

2377 auto index2 = rewriter.createlinalg::IndexOp(loc, 2);

2378 Value extract = rewriter.createtensor::ExtractOp(

2379 loc, input, ValueRange{index0, index1, index2});

2380 rewriter.createlinalg::YieldOp(loc, extract);

2381 });

2382 rewriter.replaceOp(op, genericOp.getResult(0));

2383 return success();

2384 }

2385

2389 Value indices) {

2391

2392 auto addDynamicDimension = [&](Value source, int64_t dim) {

2394 if (auto dimValue = llvm::dyn_cast_if_present(sz))

2395 results.push_back(dimValue);

2396 };

2397

2398 addDynamicDimension(values, 0);

2399 addDynamicDimension(indices, 1);

2400 addDynamicDimension(values, 2);

2401 return results;

2402 }

2403 };

2404

2405

2406

2407

2408 class TableConverter : public OpRewritePatterntosa::TableOp {

2409 public:

2411

2412 LogicalResult matchAndRewrite(tosa::TableOp op,

2414 auto loc = op.getLoc();

2415 Value input = op.getInput1();

2417 auto inputTy = cast(input.getType());

2418 auto tableTy = cast(table.getType());

2419 auto resultTy = cast(op.getType());

2420

2421 auto inputElementTy = inputTy.getElementType();

2422 auto tableElementTy = tableTy.getElementType();

2423 auto resultElementTy = resultTy.getElementType();

2424

2426 for (int i = 0; i < resultTy.getRank(); ++i) {

2427 if (inputTy.isDynamicDim(i)) {

2428 dynDims.push_back(

2429 rewriter.createtensor::DimOp(loc, op.getOperand(0), i));

2430 }

2431 }

2432

2433 auto emptyTensor = rewriter

2434 .createtensor::EmptyOp(loc, resultTy.getShape(),

2435 resultElementTy, dynDims)

2436 .getResult();

2437

2441

2442 auto genericOp = rewriter.createlinalg::GenericOp(

2445 rewriter.replaceOp(op, genericOp.getResult(0));

2446

2447 {

2450 &genericOp.getRegion(), genericOp.getRegion().end(),

2451 TypeRange({inputElementTy, resultElementTy}), {loc, loc});

2452

2453 auto inputValue = block->getArgument(0);

2455 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&

2456 resultElementTy.isInteger(8)) {

2457 Value index = rewriter.createarith::IndexCastOp(

2459 Value offset = rewriter.createarith::ConstantIndexOp(loc, 128);

2460 index = rewriter.createarith::AddIOp(loc, rewriter.getIndexType(),

2461 index, offset);

2464 rewriter.createlinalg::YieldOp(loc, extract);

2465 return success();

2466 }

2467

2468 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&

2469 resultElementTy.isInteger(32)) {

2470 Value extend = rewriter.createarith::ExtSIOp(

2471 loc, rewriter.getI32Type(), inputValue);

2472

2473 auto offset = rewriter.createarith::ConstantOp(

2475 auto seven = rewriter.createarith::ConstantOp(

2477 auto one = rewriter.createarith::ConstantOp(

2479 auto b1111111 = rewriter.createarith::ConstantOp(

2481

2482

2483

2484

2485

2486 auto extendAdd = rewriter.createarith::AddIOp(loc, extend, offset);

2487 Value index = rewriter.createarith::ShRUIOp(loc, extendAdd, seven);

2488 Value fraction =

2489 rewriter.createarith::AndIOp(loc, extendAdd, b1111111);

2490

2491

2492

2493

2494 Value indexPlusOne = rewriter.createarith::AddIOp(loc, index, one);

2495

2496 index = rewriter.createarith::IndexCastOp(

2498 indexPlusOne = rewriter.createarith::IndexCastOp(

2499 loc, rewriter.getIndexType(), indexPlusOne);

2500

2503 Value next = rewriter.createtensor::ExtractOp(

2505

2506 base =

2507 rewriter.createarith::ExtSIOp(loc, rewriter.getI32Type(), base);

2508 next =

2509 rewriter.createarith::ExtSIOp(loc, rewriter.getI32Type(), next);

2510

2511

2512

2513 Value baseScaled = rewriter.createarith::ShLIOp(loc, base, seven);

2514 Value diff = rewriter.createarith::SubIOp(loc, next, base);

2515 Value diffScaled = rewriter.createarith::MulIOp(loc, diff, fraction);

2517 rewriter.createarith::AddIOp(loc, baseScaled, diffScaled);

2518

2519 rewriter.createlinalg::YieldOp(loc, result);

2520

2521 return success();

2522 }

2523 }

2524

2526 op, "unable to create body for tosa.table op");

2527 }

2528 };

2529

2530 struct RFFT2dConverter final : public OpRewritePattern {

2532

2533 static bool isRankedTensor(Type type) { return isa(type); }

2534

2537 auto one = builder.createarith::ConstantIndexOp(loc, 1);

2538 auto two = builder.createarith::ConstantIndexOp(loc, 2);

2539

2541 auto divBy2 = builder.createOrFoldarith::DivUIOp(loc, value, two);

2542 auto plusOne = builder.createOrFoldarith::AddIOp(loc, divBy2, one);

2544 }

2545

2546 static RankedTensorType

2549

2551

2552

2553

2554 dims[2] = halfPlusOne(builder, loc, dims[2]);

2555

2558

2559 auto elementType = cast(input.getType()).getElementType();

2561 }

2562

2564 RankedTensorType type,

2566 auto emptyTensor =

2567 rewriter.createtensor::EmptyOp(loc, type, dynamicSizes);

2568 auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());

2569 auto fillValue = rewriter.createarith::ConstantOp(loc, fillValueAttr);

2570 auto filledTensor = rewriter

2573 .result();

2574 return filledTensor;

2575 }

2576

2578 FloatType type, Value value) {

2579 auto integerVal = builder.createarith::IndexCastUIOp(

2580 loc,

2581 type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()

2583 value);

2584

2585 return builder.createarith::UIToFPOp(loc, type, integerVal);

2586 }

2587

2589 FloatType type, int64_t index) {

2590 auto indexVal = builder.createlinalg::IndexOp(loc, index);

2591 return castIndexToFloat(builder, loc, type, indexVal);

2592 }

2593

2594 template <typename... Args>

2596 Args... args) {

2598 }

2599

2600 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,

2602 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||

2603 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {

2605 "only supports ranked tensors");

2606 }

2607

2608 auto loc = rfft2d.getLoc();

2609 auto input = rfft2d.getInputReal();

2610 auto elementType =

2611 dyn_cast(cast(input.getType()).getElementType());

2612 if (!elementType)

2614 "only supports float element types");

2615

2616

2618 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);

2619

2620

2622 utils::IteratorType::parallel, utils::IteratorType::parallel,

2623 utils::IteratorType::parallel, utils::IteratorType::reduction,

2624 utils::IteratorType::reduction};

2625

2626

2629 createZeroTensor(rewriter, loc, outputType, dynamicSizes),

2630 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};

2631

2632

2635 affineDimsExpr(rewriter, 0, 1, 2),

2636 affineDimsExpr(rewriter, 0, 1, 2)},

2638

2639

2640 auto dimH = rewriter.createOrFoldtensor::DimOp(loc, input, 1);

2641 auto dimW = rewriter.createOrFoldtensor::DimOp(loc, input, 2);

2642

2643

2644 auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);

2645 auto twoPi = rewriter.createarith::ConstantOp(loc, twoPiAttr);

2646 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);

2647 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);

2648

2650 Value valReal = args[0];

2651 Value sumReal = args[1];

2652 Value sumImag = args[2];

2653

2654

2655 Value oy = builder.createlinalg::IndexOp(loc, 1);

2656 Value ox = builder.createlinalg::IndexOp(loc, 2);

2657 Value iy = builder.createlinalg::IndexOp(loc, 3);

2658 Value ix = builder.createlinalg::IndexOp(loc, 4);

2659

2660

2661

2662

2663 auto iyXoy = builder.createindex::MulOp(loc, iy, oy);

2664 auto ixXox = builder.createindex::MulOp(loc, ix, ox);

2665

2666 auto iyRem = builder.createindex::RemUOp(loc, iyXoy, dimH);

2667 auto ixRem = builder.createindex::RemUOp(loc, ixXox, dimW);

2668

2669 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);

2670 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);

2671

2672 auto yComponent = builder.createarith::DivFOp(loc, iyRemFloat, constH);

2673 auto xComponent = builder.createarith::DivFOp(loc, ixRemFloat, constW);

2674 auto sumXY = builder.createarith::AddFOp(loc, yComponent, xComponent);

2675 auto angle = builder.createarith::MulFOp(loc, twoPi, sumXY);

2676

2677

2678

2679 auto cosAngle = builder.createmath::CosOp(loc, angle);

2680 auto sinAngle = builder.createmath::SinOp(loc, angle);

2681 auto realComponent =

2682 builder.createarith::MulFOp(loc, valReal, cosAngle);

2683 auto imagComponent =

2684 builder.createarith::MulFOp(loc, valReal, sinAngle);

2685

2686

2687

2688 auto outReal = builder.createarith::AddFOp(loc, sumReal, realComponent);

2689 auto outImag = builder.createarith::SubFOp(loc, sumImag, imagComponent);

2690

2691 builder.createlinalg::YieldOp(loc, ValueRange{outReal, outImag});

2692 };

2693

2695 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,

2696 indexingMaps, iteratorTypes, buildBody);

2697

2698 return success();

2699 }

2700 };

2701

2704

2705 LogicalResult matchAndRewrite(FFT2dOp fft2d,

2707 if (!llvm::all_of(fft2d->getOperandTypes(),

2708 RFFT2dConverter::isRankedTensor) ||

2709 !llvm::all_of(fft2d->getResultTypes(),

2710 RFFT2dConverter::isRankedTensor)) {

2711 return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");

2712 }

2713

2714 Location loc = fft2d.getLoc();

2715 Value input_real = fft2d.getInputReal();

2716 Value input_imag = fft2d.getInputImag();

2717 BoolAttr inverse = fft2d.getInverseAttr();

2718

2719 auto real_el_ty = cast(

2720 cast(input_real.getType()).getElementType());

2721 [[maybe_unused]] auto imag_el_ty = cast(

2722 cast(input_imag.getType()).getElementType());

2723

2724 assert(real_el_ty == imag_el_ty);

2725

2726

2728

2729

2731

2734

2736

2737

2739 utils::IteratorType::parallel, utils::IteratorType::parallel,

2740 utils::IteratorType::parallel, utils::IteratorType::reduction,

2741 utils::IteratorType::reduction};

2742

2743

2746 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,

2747 dynamicSizes),

2748 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,

2749 dynamicSizes)};

2750

2751

2753 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),

2754 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),

2755 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),

2756 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},

2758

2759

2760 auto dimH = rewriter.createOrFoldtensor::DimOp(loc, input_real, 1);

2761 auto dimW = rewriter.createOrFoldtensor::DimOp(loc, input_real, 2);

2762

2763

2764 auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);

2765 auto twoPi = rewriter.createarith::ConstantOp(loc, twoPiAttr);

2767 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);

2769 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);

2770

2772 Value valReal = args[0];

2773 Value valImag = args[1];

2774 Value sumReal = args[2];

2775 Value sumImag = args[3];

2776

2777

2778 Value oy = builder.createlinalg::IndexOp(loc, 1);

2779 Value ox = builder.createlinalg::IndexOp(loc, 2);

2780 Value iy = builder.createlinalg::IndexOp(loc, 3);

2781 Value ix = builder.createlinalg::IndexOp(loc, 4);

2782

2783

2784

2785 auto iyXoy = builder.createindex::MulOp(loc, iy, oy);

2786 auto ixXox = builder.createindex::MulOp(loc, ix, ox);

2787

2788 auto iyRem = builder.createindex::RemUOp(loc, iyXoy, dimH);

2789 auto ixRem = builder.createindex::RemUOp(loc, ixXox, dimW);

2790

2791 auto iyRemFloat =

2792 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);

2793 auto ixRemFloat =

2794 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);

2795

2796 auto yComponent = builder.createarith::DivFOp(loc, iyRemFloat, constH);

2797 auto xComponent = builder.createarith::DivFOp(loc, ixRemFloat, constW);

2798

2799 auto sumXY = builder.createarith::AddFOp(loc, yComponent, xComponent);

2800 auto angle = builder.createarith::MulFOp(loc, twoPi, sumXY);

2801

2803 angle = builder.createarith::MulFOp(

2804 loc, angle,

2805 rewriter.createarith::ConstantOp(

2806 loc, rewriter.getFloatAttr(real_el_ty, -1.0)));

2807 }

2808

2809

2810

2811 auto cosAngle = builder.createmath::CosOp(loc, angle);

2812 auto sinAngle = builder.createmath::SinOp(loc, angle);

2813

2814 auto rcos = builder.createarith::MulFOp(loc, valReal, cosAngle);

2815 auto rsin = builder.createarith::MulFOp(loc, valImag, sinAngle);

2816 auto realComponent = builder.createarith::AddFOp(loc, rcos, rsin);

2817

2818 auto icos = builder.createarith::MulFOp(loc, valImag, cosAngle);

2819 auto isin = builder.createarith::MulFOp(loc, valReal, sinAngle);

2820

2821 auto imagComponent = builder.createarith::SubFOp(loc, icos, isin);

2822

2823

2824

2825 auto outReal = builder.createarith::AddFOp(loc, sumReal, realComponent);

2826 auto outImag = builder.createarith::AddFOp(loc, sumImag, imagComponent);

2827

2828 builder.createlinalg::YieldOp(loc, ValueRange{outReal, outImag});

2829 };

2830

2832 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,

2833 indexingMaps, iteratorTypes, buildBody);

2834

2835 return success();

2836 }

2837 };

2838

2839 }

2840

2843

2844

2845 patterns->add(patterns->getContext(),

2846 100);

2848 200);

2849 patterns->add(patterns->getContext(),

2850 300);

2851

2853

2854 PointwiseConvertertosa::AddOp,

2855 PointwiseConvertertosa::SubOp,

2856 PointwiseConvertertosa::MulOp,

2857 PointwiseConvertertosa::IntDivOp,

2858 PointwiseConvertertosa::NegateOp,

2859 PointwiseConvertertosa::PowOp,

2860 PointwiseConvertertosa::ReciprocalOp,

2861 PointwiseConvertertosa::RsqrtOp,

2862 PointwiseConvertertosa::LogOp,

2863 PointwiseConvertertosa::ExpOp,

2864 PointwiseConvertertosa::AbsOp,

2865 PointwiseConvertertosa::SinOp,

2866 PointwiseConvertertosa::CosOp,

2867 PointwiseConvertertosa::TanhOp,

2868 PointwiseConvertertosa::ErfOp,

2869 PointwiseConvertertosa::BitwiseAndOp,

2870 PointwiseConvertertosa::BitwiseOrOp,

2871 PointwiseConvertertosa::BitwiseNotOp,

2872 PointwiseConvertertosa::BitwiseXorOp,

2873 PointwiseConvertertosa::LogicalAndOp,

2874 PointwiseConvertertosa::LogicalNotOp,

2875 PointwiseConvertertosa::LogicalOrOp,

2876 PointwiseConvertertosa::LogicalXorOp,

2877 PointwiseConvertertosa::CastOp,

2878 PointwiseConvertertosa::LogicalLeftShiftOp,

2879 PointwiseConvertertosa::LogicalRightShiftOp,

2880 PointwiseConvertertosa::ArithmeticRightShiftOp,

2881 PointwiseConvertertosa::ClzOp,

2882 PointwiseConvertertosa::SelectOp,

2883 PointwiseConvertertosa::GreaterOp,

2884 PointwiseConvertertosa::GreaterEqualOp,

2885 PointwiseConvertertosa::EqualOp,

2886 PointwiseConvertertosa::MaximumOp,

2887 PointwiseConvertertosa::MinimumOp,

2888 PointwiseConvertertosa::CeilOp,

2889 PointwiseConvertertosa::FloorOp,

2890 PointwiseConvertertosa::ClampOp,

2891 PointwiseConvertertosa::SigmoidOp

2892 >(converter, patterns->getContext());

2893

2895 IdentityNConvertertosa::IdentityOp,

2896 ReduceConvertertosa::ReduceAllOp,

2897 ReduceConvertertosa::ReduceAnyOp,

2898 ReduceConvertertosa::ReduceMinOp,

2899 ReduceConvertertosa::ReduceMaxOp,

2900 ReduceConvertertosa::ReduceSumOp,

2901 ReduceConvertertosa::ReduceProductOp,

2902 ArgMaxConverter,

2903 GatherConverter,

2904 RescaleConverter,

2905 ReverseConverter,

2906 RFFT2dConverter,

2907 FFT2dConverter,

2908 TableConverter,

2909 TileConverter>(patterns->getContext());

2910

2911 }

static Value getZero(OpBuilder &b, Location loc, Type elementType)

Get zero value for an element type.

static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)

static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)

static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)

static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)

static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)

static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)

static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)

static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)

static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)

static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)

static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)

static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)

static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)

static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)

static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)

static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)

static bool operandsAndResultsRanked(Operation *operation)

static AffineMap get(MLIRContext *context)

Returns a zero result affine map with no dimensions or symbols: () -> ().

static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)

Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...

Block represents an ordered list of Operations.

BlockArgument getArgument(unsigned i)

Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.

bool getValue() const

Return the boolean value of this attribute.

IntegerAttr getIndexAttr(int64_t value)

IntegerAttr getI32IntegerAttr(int32_t value)

IntegerAttr getIntegerAttr(Type type, int64_t value)

AffineMap getMultiDimIdentityMap(unsigned rank)

FloatAttr getFloatAttr(Type type, double value)

AffineExpr getAffineConstantExpr(int64_t constant)

IntegerType getIntegerType(unsigned width)

BoolAttr getBoolAttr(bool value)

StringAttr getStringAttr(const Twine &bytes)

TypedAttr getZeroAttr(Type type)

AffineExpr getAffineDimExpr(unsigned position)

MLIRContext * getContext() const

This class implements a pattern rewriter for use with ConversionPatterns.

void replaceOp(Operation *op, ValueRange newValues) override

Replace the given operation with the new values.

An attribute that represents a reference to a dense vector or tensor object.

auto getValues() const

Return the held element values as a range of the given type.

static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)

Get an instance of a DenseIntElementsAttr with the given arguments.

ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...

typename SourceOp::Adaptor OpAdaptor

This class represents a single result from folding an operation.

Operation is the basic unit of execution within MLIR.

Value getOperand(unsigned idx)

Attribute getAttr(StringAttr name)

Return the specified attribute if present, null otherwise.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

Location getLoc()

The source location the operation was defined or derived from.

unsigned getNumOperands()

result_type_range getResultTypes()

operand_range getOperands()

Returns an iterator on the underlying Value's.

result_range getResults()

unsigned getNumResults()

Return the number of results held by this operation.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const

Convert the given type.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isUnsignedInteger() const

Return true if this is an unsigned integer type (with the specified width).

bool isInteger() const

Return true if this is an integer type (with the specified width).

bool isIntOrFloat() const

Return true if this is an integer (of any signedness) or a float type.

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class provides an abstraction over the different types of ranges over Values.

type_range getType() const

Type front()

Return first type in the range.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

DynamicAPInt round(const Fraction &f)

Fraction abs(const Fraction &f)

OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)

Return the dimension of the given tensor value.

SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)

Return the dimensions of the given tensor value.

Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)

std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)

SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)

void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)

Populates conversion passes from TOSA dialect to Linalg dialect.

Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)

SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)

Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)

bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)

Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

OpFoldResult getAsOpFoldResult(Value val)

Given a value, try to extract a constant Attribute.

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)

These free functions allow clients of the API to not use classes in detail.

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...