MLIR: lib/IR/AttrTypeSubElements.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10#include

11

12using namespace mlir;

13

14

15

16

17

19 return walkImpl(attr, attrWalkFns, order);

20}

22 return walkImpl(type, typeWalkFns, order);

23}

24

25template <typename T, typename WalkFns>

26WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns,

28

29 auto key = std::make_pair(element.getAsOpaquePointer(), (int)order);

33 return it->second;

34

35

37 if (walkSubElements(element, order).wasInterrupted())

39 }

40

41

42 for (auto &walkFn : llvm::reverse(walkFns)) {

43 WalkResult walkResult = walkFn(element);

48 }

49

50

52 if (walkSubElements(element, order).wasInterrupted())

54 }

56}

57

58template

61 auto walkFn = [&](auto element) {

62 if (element && result.wasInterrupted())

63 result = walkImpl(element, order);

64 };

65 interface.walkImmediateSubElements(walkFn, walkFn);

67}

68

69

70

71

72

73template

76 attrReplacementFns.emplace_back(std::move(fn));

77}

78

79template

82 typeReplacementFns.push_back(std::move(fn));

83}

84

85template

87 Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {

88

89

90 auto replaceIfDifferent = [&](auto element) {

91 auto replacement = static_cast<Concrete *>(this)->replace(element);

93 };

94

95

96 if (replaceAttrs) {

98 op->setAttrs(cast(newAttrs));

99 }

100

101

102 if (!replaceTypes && !replaceLocs)

103 return;

104

105

106 if (replaceLocs) {

108 op->setLoc(cast(newLoc));

109 }

110

111

112 if (replaceTypes) {

114 if (Type newType = replaceIfDifferent(result.getType()))

115 result.setType(newType);

116 }

117

118

120 for (Block &block : region) {

121 for (BlockArgument &arg : block.getArguments()) {

122 if (replaceLocs) {

123 if (Attribute newLoc = replaceIfDifferent(arg.getLoc()))

124 arg.setLoc(cast(newLoc));

125 }

126

127 if (replaceTypes) {

128 if (Type newType = replaceIfDifferent(arg.getType()))

129 arg.setType(newType);

130 }

131 }

132 }

133 }

134}

135

136template

138 Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {

140 replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes);

141 });

142}

143

144template <typename T, typename Replacer>

147 FailureOr &changed) {

148

150 return;

152

153 if (!element) {

154 newElements.push_back(nullptr);

155 return;

156 }

157

158

159 if (T result = replacer.replace(element)) {

160 newElements.push_back(result);

161 if (result != element)

163 } else {

165 }

166}

167

168template <typename T, typename Replacer>

170

173 FailureOr changed = false;

174 interface.walkImmediateSubElements(

177 },

178 [&](Type element) {

180 });

182 return nullptr;

183

184

185 T result = interface;

187 result = interface.replaceImmediateSubElements(newAttrs, newTypes);

189}

190

191

192template <typename T, typename ReplaceFns, typename Replacer>

194 Replacer &replacer) {

197 for (auto &replaceFn : llvm::reverse(replaceFns)) {

198 if (std::optional<std::pair<T, WalkResult>> newRes = replaceFn(element)) {

199 std::tie(result, walkResult) = *newRes;

200 break;

201 }

202 }

203

204

206 return nullptr;

207 }

208

209

211

213 return nullptr;

214 }

215 }

216

218}

219

220template

223 *static_cast<Concrete *>(this));

224}

225

226template

229 *static_cast<Concrete *>(this));

230}

231

232

233

234

235

237

238template

239T AttrTypeReplacer::cachedReplaceImpl(T element) {

240 const void *opaqueElement = element.getAsOpaquePointer();

241 auto [it, inserted] = cache.try_emplace(opaqueElement, opaqueElement);

243 return T::getFromOpaquePointer(it->second);

244

246

247 cache[opaqueElement] = result.getAsOpaquePointer();

249}

250

252 return cachedReplaceImpl(attr);

253}

254

256

257

258

259

260

262

264 : cache([&](void *attr) { return breakCycleImpl(attr); }) {}

265

267 attrCycleBreakerFns.emplace_back(std::move(fn));

268}

269

271 typeCycleBreakerFns.emplace_back(std::move(fn));

272}

273

274template

