MLIR: lib/Dialect/Arith/Transforms/EmulateWideInt.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

20 #include "llvm/ADT/APFloat.h"

21 #include "llvm/ADT/APInt.h"

22 #include "llvm/Support/FormatVariadic.h"

23 #include "llvm/Support/MathExtras.h"

24 #include

25

27 #define GEN_PASS_DEF_ARITHEMULATEWIDEINT

28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"

29 }

30

31 using namespace mlir;

32

33

34

35

36

37

38

39

40

41 static std::pair<APInt, APInt> getHalves(const APInt &value,

42 unsigned newBitWidth) {

43 APInt low = value.extractBits(newBitWidth, 0);

44 APInt high = value.extractBits(newBitWidth, newBitWidth);

45 return {std::move(low), std::move(high)};

46 }

47

48

49

50

51

52

54 if (type.getShape().size() == 1)

55 return type.getElementType();

56

57 auto newShape = to_vector(type.getShape());

58 newShape.back() = 1;

60 }

61

62

63

64

65

66

69 int64_t lastOffset) {

71 assert(lastOffset < shape.back() && "Offset out of bounds");

72

73

74 if (shape.size() == 1)

75 return rewriter.createvector::ExtractOp(loc, input, lastOffset);

76

78 offsets.back() = lastOffset;

79 auto sizes = llvm::to_vector(shape);

80 sizes.back() = 1;

82

83 return rewriter.createvector::ExtractStridedSliceOp(loc, input, offsets,

84 sizes, strides);

85 }

86

87

88

89 static std::pair<Value, Value>

94 }

95

96

97

100 auto vecTy = dyn_cast(input.getType());

101 if (!vecTy)

102 return input;

103

104

106 assert(shape.size() >= 2 && "Expected vector with at list two dims");

107 assert(shape.back() == 1 && "Expected the last vector dim to be x1");

108

109 auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType());

110 return rewriter.createvector::ShapeCastOp(loc, newVecTy, input);

111 }

112

113

114

117 auto vecTy = dyn_cast(input.getType());

118 if (!vecTy)

119 return input;

120

121

122 auto newShape = llvm::to_vector(vecTy.getShape());

123 newShape.push_back(1);

124 auto newTy = VectorType::get(newShape, vecTy.getElementType());

125 return rewriter.createvector::ShapeCastOp(loc, newTy, input);

126 }

127

128

129

130

133 int64_t lastOffset) {

135 assert(lastOffset < shape.back() && "Offset out of bounds");

136

137

138 if (isa(source.getType()))

139 return rewriter.createvector::InsertOp(loc, source, dest, lastOffset);

140

142 offsets.back() = lastOffset;

144 return rewriter.createvector::InsertStridedSliceOp(loc, source, dest,

145 offsets, strides);

146 }

147

148

149

150

151

152

153

155 Location loc, VectorType resultType,

158 (void)resultShape;

159 assert(!resultShape.empty() && "Result expected to have dimensions");

160 assert(resultShape.back() == static_cast<int64_t>(resultComponents.size()) &&

161 "Wrong number of result components");

162

164 for (auto [i, component] : llvm::enumerate(resultComponents))

165 resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i);

166

167 return resultVec;

168 }

169

170 namespace {

171

172

173

174

177

178 LogicalResult

179 matchAndRewrite(arith::ConstantOp op, OpAdaptor,

181 Type oldType = op.getType();

182 auto newType = getTypeConverter()->convertType(oldType);

183 if (!newType)

185 op, llvm::formatv("unsupported type: {0}", op.getType()));

186

187 unsigned newBitWidth = newType.getElementTypeBitWidth();

188 Attribute oldValue = op.getValueAttr();

189

190 if (auto intAttr = dyn_cast(oldValue)) {

191 auto [low, high] = getHalves(intAttr.getValue(), newBitWidth);

194 return success();

195 }

196

197 if (auto splatAttr = dyn_cast(oldValue)) {

198 auto [low, high] =

199 getHalves(splatAttr.getSplatValue(), newBitWidth);

200 int64_t numSplatElems = splatAttr.getNumElements();

202 values.reserve(numSplatElems * 2);

203 for (int64_t i = 0; i < numSplatElems; ++i) {

204 values.push_back(low);

205 values.push_back(high);

206 }

207

210 return success();

211 }

212

213 if (auto elemsAttr = dyn_cast(oldValue)) {

214 int64_t numElems = elemsAttr.getNumElements();

216 values.reserve(numElems * 2);

217 for (const APInt &origVal : elemsAttr.getValues()) {

218 auto [low, high] = getHalves(origVal, newBitWidth);

219 values.push_back(std::move(low));

220 values.push_back(std::move(high));

221 }

222

225 return success();

226 }

227

229 "unhandled constant attribute");

