LLVM: lib/Target/ARM/MVEGatherScatterLowering.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
35#include "llvm/IR/IntrinsicsARM.h"
43#include
44
45using namespace llvm;
46
47#define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
48
51 cl::desc("Enable the generation of masked gathers and scatters"));
52
53namespace {
54
55class MVEGatherScatterLowering : public FunctionPass {
56public:
57 static char ID;
58
59 explicit MVEGatherScatterLowering() : FunctionPass(ID) {
61 }
62
64
65 StringRef getPassName() const override {
66 return "MVE gather/scatter lowering";
67 }
68
69 void getAnalysisUsage(AnalysisUsage &AU) const override {
73 FunctionPass::getAnalysisUsage(AU);
74 }
75
76private:
77 LoopInfo *LI = nullptr;
78 const DataLayout *DL;
79
80
81 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
82 Align Alignment);
83
84 void lookThroughBitcast(Value *&Ptr);
85
86
87
88 Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale,
89 FixedVectorType *Ty, Type *MemoryTy,
91
92
93
96
97 int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
98
99
100 std::optional<int64_t> getIfConst(const Value *V);
101
102
103
104 std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
105
107
108 Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
109 Instruction *&Root,
111
112 Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
115
116 Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
119
121
122 Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
124
125 Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
128
129 Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
132
133
134
135 Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr,
137
138
139
140 Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
141 Value *Ptr, unsigned TypeScale,
143
144
145 bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
146
147 Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, unsigned &Scale,
149
150 bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
151
152 void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
153
154 void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound,
155 Value *OffsSecondOperand, unsigned LoopIncrement,
157};
158
159}
160
161char MVEGatherScatterLowering::ID = 0;
162
164 "MVE gather/scattering lowering pass", false, false)
165
167 return new MVEGatherScatterLowering();
168}
169
170bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
171 unsigned ElemSize,
172 Align Alignment) {
173 if (((NumElements == 4 &&
174 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
175 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
176 (NumElements == 16 && ElemSize == 8)) &&
177 Alignment >= ElemSize / 8)
178 return true;
179 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
180 << "valid alignment or vector type \n");
181 return false;
182}
183
185
186
187
188
189
190
191
192
193
194
195 unsigned TargetElemSize = 128 / TargetElemCount;
197 ->getElementType()
198 ->getScalarSizeInBits();
199 if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
201 if (!ConstOff)
202 return false;
203 int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
204 auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
206 if (!OConst)
207 return false;
209 if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
210 return false;
211 return true;
212 };
214 for (unsigned i = 0; i < TargetElemCount; i++) {
216 return false;
217 }
218 } else {
219 if (!CheckValueSize(ConstOff))
220 return false;
221 }
222 }
223 return true;
224}
225
226Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets,
227 int &Scale, FixedVectorType *Ty,
228 Type *MemoryTy,
232 Scale =
233 computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(),
235 return Scale == -1 ? nullptr : V;
236 }
237 }
238
239
240
241
244 return nullptr;
249 Scale = 0;
251}
252
253Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets,
254 FixedVectorType *Ty,
255 GetElementPtrInst *GEP,
257 if () {
258 LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer "
259 << "found\n");
260 return nullptr;
261 }
262 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
263 << " Looking at intrinsic for base + vector of offsets\n");
264 Value *GEPPtr = GEP->getPointerOperand();
268 return nullptr;
269
270 if (GEP->getNumOperands() != 2) {
271 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
272 << " operands. Expanding.\n");
273 return nullptr;
274 }
276 unsigned OffsetsElemCount =
278
280
282 if (ZextOffs)
285
286
287
289 ->getElementType()
290 ->getScalarSizeInBits() != 32)
292 return nullptr;
293
294
295
296 if (Ty != Offsets->getType()) {
300 } else {
301 Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
302 }
303 }
304
305 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
306 return GEPPtr;
307}
308
309void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
310
314 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
315 LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through "
316 << "bitcast\n");
317 Ptr = BitCast->getOperand(0);
318 }
319 }
320}
321
322int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
323 unsigned MemoryElemSize) {
324
325
326 if (GEPElemSize == 32 && MemoryElemSize == 32)
327 return 2;
328 else if (GEPElemSize == 16 && MemoryElemSize == 16)
329 return 1;
330 else if (GEPElemSize == 8)
331 return 0;
332 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
333 << "create intrinsic\n");
334 return -1;
335}
336
337std::optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
339 if (C && C->getSplatValue())
340 return std::optional<int64_t>{C->getUniqueInteger().getSExtValue()};
342 return std::optional<int64_t>{};
343
345 if (I->getOpcode() == Instruction::Add || I->getOpcode() == Instruction::Or ||
346 I->getOpcode() == Instruction::Mul ||
347 I->getOpcode() == Instruction::Shl) {
348 std::optional<int64_t> Op0 = getIfConst(I->getOperand(0));
349 std::optional<int64_t> Op1 = getIfConst(I->getOperand(1));
350 if (!Op0 || !Op1)
351 return std::optional<int64_t>{};
352 if (I->getOpcode() == Instruction::Add)
353 return std::optional<int64_t>{*Op0 + *Op1};
354 if (I->getOpcode() == Instruction::Mul)
355 return std::optional<int64_t>{*Op0 * *Op1};
356 if (I->getOpcode() == Instruction::Shl)
357 return std::optional<int64_t>{*Op0 << *Op1};
358 if (I->getOpcode() == Instruction::Or)
359 return std::optional<int64_t>{*Op0 | *Op1};
360 }
361 return std::optional<int64_t>{};
362}
363
364
365
367 return I->getOpcode() == Instruction::Or &&
369}
370
371std::pair<Value *, int64_t>
372MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
373 std::pair<Value *, int64_t> ReturnFalse =
374 std::pair<Value *, int64_t>(nullptr, 0);
375
376
378 if (Add == nullptr ||
380 return ReturnFalse;
381
383 std::optional<int64_t> Const;
384
385 if ((Const = getIfConst(Add->getOperand(0))))
386 Summand = Add->getOperand(1);
387 else if ((Const = getIfConst(Add->getOperand(1))))
388 Summand = Add->getOperand(0);
389 else
390 return ReturnFalse;
391
392
393 int64_t Immediate = *Const << TypeScale;
394 if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
395 return ReturnFalse;
396
397 return std::pair<Value *, int64_t>(Summand, Immediate);
398}
399
400Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
401 using namespace PatternMatch;
402 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"
403 << *I << "\n");
404
405
406
407
409 Value *Ptr = I->getArgOperand(0);
410 Align Alignment = I->getParamAlign(0).valueOrOne();
412 Value *PassThru = I->getArgOperand(2);
413
415 Alignment))
416 return nullptr;
417 lookThroughBitcast(Ptr);
419
423
425
426 Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder);
427 if (!Load)
428 Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
429 if (!Load)
430 Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
431 if (!Load)
432 return nullptr;
433
435 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
436 << "creating select\n");
438 Builder.Insert(Load);
439 }
440
443 if (Root != I)
444
445
446 I->eraseFromParent();
447
448 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"
449 << *Load << "\n");
451}
452
453Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
455 using namespace PatternMatch;
457 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
459
460 return nullptr;
463 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
466 else
468 Intrinsic::arm_mve_vldr_gather_base_predicated,
471}
472
473Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
475 using namespace PatternMatch;
477 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with "
478 << "writeback\n");
480
481 return nullptr;
484 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
487 else
489 Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
492}
493
494Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
495 IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
496 using namespace PatternMatch;
497
498 Type *MemoryTy = I->getType();
499 Type *ResultTy = MemoryTy;
500
502
503
504 auto *Extend = Root;
505 bool TruncResult = false;
507 if (I->hasOneUse()) {
508
509
510
513 User->getType()->getPrimitiveSizeInBits() == 128) {
514 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
515 << *User << "\n");
516 Extend = User;
517 ResultTy = User->getType();
520 User->getType()->getPrimitiveSizeInBits() == 128) {
521 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
522 << *ResultTy << "\n");
523 Extend = User;
524 ResultTy = User->getType();
525 }
526 }
527
528
529
534 TruncResult = true;
535 LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: "
536 << *ResultTy << "\n");
537 }
538
539
541 LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided "
542 "from the correct type. Expanding\n");
543 return nullptr;
544 }
545 }
546
548 int Scale;
551 if (!BasePtr)
552 return nullptr;
553
554 Root = Extend;
559 Intrinsic::arm_mve_vldr_gather_offset_predicated,
563 else
565 Intrinsic::arm_mve_vldr_gather_offset,
569
570 if (TruncResult) {
571 Load = TruncInst::Create(Instruction::Trunc, Load, MemoryTy);
572 Builder.Insert(Load);
573 }
575}
576
577Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
578 using namespace PatternMatch;
579 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"
580 << *I << "\n");
581
582
583
584
585 Value *Input = I->getArgOperand(0);
586 Value *Ptr = I->getArgOperand(1);
587 Align Alignment = I->getParamAlign(1).valueOrOne();
589
591 Alignment))
592 return nullptr;
593
594 lookThroughBitcast(Ptr);
596
600
601 Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder);
602 if (!Store)
603 Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
604 if (!Store)
605 Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
606 if (!Store)
607 return nullptr;
608
609 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"
610 << *Store << "\n");
611 I->eraseFromParent();
613}
614
615Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
617 using namespace PatternMatch;
618 Value *Input = I->getArgOperand(0);
620
622
623 return nullptr;
624 }
626
627 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
629 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
632 else
634 Intrinsic::arm_mve_vstr_scatter_base_predicated,
637}
638
639Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
641 using namespace PatternMatch;
642 Value *Input = I->getArgOperand(0);
644 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers "
645 << "with writeback\n");
647
648 return nullptr;
651 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
654 else
656 Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
659}
660
661Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
663 using namespace PatternMatch;
664 Value *Input = I->getArgOperand(0);
667 Type *MemoryTy = InputTy;
668
669 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
670 << " to base + vector of offsets\n");
671
672
674 Value *PreTrunc = Trunc->getOperand(0);
677 Input = PreTrunc;
678 InputTy = PreTruncTy;
679 }
680 }
681 bool ExtendInput = false;
684
685
686
687
690 ExtendInput = true;
691 LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n"
692 << *Input << "\n");
693 }
695 LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for "
696 "non-standard input types. Expanding.\n");
697 return nullptr;
698 }
699
701 int Scale;
704 if (!BasePtr)
705 return nullptr;
706
707 if (ExtendInput)
708 Input = Builder.CreateZExt(Input, InputTy);
711 Intrinsic::arm_mve_vstr_scatter_offset_predicated,
713 Mask->getType()},
717 else
719 Intrinsic::arm_mve_vstr_scatter_offset,
720 {BasePtr->getType(), Offsets->getType(), Input->getType()},
724}
725
726Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
728 FixedVectorType *Ty;
729 if (I->getIntrinsicID() == Intrinsic::masked_gather)
731 else
733
734
736 return nullptr;
737
739 if (L == nullptr)
740 return nullptr;
741
742
746 if (!BasePtr)
747 return nullptr;
748
749 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
750 "wb gather/scatter\n");
751
752
753
754 int TypeScale =
755 computeScale(DL->getTypeSizeInBits(GEP->getSourceElementType()),
756 DL->getTypeSizeInBits(GEP->getType()) /
758 if (TypeScale == -1)
759 return nullptr;
760
761 if (GEP->hasOneUse()) {
762
763
764
765 if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets,
766 TypeScale, Builder))
768 }
769
770 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
771 "non-wb gather/scatter\n");
772
773 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
774 if (Add.first == nullptr)
775 return nullptr;
776 Value *OffsetsIncoming = Add.first;
777 int64_t Immediate = Add.second;
778
779
781 Instruction::Shl, OffsetsIncoming,
783 Builder.getInt32(TypeScale)),
784 "ScaledIndex", I->getIterator());
785
787 Instruction::Add, ScaledOffsets,
791 BasePtr,
793 "StartIndex", I->getIterator());
794
795 if (I->getIntrinsicID() == Intrinsic::masked_gather)
796 return tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate);
797 else
798 return tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate);
799}
800
801Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
802 IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
804
805
807
808
810 if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
811 Phi->getParent() != L->getHeader() || ->hasNUses(2))
812
813
814
815
816 return nullptr;
817
818 unsigned IncrementIndex =
819 Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
820
821 Offsets = Phi->getIncomingValue(IncrementIndex);
822
823 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
824 if (Add.first == nullptr)
825 return nullptr;
826 Value *OffsetsIncoming = Add.first;
827 int64_t Immediate = Add.second;
828 if (OffsetsIncoming != Phi)
829
830
831 return nullptr;
832
833 Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
834 unsigned NumElems =
836
837
839 Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
841 "ScaledIndex",
842 Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator());
843
845 Instruction::Add, ScaledOffsets,
847 NumElems,
849 BasePtr,
851 "StartIndex",
852 Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator());
853
855 Instruction::Sub, OffsetsIncoming,
857 "PreIncrementStartIndex",
858 Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator());
859 Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
860
862
865 if (I->getIntrinsicID() == Intrinsic::masked_gather) {
866
867 Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
868
869
872 Builder.Insert(EndResult);
873 Builder.Insert(NewInduction);
874 } else {
875
876 EndResult = NewInduction =
877 tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
878 }
882 Phi->setIncomingValue(IncrementIndex, NewInduction);
883
884 return EndResult;
885}
886
887void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
888 Value *OffsSecondOperand,
889 unsigned StartIndex) {
890 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
892 Phi->getIncomingBlock(StartIndex)->back().getIterator();
893
895 Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
896 "PushedOutAdd", InsertionPoint);
897 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
898
899
900 Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
901 Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
902 Phi->getIncomingBlock(IncrementIndex));
903 Phi->removeIncomingValue(1);
904 Phi->removeIncomingValue((unsigned)0);
905}
906
907void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi,
908 Value *IncrementPerRound,
909 Value *OffsSecondOperand,
910 unsigned LoopIncrement,
912 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
913
914
915
917 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back().getIterator();
918
919
920 Value *StartIndex =
922 Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
923 OffsSecondOperand, "PushedOutMul", InsertionPoint);
924
927 OffsSecondOperand, "Product", InsertionPoint);
928
930 Phi->getIncomingBlock(LoopIncrement)->back().getIterator();
931 NewIncrInsertPt = std::prev(NewIncrInsertPt);
932
933
935 Instruction::Add, Phi, Product, "IncrementPushedOutMul", NewIncrInsertPt);
936
937 Phi->addIncoming(StartIndex,
938 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
939 Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
940 Phi->removeIncomingValue((unsigned)0);
941 Phi->removeIncomingValue((unsigned)0);
942}
943
944
945
947 if (I->use_empty()) {
948 return false;
949 }
950 bool Gatscat = true;
951 for (User *U : I->users()) {
953 return false;
956 return Gatscat;
957 } else {
959 if ((OpCode == Instruction::Add || OpCode == Instruction::Mul ||
960 OpCode == Instruction::Shl ||
963 continue;
964 }
965 return false;
966 }
967 }
968 return Gatscat;
969}
970
971bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
972 LoopInfo *LI) {
973 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize: "
974 << *Offsets << "\n");
975
976
978 return false;
981 Offs->getOpcode() != Instruction::Mul &&
982 Offs->getOpcode() != Instruction::Shl)
983 return false;
985 if (L == nullptr)
986 return false;
989 return false;
990 }
991
992
993
994 PHINode *Phi;
995 int OffsSecondOp;
998 OffsSecondOp = 1;
1001 OffsSecondOp = 0;
1002 } else {
1011 return false;
1014 OffsSecondOp = 1;
1017 OffsSecondOp = 0;
1018 } else {
1019 return false;
1020 }
1021 }
1022
1023
1024 if (Phi->getParent() != L->getHeader())
1025 return false;
1026
1027
1028 BinaryOperator *IncInstruction;
1031 IncInstruction->getOpcode() != Instruction::Add)
1032 return false;
1033
1034 int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1;
1035
1036
1037 Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
1038
1039 if (IncrementPerRound->getType() != OffsSecondOperand->getType() ||
1040 ->isLoopInvariant(OffsSecondOperand))
1041
1042 return false;
1043
1044
1045
1049 return false;
1050
1051
1052
1053 PHINode *NewPhi;
1054 if (Phi->hasNUses(2)) {
1055
1056
1057 if (!IncInstruction->hasOneUse()) {
1058
1059
1062 IncrementPerRound, "LoopIncrement", IncInstruction->getIterator());
1063 Phi->setIncomingValue(IncrementingBlock, IncInstruction);
1064 }
1065 NewPhi = Phi;
1066 } else {
1067
1069
1070 NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
1071 Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
1074 IncrementPerRound, "LoopIncrement", IncInstruction->getIterator());
1076 Phi->getIncomingBlock(IncrementingBlock));
1077 IncrementingBlock = 1;
1078 }
1079
1083
1085 case Instruction::Add:
1086 case Instruction::Or:
1087 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
1088 break;
1089 case Instruction::Mul:
1090 case Instruction::Shl:
1091 pushOutMulShl(Offs->getOpcode(), NewPhi, IncrementPerRound,
1092 OffsSecondOperand, IncrementingBlock, Builder);
1093 break;
1094 default:
1095 return false;
1096 }
1097 LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable "
1098 << "add/mul\n");
1099
1100
1103
1104
1105 if (IncInstruction->use_empty())
1107
1108 return true;
1109}
1110
1112 unsigned ScaleY, IRBuilder<> &Builder) {
1113
1114
1115
1121 uint64_t N = Const->getZExtValue();
1122 if (N < (unsigned)(1 << (TargetElemSize - 1))) {
1123 NonVectorVal = Builder.CreateVectorSplat(
1124 VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
1125 return;
1126 }
1127 }
1128 NonVectorVal =
1129 Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
1130 };
1131
1134
1135
1136 if (XElType && !YElType) {
1137 FixSummands(XElType, Y);
1139 } else if (YElType && !XElType) {
1140 FixSummands(YElType, X);
1142 }
1143 assert(XElType && YElType && "Unknown vector types");
1144
1145 if (XElType != YElType) {
1146 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
1147 return nullptr;
1148 }
1149
1151
1152
1155 if (!ConstX || !ConstY)
1156 return nullptr;
1157 unsigned TargetElemSize = 128 / XElType->getNumElements();
1158 for (unsigned i = 0; i < XElType->getNumElements(); i++) {
1163 if (!ConstXEl || !ConstYEl ||
1166 (unsigned)(1 << (TargetElemSize - 1)))
1167 return nullptr;
1168 }
1169 }
1170
1171 Value *XScale = Builder.CreateVectorSplat(
1174 Value *YScale = Builder.CreateVectorSplat(
1177 Value *Add = Builder.CreateAdd(Builder.CreateMul(X, XScale),
1178 Builder.CreateMul(Y, YScale));
1179
1181 return Add;
1182 else
1183 return nullptr;
1184}
1185
1186Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
1187 Value *&Offsets, unsigned &Scale,
1189 Value *GEPPtr = GEP->getPointerOperand();
1191 Scale = DL->getTypeAllocSize(GEP->getSourceElementType());
1192
1193
1195 return nullptr;
1197
1198 Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Scale, Builder);
1199 if (!BaseBasePtr)
1200 return nullptr;
1202 Offsets, Scale, GEP->getOperand(1),
1203 DL->getTypeAllocSize(GEP->getSourceElementType()), Builder);
1204 if (Offsets == nullptr)
1205 return nullptr;
1206 Scale = 1;
1207 return BaseBasePtr;
1208 }
1209 return GEPPtr;
1210}
1211
1212bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
1213 LoopInfo *LI) {
1215 if ()
1216 return false;
1223 unsigned Scale;
1224 Value *Base = foldGEP(GEP, Offsets, Scale, Builder);
1225
1226
1227
1228
1230 assert(Scale == 1 && "Expected to fold GEP to a scale of 1");
1236 "gep.merged", GEP->getIterator());
1238 << "\n new : " << *NewAddress << "\n");
1239 GEP->replaceAllUsesWith(
1241 GEP = NewAddress;
1243 }
1244 }
1245 Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
1247}
1248
1249bool MVEGatherScatterLowering::runOnFunction(Function &F) {
1251 return false;
1252 auto &TPC = getAnalysis();
1253 auto &TM = TPC.getTM();
1254 auto *ST = &TM.getSubtarget(F);
1255 if (->hasMVEIntegerOps())
1256 return false;
1257 LI = &getAnalysis().getLoopInfo();
1261
1263
1264 for (BasicBlock &BB : F) {
1266
1267 for (Instruction &I : BB) {
1269 if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
1272 Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
1273 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
1276 Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
1277 }
1278 }
1279 }
1280 for (IntrinsicInst *I : Gathers) {
1282 if (L == nullptr)
1283 continue;
1284
1285
1288 }
1289
1290 for (IntrinsicInst *I : Scatters) {
1292 if (S == nullptr)
1293 continue;
1294
1295
1298 }
1300}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
cl::opt< bool > EnableMaskedGatherScatters
This file contains the declarations for the subclasses of Constant, which represent the different fla...
static Decomposition decomposeGEP(GEPOperator &GEP, SmallVectorImpl< ConditionTy > &Preconditions, bool IsSigned, const DataLayout &DL)
static bool runOnFunction(Function &F, bool PostInlining)
static bool isAddLikeOr(Instruction *I, const DataLayout &DL)
Definition MVEGatherScatterLowering.cpp:366
static bool hasAllGatScatUsers(Instruction *I, const DataLayout &DL)
Definition MVEGatherScatterLowering.cpp:946
static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount)
Definition MVEGatherScatterLowering.cpp:184
static Value * CheckAndCreateOffsetAdd(Value *X, unsigned ScaleX, Value *Y, unsigned ScaleY, IRBuilder<> &Builder)
Definition MVEGatherScatterLowering.cpp:1111
cl::opt< bool > EnableMaskedGatherScatters("enable-arm-maskedgatscat", cl::Hidden, cl::init(true), cl::desc("Enable the generation of masked gathers and scatters"))
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
static unsigned getNumElements(Type *Ty)
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")
This file describes how to lower LLVM code to machine code.
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
InstListType::iterator iterator
Instruction iterators...
LLVM_ABI LLVMContext & getContext() const
Get the context in which this basic block lives.
BinaryOps getOpcode() const
static LLVM_ABI BinaryOperator * Create(BinaryOps Op, Value *S1, Value *S2, const Twine &Name=Twine(), InsertPosition InsertBefore=nullptr)
Construct a binary instruction, given the opcode and the two operands.
Type * getDestTy() const
Return the destination type, as a convenience.
This is the shared class of boolean and integer constants.
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
This is an important base class in LLVM.
LLVM_ABI Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
A parsed version of the target data layout string in and methods for querying it.
Class to represent fixed width SIMD vectors.
unsigned getNumElements() const
static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)
FunctionPass class - This class is used to implement most global optimizations.
static GetElementPtrInst * Create(Type *PointeeType, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
LLVM_ABI Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
Value * CreateIntToPtr(Value *V, Type *DestTy, const Twine &Name="")
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
void SetCurrentDebugLocation(DebugLoc L)
Set location information used by debugging information.
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
ConstantInt * getInt32(uint32_t C)
Get a constant 32-bit value.
InstTy * Insert(InstTy *I, const Twine &Name="") const
Insert and return the specified instruction.
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
Value * CreateZExt(Value *V, Type *DestTy, const Twine &Name="", bool IsNonNeg=false)
Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")
Value * CreateTrunc(Value *V, Type *DestTy, const Twine &Name="", bool IsNUW=false, bool IsNSW=false)
PointerType * getPtrTy(unsigned AddrSpace=0)
Fetch the type representing a pointer.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
IntegerType * getInt8Ty()
Fetch the type representing an 8-bit integer.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Pass interface - Implemented by all 'passes'.
static SelectInst * Create(Value *C, Value *S1, Value *S2, const Twine &NameStr="", InsertPosition InsertBefore=nullptr, const Instruction *MDFrom=nullptr)
void push_back(const T &Elt)
bool isVectorTy() const
True if this is an instance of VectorType.
bool isIntOrIntVectorTy() const
Return true if this is an integer type or a vector of integer types.
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
LLVM_ABI Type * getWithNewBitWidth(unsigned NewBitWidth) const
Given an integer or vector type, change the lane bitwidth to NewBitwidth, whilst keeping the old numb...
LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
bool hasOneUse() const
Return true if there is exactly one use of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Type * getElementType() const
const ParentTy * getParent() const
self_iterator getIterator()
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
bool match(Val *V, const Pattern &P)
cst_pred_ty< is_one > m_One()
Match an integer 1 or a vector with all elements equal to 1.
is_zero m_Zero()
Match any null constant or a vector with all elements equal to 0.
Offsets
Offsets in bytes from the start of the input buffer.
initializer< Ty > init(const Ty &Val)
@ User
could "use" a pointer
NodeAddr< PhiNode * > Phi
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
LLVM_ABI bool haveNoCommonBitsSet(const WithCache< const Value * > &LHSCache, const WithCache< const Value * > &RHSCache, const SimplifyQuery &SQ)
Return true if LHS and RHS have no common bits set.
FunctionAddr VTableAddr Value
decltype(auto) dyn_cast(const From &Val)
dyn_cast - Return the argument parameter cast to the specified type.
Pass * createMVEGatherScatterLoweringPass()
LLVM_ABI bool SimplifyInstructionsInBlock(BasicBlock *BB, const TargetLibraryInfo *TLI=nullptr)
Scan the specified basic block and try to simplify any instructions in it and recursively delete dead...
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
bool isGatherScatter(IntrinsicInst *IntInst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa - Return true if the parameter to the template is an instance of one of the template type argu...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
decltype(auto) cast(const From &Val)
cast - Return the argument parameter cast to the specified type.
void initializeMVEGatherScatterLoweringPass(PassRegistry &)
@ Increment
Incrementally increasing token ID.
This struct is a compact representation of a valid (non-zero power of two) alignment.