275T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) {

276 void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue();

279 if (auto resultOpt = cacheEntry.get())

280 return T::getFromOpaquePointer(*resultOpt);

281

283

286}

287

289 return cachedReplaceImpl(attr);

290}

291

293 return cachedReplaceImpl(type);

294}

295

296std::optional<const void *>

297CyclicAttrTypeReplacer::breakCycleImpl(void *element) {

298 AttrOrType attrType = AttrOrType::getFromOpaqueValue(element);

299 if (auto attr = dyn_cast(attrType)) {

300 for (auto &cyclicReplaceFn : llvm::reverse(attrCycleBreakerFns)) {

301 if (std::optional newRes = cyclicReplaceFn(attr)) {

302 return newRes->getAsOpaquePointer();

303 }

304 }

305 } else {

306 auto type = dyn_cast(attrType);

307 for (auto &cyclicReplaceFn : llvm::reverse(typeCycleBreakerFns)) {

308 if (std::optional newRes = cyclicReplaceFn(type)) {

309 return newRes->getAsOpaquePointer();

310 }

311 }

312 }

313 return std::nullopt;

314}

315

316

317

318

319

321 if (element)

322 walkAttrsFn(element);

323}

324

326 if (element)

327 walkTypesFn(element);

328}

static void updateSubElementImpl(T element, Replacer &replacer, SmallVectorImpl< T > &newElements, FailureOr< bool > &changed)

Definition AttrTypeSubElements.cpp:145

static T replaceElementImpl(T element, ReplaceFns &replaceFns, Replacer &replacer)

Shared implementation of replacing a given attribute or type element.

Definition AttrTypeSubElements.cpp:193

static T replaceSubElements(T interface, Replacer &replacer)

Definition AttrTypeSubElements.cpp:169

*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method

*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`

Attribute replace(Attribute attr)

Definition AttrTypeSubElements.cpp:251

Attributes are known-constant values of operations.

This class represents an argument of a Block.

Block represents an ordered list of Operations.

void addCycleBreaker(CycleBreakerFn< Attribute > fn)

Register a cycle-breaking function.

Definition AttrTypeSubElements.cpp:266

Attribute replace(Attribute attr)

Definition AttrTypeSubElements.cpp:288

std::function< std::optional< T >(T)> CycleBreakerFn

A cycle-breaking function.

CyclicAttrTypeReplacer()

Definition AttrTypeSubElements.cpp:263

CacheEntry lookupOrInit(InT element)

Lookup the cache for a pre-calculated replacement for element.

This is a value defined by a result of an operation.

Operation is the basic unit of execution within MLIR.

void setLoc(Location loc)

Set the source location the operation was defined or derived from.

DictionaryAttr getAttrDictionary()

Return all of the attributes on this operation as a DictionaryAttr.

void setAttrs(DictionaryAttr newAttrs)

Set the attributes from a dictionary on this operation.

Location getLoc()

The source location the operation was defined or derived from.

MutableArrayRef< Region > getRegions()

Returns the regions held by this operation.

std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)

Walk the operation by calling the callback for each nested operation (including this one),...

result_range getResults()

This class contains a list of basic blocks and a link to the parent operation it is attached to.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

A utility result that is used to signal how to proceed with an ongoing walk:

bool wasSkipped() const

Returns true if the walk was skipped.

static WalkResult advance()

bool wasInterrupted() const

Returns true if the walk was interrupted.

static WalkResult interrupt()

This class provides a base utility for replacing attributes/types, and their sub elements.

void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)

Replace the elements within the given operation, and all nested operations.

Definition AttrTypeSubElements.cpp:137

Attribute replaceBase(Attribute attr)

Invokes the registered replacement functions from most recently registered to least recently register...

Definition AttrTypeSubElements.cpp:221

std::function< ReplaceFnResult< T >(T)> ReplaceFn

void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)

Replace the elements within the given operation.

Definition AttrTypeSubElements.cpp:86

void addReplacement(ReplaceFn< Attribute > fn)

Register a replacement function for mapping a given attribute or type.

Include the generated interface declarations.

const FrozenRewritePatternSet GreedyRewriteConfig bool * changed

WalkOrder

Traversal order for region, block and operation walk utilities.

A possibly unresolved cache entry.

void resolve(OutT result)

Resolve an unresolved cache entry by providing the result to be stored in the cache.

const std::optional< OutT > & get() const

Get the resolved result if one exists.