230 }

231 };

232

233

234

235

236

239

240 LogicalResult

241 matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,

244 auto newTy = getTypeConverter()->convertType(op.getType());

245 if (!newTy)

247 loc, llvm::formatv("unsupported type: {0}", op.getType()));

248

250

251 auto [lhsElem0, lhsElem1] =

253 auto [rhsElem0, rhsElem1] =

255

256 auto lowSum =

257 rewriter.createarith::AddUIExtendedOp(loc, lhsElem0, rhsElem0);

258 Value overflowVal =

259 rewriter.createarith::ExtUIOp(loc, newElemTy, lowSum.getOverflow());

260

261 Value high0 = rewriter.createarith::AddIOp(loc, overflowVal, lhsElem1);

262 Value high = rewriter.createarith::AddIOp(loc, high0, rhsElem1);

263

264 Value resultVec =

266 rewriter.replaceOp(op, resultVec);

267 return success();

268 }

269 };

270

271

272

273

274

275

276 template

280

281 LogicalResult

282 matchAndRewrite(BinaryOp op, OpAdaptor adaptor,

285 auto newTy = this->getTypeConverter()->template convertType(

286 op.getType());

287 if (!newTy)

289 loc, llvm::formatv("unsupported type: {0}", op.getType()));

290

291 auto [lhsElem0, lhsElem1] =

293 auto [rhsElem0, rhsElem1] =

295

296 Value resElem0 = rewriter.create(loc, lhsElem0, rhsElem0);

297 Value resElem1 = rewriter.create(loc, lhsElem1, rhsElem1);

298 Value resultVec =

300 rewriter.replaceOp(op, resultVec);

301 return success();

302 }

303 };

304

305

306

307

308

309

310

311 static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) {

312 using P = arith::CmpIPredicate;

313 switch (pred) {

314 case P::sge:

315 return P::uge;

316 case P::sgt:

317 return P::ugt;

318 case P::sle:

319 return P::ule;

320 case P::slt:

321 return P::ult;

322 default:

323 return pred;

324 }

325 }

326

329

330 LogicalResult

331 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,

334 auto inputTy =

335 getTypeConverter()->convertType(op.getLhs().getType());

336 if (!inputTy)

338 loc, llvm::formatv("unsupported type: {0}", op.getType()));

339

340 arith::CmpIPredicate highPred = adaptor.getPredicate();

341 arith::CmpIPredicate lowPred = toUnsignedPredicate(highPred);

342

343 auto [lhsElem0, lhsElem1] =

345 auto [rhsElem0, rhsElem1] =

347

349 rewriter.createarith::CmpIOp(loc, lowPred, lhsElem0, rhsElem0);

351 rewriter.createarith::CmpIOp(loc, highPred, lhsElem1, rhsElem1);

352

353 Value cmpResult{};

354 switch (highPred) {

355 case arith::CmpIPredicate::eq: {

356 cmpResult = rewriter.createarith::AndIOp(loc, lowCmp, highCmp);

357 break;

358 }

359 case arith::CmpIPredicate::ne: {

360 cmpResult = rewriter.createarith::OrIOp(loc, lowCmp, highCmp);

361 break;

362 }

363 default: {

364

365 Value highEq = rewriter.createarith::CmpIOp(

366 loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);

367 cmpResult =

368 rewriter.createarith::SelectOp(loc, highEq, lowCmp, highCmp);

369 break;

370 }

371 }

