MLIR: lib/Conversion/MathToFuncs/MathToFuncs.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
27
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTMATHTOFUNCS
30 #include "mlir/Conversion/Passes.h.inc"
31 }
32
33 using namespace mlir;
34
35 #define DEBUG_TYPE "math-to-funcs"
36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
37
38 namespace {
39
40 template
42 public:
44
45 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
46 };
47
48
49
51
52
53
54 class IPowIOpLowering : public OpRewritePatternmath::IPowIOp {
55 public:
56 IPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
58
59
60
61
62 LogicalResult matchAndRewrite(math::IPowIOp op,
64
65 private:
66 GetFuncCallbackTy getFuncOpCallback;
67 };
68
69
70
71 class FPowIOpLowering : public OpRewritePatternmath::FPowIOp {
72 public:
73 FPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
75
76
77
78
79 LogicalResult matchAndRewrite(math::FPowIOp op,
81
82 private:
83 GetFuncCallbackTy getFuncOpCallback;
84 };
85
86
87
88 class CtlzOpLowering : public OpRewritePatternmath::CountLeadingZerosOp {
89 public:
90 CtlzOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
92 getFuncOpCallback(cb) {}
93
94
95
96 LogicalResult matchAndRewrite(math::CountLeadingZerosOp op,
98
99 private:
100 GetFuncCallbackTy getFuncOpCallback;
101 };
102 }
103
104 template
105 LogicalResult
106 VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
107 Type opType = op.getType();
109 auto vecType = dyn_cast(opType);
110
111 if (!vecType)
113 if (!vecType.hasRank())
116 int64_t numElements = vecType.getNumElements();
117
118 Type resultElementType = vecType.getElementType();
120 if (isa(resultElementType))
121 initValueAttr = FloatAttr::get(resultElementType, 0.0);
122 else
124 Value result = rewriter.createarith::ConstantOp(
127 for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
130 for (Value input : op->getOperands())
131 operands.push_back(
132 rewriter.createvector::ExtractOp(loc, input, positions));
134 rewriter.create<Op>(loc, vecType.getElementType(), operands);
135 result =
136 rewriter.createvector::InsertOp(loc, scalarOp, result, positions);
137 }
139 return success();
140 }
141
146 resultTys.begin(),
147 [](Type ty) { return getElementTypeOrSelf(ty); });
149 inputTys.begin(),
150 [](Type ty) { return getElementTypeOrSelf(ty); });
152 }
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
186 assert(isa(elementType) &&
187 "non-integer element type for IPowIOp");
188
191
192 std::string funcName("__mlir_math_ipowi");
193 llvm::raw_string_ostream nameOS(funcName);
194 nameOS << '_' << elementType;
195
197 builder.getContext(), {elementType, elementType}, elementType);
198 auto funcOp = builder.createfunc::FuncOp(funcName, funcType);
199 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
202 funcOp->setAttr("llvm.linkage", linkage);
203 funcOp.setPrivate();
204
205 Block *entryBlock = funcOp.addEntryBlock();
207
208 Value bArg = funcOp.getArgument(0);
209 Value pArg = funcOp.getArgument(1);
211 Value zeroValue = builder.createarith::ConstantOp(
212 elementType, builder.getIntegerAttr(elementType, 0));
213 Value oneValue = builder.createarith::ConstantOp(
214 elementType, builder.getIntegerAttr(elementType, 1));
215 Value minusOneValue = builder.createarith::ConstantOp(
216 elementType,
219 true)));
220
221
222
223 auto pIsZero =
224 builder.createarith::CmpIOp(arith::CmpIPredicate::eq, pArg, zeroValue);
226 builder.createfunc::ReturnOp(oneValue);
228
230 builder.createcf::CondBranchOp(pIsZero, thenBlock, fallthroughBlock);
231
232
234 auto pIsNeg =
235 builder.createarith::CmpIOp(arith::CmpIPredicate::sle, pArg, zeroValue);
236
238 auto bIsZero =
239 builder.createarith::CmpIOp(arith::CmpIPredicate::eq, bArg, zeroValue);
240
241 thenBlock = builder.createBlock(funcBody);
242 builder.createfunc::ReturnOp(
243 builder.createarith::DivSIOp(oneValue, zeroValue).getResult());
244 fallthroughBlock = builder.createBlock(funcBody);
245
247 builder.createcf::CondBranchOp(bIsZero, thenBlock, fallthroughBlock);
248
249
251 auto bIsOne =
252 builder.createarith::CmpIOp(arith::CmpIPredicate::eq, bArg, oneValue);
253
254 thenBlock = builder.createBlock(funcBody);
255 builder.createfunc::ReturnOp(oneValue);
256 fallthroughBlock = builder.createBlock(funcBody);
257
259 builder.createcf::CondBranchOp(bIsOne, thenBlock, fallthroughBlock);
260
261
263 auto bIsMinusOne = builder.createarith::CmpIOp(arith::CmpIPredicate::eq,
264 bArg, minusOneValue);
265
267 auto pIsOdd = builder.createarith::CmpIOp(
268 arith::CmpIPredicate::ne, builder.createarith::AndIOp(pArg, oneValue),
269 zeroValue);
270
271 thenBlock = builder.createBlock(funcBody);
272 builder.createfunc::ReturnOp(minusOneValue);
273 fallthroughBlock = builder.createBlock(funcBody);
274
276 builder.createcf::CondBranchOp(pIsOdd, thenBlock, fallthroughBlock);
277
278
279
281 builder.createfunc::ReturnOp(oneValue);
282 fallthroughBlock = builder.createBlock(funcBody);
283
285 builder.createcf::CondBranchOp(bIsMinusOne, pIsOdd->getBlock(),
286 fallthroughBlock);
287
288
289
291 builder.createfunc::ReturnOp(zeroValue);
293 funcBody, funcBody->end(), {elementType, elementType, elementType},
295
297
298 builder.createcf::CondBranchOp(pIsNeg, bIsZero->getBlock(), loopHeader,
300
301
302
303
304
305
306
307
308
309
310 Value resultTmp = loopHeader->getArgument(0);
311 Value baseTmp = loopHeader->getArgument(1);
312 Value powerTmp = loopHeader->getArgument(2);
314
315
316 auto powerTmpIsOdd = builder.createarith::CmpIOp(
317 arith::CmpIPredicate::ne,
318 builder.createarith::AndIOp(powerTmp, oneValue), zeroValue);
319 thenBlock = builder.createBlock(funcBody);
320
321 Value newResultTmp = builder.createarith::MulIOp(resultTmp, baseTmp);
322 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType,
325 builder.createcf::BranchOp(newResultTmp, fallthroughBlock);
326
328 builder.createcf::CondBranchOp(powerTmpIsOdd, thenBlock, fallthroughBlock,
329 resultTmp);
330
331 newResultTmp = fallthroughBlock->getArgument(0);
332
333
335 Value newPowerTmp = builder.createarith::ShRUIOp(powerTmp, oneValue);
336
337
338 auto newPowerIsZero = builder.createarith::CmpIOp(arith::CmpIPredicate::eq,
339 newPowerTmp, zeroValue);
340
341 thenBlock = builder.createBlock(funcBody);
342 builder.createfunc::ReturnOp(newResultTmp);
343 fallthroughBlock = builder.createBlock(funcBody);
344
346 builder.createcf::CondBranchOp(newPowerIsZero, thenBlock, fallthroughBlock);
347
348
349
351 Value newBaseTmp = builder.createarith::MulIOp(baseTmp, baseTmp);
352
353 builder.createcf::BranchOp(
354 ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
355 return funcOp;
356 }
357
358
359
360
361 LogicalResult
362 IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
364 auto baseType = dyn_cast(op.getOperands()[0].getType());
365
366 if (!baseType)
368
369
370
371 func::FuncOp elementFunc = getFuncOpCallback(op, baseType);
372 if (!elementFunc)
373 return rewriter.notifyMatchFailure(op, "missing software implementation");
374
375 rewriter.replaceOpWithNewOpfunc::CallOp(op, elementFunc, op.getOperands());
376 return success();
377 }
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
413 FunctionType funcType) {
414 auto baseType = cast(funcType.getInput(0));
415 auto powType = cast(funcType.getInput(1));
418
419 std::string funcName("__mlir_math_fpowi");
420 llvm::raw_string_ostream nameOS(funcName);
421 nameOS << '_' << baseType;
422 nameOS << '_' << powType;
423 auto funcOp = builder.createfunc::FuncOp(funcName, funcType);
424 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
427 funcOp->setAttr("llvm.linkage", linkage);
428 funcOp.setPrivate();
429
430 Block *entryBlock = funcOp.addEntryBlock();
432
433 Value bArg = funcOp.getArgument(0);
434 Value pArg = funcOp.getArgument(1);
436 Value oneBValue = builder.createarith::ConstantOp(
437 baseType, builder.getFloatAttr(baseType, 1.0));
438 Value zeroPValue = builder.createarith::ConstantOp(
440 Value onePValue = builder.createarith::ConstantOp(
442 Value minPValue = builder.createarith::ConstantOp(
443 powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue(
444 powType.getWidth())));
445 Value maxPValue = builder.createarith::ConstantOp(
446 powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue(
447 powType.getWidth())));
448
449
450
451 auto pIsZero =
452 builder.createarith::CmpIOp(arith::CmpIPredicate::eq, pArg, zeroPValue);
454 builder.createfunc::ReturnOp(oneBValue);
456
458 builder.createcf::CondBranchOp(pIsZero, thenBlock, fallthroughBlock);
459
461
462 auto pIsNeg = builder.createarith::CmpIOp(arith::CmpIPredicate::sle, pArg,
463 zeroPValue);
464
465 auto pIsMin =
466 builder.createarith::CmpIOp(arith::CmpIPredicate::eq, pArg, minPValue);
467
468
469
470
471
472
473 Value negP = builder.createarith::SubIOp(zeroPValue, pArg);
474 auto pInit = builder.createarith::SelectOp(pIsNeg, negP, pArg);
475 pInit = builder.createarith::SelectOp(pIsMin, maxPValue, pInit);
476
477
478
479
480
481
482
483
484
485
486
488 funcBody, funcBody->end(), {baseType, baseType, powType},
490
492 builder.createcf::BranchOp(loopHeader, ValueRange{oneBValue, bArg, pInit});
493
494
495 Value resultTmp = loopHeader->getArgument(0);
496 Value baseTmp = loopHeader->getArgument(1);
497 Value powerTmp = loopHeader->getArgument(2);
499
500
501 auto powerTmpIsOdd = builder.createarith::CmpIOp(
502 arith::CmpIPredicate::ne,
503 builder.createarith::AndIOp(powerTmp, onePValue), zeroPValue);
504 thenBlock = builder.createBlock(funcBody);
505
506 Value newResultTmp = builder.createarith::MulFOp(resultTmp, baseTmp);
507 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
510 builder.createcf::BranchOp(newResultTmp, fallthroughBlock);
511
513 builder.createcf::CondBranchOp(powerTmpIsOdd, thenBlock, fallthroughBlock,
514 resultTmp);
515
516 newResultTmp = fallthroughBlock->getArgument(0);
517
518
520 Value newPowerTmp = builder.createarith::ShRUIOp(powerTmp, onePValue);
521
522
523 auto newPowerIsZero = builder.createarith::CmpIOp(arith::CmpIPredicate::eq,
524 newPowerTmp, zeroPValue);
525
526
527
528
529 fallthroughBlock = builder.createBlock(funcBody);
530
531
532
534 Value newBaseTmp = builder.createarith::MulFOp(baseTmp, baseTmp);
535
536 builder.createcf::BranchOp(
537 ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
538
539
540
541
542 Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType,
545 builder.createcf::CondBranchOp(newPowerIsZero, loopExit, newResultTmp,
547
548
549
550
551 newResultTmp = loopExit->getArgument(0);
552 thenBlock = builder.createBlock(funcBody);
553 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
556 builder.createcf::CondBranchOp(pIsMin, thenBlock, fallthroughBlock,
557 newResultTmp);
559 newResultTmp = builder.createarith::MulFOp(newResultTmp, bArg);
560 builder.createcf::BranchOp(newResultTmp, fallthroughBlock);
561
562
563
564
565 newResultTmp = fallthroughBlock->getArgument(0);
566 thenBlock = builder.createBlock(funcBody);
567 Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
570 builder.createcf::CondBranchOp(pIsNeg, thenBlock, returnBlock,
571 newResultTmp);
573 newResultTmp = builder.createarith::DivFOp(oneBValue, newResultTmp);
574 builder.createcf::BranchOp(newResultTmp, returnBlock);
575
576
579
580 return funcOp;
581 }
582
583
584
585
586 LogicalResult
587 FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
589 if (isa(op.getType()))
591
593
594
595
596 func::FuncOp elementFunc = getFuncOpCallback(op, funcType);
597 if (!elementFunc)
598 return rewriter.notifyMatchFailure(op, "missing software implementation");
599
600 rewriter.replaceOpWithNewOpfunc::CallOp(op, elementFunc, op.getOperands());
601 return success();
602 }
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
652 if (!isa(elementType)) {
653 LLVM_DEBUG({
654 DBGS() << "non-integer element type for CtlzFunc; type was: ";
655 elementType.print(llvm::dbgs());
656 });
657 llvm_unreachable("non-integer element type");
658 }
660
664
665 std::string funcName("__mlir_math_ctlz");
666 llvm::raw_string_ostream nameOS(funcName);
667 nameOS << '_' << elementType;
668 FunctionType funcType =
670 auto funcOp = builder.createfunc::FuncOp(funcName, funcType);
671
672
673
674 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
677 funcOp->setAttr("llvm.linkage", linkage);
678 funcOp.setPrivate();
679
680
681 Block *funcBody = funcOp.addEntryBlock();
683
684 Value arg = funcOp.getArgument(0);
686 Value bitWidthValue = builder.createarith::ConstantOp(
687 elementType, builder.getIntegerAttr(elementType, bitWidth));
688 Value zeroValue = builder.createarith::ConstantOp(
689 elementType, builder.getIntegerAttr(elementType, 0));
690
691 Value inputEqZero =
692 builder.createarith::CmpIOp(arith::CmpIPredicate::eq, arg, zeroValue);
693
694
695 scf::IfOp ifOp = builder.createscf::IfOp(
696 elementType, inputEqZero, true, true);
697 ifOp.getThenBodyBuilder().createscf::YieldOp(loc, bitWidthValue);
698
699 auto elseBuilder =
701
702 Value oneIndex = elseBuilder.createarith::ConstantOp(
703 indexType, elseBuilder.getIndexAttr(1));
704 Value oneValue = elseBuilder.createarith::ConstantOp(
705 elementType, elseBuilder.getIntegerAttr(elementType, 1));
706 Value bitWidthIndex = elseBuilder.createarith::ConstantOp(
707 indexType, elseBuilder.getIndexAttr(bitWidth));
708 Value nValue = elseBuilder.createarith::ConstantOp(
709 elementType, elseBuilder.getIntegerAttr(elementType, 0));
710
711 auto loop = elseBuilder.createscf::ForOp(
712 oneIndex, bitWidthIndex, oneIndex,
713
714
715
717
718
719
720
721
722
723
725 Value argIter = args[0];
726 Value nIter = args[1];
727
728 Value argIsNonNegative = b.createarith::CmpIOp(
729 loc, arith::CmpIPredicate::slt, argIter, zeroValue);
730 scf::IfOp ifOp = b.createscf::IfOp(
731 loc, argIsNonNegative,
733
734 b.createscf::YieldOp(loc, ValueRange{argIter, nIter});
735 },
737
738 Value nNext = b.createarith::AddIOp(loc, nIter, oneValue);
739 Value argNext = b.createarith::ShLIOp(loc, argIter, oneValue);
740 b.createscf::YieldOp(loc, ValueRange{argNext, nNext});
741 });
742 b.createscf::YieldOp(loc, ifOp.getResults());
743 });
744 elseBuilder.createscf::YieldOp(loop.getResult(1));
745
746 builder.createfunc::ReturnOp(ifOp.getResult(0));
747 return funcOp;
748 }
749
750
751
752 LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
754 if (isa(op.getType()))
756
758 func::FuncOp elementFunc = getFuncOpCallback(op, type);
759 if (!elementFunc)
761 diag << "Missing software implementation for op " << op->getName()
762 << " and type " << type;
763 });
764
765 rewriter.replaceOpWithNewOpfunc::CallOp(op, elementFunc, op.getOperand());
766 return success();
767 }
768
769 namespace {
770 struct ConvertMathToFuncsPass
771 : public impl::ConvertMathToFuncsBase {
772 ConvertMathToFuncsPass() = default;
773 ConvertMathToFuncsPass(const ConvertMathToFuncsOptions &options)
774 : impl::ConvertMathToFuncsBase(options) {}
775
776 void runOnOperation() override;
777
778 private:
779
780
781
782 bool isFPowIConvertible(math::FPowIOp op);
783
784
785 bool isConvertible(Operation *op);
786
787
788
789 void generateOpImplementations();
790
791
792
793
795 };
796 }
797
798 bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
799 auto expTy =
801 return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
802 }
803
804 bool ConvertMathToFuncsPass::isConvertible(Operation *op) {
806 }
807
808 void ConvertMathToFuncsPass::generateOpImplementations() {
809 ModuleOp module = getOperation();
810
813 .Casemath::CountLeadingZerosOp([&](math::CountLeadingZerosOp op) {
814 if (!convertCtlz || !isConvertible(op))
815 return;
817
818
819
820 auto key = std::pair(op->getName(), resultType);
821 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
822 if (entry.second)
823 entry.first->second = createCtlzFunc(&module, resultType);
824 })
825 .Casemath::IPowIOp([&](math::IPowIOp op) {
826 if (!isConvertible(op))
827 return;
828
830
831
832
833 auto key = std::pair(op->getName(), resultType);
834 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
835 if (entry.second)
837 })
838 .Casemath::FPowIOp([&](math::FPowIOp op) {
839 if (!isFPowIConvertible(op))
840 return;
841
843
844
845
846
847
848 auto key = std::pair(op->getName(), funcType);
849 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
850 if (entry.second)
852 });
853 });
854 }
855
856 void ConvertMathToFuncsPass::runOnOperation() {
857 ModuleOp module = getOperation();
858
859
860 generateOpImplementations();
861
863 patterns.add<VecOpToScalarOpmath::IPowIOp, VecOpToScalarOpmath::FPowIOp,
864 VecOpToScalarOpmath::CountLeadingZerosOp>(
866
867
868 auto getFuncOpByType = [&](Operation *op, Type type) -> func::FuncOp {
869 auto it = funcImpls.find(std::pair(op->getName(), type));
870 if (it == funcImpls.end())
871 return {};
872
873 return it->second;
874 };
875 patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(),
876 getFuncOpByType);
877
878 if (convertCtlz)
879 patterns.add(patterns.getContext(), getFuncOpByType);
880
882 target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
883 func::FuncDialect, scf::SCFDialect,
884 vector::VectorDialect>();
885
886 target.addDynamicallyLegalOpmath::IPowIOp(
887 [this](math::IPowIOp op) { return !isConvertible(op); });
888 if (convertCtlz) {
889 target.addDynamicallyLegalOpmath::CountLeadingZerosOp(
890 [this](math::CountLeadingZerosOp op) { return !isConvertible(op); });
891 }
892 target.addDynamicallyLegalOpmath::FPowIOp(
893 [this](math::FPowIOp op) { return !isFPowIConvertible(op); });
895 signalPassFailure();
896 }
static MLIRContext * getContext(OpFoldResult val)
static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType)
Create linkonce_odr function to implement the power function with the given elementType type inside m...
static FunctionType getElementalFuncTypeForOp(Operation *op)
static func::FuncOp createElementFPowIFunc(ModuleOp *module, FunctionType funcType)
Create linkonce_odr function to implement the power function with the given funcType type inside modu...
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType)
Create function to implement the ctlz function the given elementType type inside module.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
MLIRContext * getContext() const
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
operand_type_iterator operand_type_end()
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumOperands()
result_type_iterator result_type_end()
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
result_type_iterator result_type_begin()
OperationName getName()
The name of an operation is the key identifier for it.
unsigned getNumResults()
Return the number of results held by this operation.
operand_type_iterator operand_type_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
void print(raw_ostream &os) const
Print the current type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...