LLVM: lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

36

37using namespace llvm;

38

39#define DEBUG_TYPE "aggressive-instcombine"

40

41STATISTIC(NumExprsReduced, "Number of truncations eliminated by reducing bit "

42 "width of expression graph");

44 "Number of instructions whose bit width was reduced");

45

46

47

49 unsigned Opc = I->getOpcode();

50 switch (Opc) {

51 case Instruction::Trunc:

52 case Instruction::ZExt:

53 case Instruction::SExt:

54

55

56 break;

57 case Instruction::Add:

58 case Instruction::Sub:

59 case Instruction::Mul:

60 case Instruction::And:

61 case Instruction::Or:

62 case Instruction::Xor:

63 case Instruction::Shl:

64 case Instruction::LShr:

65 case Instruction::AShr:

66 case Instruction::UDiv:

67 case Instruction::URem:

68 case Instruction::InsertElement:

69 Ops.push_back(I->getOperand(0));

70 Ops.push_back(I->getOperand(1));

71 break;

72 case Instruction::ExtractElement:

73 Ops.push_back(I->getOperand(0));

74 break;

75 case Instruction::Select:

76 Ops.push_back(I->getOperand(1));

77 Ops.push_back(I->getOperand(2));

78 break;

79 case Instruction::PHI:

81 break;

82 default:

84 }

85}

86

87bool TruncInstCombine::buildTruncExpressionGraph() {

88 SmallVector<Value *, 8> Worklist;

89 SmallVector<Instruction *, 8> Stack;

90

91 InstInfoMap.clear();

92

93 Worklist.push_back(CurrentTruncInst->getOperand(0));

94

95 while (!Worklist.empty()) {

96 Value *Curr = Worklist.back();

97

99 Worklist.pop_back();

100 continue;

101 }

102

104 if (I)

105 return false;

106

107 if (Stack.empty() && Stack.back() == I) {

108

109

110 Worklist.pop_back();

111 Stack.pop_back();

112

113 InstInfoMap.try_emplace(I);

114 continue;

115 }

116

117 if (InstInfoMap.count(I)) {

118 Worklist.pop_back();

119 continue;

120 }

121

122

124

125 unsigned Opc = I->getOpcode();

126 switch (Opc) {

127 case Instruction::Trunc:

128 case Instruction::ZExt:

129 case Instruction::SExt:

130

131

132

133

134 break;

135 case Instruction::Add:

136 case Instruction::Sub:

137 case Instruction::Mul:

138 case Instruction::And:

139 case Instruction::Or:

140 case Instruction::Xor:

141 case Instruction::Shl:

142 case Instruction::LShr:

143 case Instruction::AShr:

144 case Instruction::UDiv:

145 case Instruction::URem:

146 case Instruction::InsertElement:

147 case Instruction::ExtractElement:

148 case Instruction::Select: {

149 SmallVector<Value *, 2> Operands;

152 break;

153 }

154 case Instruction::PHI: {

155 SmallVector<Value *, 2> Operands;

157

158 for (auto *Op : Operands)

160 Worklist.push_back(Op);

161 break;

162 }

163 default:

164

165

166

167

168 return false;

169 }

170 }

171 return true;

172}

173

174unsigned TruncInstCombine::getMinBitWidth() {

175 SmallVector<Value *, 8> Worklist;

176 SmallVector<Instruction *, 8> Stack;

177

178 Value *Src = CurrentTruncInst->getOperand(0);

179 Type *DstTy = CurrentTruncInst->getType();

181 unsigned OrigBitWidth =

182 CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();

183

185 return TruncBitWidth;

186

187 Worklist.push_back(Src);

188 InstInfoMap[cast(Src)].ValidBitWidth = TruncBitWidth;

189

190 while (!Worklist.empty()) {

191 Value *Curr = Worklist.back();

192

194 Worklist.pop_back();

195 continue;

196 }

197

198

200

201 auto &Info = InstInfoMap[I];

202

203 SmallVector<Value *, 2> Operands;

205

206 if (Stack.empty() && Stack.back() == I) {

207

208

209 Worklist.pop_back();

210 Stack.pop_back();

211 for (auto *Operand : Operands)

213 Info.MinBitWidth =

214 std::max(Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth);

215 continue;

216 }

217

218

220 unsigned ValidBitWidth = Info.ValidBitWidth;

221

222

223

224 Info.MinBitWidth = std::max(Info.MinBitWidth, Info.ValidBitWidth);

225

226 for (auto *Operand : Operands)

228

229

230

231 unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth;

232 if (IOpBitwidth >= ValidBitWidth)

233 continue;

234 InstInfoMap[IOp].ValidBitWidth = ValidBitWidth;

235 Worklist.push_back(IOp);

236 }

