MLIR: include/mlir/Support/CyclicReplacerCache.h Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15#ifndef MLIR_SUPPORT_CYCLICREPLACERCACHE_H

16#define MLIR_SUPPORT_CYCLICREPLACERCACHE_H

17

18#include "llvm/ADT/DenseMap.h"

19#include "llvm/ADT/DenseSet.h"

20#include "llvm/ADT/SmallVector.h"

21#include

22#include

23#include

24

25namespace mlir {

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51template <typename InT, typename OutT>

53public:

54

55

56

57 using CycleBreakerFn = std::function<std::optional(InT)>;

58

61 : cycleBreaker(std::move(cycleBreaker)) {}

62

63

64

65 struct CacheEntry {

66 public:

67 ~CacheEntry() { assert(result && "unresovled cache entry"); }

68

69

70

71

72

74

75

76 ReplacementFrame &currFrame = cache.replacementStack.back();

77 size_t currFrameIndex = cache.replacementStack.size() - 1;

78 return currFrame.dependentFrames.count(currFrameIndex);

79 }

80

81

82

84 assert(!this->result && "cache entry already resolved");

85 cache.finalizeReplacement(element, result);

86 this->result = std::move(result);

87 }

88

89

90 const std::optional &get() const { return result; }

91

92 private:

94 CacheEntry() = delete;

96 std::optional result = std::nullopt)

97 : cache(cache), element(std::move(element)), result(result) {}

98

100 InT element;

101 std::optional result;

102 };

103

104

105

106

107

108

109

110

111

112

113

115

116private:

117

118 void finalizeReplacement(InT element, OutT result);

119

122

123 struct DependentReplacement {

125

126

127 size_t highestDependentFrame;

128 };

130

131 struct ReplacementFrame {

132

133

134

136

137

138 std::set<size_t, std::greater<size_t>> dependentFrames;

139 };

140

141

143

145

146

147

148 bool resolvingCycle = false;

149};

150

151template <typename InT, typename OutT>

154 assert(!resolvingCycle &&

155 "illegal recursive invocation while breaking cycle");

156

157 if (auto it = standaloneCache.find(element); it != standaloneCache.end())

158 return CacheEntry(*this, element, it->second);

159

160 if (auto it = dependentCache.find(element); it != dependentCache.end()) {

161

162

163 ReplacementFrame &currFrame = replacementStack.back();

164 currFrame.dependentFrames.insert(it->second.highestDependentFrame);

165 return CacheEntry(*this, element, it->second.replacement);

166 }

167

168 auto [it, inserted] = cyclicElementFrame.try_emplace(element);

170

171 resolvingCycle = true;

172 std::optional result = cycleBreaker(element);

173 resolvingCycle = false;

175

176 size_t dependentFrame = it->second.back();

177 dependentCache[element] = {*result, dependentFrame};

178 ReplacementFrame &currFrame = replacementStack.back();

179

180

181 currFrame.dependentFrames.insert(dependentFrame);

182

184 }

185

186

187

188

189

190

191 assert(it->second.size() <= 2 && "illegal 3rd repeat of input");

192 }

193

194

195

196 it->second.push_back(replacementStack.size());

197 replacementStack.emplace_back();

198

200}

201

202template <typename InT, typename OutT>

203void CyclicReplacerCache<InT, OutT>::finalizeReplacement(InT element,

205 ReplacementFrame &currFrame = replacementStack.back();

206

207

208 currFrame.dependentFrames.erase(replacementStack.size() - 1);

209

210 auto prevLayerIter = ++replacementStack.rbegin();

211 if (prevLayerIter == replacementStack.rend()) {

212

213 assert(currFrame.dependentFrames.empty() &&

214 "internal error: top-level dependent replacement");

215

216 standaloneCache[element] = result;

217 } else if (currFrame.dependentFrames.empty()) {

218

219 standaloneCache[element] = result;

220 } else {

221

222 size_t highestDependentFrame = *currFrame.dependentFrames.begin();

223 dependentCache[element] = {result, highestDependentFrame};

224

225

226 prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(),

227 currFrame.dependentFrames.end());

228

229

230

231 replacementStack[highestDependentFrame].dependingReplacements.insert(

232 element);

233 }

234

235

236 for (InT key : currFrame.dependingReplacements)

237 dependentCache.erase(key);

238

239 replacementStack.pop_back();

240 auto it = cyclicElementFrame.find(element);

241 it->second.pop_back();

242 if (it->second.empty())

243 cyclicElementFrame.erase(it);

244}

245

246

247

248

249

250

251

252

253

254template <typename InT, typename OutT>

256public:

260

263 : replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}

264

266 auto cacheEntry = cache.lookupOrInit(element);

267 if (std::optional result = cacheEntry.get())

269

270 OutT result = replacer(element);

271 cacheEntry.resolve(result);

273 }

274

275private:

278};

279

280}

281

282#endif

*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method

*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`

CachedCyclicReplacer()=delete

OutT operator()(InT element)

Definition CyclicReplacerCache.h:265

CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)

Definition CyclicReplacerCache.h:262

typename CyclicReplacerCache< InT, OutT >::CycleBreakerFn CycleBreakerFn

Definition CyclicReplacerCache.h:258

std::function< OutT(InT)> ReplacerFn

Definition CyclicReplacerCache.h:257

A cache for replacer-like functions that map values between two domains.

Definition CyclicReplacerCache.h:52

CacheEntry lookupOrInit(InT element)

Lookup the cache for a pre-calculated replacement for element.

Definition CyclicReplacerCache.h:153

std::function< std::optional< OutT >(InT)> CycleBreakerFn

User-provided replacement function & cycle-breaking functions.

Definition CyclicReplacerCache.h:57

CyclicReplacerCache()=delete

CyclicReplacerCache(CycleBreakerFn cycleBreaker)

Definition CyclicReplacerCache.h:60

Include the generated interface declarations.

A possibly unresolved cache entry.

Definition CyclicReplacerCache.h:65

void resolve(OutT result)

Resolve an unresolved cache entry by providing the result to be stored in the cache.

Definition CyclicReplacerCache.h:83

friend class CyclicReplacerCache

Definition CyclicReplacerCache.h:93

bool wasRepeated() const

Check whether this node was repeated during recursive replacements.

Definition CyclicReplacerCache.h:73

~CacheEntry()

Definition CyclicReplacerCache.h:67

const std::optional< OutT > & get() const

Get the resolved result if one exists.

Definition CyclicReplacerCache.h:90