MLIR: lib/Transforms/Utils/WalkPatternRewriteDriver.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
14
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/ErrorHandling.h"
25
26#define DEBUG_TYPE "walk-rewriter"
27
28namespace mlir {
29
30
31
35 reachableBlocks.insert(entryBlock);
36
38 while (!worklist.empty()) {
39 Block *block = worklist.pop_back_val();
42 if (reachableBlocks.contains(successor))
43 continue;
44 worklist.push_back(successor);
45 reachableBlocks.insert(successor);
46 }
47 }
48}
49
50namespace {
51struct WalkAndApplyPatternsAction final
52 : tracing::ActionImpl {
54 using ActionImpl::ActionImpl;
55 static constexpr StringLiteral tag = "walk-and-apply-patterns";
57};
58
59#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
60
61
62
63
65 using RewriterBase::ForwardingListener::ForwardingListener;
66
67 void notifyOperationErased(Operation *op) override {
68 checkErasure(op);
69 ForwardingListener::notifyOperationErased(op);
70 }
71
72 void notifyBlockErased(Block *block) override {
73 checkErasure(block->getParentOp());
74 ForwardingListener::notifyBlockErased(block);
75 }
76
77 void checkErasure(Operation *op) const {
78 Operation *ancestorOp = op;
79 while (ancestorOp && ancestorOp != visitedOp)
80 ancestorOp = ancestorOp->getParentOp();
81
82 if (ancestorOp != visitedOp)
83 llvm::report_fatal_error(
84 "unsupported erasure in WalkPatternRewriter; "
85 "erasure is only supported for matched ops and their descendants");
86 }
87
88 Operation *visitedOp = nullptr;
89};
90#endif
91}
92
96#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
97 if (failed(verify(op)))
98 llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
99#endif
100
103#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
104 ErasedOpsListener erasedListener(listener);
106#else
108#endif
109
112
113
114
115
116 struct RegionReachableOpIterator {
117 RegionReachableOpIterator(Region *region) : region(region) {
118 regionIt = region->begin();
119 if (regionIt != region->end())
120 blockIt = regionIt->begin();
121 if (!llvm::hasSingleElement(*region))
123 }
124
125 void advance() {
126 assert(regionIt != region->end());
127 hasVisitedRegions = false;
128 if (blockIt == regionIt->end()) {
129 ++regionIt;
130 while (regionIt != region->end() &&
131 !reachableBlocks.contains(&*regionIt))
132 ++regionIt;
133 if (regionIt != region->end())
134 blockIt = regionIt->begin();
135 return;
136 }
137 ++blockIt;
138 if (blockIt != regionIt->end()) {
139 LDBG() << "Incrementing block iterator, next op: "
141 }
142 }
143
145
147
149
151
152 bool hasVisitedRegions = false;
153 };
154
155
157
158 LDBG() << "Starting walk-based pattern rewrite driver";
160 [&] {
161
162
164 assert(worklist.empty());
165 if (region.empty())
166 continue;
167
168
169 worklist.push_back({®ion});
170 while (!worklist.empty()) {
171 RegionReachableOpIterator &it = worklist.back();
172 if (it.regionIt == it.region->end()) {
173
174 worklist.pop_back();
175 continue;
176 }
177 if (it.blockIt == it.regionIt->end()) {
178
179 it.advance();
180 continue;
181 }
183
184
185 if (!it.hasVisitedRegions) {
186 it.hasVisitedRegions = true;
187 for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
188 if (nestedRegion.empty())
189 continue;
190 worklist.push_back({&nestedRegion});
191 }
192 }
193
194
195
196 if (&it != &worklist.back())
197 continue;
198
199
200
201 it.advance();
202
203 LDBG() << "Visiting op: "
205#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
206 erasedListener.visitedOp = op;
207#endif
209 LDBG() << "\tOp matched and rewritten";
210 }
211 }
212 },
213 {op});
214
215#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
216 if (failed(verify(op)))
217 llvm::report_fatal_error(
218 "walk pattern rewriter result IR failed to verify");
219#endif
220}
221
222}
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Block represents an ordered list of Operations.
OpListType::iterator iterator
This class represents a frozen set of patterns that can be processed by a pattern applicator.
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Operation is the basic unit of execution within MLIR.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
SuccessorRange getSuccessors()
MLIRContext * getContext()
Return the context this operation is associated with.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
void applyDefaultCostModel()
Apply the default cost model that solely uses the pattern's static benefit.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType::iterator iterator
Include the generated interface declarations.
static void findReachableBlocks(Region ®ion, DenseSet< Block * > &reachableBlocks)
Definition WalkPatternRewriteDriver.cpp:32
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
Definition WalkPatternRewriteDriver.cpp:93
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
A listener that forwards all notifications to another listener.