MLIR: lib/Conversion/MathToFuncs/MathToFuncs.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

24 #include "llvm/ADT/DenseMap.h"

25 #include "llvm/ADT/TypeSwitch.h"

26 #include "llvm/Support/Debug.h"

27

28 namespace mlir {

29 #define GEN_PASS_DEF_CONVERTMATHTOFUNCS

30 #include "mlir/Conversion/Passes.h.inc"

31 }

32

33 using namespace mlir;

34

35 #define DEBUG_TYPE "math-to-funcs"

36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")

37

38 namespace {

39

40 template

42 public:

44

45 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;

46 };

47

48

49

51

52

53

54 class IPowIOpLowering : public OpRewritePatternmath::IPowIOp {

55 public:

56 IPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)

58

59

60

61

62 LogicalResult matchAndRewrite(math::IPowIOp op,

64

65 private:

66 GetFuncCallbackTy getFuncOpCallback;

67 };

68

69

70

71 class FPowIOpLowering : public OpRewritePatternmath::FPowIOp {

72 public:

73 FPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)

75

76

77

78

79 LogicalResult matchAndRewrite(math::FPowIOp op,

81

82 private:

83 GetFuncCallbackTy getFuncOpCallback;

84 };

85

86

87

88 class CtlzOpLowering : public OpRewritePatternmath::CountLeadingZerosOp {

89 public:

90 CtlzOpLowering(MLIRContext *context, GetFuncCallbackTy cb)

92 getFuncOpCallback(cb) {}

93

94

95

96 LogicalResult matchAndRewrite(math::CountLeadingZerosOp op,

98

99 private:

100 GetFuncCallbackTy getFuncOpCallback;

101 };

102 }

103

104 template

105 LogicalResult

106 VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const {

107 Type opType = op.getType();

109 auto vecType = dyn_cast(opType);

110

111 if (!vecType)

113 if (!vecType.hasRank())

116 int64_t numElements = vecType.getNumElements();

117

118 Type resultElementType = vecType.getElementType();

120 if (isa(resultElementType))

121 initValueAttr = FloatAttr::get(resultElementType, 0.0);

122 else

124 Value result = rewriter.createarith::ConstantOp(

127 for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {

130 for (Value input : op->getOperands())

131 operands.push_back(

132 rewriter.createvector::ExtractOp(loc, input, positions));

134 rewriter.create<Op>(loc, vecType.getElementType(), operands);

135 result =

136 rewriter.createvector::InsertOp(loc, scalarOp, result, positions);

137 }

139 return success();

140 }

141

146 resultTys.begin(),

147 [](Type ty) { return getElementTypeOrSelf(ty); });

149 inputTys.begin(),

150 [](Type ty) { return getElementTypeOrSelf(ty); });

152 }

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

186 assert(isa(elementType) &&

187 "non-integer element type for IPowIOp");

188

191

192 std::string funcName("__mlir_math_ipowi");

193 llvm::raw_string_ostream nameOS(funcName);

194 nameOS << '_' << elementType;

195

197 builder.getContext(), {elementType, elementType}, elementType);

198 auto funcOp = builder.createfunc::FuncOp(funcName, funcType);

199 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;

202 funcOp->setAttr("llvm.linkage", linkage);

203 funcOp.setPrivate();

204

205 Block *entryBlock = funcOp.addEntryBlock();

207

208 Value bArg = funcOp.getArgument(0);

209 Value pArg = funcOp.getArgument(1);

211 Value zeroValue = builder.createarith::ConstantOp(

212 elementType, builder.getIntegerAttr(elementType, 0));

213 Value oneValue = builder.createarith::ConstantOp(

214 elementType, builder.getIntegerAttr(elementType, 1));

215 Value minusOneValue = builder.createarith::ConstantOp(

216 elementType,

219 true)));

220

221

222

223 auto pIsZero =

