MLIR: lib/Transforms/Utils/WalkPatternRewriteDriver.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

14

22#include "llvm/ADT/STLExtras.h"

23#include "llvm/Support/DebugLog.h"

24#include "llvm/Support/ErrorHandling.h"

25

26#define DEBUG_TYPE "walk-rewriter"

27

28namespace mlir {

29

30

31

35 reachableBlocks.insert(entryBlock);

36

38 while (!worklist.empty()) {

39 Block *block = worklist.pop_back_val();

42 if (reachableBlocks.contains(successor))

43 continue;

44 worklist.push_back(successor);

45 reachableBlocks.insert(successor);

46 }

47 }

48}

49

50namespace {

51struct WalkAndApplyPatternsAction final

52 : tracing::ActionImpl {

54 using ActionImpl::ActionImpl;

55 static constexpr StringLiteral tag = "walk-and-apply-patterns";

57};

58

59#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

60

61

62

63

65 using RewriterBase::ForwardingListener::ForwardingListener;

66

67 void notifyOperationErased(Operation *op) override {

68 checkErasure(op);

69 ForwardingListener::notifyOperationErased(op);

70 }

71

72 void notifyBlockErased(Block *block) override {

73 checkErasure(block->getParentOp());

74 ForwardingListener::notifyBlockErased(block);

75 }

76

77 void checkErasure(Operation *op) const {

78 Operation *ancestorOp = op;

79 while (ancestorOp && ancestorOp != visitedOp)

80 ancestorOp = ancestorOp->getParentOp();

81

82 if (ancestorOp != visitedOp)

83 llvm::report_fatal_error(

84 "unsupported erasure in WalkPatternRewriter; "

85 "erasure is only supported for matched ops and their descendants");

86 }

87

88 Operation *visitedOp = nullptr;

89};

90#endif

91}

92

96#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

97 if (failed(verify(op)))

98 llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");

99#endif

100

103#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

104 ErasedOpsListener erasedListener(listener);

106#else

108#endif

109

112

113

114

115

116 struct RegionReachableOpIterator {

117 RegionReachableOpIterator(Region *region) : region(region) {

118 regionIt = region->begin();

119 if (regionIt != region->end())

120 blockIt = regionIt->begin();

121 if (!llvm::hasSingleElement(*region))

123 }

124

125 void advance() {

126 assert(regionIt != region->end());

127 hasVisitedRegions = false;

128 if (blockIt == regionIt->end()) {

129 ++regionIt;

130 while (regionIt != region->end() &&

131 !reachableBlocks.contains(&*regionIt))

132 ++regionIt;

133 if (regionIt != region->end())

134 blockIt = regionIt->begin();

135 return;

136 }

137 ++blockIt;

138 if (blockIt != regionIt->end()) {

139 LDBG() << "Incrementing block iterator, next op: "

141 }

142 }

143

145

147

149

151

152 bool hasVisitedRegions = false;

153 };

154

155

157

158 LDBG() << "Starting walk-based pattern rewrite driver";

160 [&] {

161

162

164 assert(worklist.empty());

165 if (region.empty())

166 continue;

167

168

169 worklist.push_back({&region});

170 while (!worklist.empty()) {

171 RegionReachableOpIterator &it = worklist.back();

172 if (it.regionIt == it.region->end()) {

173

174 worklist.pop_back();

175 continue;

176 }

177 if (it.blockIt == it.regionIt->end()) {

178

179 it.advance();

180 continue;

181 }

183

184

185 if (!it.hasVisitedRegions) {

186 it.hasVisitedRegions = true;

187 for (Region &nestedRegion : llvm::reverse(op->getRegions())) {

188 if (nestedRegion.empty())

189 continue;

190 worklist.push_back({&nestedRegion});

191 }

192 }

193

194

195

196 if (&it != &worklist.back())

197 continue;

198

199

200

201 it.advance();

202

203 LDBG() << "Visiting op: "

205#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

206 erasedListener.visitedOp = op;

207#endif

209 LDBG() << "\tOp matched and rewritten";

210 }

211 }

212 },

213 {op});

214

215#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

216 if (failed(verify(op)))

217 llvm::report_fatal_error(

218 "walk pattern rewriter result IR failed to verify");

219#endif

220}

221

222}

static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)

#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)

Block represents an ordered list of Operations.

OpListType::iterator iterator

This class represents a frozen set of patterns that can be processed by a pattern applicator.

MLIRContext is the top-level object for a collection of MLIR operations.

void executeAction(function_ref< void()> actionFn, const tracing::Action &action)

Dispatch the provided action to the handler if any, or just execute it.

void setListener(Listener *newListener)

Sets the listener of this builder to the one provided.

Set of flags used to control the behavior of the various IR print methods (e.g.

A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...

Operation is the basic unit of execution within MLIR.

MutableArrayRef< Region > getRegions()

Returns the regions held by this operation.

SuccessorRange getSuccessors()

MLIRContext * getContext()

Return the context this operation is associated with.

This class manages the application of a group of rewrite patterns, with a user-provided cost model.

LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})

Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...

void applyDefaultCostModel()

Apply the default cost model that solely uses the pattern's static benefit.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class contains a list of basic blocks and a link to the parent operation it is attached to.

BlockListType::iterator iterator

Include the generated interface declarations.

static void findReachableBlocks(Region &region, DenseSet< Block * > &reachableBlocks)

Definition WalkPatternRewriteDriver.cpp:32

llvm::DenseSet< ValueT, ValueInfoT > DenseSet

const FrozenRewritePatternSet & patterns

void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)

A fast walk-based pattern rewrite driver.

Definition WalkPatternRewriteDriver.cpp:93

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

A listener that forwards all notifications to another listener.