MLIR: lib/Dialect/Arith/IR/ArithOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9 #include

10 #include

11 #include

12 #include

13

25

26 #include "llvm/ADT/APFloat.h"

27 #include "llvm/ADT/APInt.h"

28 #include "llvm/ADT/APSInt.h"

29 #include "llvm/ADT/FloatingPointMode.h"

30 #include "llvm/ADT/STLExtras.h"

31 #include "llvm/ADT/SmallString.h"

32 #include "llvm/ADT/SmallVector.h"

33 #include "llvm/ADT/TypeSwitch.h"

34

35 using namespace mlir;

37

38

39

40

41

42 static IntegerAttr

45 function_ref<APInt(const APInt &, const APInt &)> binFn) {

46 APInt lhsVal = llvm::cast(lhs).getValue();

47 APInt rhsVal = llvm::cast(rhs).getValue();

48 APInt value = binFn(lhsVal, rhsVal);

50 }

51

55 }

56

60 }

61

64 return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies());

65 }

66

67

68 static IntegerOverflowFlagsAttr

70 IntegerOverflowFlagsAttr val2) {

72 val1.getValue() & val2.getValue());

73 }

74

75

77 switch (pred) {

78 case arith::CmpIPredicate::eq:

79 return arith::CmpIPredicate::ne;

80 case arith::CmpIPredicate::ne:

81 return arith::CmpIPredicate::eq;

82 case arith::CmpIPredicate::slt:

83 return arith::CmpIPredicate::sge;

84 case arith::CmpIPredicate::sle:

85 return arith::CmpIPredicate::sgt;

86 case arith::CmpIPredicate::sgt:

87 return arith::CmpIPredicate::sle;

88 case arith::CmpIPredicate::sge:

89 return arith::CmpIPredicate::slt;

90 case arith::CmpIPredicate::ult:

91 return arith::CmpIPredicate::uge;

92 case arith::CmpIPredicate::ule:

93 return arith::CmpIPredicate::ugt;

94 case arith::CmpIPredicate::ugt:

95 return arith::CmpIPredicate::ule;

96 case arith::CmpIPredicate::uge:

97 return arith::CmpIPredicate::ult;

98 }

99 llvm_unreachable("unknown cmpi predicate kind");

100 }

101

102

103

104

105

106

107

108 static llvm::RoundingMode

110 switch (roundingMode) {

111 case RoundingMode::downward:

112 return llvm::RoundingMode::TowardNegative;

113 case RoundingMode::to_nearest_away:

114 return llvm::RoundingMode::NearestTiesToAway;

115 case RoundingMode::to_nearest_even:

116 return llvm::RoundingMode::NearestTiesToEven;

117 case RoundingMode::toward_zero:

118 return llvm::RoundingMode::TowardZero;

119 case RoundingMode:📈

120 return llvm::RoundingMode::TowardPositive;

121 }

122 llvm_unreachable("Unhandled rounding mode");

123 }

124

125 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {

128 }

129

134

135 return -1;

136 }

137

140 }

141

143 APInt value;

145 return value;

146

147 return failure();

148 }

149

152 ShapedType shapedType = llvm::dyn_cast_or_null(type);

153 if (!shapedType)

154 return boolAttr;

156 }

157

158

159

160

161

162 namespace {

163 #include "ArithCanonicalization.inc"

164 }

165

166

167

168

169

170

173 if (auto shapedType = llvm::dyn_cast(type))

174 return shapedType.cloneWith(std::nullopt, i1Type);

175 if (llvm::isa(type))

177 return i1Type;

178 }

179

180

181

182

183

184 void arith::ConstantOp::getAsmResultNames(

187 if (auto intCst = llvm::dyn_cast(getValue())) {

188 auto intType = llvm::dyn_cast(type);

189

190

191 if (intType && intType.getWidth() == 1)

192 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));

193

194

196 llvm::raw_svector_ostream specialName(specialNameBuffer);

197 specialName << 'c' << intCst.getValue();

198 if (intType)

199 specialName << '_' << type;

200 setNameFn(getResult(), specialName.str());

201 } else {

202 setNameFn(getResult(), "cst");

203 }

204 }

205

206

207

210

211 if (llvm::isa(type) &&

212 !llvm::cast(type).isSignless())

213 return emitOpError("integer return type must be signless");

214

215 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {

216 return emitOpError(

217 "value must be an integer, float, or elements attribute");

218 }

219

220

221

222

223 if (isa(type) && !isa(getValue()))

224 return emitOpError(

225 "intializing scalable vectors with elements attribute is not supported"

226 " unless it's a vector splat");

227 return success();

228 }

229

230 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {

231

232 auto typedAttr = llvm::dyn_cast(value);

233 if (!typedAttr || typedAttr.getType() != type)

234 return false;

235

236 if (llvm::isa(type) &&

237 !llvm::cast(type).isSignless())

238 return false;

239

240 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);

241 }

242

243 ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,

245 if (isBuildableWith(value, type))

246 return builder.createarith::ConstantOp(loc, cast(value));

247 return nullptr;

248 }

249

250 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }

251

253 int64_t value, unsigned width) {

255 arith::ConstantOp::build(builder, result, type,

257 }

258

260 int64_t value, Type type) {

262 "ConstantIntOp can only have signless integer type values");

263 arith::ConstantOp::build(builder, result, type,

265 }

266

268 if (auto constOp = dyn_cast_or_nullarith::ConstantOp(op))

269 return constOp.getType().isSignlessInteger();

270 return false;

271 }

272

274 const APFloat &value, FloatType type) {

275 arith::ConstantOp::build(builder, result, type,

277 }

278

280 if (auto constOp = dyn_cast_or_nullarith::ConstantOp(op))

281 return llvm::isa(constOp.getType());

282 return false;

283 }

284

286 int64_t value) {

287 arith::ConstantOp::build(builder, result, builder.getIndexType(),

289 }

290

292 if (auto constOp = dyn_cast_or_nullarith::ConstantOp(op))

293 return constOp.getType().isIndex();

294 return false;

295 }

296

297

298

299

300

301 OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {

302

304 return getLhs();

305

306

307 if (auto sub = getLhs().getDefiningOp())

308 if (getRhs() == sub.getRhs())

309 return sub.getLhs();

310

311

312 if (auto sub = getRhs().getDefiningOp())

313 if (getLhs() == sub.getRhs())

314 return sub.getLhs();

315

316 return constFoldBinaryOp(

317 adaptor.getOperands(),

318 [](APInt a, const APInt &b) { return std::move(a) + b; });

319 }

320

323 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,

324 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);

325 }

326

327

328

329

330

331 std::optional<SmallVector<int64_t, 4>>

332 arith::AddUIExtendedOp::getShapeForUnroll() {

333 if (auto vt = llvm::dyn_cast(getType(0)))

334 return llvm::to_vector<4>(vt.getShape());

335 return std::nullopt;

336 }

337

338

339

341 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);

342 }

343

344 LogicalResult

345 arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,

347 Type overflowTy = getOverflow().getType();

348

351 auto falseValue = builder.getZeroAttr(overflowTy);

352

353 results.push_back(getLhs());

354 results.push_back(falseValue);

355 return success();

356 }

357

358

359

360

361

362 if (Attribute sumAttr = constFoldBinaryOp(

363 adaptor.getOperands(),

364 [](APInt a, const APInt &b) { return std::move(a) + b; })) {

365 Attribute overflowAttr = constFoldBinaryOp(

366 ArrayRef({sumAttr, adaptor.getLhs()}),

369 if (!overflowAttr)

370 return failure();

371

372 results.push_back(sumAttr);

373 results.push_back(overflowAttr);

374 return success();

375 }

376

377 return failure();

378 }

379

380 void arith::AddUIExtendedOp::getCanonicalizationPatterns(

382 patterns.add(context);

383 }

384

385

386

387

388

389 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {

390

391 if (getOperand(0) == getOperand(1)) {

392 auto shapedType = dyn_cast(getType());

393

394 if (!shapedType || shapedType.hasStaticShape())

396 }

397

399 return getLhs();

400

401 if (auto add = getLhs().getDefiningOp()) {

402

403 if (getRhs() == add.getRhs())

404 return add.getLhs();

405

406 if (getRhs() == add.getLhs())

407 return add.getRhs();

408 }

409

410 return constFoldBinaryOp(

411 adaptor.getOperands(),

412 [](APInt a, const APInt &b) { return std::move(a) - b; });

413 }

