MLIR: lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

17 #include "llvm/ADT/SetOperations.h"

18

19

20

21

22

23 namespace mlir {

24 namespace bufferization {

25

26 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc"

27

28 }

29 }

30

31 using namespace mlir;

32 using namespace bufferization;

33

34

35

36

37

39 return builder.createarith::ConstantOp(loc, builder.getBoolAttr(value));

40 }

41

43

44

45

46

47

49 : indicator(indicator), state(State::Unique) {}

50

53 unknown.indicator = Value();

54 unknown.state = State::Unknown;

55 return unknown;

56 }

59

61 return state == State::Uninitialized;

62 }

65

67 assert(isUnique() && "must have unique ownership to get the indicator");

68 return indicator;

69 }

70

73 return *this;

75 return other;

76

79

80

81

82

84 return *this;

85

86

87

89 }

90

92

93

94

95

96

99 : symbolTable(symbolTables), liveness(op) {}

100

103

104 if (block == nullptr)

106

107

108 ownershipMap[{memref, block}].combine(ownership);

109 }

110

112 for (Value val : memrefs)

114 }

115

117 return ownershipMap.lookup({memref, block});

118 }

119

121 memrefsToDeallocatePerBlock[block].push_back(memref);

122 }

123

125 llvm::erase(memrefsToDeallocatePerBlock[block], memref);

126 }

127

133 memrefs.append(liveMemrefs);

134 }

135

136 std::pair<Value, Value>

139 auto iter = ownershipMap.find({memref, block});

140 assert(iter != ownershipMap.end() &&

141 "Value must already have been registered in the ownership map");

142

143 Ownership ownership = iter->second;

146

147

148

149

150

151

152

153 auto cloneOp =

154 builder.createbufferization::CloneOp(memref.getLoc(), memref);

156 Value newMemref = cloneOp.getResult();

158 memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(newMemref);

159 return {newMemref, condition};

160 }

161

165 for (Value operand : destOperands) {

167 continue;

168 toRetain.push_back(operand);

169 }

170

172 for (auto val : liveness.getLiveOut(fromBlock))

174 liveOut.insert(val);

175

176 if (toBlock)

177 llvm::set_intersect(liveOut, liveness.getLiveIn(toBlock));

178

179

180

181 SmallVector retainedByLiveness(liveOut.begin(), liveOut.end());

183 toRetain.append(retainedByLiveness);

184 }

185

189

190 for (auto [i, memref] :

191 llvm::enumerate(memrefsToDeallocatePerBlock.lookup(block))) {

192 Ownership ownership = ownershipMap.lookup({memref, block});

194 return emitError(memref.getLoc(),

195 "MemRef value does not have valid ownership");

196

197

198

199 if (auto unrankedMemRefTy = dyn_cast(memref.getType()))

200 memref = builder.creatememref::ReinterpretCastOp(

201 loc, memref,

205

206

207

208

209

210 memrefs.push_back(

211 builder.creatememref::ExtractStridedMetadataOp(loc, memref)

213 conditions.push_back(ownership.getIndicator());

214 }

215

216 return success();

217 }

218

219

220

221

222

224 if (lhs == rhs)

225 return false;

226

227

228 bool lhsIsBBArg = isa(lhs);

229 if (lhsIsBBArg != isa(rhs)) {

230 return lhsIsBBArg;

231 }

232

235 if (lhsIsBBArg) {

236 auto lhsBBArg = llvm::cast(lhs);

237 auto rhsBBArg = llvm::cast(rhs);

238 if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) {

239 return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber();

240 }

243 assert(lhsRegion != rhsRegion &&

244 "lhsRegion == rhsRegion implies lhs == rhs");

246 return llvm::cast(lhs).getResultNumber() <

247 llvm::cast(rhs).getResultNumber();

248 } else {

251 if (lhsRegion == rhsRegion) {

253 }

254 }

255

256

257

258

259

260

261 while (lhsRegion && rhsRegion) {

264 }

268 }

271 }

272 if (rhsRegion)

273 return true;

274 assert(lhsRegion && "this should only happen if lhs == rhs");

275 return false;

276 }

277

278

279

280

281

286 assert(!op->hasSuccessors() && "must not have any successors");

287

288

292 if (failed(state.getMemrefsAndConditionsToDeallocate(

293 builder, op->getLoc(), block, memrefs, conditions)))

294 return failure();

295

296 state.getMemrefsToRetain(block, nullptr, operands, toRetain);

297 if (memrefs.empty() && toRetain.empty())

298 return op;

299

300 auto deallocOp = builder.createbufferization::DeallocOp(

301 op->getLoc(), memrefs, conditions, toRetain);

302

303

304

305 state.resetOwnerships(deallocOp.getRetained(), block);

306 for (auto [retained, ownership] :

307 llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))

