LLVM: lib/CodeGen/SwitchLoweringUtils.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

19

20using namespace llvm;

22

26 const APInt &LowCase = Clusters[First].Low->getValue();

27 const APInt &HighCase = Clusters[Last].High->getValue();

29

30

31

32

33 return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;

34}

35

42 TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);

43 return NumCases;

44}

45

48 std::optional SL,

52#ifndef NDEBUG

53

54 assert(!Clusters.empty());

57 for (unsigned i = 1, e = Clusters.size(); i < e; ++i)

58 assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));

59#endif

60

61 assert(TLI && "TLI not set!");

62 if (!TLI->areJTsAllowed(SI->getParent()->getParent()))

63 return;

64

65 const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();

66 const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;

67

68

69 const int64_t N = Clusters.size();

70 if (N < 2 || N < MinJumpTableEntries)

71 return;

72

73

75 for (unsigned i = 0; i < N; ++i) {

76 const APInt &Hi = Clusters[i].High->getValue();

77 const APInt &Lo = Clusters[i].Low->getValue();

78 TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;

79 if (i != 0)

80 TotalCases[i] += TotalCases[i - 1];

81 }

82

87

88

89 if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {

91 if (buildJumpTable(Clusters, 0, N - 1, SI, SL, DefaultMBB, JTCluster)) {

92 Clusters[0] = JTCluster;

93 Clusters.resize(1);

94 return;

95 }

96 }

97

98

100 return;

101

102

103

104

105

106

107

108

109

110

112

114

115

117

118

119

120 enum PartitionScores : unsigned {

121 NoTable = 0,

122 Table = 1,

123 FewCases = 1,

124 SingleCase = 2

125 };

126

127

128 MinPartitions[N - 1] = 1;

129 LastElement[N - 1] = N - 1;

130 PartitionsScore[N - 1] = PartitionScores::SingleCase;

131

132

133 for (int64_t i = N - 2; i >= 0; i--) {

134

135

136 MinPartitions[i] = MinPartitions[i + 1] + 1;

137 LastElement[i] = i;

138 PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;

139

140

141 for (int64_t j = N - 1; j > i; j--) {

142

147

148 if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {

149 unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);

150 unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];

151 int64_t NumEntries = j - i + 1;

152

153 if (NumEntries == 1)

154 Score += PartitionScores::SingleCase;

155 else if (NumEntries <= SmallNumberOfEntries)

156 Score += PartitionScores::FewCases;

157 else if (NumEntries >= MinJumpTableEntries)

158 Score += PartitionScores::Table;

159

160

161

162 if (NumPartitions < MinPartitions[i] ||

163 (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {

164 MinPartitions[i] = NumPartitions;

165 LastElement[i] = j;

166 PartitionsScore[i] = Score;

167 }

168 }

169 }

170 }

171

172

173 unsigned DstIndex = 0;

178 unsigned NumClusters = Last - First + 1;

179

181 if (NumClusters >= MinJumpTableEntries &&

183 Clusters[DstIndex++] = JTCluster;

184 } else {

186 std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));

187 }

188 }

189 Clusters.resize(DstIndex);

190}

191

195 const std::optional &SL,

199

201 std::vector<MachineBasicBlock*> Table;

203

204

207

211 Prob += Clusters[I].Prob;

212 const APInt &Low = Clusters[I].Low->getValue();

213 const APInt &High = Clusters[I].High->getValue();

214 unsigned int NumCmp = (Low == High) ? 1 : 2;

215 const BasicBlock *BB = Clusters[I].MBB->getBasicBlock();

216 DestMap[BB] += NumCmp;

217

219

220 const APInt &PreviousHigh = Clusters[I - 1].High->getValue();

222 uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;

223 for (uint64_t J = 0; J < Gap; J++)

224 Table.push_back(DefaultMBB);

225 }

226 uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;

227 for (uint64_t J = 0; J < ClusterSize; ++J)

228 Table.push_back(Clusters[I].MBB);

229 JTProbs[Clusters[I].MBB] += Clusters[I].Prob;

230 }

231

232 if (TLI->isSuitableForBitTests(DestMap, Clusters[First].Low->getValue(),

233 Clusters[Last].High->getValue(), *DL)) {

234

235 return false;

236 }

237

238

239

243

244

247 if (Done.count(Succ))

248 continue;

250 Done.insert(Succ);

251 }

253

256

257

260 Clusters[Last].High->getValue(), SI->getCondition(),

261 nullptr, false);

262 JTCases.emplace_back(std::move(JTH), std::move(JT));

263

265 JTCases.size() - 1, Prob);

266 return true;

267}

268

271

272

273

274#ifndef NDEBUG

275

276 assert(!Clusters.empty());

280 for (unsigned i = 1; i < Clusters.size(); ++i)

281 assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));

282#endif

283

284

286 return;

287

288

289 EVT PTy = TLI->getPointerTy(*DL);

290 if (!TLI->isOperationLegal(ISD::SHL, PTy))

291 return;

292

294 const int64_t N = Clusters.size();

295

296

298

300

301

