LLVM: lib/Transforms/Scalar/StraightLineStrengthReduce.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
100#include
101#include
102#include
103#include
104#include
105#include
106
107using namespace llvm;
109
110#define DEBUG_TYPE "slsr"
111
113 std::numeric_limits::max();
114
115DEBUG_COUNTER(StraightLineStrengthReduceCounter, "slsr-counter",
116 "Controls whether rewriteCandidate is executed.");
117
118
121 cl::desc("Enable poison-reuse guard"));
122
123namespace {
124
125class StraightLineStrengthReduceLegacyPass : public FunctionPass {
127
128public:
129 static char ID;
130
131 StraightLineStrengthReduceLegacyPass() : FunctionPass(ID) {
134 }
135
136 void getAnalysisUsage(AnalysisUsage &AU) const override {
137 AU.addRequired();
138 AU.addRequired();
139 AU.addRequired();
140
142 }
143
144 bool doInitialization(Module &M) override {
145 DL = &M.getDataLayout();
146 return false;
147 }
148
150};
151
152class StraightLineStrengthReduce {
153public:
154 StraightLineStrengthReduce(const DataLayout *DL, DominatorTree *DT,
155 ScalarEvolution *SE, TargetTransformInfo *TTI)
156 : DL(DL), DT(DT), SE(SE), TTI(TTI) {}
157
158
159
160 struct Candidate {
161 enum Kind {
162 Invalid,
163 Add,
164 Mul,
165 GEP,
166 };
167
168 enum DKind {
169 InvalidDelta,
170 IndexDelta,
171 BaseDelta,
172 StrideDelta,
173 };
174
175 Candidate() = default;
176 Candidate(Kind CT, const SCEV *B, ConstantInt *Idx, Value *S,
177 Instruction *I, const SCEV *StrideSCEV)
178 : CandidateKind(CT), Base(B), Index(Idx), Stride(S), Ins(I),
179 StrideSCEV(StrideSCEV) {}
180
181 Kind CandidateKind = Invalid;
182
183 const SCEV *Base = nullptr;
184
185
186
187
188 ConstantInt *Index = nullptr;
189
190 Value *Stride = nullptr;
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
207
208
209
210 Candidate *Basis = nullptr;
211
212 DKind DeltaKind = InvalidDelta;
213
214
215 const SCEV *StrideSCEV = nullptr;
216
217
218 Value *Delta = nullptr;
219
220
221
222
223
224
225
226
227
228 enum EfficiencyLevel : unsigned {
229 Unknown = 0,
230 TwoInstTwoVar = 1,
231 TwoInstOneVar = 2,
232 OneInstTwoVar = 3,
233 OneInstOneVar = 4,
234 ZeroInst = 5
235 };
236
237 static EfficiencyLevel
238 getComputationEfficiency(Kind CandidateKind, const ConstantInt *Index,
239 const Value *Stride, const SCEV *Base = nullptr) {
240 bool IsConstantBase = false;
241 bool IsZeroBase = false;
242
243
245 IsConstantBase = true;
246 IsZeroBase = ConstBase->getValue()->isZero();
247 }
248
250 bool IsZeroStride =
252
253 if (IsConstantBase && IsConstantStride)
254 return ZeroInst;
255
256
257 if (CandidateKind == Mul) {
258 if (IsZeroStride)
259 return ZeroInst;
260 if (Index->isZero())
261 return (IsConstantStride || IsConstantBase) ? OneInstOneVar
262 : OneInstTwoVar;
263
264 if (IsConstantBase)
265 return IsZeroBase && (Index->isOne() || Index->isMinusOne())
266 ? ZeroInst
267 : OneInstOneVar;
268
269 if (IsConstantStride) {
271 return (CI->isOne() || CI->isMinusOne()) ? OneInstOneVar
272 : TwoInstOneVar;
273 }
274 return TwoInstTwoVar;
275 }
276
277
278 assert(CandidateKind == Add || CandidateKind == GEP);
279 if (Index->isZero() || IsZeroStride)
280 return ZeroInst;
281
282 bool IsSimpleIndex = Index->isOne() || Index->isMinusOne();
283
284 if (IsConstantBase)
285 return IsZeroBase ? (IsSimpleIndex ? ZeroInst : OneInstOneVar)
286 : (IsSimpleIndex ? OneInstOneVar : TwoInstOneVar);
287
288 if (IsConstantStride)
289 return IsZeroStride ? ZeroInst : OneInstOneVar;
290
291 if (IsSimpleIndex)
292 return OneInstTwoVar;
293
294 return TwoInstTwoVar;
295 }
296
297
298 bool isProfitableRewrite(const Value &Delta, const DKind DeltaKind) const {
299
300
301
302
303
304
305
306
307
308
309
310 return getComputationEfficiency(CandidateKind, Index, Stride, Base) <=
311 getRewriteEfficiency(Delta, DeltaKind);
312 }
313
314
315 EfficiencyLevel getRewriteEfficiency() const {
316 return Basis ? getRewriteEfficiency(*Delta, DeltaKind) : Unknown;
317 }
318
319
320 EfficiencyLevel getRewriteEfficiency(const Value &Delta,
321 const DKind DeltaKind) const {
322 switch (DeltaKind) {
323 case BaseDelta:
324 return getComputationEfficiency(
325 CandidateKind,
326 ConstantInt::get(cast(Delta.getType()), 1), &Delta);
327 case StrideDelta:
328 return getComputationEfficiency(CandidateKind, Index, &Delta);
329 case IndexDelta:
330 return getComputationEfficiency(CandidateKind,
332 default:
333 return Unknown;
334 }
335 }
336
337 bool isHighEfficiency() const {
338 return getComputationEfficiency(CandidateKind, Index, Stride, Base) >=
339 OneInstOneVar;
340 }
341
342
343
344 bool hasValidDelta(const Candidate &Basis) const {
345 switch (DeltaKind) {
346 case IndexDelta:
347
348 return Base == Basis.Base && StrideSCEV == Basis.StrideSCEV;
349 case StrideDelta:
350
351 return Base == Basis.Base && Index == Basis.Index;
352 case BaseDelta:
353
354 return StrideSCEV == Basis.StrideSCEV && Index == Basis.Index;
355 default:
356 return false;
357 }
358 }
359 };
360
362
363private:
364
365
366 void setBasisAndDeltaFor(Candidate &C);
367
368 bool isFoldable(const Candidate &C, TargetTransformInfo *TTI);
369
370
371
372 void allocateCandidatesAndFindBasis(Instruction *I);
373
374
375 void allocateCandidatesAndFindBasisForAdd(Instruction *I);
376
377
378
379 void allocateCandidatesAndFindBasisForAdd(Value *LHS, Value *RHS,
380 Instruction *I);
381
382 void allocateCandidatesAndFindBasisForMul(Instruction *I);
383
384
385
386 void allocateCandidatesAndFindBasisForMul(Value *LHS, Value *RHS,
387 Instruction *I);
388
389
390 void allocateCandidatesAndFindBasisForGEP(GetElementPtrInst *GEP);
391
392
393
394 void allocateCandidatesAndFindBasis(Candidate::Kind CT, const SCEV *B,
395 ConstantInt *Idx, Value *S,
396 Instruction *I);
397
398
399 void rewriteCandidate(const Candidate &C);
400
401
402 static Value *emitBump(const Candidate &Basis, const Candidate &C,
403 IRBuilder<> &Builder, const DataLayout *DL);
404
405 const DataLayout *DL = nullptr;
406 DominatorTree *DT = nullptr;
407 ScalarEvolution *SE;
408 TargetTransformInfo *TTI = nullptr;
409 std::list Candidates;
410
411
412
413 DenseMap<const SCEV *, SmallSetVector<Instruction *, 2>> SCEVToInsts;
414
415
416
417 MapVector<Instruction *, std::vector<Instruction *>> DependencyGraph;
418
419
420 DenseMap<Instruction *, SmallVector<Candidate *, 3>> RewriteCandidates;
421
422
423
424 std::vector<Instruction *> SortedCandidateInsts;
425
426
427
428 std::vector<Instruction *> DeadInstructions;
429
430
431 class CandidateDictTy {
432 public:
433 using CandsTy = SmallVector<Candidate *, 8>;
434 using BBToCandsTy = DenseMap<const BasicBlock *, CandsTy>;
435
436 private:
437
438 using IndexDeltaKeyTy = std::tuple<const SCEV *, const SCEV *, Type *>;
439 DenseMap<IndexDeltaKeyTy, BBToCandsTy> IndexDeltaCandidates;
440
441
442 using BaseDeltaKeyTy = std::tuple<const SCEV *, ConstantInt *, Type *>;
443 DenseMap<BaseDeltaKeyTy, BBToCandsTy> BaseDeltaCandidates;
444
445
446 using StrideDeltaKeyTy = std::tuple<const SCEV *, ConstantInt *, Type *>;
447 DenseMap<StrideDeltaKeyTy, BBToCandsTy> StrideDeltaCandidates;
448
449 public:
450
451
452 const BBToCandsTy *getCandidatesWithDeltaKind(const Candidate &C,
453 Candidate::DKind K) const {
454 assert(K != Candidate::InvalidDelta);
455 if (K == Candidate::IndexDelta) {
456 IndexDeltaKeyTy IndexDeltaKey(C.Base, C.StrideSCEV, C.Ins->getType());
457 auto It = IndexDeltaCandidates.find(IndexDeltaKey);
458 if (It != IndexDeltaCandidates.end())
459 return &It->second;
460 } else if (K == Candidate::BaseDelta) {
461 BaseDeltaKeyTy BaseDeltaKey(C.StrideSCEV, C.Index, C.Ins->getType());
462 auto It = BaseDeltaCandidates.find(BaseDeltaKey);
463 if (It != BaseDeltaCandidates.end())
464 return &It->second;
465 } else {
466 assert(K == Candidate::StrideDelta);
467 StrideDeltaKeyTy StrideDeltaKey(C.Base, C.Index, C.Ins->getType());
468 auto It = StrideDeltaCandidates.find(StrideDeltaKey);
469 if (It != StrideDeltaCandidates.end())
470 return &It->second;
471 }
472 return nullptr;
473 }
474
475
479 IndexDeltaKeyTy IndexDeltaKey(C.Base, C.StrideSCEV, ValueType);
480 BaseDeltaKeyTy BaseDeltaKey(C.StrideSCEV, C.Index, ValueType);
481 StrideDeltaKeyTy StrideDeltaKey(C.Base, C.Index, ValueType);
482 IndexDeltaCandidates[IndexDeltaKey][BB].push_back(&C);
483 BaseDeltaCandidates[BaseDeltaKey][BB].push_back(&C);
484 StrideDeltaCandidates[StrideDeltaKey][BB].push_back(&C);
485 }
486
487 void clear() {
488 IndexDeltaCandidates.clear();
489 BaseDeltaCandidates.clear();
490 StrideDeltaCandidates.clear();
491 }
492 } CandidateDict;
493
494 const SCEV *getAndRecordSCEV(Value *V) {
495 auto *S = SE->getSCEV(V);
499
500 return S;
501 }
502
503 bool candidatePredicate(Candidate *Basis, Candidate &C, Candidate::DKind K);
504
505 bool searchFrom(const CandidateDictTy::BBToCandsTy &BBToCands, Candidate &C,
506 Candidate::DKind K);
507
508
509
510
511 Value *getNearestValueOfSCEV(const SCEV *S, const Instruction *CI) const {
513 return nullptr;
514
516 return SU->getValue();
518 return SC->getValue();
519
520 auto It = SCEVToInsts.find(S);
521 if (It == SCEVToInsts.end())
522 return nullptr;
523
524
525
526 for (Instruction *I : reverse(It->second))
527 if (DT->dominates(I, CI))
528 return I;
529
530 return nullptr;
531 }
532
533 struct DeltaInfo {
534 Candidate *Cand;
535 Candidate::DKind DeltaKind;
537
538 DeltaInfo()
539 : Cand(nullptr), DeltaKind(Candidate::InvalidDelta), Delta(nullptr) {}
540 DeltaInfo(Candidate *Cand, Candidate::DKind DeltaKind, Value *Delta)
541 : Cand(Cand), DeltaKind(DeltaKind), Delta(Delta) {}
542 operator bool() const { return Cand != nullptr; }
543 };
544
545 friend raw_ostream &operator<<(raw_ostream &OS, const DeltaInfo &DI);
546
547 DeltaInfo compressPath(Candidate &C, Candidate *Basis) const;
548
549 Candidate *pickRewriteCandidate(Instruction *I) const;
550 void sortCandidateInstructions();
551 Value *getDelta(const Candidate &C, const Candidate &Basis,
552 Candidate::DKind K) const;
553 static bool isSimilar(Candidate &C, Candidate &Basis, Candidate::DKind K);
554
555
556
557 void addDependency(Candidate &C, Candidate *Basis) {
558 if (Basis)
559 DependencyGraph[Basis->Ins].emplace_back(C.Ins);
560
561
562
563
564 auto PropagateDependency = [&](Instruction *Inst) {
565 if (auto CandsIt = RewriteCandidates.find(Inst);
566 CandsIt != RewriteCandidates.end() &&
568 [](Candidate *Cand) { return Cand->Basis; }))
569 DependencyGraph[Inst].emplace_back(C.Ins);
570 };
571
572
573
575 PropagateDependency(DeltaInst);
576
577
579 PropagateDependency(StrideInst);
580 };
581};
582
584 const StraightLineStrengthReduce::Candidate &C) {
585 OS << "Ins: " << *C.Ins << "\n Base: " << *C.Base
586 << "\n Index: " << *C.Index << "\n Stride: " << *C.Stride
587 << "\n StrideSCEV: " << *C.StrideSCEV;
588 if (C.Basis)
589 OS << "\n Delta: " << *C.Delta << "\n Basis: \n [ " << *C.Basis << " ]";
590 return OS;
591}
592
595 OS << "Cand: " << *DI.Cand << "\n";
596 OS << "Delta Kind: ";
597 switch (DI.DeltaKind) {
598 case StraightLineStrengthReduce::Candidate::IndexDelta:
599 OS << "Index";
600 break;
601 case StraightLineStrengthReduce::Candidate::BaseDelta:
602 OS << "Base";
603 break;
604 case StraightLineStrengthReduce::Candidate::StrideDelta:
605 OS << "Stride";
606 break;
607 default:
608 break;
609 }
610 OS << "\nDelta: " << *DI.Delta;
611 return OS;
612}
613
614}
615
616char StraightLineStrengthReduceLegacyPass::ID = 0;
617
619 "Straight line strength reduction", false, false)
624 "Straight line strength reduction", false, false)
625
627 return new StraightLineStrengthReduceLegacyPass();
628}
629
630
632 if (A.getBitWidth() < B.getBitWidth())
633 A = A.sext(B.getBitWidth());
634 else if (A.getBitWidth() > B.getBitWidth())
635 B = B.sext(A.getBitWidth());
636}
637
638Value *StraightLineStrengthReduce::getDelta(const Candidate &C,
639 const Candidate &Basis,
640 Candidate::DKind K) const {
641 if (K == Candidate::IndexDelta) {
642 APInt Idx = C.Index->getValue();
643 APInt BasisIdx = Basis.Index->getValue();
645 APInt IndexDelta = Idx - BasisIdx;
646 IntegerType *DeltaType =
648 return ConstantInt::get(DeltaType, IndexDelta);
649 } else if (K == Candidate::BaseDelta || K == Candidate::StrideDelta) {
650 const SCEV *BasisPart =
651 (K == Candidate::BaseDelta) ? Basis.Base : Basis.StrideSCEV;
652 const SCEV *CandPart = (K == Candidate::BaseDelta) ? C.Base : C.StrideSCEV;
653 const SCEV *Diff = SE->getMinusSCEV(CandPart, BasisPart);
654 return getNearestValueOfSCEV(Diff, C.Ins);
655 }
656 return nullptr;
657}
658
659bool StraightLineStrengthReduce::isSimilar(Candidate &C, Candidate &Basis,
660 Candidate::DKind K) {
661 bool SameType = false;
662 switch (K) {
663 case Candidate::StrideDelta:
664 SameType = C.StrideSCEV->getType() == Basis.StrideSCEV->getType();
665 break;
666 case Candidate::BaseDelta:
667 SameType = C.Base->getType() == Basis.Base->getType();
668 break;
669 case Candidate::IndexDelta:
670 SameType = true;
671 break;
672 default:;
673 }
674 return SameType && Basis.Ins != C.Ins &&
675 Basis.CandidateKind == C.CandidateKind;
676}
677
678
679
680
681
682bool StraightLineStrengthReduce::candidatePredicate(Candidate *Basis,
683 Candidate &C,
684 Candidate::DKind K) {
686
687 if (!isSimilar(C, *Basis, K) ||
690 DropPoisonGeneratingInsts)))
691 return false;
692
694 Value *Delta = getDelta(C, *Basis, K);
695 if (!Delta)
696 return false;
697
698
699
700
701
702
703
704 if (K == Candidate::IndexDelta &&
705 .isProfitableRewrite(*Delta, Candidate::IndexDelta))
706 return false;
707
708
709
710
711 for (Instruction *I : DropPoisonGeneratingInsts)
712 I->dropPoisonGeneratingAnnotations();
713
714
715
717 C.Delta = Delta;
718 C.Basis = Basis;
720 }
722}
723
724
725
726
727bool StraightLineStrengthReduce::searchFrom(
728 const CandidateDictTy::BBToCandsTy &BBToCands, Candidate &C,
729 Candidate::DKind K) {
730
731
732
733 if (C.CandidateKind == Candidate::Mul && K != Candidate::IndexDelta)
734 return false;
735
736
737
738
740 while (BB) {
741 auto It = BBToCands.find(BB);
742 if (It != BBToCands.end())
743 for (Candidate *Basis : reverse(It->second))
744 if (candidatePredicate(Basis, C, K))
745 return true;
746
748 if (!Node)
749 break;
751 BB = Node ? Node->getBlock() : nullptr;
752 }
753 return false;
754}
755
756void StraightLineStrengthReduce::setBasisAndDeltaFor(Candidate &C) {
757 if (const auto *BaseDeltaCandidates =
758 CandidateDict.getCandidatesWithDeltaKind(C, Candidate::BaseDelta))
759 if (searchFrom(*BaseDeltaCandidates, C, Candidate::BaseDelta)) {
760 LLVM_DEBUG(dbgs() << "Found delta from Base: " << *C.Delta << "\n");
761 return;
762 }
763
764 if (const auto *StrideDeltaCandidates =
765 CandidateDict.getCandidatesWithDeltaKind(C, Candidate::StrideDelta))
766 if (searchFrom(*StrideDeltaCandidates, C, Candidate::StrideDelta)) {
767 LLVM_DEBUG(dbgs() << "Found delta from Stride: " << *C.Delta << "\n");
768 return;
769 }
770
771 if (const auto *IndexDeltaCandidates =
772 CandidateDict.getCandidatesWithDeltaKind(C, Candidate::IndexDelta))
773 if (searchFrom(*IndexDeltaCandidates, C, Candidate::IndexDelta)) {
774 LLVM_DEBUG(dbgs() << "Found delta from Index: " << *C.Delta << "\n");
775 return;
776 }
777
778
779 if (C.Delta) {
781 dbgs() << "Found delta from ";
782 if (C.DeltaKind == Candidate::BaseDelta)
783 dbgs() << "Base: ";
784 else
785 dbgs() << "Stride: ";
786 dbgs() << *C.Delta << "\n";
787 });
788 assert(C.DeltaKind != Candidate::InvalidDelta && C.Basis);
789 }
790}
791
792
793
794
795
796
797
798
799
800
801
802auto StraightLineStrengthReduce::compressPath(Candidate &C,
803 Candidate *Basis) const
804 -> DeltaInfo {
805 if (!Basis || !Basis->Basis || C.CandidateKind == Candidate::Mul)
806 return {};
807 Candidate *Root = Basis;
808 Value *NewDelta = nullptr;
809 auto NewKind = Candidate::InvalidDelta;
810
811 while (Root->Basis) {
812 Candidate *NextRoot = Root->Basis;
813 if (C.Base == NextRoot->Base && C.StrideSCEV == NextRoot->StrideSCEV &&
814 isSimilar(C, *NextRoot, Candidate::IndexDelta)) {
815 ConstantInt *CI =
818 Root = NextRoot;
819 NewKind = Candidate::IndexDelta;
820 NewDelta = CI;
821 continue;
822 }
823 }
824
825 const SCEV *CandPart = nullptr;
826 const SCEV *BasisPart = nullptr;
827 auto CurrKind = Candidate::InvalidDelta;
828 if (C.Base == NextRoot->Base && C.Index == NextRoot->Index) {
829 CandPart = C.StrideSCEV;
830 BasisPart = NextRoot->StrideSCEV;
831 CurrKind = Candidate::StrideDelta;
832 } else if (C.StrideSCEV == NextRoot->StrideSCEV &&
833 C.Index == NextRoot->Index) {
834 CandPart = C.Base;
835 BasisPart = NextRoot->Base;
836 CurrKind = Candidate::BaseDelta;
837 } else
838 break;
839
840 assert(CandPart && BasisPart);
841 if (!isSimilar(C, *NextRoot, CurrKind))
842 break;
843
844 if (auto DeltaVal =
846 Root = NextRoot;
847 NewDelta = DeltaVal->getValue();
848 NewKind = CurrKind;
849 } else
850 break;
851 }
852
853 if (Root != Basis) {
854 assert(NewKind != Candidate::InvalidDelta && NewDelta);
855 LLVM_DEBUG(dbgs() << "Found new Basis with " << *NewDelta
856 << " from path compression.\n");
857 return {Root, NewKind, NewDelta};
858 }
859
860 return {};
861}
862
863
864
865void StraightLineStrengthReduce::sortCandidateInstructions() {
866 SortedCandidateInsts.clear();
867
868
869
870
871
872 DenseMap<Instruction *, int> InDegree;
873 for (auto &KV : DependencyGraph) {
875
876 for (auto *Child : KV.second) {
877 InDegree[Child]++;
878 }
879 }
880 std::queue<Instruction *> WorkList;
881 DenseSet<Instruction *> Visited;
882
883 for (auto &KV : DependencyGraph)
884 if (InDegree[KV.first] == 0)
885 WorkList.push(KV.first);
886
887 while (!WorkList.empty()) {
889 WorkList.pop();
890 if (!Visited.insert(I).second)
891 continue;
892
893 SortedCandidateInsts.push_back(I);
894
895 for (auto *Next : DependencyGraph[I]) {
896 auto &Degree = InDegree[Next];
897 if (--Degree == 0)
898 WorkList.push(Next);
899 }
900 }
901
902 assert(SortedCandidateInsts.size() == DependencyGraph.size() &&
903 "Dependency graph should not have cycles");
904}
905
906auto StraightLineStrengthReduce::pickRewriteCandidate(Instruction *I) const
907 -> Candidate * {
908
909 auto It = RewriteCandidates.find(I);
910 if (It == RewriteCandidates.end())
911 return nullptr;
912
913 Candidate *BestC = nullptr;
914 auto BestEfficiency = Candidate::Unknown;
915 for (Candidate *C : reverse(It->second))
916 if (C->Basis) {
917 auto Efficiency = C->getRewriteEfficiency();
918 if (Efficiency > BestEfficiency) {
919 BestEfficiency = Efficiency;
920 BestC = C;
921 }
922 }
923
924 return BestC;
925}
926
930 return TTI->getGEPCost(GEP->getSourceElementType(), GEP->getPointerOperand(),
932}
933
934
937
938 return Index->getBitWidth() <= 64 &&
939 TTI->isLegalAddressingMode(Base->getType(), nullptr, 0, true,
941}
942
943bool StraightLineStrengthReduce::isFoldable(const Candidate &C,
944 TargetTransformInfo *TTI) {
945 if (C.CandidateKind == Candidate::Add)
947 if (C.CandidateKind == Candidate::GEP)
949 return false;
950}
951
952void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
953 Candidate::Kind CT, const SCEV *B, ConstantInt *Idx, Value *S,
954 Instruction *I) {
955
956
957
958 Candidate C(CT, B, Idx, S, I, getAndRecordSCEV(S));
959
960
961
962
963
964
965
966
967
968
969 if (!isFoldable(C, TTI) && .isHighEfficiency()) {
970 setBasisAndDeltaFor(C);
971
972
973 if (auto Res = compressPath(C, C.Basis)) {
974 C.Basis = Res.Cand;
975 C.DeltaKind = Res.DeltaKind;
976 C.Delta = Res.Delta;
977 }
978 }
979
980
981 LLVM_DEBUG(dbgs() << "Allocated Candidate: " << C << "\n");
982 Candidates.push_back(C);
983 RewriteCandidates[C.Ins].push_back(&Candidates.back());
984 CandidateDict.add(Candidates.back());
985}
986
987void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
988 Instruction *I) {
989 switch (I->getOpcode()) {
990 case Instruction::Add:
991 allocateCandidatesAndFindBasisForAdd(I);
992 break;
993 case Instruction::Mul:
994 allocateCandidatesAndFindBasisForMul(I);
995 break;
996 case Instruction::GetElementPtr:
998 break;
999 }
1000}
1001
1002void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
1003 Instruction *I) {
1004
1006 return;
1007
1008 assert(I->getNumOperands() == 2 && "isn't I an add?");
1009 Value *LHS = I->getOperand(0), *RHS = I->getOperand(1);
1010 allocateCandidatesAndFindBasisForAdd(LHS, RHS, I);
1012 allocateCandidatesAndFindBasisForAdd(RHS, LHS, I);
1013}
1014
1015void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
1017 Value *S = nullptr;
1018 ConstantInt *Idx = nullptr;
1020
1021 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), Idx, S, I);
1023
1026 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), Idx, S, I);
1027 } else {
1028
1029 ConstantInt *One = ConstantInt::get(cast(I->getType()), 1);
1030 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), One, RHS,
1031 I);
1032 }
1033}
1034
1035
1039
1040
1044
1045void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
1048 ConstantInt *Idx = nullptr;
1050
1051
1052 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS, I);
1054
1055
1056
1057
1058 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS, I);
1059 } else {
1060
1062 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(LHS), Zero, RHS,
1063 I);
1064 }
1065}
1066
1067void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
1068 Instruction *I) {
1069
1070
1072 return;
1073
1074 assert(I->getNumOperands() == 2 && "isn't I a mul?");
1075 Value *LHS = I->getOperand(0), *RHS = I->getOperand(1);
1076 allocateCandidatesAndFindBasisForMul(LHS, RHS, I);
1078
1079 allocateCandidatesAndFindBasisForMul(RHS, LHS, I);
1080 }
1081}
1082
1083void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
1084 GetElementPtrInst *GEP) {
1085
1086 if (GEP->getType()->isVectorTy())
1087 return;
1088
1090 for (Use &Idx : GEP->indices())
1092
1094 for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) {
1096 continue;
1097
1098 const SCEV *OrigIndexExpr = IndexExprs[I - 1];
1099 IndexExprs[I - 1] = SE->getZero(OrigIndexExpr->getType());
1100
1101
1102
1104 Value *ArrayIdx = GEP->getOperand(I);
1107 ConstantInt *ElementSizeIdx = ConstantInt::get(PtrIdxTy, ElementSize, true);
1109 DL->getIndexSizeInBits(GEP->getAddressSpace())) {
1110
1111
1112 allocateCandidatesAndFindBasis(Candidate::GEP, BaseExpr, ElementSizeIdx,
1113 ArrayIdx, GEP);
1114 }
1115
1116
1117
1118 Value *TruncatedArrayIdx = nullptr;
1121 DL->getIndexSizeInBits(GEP->getAddressSpace())) {
1122
1123
1124 allocateCandidatesAndFindBasis(Candidate::GEP, BaseExpr, ElementSizeIdx,
1125 TruncatedArrayIdx, GEP);
1126 }
1127
1128 IndexExprs[I - 1] = OrigIndexExpr;
1129 }
1130}
1131
1132Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis,
1133 const Candidate &C,
1135 const DataLayout *DL) {
1138 const APInt &ConstRHS = CR->getValue();
1139 IntegerType *DeltaType =
1143 ConstantInt::get(DeltaType, ConstRHS.logBase2());
1145 }
1148 ConstantInt::get(DeltaType, (-ConstRHS).logBase2());
1150 }
1151 }
1152
1154 };
1155
1157
1158
1161 return nullptr;
1162
1163 if (C.DeltaKind == Candidate::IndexDelta) {
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174 if (IndexDelta == 1)
1175 return C.Stride;
1176
1179
1180 IntegerType *DeltaType =
1183
1184 return CreateMul(ExtendedStride, C.Delta);
1185 }
1186
1187 assert(C.DeltaKind == Candidate::StrideDelta ||
1188 C.DeltaKind == Candidate::BaseDelta);
1189 assert(C.CandidateKind != Candidate::Mul);
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1205 if (C.DeltaKind == Candidate::StrideDelta) {
1206
1207
1208 if (C.CandidateKind == Candidate::GEP) {
1210 Type *NewScalarIndexTy =
1211 DL->getIndexType(GEP->getPointerOperandType()->getScalarType());
1213 }
1214 if (.Index->isOne()) {
1215 Value *ExtendedIndex =
1217 Bump = CreateMul(Bump, ExtendedIndex);
1218 }
1219 }
1220 return Bump;
1221}
1222
1223void StraightLineStrengthReduce::rewriteCandidate(const Candidate &C) {
1225 return;
1226
1227 const Candidate &Basis = *C.Basis;
1228 assert(C.Delta && C.CandidateKind == Basis.CandidateKind &&
1229 C.hasValidDelta(Basis));
1230
1232 Value *Bump = emitBump(Basis, C, Builder, DL);
1233 Value *Reduced = nullptr;
1234
1235
1236 if (!Bump)
1237 Reduced = Basis.Ins;
1238 else {
1239 switch (C.CandidateKind) {
1240 case Candidate::Add:
1241 case Candidate::Mul: {
1242
1245
1246 Reduced = Builder.CreateSub(Basis.Ins, NegBump);
1247
1248
1250 } else {
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260 Reduced = Builder.CreateAdd(Basis.Ins, Bump);
1261 }
1262 break;
1263 }
1264 case Candidate::GEP: {
1266
1267 Reduced = Builder.CreatePtrAdd(Basis.Ins, Bump, "", InBounds);
1268 break;
1269 }
1270 default:
1272 };
1274 }
1275 C.Ins->replaceAllUsesWith(Reduced);
1276 DeadInstructions.push_back(C.Ins);
1277}
1278
1279bool StraightLineStrengthReduceLegacyPass::runOnFunction(Function &F) {
1280 if (skipFunction(F))
1281 return false;
1282
1283 auto *TTI = &getAnalysis().getTTI(F);
1284 auto *DT = &getAnalysis().getDomTree();
1285 auto *SE = &getAnalysis().getSE();
1286 return StraightLineStrengthReduce(DL, DT, SE, TTI).runOnFunction(F);
1287}
1288
1289bool StraightLineStrengthReduce::runOnFunction(Function &F) {
1290 LLVM_DEBUG(dbgs() << "SLSR on Function: " << F.getName() << "\n");
1291
1292
1294 for (auto &I : *(Node->getBlock()))
1295 allocateCandidatesAndFindBasis(&I);
1296
1297
1298
1299 for (auto &C : Candidates) {
1300 DependencyGraph.try_emplace(C.Ins);
1301 addDependency(C, C.Basis);
1302 }
1303 sortCandidateInstructions();
1304
1305
1306
1307 for (Instruction *I : reverse(SortedCandidateInsts))
1308 if (Candidate *C = pickRewriteCandidate(I))
1309 rewriteCandidate(*C);
1310
1311 for (auto *DeadIns : DeadInstructions)
1312
1313
1314 if (DeadIns->getParent())
1316
1317 bool Ret = !DeadInstructions.empty();
1318 DeadInstructions.clear();
1319 DependencyGraph.clear();
1320 RewriteCandidates.clear();
1321 SortedCandidateInsts.clear();
1322
1323 CandidateDict.clear();
1324
1325 Candidates.clear();
1326 return Ret;
1327}
1328
1329PreservedAnalyses
1335
1338
1344 return PA;
1345}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
This file implements a class to represent arbitrary precision integral constant values and operations...
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
This file contains the declarations for the subclasses of Constant, which represent the different fla...
This file provides an implementation of debug counters.
#define DEBUG_COUNTER(VARNAME, COUNTERNAME, DESC)
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool runOnFunction(Function &F, bool PostInlining)
Module.h This file contains the declarations for the Module class.
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Machine Check Debug Module
static bool isGEPFoldable(GetElementPtrInst *GEP, const TargetTransformInfo *TTI)
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
static BinaryOperator * CreateMul(Value *S1, Value *S2, const Twine &Name, BasicBlock::iterator InsertBefore, Value *FlagsOp)
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallVector class.
static bool matchesOr(Value *A, Value *&B, ConstantInt *&C)
Definition StraightLineStrengthReduce.cpp:1041
static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride, TargetTransformInfo *TTI)
Definition StraightLineStrengthReduce.cpp:935
static void unifyBitWidth(APInt &A, APInt &B)
Definition StraightLineStrengthReduce.cpp:631
static bool matchesAdd(Value *A, Value *&B, ConstantInt *&C)
Definition StraightLineStrengthReduce.cpp:1036
static const unsigned UnknownAddressSpace
Definition StraightLineStrengthReduce.cpp:112
static cl::opt< bool > EnablePoisonReuseGuard("enable-poison-reuse-guard", cl::init(true), cl::desc("Enable poison-reuse guard"))
This pass exposes codegen information to IR-level passes.
Class for arbitrary precision integers.
bool isNegatedPowerOf2() const
Check if this APInt's negated value is a power of two greater than zero.
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
unsigned getBitWidth() const
Return the number of bits in the APInt.
unsigned logBase2() const
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
const Function * getParent() const
Return the enclosing method, or null if none.
Represents analyses that only rely on functions' control flow.
This is the shared class of boolean and integer constants.
bool isOne() const
This is just a convenience method to make client code smaller for a common case.
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
unsigned getBitWidth() const
getBitWidth - Return the scalar bitwidth of this constant.
const APInt & getValue() const
Return the constant as an APInt value reference.
A parsed version of the target data layout string in and methods for querying it.
static bool shouldExecute(CounterInfo &Counter)
iterator find(const_arg_type_t< KeyT > Val)
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Analysis pass which computes a DominatorTree.
DomTreeNodeBase< NodeT > * getNode(const NodeT *BB) const
getNode - return the (Post)DominatorTree node for the specified basic block.
Legacy analysis pass which computes a DominatorTree.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FunctionPass class - This class is used to implement most global optimizations.
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
Value * CreatePtrAdd(Value *Ptr, Value *Offset, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
Value * CreateNeg(Value *V, const Twine &Name="", bool HasNSW=false)
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateSExtOrTrunc(Value *V, Type *DestTy, const Twine &Name="")
Create a SExt or Trunc from the integer value V to DestTy.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
PreservedAnalyses & preserve()
Mark an analysis as preserved.
This class represents an analyzed expression in the program.
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
Analysis pass that exposes the ScalarEvolution for a function.
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< const SCEV * > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Definition StraightLineStrengthReduce.cpp:1330
Analysis pass providing the TargetTransformInfo.
Wrapper pass for TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
@ TCC_Free
Expected to fold away in lowering.
LLVM_ABI unsigned getIntegerBitWidth() const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
std::pair< iterator, bool > insert(const ValueT &V)
TypeSize getSequentialElementStride(const DataLayout &DL) const
This class implements an extremely fast bulk output stream that can only output to a stream.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
@ BasicBlock
Various leaf nodes.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Add, true > m_c_Add(const LHS &L, const RHS &R)
Matches a Add with LHS and RHS in either order.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
CastInst_match< OpTy, SExtInst > m_SExt(const OpTy &Op)
Matches SExt.
BinaryOp_match< LHS, RHS, Instruction::Or, true > m_c_Or(const LHS &L, const RHS &R)
Matches an Or with LHS and RHS in either order.
initializer< Ty > init(const Ty &Val)
NodeAddr< NodeBase * > Node
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
LLVM_ABI bool haveNoCommonBitsSet(const WithCache< const Value * > &LHSCache, const WithCache< const Value * > &RHSCache, const SimplifyQuery &SQ)
Return true if LHS and RHS have no common bits set.
FunctionAddr VTableAddr Value
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
decltype(auto) dyn_cast(const From &Val)
dyn_cast - Return the argument parameter cast to the specified type.
LLVM_ABI void initializeStraightLineStrengthReduceLegacyPassPass(PassRegistry &)
DomTreeNodeBase< BasicBlock > DomTreeNode
auto dyn_cast_or_null(const Y &Val)
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
auto reverse(ContainerTy &&C)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
generic_gep_type_iterator<> gep_type_iterator
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa - Return true if the parameter to the template is an instance of one of the template type argu...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
FunctionAddr VTableAddr Next
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
decltype(auto) cast(const From &Val)
cast - Return the argument parameter cast to the specified type.
gep_type_iterator gep_type_begin(const User *GEP)
PointerUnion< const Value *, const PseudoSourceValue * > ValueType
iterator_range< df_iterator< T > > depth_first(const T &G)
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI FunctionPass * createStraightLineStrengthReducePass()
Definition StraightLineStrengthReduce.cpp:626