372

373 assert(cmpResult && "Unhandled case");

375 return success();

376 }

377 };

378

379

380

381

382

385

386 LogicalResult

387 matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,

390 auto newTy = getTypeConverter()->convertType(op.getType());

391 if (!newTy)

393 loc, llvm::formatv("unsupported type: {0}", op.getType()));

394

395 auto [lhsElem0, lhsElem1] =

397 auto [rhsElem0, rhsElem1] =

399

400

401

402

403 auto mulLowLow =

404 rewriter.createarith::MulUIExtendedOp(loc, lhsElem0, rhsElem0);

405 Value mulLowHi = rewriter.createarith::MulIOp(loc, lhsElem0, rhsElem1);

406 Value mulHiLow = rewriter.createarith::MulIOp(loc, lhsElem1, rhsElem0);

407

408 Value resLow = mulLowLow.getLow();

410 rewriter.createarith::AddIOp(loc, mulLowLow.getHigh(), mulLowHi);

411 resHi = rewriter.createarith::AddIOp(loc, resHi, mulHiLow);

412

413 Value resultVec =

415 rewriter.replaceOp(op, resultVec);

416 return success();

417 }

418 };

419

420

421

422

423

426

427 LogicalResult

428 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,

431 auto newTy = getTypeConverter()->convertType(op.getType());

432 if (!newTy)

434 loc, llvm::formatv("unsupported type: {0}", op.getType()));

435

437

438

439

440

441 Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());

443 loc, newResultComponentTy, newOperand);

444 Value operandZeroCst =

446 Value signBit = rewriter.createarith::CmpIOp(

447 loc, arith::CmpIPredicate::slt, extended, operandZeroCst);

448 Value signValue =

449 rewriter.createarith::ExtSIOp(loc, newResultComponentTy, signBit);

450

451 Value resultVec =

453 rewriter.replaceOp(op, resultVec);

454 return success();

455 }

456 };

457

458

459

460

461

464

465 LogicalResult

466 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,

469 auto newTy = getTypeConverter()->convertType(op.getType());

470 if (!newTy)

472 loc, llvm::formatv("unsupported type: {0}", op.getType()));

473

475

476

477

478 Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());

480 loc, newResultComponentTy, newOperand);

484 return success();

485 }

486 };

487

488

489

490

491

492 template <typename SourceOp, arith::CmpIPredicate CmpPred>

495

496 LogicalResult

497 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,

500

501 Type oldTy = op.getType();

502 auto newTy = dyn_cast_or_null(

503 this->getTypeConverter()->convertType(oldTy));

504 if (!newTy)

506 loc, llvm::formatv("unsupported type: {0}", op.getType()));

507

508

509

511 rewriter.createarith::CmpIOp(loc, CmpPred, op.getLhs(), op.getRhs());

513 op.getRhs());

514 return success();

515 }

516 };

517

518

519

520

521

522 static bool isIndexOrIndexVector(Type type) {

523 if (isa(type))

524 return true;

525

526 if (auto vectorTy = dyn_cast(type))

527 if (isa(vectorTy.getElementType()))

528 return true;

529

530 return false;

531 }

532

533 template

536

537 LogicalResult

538 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,

540 Type resultType = op.getType();

541 if (!isIndexOrIndexVector(resultType))

542 return failure();

543

545 Type inType = op.getIn().getType();

546 auto newInTy =

547 this->getTypeConverter()->template convertType(inType);

548 if (!newInTy)

550 loc, llvm::formatv("unsupported type: {0}", inType));

551

552

556 return success();

557 }

558 };

559

560 template <typename CastOp, typename ExtensionOp>

563

564 LogicalResult

565 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,

567 Type inType = op.getIn().getType();

568 if (!isIndexOrIndexVector(inType))

569 return failure();

570

572 auto *typeConverter =

573 this->template getTypeConverterarith::WideIntEmulationConverter();

574

575 Type resultType = op.getType();

576 auto newTy = typeConverter->template convertType(resultType);

