MLIR: lib/Analysis/TopologicalSortUtils.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

13

14 #include "llvm/ADT/PostOrderIterator.h"

15 #include "llvm/ADT/SetVector.h"

16

17 using namespace mlir;

18

19

22

23

24 const auto isReady = [&](Value value) {

25

26 if (isOperandReady && isOperandReady(value, op))

27 return true;

28 Operation *parent = value.getDefiningOp();

29

30 if (!parent)

31 return true;

32

33

34 do {

35

36 if (parent == op)

37 return true;

38 if (unscheduledOps.contains(parent))

39 return false;

40 } while ((parent = parent->getParentOp()));

41

42 return true;

43 };

44

45

46

48 return llvm::all_of(nestedOp->getOperands(),

49 [&](Value operand) { return isReady(operand); })

52 });

54 }

55

59 if (ops.empty())

60 return true;

61

62

64

66 unscheduledOps.insert(&op);

67

70

71 bool allOpsScheduled = true;

72 while (!unscheduledOps.empty()) {

73 bool scheduledAtLeastOnce = false;

74

75

76

77

79 llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {

80 if (isOpReady(&op, unscheduledOps, isOperandReady))

81 continue;

82

83

84 unscheduledOps.erase(&op);

85 op.moveBefore(block, nextScheduledOp);

86 scheduledAtLeastOnce = true;

87

88 if (&op == &*nextScheduledOp)

89 ++nextScheduledOp;

90 }

91

92 if (!scheduledAtLeastOnce) {

93 allOpsScheduled = false;

94 unscheduledOps.erase(&*nextScheduledOp);

95 ++nextScheduledOp;

96 }

97 }

98

99 return allOpsScheduled;

100 }

101

104 if (block->empty())

105 return true;

108 isOperandReady);

110 }

111

115 if (ops.empty())

116 return true;

117

118

119

121

122 unsigned nextScheduledOp = 0;

123

124 bool allOpsScheduled = true;

125 while (!unscheduledOps.empty()) {

126 bool scheduledAtLeastOnce = false;

127

128

129

130

131 for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {

132 if (isOpReady(ops[i], unscheduledOps, isOperandReady))

133 continue;

134

135

136 unscheduledOps.erase(ops[i]);

137 std::swap(ops[i], ops[nextScheduledOp]);

138 scheduledAtLeastOnce = true;

139 ++nextScheduledOp;

140 }

141

142

143 if (!scheduledAtLeastOnce) {

144 allOpsScheduled = false;

145 unscheduledOps.erase(ops[nextScheduledOp++]);

146 }

147 }

148

149 return allOpsScheduled;

150 }

151

153

154

156 for (Block &b : region) {

157 if (blocks.count(&b) == 0) {

158 llvm::ReversePostOrderTraversal<Block *> traversal(&b);

159 blocks.insert_range(traversal);

160 }

161 }

162 assert(blocks.size() == region.getBlocks().size() &&

163 "some blocks are not sorted");

164

165 return blocks;

166 }

167

168 namespace {

169 class TopoSortHelper {

170 public:

172 : toSort(toSort) {}

173

174

175

176

178 if (toSort.size() <= 1) {

179

180 return toSort;

181 }

182

183

184

185

186 Region *rootRegion = findCommonAncestorRegion();

187 assert(rootRegion && "expected all ops to have a common ancestor");

188

189

190

192 assert(result.size() == toSort.size() &&

193 "expected all operations to be present in the result");

194 return result;

195 }

196

197 private:

198

199 Region *findCommonAncestorRegion() {

200

202 size_t expectedCount = toSort.size();

203

204

205

206 Region *res = nullptr;

209

210 ancestorBlocks.insert(op->getBlock());

211 while (current) {

212

213 if (++regionCounts[current] == expectedCount) {

214 res = current;

215 break;

216 }

219 }

220 }

221 auto firstRange = llvm::make_first_range(regionCounts);

222 ancestorRegions.insert_range(firstRange);

223 return res;

224 }

225

226

227

230

232

234 stack.push_back(&rootRegion);

235

236

237 while (!stack.empty()) {

238 StackT current = stack.pop_back_val();

239 if (auto *region = dyn_cast<Region *>(current)) {

240

242 for (Block *block : llvm::reverse(sortedBlocks)) {

243

244

245 if (ancestorBlocks.contains(block))

246 stack.push_back(block);

247 }

248 continue;

249 }

250

251 if (auto *block = dyn_cast<Block *>(current)) {

252

253 for (Operation &op : llvm::reverse(*block))

254 stack.push_back(&op);

255 continue;

256 }

257

258 auto *op = cast<Operation *>(current);

259 if (toSort.contains(op))

260 result.insert(op);

261

262

264 if (ancestorRegions.contains(&subRegion))

265 stack.push_back(&subRegion);

266 }

267 return result;

268 }

269

270

272

274

276 };

277 }

278

281 return TopoSortHelper(toSort).sort();

282 }

static bool isOpReady(Operation *op, DenseSet< Operation * > &unscheduledOps, function_ref< bool(Value, Operation *)> isOperandReady)

Return true if the given operation is ready to be scheduled.

Block represents an ordered list of Operations.

OpListType::iterator iterator

iterator_range< iterator > without_terminator()

Return an iterator range over the operation within this block excluding the terminator operation at t...

This class provides the API for ops that are known to be terminators.

Operation is the basic unit of execution within MLIR.

bool hasTrait()

Returns true if the operation was registered with a particular trait, e.g.

std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)

Walk the operation by calling the callback for each nested operation (including this one),...

Operation * getParentOp()

Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...

Block * getBlock()

Returns the operation block that contains this operation.

MutableArrayRef< Region > getRegions()

Returns the regions held by this operation.

operand_range getOperands()

Returns an iterator on the underlying Value's.

void moveBefore(Operation *existingOp)

Unlink this operation from its current block and insert it right before existingOp which may be in th...

Region * getParentRegion()

Returns the region to which the instruction belongs.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

Region * getParentRegion()

Return the region containing this region or nullptr if the region is attached to a top-level operatio...

Operation * getParentOp()

Return the parent operation this region is attached to.

BlockListType & getBlocks()

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

A utility result that is used to signal how to proceed with an ongoing walk:

static WalkResult advance()

bool wasInterrupted() const

Returns true if the walk was interrupted.

static WalkResult interrupt()

Include the generated interface declarations.

SetVector< Block * > getBlocksSortedByDominance(Region &region)

Gets a list of blocks that is sorted according to dominance.

bool computeTopologicalSorting(MutableArrayRef< Operation * > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)

Compute a topological ordering of the given ops.

bool sortTopologically(Block *block, iterator_range< Block::iterator > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)

Given a block, sort a range operations in said block in topological order.

SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)

Sorts all operations in toSort topologically while also considering region semantics.