224 builder.createarith::CmpIOp(arith::CmpIPredicate::eq, pArg, zeroValue);

226 builder.createfunc::ReturnOp(oneValue);

228

230 builder.createcf::CondBranchOp(pIsZero, thenBlock, fallthroughBlock);

231

232

234 auto pIsNeg =

235 builder.createarith::CmpIOp(arith::CmpIPredicate::sle, pArg, zeroValue);

236

238 auto bIsZero =

239 builder.createarith::CmpIOp(arith::CmpIPredicate::eq, bArg, zeroValue);

240

241 thenBlock = builder.createBlock(funcBody);

242 builder.createfunc::ReturnOp(

243 builder.createarith::DivSIOp(oneValue, zeroValue).getResult());

244 fallthroughBlock = builder.createBlock(funcBody);

245

247 builder.createcf::CondBranchOp(bIsZero, thenBlock, fallthroughBlock);

248

249

251 auto bIsOne =

252 builder.createarith::CmpIOp(arith::CmpIPredicate::eq, bArg, oneValue);

253

254 thenBlock = builder.createBlock(funcBody);

255 builder.createfunc::ReturnOp(oneValue);

256 fallthroughBlock = builder.createBlock(funcBody);

257

259 builder.createcf::CondBranchOp(bIsOne, thenBlock, fallthroughBlock);

260

261

263 auto bIsMinusOne = builder.createarith::CmpIOp(arith::CmpIPredicate::eq,

264 bArg, minusOneValue);

265

267 auto pIsOdd = builder.createarith::CmpIOp(

268 arith::CmpIPredicate::ne, builder.createarith::AndIOp(pArg, oneValue),

269 zeroValue);

270

271 thenBlock = builder.createBlock(funcBody);

272 builder.createfunc::ReturnOp(minusOneValue);

273 fallthroughBlock = builder.createBlock(funcBody);

274

276 builder.createcf::CondBranchOp(pIsOdd, thenBlock, fallthroughBlock);

277

278

279

281 builder.createfunc::ReturnOp(oneValue);

282 fallthroughBlock = builder.createBlock(funcBody);

283

285 builder.createcf::CondBranchOp(bIsMinusOne, pIsOdd->getBlock(),

286 fallthroughBlock);

287

288

289

291 builder.createfunc::ReturnOp(zeroValue);

293 funcBody, funcBody->end(), {elementType, elementType, elementType},

295

297

298 builder.createcf::CondBranchOp(pIsNeg, bIsZero->getBlock(), loopHeader,

300

301

302

303

304

305

306

307

308

309

310 Value resultTmp = loopHeader->getArgument(0);

311 Value baseTmp = loopHeader->getArgument(1);

312 Value powerTmp = loopHeader->getArgument(2);

314

315

316 auto powerTmpIsOdd = builder.createarith::CmpIOp(

317 arith::CmpIPredicate::ne,

318 builder.createarith::AndIOp(powerTmp, oneValue), zeroValue);

319 thenBlock = builder.createBlock(funcBody);

320

321 Value newResultTmp = builder.createarith::MulIOp(resultTmp, baseTmp);

322 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType,

325 builder.createcf::BranchOp(newResultTmp, fallthroughBlock);

326

328 builder.createcf::CondBranchOp(powerTmpIsOdd, thenBlock, fallthroughBlock,

329 resultTmp);

330

331 newResultTmp = fallthroughBlock->getArgument(0);

332

333

335 Value newPowerTmp = builder.createarith::ShRUIOp(powerTmp, oneValue);

336

337

338 auto newPowerIsZero = builder.createarith::CmpIOp(arith::CmpIPredicate::eq,

339 newPowerTmp, zeroValue);

340

341 thenBlock = builder.createBlock(funcBody);

342 builder.createfunc::ReturnOp(newResultTmp);

343 fallthroughBlock = builder.createBlock(funcBody);

344

346 builder.createcf::CondBranchOp(newPowerIsZero, thenBlock, fallthroughBlock);

347

348

349

351 Value newBaseTmp = builder.createarith::MulIOp(baseTmp, baseTmp);

352

353 builder.createcf::BranchOp(

354 ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);

355 return funcOp;

356 }

