LLVM: lib/CodeGen/SwitchLoweringUtils.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
19
20using namespace llvm;
22
26 const APInt &LowCase = Clusters[First].Low->getValue();
27 const APInt &HighCase = Clusters[Last].High->getValue();
29
30
31
32
33 return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
34}
35
42 TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
43 return NumCases;
44}
45
48 std::optional SL,
52#ifndef NDEBUG
53
54 assert(!Clusters.empty());
57 for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
58 assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
59#endif
60
61 assert(TLI && "TLI not set!");
62 if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
63 return;
64
65 const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
66 const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
67
68
69 const int64_t N = Clusters.size();
70 if (N < 2 || N < MinJumpTableEntries)
71 return;
72
73
75 for (unsigned i = 0; i < N; ++i) {
76 const APInt &Hi = Clusters[i].High->getValue();
77 const APInt &Lo = Clusters[i].Low->getValue();
78 TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
79 if (i != 0)
80 TotalCases[i] += TotalCases[i - 1];
81 }
82
87
88
89 if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
91 if (buildJumpTable(Clusters, 0, N - 1, SI, SL, DefaultMBB, JTCluster)) {
92 Clusters[0] = JTCluster;
93 Clusters.resize(1);
94 return;
95 }
96 }
97
98
100 return;
101
102
103
104
105
106
107
108
109
110
112
114
115
117
118
119
120 enum PartitionScores : unsigned {
121 NoTable = 0,
122 Table = 1,
123 FewCases = 1,
124 SingleCase = 2
125 };
126
127
128 MinPartitions[N - 1] = 1;
129 LastElement[N - 1] = N - 1;
130 PartitionsScore[N - 1] = PartitionScores::SingleCase;
131
132
133 for (int64_t i = N - 2; i >= 0; i--) {
134
135
136 MinPartitions[i] = MinPartitions[i + 1] + 1;
137 LastElement[i] = i;
138 PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
139
140
141 for (int64_t j = N - 1; j > i; j--) {
142
147
148 if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
149 unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
150 unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
151 int64_t NumEntries = j - i + 1;
152
153 if (NumEntries == 1)
154 Score += PartitionScores::SingleCase;
155 else if (NumEntries <= SmallNumberOfEntries)
156 Score += PartitionScores::FewCases;
157 else if (NumEntries >= MinJumpTableEntries)
158 Score += PartitionScores::Table;
159
160
161
162 if (NumPartitions < MinPartitions[i] ||
163 (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
164 MinPartitions[i] = NumPartitions;
165 LastElement[i] = j;
166 PartitionsScore[i] = Score;
167 }
168 }
169 }
170 }
171
172
173 unsigned DstIndex = 0;
178 unsigned NumClusters = Last - First + 1;
179
181 if (NumClusters >= MinJumpTableEntries &&
183 Clusters[DstIndex++] = JTCluster;
184 } else {
186 std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
187 }
188 }
189 Clusters.resize(DstIndex);
190}
191
195 const std::optional &SL,
199
201 unsigned NumCmps = 0;
202 std::vector<MachineBasicBlock*> Table;
204
205
208
211 Prob += Clusters[I].Prob;
212 const APInt &Low = Clusters[I].Low->getValue();
213 const APInt &High = Clusters[I].High->getValue();
214 NumCmps += (Low == High) ? 1 : 2;
216
217 const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
219 uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
220 for (uint64_t J = 0; J < Gap; J++)
221 Table.push_back(DefaultMBB);
222 }
223 uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
224 for (uint64_t J = 0; J < ClusterSize; ++J)
225 Table.push_back(Clusters[I].MBB);
226 JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
227 }
228
229 unsigned NumDests = JTProbs.size();
230 if (TLI->isSuitableForBitTests(NumDests, NumCmps,
231 Clusters[First].Low->getValue(),
232 Clusters[Last].High->getValue(), *DL)) {
233
234 return false;
235 }
236
237
238
242
243
246 if (Done.count(Succ))
247 continue;
249 Done.insert(Succ);
250 }
252
255
256
259 Clusters[Last].High->getValue(), SI->getCondition(),
260 nullptr, false);
261 JTCases.emplace_back(std::move(JTH), std::move(JT));
262
264 JTCases.size() - 1, Prob);
265 return true;
266}
267
270
271
272
273#ifndef NDEBUG
274
275 assert(!Clusters.empty());
279 for (unsigned i = 1; i < Clusters.size(); ++i)
280 assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
281#endif
282
283
285 return;
286
287
288 EVT PTy = TLI->getPointerTy(*DL);
289 if (!TLI->isOperationLegal(ISD::SHL, PTy))
290 return;
291
293 const int64_t N = Clusters.size();
294
295
297
299
300
301
302
303 MinPartitions[N - 1] = 1;
304 LastElement[N - 1] = N - 1;
305
306
307 for (int64_t i = N - 2; i >= 0; --i) {
308
309
310 MinPartitions[i] = MinPartitions[i + 1] + 1;
311 LastElement[i] = i;
312
313
314
315 for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
316
317
318
319 if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
320 Clusters[j].High->getValue(), *DL))
321 continue;
322
323
324
325 bool RangesOnly = true;
326 BitVector Dests(FuncInfo.MF->getNumBlockIDs());
327 for (int64_t k = i; k <= j; k++) {
328 if (Clusters[k].Kind != CC_Range) {
329 RangesOnly = false;
330 break;
331 }
332 Dests.set(Clusters[k].MBB->getNumber());
333 }
334 if (!RangesOnly || Dests.count() > 3)
335 break;
336
337
338 unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
339 if (NumPartitions < MinPartitions[i]) {
340
341 MinPartitions[i] = NumPartitions;
342 LastElement[i] = j;
343 }
344 }
345 }
346
347
348 unsigned DstIndex = 0;
353
356 Clusters[DstIndex++] = BitTestCluster;
357 } else {
358 size_t NumClusters = Last - First + 1;
359 std::memmove(&Clusters[DstIndex], &Clusters[First],
360 sizeof(Clusters[0]) * NumClusters);
361 DstIndex += NumClusters;
362 }
363 }
364 Clusters.resize(DstIndex);
365}
366
373 return false;
374
375 BitVector Dests(FuncInfo.MF->getNumBlockIDs());
376 unsigned NumCmps = 0;
379 Dests.set(Clusters[I].MBB->getNumber());
380 NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
381 }
382 unsigned NumDests = Dests.count();
383
387
388 if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
389 return false;
390
393
394 const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
396 "Case range must fit in bit mask!");
397
398
399
400 bool ContiguousRange = true;
401 for (int64_t I = First + 1; I <= Last; ++I) {
402 if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
403 ContiguousRange = false;
404 break;
405 }
406 }
407
409
410
412 CmpRange = High;
413 ContiguousRange = false;
414 } else {
415 LowBound = Low;
417 }
418
421 for (unsigned i = First; i <= Last; ++i) {
422
423 unsigned j;
424 for (j = 0; j < CBV.size(); ++j)
425 if (CBV[j].BB == Clusters[i].MBB)
426 break;
427 if (j == CBV.size())
428 CBV.push_back(
431
432
433 uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
434 uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
435 assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
436 CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
438 CB->ExtraProb += Clusters[i].Prob;
439 TotalProb += Clusters[i].Prob;
440 }
441
444
446 return a.ExtraProb > b.ExtraProb;
447 if (a.Bits != b.Bits)
448 return a.Bits > b.Bits;
449 return a.Mask < b.Mask;
450 });
451
452 for (auto &CB : CBV) {
454 FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
456 }
457 BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
458 SI->getCondition(), Register(), MVT::Other, false,
459 ContiguousRange, nullptr, nullptr, std::move(BTI),
460 TotalProb);
461
464 return true;
465}
466
468#ifndef NDEBUG
470 assert(CC.Low == CC.High && "Input clusters must be single-case");
471#endif
472
475 });
476
477
478 const unsigned N = Clusters.size();
479 unsigned DstIndex = 0;
480 for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
484
485 if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
486 (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
487
488
489 Clusters[DstIndex - 1].High = CaseVal;
490 Clusters[DstIndex - 1].Prob += CC.Prob;
491 } else {
492 std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
493 sizeof(Clusters[SrcIndex]));
494 }
495 }
496 Clusters.resize(DstIndex);
497}
498
505
506
507 return X.Low->getValue().slt(CC.Low->getValue());
508 });
509}
510
516 auto LeftProb = LastLeft->Prob + W.DefaultProb / 2;
517 auto RightProb = FirstRight->Prob + W.DefaultProb / 2;
518
519
520
521
522
523 unsigned I = 0;
524 while (LastLeft + 1 < FirstRight) {
525 if (LeftProb < RightProb || (LeftProb == RightProb && (I & 1)))
526 LeftProb += (++LastLeft)->Prob;
527 else
528 RightProb += (--FirstRight)->Prob;
529 I++;
530 }
531
532 while (true) {
533
534
535
536
537
538 unsigned NumLeft = LastLeft - W.FirstCluster + 1;
539 unsigned NumRight = W.LastCluster - FirstRight + 1;
540
541 if (std::min(NumLeft, NumRight) < 3 && std::max(NumLeft, NumRight) > 3) {
542
543
544
545 if (NumLeft < NumRight) {
546
548 unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
549 unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
550 if (LeftSideRank <= RightSideRank) {
551
552 ++LastLeft;
553 ++FirstRight;
554 continue;
555 }
556 } else {
557 assert(NumRight < NumLeft);
558
560 unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
561 unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
562 if (RightSideRank <= LeftSideRank) {
563
564 --LastLeft;
565 --FirstRight;
566 continue;
567 }
568 }
569 }
570 break;
571 }
572
573 assert(LastLeft + 1 == FirstRight);
574 assert(LastLeft >= W.FirstCluster);
575 assert(FirstRight <= W.LastCluster);
576
577 return SplitWorkItemInfo{LastLeft, FirstRight, LeftProb, RightProb};
578}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Promote Memory to Register
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")
This file describes how to lower LLVM code to machine code.
Class for arbitrary precision integers.
unsigned getBitWidth() const
Return the number of bits in the APInt.
bool slt(const APInt &RHS) const
Signed less than comparison.
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
size_type count() const
count - Returns the number of bits which are set.
BlockFrequencyInfo pass uses BlockFrequencyInfoImpl implementation to estimate IR basic block frequen...
static BranchProbability getZero()
This is the shared class of boolean and integer constants.
const APInt & getValue() const
Return the constant as an APInt value reference.
void normalizeSuccProbs()
Normalize probabilities of all successors so that the sum of them becomes one.
MachineJumpTableInfo * getOrCreateJumpTableInfo(unsigned JTEntryKind)
getOrCreateJumpTableInfo - Get the JumpTableInfo for this function, if it does already exist,...
MachineBasicBlock * CreateMachineBasicBlock(const BasicBlock *BB=nullptr, std::optional< UniqueBBID > BBID=std::nullopt)
CreateMachineInstr - Allocate a new MachineInstr.
LLVM_ABI unsigned createJumpTableIndex(const std::vector< MachineBasicBlock * > &DestBBs)
createJumpTableIndex - Create a new jump table.
Analysis providing profile information.
Wrapper class representing virtual and physical registers.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last, const SwitchInst *SI, CaseCluster &BTCluster)
Build a bit test cluster from Clusters[First..Last].
Definition SwitchLoweringUtils.cpp:367
unsigned caseClusterRank(const CaseCluster &CC, CaseClusterIt First, CaseClusterIt Last)
Determine the rank by weight of CC in [First,Last].
Definition SwitchLoweringUtils.cpp:499
void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI, std::optional< SDLoc > SL, MachineBasicBlock *DefaultMBB, ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI)
Definition SwitchLoweringUtils.cpp:46
virtual void addSuccessorWithProb(MachineBasicBlock *Src, MachineBasicBlock *Dst, BranchProbability Prob=BranchProbability::getUnknown())=0
void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI)
Definition SwitchLoweringUtils.cpp:268
std::vector< BitTestBlock > BitTestCases
Vector of BitTestBlock structures used to communicate SwitchInst code generation information.
SplitWorkItemInfo computeSplitWorkItemInfo(const SwitchWorkListItem &W)
Compute information to balance the tree based on branch probabilities to create a near-optimal (in te...
Definition SwitchLoweringUtils.cpp:512
bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First, unsigned Last, const SwitchInst *SI, const std::optional< SDLoc > &SL, MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster)
Definition SwitchLoweringUtils.cpp:192
std::vector< JumpTableBlock > JTCases
Vector of JumpTable structures used to communicate SwitchInst code generation information.
@ C
The default llvm calling convention, compatible with C.
@ SHL
Shift and rotation operations.
std::vector< CaseBits > CaseBitsVector
uint64_t getJumpTableNumCases(const SmallVectorImpl< unsigned > &TotalCases, unsigned First, unsigned Last)
Return the number of cases within a range.
Definition SwitchLoweringUtils.cpp:37
void sortAndRangeify(CaseClusterVector &Clusters)
Sort Clusters and merge adjacent cases.
Definition SwitchLoweringUtils.cpp:467
std::vector< CaseCluster > CaseClusterVector
uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First, unsigned Last)
Return the range of values within a range.
Definition SwitchLoweringUtils.cpp:23
@ CC_Range
A cluster of adjacent case labels with the same destination, or just one case.
@ CC_JumpTable
A cluster of cases suitable for jump table lowering.
SmallVector< BitTestCase, 3 > BitTestInfo
CaseClusterVector::iterator CaseClusterIt
This is an optimization pass for GlobalISel generic memory operations.
@ Low
Lower the current thread's priority such that it does not affect foreground tasks significantly.
void sort(IteratorTy Start, IteratorTy End)
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
constexpr unsigned BitWidth
TypeSize getSizeInBits() const
Return the size of the specified value type in bits.
BranchProbability ExtraProb
A cluster of case labels.
static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High, unsigned JTCasesIndex, BranchProbability Prob)
static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High, unsigned BTCasesIndex, BranchProbability Prob)