237 }

238 unsigned MinBitWidth = InstInfoMap.lookup(cast(Src)).MinBitWidth;

239 assert(MinBitWidth >= TruncBitWidth);

240

241 if (MinBitWidth > TruncBitWidth) {

242

243

244

246 return OrigBitWidth;

247

248 Type *Ty = DL.getSmallestLegalIntType(DstTy->getContext(), MinBitWidth);

249

250

252 } else {

253

254

255

256

257 bool FromLegal = MinBitWidth == 1 || DL.isLegalInteger(OrigBitWidth);

258 bool ToLegal = MinBitWidth == 1 || DL.isLegalInteger(MinBitWidth);

259 if (!DstTy->isVectorTy() && FromLegal && !ToLegal)

260 return OrigBitWidth;

261 }

262 return MinBitWidth;

263}

264

265Type *TruncInstCombine::getBestTruncatedType() {

266 if (!buildTruncExpressionGraph())

267 return nullptr;

268

269

270

271

272

273 unsigned DesiredBitWidth = 0;

274 for (auto Itr : InstInfoMap) {

276 if (I->hasOneUse())

277 continue;

279 for (auto *U : I->users())

281 if (UI != CurrentTruncInst && !InstInfoMap.count(UI)) {

282 if (!IsExtInst)

283 return nullptr;

284

285

286

287 unsigned ExtInstBitWidth =

288 I->getOperand(0)->getType()->getScalarSizeInBits();

289 if (DesiredBitWidth && DesiredBitWidth != ExtInstBitWidth)

290 return nullptr;

291 DesiredBitWidth = ExtInstBitWidth;

292 }

293 }

294

295 unsigned OrigBitWidth =

296 CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();

297

298

299

300

301

302

303

304

305

306 for (auto &Itr : InstInfoMap) {

308 if (I->isShift()) {

309 KnownBits KnownRHS = computeKnownBits(I->getOperand(1));

310 unsigned MinBitWidth = KnownRHS.getMaxValue()

311 .uadd_sat(APInt(OrigBitWidth, 1))

313 if (MinBitWidth == OrigBitWidth)

314 return nullptr;

315 if (I->getOpcode() == Instruction::LShr) {

316 KnownBits KnownLHS = computeKnownBits(I->getOperand(0));

317 MinBitWidth =

319 }

320 if (I->getOpcode() == Instruction::AShr) {

321 unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0));

322 MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1);

323 }

324 if (MinBitWidth >= OrigBitWidth)

325 return nullptr;

326 Itr.second.MinBitWidth = MinBitWidth;

327 }

328 if (I->getOpcode() == Instruction::UDiv ||

329 I->getOpcode() == Instruction::URem) {

330 unsigned MinBitWidth = 0;

331 for (const auto &Op : I->operands()) {

332 KnownBits Known = computeKnownBits(Op);

333 MinBitWidth =

335 if (MinBitWidth >= OrigBitWidth)

336 return nullptr;

337 }

338 Itr.second.MinBitWidth = MinBitWidth;

339 }

340 }

341

342

343

344 unsigned MinBitWidth = getMinBitWidth();

345

346

347

348 if (MinBitWidth >= OrigBitWidth ||

349 (DesiredBitWidth && DesiredBitWidth != MinBitWidth))

350 return nullptr;

351

352 return IntegerType::get(CurrentTruncInst->getContext(), MinBitWidth);

353}

354

355

356

357

359 assert(Ty && !Ty->isVectorTy() && "Expect Scalar Type");

362 return Ty;

363}

364

365Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) {

369

371 }

372

374 Info Entry = InstInfoMap.lookup(I);

376 return Entry.NewValue;

377}

378