357

358

359

360

361 LogicalResult

362 IPowIOpLowering::matchAndRewrite(math::IPowIOp op,

364 auto baseType = dyn_cast(op.getOperands()[0].getType());

365

366 if (!baseType)

368

369

370

371 func::FuncOp elementFunc = getFuncOpCallback(op, baseType);

372 if (!elementFunc)

373 return rewriter.notifyMatchFailure(op, "missing software implementation");

374

375 rewriter.replaceOpWithNewOpfunc::CallOp(op, elementFunc, op.getOperands());

376 return success();

377 }

378

379

380

381

382

383

384

385

386

387

388

389

390

391

392

393

394

395

396

397

398

399

400

401

402

403

404

405

406

407

408

409

410

411

413 FunctionType funcType) {

414 auto baseType = cast(funcType.getInput(0));

415 auto powType = cast(funcType.getInput(1));

418

419 std::string funcName("__mlir_math_fpowi");

420 llvm::raw_string_ostream nameOS(funcName);

421 nameOS << '_' << baseType;

422 nameOS << '_' << powType;

423 auto funcOp = builder.createfunc::FuncOp(funcName, funcType);

424 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;

427 funcOp->setAttr("llvm.linkage", linkage);

428 funcOp.setPrivate();

429

430 Block *entryBlock = funcOp.addEntryBlock();

432

433 Value bArg = funcOp.getArgument(0);

434 Value pArg = funcOp.getArgument(1);

436 Value oneBValue = builder.createarith::ConstantOp(

437 baseType, builder.getFloatAttr(baseType, 1.0));

438 Value zeroPValue = builder.createarith::ConstantOp(

440 Value onePValue = builder.createarith::ConstantOp(

442 Value minPValue = builder.createarith::ConstantOp(

443 powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue(

444 powType.getWidth())));

445 Value maxPValue = builder.createarith::ConstantOp(

446 powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue(

447 powType.getWidth())));

448

449

450

451 auto pIsZero =

452 builder.createarith::CmpIOp(arith::CmpIPredicate::eq, pArg, zeroPValue);

454 builder.createfunc::ReturnOp(oneBValue);

456

458 builder.createcf::CondBranchOp(pIsZero, thenBlock, fallthroughBlock);

459

461

462 auto pIsNeg = builder.createarith::CmpIOp(arith::CmpIPredicate::sle, pArg,

463 zeroPValue);

464

465 auto pIsMin =

466 builder.createarith::CmpIOp(arith::CmpIPredicate::eq, pArg, minPValue);

467

468

469

470

471

472

473 Value negP = builder.createarith::SubIOp(zeroPValue, pArg);

474 auto pInit = builder.createarith::SelectOp(pIsNeg, negP, pArg);

475 pInit = builder.createarith::SelectOp(pIsMin, maxPValue, pInit);

476

477

478

479

480

481

482

483

484

485

486

488 funcBody, funcBody->end(), {baseType, baseType, powType},

490

492 builder.createcf::BranchOp(loopHeader, ValueRange{oneBValue, bArg, pInit});

493

494

495 Value resultTmp = loopHeader->getArgument(0);

496 Value baseTmp = loopHeader->getArgument(1);

497 Value powerTmp = loopHeader->getArgument(2);

499

500

501 auto powerTmpIsOdd = builder.createarith::CmpIOp(

502 arith::CmpIPredicate::ne,

503 builder.createarith::AndIOp(powerTmp, onePValue), zeroPValue);

504 thenBlock = builder.createBlock(funcBody);

505

506 Value newResultTmp = builder.createarith::MulFOp(resultTmp, baseTmp);

507 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,

510 builder.createcf::BranchOp(newResultTmp, fallthroughBlock);

511

513 builder.createcf::CondBranchOp(powerTmpIsOdd, thenBlock, fallthroughBlock,

514 resultTmp);

515

516 newResultTmp = fallthroughBlock->getArgument(0);

517

518

520 Value newPowerTmp = builder.createarith::ShRUIOp(powerTmp, onePValue);

521

522

523 auto newPowerIsZero = builder.createarith::CmpIOp(arith::CmpIPredicate::eq,

524 newPowerTmp, zeroPValue);

525

526

527

528

529 fallthroughBlock = builder.createBlock(funcBody);

530

531

532

534 Value newBaseTmp = builder.createarith::MulFOp(baseTmp, baseTmp);

535

536 builder.createcf::BranchOp(

537 ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);

538

539

540

541

542 Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType,

545 builder.createcf::CondBranchOp(newPowerIsZero, loopExit, newResultTmp,

547

548

549

550

551 newResultTmp = loopExit->getArgument(0);

552 thenBlock = builder.createBlock(funcBody);

553 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,

556 builder.createcf::CondBranchOp(pIsMin, thenBlock, fallthroughBlock,

557 newResultTmp);

559 newResultTmp = builder.createarith::MulFOp(newResultTmp, bArg);

560 builder.createcf::BranchOp(newResultTmp, fallthroughBlock);

561

562

563

564

565 newResultTmp = fallthroughBlock->getArgument(0);

566 thenBlock = builder.createBlock(funcBody);

567 Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType,

570 builder.createcf::CondBranchOp(pIsNeg, thenBlock, returnBlock,

571 newResultTmp);

573 newResultTmp = builder.createarith::DivFOp(oneBValue, newResultTmp);

574 builder.createcf::BranchOp(newResultTmp, returnBlock);

575

576

579

580 return funcOp;

581 }

582

583

584

585

586 LogicalResult

587 FPowIOpLowering::matchAndRewrite(math::FPowIOp op,

589 if (isa(op.getType()))

591

593

594

595

596 func::FuncOp elementFunc = getFuncOpCallback(op, funcType);

597 if (!elementFunc)

598 return rewriter.notifyMatchFailure(op, "missing software implementation");

599

600 rewriter.replaceOpWithNewOpfunc::CallOp(op, elementFunc, op.getOperands());

601 return success();

602 }

603

604

605

606

607

608

609

610

611

612

613

614

615

616

617

618

619

620

621

622

623

624

625

626

627

628

629

630

631

632

633

634

635

636

637

638

639

640

641

642

643

644

645

646

647

648

649

650

652 if (!isa(elementType)) {

653 LLVM_DEBUG({

654 DBGS() << "non-integer element type for CtlzFunc; type was: ";

655 elementType.print(llvm::dbgs());

656 });

657 llvm_unreachable("non-integer element type");

658 }

660

664

665 std::string funcName("__mlir_math_ctlz");

666 llvm::raw_string_ostream nameOS(funcName);

667 nameOS << '_' << elementType;

668 FunctionType funcType =

670 auto funcOp = builder.createfunc::FuncOp(funcName, funcType);

671

672

673

674 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;

677 funcOp->setAttr("llvm.linkage", linkage);

678 funcOp.setPrivate();

679

680

681 Block *funcBody = funcOp.addEntryBlock();

683

684 Value arg = funcOp.getArgument(0);

686 Value bitWidthValue = builder.createarith::ConstantOp(

687 elementType, builder.getIntegerAttr(elementType, bitWidth));

688 Value zeroValue = builder.createarith::ConstantOp(

689 elementType, builder.getIntegerAttr(elementType, 0));

690

691 Value inputEqZero =

692 builder.createarith::CmpIOp(arith::CmpIPredicate::eq, arg, zeroValue);

693

694

695 scf::IfOp ifOp = builder.createscf::IfOp(

696 elementType, inputEqZero, true, true);

697 ifOp.getThenBodyBuilder().createscf::YieldOp(loc, bitWidthValue);

698

699 auto elseBuilder =

701

702 Value oneIndex = elseBuilder.createarith::ConstantOp(

703 indexType, elseBuilder.getIndexAttr(1));

704 Value oneValue = elseBuilder.createarith::ConstantOp(

705 elementType, elseBuilder.getIntegerAttr(elementType, 1));

706 Value bitWidthIndex = elseBuilder.createarith::ConstantOp(

707 indexType, elseBuilder.getIndexAttr(bitWidth));

708 Value nValue = elseBuilder.createarith::ConstantOp(

709 elementType, elseBuilder.getIntegerAttr(elementType, 0));

710

711 auto loop = elseBuilder.createscf::ForOp(

712 oneIndex, bitWidthIndex, oneIndex,

713

714

715

717

718

719

720

721

722

723

725 Value argIter = args[0];

726 Value nIter = args[1];

727

728 Value argIsNonNegative = b.createarith::CmpIOp(

729 loc, arith::CmpIPredicate::slt, argIter, zeroValue);

730 scf::IfOp ifOp = b.createscf::IfOp(

731 loc, argIsNonNegative,

733

734 b.createscf::YieldOp(loc, ValueRange{argIter, nIter});

735 },

737

738 Value nNext = b.createarith::AddIOp(loc, nIter, oneValue);

739 Value argNext = b.createarith::ShLIOp(loc, argIter, oneValue);

740 b.createscf::YieldOp(loc, ValueRange{argNext, nNext});

741 });