414

417 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,

418 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,

419 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);

420 }

421

422

423

424

425

426 OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {

427

429 return getRhs();

430

432 return getLhs();

433

434

435

436 return constFoldBinaryOp(

437 adaptor.getOperands(),

438 [](const APInt &a, const APInt &b) { return a * b; });

439 }

440

441 void arith::MulIOp::getAsmResultNames(

443 if (!isa(getType()))

444 return;

445

446

447

448 auto isVscale = [](Operation *op) {

449 return op && op->getName().getStringRef() == "vector.vscale";

450 };

451

452 IntegerAttr baseValue;

453 auto isVscaleExpr = [&](Value a, Value b) {

455 isVscale(b.getDefiningOp());

456 };

457

458 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))

459 return;

460

461

463 llvm::raw_svector_ostream specialName(specialNameBuffer);

464 specialName << 'c' << baseValue.getInt() << "_vscale";

465 setNameFn(getResult(), specialName.str());

466 }

467

470 patterns.add(context);

471 }

472

473

474

475

476

477 std::optional<SmallVector<int64_t, 4>>

478 arith::MulSIExtendedOp::getShapeForUnroll() {

479 if (auto vt = llvm::dyn_cast(getType(0)))

480 return llvm::to_vector<4>(vt.getShape());

481 return std::nullopt;

482 }

483

484 LogicalResult

485 arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,

487

489 Attribute zero = adaptor.getRhs();

490 results.push_back(zero);

491 results.push_back(zero);

492 return success();

493 }

494

495

496 if (Attribute lowAttr = constFoldBinaryOp(

497 adaptor.getOperands(),

498 [](const APInt &a, const APInt &b) { return a * b; })) {

499

500 Attribute highAttr = constFoldBinaryOp(

501 adaptor.getOperands(), [](const APInt &a, const APInt &b) {

502 return llvm::APIntOps::mulhs(a, b);

503 });

504 assert(highAttr && "Unexpected constant-folding failure");

505

506 results.push_back(lowAttr);

507 results.push_back(highAttr);

508 return success();

509 }

510

511 return failure();

512 }

513

514 void arith::MulSIExtendedOp::getCanonicalizationPatterns(

516 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);

517 }

518

519

520

521

522

523 std::optional<SmallVector<int64_t, 4>>

524 arith::MulUIExtendedOp::getShapeForUnroll() {

525 if (auto vt = llvm::dyn_cast(getType(0)))

526 return llvm::to_vector<4>(vt.getShape());

527 return std::nullopt;

528 }

529

530 LogicalResult

531 arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,

533

535 Attribute zero = adaptor.getRhs();

536 results.push_back(zero);

537 results.push_back(zero);

538 return success();

539 }

540

541

545 results.push_back(getLhs());

546 results.push_back(zero);

547 return success();

548 }

549

550

551 if (Attribute lowAttr = constFoldBinaryOp(

552 adaptor.getOperands(),

553 [](const APInt &a, const APInt &b) { return a * b; })) {

554

555 Attribute highAttr = constFoldBinaryOp(

556 adaptor.getOperands(), [](const APInt &a, const APInt &b) {

557 return llvm::APIntOps::mulhu(a, b);

558 });

559 assert(highAttr && "Unexpected constant-folding failure");

560

561 results.push_back(lowAttr);

562 results.push_back(highAttr);

563 return success();

564 }

565

566 return failure();

567 }

568

569 void arith::MulUIExtendedOp::getCanonicalizationPatterns(

571 patterns.add(context);

572 }

573

574

575

576

577

578

580 arith::IntegerOverflowFlags ovfFlags) {

581 auto mul = lhs.getDefiningOpmlir::arith::MulIOp();

582 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))

583 return {};

584

585 if (mul.getLhs() == rhs)

586 return mul.getRhs();

587

588 if (mul.getRhs() == rhs)

589 return mul.getLhs();

590

591 return {};

592 }

593

594 OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {

595

597 return getLhs();

598

599

600 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))

601 return val;

602

603

604 bool div0 = false;

605 auto result = constFoldBinaryOp(adaptor.getOperands(),

606 [&](APInt a, const APInt &b) {

607 if (div0 || !b) {

608 div0 = true;

609 return a;

610 }

611 return a.udiv(b);

612 });

613

614 return div0 ? Attribute() : result;

615 }

616

617

619

622

624 }

625

628 }

629

630

631

632

633

634 OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {

635

637 return getLhs();

638

639

640 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))

641 return val;

642

643

644 bool overflowOrDiv0 = false;

645 auto result = constFoldBinaryOp(

646 adaptor.getOperands(), [&](APInt a, const APInt &b) {

647 if (overflowOrDiv0 || !b) {

648 overflowOrDiv0 = true;

649 return a;

650 }

651 return a.sdiv_ov(b, overflowOrDiv0);

652 });

653

654 return overflowOrDiv0 ? Attribute() : result;

655 }

656

657

658

659

661

662

666

668 }

669

672 }

673

674

675

676

677

679 bool &overflow) {

680

681 APInt one(a.getBitWidth(), 1, true);

682 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);

683 return val.sadd_ov(one, overflow);

684 }

685

686

687

688

689

690 OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {

691

693 return getLhs();

694

695 bool overflowOrDiv0 = false;

696 auto result = constFoldBinaryOp(

697 adaptor.getOperands(), [&](APInt a, const APInt &b) {

698 if (overflowOrDiv0 || !b) {

699 overflowOrDiv0 = true;

700 return a;

701 }

702 APInt quotient = a.udiv(b);

703 if (!a.urem(b))

704 return quotient;

705 APInt one(a.getBitWidth(), 1, true);

706 return quotient.uadd_ov(one, overflowOrDiv0);

707 });

708

709 return overflowOrDiv0 ? Attribute() : result;

710 }

711

714 }

715

716

717

718

719

720 OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {

721

723 return getLhs();

724

725

726

727

728 bool overflowOrDiv0 = false;

729 auto result = constFoldBinaryOp(

730 adaptor.getOperands(), [&](APInt a, const APInt &b) {

731 if (overflowOrDiv0 || !b) {

732 overflowOrDiv0 = true;

733 return a;

734 }

735 if (!a)

736 return a;

737

738 unsigned bits = a.getBitWidth();

740 bool aGtZero = a.sgt(zero);

741 bool bGtZero = b.sgt(zero);

742 if (aGtZero && bGtZero) {

743

745 }

746

747

748

749 bool overflowNegA = false;

750 bool overflowNegB = false;

751 bool overflowDiv = false;

752 bool overflowNegRes = false;

753 if (!aGtZero && !bGtZero) {

754

755 APInt posA = zero.ssub_ov(a, overflowNegA);

756 APInt posB = zero.ssub_ov(b, overflowNegB);

758 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);

759 return res;

760 }

761 if (!aGtZero && bGtZero) {

762

763 APInt posA = zero.ssub_ov(a, overflowNegA);

764 APInt div = posA.sdiv_ov(b, overflowDiv);

765 APInt res = zero.ssub_ov(div, overflowNegRes);

766 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);

767 return res;

768 }

769

770 APInt posB = zero.ssub_ov(b, overflowNegB);

771 APInt div = a.sdiv_ov(posB, overflowDiv);

772 APInt res = zero.ssub_ov(div, overflowNegRes);

773

774 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);

775 return res;

776 });

777

778 return overflowOrDiv0 ? Attribute() : result;

779 }

780

783 }

784

785

786

787

788

789 OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {

790

792 return getLhs();

793

794

795 bool overflowOrDiv = false;

796 auto result = constFoldBinaryOp(

797 adaptor.getOperands(), [&](APInt a, const APInt &b) {

798 if (b.isZero()) {

799 overflowOrDiv = true;

800 return a;

801 }

802 return a.sfloordiv_ov(b, overflowOrDiv);

803 });

804

805 return overflowOrDiv ? Attribute() : result;

806 }

807

808

809

810

811

812 OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {

813

816

817

818 bool div0 = false;

819 auto result = constFoldBinaryOp(adaptor.getOperands(),

820 [&](APInt a, const APInt &b) {

821 if (div0 || b.isZero()) {

822 div0 = true;

823 return a;

824 }

825 return a.urem(b);

826 });

827

828 return div0 ? Attribute() : result;

829 }