302

303

304 MinPartitions[N - 1] = 1;

305 LastElement[N - 1] = N - 1;

306

307

308 for (int64_t i = N - 2; i >= 0; --i) {

309

310

311 MinPartitions[i] = MinPartitions[i + 1] + 1;

312 LastElement[i] = i;

313

314

315

316 for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {

317

318

319

320 if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),

321 Clusters[j].High->getValue(), *DL))

322 continue;

323

324

325

326 bool RangesOnly = true;

327 BitVector Dests(FuncInfo.MF->getNumBlockIDs());

328 for (int64_t k = i; k <= j; k++) {

329 if (Clusters[k].Kind != CC_Range) {

330 RangesOnly = false;

331 break;

332 }

333 Dests.set(Clusters[k].MBB->getNumber());

334 }

335 if (!RangesOnly || Dests.count() > 3)

336 break;

337

338

339 unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);

340 if (NumPartitions < MinPartitions[i]) {

341

342 MinPartitions[i] = NumPartitions;

343 LastElement[i] = j;

344 }

345 }

346 }

347

348

349 unsigned DstIndex = 0;

354

357 Clusters[DstIndex++] = BitTestCluster;

358 } else {

359 size_t NumClusters = Last - First + 1;

360 std::memmove(&Clusters[DstIndex], &Clusters[First],

361 sizeof(Clusters[0]) * NumClusters);

362 DstIndex += NumClusters;

363 }

364 }

365 Clusters.resize(DstIndex);

366}

367

374 return false;

375

379 unsigned NumCmp = (Clusters[I].Low == Clusters[I].High) ? 1 : 2;

380 const BasicBlock *BB = Clusters[I].MBB->getBasicBlock();

381 DestMap[BB] += NumCmp;

382 }

383

387

388 if (!TLI->isSuitableForBitTests(DestMap, Low, High, *DL))

389 return false;

390

393

394 const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();

396 "Case range must fit in bit mask!");

397

398

399

400 bool ContiguousRange = true;

401 for (int64_t I = First + 1; I <= Last; ++I) {

402 if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {

403 ContiguousRange = false;

404 break;

405 }

406 }

407

409

410

412 CmpRange = High;

413 ContiguousRange = false;

414 } else {

415 LowBound = Low;

417 }

418

421 for (unsigned i = First; i <= Last; ++i) {

422

423 unsigned j;

424 for (j = 0; j < CBV.size(); ++j)

425 if (CBV[j].BB == Clusters[i].MBB)

426 break;

427 if (j == CBV.size())

428 CBV.push_back(

431

432

433 uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();

434 uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();

435 assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");

436 CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;

438 CB->ExtraProb += Clusters[i].Prob;

439 TotalProb += Clusters[i].Prob;

440 }

441

444

446 return a.ExtraProb > b.ExtraProb;

447 if (a.Bits != b.Bits)

448 return a.Bits > b.Bits;

449 return a.Mask < b.Mask;

450 });

451

452 for (auto &CB : CBV) {

454 FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());

456 }

457 BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),

458 SI->getCondition(), Register(), MVT::Other, false,

459 ContiguousRange, nullptr, nullptr, std::move(BTI),

460 TotalProb);

461

464 return true;

465}

466

468#ifndef NDEBUG

470 assert(CC.Low == CC.High && "Input clusters must be single-case");

471#endif

472

475 });

476

477

478 const unsigned N = Clusters.size();

479 unsigned DstIndex = 0;

480 for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {

484

485 if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&

486 (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {

487

488

489 Clusters[DstIndex - 1].High = CaseVal;

490 Clusters[DstIndex - 1].Prob += CC.Prob;

491 } else {

492 std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],

493 sizeof(Clusters[SrcIndex]));

494 }

495 }

496 Clusters.resize(DstIndex);

497}

498

503 if (X.Prob != CC.Prob)

504 return X.Prob > CC.Prob;

505

506

507 return X.Low->getValue().slt(CC.Low->getValue());

508 });

509}

510

516 auto LeftProb = LastLeft->Prob + W.DefaultProb / 2;

517 auto RightProb = FirstRight->Prob + W.DefaultProb / 2;

518

519

520

521

522

523 unsigned I = 0;

524 while (LastLeft + 1 < FirstRight) {

525 if (LeftProb < RightProb || (LeftProb == RightProb && (I & 1)))

526 LeftProb += (++LastLeft)->Prob;

527 else

528 RightProb += (--FirstRight)->Prob;

529 I++;

530 }

531

532 while (true) {

533

534

535

536

537

538 unsigned NumLeft = LastLeft - W.FirstCluster + 1;

539 unsigned NumRight = W.LastCluster - FirstRight + 1;

540

541 if (std::min(NumLeft, NumRight) < 3 && std::max(NumLeft, NumRight) > 3) {

542

543

544

545 if (NumLeft < NumRight) {

546

548 unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);

549 unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);

550 if (LeftSideRank <= RightSideRank) {

551

552 ++LastLeft;

553 ++FirstRight;

554 continue;

555 }

556 } else {

557 assert(NumRight < NumLeft);

558

560 unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);

