MLIR: lib/Analysis/TopologicalSortUtils.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
13
14 #include "llvm/ADT/PostOrderIterator.h"
15 #include "llvm/ADT/SetVector.h"
16
17 using namespace mlir;
18
19
22
23
24 const auto isReady = [&](Value value) {
25
26 if (isOperandReady && isOperandReady(value, op))
27 return true;
28 Operation *parent = value.getDefiningOp();
29
30 if (!parent)
31 return true;
32
33
34 do {
35
36 if (parent == op)
37 return true;
38 if (unscheduledOps.contains(parent))
39 return false;
40 } while ((parent = parent->getParentOp()));
41
42 return true;
43 };
44
45
46
48 return llvm::all_of(nestedOp->getOperands(),
49 [&](Value operand) { return isReady(operand); })
52 });
54 }
55
59 if (ops.empty())
60 return true;
61
62
64
66 unscheduledOps.insert(&op);
67
70
71 bool allOpsScheduled = true;
72 while (!unscheduledOps.empty()) {
73 bool scheduledAtLeastOnce = false;
74
75
76
77
79 llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
80 if ((&op, unscheduledOps, isOperandReady))
81 continue;
82
83
84 unscheduledOps.erase(&op);
85 op.moveBefore(block, nextScheduledOp);
86 scheduledAtLeastOnce = true;
87
88 if (&op == &*nextScheduledOp)
89 ++nextScheduledOp;
90 }
91
92 if (!scheduledAtLeastOnce) {
93 allOpsScheduled = false;
94 unscheduledOps.erase(&*nextScheduledOp);
95 ++nextScheduledOp;
96 }
97 }
98
99 return allOpsScheduled;
100 }
101
104 if (block->empty())
105 return true;
108 isOperandReady);
110 }
111
115 if (ops.empty())
116 return true;
117
118
119
121
122 unsigned nextScheduledOp = 0;
123
124 bool allOpsScheduled = true;
125 while (!unscheduledOps.empty()) {
126 bool scheduledAtLeastOnce = false;
127
128
129
130
131 for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
132 if ((ops[i], unscheduledOps, isOperandReady))
133 continue;
134
135
136 unscheduledOps.erase(ops[i]);
137 std::swap(ops[i], ops[nextScheduledOp]);
138 scheduledAtLeastOnce = true;
139 ++nextScheduledOp;
140 }
141
142
143 if (!scheduledAtLeastOnce) {
144 allOpsScheduled = false;
145 unscheduledOps.erase(ops[nextScheduledOp++]);
146 }
147 }
148
149 return allOpsScheduled;
150 }
151
153
154
156 for (Block &b : region) {
157 if (blocks.count(&b) == 0) {
158 llvm::ReversePostOrderTraversal<Block *> traversal(&b);
159 blocks.insert_range(traversal);
160 }
161 }
162 assert(blocks.size() == region.getBlocks().size() &&
163 "some blocks are not sorted");
164
165 return blocks;
166 }
167
168 namespace {
169 class TopoSortHelper {
170 public:
172 : toSort(toSort) {}
173
174
175
176
178 if (toSort.size() <= 1) {
179
180 return toSort;
181 }
182
183
184
185
186 Region *rootRegion = findCommonAncestorRegion();
187 assert(rootRegion && "expected all ops to have a common ancestor");
188
189
190
192 assert(result.size() == toSort.size() &&
193 "expected all operations to be present in the result");
194 return result;
195 }
196
197 private:
198
199 Region *findCommonAncestorRegion() {
200
202 size_t expectedCount = toSort.size();
203
204
205
206 Region *res = nullptr;
209
210 ancestorBlocks.insert(op->getBlock());
211 while (current) {
212
213 if (++regionCounts[current] == expectedCount) {
214 res = current;
215 break;
216 }
219 }
220 }
221 auto firstRange = llvm::make_first_range(regionCounts);
222 ancestorRegions.insert_range(firstRange);
223 return res;
224 }
225
226
227
230
232
234 stack.push_back(&rootRegion);
235
236
237 while (!stack.empty()) {
238 StackT current = stack.pop_back_val();
239 if (auto *region = dyn_cast<Region *>(current)) {
240
242 for (Block *block : llvm::reverse(sortedBlocks)) {
243
244
245 if (ancestorBlocks.contains(block))
246 stack.push_back(block);
247 }
248 continue;
249 }
250
251 if (auto *block = dyn_cast<Block *>(current)) {
252
253 for (Operation &op : llvm::reverse(*block))
254 stack.push_back(&op);
255 continue;
256 }
257
258 auto *op = cast<Operation *>(current);
259 if (toSort.contains(op))
260 result.insert(op);
261
262
264 if (ancestorRegions.contains(&subRegion))
265 stack.push_back(&subRegion);
266 }
267 return result;
268 }
269
270
272
274
276 };
277 }
278
281 return TopoSortHelper(toSort).sort();
282 }
static bool isOpReady(Operation *op, DenseSet< Operation * > &unscheduledOps, function_ref< bool(Value, Operation *)> isOperandReady)
Return true if the given operation is ready to be scheduled.
Block represents an ordered list of Operations.
OpListType::iterator iterator
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
bool computeTopologicalSorting(MutableArrayRef< Operation * > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Compute a topological ordering of the given ops.
bool sortTopologically(Block *block, iterator_range< Block::iterator > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Given a block, sort a range operations in said block in topological order.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.