LLVM: lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
36
37using namespace llvm;
38
39#define DEBUG_TYPE "aggressive-instcombine"
40
41STATISTIC(NumExprsReduced, "Number of truncations eliminated by reducing bit "
42 "width of expression graph");
44 "Number of instructions whose bit width was reduced");
45
46
47
49 unsigned Opc = I->getOpcode();
50 switch (Opc) {
51 case Instruction::Trunc:
52 case Instruction::ZExt:
53 case Instruction::SExt:
54
55
56 break;
57 case Instruction::Add:
58 case Instruction::Sub:
59 case Instruction::Mul:
60 case Instruction::And:
61 case Instruction::Or:
62 case Instruction::Xor:
63 case Instruction::Shl:
64 case Instruction::LShr:
65 case Instruction::AShr:
66 case Instruction::UDiv:
67 case Instruction::URem:
68 case Instruction::InsertElement:
69 Ops.push_back(I->getOperand(0));
70 Ops.push_back(I->getOperand(1));
71 break;
72 case Instruction::ExtractElement:
73 Ops.push_back(I->getOperand(0));
74 break;
75 case Instruction::Select:
76 Ops.push_back(I->getOperand(1));
77 Ops.push_back(I->getOperand(2));
78 break;
79 case Instruction::PHI:
81 break;
82 default:
84 }
85}
86
87bool TruncInstCombine::buildTruncExpressionGraph() {
88 SmallVector<Value *, 8> Worklist;
89 SmallVector<Instruction *, 8> Stack;
90
91 InstInfoMap.clear();
92
93 Worklist.push_back(CurrentTruncInst->getOperand(0));
94
95 while (!Worklist.empty()) {
96 Value *Curr = Worklist.back();
97
99 Worklist.pop_back();
100 continue;
101 }
102
104 if ()
105 return false;
106
107 if (.empty() && Stack.back() == I) {
108
109
110 Worklist.pop_back();
111 Stack.pop_back();
112
113 InstInfoMap.try_emplace(I);
114 continue;
115 }
116
117 if (InstInfoMap.count(I)) {
118 Worklist.pop_back();
119 continue;
120 }
121
122
124
125 unsigned Opc = I->getOpcode();
126 switch (Opc) {
127 case Instruction::Trunc:
128 case Instruction::ZExt:
129 case Instruction::SExt:
130
131
132
133
134 break;
135 case Instruction::Add:
136 case Instruction::Sub:
137 case Instruction::Mul:
138 case Instruction::And:
139 case Instruction::Or:
140 case Instruction::Xor:
141 case Instruction::Shl:
142 case Instruction::LShr:
143 case Instruction::AShr:
144 case Instruction::UDiv:
145 case Instruction::URem:
146 case Instruction::InsertElement:
147 case Instruction::ExtractElement:
148 case Instruction::Select: {
149 SmallVector<Value *, 2> Operands;
152 break;
153 }
154 case Instruction::PHI: {
155 SmallVector<Value *, 2> Operands;
157
158 for (auto *Op : Operands)
160 Worklist.push_back(Op);
161 break;
162 }
163 default:
164
165
166
167
168 return false;
169 }
170 }
171 return true;
172}
173
174unsigned TruncInstCombine::getMinBitWidth() {
175 SmallVector<Value *, 8> Worklist;
176 SmallVector<Instruction *, 8> Stack;
177
178 Value *Src = CurrentTruncInst->getOperand(0);
179 Type *DstTy = CurrentTruncInst->getType();
181 unsigned OrigBitWidth =
182 CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();
183
185 return TruncBitWidth;
186
187 Worklist.push_back(Src);
188 InstInfoMap[cast(Src)].ValidBitWidth = TruncBitWidth;
189
190 while (!Worklist.empty()) {
191 Value *Curr = Worklist.back();
192
194 Worklist.pop_back();
195 continue;
196 }
197
198
200
201 auto &Info = InstInfoMap[I];
202
203 SmallVector<Value *, 2> Operands;
205
206 if (.empty() && Stack.back() == I) {
207
208
209 Worklist.pop_back();
210 Stack.pop_back();
211 for (auto *Operand : Operands)
213 Info.MinBitWidth =
214 std::max(Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth);
215 continue;
216 }
217
218
220 unsigned ValidBitWidth = Info.ValidBitWidth;
221
222
223
224 Info.MinBitWidth = std::max(Info.MinBitWidth, Info.ValidBitWidth);
225
226 for (auto *Operand : Operands)
228
229
230
231 unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth;
232 if (IOpBitwidth >= ValidBitWidth)
233 continue;
234 InstInfoMap[IOp].ValidBitWidth = ValidBitWidth;
235 Worklist.push_back(IOp);
236 }
237 }
238 unsigned MinBitWidth = InstInfoMap.lookup(cast(Src)).MinBitWidth;
239 assert(MinBitWidth >= TruncBitWidth);
240
241 if (MinBitWidth > TruncBitWidth) {
242
243
244
246 return OrigBitWidth;
247
248 Type *Ty = DL.getSmallestLegalIntType(DstTy->getContext(), MinBitWidth);
249
250
252 } else {
253
254
255
256
257 bool FromLegal = MinBitWidth == 1 || DL.isLegalInteger(OrigBitWidth);
258 bool ToLegal = MinBitWidth == 1 || DL.isLegalInteger(MinBitWidth);
259 if (!DstTy->isVectorTy() && FromLegal && !ToLegal)
260 return OrigBitWidth;
261 }
262 return MinBitWidth;
263}
264
265Type *TruncInstCombine::getBestTruncatedType() {
266 if (!buildTruncExpressionGraph())
267 return nullptr;
268
269
270
271
272
273 unsigned DesiredBitWidth = 0;
274 for (auto Itr : InstInfoMap) {
276 if (I->hasOneUse())
277 continue;
279 for (auto *U : I->users())
281 if (UI != CurrentTruncInst && !InstInfoMap.count(UI)) {
282 if (!IsExtInst)
283 return nullptr;
284
285
286
287 unsigned ExtInstBitWidth =
288 I->getOperand(0)->getType()->getScalarSizeInBits();
289 if (DesiredBitWidth && DesiredBitWidth != ExtInstBitWidth)
290 return nullptr;
291 DesiredBitWidth = ExtInstBitWidth;
292 }
293 }
294
295 unsigned OrigBitWidth =
296 CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();
297
298
299
300
301
302
303
304
305
306 for (auto &Itr : InstInfoMap) {
308 if (I->isShift()) {
309 KnownBits KnownRHS = computeKnownBits(I->getOperand(1));
310 unsigned MinBitWidth = KnownRHS.getMaxValue()
311 .uadd_sat(APInt(OrigBitWidth, 1))
313 if (MinBitWidth == OrigBitWidth)
314 return nullptr;
315 if (I->getOpcode() == Instruction::LShr) {
316 KnownBits KnownLHS = computeKnownBits(I->getOperand(0));
317 MinBitWidth =
319 }
320 if (I->getOpcode() == Instruction::AShr) {
321 unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0));
322 MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1);
323 }
324 if (MinBitWidth >= OrigBitWidth)
325 return nullptr;
326 Itr.second.MinBitWidth = MinBitWidth;
327 }
328 if (I->getOpcode() == Instruction::UDiv ||
329 I->getOpcode() == Instruction::URem) {
330 unsigned MinBitWidth = 0;
331 for (const auto &Op : I->operands()) {
332 KnownBits Known = computeKnownBits(Op);
333 MinBitWidth =
335 if (MinBitWidth >= OrigBitWidth)
336 return nullptr;
337 }
338 Itr.second.MinBitWidth = MinBitWidth;
339 }
340 }
341
342
343
344 unsigned MinBitWidth = getMinBitWidth();
345
346
347
348 if (MinBitWidth >= OrigBitWidth ||
349 (DesiredBitWidth && DesiredBitWidth != MinBitWidth))
350 return nullptr;
351
352 return IntegerType::get(CurrentTruncInst->getContext(), MinBitWidth);
353}
354
355
356
357
359 assert(Ty && !Ty->isVectorTy() && "Expect Scalar Type");
362 return Ty;
363}
364
365Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) {
369
371 }
372
374 Info Entry = InstInfoMap.lookup(I);
376 return Entry.NewValue;
377}
378
379void TruncInstCombine::ReduceExpressionGraph(Type *SclTy) {
380 NumInstrsReduced += InstInfoMap.size();
381
383 for (auto &Itr : InstInfoMap) {
385 TruncInstCombine::Info &NodeInfo = Itr.second;
386
387 assert(!NodeInfo.NewValue && "Instruction has been evaluated");
388
390 Value *Res = nullptr;
391 unsigned Opc = I->getOpcode();
392 switch (Opc) {
393 case Instruction::Trunc:
394 case Instruction::ZExt:
395 case Instruction::SExt: {
397
398
399
400 if (I->getOperand(0)->getType() == Ty) {
402 NodeInfo.NewValue = I->getOperand(0);
403 continue;
404 }
405
406
407 Res = Builder.CreateIntCast(I->getOperand(0), Ty,
408 Opc == Instruction::SExt);
409
410
411
412
413
414
416 if (Entry != Worklist.end()) {
419 else
420 Worklist.erase(Entry);
422 Worklist.push_back(NewCI);
423 break;
424 }
425 case Instruction::Add:
426 case Instruction::Sub:
427 case Instruction::Mul:
428 case Instruction::And:
429 case Instruction::Or:
430 case Instruction::Xor:
431 case Instruction::Shl:
432 case Instruction::LShr:
433 case Instruction::AShr:
434 case Instruction::UDiv:
435 case Instruction::URem: {
436 Value *LHS = getReducedOperand(I->getOperand(0), SclTy);
437 Value *RHS = getReducedOperand(I->getOperand(1), SclTy);
439
442 ResI->setIsExact(PEO->isExact());
443 break;
444 }
445 case Instruction::ExtractElement: {
446 Value *Vec = getReducedOperand(I->getOperand(0), SclTy);
447 Value *Idx = I->getOperand(1);
448 Res = Builder.CreateExtractElement(Vec, Idx);
449 break;
450 }
451 case Instruction::InsertElement: {
452 Value *Vec = getReducedOperand(I->getOperand(0), SclTy);
453 Value *NewElt = getReducedOperand(I->getOperand(1), SclTy);
454 Value *Idx = I->getOperand(2);
455 Res = Builder.CreateInsertElement(Vec, NewElt, Idx);
456 break;
457 }
458 case Instruction::Select: {
459 Value *Op0 = I->getOperand(0);
460 Value *LHS = getReducedOperand(I->getOperand(1), SclTy);
461 Value *RHS = getReducedOperand(I->getOperand(2), SclTy);
462 Res = Builder.CreateSelect(Op0, LHS, RHS, "", I);
463 break;
464 }
465 case Instruction::PHI: {
466 Res = Builder.CreatePHI(getReducedType(I, SclTy), I->getNumOperands());
469 break;
470 }
471 default:
473 }
474
475 NodeInfo.NewValue = Res;
478 }
479
480 for (auto &Node : OldNewPHINodes) {
481 PHINode *OldPN = Node.first;
482 PHINode *NewPN = Node.second;
484 NewPN->addIncoming(getReducedOperand(std::get<0>(Incoming), SclTy),
485 std::get<1>(Incoming));
486 }
487
488 Value *Res = getReducedOperand(CurrentTruncInst->getOperand(0), SclTy);
489 Type *DstTy = CurrentTruncInst->getType();
490 if (Res->getType() != DstTy) {
492 Res = Builder.CreateIntCast(Res, DstTy, false);
494 ResI->takeName(CurrentTruncInst);
495 }
496 CurrentTruncInst->replaceAllUsesWith(Res);
497
498
499
500 CurrentTruncInst->eraseFromParent();
501
502 for (auto &Node : OldNewPHINodes) {
503 PHINode *OldPN = Node.first;
505 InstInfoMap.erase(OldPN);
507 }
508
509
510
511
513
514
515
516 if (I.first->use_empty())
517 I.first->eraseFromParent();
518 else
520 "Only {SExt, ZExt}Inst might have unreduced users");
521 }
522}
523
525 bool MadeIRChange = false;
526
527
528 for (auto &BB : F) {
529
530 if (!DT.isReachableFromEntry(&BB))
531 continue;
532 for (auto &I : BB)
534 Worklist.push_back(CI);
535 }
536
537
538
539
540 while (!Worklist.empty()) {
541 CurrentTruncInst = Worklist.pop_back_val();
542
543 if (Type *NewDstSclTy = getBestTruncatedType()) {
545 dbgs() << "ICE: TruncInstCombine reducing type of expression graph "
546 "dominated by: "
547 << CurrentTruncInst << '\n');
548 ReduceExpressionGraph(NewDstSclTy);
549 ++NumExprsReduced;
550 MadeIRChange = true;
551 }
552 }
553
554 return MadeIRChange;
555}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
static Type * getReducedType(Value *V, Type *Ty)
Given a reduced scalar type Ty and a V value, return a reduced type for V, according to its type,...
Definition TruncInstCombine.cpp:358
static void getRelevantOperands(Instruction *I, SmallVectorImpl< Value * > &Ops)
Given an instruction and a container, it fills all the relevant operands of that instruction,...
Definition TruncInstCombine.cpp:48
unsigned getActiveBits() const
Compute the number of active bits in the value.
uint64_t getLimitedValue(uint64_t Limit=UINT64_MAX) const
If this value is smaller than the specified limit, return it, otherwise return the limit value.
LLVM_ABI APInt uadd_sat(const APInt &RHS) const
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
bool run(Function &F)
Perform TruncInst pattern optimization on given function.
Definition TruncInstCombine.cpp:524
The instances of the Type class are immutable: once they are created, they are never changed.
bool isVectorTy() const
True if this is an instance of VectorType.
LLVMContext & getContext() const
Return the LLVMContext in which this type was uniqued.
LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
static LLVM_ABI VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ C
The default llvm calling convention, compatible with C.
NodeAddr< NodeBase * > Node
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
detail::zippy< detail::zip_shortest, T, U, Args... > zip(T &&t, U &&u, Args &&...args)
zip iterator for two or more iteratable types.
FunctionAddr VTableAddr Value
auto find(R &&Range, const T &Val)
Provide wrappers to std::find which take ranges instead of having to pass begin/end explicitly.
decltype(auto) dyn_cast(const From &Val)
dyn_cast - Return the argument parameter cast to the specified type.
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
LLVM_ABI Constant * ConstantFoldConstant(const Constant *C, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr)
ConstantFoldConstant - Fold the constant using the specified DataLayout.
auto reverse(ContainerTy &&C)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa - Return true if the parameter to the template is an instance of one of the template type argu...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast - Return the argument parameter cast to the specified type.
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.