MLIR: lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
16 #include "llvm/ADT/SetOperations.h"
17 #include "llvm/ADT/SetVector.h"
18
19 using namespace mlir;
21
22
23
24
25
26
28
33 queue.push_back(value);
34 while (!queue.empty()) {
35 Value currentValue = queue.pop_back_val();
36 if (result.insert(currentValue).second) {
37 auto it = map.find(currentValue);
38 if (it != map.end()) {
39 for (Value aliasValue : it->second)
40 queue.push_back(aliasValue);
41 }
42 }
43 }
44 return result;
45 }
46
47
48
49
53 }
54
57 return resolveValues(reverseDependencies, rootValue);
58 }
59
60
62 for (auto &entry : dependencies)
63 llvm::set_subtract(entry.second, aliasValues);
64 }
65
67 dependencies[to] = dependencies[from];
68 dependencies.erase(from);
69
70 for (auto &[_, value] : dependencies) {
71 if (value.contains(from)) {
72 value.insert(to);
73 value.erase(from);
74 }
75 }
76 }
77
78
79
80
81
82
83 void BufferViewFlowAnalysis::build(Operation *op) {
84
86 for (auto [value, dep] : llvm::zip_equal(values, dependencies)) {
87 this->dependencies[value].insert(dep);
88 this->reverseDependencies[dep].insert(value);
89 }
90 };
91
92
93
94 auto populateTerminalValues = [&](Operation *op) {
96 if (isa(v.getType()))
97 this->terminals.insert(v);
100 if (isa(v.getType()))
101 this->terminals.insert(v);
102 };
103
105
106
107
108 if (auto bufferViewFlowOp = dyn_cast(op)) {
109 bufferViewFlowOp.populateDependencies(registerDependencies);
111 if (isa(v.getType()) &&
112 bufferViewFlowOp.mayBeTerminalBuffer(v))
113 this->terminals.insert(v);
116 if (isa(v.getType()) &&
117 bufferViewFlowOp.mayBeTerminalBuffer(v))
118 this->terminals.insert(v);
120 }
121
122
123 if (auto viewInterface = dyn_cast(op)) {
124 registerDependencies(viewInterface.getViewSource(),
125 viewInterface->getResult(0));
127 }
128
129 if (auto branchInterface = dyn_cast(op)) {
130
131 Block *parentBlock = branchInterface->getBlock();
132 for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
133 it != e; ++it) {
134
135 auto successorOperands =
136 branchInterface.getSuccessorOperands(it.getIndex());
137
138 registerDependencies(successorOperands.getForwardedOperands(),
139 (*it)->getArguments().drop_front(
140 successorOperands.getProducedOperandCount()));
141 }
143 }
144
145 if (auto regionInterface = dyn_cast(op)) {
146
147
150 entrySuccessors);
151 for (RegionSuccessor &entrySuccessor : entrySuccessors) {
152
153
154 registerDependencies(
155 regionInterface.getEntrySuccessorOperands(entrySuccessor),
156 entrySuccessor.getSuccessorInputs());
157 }
158
159
160 for (Region ®ion : regionInterface->getRegions()) {
161
162
164 regionInterface.getSuccessorRegions(region, successorRegions);
165 for (RegionSuccessor &successorRegion : successorRegions) {
166
167
168 for (Block &block : region)
169 if (auto terminator = dyn_cast(
170 block.getTerminator()))
171 registerDependencies(
172 terminator.getSuccessorOperands(successorRegion),
173 successorRegion.getSuccessorInputs());
174 }
175 }
176
178 }
179
180
181 if (isa(op))
183
184 if (isa(op)) {
185
186
187
188
189 populateTerminalValues(op);
192 registerDependencies({operand}, {result});
194 }
195
196
197 populateTerminalValues(op);
198
200 });
201 }
202
204 assert(isa(value.getType()) && "expected memref");
205 return terminals.contains(value);
206 }
207
208
209
210
211
212
215 if (!op)
216 return false;
217 return hasEffectMemoryEffects::Allocate(op, v);
218 }
219
220
222 auto bbArg = dyn_cast(v);
223 if (!bbArg)
224 return false;
225 Block *b = bbArg.getOwner();
226 auto funcOp = dyn_cast(b->getParentOp());
227 if (!funcOp)
228 return false;
229 return bbArg.getOwner() == &funcOp.getFunctionBody().front();
230 }
231
232
233
235 while (auto viewLikeOp = value.getDefiningOp())
236 value = viewLikeOp.getViewSource();
237 return value;
238 }
239
241
243 assert(isa(v1.getType()) && "expected buffer");
244 assert(isa(v2.getType()) && "expected buffer");
245
246
249
250
251
252 if (v1 == v2)
253 return true;
254
255
258
259
260
261
262
263
264
266
267
268
269 bool allAllocs1 = true, allAllocs2 = true;
270 bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;
271
272
275 bool &allAllocs,
276 bool &allAllocsOrFuncEntryArgs) {
277 for (Value v : origin) {
278 if (isa(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
279 terminal.insert(v);
281 allAllocsOrFuncEntryArgs &=
283 }
284 }
285 assert(!terminal.empty() && "expected non-empty terminal set");
286 };
287
288
289 gatherTerminalBuffers(origin1, terminal1, allAllocs1,
290 allAllocsOrFuncEntryArgs1);
291 gatherTerminalBuffers(origin2, terminal2, allAllocs2,
292 allAllocsOrFuncEntryArgs2);
293
294
295
296 if (llvm::hasSingleElement(terminal1) && llvm::hasSingleElement(terminal2) &&
297 *terminal1.begin() == *terminal2.begin())
298 return true;
299
300
301
302
303 bool distinctTerminalSets = true;
304 for (Value v : terminal1)
305 distinctTerminalSets &= !terminal2.contains(v);
306
307
308 if (!distinctTerminalSets)
309 return std::nullopt;
310
311
312
313
314
315
316 bool isolatedAlloc1 = allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2);
317 bool isolatedAlloc2 = (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2;
318 if (isolatedAlloc1 || isolatedAlloc2)
319 return false;
320
321
322
323
324
325
326
327
328
329 return std::nullopt;
330 }
static bool isFunctionArgument(Value v)
Return "true" if the given value is a function block argument.
static Value getViewBase(Value value)
Given a memref value, return the "base" value by skipping over all ViewLikeOpInterface ops (if any) i...
static BufferViewFlowAnalysis::ValueSetT resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value)
static bool hasAllocateSideEffect(Value v)
Return "true" if the given value is the result of a memory allocation.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
succ_iterator succ_begin()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
BufferOriginAnalysis(Operation *op)
std::optional< bool > isSameAllocation(Value v1, Value v2)
Return "true" if v1 and v2 originate from the same buffer allocation.
BufferViewFlowAnalysis(Operation *op)
Constructs a new alias analysis using the op provided.
void remove(const SetVector< Value > &aliasValues)
Removes the given values from all alias sets.
ValueSetT resolve(Value value) const
Find all immediate and indirect views upon this value.
void rename(Value from, Value to)
Replaces all occurrences of 'from' in the internal datastructures with 'to'.
bool mayBeTerminalBuffer(Value value) const
Returns "true" if the given value may be a terminal.
ValueSetT resolveReverse(Value value) const
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Include the generated interface declarations.