742 b.createscf::YieldOp(loc, ifOp.getResults());

743 });

744 elseBuilder.createscf::YieldOp(loop.getResult(1));

745

746 builder.createfunc::ReturnOp(ifOp.getResult(0));

747 return funcOp;

748 }

749

750

751

752 LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,

754 if (isa(op.getType()))

756

758 func::FuncOp elementFunc = getFuncOpCallback(op, type);

759 if (!elementFunc)

761 diag << "Missing software implementation for op " << op->getName()

762 << " and type " << type;

763 });

764

765 rewriter.replaceOpWithNewOpfunc::CallOp(op, elementFunc, op.getOperand());

766 return success();

767 }

768

769 namespace {

770 struct ConvertMathToFuncsPass

771 : public impl::ConvertMathToFuncsBase {

772 ConvertMathToFuncsPass() = default;

773 ConvertMathToFuncsPass(const ConvertMathToFuncsOptions &options)

774 : impl::ConvertMathToFuncsBase(options) {}

775

776 void runOnOperation() override;

777

778 private:

779

780

781

782 bool isFPowIConvertible(math::FPowIOp op);

783

784

785 bool isConvertible(Operation *op);

786

787

788

789 void generateOpImplementations();

790

791

792

793

795 };

796 }

797

798 bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {

799 auto expTy =

801 return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);