577 if (!newTy)

579 loc, llvm::formatv("unsupported type: {0}", resultType));

580

581

582 Type narrowTy =

583 rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());

584 if (auto vecTy = dyn_cast(resultType))

586

587

588

589 Value underlyingVal =

590 rewriter.create(loc, narrowTy, adaptor.getIn());

591 rewriter.replaceOpWithNewOp(op, resultType, underlyingVal);

592 return success();

593 }

594 };

595

596

597

598

599

602

603 LogicalResult

604 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,

607 auto newTy = getTypeConverter()->convertType(op.getType());

608 if (!newTy)

610 loc, llvm::formatv("unsupported type: {0}", op.getType()));

611

612 auto [trueElem0, trueElem1] =

614 auto [falseElem0, falseElem1] =

616 Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition());

617

619 rewriter.createarith::SelectOp(loc, cond, trueElem0, falseElem0);

621 rewriter.createarith::SelectOp(loc, cond, trueElem1, falseElem1);

622 Value resultVec =

624 rewriter.replaceOp(op, resultVec);

625 return success();

626 }

627 };

628

629

630

631

632

635

636 LogicalResult

637 matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,

640

641 Type oldTy = op.getType();

642 auto newTy = getTypeConverter()->convertType(oldTy);

643 if (!newTy)

645 loc, llvm::formatv("unsupported type: {0}", op.getType()));

646

648

649 unsigned newBitWidth = newTy.getElementTypeBitWidth();

650

651 auto [lhsElem0, lhsElem1] =

654

655

656

657

658

659

660

661

662

663

664

665

666

667

668

669

670

671

672

673

674

675

676

677

678

680 Value elemBitWidth =

682

683 Value illegalElemShift = rewriter.createarith::CmpIOp(

684 loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);

685

686 Value shiftedElem0 =

687 rewriter.createarith::ShLIOp(loc, lhsElem0, rhsElem0);

688 Value resElem0 = rewriter.createarith::SelectOp(loc, illegalElemShift,

689 zeroCst, shiftedElem0);

690

691 Value cappedShiftAmount = rewriter.createarith::SelectOp(

692 loc, illegalElemShift, elemBitWidth, rhsElem0);

693 Value rightShiftAmount =

694 rewriter.createarith::SubIOp(loc, elemBitWidth, cappedShiftAmount);

695 Value shiftedRight =

696 rewriter.createarith::ShRUIOp(loc, lhsElem0, rightShiftAmount);

697 Value overshotShiftAmount =

698 rewriter.createarith::SubIOp(loc, rhsElem0, elemBitWidth);

699 Value shiftedLeft =

700 rewriter.createarith::ShLIOp(loc, lhsElem0, overshotShiftAmount);

701

702 Value shiftedElem1 =

703 rewriter.createarith::ShLIOp(loc, lhsElem1, rhsElem0);

704 Value resElem1High = rewriter.createarith::SelectOp(

705 loc, illegalElemShift, zeroCst, shiftedElem1);

706 Value resElem1Low = rewriter.createarith::SelectOp(

707 loc, illegalElemShift, shiftedLeft, shiftedRight);

709 rewriter.createarith::OrIOp(loc, resElem1Low, resElem1High);

710

711 Value resultVec =

713 rewriter.replaceOp(op, resultVec);

714 return success();

715 }

716 };

717

718

719

720

721

724

725 LogicalResult

726 matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor,

729

730 Type oldTy = op.getType();

731 auto newTy = getTypeConverter()->convertType(oldTy);

732 if (!newTy)

734 loc, llvm::formatv("unsupported type: {0}", op.getType()));

735

737

738 unsigned newBitWidth = newTy.getElementTypeBitWidth();

739

740 auto [lhsElem0, lhsElem1] =

743

744

745

746

747

748

749

750

751

752

753

754

755

756

757

758

759

760

761

762

763

764

765

766

767

769 Value elemBitWidth =

771

772 Value illegalElemShift = rewriter.createarith::CmpIOp(

773 loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);

774

775 Value shiftedElem0 =

