MLIR: lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
15
24
25 namespace mlir {
26 namespace bufferization {
27 #define GEN_PASS_DEF_BUFFERHOISTINGPASS
28 #define GEN_PASS_DEF_BUFFERLOOPHOISTINGPASS
29 #define GEN_PASS_DEF_PROMOTEBUFFERSTOSTACKPASS
30 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
31 }
32 }
33
34 using namespace mlir;
36
37
38
40 return isa<LoopLikeOpInterface, RegionBranchOpInterface>(op);
41 }
42
43
44
45
46
48
49
50 if (isa(op))
51 return true;
52
53
54
55 auto regionInterface = dyn_cast(op);
56 if (!regionInterface)
57 return false;
58
59 return regionInterface.hasLoop();
60 }
61
62
63
66 }
67
68
69
71 auto allocOp = dyn_cast(op);
72 return allocOp &&
74 }
75
76
77
79 auto allocOp = dyn_cast(op);
80 return allocOp &&
82 }
83
84
85
86
88 unsigned maxRankOfAllocatedMemRef) {
89 auto type = dyn_cast(alloc.getType());
90 if (!type || !alloc.getDefiningOpmemref::AllocOp())
91 return false;
92 if (!type.hasStaticShape()) {
93
94
95
96
97
98 if (type.getRank() <= maxRankOfAllocatedMemRef) {
100 [&](Value operand) {
101 return operand.getDefiningOpmemref::RankOp();
102 });
103 }
104 return false;
105 }
108 return type.getNumElements() * bitwidth <= maximumSizeInBytes * 8;
109 }
110
111
112 static bool
115 for (Value alias : aliases) {
116 for (auto *use : alias.getUsers()) {
117
118
119
120 if (isa(use) &&
121 use->getParentRegion() == parentRegion)
122 return true;
123 }
124 }
125 return false;
126 }
127
128
132 do {
134
135
136
139 return true;
140
141
142
144 break;
145 }
147 return false;
148 }
149
150 namespace {
151
152
153
154
155
156
157 struct BufferAllocationHoistingStateBase {
158
160
161
162 Value allocValue;
163
164
165 Block *placementBlock;
166
167
168 BufferAllocationHoistingStateBase(DominanceInfo *dominators, Value allocValue,
169 Block *placementBlock)
170 : dominators(dominators), allocValue(allocValue),
171 placementBlock(placementBlock) {}
172 };
173
174
175 template
177 public:
178 BufferAllocationHoisting(Operation *op)
180 postDominators(op), scopeOp(op) {}
181
182
183 void hoist() {
186 allocsAndAllocas.push_back(std::get<0>(entry));
187 scopeOp->walk([&](memref::AllocaOp op) {
188 allocsAndAllocas.push_back(op.getMemref());
189 });
190
191 for (auto allocValue : allocsAndAllocas) {
192 if (!StateT::shouldHoistOpType(allocValue.getDefiningOp()))
193 continue;
194 Operation *definingOp = allocValue.getDefiningOp();
195 assert(definingOp && "No defining op");
196 auto operands = definingOp->getOperands();
197 auto resultAliases = aliases.resolve(allocValue);
198
199 Block *dominatorBlock =
201
202 StateT state(&dominators, allocValue, allocValue.getParentBlock());
203
204
205 Block *dependencyBlock = nullptr;
206
207
208
209 for (Value depValue : operands) {
210 Block *depBlock = depValue.getParentBlock();
211 if (!dependencyBlock || dominators.dominates(dependencyBlock, depBlock))
212 dependencyBlock = depBlock;
213 }
214
215
216
217
218 Block *placementBlock = findPlacementBlock(
219 state, state.computeUpperBound(dominatorBlock, dependencyBlock));
221 allocValue, placementBlock, liveness);
222
223
224 Operation *allocOperation = allocValue.getDefiningOp();
225 allocOperation->moveBefore(startOperation);
226 }
227 }
228
229 private:
230
231
232
233 Block *findPlacementBlock(StateT &state, Block *upperBound) {
234 Block *currentBlock = state.placementBlock;
235
236
237
238
239
240
241
243 Block *parentBlock;
244 while ((parentOp = currentBlock->getParentOp()) &&
245 (parentBlock = parentOp->getBlock()) &&
246 (!upperBound ||
247 dominators.properlyDominates(upperBound, currentBlock))) {
248
249
251
252
254 idom = dominators.getNode(currentBlock)->getIDom();
255
256 if (idom && dominators.properlyDominates(parentBlock, idom->getBlock())) {
257
258
259 currentBlock = idom->getBlock();
260 state.recordMoveToDominator(currentBlock);
261 } else {
262
263
264
265
266
268 !state.isLegalPlacement(parentOp))
269 break;
270
271
272 currentBlock = parentBlock;
273 state.recordMoveToParent(currentBlock);
274 }
275 }
276
277 return state.placementBlock;
278 }
279
280
281
283
284
285
287
288
290
291
292
294 };
295
296
297
298
299 struct BufferAllocationHoistingState : BufferAllocationHoistingStateBase {
300 using BufferAllocationHoistingStateBase::BufferAllocationHoistingStateBase;
301
302
303 Block *computeUpperBound(Block *dominatorBlock, Block *dependencyBlock) {
304
305
306 if (!dependencyBlock)
307 return dominatorBlock;
308
309
310
311 return dominators->properlyDominates(dominatorBlock, dependencyBlock)
312 ? dependencyBlock
313 : dominatorBlock;
314 }
315
316
317 bool isLegalPlacement(Operation *op) { return (op); }
318
319
320 static bool shouldHoistOpType(Operation *op) {
322 }
323
324
325 void recordMoveToDominator(Block *block) { placementBlock = block; }
326
327
328 void recordMoveToParent(Block *block) { recordMoveToDominator(block); }
329 };
330
331
332
333 struct BufferAllocationLoopHoistingState : BufferAllocationHoistingStateBase {
334 using BufferAllocationHoistingStateBase::BufferAllocationHoistingStateBase;
335
336
337 Block *aliasDominatorBlock = nullptr;
338
339
340 Block *computeUpperBound(Block *dominatorBlock, Block *dependencyBlock) {
341 aliasDominatorBlock = dominatorBlock;
342
343
344 return dependencyBlock ? dependencyBlock : nullptr;
345 }
346
347
348
349
350
351
352 bool isLegalPlacement(Operation *op) {
354 !dominators->dominates(aliasDominatorBlock, op->getBlock());
355 }
356
357
358 static bool shouldHoistOpType(Operation *op) {
360 }
361
362
363
364 void recordMoveToDominator(Block *block) {}
365
366
367 void recordMoveToParent(Block *block) { placementBlock = block; }
368 };
369
370
371
372
373
374
376 public:
377 BufferPlacementPromotion(Operation *op)
379
380
383 Value alloc = std::get<0>(entry);
384 Operation *dealloc = std::get<1>(entry);
385
386
387
388
389 if (!isSmallAlloc(alloc) || dealloc ||
391 continue;
392
395
396
397 OpBuilder builder(startOperation);
399 if (auto allocInterface = dyn_cast(allocOp)) {
400 std::optional<Operation *> alloca =
401 allocInterface.buildPromotedAlloc(builder, alloc);
402 if (!alloca)
403 continue;
404
406 allocOp->erase();
407 }
408 }
409 }
410 };
411
412
413
414
415
416
417
418 struct BufferHoistingPass
419 : public bufferization::impl::BufferHoistingPassBase {
420
421 void runOnOperation() override {
422
423 BufferAllocationHoisting optimizer(
424 getOperation());
425 optimizer.hoist();
426 }
427 };
428
429
430 struct BufferLoopHoistingPass
431 : public bufferization::impl::BufferLoopHoistingPassBase<
432 BufferLoopHoistingPass> {
433
434 void runOnOperation() override {
435
437 }
438 };
439
440
441
442 class PromoteBuffersToStackPass
443 : public bufferization::impl::PromoteBuffersToStackPassBase<
444 PromoteBuffersToStackPass> {
445 using Base::Base;
446
447 public:
448 explicit PromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc)
449 : isSmallAlloc(std::move(isSmallAlloc)) {}
450
451 LogicalResult initialize(MLIRContext *context) override {
452 if (isSmallAlloc == nullptr) {
453 isSmallAlloc = [=](Value alloc) {
455 maxRankOfAllocatedMemRef);
456 };
457 }
458 return success();
459 }
460
461 void runOnOperation() override {
462
463 BufferPlacementPromotion optimizer(getOperation());
464 optimizer.promote(isSmallAlloc);
465 }
466
467 private:
468 std::function<bool(Value)> isSmallAlloc;
469 };
470
471 }
472
474 BufferAllocationHoisting optimizer(op);
475 optimizer.hoist();
476 }
477
479 std::function<bool(Value)> isSmallAlloc) {
480 return std::make_unique(std::move(isSmallAlloc));
481 }
static bool leavesAllocationScope(Region *parentRegion, const BufferViewFlowAnalysis::ValueSetT &aliases)
Checks whether the given aliases leave the allocation scope.
static bool isKnownControlFlowInterface(Operation *op)
Returns true if the given operation implements a known high-level region- based control-flow interfac...
static bool hasAllocationScope(Value alloc, const BufferViewFlowAnalysis &aliasAnalysis)
Checks, if an automated allocation scope for a given alloc value exists.
static bool isSequentialLoop(Operation *op)
Return whether the given operation is a loop with sequential execution semantics.
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool allowAllocDominateBlockHoisting(Operation *op)
Returns true if the given operation implements the AllocationOpInterface and it supports the dominate...
static bool allowAllocLoopHoisting(Operation *op)
Returns true if the given operation implements the AllocationOpInterface and it supports the loop hoi...
static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes, unsigned maxRankOfAllocatedMemRef)
Check if the size of the allocation is less than the given size.
Block represents an ordered list of Operations.
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
A straight-forward alias analysis which ensures that all dependencies of all values will be determine...
ValueSetT resolve(Value value) const
Find all immediate and indirect views upon this value.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
A class for computing basic dominance information.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getBlock()
Returns the operation block that contains this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
void erase()
Remove this operation from its parent block and delete it.
A class for computing basic postdominance information.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * getParentOp()
Return the parent operation this region is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
static Operation * getStartOperation(Value allocValue, Block *placementBlock, const Liveness &liveness)
Get the start operation to place the given alloc value within the specified placement block.
std::tuple< Value, Operation * > AllocEntry
Represents a tuple of allocValue and deallocOperation.
The base class for all BufferPlacement transformations.
void hoistBuffersFromLoops(Operation *op)
Within the given operation, hoist buffers from loops where possible.
std::unique_ptr< Pass > createPromoteBuffersToStackPass(std::function< bool(Value)> isSmallAlloc)
Creates a pass that promotes heap-based allocations to stack-based ones.
Block * findCommonDominator(Value value, const BufferViewFlowAnalysis::ValueSetT &values, const DominatorT &doms)
Finds a common dominator for the given value while taking the positions of the values in the value se...
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Include the generated interface declarations.
llvm::DomTreeNodeBase< Block > DominanceInfoNode