MLIR: include/mlir/Support/CyclicReplacerCache.h Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15#ifndef MLIR_SUPPORT_CYCLICREPLACERCACHE_H
16#define MLIR_SUPPORT_CYCLICREPLACERCACHE_H
17
18#include "llvm/ADT/DenseMap.h"
19#include "llvm/ADT/DenseSet.h"
20#include "llvm/ADT/SmallVector.h"
21#include
22#include
23#include
24
25namespace mlir {
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51template <typename InT, typename OutT>
53public:
54
55
56
57 using CycleBreakerFn = std::function<std::optional(InT)>;
58
61 : cycleBreaker(std::move(cycleBreaker)) {}
62
63
64
65 struct CacheEntry {
66 public:
67 ~CacheEntry() { assert(result && "unresovled cache entry"); }
68
69
70
71
72
74
75
76 ReplacementFrame &currFrame = cache.replacementStack.back();
77 size_t currFrameIndex = cache.replacementStack.size() - 1;
78 return currFrame.dependentFrames.count(currFrameIndex);
79 }
80
81
82
84 assert(!this->result && "cache entry already resolved");
85 cache.finalizeReplacement(element, result);
86 this->result = std::move(result);
87 }
88
89
90 const std::optional &get() const { return result; }
91
92 private:
94 CacheEntry() = delete;
96 std::optional result = std::nullopt)
97 : cache(cache), element(std::move(element)), result(result) {}
98
100 InT element;
101 std::optional result;
102 };
103
104
105
106
107
108
109
110
111
112
113
115
116private:
117
118 void finalizeReplacement(InT element, OutT result);
119
122
123 struct DependentReplacement {
125
126
127 size_t highestDependentFrame;
128 };
130
131 struct ReplacementFrame {
132
133
134
136
137
138 std::set<size_t, std::greater<size_t>> dependentFrames;
139 };
140
141
143
145
146
147
148 bool resolvingCycle = false;
149};
150
151template <typename InT, typename OutT>
154 assert(!resolvingCycle &&
155 "illegal recursive invocation while breaking cycle");
156
157 if (auto it = standaloneCache.find(element); it != standaloneCache.end())
158 return CacheEntry(*this, element, it->second);
159
160 if (auto it = dependentCache.find(element); it != dependentCache.end()) {
161
162
163 ReplacementFrame &currFrame = replacementStack.back();
164 currFrame.dependentFrames.insert(it->second.highestDependentFrame);
165 return CacheEntry(*this, element, it->second.replacement);
166 }
167
168 auto [it, inserted] = cyclicElementFrame.try_emplace(element);
170
171 resolvingCycle = true;
172 std::optional result = cycleBreaker(element);
173 resolvingCycle = false;
175
176 size_t dependentFrame = it->second.back();
177 dependentCache[element] = {*result, dependentFrame};
178 ReplacementFrame &currFrame = replacementStack.back();
179
180
181 currFrame.dependentFrames.insert(dependentFrame);
182
184 }
185
186
187
188
189
190
191 assert(it->second.size() <= 2 && "illegal 3rd repeat of input");
192 }
193
194
195
196 it->second.push_back(replacementStack.size());
197 replacementStack.emplace_back();
198
200}
201
202template <typename InT, typename OutT>
203void CyclicReplacerCache<InT, OutT>::finalizeReplacement(InT element,
205 ReplacementFrame &currFrame = replacementStack.back();
206
207
208 currFrame.dependentFrames.erase(replacementStack.size() - 1);
209
210 auto prevLayerIter = ++replacementStack.rbegin();
211 if (prevLayerIter == replacementStack.rend()) {
212
213 assert(currFrame.dependentFrames.empty() &&
214 "internal error: top-level dependent replacement");
215
216 standaloneCache[element] = result;
217 } else if (currFrame.dependentFrames.empty()) {
218
219 standaloneCache[element] = result;
220 } else {
221
222 size_t highestDependentFrame = *currFrame.dependentFrames.begin();
223 dependentCache[element] = {result, highestDependentFrame};
224
225
226 prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(),
227 currFrame.dependentFrames.end());
228
229
230
231 replacementStack[highestDependentFrame].dependingReplacements.insert(
232 element);
233 }
234
235
236 for (InT key : currFrame.dependingReplacements)
237 dependentCache.erase(key);
238
239 replacementStack.pop_back();
240 auto it = cyclicElementFrame.find(element);
241 it->second.pop_back();
242 if (it->second.empty())
243 cyclicElementFrame.erase(it);
244}
245
246
247
248
249
250
251
252
253
254template <typename InT, typename OutT>
256public:
260
263 : replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}
264
266 auto cacheEntry = cache.lookupOrInit(element);
267 if (std::optional result = cacheEntry.get())
269
270 OutT result = replacer(element);
271 cacheEntry.resolve(result);
273 }
274
275private:
278};
279
280}
281
282#endif
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
CachedCyclicReplacer()=delete
OutT operator()(InT element)
Definition CyclicReplacerCache.h:265
CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)
Definition CyclicReplacerCache.h:262
typename CyclicReplacerCache< InT, OutT >::CycleBreakerFn CycleBreakerFn
Definition CyclicReplacerCache.h:258
std::function< OutT(InT)> ReplacerFn
Definition CyclicReplacerCache.h:257
A cache for replacer-like functions that map values between two domains.
Definition CyclicReplacerCache.h:52
CacheEntry lookupOrInit(InT element)
Lookup the cache for a pre-calculated replacement for element.
Definition CyclicReplacerCache.h:153
std::function< std::optional< OutT >(InT)> CycleBreakerFn
User-provided replacement function & cycle-breaking functions.
Definition CyclicReplacerCache.h:57
CyclicReplacerCache()=delete
CyclicReplacerCache(CycleBreakerFn cycleBreaker)
Definition CyclicReplacerCache.h:60
Include the generated interface declarations.
A possibly unresolved cache entry.
Definition CyclicReplacerCache.h:65
void resolve(OutT result)
Resolve an unresolved cache entry by providing the result to be stored in the cache.
Definition CyclicReplacerCache.h:83
friend class CyclicReplacerCache
Definition CyclicReplacerCache.h:93
bool wasRepeated() const
Check whether this node was repeated during recursive replacements.
Definition CyclicReplacerCache.h:73
~CacheEntry()
Definition CyclicReplacerCache.h:67
const std::optional< OutT > & get() const
Get the resolved result if one exists.
Definition CyclicReplacerCache.h:90