MLIR: lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

14

15 using namespace mlir;

16

17

19 if (!val)

20 return false;

21

22

23

25 return true;

26

28 .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {

29 return op && op.getInputs().size() == 1 &&

31 })

32 .Default([&](auto) { return false; });

33 }

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

57

60

61 if (!addOp.hasPureTensorSemantics())

62 return failure();

63

64 Value dominatingOperand = nullptr;

65 linalg::LinalgOp dominatedOp = nullptr;

66 {

67 Value lhs = addOp.getInputs()[0];

68 Value rhs = addOp.getInputs()[1];

69

70

71

72

73

74 if (auto rhsOp = rhs.getDefiningOplinalg::LinalgOp()) {

77 dominatingOperand = lhs;

78 dominatedOp = rhsOp;

79 }

80 }

81 if (auto lhsOp = lhs.getDefiningOplinalg::LinalgOp()) {

84 dominatingOperand = rhs;

85 dominatedOp = lhsOp;

86 }

87 }

88 if (!dominatingOperand || !dominatedOp)

89 return failure();

90

91

92 }

93

94

95

96

97

98 auto dominatedDestOp =

99 dyn_cast((Operation *)dominatedOp);

100 if (dominatedOp->getNumResults() != 1 ||

102 (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))

104 dominatedOp, "expected dominated op to be single-result "

105 "destination-passing contraction");

106

107

108 if (!dominatedOp->getResult(0).hasOneUse())

110 dominatedOp,

111 "expected linalg.add to be single user of contraction's result");

112

113

114

115 auto *destOperand = dominatedDestOp.getDpsInitOperand(0);

118 dominatedOp, "expected dominated op's dest to be additive zero");

119

120

121

122

123

124

125

127 int prevDimPos = -1;

128 for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {

129 auto dim = dyn_cast(expr);

130 if (!dim || prevDimPos > static_cast<int>(dim.getPosition()))

132 dominatedOp, "expected index_map for contraction's dest to be an "

133 "ordered projection");

134 prevDimPos = dim.getPosition();

135 }

136

137

138

139

141 dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });

143 return success();

144 }

145 };

146

148

150 }

static bool isDefinedAsZero(Value val)

A class for computing basic dominance information.

bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const

Return true if operation A properly dominates operation B, i.e.

Operation is the basic unit of execution within MLIR.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

void replaceAllOpUsesWith(Operation *from, ValueRange to)

Find uses of from and replace them with to.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

bool isaContractionOpInterface(LinalgOp linalgOp)

Checks whether linalgOp conforms to ContractionOpInterface.

void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)

Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

detail::constant_int_predicate_matcher m_Zero()

Matches a constant scalar / vector splat / tensor splat integer zero.

const FrozenRewritePatternSet & patterns

detail::constant_float_predicate_matcher m_AnyZeroFloat()

Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.

Replace a linalg.add with one operand the single user of a contraction, which has a zero-filled,...

LogicalResult matchAndRewrite(linalg::AddOp addOp, PatternRewriter &rewriter) const override

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...