MLIR: lib/Dialect/Math/Transforms/UpliftToFMA.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
18
20 #define GEN_PASS_DEF_MATHUPLIFTTOFMA
21 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
22 }
23
24 using namespace mlir;
25
26 template
29 }
30
31 namespace {
32
35
36 LogicalResult matchAndRewrite(arith::AddFOp op,
39 return rewriter.notifyMatchFailure(op, "addf op is not suitable for fma");
40
42 arith::MulFOp ab;
43 if ((ab = op.getLhs().getDefiningOparith::MulFOp())) {
44 c = op.getRhs();
45 } else if ((ab = op.getRhs().getDefiningOparith::MulFOp())) {
46 c = op.getLhs();
47 } else {
49 }
50
52 return rewriter.notifyMatchFailure(ab, "mulf op is not suitable for fma");
53
54 Value a = ab.getLhs();
55 Value b = ab.getRhs();
56 arith::FastMathFlags fmf = op.getFastmath() & ab.getFastmath();
58 return success();
59 }
60 };
61
62 struct MathUpliftToFMA final
63 : math::impl::MathUpliftToFMABase {
64 using MathUpliftToFMABase::MathUpliftToFMABase;
65
66 void runOnOperation() override {
70 return signalPassFailure();
71 }
72 };
73
74 }
75
78 }
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static bool isValidForFMA(Op op)
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateUpliftToFMAPatterns(RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...