776 rewriter.createarith::ShRUIOp(loc, lhsElem0, rhsElem0);

777 Value resElem0Low = rewriter.createarith::SelectOp(loc, illegalElemShift,

778 zeroCst, shiftedElem0);

779 Value shiftedElem1 =

780 rewriter.createarith::ShRUIOp(loc, lhsElem1, rhsElem0);

781 Value resElem1 = rewriter.createarith::SelectOp(loc, illegalElemShift,

782 zeroCst, shiftedElem1);

783

784 Value cappedShiftAmount = rewriter.createarith::SelectOp(

785 loc, illegalElemShift, elemBitWidth, rhsElem0);

786 Value leftShiftAmount =

787 rewriter.createarith::SubIOp(loc, elemBitWidth, cappedShiftAmount);

788 Value shiftedLeft =

789 rewriter.createarith::ShLIOp(loc, lhsElem1, leftShiftAmount);

790 Value overshotShiftAmount =

791 rewriter.createarith::SubIOp(loc, rhsElem0, elemBitWidth);

792 Value shiftedRight =

793 rewriter.createarith::ShRUIOp(loc, lhsElem1, overshotShiftAmount);

794

795 Value resElem0High = rewriter.createarith::SelectOp(

796 loc, illegalElemShift, shiftedRight, shiftedLeft);

798 rewriter.createarith::OrIOp(loc, resElem0Low, resElem0High);

799

800 Value resultVec =

802 rewriter.replaceOp(op, resultVec);

803 return success();

804 }

805 };

806

807

808

809

810

813

814 LogicalResult

815 matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,

818

819 Type oldTy = op.getType();

820 auto newTy = getTypeConverter()->convertType(oldTy);

821 if (!newTy)

823 loc, llvm::formatv("unsupported type: {0}", op.getType()));

824

827

829 int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2;

830

831

832

833

835 Value signBit = rewriter.createarith::CmpIOp(

836 loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);

838

839

840

841

842 Value allSign = rewriter.createarith::ExtSIOp(loc, oldTy, signBit);

845 Value numNonSignExtBits =

846 rewriter.createarith::SubIOp(loc, maxShift, rhsElem0);

847 numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits);

848 numNonSignExtBits =

849 rewriter.createarith::ExtUIOp(loc, oldTy, numNonSignExtBits);

851 rewriter.createarith::ShLIOp(loc, allSign, numNonSignExtBits);

852

853

855 rewriter.createarith::ShRUIOp(loc, op.getLhs(), op.getRhs());

856 Value shrsi = rewriter.createarith::OrIOp(loc, shrui, signBits);

857

858

859

860 Value isNoop = rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::eq,

861 rhsElem0, elemZero);

863 rewriter.replaceOpWithNewOparith::SelectOp(op, isNoop, op.getLhs(),

864 shrsi);

865

866 return success();

867 }

868 };

869

870

871

872

873

876

877 LogicalResult

878 matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor,

881 auto newTy = getTypeConverter()->convertType(op.getType());

882 if (!newTy)

884 loc, llvm::formatv("unsupported type: {}", op.getType()));

885

887

888 auto [lhsElem0, lhsElem1] =

890 auto [rhsElem0, rhsElem1] =

892

893

894

895 Value low = rewriter.createarith::SubIOp(loc, lhsElem0, rhsElem0);

896

897 Value carry0 = rewriter.createarith::CmpIOp(

898 loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);

899 Value carryVal = rewriter.createarith::ExtUIOp(loc, newElemTy, carry0);

900

901 Value high0 = rewriter.createarith::SubIOp(loc, lhsElem1, carryVal);

902 Value high = rewriter.createarith::SubIOp(loc, high0, rhsElem1);

903

905 rewriter.replaceOp(op, resultVec);

906 return success();

907 }

908 };

909

910

911

912

913

916

917 LogicalResult

918 matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor,

921

922 Value in = op.getIn();

924 auto newTy = getTypeConverter()->convertType(oldTy);

925 if (!newTy)

927 loc, llvm::formatv("unsupported type: {0}", oldTy));

