LLVM: lib/Transforms/Scalar/LoopLoadElimination.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
56#include
57#include
58#include <forward_list>
59#include
60#include
61
62using namespace llvm;
63
64#define LLE_OPTION "loop-load-elim"
65#define DEBUG_TYPE LLE_OPTION
66
68 "runtime-check-per-loop-load-elim", cl::Hidden,
69 cl::desc("Max number of memchecks allowed per eliminated load on average"),
71
73 "loop-load-elimination-scev-check-threshold", cl::init(8), cl::Hidden,
74 cl::desc("The maximum number of SCEV checks allowed for Loop "
75 "Load Elimination"));
76
77STATISTIC(NumLoopLoadEliminted, "Number of loads eliminated by LLE");
78
79namespace {
80
81
82struct StoreToLoadForwardingCandidate {
85
87 : Load(Load), Store(Store) {}
88
89
90
91
92 bool isDependenceDistanceOfOne(PredicatedScalarEvolution &PSE, Loop *L,
93 const DominatorTree &DT) const {
94 Value *LoadPtr = Load->getPointerOperand();
95 Value *StorePtr = Store->getPointerOperand();
97 auto &DL = Load->getDataLayout();
98
101 DL.getTypeSizeInBits(LoadType) ==
103 "Should be a known dependence");
104
105 int64_t StrideLoad =
106 getPtrStride(PSE, LoadType, LoadPtr, L, DT).value_or(0);
107 int64_t StrideStore =
108 getPtrStride(PSE, LoadType, StorePtr, L, DT).value_or(0);
109 if (!StrideLoad || !StrideStore || StrideLoad != StrideStore)
110 return false;
111
112
113
114
115
116
117
118
119 if (std::abs(StrideLoad) != 1)
120 return false;
121
122 unsigned TypeByteSize = DL.getTypeAllocSize(LoadType);
123
126
127
128
131 if (!Dist)
132 return false;
133 const APInt &Val = Dist->getAPInt();
134 return Val == TypeByteSize * StrideLoad;
135 }
136
137 Value *getLoadPtr() const { return Load->getPointerOperand(); }
138
139#ifndef NDEBUG
140 friend raw_ostream &operator<<(raw_ostream &OS,
141 const StoreToLoadForwardingCandidate &Cand) {
142 OS << *Cand.Store << " -->\n";
143 OS.indent(2) << *Cand.Load << "\n";
144 return OS;
145 }
146#endif
147};
148
149}
150
151
152
156 L->getLoopLatches(Latches);
158 return DT->dominates(StoreBlock, Latch);
159 });
160}
161
162
164 return Load->getParent() != L->getHeader();
165}
166
167namespace {
168
169
170class LoadEliminationForLoop {
171public:
172 LoadEliminationForLoop(Loop *L, LoopInfo *LI, const LoopAccessInfo &LAI,
173 DominatorTree *DT, BlockFrequencyInfo *BFI,
174 ProfileSummaryInfo* PSI)
175 : L(L), LI(LI), LAI(LAI), DT(DT), BFI(BFI), PSI(PSI), PSE(LAI.getPSE()) {}
176
177
178
179
180
181
182 std::forward_list
183 findStoreToLoadDependences(const LoopAccessInfo &LAI) {
184 std::forward_list Candidates;
185
186 const auto &DepChecker = LAI.getDepChecker();
187 const auto *Deps = DepChecker.getDependences();
188 if (!Deps)
189 return Candidates;
190
191
192
193
194
195 SmallPtrSet<Instruction *, 4> LoadsWithUnknownDependence;
196
197 for (const auto &Dep : *Deps) {
199 Instruction *Destination = Dep.getDestination(DepChecker);
200
204 LoadsWithUnknownDependence.insert(Source);
206 LoadsWithUnknownDependence.insert(Destination);
207 continue;
208 }
209
210 if (Dep.isBackward())
211
212
213
215 else
216 assert(Dep.isForward() && "Needs to be a forward dependence");
217
219 if (!Store)
220 continue;
222 if (!Load)
223 continue;
224
225
228 Store->getDataLayout()))
229 continue;
230
231 Candidates.emplace_front(Load, Store);
232 }
233
234 if (!LoadsWithUnknownDependence.empty())
235 Candidates.remove_if([&](const StoreToLoadForwardingCandidate &C) {
236 return LoadsWithUnknownDependence.count(C.Load);
237 });
238
239 return Candidates;
240 }
241
242
243 unsigned getInstrIndex(Instruction *Inst) {
244 auto I = InstOrder.find(Inst);
245 assert(I != InstOrder.end() && "No index for instruction");
246 return I->second;
247 }
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268 void removeDependencesFromMultipleStores(
269 std::forward_list &Candidates) {
270
271
272 using LoadToSingleCandT =
273 DenseMap<LoadInst *, const StoreToLoadForwardingCandidate *>;
274 LoadToSingleCandT LoadToSingleCand;
275
276 for (const auto &Cand : Candidates) {
277 bool NewElt;
278 LoadToSingleCandT::iterator Iter;
279
280 std::tie(Iter, NewElt) =
281 LoadToSingleCand.insert(std::make_pair(Cand.Load, &Cand));
282 if (!NewElt) {
283 const StoreToLoadForwardingCandidate *&OtherCand = Iter->second;
284
285 if (OtherCand == nullptr)
286 continue;
287
288
289
290
291 if (Cand.Store->getParent() == OtherCand->Store->getParent() &&
292 Cand.isDependenceDistanceOfOne(PSE, L, *DT) &&
293 OtherCand->isDependenceDistanceOfOne(PSE, L, *DT)) {
294
295 if (getInstrIndex(OtherCand->Store) < getInstrIndex(Cand.Store))
296 OtherCand = &Cand;
297 } else
298 OtherCand = nullptr;
299 }
300 }
301
302 Candidates.remove_if([&](const StoreToLoadForwardingCandidate &Cand) {
303 if (LoadToSingleCand[Cand.Load] != &Cand) {
305 dbgs() << "Removing from candidates: \n"
306 << Cand
307 << " The load may have multiple stores forwarding to "
308 << "it\n");
309 return true;
310 }
311 return false;
312 });
313 }
314
315
316
317
318
319
320 bool needsChecking(unsigned PtrIdx1, unsigned PtrIdx2,
321 const SmallPtrSetImpl<Value *> &PtrsWrittenOnFwdingPath,
322 const SmallPtrSetImpl<Value *> &CandLoadPtrs) {
324 LAI.getRuntimePointerChecking()->getPointerInfo(PtrIdx1).PointerValue;
326 LAI.getRuntimePointerChecking()->getPointerInfo(PtrIdx2).PointerValue;
327 return ((PtrsWrittenOnFwdingPath.count(Ptr1) && CandLoadPtrs.count(Ptr2)) ||
328 (PtrsWrittenOnFwdingPath.count(Ptr2) && CandLoadPtrs.count(Ptr1)));
329 }
330
331
332
333
334
335 SmallPtrSet<Value *, 4> findPointersWrittenOnForwardingPath(
336 const SmallVectorImpl &Candidates) {
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354 LoadInst *LastLoad =
356 [&](const StoreToLoadForwardingCandidate &A,
357 const StoreToLoadForwardingCandidate &B) {
358 return getInstrIndex(A.Load) <
359 getInstrIndex(B.Load);
360 })
361 ->Load;
362 StoreInst *FirstStore =
364 [&](const StoreToLoadForwardingCandidate &A,
365 const StoreToLoadForwardingCandidate &B) {
366 return getInstrIndex(A.Store) <
367 getInstrIndex(B.Store);
368 })
369 ->Store;
370
371
372
373
374 SmallPtrSet<Value *, 4> PtrsWrittenOnFwdingPath;
375
378 PtrsWrittenOnFwdingPath.insert(S->getPointerOperand());
379 };
380 const auto &MemInstrs = LAI.getDepChecker().getMemoryInstructions();
381 std::for_each(MemInstrs.begin() + getInstrIndex(FirstStore) + 1,
382 MemInstrs.end(), InsertStorePtr);
383 std::for_each(MemInstrs.begin(), &MemInstrs[getInstrIndex(LastLoad)],
384 InsertStorePtr);
385
386 return PtrsWrittenOnFwdingPath;
387 }
388
389
390
391 SmallVector<RuntimePointerCheck, 4> collectMemchecks(
392 const SmallVectorImpl &Candidates) {
393
394 SmallPtrSet<Value *, 4> PtrsWrittenOnFwdingPath =
395 findPointersWrittenOnForwardingPath(Candidates);
396
397
398 SmallPtrSet<Value *, 4> CandLoadPtrs;
399 for (const auto &Candidate : Candidates)
400 CandLoadPtrs.insert(Candidate.getLoadPtr());
401
402 const auto &AllChecks = LAI.getRuntimePointerChecking()->getChecks();
403 SmallVector<RuntimePointerCheck, 4> Checks;
404
405 copy_if(AllChecks, std::back_inserter(Checks),
407 for (auto PtrIdx1 : Check.first->Members)
408 for (auto PtrIdx2 : Check.second->Members)
409 if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath,
410 CandLoadPtrs))
411 return true;
412 return false;
413 });
414
416 << "):\n");
417 LLVM_DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks));
418
419 return Checks;
420 }
421
422
423 void
424 propagateStoredValueToLoadUsers(const StoreToLoadForwardingCandidate &Cand,
425 SCEVExpander &SEE) {
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
443 auto *PH = L->getLoopPreheader();
444 assert(PH && "Preheader should exist!");
445 Value *InitialPtr = SEE.expandCodeFor(PtrSCEV->getStart(), Ptr->getType(),
446 PH->getTerminator());
448 new LoadInst(Cand.Load->getType(), InitialPtr, "load_initial",
449 false, Cand.Load->getAlign(),
450 PH->getTerminator()->getIterator());
451
452
453
454
456
458 PHI->insertBefore(L->getHeader()->begin());
459 PHI->addIncoming(Initial, PH);
460
464 (void)DL;
465
466 assert(DL.getTypeSizeInBits(LoadType) == DL.getTypeSizeInBits(StoreType) &&
467 "The type sizes should match!");
468
470 if (LoadType != StoreType) {
472 "store_forward_cast",
474
475
476
478 }
479
480 PHI->addIncoming(StoreValue, L->getLoopLatch());
481
484 }
485
486
487
488 bool processLoop() {
489 LLVM_DEBUG(dbgs() << "\nIn \"" << L->getHeader()->getParent()->getName()
490 << "\" checking " << *L << "\n");
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511 auto StoreToLoadDependences = findStoreToLoadDependences(LAI);
512 if (StoreToLoadDependences.empty())
513 return false;
514
515
516
517 InstOrder = LAI.getDepChecker().generateInstructionOrderMap();
518
519
520
521 removeDependencesFromMultipleStores(StoreToLoadDependences);
522 if (StoreToLoadDependences.empty())
523 return false;
524
525
527 for (const StoreToLoadForwardingCandidate &Cand : StoreToLoadDependences) {
529
530
531
533 continue;
534
535
536
537
539 continue;
540
541
542
543 if (!Cand.isDependenceDistanceOfOne(PSE, L, *DT))
544 continue;
545
547 "Loading from something other than indvar?");
550 "Storing to something other than indvar?");
551
555 << Candidates.size()
556 << ". Valid store-to-load forwarding across the loop backedge\n");
557 }
558 if (Candidates.empty())
559 return false;
560
561
562
563 SmallVector<RuntimePointerCheck, 4> Checks = collectMemchecks(Candidates);
564
565
567 LLVM_DEBUG(dbgs() << "Too many run-time checks needed.\n");
568 return false;
569 }
570
571 if (LAI.getPSE().getPredicate().getComplexity() >
573 LLVM_DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n");
574 return false;
575 }
576
577 if (!L->isLoopSimplifyForm()) {
578 LLVM_DEBUG(dbgs() << "Loop is not is loop-simplify form");
579 return false;
580 }
581
582 if (!Checks.empty() || !LAI.getPSE().getPredicate().isAlwaysTrue()) {
583 if (LAI.hasConvergentOp()) {
584 LLVM_DEBUG(dbgs() << "Versioning is needed but not allowed with "
585 "convergent calls\n");
586 return false;
587 }
588
589 auto *HeaderBB = L->getHeader();
591 PGSOQueryType::IRPass)) {
593 dbgs() << "Versioning is needed but not allowed when optimizing "
594 "for size.\n");
595 return false;
596 }
597
598
599
600
601 LoopVersioning LV(LAI, Checks, L, LI, DT, PSE.getSE());
602 LV.versionLoop();
603
604
605
606 auto NoLongerGoodCandidate = [this](
607 const StoreToLoadForwardingCandidate &Cand) {
612 };
614 }
615
616
617
618 SCEVExpander SEE(*PSE.getSE(), L->getHeader()->getDataLayout(),
619 "storeforward");
620 for (const auto &Cand : Candidates)
621 propagateStoredValueToLoadUsers(Cand, SEE);
622 NumLoopLoadEliminted += Candidates.size();
623
624 return true;
625 }
626
627private:
628 Loop *L;
629
630
631
632 DenseMap<Instruction *, unsigned> InstOrder;
633
634
635 LoopInfo *LI;
636 const LoopAccessInfo &LAI;
637 DominatorTree *DT;
638 BlockFrequencyInfo *BFI;
639 ProfileSummaryInfo *PSI;
640 PredicatedScalarEvolution PSE;
641};
642
643}
644
651
652
653
654
655
657
659
660 for (Loop *TopLevelLoop : LI)
663
664 if (L->isInnermost())
666 }
667
668
669 for (Loop *L : Worklist) {
670
671 if (!L->isRotatedForm() || !L->getExitingBlock())
672 continue;
673
674 LoadEliminationForLoop LEL(L, &LI, LAIs.getInfo(*L), &DT, BFI, PSI);
675 Changed |= LEL.processLoop();
678 }
680}
681
685
686
687 if (LI.empty())
694 auto *BFI = (PSI && PSI->hasProfileSummary()) ?
697
699
702
706 return PA;
707}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
This file implements a class to represent arbitrary precision integral constant values and operations...
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This is the interface for a simple mod/ref and alias analysis over globals.
This header defines various interfaces for pass management in LLVM.
This header provides classes for managing per-loop analyses.
static bool eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT, BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, ScalarEvolution *SE, AssumptionCache *AC, LoopAccessInfoManager &LAIs)
Definition LoopLoadElimination.cpp:645
static cl::opt< unsigned > LoadElimSCEVCheckThreshold("loop-load-elimination-scev-check-threshold", cl::init(8), cl::Hidden, cl::desc("The maximum number of SCEV checks allowed for Loop " "Load Elimination"))
static bool isLoadConditional(LoadInst *Load, Loop *L)
Return true if the load is not executed on all paths in the loop.
Definition LoopLoadElimination.cpp:163
static bool doesStoreDominatesAllLatches(BasicBlock *StoreBlock, Loop *L, DominatorTree *DT)
Check if the store dominates all latches, so as long as there is no intervening store this value will...
Definition LoopLoadElimination.cpp:153
static cl::opt< unsigned > CheckPerElim("runtime-check-per-loop-load-elim", cl::Hidden, cl::desc("Max number of memchecks allowed per eliminated load on average"), cl::init(1))
This header defines the LoopLoadEliminationPass object.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
This pass exposes codegen information to IR-level passes.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
A function analysis which provides an AssumptionCache.
A cache of @llvm.assume calls within a function.
LLVM Basic Block Representation.
Analysis pass which computes BlockFrequencyInfo.
BlockFrequencyInfo pass uses BlockFrequencyInfoImpl implementation to estimate IR basic block frequen...
static LLVM_ABI bool isBitOrNoopPointerCastable(Type *SrcTy, Type *DestTy, const DataLayout &DL)
Check whether a bitcast, inttoptr, or ptrtoint cast between these types is valid and a no-op.
static LLVM_ABI CastInst * CreateBitOrPointerCast(Value *S, Type *Ty, const Twine &Name="", InsertPosition InsertBefore=nullptr)
Create a BitCast, a PtrToInt, or an IntToPTr cast instruction.
static DebugLoc getDropped()
Analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI const DataLayout & getDataLayout() const
Get the data layout of the module this instruction belongs to.
An instruction for reading from memory.
Value * getPointerOperand()
Align getAlign() const
Return the alignment of the access that is being performed.
This analysis provides dependence information for the memory accesses of a loop.
LLVM_ABI const LoopAccessInfo & getInfo(Loop &L, bool AllowPartial=false)
Analysis pass that exposes the LoopInfo for a function.
Represents a single loop in the control flow graph.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
ScalarEvolution * getSE() const
Returns the ScalarEvolution analysis used.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserve()
Mark an analysis as preserved.
An analysis pass based on the new PM to deliver ProfileSummaryInfo.
Analysis providing profile information.
Analysis pass that exposes the ScalarEvolution for a function.
The main scalar evolution driver.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Value * getValueOperand()
Value * getPointerOperand()
LLVM_ABI unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
const ParentTy * getParent() const
self_iterator getIterator()
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
@ C
The default llvm calling convention, compatible with C.
initializer< Ty > init(const Ty &Val)
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
LLVM_ABI bool simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, MemorySSAUpdater *MSSAU, bool PreserveLCSSA)
Simplify each loop in a loop nest recursively.
FunctionAddr VTableAddr Value
auto min_element(R &&Range)
Provide wrappers to std::min_element which take ranges instead of having to pass begin/end explicitly...
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
std::pair< const RuntimeCheckingPtrGroup *, const RuntimeCheckingPtrGroup * > RuntimePointerCheck
A memcheck which made up of a pair of grouped pointers.
decltype(auto) dyn_cast(const From &Val)
dyn_cast - Return the argument parameter cast to the specified type.
OuterAnalysisManagerProxy< ModuleAnalysisManager, Function > ModuleAnalysisManagerFunctionProxy
Provide the ModuleAnalysisManager to Function proxy.
LLVM_ABI bool shouldOptimizeForSize(const MachineFunction *MF, ProfileSummaryInfo *PSI, const MachineBlockFrequencyInfo *BFI, PGSOQueryType QueryType=PGSOQueryType::Other)
Returns true if machine function MF is suggested to be size-optimized based on the profile.
OutputIt copy_if(R &&Range, OutputIt Out, UnaryPredicate P)
Provide wrappers to std::copy_if which take ranges instead of having to pass begin/end explicitly.
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa - Return true if the parameter to the template is an instance of one of the template type argu...
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
decltype(auto) cast(const From &Val)
cast - Return the argument parameter cast to the specified type.
void erase_if(Container &C, UnaryPredicate P)
Provide a container algorithm similar to C++ Library Fundamentals v2's erase_if which is equivalent t...
Type * getLoadStoreType(const Value *I)
A helper function that returns the type of a load or store instruction.
iterator_range< df_iterator< T > > depth_first(const T &G)
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI std::optional< int64_t > getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, Value *Ptr, const Loop *Lp, const DominatorTree &DT, const DenseMap< Value *, const SCEV * > &StridesMap=DenseMap< Value *, const SCEV * >(), bool Assume=false, bool ShouldCheckWrap=true)
If the pointer has a constant stride return it in units of the access type size.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Definition LoopLoadElimination.cpp:682