LLVM: lib/Target/RISCV/RISCVGatherScatterLowering.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
26#include
27
28using namespace llvm;
30
31#define DEBUG_TYPE "riscv-gather-scatter-lowering"
32
33namespace {
34
35class RISCVGatherScatterLowering : public FunctionPass {
40
42
43
44
45
47
48public:
49 static char ID;
50
52
54
55 void getAnalysisUsage(AnalysisUsage &AU) const override {
59 }
60
61 StringRef getPassName() const override {
62 return "RISC-V gather/scatter lowering";
63 }
64
65private:
67
68 std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr,
70
71 bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
74};
75
76}
77
78char RISCVGatherScatterLowering::ID = 0;
79
81 "RISC-V gather/scatter lowering pass", false, false)
82
84 return new RISCVGatherScatterLowering();
85}
86
87
90 return std::make_pair(nullptr, nullptr);
91
93
94
95 auto *StartVal =
97 if (!StartVal)
98 return std::make_pair(nullptr, nullptr);
99 APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
101 for (unsigned i = 1; i != NumElts; ++i) {
103 if ()
104 return std::make_pair(nullptr, nullptr);
105
106 APInt LocalStride = C->getValue() - Prev->getValue();
107 if (i == 1)
108 StrideVal = LocalStride;
109 else if (StrideVal != LocalStride)
110 return std::make_pair(nullptr, nullptr);
111
112 Prev = C;
113 }
114
115 Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
116
117 return std::make_pair(StartVal, Stride);
118}
119
122
124 if (StartC)
126
127
129 auto *Ty = Start->getType()->getScalarType();
130 return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
131 }
132
133
134
136 if (!BO || (BO->getOpcode() != Instruction::Add &&
137 BO->getOpcode() != Instruction::Or &&
138 BO->getOpcode() != Instruction::Shl &&
139 BO->getOpcode() != Instruction::Mul))
140 return std::make_pair(nullptr, nullptr);
141
142 if (BO->getOpcode() == Instruction::Or &&
144 return std::make_pair(nullptr, nullptr);
145
146
147 unsigned OtherIndex = 0;
151 OtherIndex = 1;
152 }
154 return std::make_pair(nullptr, nullptr);
155
157 std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),
158 Builder);
159 if (!Start)
160 return std::make_pair(nullptr, nullptr);
161
162 Builder.SetInsertPoint(BO);
163 Builder.SetCurrentDebugLocation(DebugLoc());
164
165
166 switch (BO->getOpcode()) {
167 default:
169 case Instruction::Or:
170 Start = Builder.CreateOr(Start, Splat, "", true);
171 break;
172 case Instruction::Add:
173 Start = Builder.CreateAdd(Start, Splat);
174 break;
175 case Instruction::Mul:
176 Start = Builder.CreateMul(Start, Splat);
177 Stride = Builder.CreateMul(Stride, Splat);
178 break;
179 case Instruction::Shl:
180 Start = Builder.CreateShl(Start, Splat);
181 Stride = Builder.CreateShl(Stride, Splat);
182 break;
183 }
184
185 return std::make_pair(Start, Stride);
186}
187
188
189
190
191
192bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
197
199
200
201 if (Phi->getParent() != L->getHeader())
202 return false;
203
206 Inc->getOpcode() != Instruction::Add)
207 return false;
208 assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
209 unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
210 assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
211 "Expected one operand of phi to be Inc");
212
213
215 if (!Step)
216 return false;
217
219 if (!Start)
220 return false;
221 assert(Stride != nullptr);
222
223
226 Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",
228 BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
229 BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
230
231
232 MaybeDeadPHIs.push_back(Phi);
233 return true;
234 }
235
236
238 if (!BO)
239 return false;
240
241 switch (BO->getOpcode()) {
242 default:
243 return false;
244 case Instruction::Or:
245
247 return false;
248 break;
249 case Instruction::Add:
250 break;
251 case Instruction::Shl:
252 break;
253 case Instruction::Mul:
254 break;
255 }
256
257
262 OtherOp = BO->getOperand(1);
267 OtherOp = BO->getOperand(0);
268 } else {
269 return false;
270 }
271
272
273 if (->isLoopInvariant(OtherOp))
274 return false;
275
276
278 if (!SplatOp)
279 return false;
280
281
282 if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
283 return false;
284
285
287 unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;
290
291
293 BasePtr->getIncomingBlock(StartBlock)->getTerminator());
295
296
297 switch (BO->getOpcode()) {
298 default:
300 case Instruction::Add:
301 case Instruction::Or: {
302
303
305 break;
306 }
307 case Instruction::Mul: {
309 Stride = Builder.CreateMul(Stride, SplatOp, "stride");
310 break;
311 }
312 case Instruction::Shl: {
314 Stride = Builder.CreateShl(Stride, SplatOp, "stride");
315 break;
316 }
317 }
318
319
320
322 Builder.SetInsertPoint(*StepI->getInsertionPointAfterDef());
323
324 switch (BO->getOpcode()) {
325 default:
326 break;
327 case Instruction::Mul:
328 Step = Builder.CreateMul(Step, SplatOp, "step");
329 break;
330 case Instruction::Shl:
331 Step = Builder.CreateShl(Step, SplatOp, "step");
332 break;
333 }
334
336 BasePtr->setIncomingValue(StartBlock, Start);
337 return true;
338}
339
340std::pair<Value *, Value *>
341RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
342 IRBuilderBase &Builder) {
343
344
346 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
347 return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0));
348 }
349
351 if ()
352 return std::make_pair(nullptr, nullptr);
353
354 auto I = StridedAddrs.find(GEP);
355 if (I != StridedAddrs.end())
356 return I->second;
357
358 SmallVector<Value *, 2> Ops(GEP->operands());
359
360
363 BaseInst && BaseInst->getType()->isVectorTy()) {
364
365 auto IsScalar = [](Value *Idx) { return !Idx->getType()->isVectorTy(); };
366 if (all_of(GEP->indices(), IsScalar)) {
367 auto [BaseBase, Stride] = determineBaseAndStride(BaseInst, Builder);
368 if (BaseBase) {
371 Value *OffsetBase =
372 Builder.CreateGEP(GEP->getSourceElementType(), BaseBase, Indices,
373 GEP->getName() + "offset", GEP->isInBounds());
374 return {OffsetBase, Stride};
375 }
376 }
377 }
378
379
383 if (!ScalarBase)
384 return std::make_pair(nullptr, nullptr);
385 }
386
387 std::optional VecOperand;
388 unsigned TypeScale = 0;
389
390
392 for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
393 if ([i]->getType()->isVectorTy())
394 continue;
395
396 if (VecOperand)
397 return std::make_pair(nullptr, nullptr);
398
399 VecOperand = i;
400
403 return std::make_pair(nullptr, nullptr);
404
406 }
407
408
409 if (!VecOperand)
410 return std::make_pair(nullptr, nullptr);
411
412
413
414
415
416
417 Value *VecIndex = Ops[*VecOperand];
418 Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
419 if (VecIndex->getType() != VecIntPtrTy) {
421 if (!VecIndexC)
422 return std::make_pair(nullptr, nullptr);
425 else
427 }
428
429
430
432 if (Start) {
435
436
438 Type *SourceTy = GEP->getSourceElementType();
441
442
443 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
444 assert(Stride->getType() == IntPtrTy && "Unexpected type");
445
446
447 if (TypeScale != 1)
448 Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
449
450 auto P = std::make_pair(BasePtr, Stride);
452 return P;
453 }
454
455
457 if (!L || ->getLoopPreheader() ||
->getLoopLatch())
458 return std::make_pair(nullptr, nullptr);
459
460 BinaryOperator *Inc;
461 PHINode *BasePhi;
462 if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
463 return std::make_pair(nullptr, nullptr);
464
466 unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;
468 "Expected one operand of phi to be Inc");
469
471
472
473 Ops[*VecOperand] = BasePhi;
474 Type *SourceTy = GEP->getSourceElementType();
477
478
481
482
483 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
484 assert(Stride->getType() == IntPtrTy && "Unexpected type");
485
486
487 if (TypeScale != 1)
488 Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
489
490 auto P = std::make_pair(BasePtr, Stride);
492 return P;
493}
494
495bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II) {
497 Value *StoreVal = nullptr, *Ptr, *Mask, *EVL = nullptr;
499 switch (II->getIntrinsicID()) {
500 case Intrinsic::masked_gather:
502 Ptr = II->getArgOperand(0);
503 Alignment = II->getParamAlign(0).valueOrOne();
504 Mask = II->getArgOperand(1);
505 break;
506 case Intrinsic::vp_gather:
508 Ptr = II->getArgOperand(0);
509
510 Alignment = II->getParamAlign(0).value_or(
511 DL->getABITypeAlign(DataType->getElementType()));
512 Mask = II->getArgOperand(1);
513 EVL = II->getArgOperand(2);
514 break;
515 case Intrinsic::masked_scatter:
517 StoreVal = II->getArgOperand(0);
518 Ptr = II->getArgOperand(1);
519 Alignment = II->getParamAlign(1).valueOrOne();
520 Mask = II->getArgOperand(2);
521 break;
522 case Intrinsic::vp_scatter:
524 StoreVal = II->getArgOperand(0);
525 Ptr = II->getArgOperand(1);
526
527 Alignment = II->getParamAlign(1).value_or(
528 DL->getABITypeAlign(DataType->getElementType()));
529 Mask = II->getArgOperand(2);
530 EVL = II->getArgOperand(3);
531 break;
532 default:
534 }
535
536
539 return false;
540
541
543 return false;
544
545
547 if (!PtrI)
548 return false;
549
550 LLVMContext &Ctx = PtrI->getContext();
551 IRBuilder Builder(Ctx, InstSimplifyFolder(*DL));
553
555 std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder);
556 if (!BasePtr)
557 return false;
558 assert(Stride != nullptr);
559
561
562 if (!EVL)
565
567
568 if (!StoreVal) {
570 Intrinsic::experimental_vp_strided_load,
573
574
575 if (II->getIntrinsicID() == Intrinsic::masked_gather)
577 } else
579 Intrinsic::experimental_vp_strided_store,
582
584 II->replaceAllUsesWith(Call);
585 II->eraseFromParent();
586
587 if (PtrI->use_empty())
589
590 return true;
591}
592
593bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
594 if (skipFunction(F))
595 return false;
596
597 auto &TPC = getAnalysis();
598 auto &TM = TPC.getTM();
599 ST = &TM.getSubtarget(F);
600 if (->hasVInstructions() ||
->useRVVForFixedLengthVectors())
601 return false;
602
603 TLI = ST->getTargetLowering();
605 LI = &getAnalysis().getLoopInfo();
606
607 StridedAddrs.clear();
608
610
612
613 for (BasicBlock &BB : F) {
614 for (Instruction &I : BB) {
616 if ()
617 continue;
618 switch (II->getIntrinsicID()) {
619 case Intrinsic::masked_gather:
620 case Intrinsic::masked_scatter:
621 case Intrinsic::vp_gather:
622 case Intrinsic::vp_scatter:
624 break;
625 default:
626 break;
627 }
628 }
629 }
630
631
632 for (auto *II : Worklist)
633 Changed |= tryCreateStridedLoadStore(II);
634
635
636 while (!MaybeDeadPHIs.empty()) {
639 }
640
642}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static bool runOnFunction(Function &F, bool PostInlining)
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
static std::pair< Value *, Value * > matchStridedStart(Value *Start, IRBuilderBase &Builder)
Definition RISCVGatherScatterLowering.cpp:120
static std::pair< Value *, Value * > matchStridedConstant(Constant *StartC)
Definition RISCVGatherScatterLowering.cpp:88
static SymbolRef::Type getType(const Symbol *Sym)
Target-Independent Code Generator Pass Configuration Options pass.
Class for arbitrary precision integers.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
BinaryOps getOpcode() const
This is the shared class of boolean and integer constants.
const APInt & getValue() const
Return the constant as an APInt value reference.
This is an important base class in LLVM.
LLVM_ABI Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
A parsed version of the target data layout string in and methods for querying it.
FunctionPass class - This class is used to implement most global optimizations.
Common base class shared among various IRBuilders.
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
void SetCurrentDebugLocation(DebugLoc L)
Set location information used by debugging information.
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
LLVM_ABI Value * CreateElementCount(Type *Ty, ElementCount EC)
Create an expression which evaluates to the number of elements in EC at runtime.
LLVM_ABI bool isCommutative() const LLVM_READONLY
Return true if the instruction is commutative:
A wrapper class for inspecting calls to intrinsic functions.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Represents a single loop in the control flow graph.
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
bool isLegalStridedLoadStore(EVT DataType, Align Alignment) const
Return true if a stride load store of the given result type and alignment is legal.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
EVT getValueType(const DataLayout &DL, Type *Ty, bool AllowUnknown=false) const
Return the EVT corresponding to this LLVM type.
bool isTypeLegal(EVT VT) const
Return true if the target has native support for the specified value type.
Target-Independent Code Generator Pass Configuration Options.
bool isVectorTy() const
True if this is an instance of VectorType.
LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
void setOperand(unsigned i, Value *Val)
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
constexpr ScalarTy getFixedValue() const
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
TypeSize getSequentialElementStride(const DataLayout &DL) const
self_iterator getIterator()
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_IntrinsicIntrinsic::fabs(m_Value(X))
NodeAddr< PhiNode * > Phi
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
decltype(auto) dyn_cast(const From &Val)
dyn_cast - Return the argument parameter cast to the specified type.
LLVM_ABI Value * getSplatValue(const Value *V)
Get splat value if the input is a splat vector or return nullptr.
FunctionPass * createRISCVGatherScatterLoweringPass()
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
generic_gep_type_iterator<> gep_type_iterator
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa - Return true if the parameter to the template is an instance of one of the template type argu...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast - Return the argument parameter cast to the specified type.
gep_type_iterator gep_type_begin(const User *GEP)
LLVM_ABI bool RecursivelyDeleteDeadPHINode(PHINode *PN, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
If the specified value is an effectively dead PHI node, due to being a def-use chain of single-use no...
LLVM_ABI Constant * ConstantFoldCastInstruction(unsigned opcode, Constant *V, Type *DestTy)