561 unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);

562 if (RightSideRank <= LeftSideRank) {

563

564 --LastLeft;

565 --FirstRight;

566 continue;

567 }

568 }

569 }

570 break;

571 }

572

573 assert(LastLeft + 1 == FirstRight);

574 assert(LastLeft >= W.FirstCluster);

575 assert(FirstRight <= W.LastCluster);

576

577 return SplitWorkItemInfo{LastLeft, FirstRight, LeftProb, RightProb};

578}

assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")

Promote Memory to Register

ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))

static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")

This file describes how to lower LLVM code to machine code.

Class for arbitrary precision integers.

unsigned getBitWidth() const

Return the number of bits in the APInt.

bool slt(const APInt &RHS) const

Signed less than comparison.

static APInt getZero(unsigned numBits)

Get the '0' value for the specified bit-width.

LLVM Basic Block Representation.

size_type count() const

count - Returns the number of bits which are set.

BlockFrequencyInfo pass uses BlockFrequencyInfoImpl implementation to estimate IR basic block frequen...

static BranchProbability getZero()

This is the shared class of boolean and integer constants.

const APInt & getValue() const

Return the constant as an APInt value reference.

void normalizeSuccProbs()

Normalize probabilities of all successors so that the sum of them becomes one.

MachineJumpTableInfo * getOrCreateJumpTableInfo(unsigned JTEntryKind)

getOrCreateJumpTableInfo - Get the JumpTableInfo for this function, if it does already exist,...

MachineBasicBlock * CreateMachineBasicBlock(const BasicBlock *BB=nullptr, std::optional< UniqueBBID > BBID=std::nullopt)

CreateMachineInstr - Allocate a new MachineInstr.

LLVM_ABI unsigned createJumpTableIndex(const std::vector< MachineBasicBlock * > &DestBBs)

createJumpTableIndex - Create a new jump table.

Analysis providing profile information.

Wrapper class representing virtual and physical registers.

SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.

This class consists of common code factored out of the SmallVector class to reduce code duplication b...

void push_back(const T &Elt)

This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.

bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last, const SwitchInst *SI, CaseCluster &BTCluster)

Build a bit test cluster from Clusters[First..Last].

Definition SwitchLoweringUtils.cpp:368

unsigned caseClusterRank(const CaseCluster &CC, CaseClusterIt First, CaseClusterIt Last)

Determine the rank by weight of CC in [First,Last].

Definition SwitchLoweringUtils.cpp:499

void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI, std::optional< SDLoc > SL, MachineBasicBlock *DefaultMBB, ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI)

Definition SwitchLoweringUtils.cpp:46

virtual void addSuccessorWithProb(MachineBasicBlock *Src, MachineBasicBlock *Dst, BranchProbability Prob=BranchProbability::getUnknown())=0

void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI)

Definition SwitchLoweringUtils.cpp:269

std::vector< BitTestBlock > BitTestCases

Vector of BitTestBlock structures used to communicate SwitchInst code generation information.

SplitWorkItemInfo computeSplitWorkItemInfo(const SwitchWorkListItem &W)

Compute information to balance the tree based on branch probabilities to create a near-optimal (in te...

Definition SwitchLoweringUtils.cpp:512

bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First, unsigned Last, const SwitchInst *SI, const std::optional< SDLoc > &SL, MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster)

Definition SwitchLoweringUtils.cpp:192

std::vector< JumpTableBlock > JTCases

Vector of JumpTable structures used to communicate SwitchInst code generation information.

@ C

The default llvm calling convention, compatible with C.

@ SHL

Shift and rotation operations.

std::vector< CaseBits > CaseBitsVector

uint64_t getJumpTableNumCases(const SmallVectorImpl< unsigned > &TotalCases, unsigned First, unsigned Last)

Return the number of cases within a range.

Definition SwitchLoweringUtils.cpp:37

void sortAndRangeify(CaseClusterVector &Clusters)

Sort Clusters and merge adjacent cases.

Definition SwitchLoweringUtils.cpp:467

std::vector< CaseCluster > CaseClusterVector

uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First, unsigned Last)

Return the range of values within a range.

Definition SwitchLoweringUtils.cpp:23

@ CC_Range

A cluster of adjacent case labels with the same destination, or just one case.

@ CC_JumpTable

A cluster of cases suitable for jump table lowering.

SmallVector< BitTestCase, 3 > BitTestInfo

CaseClusterVector::iterator CaseClusterIt

This is an optimization pass for GlobalISel generic memory operations.

@ Low

Lower the current thread's priority such that it does not affect foreground tasks significantly.

void sort(IteratorTy Start, IteratorTy End)

@ First

Helpers to iterate all locations in the MemoryEffectsBase class.

constexpr unsigned BitWidth

TypeSize getSizeInBits() const

Return the size of the specified value type in bits.

BranchProbability ExtraProb

A cluster of case labels.

static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High, unsigned JTCasesIndex, BranchProbability Prob)

static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High, unsigned BTCasesIndex, BranchProbability Prob)