830

831

832

833

834

835 OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {

836

839

840

841 bool div0 = false;

842 auto result = constFoldBinaryOp(adaptor.getOperands(),

843 [&](APInt a, const APInt &b) {

844 if (div0 || b.isZero()) {

845 div0 = true;

846 return a;

847 }

848 return a.srem(b);

849 });

850

851 return div0 ? Attribute() : result;

852 }

853

854

855

856

857

858

860 for (bool reversePrev : {false, true}) {

861 auto prev = (reversePrev ? op.getRhs() : op.getLhs())

862 .getDefiningOparith::AndIOp();

863 if (!prev)

864 continue;

865

866 Value other = (reversePrev ? op.getLhs() : op.getRhs());

867 if (other != prev.getLhs() && other != prev.getRhs())

868 continue;

869

870 return prev.getResult();

871 }

872 return {};

873 }

874

875 OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {

876

878 return getRhs();

879

880 APInt intValue;

882 intValue.isAllOnes())

883 return getLhs();

884

887 intValue.isAllOnes())

889

892 intValue.isAllOnes())

894

895

897 return result;

898

899 return constFoldBinaryOp(

900 adaptor.getOperands(),

901 [](APInt a, const APInt &b) { return std::move(a) & b; });

902 }

903

904

905

906

907

908 OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {

910

911 if (rhsVal.isZero())

912 return getLhs();

913

914 if (rhsVal.isAllOnes())

915 return adaptor.getRhs();

916 }

917

918 APInt intValue;

919

922 intValue.isAllOnes())

923 return getRhs().getDefiningOp().getRhs();

924

927 intValue.isAllOnes())

928 return getLhs().getDefiningOp().getRhs();

929

930 return constFoldBinaryOp(

931 adaptor.getOperands(),

932 [](APInt a, const APInt &b) { return std::move(a) | b; });

933 }

934

935

936

937

938

939 OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {

940

942 return getLhs();

943

944 if (getLhs() == getRhs())

946

947

948 if (arith::XOrIOp prev = getLhs().getDefiningOparith::XOrIOp()) {

949 if (prev.getRhs() == getRhs())

950 return prev.getLhs();

951 if (prev.getLhs() == getRhs())

952 return prev.getRhs();

953 }

954

955

956 if (arith::XOrIOp prev = getRhs().getDefiningOparith::XOrIOp()) {

957 if (prev.getRhs() == getLhs())

958 return prev.getLhs();

959 if (prev.getLhs() == getLhs())

960 return prev.getRhs();

961 }

962

963 return constFoldBinaryOp(

964 adaptor.getOperands(),

965 [](APInt a, const APInt &b) { return std::move(a) ^ b; });

966 }

967

970 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);

971 }

972

973

974

975

976

977 OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {

978

979 if (auto op = this->getOperand().getDefiningOparith::NegFOp())

980 return op.getOperand();

981 return constFoldUnaryOp(adaptor.getOperands(),

982 [](const APFloat &a) { return -a; });

983 }

984

985

986

987

988

989 OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {

990

992 return getLhs();

993

994 return constFoldBinaryOp(

995 adaptor.getOperands(),

996 [](const APFloat &a, const APFloat &b) { return a + b; });

997 }

998

999

1000

1001

1002

1003 OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {

1004

1006 return getLhs();

1007

1008 return constFoldBinaryOp(

1009 adaptor.getOperands(),

1010 [](const APFloat &a, const APFloat &b) { return a - b; });

1011 }

1012

1013

1014

1015

1016

1017 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {

1018

1019 if (getLhs() == getRhs())

1020 return getRhs();

1021

1022

1024 return getLhs();

1025

1026 return constFoldBinaryOp(

1027 adaptor.getOperands(),

1028 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });

1029 }

1030

1031

1032

1033

1034

1035 OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {

1036

1037 if (getLhs() == getRhs())

1038 return getRhs();

1039

1040

1042 return getLhs();

1043

1044 return constFoldBinaryOp(adaptor.getOperands(), llvm::maxnum);

1045 }

1046

1047

1048

1049

1050

1051 OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {

1052

1053 if (getLhs() == getRhs())

1054 return getRhs();

1055

1056 if (APInt intValue;

1058

1059 if (intValue.isMaxSignedValue())

1060 return getRhs();

1061

1062 if (intValue.isMinSignedValue())

1063 return getLhs();

1064 }

1065

1066 return constFoldBinaryOp(adaptor.getOperands(),

1067 [](const APInt &a, const APInt &b) {

1068 return llvm::APIntOps::smax(a, b);

1069 });

1070 }

1071

1072

1073

1074

1075

1076 OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {

1077

1078 if (getLhs() == getRhs())

1079 return getRhs();

1080

1081 if (APInt intValue;

1083

1084 if (intValue.isMaxValue())

1085 return getRhs();

1086

1087 if (intValue.isMinValue())

1088 return getLhs();

1089 }

1090

1091 return constFoldBinaryOp(adaptor.getOperands(),

1092 [](const APInt &a, const APInt &b) {

1093 return llvm::APIntOps::umax(a, b);

1094 });

1095 }

1096

1097

1098

1099

1100

1101 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {

1102

1103 if (getLhs() == getRhs())

1104 return getRhs();

1105

1106

1108 return getLhs();

1109

1110 return constFoldBinaryOp(

1111 adaptor.getOperands(),

1112 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });

1113 }

1114

1115

1116

1117

1118

1119 OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {

1120

1121 if (getLhs() == getRhs())

1122 return getRhs();

1123

1124

1126 return getLhs();

1127

1128 return constFoldBinaryOp(

1129 adaptor.getOperands(),

1130 [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });

1131 }

1132

1133

1134

1135

1136

1137 OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {

1138

1139 if (getLhs() == getRhs())

1140 return getRhs();

1141

1142 if (APInt intValue;

1144

1145 if (intValue.isMinSignedValue())

1146 return getRhs();

1147

1148 if (intValue.isMaxSignedValue())

1149 return getLhs();

1150 }

1151

1152 return constFoldBinaryOp(adaptor.getOperands(),

1153 [](const APInt &a, const APInt &b) {

1154 return llvm::APIntOps::smin(a, b);

1155 });

1156 }

1157

1158

1159

1160

1161

1162 OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {

1163

1164 if (getLhs() == getRhs())

1165 return getRhs();

1166

1167 if (APInt intValue;

1169

1170 if (intValue.isMinValue())

1171 return getRhs();

1172

1173 if (intValue.isMaxValue())

1174 return getLhs();

1175 }

1176

1177 return constFoldBinaryOp(adaptor.getOperands(),

1178 [](const APInt &a, const APInt &b) {

1179 return llvm::APIntOps::umin(a, b);

1180 });

1181 }

1182

1183

1184

1185

1186

1187 OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {

1188

1190 return getLhs();

1191

1192 return constFoldBinaryOp(

1193 adaptor.getOperands(),

1194 [](const APFloat &a, const APFloat &b) { return a * b; });

1195 }

1196

1199 patterns.add(context);

1200 }

1201

1202

1203

1204

1205

1206 OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {

1207

1209 return getLhs();

1210

1211 return constFoldBinaryOp(

1212 adaptor.getOperands(),

1213 [](const APFloat &a, const APFloat &b) { return a / b; });

1214 }

1215

1218 patterns.add(context);

1219 }

1220

1221

1222

1223

1224

1225 OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {

1226 return constFoldBinaryOp(adaptor.getOperands(),

1227 [](const APFloat &a, const APFloat &b) {

1228 APFloat result(a);

1229

1230

1231

1232 (void)result.mod(b);

1233 return result;

1234 });

1235 }

1236

1237

1238

1239

1240

1241 template <typename... Types>

1243

1244

1245

1246

1247 template <typename... ShapedTypes, typename... ElementTypes>

1250 if (llvm::isa(type) && !llvm::isa<ShapedTypes...>(type))

1251 return {};

1252

1254 if (!llvm::isa<ElementTypes...>(underlyingType))

1255 return {};

1256

1257 return underlyingType;

1258 }

1259

1260

1261 template <typename... ElementTypes>

1265 }

