MLIR: lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
17 #include "llvm/ADT/SetOperations.h"
18
19
20
21
22
23 namespace mlir {
24 namespace bufferization {
25
26 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc"
27
28 }
29 }
30
31 using namespace mlir;
32 using namespace bufferization;
33
34
35
36
37
39 return builder.createarith::ConstantOp(loc, builder.getBoolAttr(value));
40 }
41
43
44
45
46
47
49 : indicator(indicator), state(State::Unique) {}
50
53 unknown.indicator = Value();
54 unknown.state = State::Unknown;
55 return unknown;
56 }
59
61 return state == State::Uninitialized;
62 }
65
67 assert(isUnique() && "must have unique ownership to get the indicator");
68 return indicator;
69 }
70
73 return *this;
75 return other;
76
79
80
81
82
84 return *this;
85
86
87
89 }
90
92
93
94
95
96
99 : symbolTable(symbolTables), liveness(op) {}
100
103
104 if (block == nullptr)
106
107
108 ownershipMap[{memref, block}].combine(ownership);
109 }
110
112 for (Value val : memrefs)
114 }
115
117 return ownershipMap.lookup({memref, block});
118 }
119
121 memrefsToDeallocatePerBlock[block].push_back(memref);
122 }
123
125 llvm::erase(memrefsToDeallocatePerBlock[block], memref);
126 }
127
133 memrefs.append(liveMemrefs);
134 }
135
136 std::pair<Value, Value>
139 auto iter = ownershipMap.find({memref, block});
140 assert(iter != ownershipMap.end() &&
141 "Value must already have been registered in the ownership map");
142
143 Ownership ownership = iter->second;
146
147
148
149
150
151
152
153 auto cloneOp =
154 builder.createbufferization::CloneOp(memref.getLoc(), memref);
156 Value newMemref = cloneOp.getResult();
158 memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(newMemref);
159 return {newMemref, condition};
160 }
161
165 for (Value operand : destOperands) {
167 continue;
168 toRetain.push_back(operand);
169 }
170
172 for (auto val : liveness.getLiveOut(fromBlock))
174 liveOut.insert(val);
175
176 if (toBlock)
177 llvm::set_intersect(liveOut, liveness.getLiveIn(toBlock));
178
179
180
181 SmallVector retainedByLiveness(liveOut.begin(), liveOut.end());
183 toRetain.append(retainedByLiveness);
184 }
185
189
190 for (auto [i, memref] :
191 llvm::enumerate(memrefsToDeallocatePerBlock.lookup(block))) {
192 Ownership ownership = ownershipMap.lookup({memref, block});
194 return emitError(memref.getLoc(),
195 "MemRef value does not have valid ownership");
196
197
198
199 if (auto unrankedMemRefTy = dyn_cast(memref.getType()))
200 memref = builder.creatememref::ReinterpretCastOp(
201 loc, memref,
205
206
207
208
209
210 memrefs.push_back(
211 builder.creatememref::ExtractStridedMetadataOp(loc, memref)
213 conditions.push_back(ownership.getIndicator());
214 }
215
216 return success();
217 }
218
219
220
221
222
224 if (lhs == rhs)
225 return false;
226
227
228 bool lhsIsBBArg = isa(lhs);
229 if (lhsIsBBArg != isa(rhs)) {
230 return lhsIsBBArg;
231 }
232
235 if (lhsIsBBArg) {
236 auto lhsBBArg = llvm::cast(lhs);
237 auto rhsBBArg = llvm::cast(rhs);
238 if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) {
239 return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber();
240 }
243 assert(lhsRegion != rhsRegion &&
244 "lhsRegion == rhsRegion implies lhs == rhs");
246 return llvm::cast(lhs).getResultNumber() <
247 llvm::cast(rhs).getResultNumber();
248 } else {
251 if (lhsRegion == rhsRegion) {
253 }
254 }
255
256
257
258
259
260
261 while (lhsRegion && rhsRegion) {
264 }
268 }
271 }
272 if (rhsRegion)
273 return true;
274 assert(lhsRegion && "this should only happen if lhs == rhs");
275 return false;
276 }
277
278
279
280
281
286 assert(!op->hasSuccessors() && "must not have any successors");
287
288
292 if (failed(state.getMemrefsAndConditionsToDeallocate(
293 builder, op->getLoc(), block, memrefs, conditions)))
294 return failure();
295
296 state.getMemrefsToRetain(block, nullptr, operands, toRetain);
297 if (memrefs.empty() && toRetain.empty())
298 return op;
299
300 auto deallocOp = builder.createbufferization::DeallocOp(
301 op->getLoc(), memrefs, conditions, toRetain);
302
303
304
305 state.resetOwnerships(deallocOp.getRetained(), block);
306 for (auto [retained, ownership] :
307 llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))
308 state.updateOwnership(retained, ownership, block);
309
310 unsigned numMemrefOperands = llvm::count_if(operands, isMemref);
311 auto newOperandOwnerships =
312 deallocOp.getUpdatedConditions().take_front(numMemrefOperands);
313 updatedOperandOwnerships.append(newOperandOwnerships.begin(),
314 newOperandOwnerships.end());
315
316 return op;
317 }
static bool isMemref(Value v)
static Value buildBoolValue(OpBuilder &builder, Location loc, bool value)
Block represents an ordered list of Operations.
IntegerAttr getIndexAttr(int64_t value)
BoolAttr getBoolAttr(bool value)
const ValueSetT & getLiveOut(Block *block) const
Returns a reference to a set containing live-out values (unordered).
const ValueSetT & getLiveIn(Block *block) const
Returns a reference to a set containing live-in values (unordered).
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This class collects all the state that we need to perform the buffer deallocation pass with associate...
void addMemrefToDeallocate(Value memref, Block *block)
Remember the given 'memref' to deallocate it at the end of the 'block'.
Ownership getOwnership(Value memref, Block *block) const
Returns the ownership of 'memref' for the given basic block.
DeallocationState(Operation *op, SymbolTableCollection &symbolTables)
void resetOwnerships(ValueRange memrefs, Block *block)
Removes ownerships associated with all values in the passed range for 'block'.
void updateOwnership(Value memref, Ownership ownership, Block *block=nullptr)
Small helper function to update the ownership map by taking the current ownership ('Uninitialized' st...
std::pair< Value, Value > getMemrefWithUniqueOwnership(OpBuilder &builder, Value memref, Block *block)
Given an SSA value of MemRef type, this function queries the ownership and if it is not already in th...
LogicalResult getMemrefsAndConditionsToDeallocate(OpBuilder &builder, Location loc, Block *block, SmallVectorImpl< Value > &memrefs, SmallVectorImpl< Value > &conditions) const
For a given block, computes the list of MemRefs that potentially need to be deallocated at the end of...
void getLiveMemrefsIn(Block *block, SmallVectorImpl< Value > &memrefs)
Return a sorted list of MemRef values which are live at the start of the given block.
void dropMemrefToDeallocate(Value memref, Block *block)
Forget about a MemRef that we originally wanted to deallocate at the end of 'block',...
void getMemrefsToRetain(Block *fromBlock, Block *toBlock, ValueRange destOperands, SmallVectorImpl< Value > &toRetain) const
Given two basic blocks and the values passed via block arguments to the destination block,...
This class is used to track the ownership of values.
static Ownership getUnique(Value indicator)
Get an ownership value in 'Unique' state with 'indicator' as parameter.
Ownership getCombined(Ownership other) const
Get the join of the two-element subset {this,other}.
void combine(Ownership other)
Modify 'this' ownership to be the join of the current 'this' and 'other'.
Ownership()=default
Constructor that creates an 'Uninitialized' ownership.
bool isUnknown() const
Check if this ownership value is in the 'Unknown' state.
bool isUnique() const
Check if this ownership value is in the 'Unique' state.
static Ownership getUnknown()
Get an ownership value in 'Unknown' state.
Value getIndicator() const
If this ownership value is in 'Unique' state, this function can be used to get the indicator paramete...
bool isUninitialized() const
Check if this ownership value is in the 'Uninitialized' state.
static Ownership getUninitialized()
Get an ownership value in 'Uninitialized' state.
FailureOr< Operation * > insertDeallocOpForReturnLike(DeallocationState &state, Operation *op, ValueRange operands, SmallVectorImpl< Value > &updatedOperandOwnerships)
Insert a bufferization.dealloc operation right before op which has to be a terminator without any suc...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Compare two SSA values in a deterministic manner.
bool operator()(const Value &lhs, const Value &rhs) const