802 }

803

804 bool ConvertMathToFuncsPass::isConvertible(Operation *op) {

806 }

807

808 void ConvertMathToFuncsPass::generateOpImplementations() {

809 ModuleOp module = getOperation();

810

813 .Casemath::CountLeadingZerosOp([&](math::CountLeadingZerosOp op) {

814 if (!convertCtlz || !isConvertible(op))

815 return;

817

818

819

820 auto key = std::pair(op->getName(), resultType);

821 auto entry = funcImpls.try_emplace(key, func::FuncOp{});

822 if (entry.second)

823 entry.first->second = createCtlzFunc(&module, resultType);

824 })

825 .Casemath::IPowIOp([&](math::IPowIOp op) {

826 if (!isConvertible(op))

827 return;

828

830

831

832

833 auto key = std::pair(op->getName(), resultType);

834 auto entry = funcImpls.try_emplace(key, func::FuncOp{});

835 if (entry.second)

837 })

838 .Casemath::FPowIOp([&](math::FPowIOp op) {

839 if (!isFPowIConvertible(op))

840 return;

841

843

844

845

846

847

848 auto key = std::pair(op->getName(), funcType);

849 auto entry = funcImpls.try_emplace(key, func::FuncOp{});

850 if (entry.second)

852 });

853 });