1266

1267

1268 template <typename... ElementTypes>

1273 }

1274

1275

1277 auto rankedTensorA = dyn_cast(typeA);

1278 auto rankedTensorB = dyn_cast(typeB);

1279 if (!rankedTensorA || !rankedTensorB)

1280 return true;

1281 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();

1282 }

1283

1285 if (inputs.size() != 1 || outputs.size() != 1)

1286 return false;

1288 return false;

1290 }

1291

1292

1293

1294

1295

1296

1297 template <typename ValType, typename Op>

1301

1302 if (llvm::cast(srcType).getWidth() >=

1303 llvm::cast(dstType).getWidth())

1304 return op.emitError("result type ")

1305 << dstType << " must be wider than operand type " << srcType;

1306

1307 return success();

1308 }

1309

1310

1311 template <typename ValType, typename Op>

1315

1316 if (llvm::cast(srcType).getWidth() <=

1317 llvm::cast(dstType).getWidth())

1318 return op.emitError("result type ")

1319 << dstType << " must be shorter than operand type " << srcType;

1320

1321 return success();

1322 }

1323

1324

1325 template <template <typename> class WidthComparator, typename... ElementTypes>

1328 return false;

1329

1330 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());

1331 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());

1332 if (!srcType || !dstType)

1333 return false;

1334

1335 return WidthComparator()(dstType.getIntOrFloatBitWidth(),

1336 srcType.getIntOrFloatBitWidth());

1337 }

1338

1339

1340

1342 APFloat sourceValue, const llvm::fltSemantics &targetSemantics,

1343 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {

1344 bool losesInfo = false;

1345 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);

1346 if (losesInfo || status != APFloat::opOK)

1347 return failure();

1348

1349 return sourceValue;

1350 }

1351

1352

1353

1354

1355

1356 OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {

1357 if (auto lhs = getIn().getDefiningOp()) {

1358 getInMutable().assign(lhs.getIn());

1359 return getResult();

1360 }

1361

1363 unsigned bitWidth = llvm::cast(resType).getWidth();

1364 return constFoldCastOp<IntegerAttr, IntegerAttr>(

1365 adaptor.getOperands(), getType(),

1366 [bitWidth](const APInt &a, bool &castStatus) {

1367 return a.zext(bitWidth);

1368 });

1369 }

1370

1371 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

1372 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);

1373 }

1374

1376 return verifyExtOp(*this);

1377 }

1378

1379

1380

1381

1382

1383 OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {

1384 if (auto lhs = getIn().getDefiningOp()) {

1385 getInMutable().assign(lhs.getIn());

1386 return getResult();

1387 }

1388

1390 unsigned bitWidth = llvm::cast(resType).getWidth();

1391 return constFoldCastOp<IntegerAttr, IntegerAttr>(

1392 adaptor.getOperands(), getType(),

1393 [bitWidth](const APInt &a, bool &castStatus) {

1394 return a.sext(bitWidth);

1395 });

1396 }

1397

1398 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

1399 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);

1400 }

1401

1404 patterns.add(context);

1405 }

1406

1408 return verifyExtOp(*this);

1409 }

1410

1411

1412

1413

1414

1415

1416

1417 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {

1418 if (auto truncFOp = getOperand().getDefiningOp()) {

1419 if (truncFOp.getOperand().getType() == getType()) {

1420 arith::FastMathFlags truncFMF =

1421 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);

1422 bool isTruncContract =

1424 arith::FastMathFlags extFMF =

1425 getFastmath().value_or(arith::FastMathFlags::none);

1426 bool isExtContract =

1428 if (isTruncContract && isExtContract) {

1429 return truncFOp.getOperand();

1430 }

1431 }

1432 }

1433

1435 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();

1436 return constFoldCastOp<FloatAttr, FloatAttr>(

1437 adaptor.getOperands(), getType(),

1438 [&targetSemantics](const APFloat &a, bool &castStatus) {

1439 FailureOr result = convertFloatValue(a, targetSemantics);

1440 if (failed(result)) {

1441 castStatus = false;

1442 return a;

1443 }

1444 return *result;

1445 });

1446 }

1447

1448 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

1449 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);

1450 }

1451

1453

1454

1455

1456

1457

1458 bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,

1460 return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);

1461 }

1462

1464 return verifyExtOp(*this);

1465 }

1466

1467

1468

1469

1470

1471 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {

1472 if (matchPattern(getOperand(), m_Oparith::ExtUIOp()) ||

1473 matchPattern(getOperand(), m_Oparith::ExtSIOp())) {

1477

1478

1479 if (llvm::cast(srcType).getWidth() >

1480 llvm::cast(dstType).getWidth()) {

1481 setOperand(src);

1482 return getResult();

1483 }

1484

1485

1486

1487 if (srcType == dstType)

1488 return src;

1489 }

1490

1491

1492 if (matchPattern(getOperand(), m_Oparith::TruncIOp())) {

1493 setOperand(getOperand().getDefiningOp()->getOperand(0));

1494 return getResult();

1495 }

1496

1498 unsigned bitWidth = llvm::cast(resType).getWidth();

1499 return constFoldCastOp<IntegerAttr, IntegerAttr>(

1500 adaptor.getOperands(), getType(),

1501 [bitWidth](const APInt &a, bool &castStatus) {

1502 return a.trunc(bitWidth);

1503 });

1504 }

1505

1506 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

1507 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);

1508 }

1509

1512 patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,

1513 TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(

1514 context);

1515 }

1516

1518 return verifyTruncateOp(*this);

1519 }

1520

1521

1522

1523

1524

1525

1526

1527 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {

1529 if (auto extOp = getOperand().getDefiningOparith::ExtFOp()) {

1530 Value src = extOp.getIn();

1532 auto intermediateType =

1534

1535 if (llvm::APFloatBase::isRepresentableBy(

1536 srcType.getFloatSemantics(),

1537 intermediateType.getFloatSemantics())) {

1538

1539 if (srcType.getWidth() > resElemType.getWidth()) {

1540 setOperand(src);

1541 return getResult();

1542 }

1543

1544

1545 if (srcType == resElemType)

1546 return src;

1547 }

1548 }

1549

1550 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();

1551 return constFoldCastOp<FloatAttr, FloatAttr>(

1552 adaptor.getOperands(), getType(),

1553 [this, &targetSemantics](const APFloat &a, bool &castStatus) {

1554 RoundingMode roundingMode =

1555 getRoundingmode().value_or(RoundingMode::to_nearest_even);

1556 llvm::RoundingMode llvmRoundingMode =

1558 FailureOr result =

1560 if (failed(result)) {

1561 castStatus = false;

1562 return a;

1563 }

1564 return *result;

1565 });

1566 }

1567

1570 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);

1571 }

1572

1573 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

1574 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);

1575 }

1576

1578 return verifyTruncateOp(*this);

1579 }

1580

1581

1582

1583

1584

1585 bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,

1587 return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);

1588 }

1589

1591 return verifyTruncateOp(*this);

1592 }

1593

1594

1595

1596

1597

1600 patterns.add<AndOfExtUI, AndOfExtSI>(context);

1601 }

1602

1603

1604

1605

1606

1609 patterns.add<OrOfExtUI, OrOfExtSI>(context);

1610 }

1611

1612

1613

1614

1615

1616 template <typename From, typename To>

1619 return false;

1620

1621 auto srcType = getTypeIfLike(inputs.front());

1622 auto dstType = getTypeIfLike(outputs.back());

1623

1624 return srcType && dstType;

1625 }

1626

1627

1628

1629

1630

1631 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

1632 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);

1633 }

1634

1635 OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {

1637 return constFoldCastOp<IntegerAttr, FloatAttr>(

1638 adaptor.getOperands(), getType(),

1639 [&resEleType](const APInt &a, bool &castStatus) {

1640 FloatType floatTy = llvm::cast(resEleType);

1641 APFloat apf(floatTy.getFloatSemantics(),

1643 apf.convertFromAPInt(a, false,

1644 APFloat::rmNearestTiesToEven);

1645 return apf;

1646 });

1647 }

1648

1649

1650

1651

1652

1653 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

1654 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);

1655 }

1656