379void TruncInstCombine::ReduceExpressionGraph(Type *SclTy) {

380 NumInstrsReduced += InstInfoMap.size();

381

383 for (auto &Itr : InstInfoMap) {

385 TruncInstCombine::Info &NodeInfo = Itr.second;

386

387 assert(!NodeInfo.NewValue && "Instruction has been evaluated");

388

390 Value *Res = nullptr;

391 unsigned Opc = I->getOpcode();

392 switch (Opc) {

393 case Instruction::Trunc:

394 case Instruction::ZExt:

395 case Instruction::SExt: {

397

398

399

400 if (I->getOperand(0)->getType() == Ty) {

402 NodeInfo.NewValue = I->getOperand(0);

403 continue;

404 }

405

406

407 Res = Builder.CreateIntCast(I->getOperand(0), Ty,

408 Opc == Instruction::SExt);

409

410

411

412

413

414

416 if (Entry != Worklist.end()) {

419 else

420 Worklist.erase(Entry);

422 Worklist.push_back(NewCI);

423 break;

424 }

425 case Instruction::Add:

426 case Instruction::Sub:

427 case Instruction::Mul:

428 case Instruction::And:

429 case Instruction::Or:

430 case Instruction::Xor:

431 case Instruction::Shl:

432 case Instruction::LShr:

433 case Instruction::AShr:

434 case Instruction::UDiv:

435 case Instruction::URem: {

436 Value *LHS = getReducedOperand(I->getOperand(0), SclTy);

437 Value *RHS = getReducedOperand(I->getOperand(1), SclTy);

439

442 ResI->setIsExact(PEO->isExact());

443 break;

444 }

445 case Instruction::ExtractElement: {

446 Value *Vec = getReducedOperand(I->getOperand(0), SclTy);

447 Value *Idx = I->getOperand(1);

448 Res = Builder.CreateExtractElement(Vec, Idx);

449 break;

450 }

451 case Instruction::InsertElement: {

452 Value *Vec = getReducedOperand(I->getOperand(0), SclTy);

453 Value *NewElt = getReducedOperand(I->getOperand(1), SclTy);

454 Value *Idx = I->getOperand(2);

455 Res = Builder.CreateInsertElement(Vec, NewElt, Idx);

456 break;

457 }

458 case Instruction::Select: {

459 Value *Op0 = I->getOperand(0);

460 Value *LHS = getReducedOperand(I->getOperand(1), SclTy);

461 Value *RHS = getReducedOperand(I->getOperand(2), SclTy);

462 Res = Builder.CreateSelect(Op0, LHS, RHS, "", I);

463 break;

464 }

465 case Instruction::PHI: {

466 Res = Builder.CreatePHI(getReducedType(I, SclTy), I->getNumOperands());

469 break;

470 }

471 default:

473 }

474

475 NodeInfo.NewValue = Res;

478 }

479

480 for (auto &Node : OldNewPHINodes) {

481 PHINode *OldPN = Node.first;

482 PHINode *NewPN = Node.second;

484 NewPN->addIncoming(getReducedOperand(std::get<0>(Incoming), SclTy),

485 std::get<1>(Incoming));

486 }

487

488 Value *Res = getReducedOperand(CurrentTruncInst->getOperand(0), SclTy);

489 Type *DstTy = CurrentTruncInst->getType();

490 if (Res->getType() != DstTy) {

492 Res = Builder.CreateIntCast(Res, DstTy, false);

494 ResI->takeName(CurrentTruncInst);

495 }

496 CurrentTruncInst->replaceAllUsesWith(Res);

497

498

499

500 CurrentTruncInst->eraseFromParent();

501

502 for (auto &Node : OldNewPHINodes) {

503 PHINode *OldPN = Node.first;

505 InstInfoMap.erase(OldPN);

507 }

508

509

510

511

513

514

515

516 if (I.first->use_empty())

517 I.first->eraseFromParent();

518 else

520 "Only {SExt, ZExt}Inst might have unreduced users");

521 }

522}

523

525 bool MadeIRChange = false;

526

527

528 for (auto &BB : F) {

529

530 if (!DT.isReachableFromEntry(&BB))

531 continue;

532 for (auto &I : BB)

534 Worklist.push_back(CI);

535 }

536

537

538

539

