MLIR: lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
14
15 using namespace mlir;
16
17
19 if (!val)
20 return false;
21
22
23
25 return true;
26
28 .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
29 return op && op.getInputs().size() == 1 &&
31 })
32 .Default([&](auto) { return false; });
33 }
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
57
60
61 if (!addOp.hasPureTensorSemantics())
62 return failure();
63
64 Value dominatingOperand = nullptr;
65 linalg::LinalgOp dominatedOp = nullptr;
66 {
67 Value lhs = addOp.getInputs()[0];
68 Value rhs = addOp.getInputs()[1];
69
70
71
72
73
74 if (auto rhsOp = rhs.getDefiningOplinalg::LinalgOp()) {
77 dominatingOperand = lhs;
78 dominatedOp = rhsOp;
79 }
80 }
81 if (auto lhsOp = lhs.getDefiningOplinalg::LinalgOp()) {
84 dominatingOperand = rhs;
85 dominatedOp = lhsOp;
86 }
87 }
88 if (!dominatingOperand || !dominatedOp)
89 return failure();
90
91
92 }
93
94
95
96
97
98 auto dominatedDestOp =
99 dyn_cast((Operation *)dominatedOp);
100 if (dominatedOp->getNumResults() != 1 ||
102 (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
104 dominatedOp, "expected dominated op to be single-result "
105 "destination-passing contraction");
106
107
108 if (!dominatedOp->getResult(0).hasOneUse())
110 dominatedOp,
111 "expected linalg.add to be single user of contraction's result");
112
113
114
115 auto *destOperand = dominatedDestOp.getDpsInitOperand(0);
118 dominatedOp, "expected dominated op's dest to be additive zero");
119
120
121
122
123
124
125
127 int prevDimPos = -1;
128 for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {
129 auto dim = dyn_cast(expr);
130 if (!dim || prevDimPos > static_cast<int>(dim.getPosition()))
132 dominatedOp, "expected index_map for contraction's dest to be an "
133 "ordered projection");
134 prevDimPos = dim.getPosition();
135 }
136
137
138
139
141 dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
143 return success();
144 }
145 };
146
148
150 }
static bool isDefinedAsZero(Value val)
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Replace a linalg.add with one operand the single user of a contraction, which has a zero-filled,...
LogicalResult matchAndRewrite(linalg::AddOp addOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...