1657 OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {

1659 return constFoldCastOp<IntegerAttr, FloatAttr>(

1660 adaptor.getOperands(), getType(),

1661 [&resEleType](const APInt &a, bool &castStatus) {

1662 FloatType floatTy = llvm::cast(resEleType);

1663 APFloat apf(floatTy.getFloatSemantics(),

1665 apf.convertFromAPInt(a, true,

1666 APFloat::rmNearestTiesToEven);

1667 return apf;

1668 });

1669 }

1670

1671

1672

1673

1674

1675 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

1676 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);

1677 }

1678

1679 OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {

1681 unsigned bitWidth = llvm::cast(resType).getWidth();

1682 return constFoldCastOp<FloatAttr, IntegerAttr>(

1683 adaptor.getOperands(), getType(),

1684 [&bitWidth](const APFloat &a, bool &castStatus) {

1685 bool ignored;

1686 APSInt api(bitWidth, true);

1687 castStatus = APFloat::opInvalidOp !=

1688 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);

1689 return api;

1690 });

1691 }

1692

1693

1694

1695

1696

1697 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

1698 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);

1699 }

1700

1701 OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {

1703 unsigned bitWidth = llvm::cast(resType).getWidth();

1704 return constFoldCastOp<FloatAttr, IntegerAttr>(

1705 adaptor.getOperands(), getType(),

1706 [&bitWidth](const APFloat &a, bool &castStatus) {

1707 bool ignored;

1708 APSInt api(bitWidth, false);

1709 castStatus = APFloat::opInvalidOp !=

1710 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);

1711 return api;

1712 });

1713 }

1714

1715

1716

1717

1718

1721 return false;

1722

1723 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());

1724 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());

1725 if (!srcType || !dstType)

1726 return false;

1727

1730 }

1731

1732 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,

1735 }

1736

1737 OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {

1738

1739 unsigned resultBitwidth = 64;

1741 resultBitwidth = intTy.getWidth();

1742

1743 return constFoldCastOp<IntegerAttr, IntegerAttr>(

1744 adaptor.getOperands(), getType(),

1745 [resultBitwidth](const APInt &a, bool & ) {

1746 return a.sextOrTrunc(resultBitwidth);

1747 });

1748 }

1749

1750 void arith::IndexCastOp::getCanonicalizationPatterns(

1752 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);

1753 }

1754

1755

1756

1757

1758

1759 bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,

1762 }

1763

1764 OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {

1765

1766 unsigned resultBitwidth = 64;

1768 resultBitwidth = intTy.getWidth();

1769

1770 return constFoldCastOp<IntegerAttr, IntegerAttr>(

1771 adaptor.getOperands(), getType(),

1772 [resultBitwidth](const APInt &a, bool & ) {

1773 return a.zextOrTrunc(resultBitwidth);

1774 });

1775 }

1776

1777 void arith::IndexCastUIOp::getCanonicalizationPatterns(

1779 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);

1780 }

1781

1782

1783

1784

1785

1786 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

1788 return false;

1789

1790 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());

1791 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());

1792 if (!srcType || !dstType)

1793 return false;

1794

1796 }

1797

1798 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {

1799 auto resType = getType();

1800 auto operand = adaptor.getIn();

1801 if (!operand)

1802 return {};

1803

1804

1805 if (auto denseAttr = llvm::dyn_cast_or_null(operand))

1806 return denseAttr.bitcast(llvm::cast(resType).getElementType());

1807

1808 if (llvm::isa(resType))

1809 return {};

1810

1811

1812 if (llvm::isaub::PoisonAttr(operand))

1814

1815

1816 APInt bits = llvm::isa(operand)

1817 ? llvm::cast(operand).getValue().bitcastToAPInt()

1818 : llvm::cast(operand).getValue();

1820 "trying to fold on broken IR: operands have incompatible types");

1821

1822 if (auto resFloatType = llvm::dyn_cast(resType))

1824 APFloat(resFloatType.getFloatSemantics(), bits));

1826 }

1827

1830 patterns.add(context);

1831 }

1832

1833

1834

1835

1836

1837

1838

1840 const APInt &lhs, const APInt &rhs) {

1841 switch (predicate) {

1842 case arith::CmpIPredicate::eq:

1843 return lhs.eq(rhs);

1844 case arith::CmpIPredicate::ne:

1845 return lhs.ne(rhs);

1846 case arith::CmpIPredicate::slt:

1847 return lhs.slt(rhs);

1848 case arith::CmpIPredicate::sle:

1849 return lhs.sle(rhs);

1850 case arith::CmpIPredicate::sgt:

1851 return lhs.sgt(rhs);

1852 case arith::CmpIPredicate::sge:

1853 return lhs.sge(rhs);

1854 case arith::CmpIPredicate::ult:

1855 return lhs.ult(rhs);

1856 case arith::CmpIPredicate::ule:

1857 return lhs.ule(rhs);

1858 case arith::CmpIPredicate::ugt:

1859 return lhs.ugt(rhs);

1860 case arith::CmpIPredicate::uge:

1861 return lhs.uge(rhs);

1862 }

1863 llvm_unreachable("unknown cmpi predicate kind");

1864 }

1865

1866

1868 switch (predicate) {

1869 case arith::CmpIPredicate::eq:

1870 case arith::CmpIPredicate::sle:

1871 case arith::CmpIPredicate::sge:

1872 case arith::CmpIPredicate::ule:

1873 case arith::CmpIPredicate::uge:

1874 return true;

1875 case arith::CmpIPredicate::ne:

1876 case arith::CmpIPredicate::slt:

1877 case arith::CmpIPredicate::sgt:

1878 case arith::CmpIPredicate::ult:

1879 case arith::CmpIPredicate::ugt:

1880 return false;

1881 }

1882 llvm_unreachable("unknown cmpi predicate kind");

1883 }

1884

1886 if (auto intType = llvm::dyn_cast(t)) {

1887 return intType.getWidth();

1888 }

1889 if (auto vectorIntType = llvm::dyn_cast(t)) {

1890 return llvm::cast(vectorIntType.getElementType()).getWidth();

1891 }

1892 return std::nullopt;

1893 }

1894

1895 OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {

1896

1897 if (getLhs() == getRhs()) {

1900 }

1901

1903 if (auto extOp = getLhs().getDefiningOp()) {

1904

1905 std::optional<int64_t> integerWidth =

1907 if (integerWidth && integerWidth.value() == 1 &&

1908 getPredicate() == arith::CmpIPredicate::ne)

1909 return extOp.getOperand();

1910 }

1911 if (auto extOp = getLhs().getDefiningOp()) {

1912

1913 std::optional<int64_t> integerWidth =

1915 if (integerWidth && integerWidth.value() == 1 &&

1916 getPredicate() == arith::CmpIPredicate::ne)

1917 return extOp.getOperand();

1918 }

1919

1920

1922 getPredicate() == arith::CmpIPredicate::ne)

1923 return getLhs();

1924 }

1925

1927

1929 getPredicate() == arith::CmpIPredicate::eq)

1930 return getLhs();

1931 }

1932

1933

1934 if (adaptor.getLhs() && !adaptor.getRhs()) {

1935

1936 using Pred = CmpIPredicate;

1937 const std::pair<Pred, Pred> invPreds[] = {

1938 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},

1939 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},

1940 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},

1941 {Pred::ne, Pred::ne},

1942 };

1943 Pred origPred = getPredicate();

1944 for (auto pred : invPreds) {

1945 if (origPred == pred.first) {

1946 setPredicate(pred.second);

1947 Value lhs = getLhs();

1948 Value rhs = getRhs();

1949 getLhsMutable().assign(rhs);

1950 getRhsMutable().assign(lhs);

1951 return getResult();

1952 }

1953 }

1954 llvm_unreachable("unknown cmpi predicate kind");

1955 }

1956

1957

1958

1959 if (auto lhs = llvm::dyn_cast_if_present(adaptor.getLhs())) {

1960 return constFoldBinaryOp(

1961 adaptor.getOperands(), getI1SameShape(lhs.getType()),

1962 [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {

1963 return APInt(1,

1965 });

1966 }

1967

1968 return {};

1969 }

1970

1973 patterns.insert<CmpIExtSI, CmpIExtUI>(context);

1974 }

1975

1976

1977

1978

1979

1980

1981