854 }

855

856 void ConvertMathToFuncsPass::runOnOperation() {

857 ModuleOp module = getOperation();

858

859

860 generateOpImplementations();

861

863 patterns.add<VecOpToScalarOpmath::IPowIOp, VecOpToScalarOpmath::FPowIOp,

864 VecOpToScalarOpmath::CountLeadingZerosOp>(

866

867

868 auto getFuncOpByType = [&](Operation *op, Type type) -> func::FuncOp {

869 auto it = funcImpls.find(std::pair(op->getName(), type));

870 if (it == funcImpls.end())

871 return {};

872

873 return it->second;

874 };

875 patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(),

876 getFuncOpByType);

877

878 if (convertCtlz)

879 patterns.add(patterns.getContext(), getFuncOpByType);

880

882 target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,

883 func::FuncDialect, scf::SCFDialect,

884 vector::VectorDialect>();

885

886 target.addDynamicallyLegalOpmath::IPowIOp(

887 [this](math::IPowIOp op) { return !isConvertible(op); });

888 if (convertCtlz) {

889 target.addDynamicallyLegalOpmath::CountLeadingZerosOp(

890 [this](math::CountLeadingZerosOp op) { return !isConvertible(op); });

891 }

892 target.addDynamicallyLegalOpmath::FPowIOp(

893 [this](math::FPowIOp op) { return !isFPowIConvertible(op); });

895 signalPassFailure();

896 }

static MLIRContext * getContext(OpFoldResult val)

static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType)

Create linkonce_odr function to implement the power function with the given elementType type inside m...

static FunctionType getElementalFuncTypeForOp(Operation *op)

static func::FuncOp createElementFPowIFunc(ModuleOp *module, FunctionType funcType)

Create linkonce_odr function to implement the power function with the given funcType type inside modu...

static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType)

Create function to implement the ctlz function the given elementType type inside module.

static std::string diag(const llvm::Value &value)

static llvm::ManagedStatic< PassManagerOptions > options

Attributes are known-constant values of operations.

Block represents an ordered list of Operations.

BlockArgument getArgument(unsigned i)

Region * getParent() const

Provide a 'getParent' method for ilist_node_with_parent methods.

IntegerAttr getIntegerAttr(Type type, int64_t value)

FloatAttr getFloatAttr(Type type, double value)

MLIRContext * getContext() const

This class describes a specific conversion target.

static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)

Constructs a dense elements attribute from an array of element values.

This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.

ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...

Location getLoc() const

Accessors for the implied location.

OpTy create(Args &&...args)

Create an operation of specific op type at the current insertion point and location.

static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)

Create a builder and set the insertion point to after the last operation in the block but still insid...

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

This class helps build Operations.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

Location getLoc()

The source location the operation was defined or derived from.

This provides public APIs that all operations should have.

Operation is the basic unit of execution within MLIR.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

operand_type_iterator operand_type_end()

MLIRContext * getContext()

Return the context this operation is associated with.

unsigned getNumOperands()

result_type_iterator result_type_end()

static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)

Create a new Operation with the specific fields.

result_type_iterator result_type_begin()

OperationName getName()

The name of an operation is the key identifier for it.

unsigned getNumResults()

Return the number of results held by this operation.

operand_type_iterator operand_type_begin()

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class contains a list of basic blocks and a link to the parent operation it is attached to.

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

void print(raw_ostream &os) const

Print the current type.

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Include the generated interface declarations.

SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)

SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)

Given the strides together with a linear index in the dimension space, return the vector-space offset...

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())

Below we define several entry points for operation conversion.

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...