540 while (!Worklist.empty()) {

541 CurrentTruncInst = Worklist.pop_back_val();

542

543 if (Type *NewDstSclTy = getBestTruncatedType()) {

545 dbgs() << "ICE: TruncInstCombine reducing type of expression graph "

546 "dominated by: "

547 << CurrentTruncInst << '\n');

548 ReduceExpressionGraph(NewDstSclTy);

549 ++NumExprsReduced;

550 MadeIRChange = true;

551 }

552 }

553

554 return MadeIRChange;

555}

assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")

const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]

This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...

#define STATISTIC(VARNAME, DESC)

static Type * getReducedType(Value *V, Type *Ty)

Given a reduced scalar type Ty and a V value, return a reduced type for V, according to its type,...

Definition TruncInstCombine.cpp:358

static void getRelevantOperands(Instruction *I, SmallVectorImpl< Value * > &Ops)

Given an instruction and a container, it fills all the relevant operands of that instruction,...

Definition TruncInstCombine.cpp:48

unsigned getActiveBits() const

Compute the number of active bits in the value.

uint64_t getLimitedValue(uint64_t Limit=UINT64_MAX) const

If this value is smaller than the specified limit, return it, otherwise return the limit value.

LLVM_ABI APInt uadd_sat(const APInt &RHS) const

static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)

LLVM_ABI InstListType::iterator eraseFromParent()

This method unlinks 'this' from the containing basic block and deletes it.

static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)

This static method is the primary way of constructing an IntegerType.

void addIncoming(Value *V, BasicBlock *BB)

Add an incoming value to the end of the PHI list.

iterator_range< const_block_iterator > blocks() const

op_range incoming_values()

static LLVM_ABI PoisonValue * get(Type *T)

Static factory methods - Return an 'poison' object of the specified type.

This class consists of common code factored out of the SmallVector class to reduce code duplication b...

void push_back(const T &Elt)

bool run(Function &F)

Perform TruncInst pattern optimization on given function.

Definition TruncInstCombine.cpp:524

The instances of the Type class are immutable: once they are created, they are never changed.

bool isVectorTy() const

True if this is an instance of VectorType.

LLVMContext & getContext() const

Return the LLVMContext in which this type was uniqued.

LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY

If this is a vector type, return the getPrimitiveSizeInBits value for the element type.

LLVM Value Representation.

Type * getType() const

All values are typed, get the type of this value.

LLVM_ABI void replaceAllUsesWith(Value *V)

Change all uses of this to point to a new Value.

LLVM_ABI void takeName(Value *V)

Transfer the name from V to this value.

static LLVM_ABI VectorType * get(Type *ElementType, ElementCount EC)

This static method is the primary way to construct an VectorType.

#define llvm_unreachable(msg)

Marks that the current location is not supposed to be reachable.

@ C

The default llvm calling convention, compatible with C.

NodeAddr< NodeBase * > Node

friend class Instruction

Iterator for Instructions in a `BasicBlock.

This is an optimization pass for GlobalISel generic memory operations.

detail::zippy< detail::zip_shortest, T, U, Args... > zip(T &&t, U &&u, Args &&...args)

zip iterator for two or more iteratable types.

FunctionAddr VTableAddr Value

auto find(R &&Range, const T &Val)

Provide wrappers to std::find which take ranges instead of having to pass begin/end explicitly.

decltype(auto) dyn_cast(const From &Val)

dyn_cast - Return the argument parameter cast to the specified type.

void append_range(Container &C, Range &&R)

Wrapper function to append range R to container C.

LLVM_ABI Constant * ConstantFoldConstant(const Constant *C, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr)

ConstantFoldConstant - Fold the constant using the specified DataLayout.

auto reverse(ContainerTy &&C)

LLVM_ABI raw_ostream & dbgs()

dbgs() - This returns a reference to a raw_ostream for debugging messages.

class LLVM_GSL_OWNER SmallVector

Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...

bool isa(const From &Val)

isa - Return true if the parameter to the template is an instance of one of the template type argu...

IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >

DWARFExpression::Operation Op

decltype(auto) cast(const From &Val)

cast - Return the argument parameter cast to the specified type.

bool is_contained(R &&Range, const E &Element)

Returns true if Element is found in Range.

APInt getMaxValue() const

Return the maximal unsigned value possible given these KnownBits.