1983 const APFloat &lhs, const APFloat &rhs) {

1984 auto cmpResult = lhs.compare(rhs);

1985 switch (predicate) {

1986 case arith::CmpFPredicate::AlwaysFalse:

1987 return false;

1988 case arith::CmpFPredicate::OEQ:

1989 return cmpResult == APFloat::cmpEqual;

1990 case arith::CmpFPredicate::OGT:

1991 return cmpResult == APFloat::cmpGreaterThan;

1992 case arith::CmpFPredicate::OGE:

1993 return cmpResult == APFloat::cmpGreaterThan ||

1994 cmpResult == APFloat::cmpEqual;

1995 case arith::CmpFPredicate::OLT:

1996 return cmpResult == APFloat::cmpLessThan;

1997 case arith::CmpFPredicate::OLE:

1998 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;

1999 case arith::CmpFPredicate::ONE:

2000 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;

2001 case arith::CmpFPredicate::ORD:

2002 return cmpResult != APFloat::cmpUnordered;

2003 case arith::CmpFPredicate::UEQ:

2004 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;

2005 case arith::CmpFPredicate::UGT:

2006 return cmpResult == APFloat::cmpUnordered ||

2007 cmpResult == APFloat::cmpGreaterThan;

2008 case arith::CmpFPredicate::UGE:

2009 return cmpResult == APFloat::cmpUnordered ||

2010 cmpResult == APFloat::cmpGreaterThan ||

2011 cmpResult == APFloat::cmpEqual;

2012 case arith::CmpFPredicate::ULT:

2013 return cmpResult == APFloat::cmpUnordered ||

2014 cmpResult == APFloat::cmpLessThan;

2015 case arith::CmpFPredicate::ULE:

2016 return cmpResult == APFloat::cmpUnordered ||

2017 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;

2018 case arith::CmpFPredicate::UNE:

2019 return cmpResult != APFloat::cmpEqual;

2020 case arith::CmpFPredicate::UNO:

2021 return cmpResult == APFloat::cmpUnordered;

2022 case arith::CmpFPredicate::AlwaysTrue:

2023 return true;

2024 }

2025 llvm_unreachable("unknown cmpf predicate kind");

2026 }

2027

2028 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {

2029 auto lhs = llvm::dyn_cast_if_present(adaptor.getLhs());

2030 auto rhs = llvm::dyn_cast_if_present(adaptor.getRhs());

2031

2032

2033 if (lhs && lhs.getValue().isNaN())

2034 rhs = lhs;

2035 if (rhs && rhs.getValue().isNaN())

2036 lhs = rhs;

2037

2038 if (!lhs || !rhs)

2039 return {};

2040

2041 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());

2043 }

2044

2046 public:

2048

2050 bool isUnsigned) {

2051 using namespace arith;

2052 switch (pred) {

2053 case CmpFPredicate::UEQ:

2054 case CmpFPredicate::OEQ:

2055 return CmpIPredicate::eq;

2056 case CmpFPredicate::UGT:

2057 case CmpFPredicate::OGT:

2058 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;

2059 case CmpFPredicate::UGE:

2060 case CmpFPredicate::OGE:

2061 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;

2062 case CmpFPredicate::ULT:

2063 case CmpFPredicate::OLT:

2064 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;

2065 case CmpFPredicate::ULE:

2066 case CmpFPredicate::OLE:

2067 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;

2068 case CmpFPredicate::UNE:

2069 case CmpFPredicate::ONE:

2070 return CmpIPredicate::ne;

2071 default:

2072 llvm_unreachable("Unexpected predicate!");

2073 }

2074 }

2075

2078 FloatAttr flt;

2080 return failure();

2081

2082 const APFloat &rhs = flt.getValue();

2083

2084

2085 if (rhs.isNaN())

2086 return failure();

2087

2088

2089

2090 FloatType floatTy = llvm::cast(op.getRhs().getType());

2091 int mantissaWidth = floatTy.getFPMantissaWidth();

2092 if (mantissaWidth <= 0)

2093 return failure();

2094

2095 bool isUnsigned;

2097

2098 if (auto si = op.getLhs().getDefiningOp()) {

2099 isUnsigned = false;

2100 intVal = si.getIn();

2101 } else if (auto ui = op.getLhs().getDefiningOp()) {

2102 isUnsigned = true;

2103 intVal = ui.getIn();

2104 } else {

2105 return failure();

2106 }

2107

2108

2109

2110 auto intTy = llvm::cast(intVal.getType());

2111 auto intWidth = intTy.getWidth();

2112

2113

2114 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);

2115

2116

2117

2118

2119 if ((int)intWidth > mantissaWidth) {

2120

2121 int exponent = ilogb(rhs);

2122 if (exponent == APFloat::IEK_Inf) {

2123 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));

2124 if (maxExponent < (int)valueBits) {

2125

2126 return failure();

2127 }

2128 } else {

2129

2130

2131 if (mantissaWidth <= exponent && exponent <= (int)valueBits) {

2132

2133 return failure();

2134 }

2135 }

2136 }

2137

2138

2139 CmpIPredicate pred;

2140 switch (op.getPredicate()) {

2141 case CmpFPredicate::ORD:

2142

2144 1);

2145 return success();

2146 case CmpFPredicate::UNO:

2147

2149 1);

2150 return success();

2151 default:

2153 break;

2154 }

2155

2156 if (!isUnsigned) {

2157

2158

2159 APFloat signedMax(rhs.getSemantics());

2160 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,

2161 APFloat::rmNearestTiesToEven);

2162 if (signedMax < rhs) {

2163 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||

2164 pred == CmpIPredicate::sle)

2166 1);

2167 else

2169 1);

2170 return success();

2171 }

2172 } else {

2173

2174

2175 APFloat unsignedMax(rhs.getSemantics());

2176 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,

2177 APFloat::rmNearestTiesToEven);

2178 if (unsignedMax < rhs) {

2179 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||

2180 pred == CmpIPredicate::ule)

2182 1);

2183 else

2185 1);

2186 return success();

2187 }

2188 }

2189

2190 if (!isUnsigned) {

2191

2192 APFloat signedMin(rhs.getSemantics());

2193 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,

2194 APFloat::rmNearestTiesToEven);

2195 if (signedMin > rhs) {

2196 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||

2197 pred == CmpIPredicate::sge)

2199 1);

2200 else

2202 1);

2203 return success();

2204 }

2205 } else {

2206

2207 APFloat unsignedMin(rhs.getSemantics());

2208 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,

2209 APFloat::rmNearestTiesToEven);

2210 if (unsignedMin > rhs) {

2211 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||

2212 pred == CmpIPredicate::uge)

2214 1);

2215 else

2217 1);

2218 return success();

2219 }

2220 }

2221

2222

2223

2224

2225

2226 bool ignored;

2227 APSInt rhsInt(intWidth, isUnsigned);

2228 if (APFloat::opInvalidOp ==

2229 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {

2230

2231

2232 return failure();

2233 }

2234

2235 if (!rhs.isZero()) {

2236 APFloat apf(floatTy.getFloatSemantics(),

2238 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);

2239

2240 bool equal = apf == rhs;

2241 if (!equal) {

2242

2243

2244

2245 switch (pred) {

2246 case CmpIPredicate::ne:

2248 1);

2249 return success();

2250 case CmpIPredicate::eq:

2252 1);

2253 return success();

2254 case CmpIPredicate::ule:

2255

2256

2257 if (rhs.isNegative()) {

2259 1);

2260 return success();

2261 }

2262 break;

2263 case CmpIPredicate::sle:

2264

2265

2266 if (rhs.isNegative())

2267 pred = CmpIPredicate::slt;

2268 break;

2269 case CmpIPredicate::ult:

2270

2271

2272 if (rhs.isNegative()) {

2274 1);

2275 return success();

2276 }

2277 pred = CmpIPredicate::ule;

2278 break;

2279 case CmpIPredicate::slt:

2280

2281

2282 if (!rhs.isNegative())

2283 pred = CmpIPredicate::sle;

2284 break;

2285 case CmpIPredicate::ugt:

2286

2287

2288 if (rhs.isNegative()) {

2290 1);

2291 return success();

2292 }

2293 break;

2294 case CmpIPredicate::sgt:

2295

2296

2297 if (rhs.isNegative())

2298 pred = CmpIPredicate::sge;

2299 break;

2300 case CmpIPredicate::uge:

2301

2302

2303 if (rhs.isNegative()) {

2305 1);

2306 return success();

2307 }

2308 pred = CmpIPredicate::ugt;

2309 break;

2310 case CmpIPredicate::sge:

2311

2312

2313 if (!rhs.isNegative())

2314 pred = CmpIPredicate::sgt;

2315 break;

2316 }

2317 }

2318 }

2319

2320

2321

2323 op, pred, intVal,

2324 rewriter.create(

2325 op.getLoc(), intVal.getType(),

2327 return success();

2328 }

2329 };

2330

2334 }

2335

2336

2337

2338

2339

2340

2343

2346

2347 if (!llvm::isa(op.getType()) || op.getType().isInteger(1))

2348 return failure();

2349

2350

2354 op.getCondition());

2355 return success();

2356 }

2357

2358

2362 op, op.getType(),

2363 rewriter.createarith::XOrIOp(

2364 op.getLoc(), op.getCondition(),

2365 rewriter.createarith::ConstantIntOp(

2366 op.getLoc(), 1, op.getCondition().getType())));

2367 return success();

2368 }

2369

2370 return failure();

2371 }

2372 };

2373

2374 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,

2376 results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,

2378 }

2379

2380 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {

2381 Value trueVal = getTrueValue();

2382 Value falseVal = getFalseValue();

2383 if (trueVal == falseVal)

2384 return trueVal;

2385

2386 Value condition = getCondition();

2387

2388

2390 return trueVal;

2391

2392

2394 return falseVal;

2395

2396

2397 if (isa_and_nonnullub::PoisonAttr(adaptor.getTrueValue()))

2398 return falseVal;

2399

2400 if (isa_and_nonnullub::PoisonAttr(adaptor.getFalseValue()))

2401 return trueVal;

2402

2403

2404 if (getType().isSignlessInteger(1) &&

2407 return condition;

2408

2409 if (auto cmp = dyn_cast_or_nullarith::CmpIOp(condition.getDefiningOp())) {

2410 auto pred = cmp.getPredicate();

2411 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {

2412 auto cmpLhs = cmp.getLhs();

2413 auto cmpRhs = cmp.getRhs();

2414

2415

2416

2417

2418

2419

2420

2421 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||

2422 (cmpRhs == trueVal && cmpLhs == falseVal))

2423 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;

2424 }

2425 }

2426

2427

2428

2429 if (auto cond =

2430 llvm::dyn_cast_if_present(adaptor.getCondition())) {

2431 if (auto lhs =

2432 llvm::dyn_cast_if_present(adaptor.getTrueValue())) {

2433 if (auto rhs =

2434 llvm::dyn_cast_if_present(adaptor.getFalseValue())) {

2436 results.reserve(static_cast<size_t>(cond.getNumElements()));

2437 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),

2438 cond.value_end<BoolAttr>());

2439 auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),

2441 auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),

2443

2444 for (auto [condVal, lhsVal, rhsVal] :

2445 llvm::zip_equal(condVals, lhsVals, rhsVals))

2446 results.push_back(condVal.getValue() ? lhsVal : rhsVal);

2447

2449 }

2450 }

2451 }

2452

2453 return nullptr;

2454 }

2455

2457 Type conditionType, resultType;

2459 if (parser.parseOperandList(operands, 3) ||

2462 return failure();

2463

2464

2466 conditionType = resultType;

2467 if (parser.parseType(resultType))

2468 return failure();

2469 } else {

2471 }

2472

2473 result.addTypes(resultType);

2475 {conditionType, resultType, resultType},

2477 }

2478

2480 p << " " << getOperands();

2482 p << " : ";

2483 if (ShapedType condType =

2484 llvm::dyn_cast(getCondition().getType()))

2485 p << condType << ", ";

2487 }

2488

2490 Type conditionType = getCondition().getType();

2492 return success();

2493

2494

2495

2497 if (!llvm::isa<TensorType, VectorType>(resultType))

2498 return emitOpError() << "expected condition to be a signless i1, but got "

2499 << conditionType;

2501 if (conditionType != shapedConditionType) {

2502 return emitOpError() << "expected condition type to have the same shape "

2503 "as the result type, expected "

2504 << shapedConditionType << ", but got "

2505 << conditionType;

2506 }

2507 return success();

2508 }

2509

2510

2511

2512

2513 OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {

2514

2516 return getLhs();

2517

2518 bool bounded = false;

2519 auto result = constFoldBinaryOp(

2520 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {

2521 bounded = b.ult(b.getBitWidth());

2522 return a.shl(b);

2523 });

2524 return bounded ? result : Attribute();

2525 }

2526

2527

2528

2529

2530

2531 OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {

2532

2534 return getLhs();

2535

2536 bool bounded = false;

2537 auto result = constFoldBinaryOp(

2538 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {

2539 bounded = b.ult(b.getBitWidth());

2540 return a.lshr(b);

2541 });

2542 return bounded ? result : Attribute();

2543 }

2544

2545

2546

2547

2548

2549 OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {

2550

2552 return getLhs();

2553

2554 bool bounded = false;

2555 auto result = constFoldBinaryOp(

2556 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {

2557 bounded = b.ult(b.getBitWidth());

2558 return a.ashr(b);

2559 });

2560 return bounded ? result : Attribute();

2561 }

2562

2563

2564

2565

2566

2567

2570 bool useOnlyFiniteValue) {

2571 switch (kind) {

2572 case AtomicRMWKind::maximumf: {

2573 const llvm::fltSemantics &semantic =

2574 llvm::cast(resultType).getFloatSemantics();

2575 APFloat identity = useOnlyFiniteValue

2576 ? APFloat::getLargest(semantic, true)

2577 : APFloat::getInf(semantic, true);

2578 return builder.getFloatAttr(resultType, identity);

2579 }

2580 case AtomicRMWKind::maxnumf: {

2581 const llvm::fltSemantics &semantic =

2582 llvm::cast(resultType).getFloatSemantics();

2583 APFloat identity = APFloat::getNaN(semantic, true);

2584 return builder.getFloatAttr(resultType, identity);

2585 }

2586 case AtomicRMWKind::addf:

2587 case AtomicRMWKind::addi:

2588 case AtomicRMWKind::maxu:

2589 case AtomicRMWKind::ori:

2591 case AtomicRMWKind::andi:

2593 resultType,

2594 APInt::getAllOnes(llvm::cast(resultType).getWidth()));

2595 case AtomicRMWKind::maxs:

2597 resultType, APInt::getSignedMinValue(

2598 llvm::cast(resultType).getWidth()));

2599 case AtomicRMWKind::minimumf: {

2600 const llvm::fltSemantics &semantic =

2601 llvm::cast(resultType).getFloatSemantics();

2602 APFloat identity = useOnlyFiniteValue

2603 ? APFloat::getLargest(semantic, false)

2604 : APFloat::getInf(semantic, false);

2605

2606 return builder.getFloatAttr(resultType, identity);

2607 }

2608 case AtomicRMWKind::minnumf: {

2609 const llvm::fltSemantics &semantic =

2610 llvm::cast(resultType).getFloatSemantics();

2611 APFloat identity = APFloat::getNaN(semantic, false);

2612 return builder.getFloatAttr(resultType, identity);

2613 }

2614 case AtomicRMWKind::mins:

2616 resultType, APInt::getSignedMaxValue(

2617 llvm::cast(resultType).getWidth()));

2618 case AtomicRMWKind::minu:

2620 resultType,

2621 APInt::getMaxValue(llvm::cast(resultType).getWidth()));

2622 case AtomicRMWKind::muli:

2624 case AtomicRMWKind::mulf:

2626

2627 default:

2628 (void)emitOptionalError(loc, "Reduction operation type not supported");

2629 break;

2630 }

2631 return nullptr;

2632 }

2633

2634

2636 std::optional maybeKind =

2638

2639 .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })

2640 .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })

2641 .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })

2642 .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })

2643 .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })

2644 .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })

2645

2646 .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })

2647 .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })

2648 .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })

2649 .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })

2650 .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })

2651 .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })

2652 .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })

2653 .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })

2654 .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })

2655 .Default([](Operation *op) { return std::nullopt; });

2656 if (!maybeKind) {

2657 return std::nullopt;

2658 }

2659

