MLIR: lib/Transforms/Mem2Reg.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
21#include "llvm/ADT/STLExtras.h"
22#include "llvm/Support/DebugLog.h"
23#include "llvm/Support/GenericIteratedDominanceFrontier.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_MEM2REG
27#include "mlir/Transforms/Passes.h.inc"
28}
29
30#define DEBUG_TYPE "mem2reg"
31
32using namespace mlir;
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99namespace {
100
101using BlockingUsesMap =
102 llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
103
104
105
106struct MemorySlotPromotionInfo {
107
109
110
111
112
113
114 BlockingUsesMap userToBlockingUses;
115};
116
117
118
119
120class MemorySlotPromotionAnalyzer {
121public:
124 : slot(slot), dominance(dominance), dataLayout(dataLayout) {}
125
126
127
128 std::optional computeInfo();
129
130private:
131
132
133
134
135
136
137
138 LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
139
140
141
142
143
146
147
148
150
151
152
153
154
156
160};
161
163
164
165
166
167class MemorySlotPromoter {
168public:
169 MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
171 const DataLayout &dataLayout, MemorySlotPromotionInfo info,
173 BlockIndexCache &blockIndexCache);
174
175
176
177
178
179
180 std::optional promoteSlot();
181
182private:
183
184
185
186
187 Value computeReachingDefInBlock(Block *block, Value reachingDef);
188
189
190
191
192
193 void computeReachingDefInRegion(Region *region, Value reachingDef);
194
195
196 void removeBlockingUses();
197
198
199
200 Value getOrCreateDefaultValue();
201
203 PromotableAllocationOpInterface allocator;
205
206
207 Value defaultValue;
208
209
214 MemorySlotPromotionInfo info;
216
217
218 BlockIndexCache &blockIndexCache;
219};
220
221}
222
223MemorySlotPromoter::MemorySlotPromoter(
224 MemorySlot slot, PromotableAllocationOpInterface allocator,
226 MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics,
227 BlockIndexCache &blockIndexCache)
228 : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
229 dataLayout(dataLayout), info(std::move(info)), statistics(statistics),
230 blockIndexCache(blockIndexCache) {
231#ifndef NDEBUG
232 auto isResultOrNewBlockArgument = [&]() {
234 return arg.getOwner()->getParentOp() == allocator;
236 };
237
238 assert(isResultOrNewBlockArgument() &&
239 "a slot must be a result of the allocator or an argument of the child "
240 "regions of the allocator");
241#endif
242}
243
244Value MemorySlotPromoter::getOrCreateDefaultValue() {
245 if (defaultValue)
246 return defaultValue;
247
248 OpBuilder::InsertionGuard guard(builder);
250 return defaultValue = allocator.getDefaultValue(slot, builder);
251}
252
253LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
254 BlockingUsesMap &userToBlockingUses) {
255
256
257
258
259
260
261
262
263
265 auto slotPtrRegionOp =
266 dyn_cast(slotPtrRegion->getParentOp());
267 if (slotPtrRegionOp &&
268 slotPtrRegionOp.getRegionKind(slotPtrRegion->getRegionNumber()) ==
269 RegionKind::Graph)
270 return failure();
271
272
273
274 for (OpOperand &use : slot.ptr.getUses()) {
275 SmallPtrSet<OpOperand *, 4> &blockingUses =
276 userToBlockingUses[use.getOwner()];
277 blockingUses.insert(&use);
278 }
279
280
281
282
283
284
287 for (Operation *user : forwardSlice) {
288
289 auto *it = userToBlockingUses.find(user);
290 if (it == userToBlockingUses.end())
291 continue;
292
293 SmallPtrSet<OpOperand *, 4> &blockingUses = it->second;
294
295 SmallVector<OpOperand *> newBlockingUses;
296
297
298 if (auto promotable = dyn_cast(user)) {
299 if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
300 dataLayout))
301 return failure();
302 } else if (auto promotable = dyn_cast(user)) {
303 if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
304 dataLayout))
305 return failure();
306 } else {
307
308
309 return failure();
310 }
311
312
313 for (OpOperand *blockingUse : newBlockingUses) {
314 assert(llvm::is_contained(user->getResults(), blockingUse->get()));
315
316 SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
317 userToBlockingUses[blockingUse->getOwner()];
318 newUserBlockingUseSet.insert(blockingUse);
319 }
320 }
321
322
323
324
325
326 for (auto &[toPromote, _] : userToBlockingUses)
327 if (isa(toPromote) &&
329 return failure();
330
332}
333
334SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(
335 SmallPtrSetImpl<Block *> &definingBlocks) {
336 SmallPtrSet<Block *, 16> liveIn;
337
338
339
340
341 SmallVector<Block *> liveInWorkList;
342
343
344
345
346 SmallPtrSet<Block *, 16> visited;
347 for (Operation *user : slot.ptr.getUsers()) {
348 if (!visited.insert(user->getBlock()).second)
349 continue;
350
351 for (Operation &op : user->getBlock()->getOperations()) {
352 if (auto memOp = dyn_cast(op)) {
353
354
355 if (memOp.loadsFrom(slot)) {
356 liveInWorkList.push_back(user->getBlock());
357 break;
358 }
359
360
361
362 if (memOp.storesTo(slot))
363 break;
364 }
365 }
366 }
367
368
369
370 while (!liveInWorkList.empty()) {
371 Block *liveInBlock = liveInWorkList.pop_back_val();
372
373 if (!liveIn.insert(liveInBlock).second)
374 continue;
375
376
377
378
379
380
381
383 if (!definingBlocks.contains(pred))
384 liveInWorkList.push_back(pred);
385 }
386
387 return liveIn;
388}
389
391void MemorySlotPromotionAnalyzer::computeMergePoints(
394 return;
395
397
400 if (auto storeOp = dyn_cast(user))
401 if (storeOp.storesTo(slot))
402 definingBlocks.insert(user->getBlock());
403
404 idfCalculator.setDefiningBlocks(definingBlocks);
405
407 idfCalculator.setLiveInBlocks(liveIn);
408
410 idfCalculator.calculate(mergePointsVec);
411
412 mergePoints.insert_range(mergePointsVec);
413}
414
415bool MemorySlotPromotionAnalyzer::areMergePointsUsable(
416 SmallPtrSetImpl<Block *> &mergePoints) {
417 for (Block *mergePoint : mergePoints)
418 for (Block *pred : mergePoint->getPredecessors())
419 if (!isa(pred->getTerminator()))
420 return false;
421
422 return true;
423}
424
425std::optional
426MemorySlotPromotionAnalyzer::computeInfo() {
427 MemorySlotPromotionInfo info;
428
429
430
431
432
433 if (failed(computeBlockingUses(info.userToBlockingUses)))
434 return {};
435
436
437
438
439 computeMergePoints(info.mergePoints);
440
441
442
443
444 if (!areMergePointsUsable(info.mergePoints))
445 return {};
446
447 return info;
448}
449
450Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
451 Value reachingDef) {
452 SmallVector<Operation *> blockOps;
454 blockOps.push_back(&op);
455 for (Operation *op : blockOps) {
456 if (auto memOp = dyn_cast(op)) {
457 if (info.userToBlockingUses.contains(memOp))
458 reachingDefs.insert({memOp, reachingDef});
459
460 if (memOp.storesTo(slot)) {
462 Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout);
463 assert(stored && "a memory operation storing to a slot must provide a "
464 "new definition of the slot");
465 reachingDef = stored;
466 replacedValuesMap[memOp] = stored;
467 }
468 }
469 }
470
471 return reachingDef;
472}
473
474void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
475 Value reachingDef) {
476 assert(reachingDef && "expected an initial reaching def to be provided");
478 computeReachingDefInBlock(®ion->front(), reachingDef);
479 return;
480 }
481
482 struct DfsJob {
483 llvm::DomTreeNodeBase *block;
484 Value reachingDef;
485 };
486
487 SmallVector dfsStack;
488
490
491 dfsStack.emplace_back(
492 {domTree.getNode(®ion->front()), reachingDef});
493
494 while (!dfsStack.empty()) {
495 DfsJob job = dfsStack.pop_back_val();
496 Block *block = job.block->getBlock();
497
498 if (info.mergePoints.contains(block)) {
499 BlockArgument blockArgument =
502 allocator.handleBlockArgument(slot, blockArgument, builder);
503 job.reachingDef = blockArgument;
504
507 }
508
509 job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
510 assert(job.reachingDef);
511
512 if (auto terminator = dyn_cast(block->getTerminator())) {
513 for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
514 if (info.mergePoints.contains(blockOperand.get())) {
515 terminator.getSuccessorOperands(blockOperand.getOperandNumber())
516 .append(job.reachingDef);
517 }
518 }
519 }
520
521 for (auto *child : job.block->children())
522 dfsStack.emplace_back({child, job.reachingDef});
523 }
524}
525
526
529 auto [it, inserted] = blockIndexCache.try_emplace(region);
531 return it->second;
532
535 for (auto [index, block] : llvm::enumerate(topologicalOrder))
536 blockIndices[block] = index;
537 return blockIndices;
538}
539
540
541
542
544 BlockIndexCache &blockIndexCache) {
545
546
549
550
551
552
554 size_t lhsBlockIndex = topoBlockIndices.at(lhs->getBlock());
555 size_t rhsBlockIndex = topoBlockIndices.at(rhs->getBlock());
556 if (lhsBlockIndex == rhsBlockIndex)
557 return lhs->isBeforeInBlock(rhs);
558 return lhsBlockIndex < rhsBlockIndex;
559 });
560}
561
562void MemorySlotPromoter::removeBlockingUses() {
563 llvm::SmallVector<Operation *> usersToRemoveUses(
564 llvm::make_first_range(info.userToBlockingUses));
565
566
568 blockIndexCache);
569
570 llvm::SmallVector<Operation *> toErase;
571
572 llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList;
573
574 llvm::SmallVector toVisit;
575 for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
576 if (auto toPromoteMemOp = dyn_cast(toPromote)) {
577 Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
578
579
580 if (!reachingDef)
581 reachingDef = getOrCreateDefaultValue();
582
584 if (toPromoteMemOp.removeBlockingUses(
585 slot, info.userToBlockingUses[toPromote], builder, reachingDef,
586 dataLayout) == DeletionKind::Delete)
587 toErase.push_back(toPromote);
588 if (toPromoteMemOp.storesTo(slot))
589 if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
590 replacedValuesList.push_back({toPromoteMemOp, replacedValue});
591 continue;
592 }
593
594 auto toPromoteBasic = cast(toPromote);
596 if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
597 builder) == DeletionKind::Delete)
598 toErase.push_back(toPromote);
599 if (toPromoteBasic.requiresReplacedValues())
600 toVisit.push_back(toPromoteBasic);
601 }
602 for (PromotableOpInterface op : toVisit) {
604 op.visitReplacedValues(replacedValuesList, builder);
605 }
606
607 for (Operation *toEraseOp : toErase)
608 toEraseOp->erase();
609
611 "after promotion, the slot pointer should not be used anymore");
612}
613
614std::optional
615MemorySlotPromoter::promoteSlot() {
617 getOrCreateDefaultValue());
618
619
620 removeBlockingUses();
621
622
623
624 for (Block *mergePoint : info.mergePoints) {
625 for (BlockOperand &use : mergePoint->getUses()) {
626 auto user = cast(use.getOwner());
627 SuccessorOperands succOperands =
628 user.getSuccessorOperands(use.getOperandNumber());
632 succOperands.append(getOrCreateDefaultValue());
633 }
634 }
635
636 LDBG() << "Promoted memory slot: " << slot.ptr;
637
640
641 return allocator.handlePromotionComplete(slot, defaultValue, builder);
642}
643
648 bool promotedAny = false;
649
650
651
652
653 BlockIndexCache blockIndexCache;
654
656
658 newWorkList.reserve(workList.size());
659 while (true) {
660 bool changesInThisRound = false;
661 for (PromotableAllocationOpInterface allocator : workList) {
662 bool changedAllocator = false;
663 for (MemorySlot slot : allocator.getPromotableSlots()) {
665 continue;
666
667 MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
668 std::optional info = analyzer.computeInfo();
669 if (info) {
670 std::optional newAllocator =
671 MemorySlotPromoter(slot, allocator, builder, dominance,
672 dataLayout, std::move(*info), statistics,
673 blockIndexCache)
674 .promoteSlot();
675 changedAllocator = true;
676
677
678 if (newAllocator)
679 newWorkList.push_back(*newAllocator);
680
681
682
683 break;
684 }
685 }
686 if (!changedAllocator)
687 newWorkList.push_back(allocator);
688 changesInThisRound |= changedAllocator;
689 }
690 if (!changesInThisRound)
691 break;
692 promotedAny = true;
693
694
695
696 workList.swap(newWorkList);
697 newWorkList.clear();
698 }
699
700 return success(promotedAny);
701}
702
703namespace {
704
707
708 void runOnOperation() override {
709 Operation *scopeOp = getOperation();
710
711 Mem2RegStatistics statistics{&promotedAmount, &newBlockArgumentAmount};
712
714
715 auto &dataLayoutAnalysis = getAnalysis();
716 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
717 auto &dominance = getAnalysis();
718
719 for (Region ®ion : scopeOp->getRegions()) {
721 continue;
722
723 OpBuilder builder(®ion.front(), region.front().begin());
724
725 SmallVector allocators;
726
727 region.walk([&](PromotableAllocationOpInterface allocator) {
728 allocators.emplace_back(allocator);
729 });
730
731
733 dominance, statistics)))
735 }
737 markAllAnalysesPreserved();
738 }
739};
740
741}
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static void dominanceSort(SmallVector< Operation * > &ops, Region ®ion, BlockIndexCache &blockIndexCache)
Sorts ops according to dominance.
Definition Mem2Reg.cpp:543
static const DenseMap< Block *, size_t > & getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region)
Gets or creates a block index mapping for region.
Definition Mem2Reg.cpp:528
llvm::IDFCalculatorBase< Block, false > IDFCalculator
Definition Mem2Reg.cpp:390
This class represents an argument of a Block.
Block represents an ordered list of Operations.
unsigned getNumArguments()
iterator_range< pred_iterator > getPredecessors()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
The main mechanism for performing data layout queries.
A class for computing basic dominance information.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
void append(ValueRange valueRange)
Add new operands that are forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
DomTree & getDomTree(Region *region) const
Definition Mem2Reg.cpp:831
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
LogicalResult tryToPromoteMemorySlots(ArrayRef< PromotableAllocationOpInterface > allocators, OpBuilder &builder, const DataLayout &dataLayout, DominanceInfo &dominance, Mem2RegStatistics statistics={})
Attempts to promote the memory slots of the provided allocators.
Definition Mem2Reg.cpp:644
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Statistics collected while applying mem2reg.
llvm::Statistic * promotedAmount
Total amount of memory slots promoted.
llvm::Statistic * newBlockArgumentAmount
Total amount of new block arguments inserted in blocks.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.