308 state.updateOwnership(retained, ownership, block);

309

310 unsigned numMemrefOperands = llvm::count_if(operands, isMemref);

311 auto newOperandOwnerships =

312 deallocOp.getUpdatedConditions().take_front(numMemrefOperands);

313 updatedOperandOwnerships.append(newOperandOwnerships.begin(),

314 newOperandOwnerships.end());

315

316 return op;

317 }

static bool isMemref(Value v)

static Value buildBoolValue(OpBuilder &builder, Location loc, bool value)

Block represents an ordered list of Operations.

IntegerAttr getIndexAttr(int64_t value)

BoolAttr getBoolAttr(bool value)

const ValueSetT & getLiveOut(Block *block) const

Returns a reference to a set containing live-out values (unordered).

const ValueSetT & getLiveIn(Block *block) const

Returns a reference to a set containing live-in values (unordered).

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

This class helps build Operations.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class provides the API for ops that are known to be terminators.

Operation is the basic unit of execution within MLIR.

bool hasTrait()

Returns true if the operation was registered with a particular trait, e.g.

bool isBeforeInBlock(Operation *other)

Given an operation 'other' that is within the same parent block, return whether the current operation...

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

Location getLoc()

The source location the operation was defined or derived from.

Block * getBlock()

Returns the operation block that contains this operation.

Region * getParentRegion()

Returns the region to which the instruction belongs.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

Region * getParentRegion()

Return the region containing this region or nullptr if the region is attached to a top-level operatio...

unsigned getRegionNumber()

Return the number of this region in the parent operation.

Operation * getParentOp()

Return the parent operation this region is attached to.

This class represents a collection of SymbolTables.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Block * getParentBlock()

Return the Block in which this Value is defined.

Location getLoc() const

Return the location of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

This class collects all the state that we need to perform the buffer deallocation pass with associate...

void addMemrefToDeallocate(Value memref, Block *block)

Remember the given 'memref' to deallocate it at the end of the 'block'.

Ownership getOwnership(Value memref, Block *block) const

Returns the ownership of 'memref' for the given basic block.

DeallocationState(Operation *op, SymbolTableCollection &symbolTables)

void resetOwnerships(ValueRange memrefs, Block *block)

Removes ownerships associated with all values in the passed range for 'block'.

void updateOwnership(Value memref, Ownership ownership, Block *block=nullptr)

Small helper function to update the ownership map by taking the current ownership ('Uninitialized' st...

std::pair< Value, Value > getMemrefWithUniqueOwnership(OpBuilder &builder, Value memref, Block *block)

Given an SSA value of MemRef type, this function queries the ownership and if it is not already in th...

LogicalResult getMemrefsAndConditionsToDeallocate(OpBuilder &builder, Location loc, Block *block, SmallVectorImpl< Value > &memrefs, SmallVectorImpl< Value > &conditions) const

For a given block, computes the list of MemRefs that potentially need to be deallocated at the end of...

void getLiveMemrefsIn(Block *block, SmallVectorImpl< Value > &memrefs)

Return a sorted list of MemRef values which are live at the start of the given block.

void dropMemrefToDeallocate(Value memref, Block *block)

Forget about a MemRef that we originally wanted to deallocate at the end of 'block',...

void getMemrefsToRetain(Block *fromBlock, Block *toBlock, ValueRange destOperands, SmallVectorImpl< Value > &toRetain) const

Given two basic blocks and the values passed via block arguments to the destination block,...

This class is used to track the ownership of values.

static Ownership getUnique(Value indicator)

Get an ownership value in 'Unique' state with 'indicator' as parameter.

Ownership getCombined(Ownership other) const

Get the join of the two-element subset {this,other}.

void combine(Ownership other)

Modify 'this' ownership to be the join of the current 'this' and 'other'.

Ownership()=default

Constructor that creates an 'Uninitialized' ownership.

bool isUnknown() const

Check if this ownership value is in the 'Unknown' state.

bool isUnique() const

Check if this ownership value is in the 'Unique' state.

static Ownership getUnknown()

Get an ownership value in 'Unknown' state.

Value getIndicator() const

If this ownership value is in 'Unique' state, this function can be used to get the indicator paramete...

bool isUninitialized() const

Check if this ownership value is in the 'Uninitialized' state.

static Ownership getUninitialized()

Get an ownership value in 'Uninitialized' state.

FailureOr< Operation * > insertDeallocOpForReturnLike(DeallocationState &state, Operation *op, ValueRange operands, SmallVectorImpl< Value > &updatedOperandOwnerships)

Insert a bufferization.dealloc operation right before op which has to be a terminator without any suc...

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

Include the generated interface declarations.

bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)

Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

Compare two SSA values in a deterministic manner.

bool operator()(const Value &lhs, const Value &rhs) const