2660 bool useOnlyFiniteValue = false;

2661 auto fmfOpInterface = dyn_cast(op);

2662 if (fmfOpInterface) {

2663 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();

2664 useOnlyFiniteValue =

2665 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);

2666 }

2667

2668

2671

2673 useOnlyFiniteValue);

2674 }

2675

2676

2679 bool useOnlyFiniteValue) {

2680 auto attr =

2682 return builder.createarith::ConstantOp(loc, attr);

2683 }

2684

2685

2686

2689 switch (op) {

2690 case AtomicRMWKind::addf:

2691 return builder.createarith::AddFOp(loc, lhs, rhs);

2692 case AtomicRMWKind::addi:

2693 return builder.createarith::AddIOp(loc, lhs, rhs);

2694 case AtomicRMWKind::mulf:

2695 return builder.createarith::MulFOp(loc, lhs, rhs);

2696 case AtomicRMWKind::muli:

2697 return builder.createarith::MulIOp(loc, lhs, rhs);

2698 case AtomicRMWKind::maximumf:

2699 return builder.createarith::MaximumFOp(loc, lhs, rhs);

2700 case AtomicRMWKind::minimumf:

2701 return builder.createarith::MinimumFOp(loc, lhs, rhs);

2702 case AtomicRMWKind::maxnumf:

2703 return builder.createarith::MaxNumFOp(loc, lhs, rhs);

2704 case AtomicRMWKind::minnumf:

2705 return builder.createarith::MinNumFOp(loc, lhs, rhs);

2706 case AtomicRMWKind::maxs:

2707 return builder.createarith::MaxSIOp(loc, lhs, rhs);

2708 case AtomicRMWKind::mins:

2709 return builder.createarith::MinSIOp(loc, lhs, rhs);

2710 case AtomicRMWKind::maxu:

2711 return builder.createarith::MaxUIOp(loc, lhs, rhs);

2712 case AtomicRMWKind::minu:

2713 return builder.createarith::MinUIOp(loc, lhs, rhs);

2714 case AtomicRMWKind::ori:

2715 return builder.createarith::OrIOp(loc, lhs, rhs);

2716 case AtomicRMWKind::andi:

2717 return builder.createarith::AndIOp(loc, lhs, rhs);

2718

2719 default:

2720 (void)emitOptionalError(loc, "Reduction operation type not supported");

2721 break;

2722 }

2723 return nullptr;

2724 }

2725

2726

2727

2728

2729

2730 #define GET_OP_CLASSES

2731 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"

2732

2733

2734

2735

2736

2737 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"

static Speculation::Speculatability getDivUISpeculatability(Value divisor)

Returns whether an unsigned division by divisor is speculatable.

static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)

Validate a cast that changes the width of a type.

static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)

static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)

static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)

Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).

static Type getTypeIfLike(Type type)

Get allowed underlying types for vectors and tensors.

static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)

Returns true if the predicate is true for two equal operands.

static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)

Fold (a * b) / b -> a

static bool hasSameEncoding(Type typeA, Type typeB)

Return false if both types are ranked tensor with mismatching encoding.

static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)

Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...

static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)

static Speculation::Speculatability getDivSISpeculatability(Value divisor)

Returns whether a signed division by divisor is speculatable.

static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)

static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)

static Attribute getBoolAttribute(Type type, bool value)

static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)

Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...

static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)

static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)

static LogicalResult verifyExtOp(Op op)

static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)

static int64_t getScalarOrElementWidth(Type type)

static Value foldAndIofAndI(arith::AndIOp op)

Fold and(a, and(a, b)) to and(a, b)

static Type getTypeIfLikeOrMemRef(Type type)

Get allowed underlying types for vectors, tensors, and memrefs.

static Type getI1SameShape(Type type)

Return the type of the same shape (scalar, vector or tensor) containing i1.

static std::optional< int64_t > getIntegerWidth(Type t)

static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)

std::tuple< Types... > * type_list

static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)

static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)

static LogicalResult verifyTruncateOp(Op op)

static Value getZero(OpBuilder &b, Location loc, Type elementType)

Get zero value for an element type.

static MLIRContext * getContext(OpFoldResult val)

union mlir::linalg::@1203::ArityGroupAndKind::Kind kind

static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)

Contracts the specified cycle in the given graph in-place.

static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)

static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)

Walks the given type hierarchy with the given indices, potentially down to component granularity,...

LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override

static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)

virtual Builder & getBuilder() const =0

Return a builder which provides useful access to MLIRContext, global objects like types and attribute...

virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0

Parse a named dictionary into 'result' if it is present.

virtual ParseResult parseColonType(Type &result)=0

Parse a colon followed by a type.

virtual ParseResult parseOptionalComma()=0

Parse a , token if present.

virtual SMLoc getNameLoc() const =0

Return the location of the original name token.

virtual ParseResult parseType(Type &result)=0

Parse a type.

Attributes are known-constant values of operations.

Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.

static BoolAttr get(MLIRContext *context, bool value)

This class is a general helper class for creating context-global objects like types,...

IntegerAttr getIndexAttr(int64_t value)

IntegerAttr getIntegerAttr(Type type, int64_t value)

FloatAttr getFloatAttr(Type type, double value)

IntegerType getIntegerType(unsigned width)

TypedAttr getZeroAttr(Type type)

static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)

Constructs a dense elements attribute from an array of element values.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)

Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...

virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0

Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0

If the specified operation has attributes, print out an attribute dictionary with their values.

This class helps build Operations.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents a single result from folding an operation.

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

This provides public APIs that all operations should have.

Operation is the basic unit of execution within MLIR.

Value getOperand(unsigned idx)

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

MLIRContext * getContext()

Return the context this operation is associated with.

Location getLoc()

The source location the operation was defined or derived from.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)

Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

MLIRContext * getContext() const

Return the MLIRContext in which this type was uniqued.

bool isSignlessInteger() const

Return true if this is a signless integer type (with the specified width).

bool isIntOrFloat() const

Return true if this is an integer (of any signedness) or a float type.

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type)

Build a constant float op that produces a float of the specified type.

static bool classof(Operation *op)

static void build(OpBuilder &builder, OperationState &result, int64_t value)

Build a constant int op that produces an index.

static bool classof(Operation *op)

Specialization of arith.constant op that returns an integer value.

static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)

Build a constant int op that produces an integer of the specified width.

static bool classof(Operation *op)

Speculatability

This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...

constexpr auto Speculatable

constexpr auto NotSpeculatable

std::optional< TypedAttr > getNeutralElement(Operation *op)

Return the identity numeric value associated to the give op.

bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)

Compute lhs pred rhs, where pred is one of the known integer comparison predicates.

TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)

Returns the identity value attribute associated with an AtomicRMWKind op.

Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)

Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...

Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)

Returns the identity value associated with an AtomicRMWKind op.

arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)

Invert an integer comparison predicate.

QueryRef parse(llvm::StringRef line, const QuerySession &qs)

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)

Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...

detail::constant_float_predicate_matcher m_NaNFloat()

Matches a constant scalar / vector splat / tensor splat float ones.

LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)

Returns success if the given two arrays have the same number of elements and each pair wise entries h...

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()

Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...

LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)

Overloads of the above emission functions that take an optionally null location.

detail::constant_float_predicate_matcher m_PosZeroFloat()

Matches a constant scalar / vector splat / tensor splat float positive zero.

detail::constant_int_predicate_matcher m_Zero()

Matches a constant scalar / vector splat / tensor splat integer zero.

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

detail::constant_int_predicate_matcher m_One()

Matches a constant scalar / vector splat / tensor splat integer one.

detail::constant_float_predicate_matcher m_NegInfFloat()

Matches a constant scalar / vector splat / tensor splat float negative infinity.

detail::constant_float_predicate_matcher m_NegZeroFloat()

Matches a constant scalar / vector splat / tensor splat float negative zero.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()

Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

detail::constant_float_predicate_matcher m_PosInfFloat()

Matches a constant scalar / vector splat / tensor splat float positive infinity.

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

detail::constant_float_predicate_matcher m_OneFloat()

Matches a constant scalar / vector splat / tensor splat float ones.

detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()

Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...

LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

This represents an operation in an abstracted form, suitable for use with the builder APIs.

SmallVector< Value, 4 > operands

void addTypes(ArrayRef< Type > newTypes)