MLIR: lib/Dialect/Math/Transforms/UpliftToFMA.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

18

20 #define GEN_PASS_DEF_MATHUPLIFTTOFMA

21 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"

22 }

23

24 using namespace mlir;

25

26 template

29 }

30

31 namespace {

32

35

36 LogicalResult matchAndRewrite(arith::AddFOp op,

39 return rewriter.notifyMatchFailure(op, "addf op is not suitable for fma");

40

42 arith::MulFOp ab;

43 if ((ab = op.getLhs().getDefiningOparith::MulFOp())) {

44 c = op.getRhs();

45 } else if ((ab = op.getRhs().getDefiningOparith::MulFOp())) {

46 c = op.getLhs();

47 } else {

49 }

50

52 return rewriter.notifyMatchFailure(ab, "mulf op is not suitable for fma");

53

54 Value a = ab.getLhs();

55 Value b = ab.getRhs();

56 arith::FastMathFlags fmf = op.getFastmath() & ab.getFastmath();

58 return success();

59 }

60 };

61

62 struct MathUpliftToFMA final

63 : math::impl::MathUpliftToFMABase {

64 using MathUpliftToFMABase::MathUpliftToFMABase;

65

66 void runOnOperation() override {

70 return signalPassFailure();

71 }

72 };

73

74 }

75

78 }

static MLIRContext * getContext(OpFoldResult val)

static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)

Contracts the specified cycle in the given graph in-place.

static bool isValidForFMA(Op op)

This provides public APIs that all operations should have.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Include the generated interface declarations.

LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)

Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...

void populateUpliftToFMAPatterns(RewritePatternSet &patterns)

const FrozenRewritePatternSet & patterns

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...