928

930

931

932

933

934

935

936 Value isNeg = rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::slt,

937 in, zeroCst);

938 Value neg = rewriter.createarith::SubIOp(loc, zeroCst, in);

939 Value abs = rewriter.createarith::SelectOp(loc, isNeg, neg, in);

940

941 Value absResult = rewriter.createarith::UIToFPOp(loc, op.getType(), abs);

942 Value negResult = rewriter.createarith::NegFOp(loc, absResult);

944 absResult);

945 return success();

946 }

947 };

948

949

950

951

952

955

956 LogicalResult

957 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,

960

961 Type oldTy = op.getIn().getType();

962 auto newTy = getTypeConverter()->convertType(oldTy);

963 if (!newTy)

965 loc, llvm::formatv("unsupported type: {0}", oldTy));

966 unsigned newBitWidth = newTy.getElementTypeBitWidth();

967

973

974

975

976

977

978

979

980

981

982

983

984

985

986

987

988 Value hiEqZero = rewriter.createarith::CmpIOp(

989 loc, arith::CmpIPredicate::eq, hiInt, zeroCst);

990

991 Type resultTy = op.getType();

993 Value lowFp = rewriter.createarith::UIToFPOp(loc, resultTy, lowInt);

994 Value hiFp = rewriter.createarith::UIToFPOp(loc, resultTy, hiInt);

995

996 int64_t pow2Int = int64_t(1) << newBitWidth;

997 TypedAttr pow2Attr =

998 rewriter.getFloatAttr(resultElemTy, static_cast<double>(pow2Int));

999 if (auto vecTy = dyn_cast(resultTy))

1001

1002 Value pow2Val = rewriter.createarith::ConstantOp(loc, resultTy, pow2Attr);

1003

1004 Value hiVal = rewriter.createarith::MulFOp(loc, hiFp, pow2Val);

1005 Value result = rewriter.createarith::AddFOp(loc, lowFp, hiVal);

1006

1007 rewriter.replaceOpWithNewOparith::SelectOp(op, hiEqZero, lowFp, result);

1008 return success();

1009 }

1010 };

1011

1012

1013

1014

1015

1018

1019 LogicalResult

1020 matchAndRewrite(arith::FPToSIOp op, OpAdaptor adaptor,

1023

1024 Value inFp = adaptor.getIn();

1026

1027 Type intTy = op.getType();

1028

1029 auto newTy = getTypeConverter()->convertType(intTy);

1030 if (!newTy)

1032 loc, llvm::formatv("unsupported type: {}", intTy));

1033

1034

1035

1036

1037

1038

1039 TypedAttr zeroAttr = rewriter.getZeroAttr(fpTy);

1040 Value zeroCst = rewriter.createarith::ConstantOp(loc, zeroAttr);

1042

1043

1044

1045 Value isNeg = rewriter.createarith::CmpFOp(loc, arith::CmpFPredicate::OLT,

1046 inFp, zeroCst);

1047 Value negInFp = rewriter.createarith::NegFOp(loc, inFp);

1048

1049 Value absVal = rewriter.createarith::SelectOp(loc, isNeg, negInFp, inFp);

1050

1051

1052 Value res = rewriter.createarith::FPToUIOp(loc, intTy, absVal);

1053

1054

1055 Value neg = rewriter.createarith::SubIOp(loc, zeroCstInt, res);

1056

1058 return success();

1059 }

1060 };

1061

1062

1063

1064

1065

1068

1069 LogicalResult

1070 matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor,

1073

1074 Value inFp = adaptor.getIn();

1076

1077 Type intTy = op.getType();

1078 auto newTy = getTypeConverter()->convertType(intTy);

1079 if (!newTy)

1081 loc, llvm::formatv("unsupported type: {}", intTy));

1082 unsigned newBitWidth = newTy.getElementTypeBitWidth();

1083

1085 if (auto vecType = dyn_cast(fpTy))

1086 newHalfType = VectorType::get(vecType.getShape(), newHalfType);

1087

1088

1089

1090

1091

1092

1093

