MLIR: lib/IR/AttrTypeSubElements.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10#include
11
12using namespace mlir;
13
14
15
16
17
19 return walkImpl(attr, attrWalkFns, order);
20}
22 return walkImpl(type, typeWalkFns, order);
23}
24
25template <typename T, typename WalkFns>
26WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns,
28
29 auto key = std::make_pair(element.getAsOpaquePointer(), (int)order);
33 return it->second;
34
35
37 if (walkSubElements(element, order).wasInterrupted())
39 }
40
41
42 for (auto &walkFn : llvm::reverse(walkFns)) {
43 WalkResult walkResult = walkFn(element);
48 }
49
50
52 if (walkSubElements(element, order).wasInterrupted())
54 }
56}
57
58template
61 auto walkFn = [&](auto element) {
62 if (element && .wasInterrupted())
63 result = walkImpl(element, order);
64 };
65 interface.walkImmediateSubElements(walkFn, walkFn);
67}
68
69
70
71
72
73template
76 attrReplacementFns.emplace_back(std::move(fn));
77}
78
79template
82 typeReplacementFns.push_back(std::move(fn));
83}
84
85template
87 Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
88
89
90 auto replaceIfDifferent = [&](auto element) {
91 auto replacement = static_cast<Concrete *>(this)->replace(element);
93 };
94
95
96 if (replaceAttrs) {
98 op->setAttrs(cast(newAttrs));
99 }
100
101
102 if (!replaceTypes && !replaceLocs)
103 return;
104
105
106 if (replaceLocs) {
108 op->setLoc(cast(newLoc));
109 }
110
111
112 if (replaceTypes) {
114 if (Type newType = replaceIfDifferent(result.getType()))
115 result.setType(newType);
116 }
117
118
120 for (Block &block : region) {
121 for (BlockArgument &arg : block.getArguments()) {
122 if (replaceLocs) {
123 if (Attribute newLoc = replaceIfDifferent(arg.getLoc()))
124 arg.setLoc(cast(newLoc));
125 }
126
127 if (replaceTypes) {
128 if (Type newType = replaceIfDifferent(arg.getType()))
129 arg.setType(newType);
130 }
131 }
132 }
133 }
134}
135
136template
138 Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
140 replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes);
141 });
142}
143
144template <typename T, typename Replacer>
147 FailureOr &changed) {
148
150 return;
152
153 if (!element) {
154 newElements.push_back(nullptr);
155 return;
156 }
157
158
159 if (T result = replacer.replace(element)) {
160 newElements.push_back(result);
161 if (result != element)
163 } else {
165 }
166}
167
168template <typename T, typename Replacer>
170
173 FailureOr changed = false;
174 interface.walkImmediateSubElements(
177 },
178 [&](Type element) {
180 });
182 return nullptr;
183
184
185 T result = interface;
187 result = interface.replaceImmediateSubElements(newAttrs, newTypes);
189}
190
191
192template <typename T, typename ReplaceFns, typename Replacer>
194 Replacer &replacer) {
197 for (auto &replaceFn : llvm::reverse(replaceFns)) {
198 if (std::optional<std::pair<T, WalkResult>> newRes = replaceFn(element)) {
199 std::tie(result, walkResult) = *newRes;
200 break;
201 }
202 }
203
204
206 return nullptr;
207 }
208
209
211
213 return nullptr;
214 }
215 }
216
218}
219
220template
223 *static_cast<Concrete *>(this));
224}
225
226template
229 *static_cast<Concrete *>(this));
230}
231
232
233
234
235
237
238template
239T AttrTypeReplacer::cachedReplaceImpl(T element) {
240 const void *opaqueElement = element.getAsOpaquePointer();
241 auto [it, inserted] = cache.try_emplace(opaqueElement, opaqueElement);
243 return T::getFromOpaquePointer(it->second);
244
246
247 cache[opaqueElement] = result.getAsOpaquePointer();
249}
250
252 return cachedReplaceImpl(attr);
253}
254
256
257
258
259
260
262
264 : cache([&](void *attr) { return breakCycleImpl(attr); }) {}
265
267 attrCycleBreakerFns.emplace_back(std::move(fn));
268}
269
271 typeCycleBreakerFns.emplace_back(std::move(fn));
272}
273
274template
275T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) {
276 void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue();
279 if (auto resultOpt = cacheEntry.get())
280 return T::getFromOpaquePointer(*resultOpt);
281
283
286}
287
289 return cachedReplaceImpl(attr);
290}
291
293 return cachedReplaceImpl(type);
294}
295
296std::optional<const void *>
297CyclicAttrTypeReplacer::breakCycleImpl(void *element) {
298 AttrOrType attrType = AttrOrType::getFromOpaqueValue(element);
299 if (auto attr = dyn_cast(attrType)) {
300 for (auto &cyclicReplaceFn : llvm::reverse(attrCycleBreakerFns)) {
301 if (std::optional newRes = cyclicReplaceFn(attr)) {
302 return newRes->getAsOpaquePointer();
303 }
304 }
305 } else {
306 auto type = dyn_cast(attrType);
307 for (auto &cyclicReplaceFn : llvm::reverse(typeCycleBreakerFns)) {
308 if (std::optional newRes = cyclicReplaceFn(type)) {
309 return newRes->getAsOpaquePointer();
310 }
311 }
312 }
313 return std::nullopt;
314}
315
316
317
318
319
321 if (element)
322 walkAttrsFn(element);
323}
324
326 if (element)
327 walkTypesFn(element);
328}
static void updateSubElementImpl(T element, Replacer &replacer, SmallVectorImpl< T > &newElements, FailureOr< bool > &changed)
Definition AttrTypeSubElements.cpp:145
static T replaceElementImpl(T element, ReplaceFns &replaceFns, Replacer &replacer)
Shared implementation of replacing a given attribute or type element.
Definition AttrTypeSubElements.cpp:193
static T replaceSubElements(T interface, Replacer &replacer)
Definition AttrTypeSubElements.cpp:169
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Attribute replace(Attribute attr)
Definition AttrTypeSubElements.cpp:251
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
void addCycleBreaker(CycleBreakerFn< Attribute > fn)
Register a cycle-breaking function.
Definition AttrTypeSubElements.cpp:266
Attribute replace(Attribute attr)
Definition AttrTypeSubElements.cpp:288
std::function< std::optional< T >(T)> CycleBreakerFn
A cycle-breaking function.
CyclicAttrTypeReplacer()
Definition AttrTypeSubElements.cpp:263
CacheEntry lookupOrInit(InT element)
Lookup the cache for a pre-calculated replacement for element.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A utility result that is used to signal how to proceed with an ongoing walk:
bool wasSkipped() const
Returns true if the walk was skipped.
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
This class provides a base utility for replacing attributes/types, and their sub elements.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
Definition AttrTypeSubElements.cpp:137
Attribute replaceBase(Attribute attr)
Invokes the registered replacement functions from most recently registered to least recently register...
Definition AttrTypeSubElements.cpp:221
std::function< ReplaceFnResult< T >(T)> ReplaceFn
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
Definition AttrTypeSubElements.cpp:86
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
WalkOrder
Traversal order for region, block and operation walk utilities.
A possibly unresolved cache entry.
void resolve(OutT result)
Resolve an unresolved cache entry by providing the result to be stored in the cache.
const std::optional< OutT > & get() const
Get the resolved result if one exists.