1094 const llvm::fltSemantics &fSemantics =

1096

1097 auto powBitwidth = llvm::APFloat(fSemantics);

1098

1099

1100

1101

1102 if (powBitwidth.convertFromAPInt(APInt(newBitWidth * 2, 1).shl(newBitWidth),

1103 false, llvm::RoundingMode::TowardZero) ==

1104 llvm::detail::opStatus::opInexact)

1105 powBitwidth = llvm::APFloat::getInf(fSemantics);

1106

1107 TypedAttr powBitwidthAttr =

1109 if (auto vecType = dyn_cast(fpTy))

1111 Value powBitwidthFloatCst =

1112 rewriter.createarith::ConstantOp(loc, powBitwidthAttr);

1113

1114 Value fpDivPowBitwidth =

1115 rewriter.createarith::DivFOp(loc, inFp, powBitwidthFloatCst);

1117 rewriter.createarith::FPToUIOp(loc, newHalfType, fpDivPowBitwidth);

1118

1119 Value remainder =

1120 rewriter.createarith::RemFOp(loc, inFp, powBitwidthFloatCst);

1122 rewriter.createarith::FPToUIOp(loc, newHalfType, remainder);

1123

1126

1128

1129 rewriter.replaceOp(op, resultVec);

1130 return success();

1131 }

1132 };

1133

1134

1135

1136

1137

1140

1141 LogicalResult

1142 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,

1145

1146

1147 if (!getTypeConverter()->isLegal(op.getType()))

1149 loc, llvm::formatv("unsupported truncation result type: {0}",

1150 op.getType()));

1151

1152

1153

1156 Value truncated =

1157 rewriter.createOrFoldarith::TruncIOp(loc, op.getType(), extracted);

1158 rewriter.replaceOp(op, truncated);

1159 return success();

1160 }

1161 };

1162

1163

1164

1165

1166

1167 struct ConvertVectorPrint final : OpConversionPatternvector::PrintOp {

1169

1170 LogicalResult

1171 matchAndRewrite(vector::PrintOp op, OpAdaptor adaptor,

1173 rewriter.replaceOpWithNewOpvector::PrintOp(op, adaptor.getSource());

1174 return success();

1175 }

1176 };

1177

1178

1179

1180

1181

1182 struct EmulateWideIntPass final

1183 : arith::impl::ArithEmulateWideIntBase {

1184 using ArithEmulateWideIntBase::ArithEmulateWideIntBase;

1185

1186 void runOnOperation() override {

1187 if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {

1188 signalPassFailure();

1189 return;

1190 }

1191

1194

1195 arith::WideIntEmulationConverter typeConverter(widestIntSupported);

1197 target.addDynamicallyLegalOpfunc::FuncOp([&typeConverter](Operation *op) {

1198 return typeConverter.isLegal(castfunc::FuncOp(op).getFunctionType());

1199 });

1200 auto opLegalCallback = [&typeConverter](Operation *op) {

1201 return typeConverter.isLegal(op);

1202 };

1203 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);

1204 target.addDynamicallyLegalOpvector::PrintOp(opLegalCallback);

1205 target.addDynamicallyLegalDialectarith::ArithDialect(opLegalCallback);

1206 target.addLegalDialectvector::VectorDialect();

1207

1210

1211

1212 populateFunctionOpInterfaceTypeConversionPatternfunc::FuncOp(

1216

1218 signalPassFailure();

1219 }

1220 };

1221 }

1222

1223

1224

1225

1226

1228 unsigned widestIntSupportedByTarget)

1229 : maxIntWidth(widestIntSupportedByTarget) {

1230 assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&

1231 "Only power-of-two integers with are supported");

1232 assert(widestIntSupportedByTarget >= 2 && "Integer type too narrow");

1233

1234

1235 addConversion([](Type ty) -> std::optional { return ty; });

1236

1237

1238 addConversion([this](IntegerType ty) -> std::optional {

1239 unsigned width = ty.getWidth();

1240 if (width <= maxIntWidth)

1241 return ty;

1242

1243

1244 if (width == 2 * maxIntWidth)

1246

1247 return nullptr;

1248 });

1249

1250

1251 addConversion([this](VectorType ty) -> std::optional {

1252 auto intTy = dyn_cast(ty.getElementType());

1253 if (!intTy)

1254 return ty;

1255

1256 unsigned width = intTy.getWidth();

1257 if (width <= maxIntWidth)

1258 return ty;

1259

1260

1261 if (width == 2 * maxIntWidth) {

1262 auto newShape = to_vector(ty.getShape());

1263 newShape.push_back(2);

1266 }

1267

1268 return nullptr;

1269 });

1270

1271

1272 addConversion([this](FunctionType ty) -> std::optional {

1273

1274

1276 if (failed(convertTypes(ty.getInputs(), inputs)))

1277 return nullptr;

1278

1280 if (failed(convertTypes(ty.getResults(), results)))

1281 return nullptr;

1282

1284 });

1285 }

1286

1290

1292

1293 ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint,

1294

1295 ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI,

1296 ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,

1297 ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,

1298 ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,

1299 ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>, ConvertSubI,

1300

1301 ConvertBitwiseBinaryarith::AndIOp, ConvertBitwiseBinaryarith::OrIOp,

1302 ConvertBitwiseBinaryarith::XOrIOp,

1303

1304 ConvertExtSI, ConvertExtUI, ConvertTruncI,

1305

1306 ConvertIndexCastIntToIndexarith::IndexCastOp,

1307 ConvertIndexCastIntToIndexarith::IndexCastUIOp,

1308 ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,

1309 ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,

1310 ConvertSIToFP, ConvertUIToFP, ConvertFPToUI, ConvertFPToSI>(

1311 typeConverter, patterns.getContext());

1312 }

static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value source, Value dest, int64_t lastOffset)

Inserts the source vector slice into the dest vector at offset lastOffset in the last dimension.

static std::pair< APInt, APInt > getHalves(const APInt &value, unsigned newBitWidth)

Returns N bottom and N top bits from value, where N = newBitWidth.

static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input)

Performs a vector shape cast to append an x1 dimension.

static std::pair< Value, Value > extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc, Value input)

Extracts two vector slices from the input whose type is vector<...x2T>, with the first element at off...

static Type reduceInnermostDim(VectorType type)

Returns the type with the last (innermost) dimension reduced to x1.

static Value constructResultVector(ConversionPatternRewriter &rewriter, Location loc, VectorType resultType, ValueRange resultComponents)

Constructs a new vector of type resultType by creating a series of insertions of resultComponents,...

static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input)

static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t lastOffset)

Extracts the input vector slice with elements at the last dimension offset by lastOffset.

Attributes are known-constant values of operations.

FloatAttr getFloatAttr(Type type, double value)

IntegerType getIntegerType(unsigned width)

TypedAttr getZeroAttr(Type type)

This class implements a pattern rewriter for use with ConversionPatterns.

void replaceOp(Operation *op, ValueRange newValues) override

Replace the given operation with the new values.

This class describes a specific conversion target.

static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)

Constructs a dense elements attribute from an array of element values.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...

typename SourceOp::Adaptor OpAdaptor

OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)

Operation is the basic unit of execution within MLIR.

MLIRContext * getContext()

Return the context this operation is associated with.

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

void addConversion(FnT &&callback)

Register a conversion function.

LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const

Convert the given set of types, filling 'results' as necessary.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

MLIRContext * getContext() const

Return the MLIRContext in which this type was uniqued.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

MLIRContext * getContext() const

Utility to get the associated MLIRContext that this value is defined in.

Type getType() const

Return the type of this value.

Converts integer types that are too wide for the target by splitting them in two halves and thus turn...

WideIntEmulationConverter(unsigned widestIntSupportedByTarget)

void populateArithWideIntEmulationPatterns(const WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)

Adds patterns to emulate wide Arith and Function ops over integer types into supported ones.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

Fraction abs(const Fraction &f)

Include the generated interface declarations.

Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)

Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)

Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)

Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...

LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())

Below we define several entry points for operation conversion.