LLVM: lib/Target/ARM/MVEGatherScatterLowering.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

35#include "llvm/IR/IntrinsicsARM.h"

43#include

44

45using namespace llvm;

46

47#define DEBUG_TYPE "arm-mve-gather-scatter-lowering"

48

51 cl::desc("Enable the generation of masked gathers and scatters"));

52

53namespace {

54

55class MVEGatherScatterLowering : public FunctionPass {

56public:

57 static char ID;

58

59 explicit MVEGatherScatterLowering() : FunctionPass(ID) {

61 }

62

64

65 StringRef getPassName() const override {

66 return "MVE gather/scatter lowering";

67 }

68

69 void getAnalysisUsage(AnalysisUsage &AU) const override {

73 FunctionPass::getAnalysisUsage(AU);

74 }

75

76private:

77 LoopInfo *LI = nullptr;

78 const DataLayout *DL;

79

80

81 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,

82 Align Alignment);

83

84 void lookThroughBitcast(Value *&Ptr);

85

86

87

88 Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale,

89 FixedVectorType *Ty, Type *MemoryTy,

91

92

93

96

97 int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);

98

99

100 std::optional<int64_t> getIfConst(const Value *V);

101

102

103

104 std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);

105

107

108 Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,

109 Instruction *&Root,

111

112 Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,

115

116 Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,

119

121

122 Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,

124

125 Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,

128

129 Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,

132

133

134

135 Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr,

137

138

139

140 Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,

141 Value *Ptr, unsigned TypeScale,

143

144

145 bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);

146

147 Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, unsigned &Scale,

149

150 bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);

151

152 void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);

153

154 void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound,

155 Value *OffsSecondOperand, unsigned LoopIncrement,

157};

158

159}

160

161char MVEGatherScatterLowering::ID = 0;

162

164 "MVE gather/scattering lowering pass", false, false)

165

167 return new MVEGatherScatterLowering();

168}

169

170bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,

171 unsigned ElemSize,

172 Align Alignment) {

173 if (((NumElements == 4 &&

174 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||

175 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||

176 (NumElements == 16 && ElemSize == 8)) &&

177 Alignment >= ElemSize / 8)

178 return true;

179 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "

180 << "valid alignment or vector type \n");

181 return false;

182}

183

185

186

187

188

189

190

191

192

193

194

195 unsigned TargetElemSize = 128 / TargetElemCount;

197 ->getElementType()

198 ->getScalarSizeInBits();

199 if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {

201 if (!ConstOff)

202 return false;

203 int64_t TargetElemMaxSize = (1ULL << TargetElemSize);

204 auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {

206 if (!OConst)

207 return false;

209 if (SExtValue >= TargetElemMaxSize || SExtValue < 0)

210 return false;

211 return true;

212 };

214 for (unsigned i = 0; i < TargetElemCount; i++) {

216 return false;

217 }

218 } else {

219 if (!CheckValueSize(ConstOff))

220 return false;

221 }

222 }

223 return true;

224}

225

226Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets,

227 int &Scale, FixedVectorType *Ty,

228 Type *MemoryTy,

232 Scale =

233 computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(),

235 return Scale == -1 ? nullptr : V;

236 }

237 }

238

239

240

241

244 return nullptr;

249 Scale = 0;

251}

252

253Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets,

254 FixedVectorType *Ty,

255 GetElementPtrInst *GEP,

257 if (GEP) {

258 LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer "

259 << "found\n");

260 return nullptr;

261 }

262 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."

263 << " Looking at intrinsic for base + vector of offsets\n");

264 Value *GEPPtr = GEP->getPointerOperand();

268 return nullptr;

269

270 if (GEP->getNumOperands() != 2) {

271 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"

272 << " operands. Expanding.\n");

273 return nullptr;

274 }

276 unsigned OffsetsElemCount =

278

280

282 if (ZextOffs)

285

286

287

289 ->getElementType()

290 ->getScalarSizeInBits() != 32)

292 return nullptr;

293

294

295

296 if (Ty != Offsets->getType()) {

300 } else {

301 Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));

302 }

303 }

304

305 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");

306 return GEPPtr;

307}

308

309void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {

310

314 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {

315 LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through "

316 << "bitcast\n");

317 Ptr = BitCast->getOperand(0);

318 }

319 }

320}

321

322int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,

323 unsigned MemoryElemSize) {

324

325

326 if (GEPElemSize == 32 && MemoryElemSize == 32)

327 return 2;

328 else if (GEPElemSize == 16 && MemoryElemSize == 16)

329 return 1;

330 else if (GEPElemSize == 8)

331 return 0;

332 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "

333 << "create intrinsic\n");

334 return -1;

335}

336

337std::optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {

339 if (C && C->getSplatValue())

340 return std::optional<int64_t>{C->getUniqueInteger().getSExtValue()};

342 return std::optional<int64_t>{};

343

345 if (I->getOpcode() == Instruction::Add || I->getOpcode() == Instruction::Or ||

346 I->getOpcode() == Instruction::Mul ||

347 I->getOpcode() == Instruction::Shl) {

348 std::optional<int64_t> Op0 = getIfConst(I->getOperand(0));

349 std::optional<int64_t> Op1 = getIfConst(I->getOperand(1));

350 if (!Op0 || !Op1)

351 return std::optional<int64_t>{};

352 if (I->getOpcode() == Instruction::Add)

353 return std::optional<int64_t>{*Op0 + *Op1};

354 if (I->getOpcode() == Instruction::Mul)

355 return std::optional<int64_t>{*Op0 * *Op1};

356 if (I->getOpcode() == Instruction::Shl)

357 return std::optional<int64_t>{*Op0 << *Op1};

358 if (I->getOpcode() == Instruction::Or)

359 return std::optional<int64_t>{*Op0 | *Op1};

360 }

361 return std::optional<int64_t>{};

362}

363

364

365

367 return I->getOpcode() == Instruction::Or &&

369}

370

371std::pair<Value *, int64_t>

372MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {

373 std::pair<Value *, int64_t> ReturnFalse =

374 std::pair<Value *, int64_t>(nullptr, 0);

375

376

378 if (Add == nullptr ||

380 return ReturnFalse;

381

383 std::optional<int64_t> Const;

384

385 if ((Const = getIfConst(Add->getOperand(0))))

386 Summand = Add->getOperand(1);

387 else if ((Const = getIfConst(Add->getOperand(1))))

388 Summand = Add->getOperand(0);

389 else

390 return ReturnFalse;

391

392

393 int64_t Immediate = *Const << TypeScale;

394 if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)

395 return ReturnFalse;

396

397 return std::pair<Value *, int64_t>(Summand, Immediate);

398}

399

400Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {

401 using namespace PatternMatch;

402 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"

403 << *I << "\n");

404

405

406

407

409 Value *Ptr = I->getArgOperand(0);

410 Align Alignment = I->getParamAlign(0).valueOrOne();

412 Value *PassThru = I->getArgOperand(2);

413

415 Alignment))

416 return nullptr;

417 lookThroughBitcast(Ptr);

419

423

425

426 Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder);

427 if (!Load)

428 Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);

429 if (!Load)

430 Load = tryCreateMaskedGatherBase(I, Ptr, Builder);

431 if (!Load)

432 return nullptr;

433

435 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "

436 << "creating select\n");

438 Builder.Insert(Load);

439 }

440

443 if (Root != I)

444

445

446 I->eraseFromParent();

447

448 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"

449 << *Load << "\n");

451}

452

453Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase(

455 using namespace PatternMatch;

457 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");

459

460 return nullptr;

463 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,

466 else

468 Intrinsic::arm_mve_vldr_gather_base_predicated,

471}

472

473Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(

475 using namespace PatternMatch;

477 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with "

478 << "writeback\n");

480

481 return nullptr;

484 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,

487 else

489 Intrinsic::arm_mve_vldr_gather_base_wb_predicated,

492}

493

494Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(

495 IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {

496 using namespace PatternMatch;

497

498 Type *MemoryTy = I->getType();

499 Type *ResultTy = MemoryTy;

500

502

503

504 auto *Extend = Root;

505 bool TruncResult = false;

507 if (I->hasOneUse()) {

508

509

510

513 User->getType()->getPrimitiveSizeInBits() == 128) {

514 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "

515 << *User << "\n");

516 Extend = User;

517 ResultTy = User->getType();

520 User->getType()->getPrimitiveSizeInBits() == 128) {

521 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "

522 << *ResultTy << "\n");

523 Extend = User;

524 ResultTy = User->getType();

525 }

526 }

527

528

529

534 TruncResult = true;

535 LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: "

536 << *ResultTy << "\n");

537 }

538

539

541 LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided "

542 "from the correct type. Expanding\n");

543 return nullptr;

544 }

545 }

546

548 int Scale;

551 if (!BasePtr)

552 return nullptr;

553

554 Root = Extend;

559 Intrinsic::arm_mve_vldr_gather_offset_predicated,

563 else

565 Intrinsic::arm_mve_vldr_gather_offset,

569

570 if (TruncResult) {

571 Load = TruncInst::Create(Instruction::Trunc, Load, MemoryTy);

572 Builder.Insert(Load);

573 }

575}

576

577Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {

578 using namespace PatternMatch;

579 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"

580 << *I << "\n");

581

582

583

584

585 Value *Input = I->getArgOperand(0);

586 Value *Ptr = I->getArgOperand(1);

587 Align Alignment = I->getParamAlign(1).valueOrOne();

589

591 Alignment))

592 return nullptr;

593

594 lookThroughBitcast(Ptr);

596

600

601 Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder);

602 if (!Store)

603 Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);

604 if (!Store)

605 Store = tryCreateMaskedScatterBase(I, Ptr, Builder);

606 if (!Store)

607 return nullptr;

608

609 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"

610 << *Store << "\n");

611 I->eraseFromParent();

613}

614

615Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase(

617 using namespace PatternMatch;

618 Value *Input = I->getArgOperand(0);

620

622

623 return nullptr;

624 }

626

627 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");

629 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,

632 else

634 Intrinsic::arm_mve_vstr_scatter_base_predicated,

637}

638

639Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(

641 using namespace PatternMatch;

642 Value *Input = I->getArgOperand(0);

644 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers "

645 << "with writeback\n");

647

648 return nullptr;

651 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,

654 else

656 Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,

659}

660

661Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(

663 using namespace PatternMatch;

664 Value *Input = I->getArgOperand(0);

667 Type *MemoryTy = InputTy;

668

669 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"

670 << " to base + vector of offsets\n");

671

672

674 Value *PreTrunc = Trunc->getOperand(0);

677 Input = PreTrunc;

678 InputTy = PreTruncTy;

679 }

680 }

681 bool ExtendInput = false;

684

685

686

687

690 ExtendInput = true;

691 LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n"

692 << *Input << "\n");

693 }

695 LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for "

696 "non-standard input types. Expanding.\n");

697 return nullptr;

698 }

699

701 int Scale;

704 if (!BasePtr)

705 return nullptr;

706

707 if (ExtendInput)

708 Input = Builder.CreateZExt(Input, InputTy);

711 Intrinsic::arm_mve_vstr_scatter_offset_predicated,

713 Mask->getType()},

717 else

719 Intrinsic::arm_mve_vstr_scatter_offset,

720 {BasePtr->getType(), Offsets->getType(), Input->getType()},

724}

725

726Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat(

728 FixedVectorType *Ty;

729 if (I->getIntrinsicID() == Intrinsic::masked_gather)

731 else

733

734

736 return nullptr;

737

739 if (L == nullptr)

740 return nullptr;

741

742

746 if (!BasePtr)

747 return nullptr;

748

749 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "

750 "wb gather/scatter\n");

751

752

753

754 int TypeScale =

755 computeScale(DL->getTypeSizeInBits(GEP->getSourceElementType()),

756 DL->getTypeSizeInBits(GEP->getType()) /

758 if (TypeScale == -1)

759 return nullptr;

760

761 if (GEP->hasOneUse()) {

762

763

764

765 if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets,

766 TypeScale, Builder))

768 }

769

770 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "

771 "non-wb gather/scatter\n");

772

773 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);

774 if (Add.first == nullptr)

775 return nullptr;

776 Value *OffsetsIncoming = Add.first;

777 int64_t Immediate = Add.second;

778

779

781 Instruction::Shl, OffsetsIncoming,

783 Builder.getInt32(TypeScale)),

784 "ScaledIndex", I->getIterator());

785

787 Instruction::Add, ScaledOffsets,

791 BasePtr,

793 "StartIndex", I->getIterator());

794

795 if (I->getIntrinsicID() == Intrinsic::masked_gather)

796 return tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate);

797 else

798 return tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate);

799}

800

801Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(

802 IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,

804

805

807

808

810 if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||

811 Phi->getParent() != L->getHeader() || Phi->hasNUses(2))

812

813

814

815

816 return nullptr;

817

818 unsigned IncrementIndex =

819 Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;

820

821 Offsets = Phi->getIncomingValue(IncrementIndex);

822

823 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);

824 if (Add.first == nullptr)

825 return nullptr;

826 Value *OffsetsIncoming = Add.first;

827 int64_t Immediate = Add.second;

828 if (OffsetsIncoming != Phi)

829

830

831 return nullptr;

832

833 Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());

834 unsigned NumElems =

836

837

839 Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),

841 "ScaledIndex",

842 Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator());

843

845 Instruction::Add, ScaledOffsets,

847 NumElems,

849 BasePtr,

851 "StartIndex",

852 Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator());

853

855 Instruction::Sub, OffsetsIncoming,

857 "PreIncrementStartIndex",

858 Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator());

859 Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);

860

862

865 if (I->getIntrinsicID() == Intrinsic::masked_gather) {

866

867 Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);

868

869

872 Builder.Insert(EndResult);

873 Builder.Insert(NewInduction);

874 } else {

875

876 EndResult = NewInduction =

877 tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);

878 }

882 Phi->setIncomingValue(IncrementIndex, NewInduction);

883

884 return EndResult;

885}

886

887void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,

888 Value *OffsSecondOperand,

889 unsigned StartIndex) {

890 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");

892 Phi->getIncomingBlock(StartIndex)->back().getIterator();

893

895 Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,

896 "PushedOutAdd", InsertionPoint);

897 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;

898

899

900 Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));

901 Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),

902 Phi->getIncomingBlock(IncrementIndex));

903 Phi->removeIncomingValue(1);

904 Phi->removeIncomingValue((unsigned)0);

905}

906

907void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi,

908 Value *IncrementPerRound,

909 Value *OffsSecondOperand,

910 unsigned LoopIncrement,

912 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");

913

914

915

917 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back().getIterator();

918

919

920 Value *StartIndex =

922 Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),

923 OffsSecondOperand, "PushedOutMul", InsertionPoint);

924

927 OffsSecondOperand, "Product", InsertionPoint);

928

930 Phi->getIncomingBlock(LoopIncrement)->back().getIterator();

931 NewIncrInsertPt = std::prev(NewIncrInsertPt);

932

933

935 Instruction::Add, Phi, Product, "IncrementPushedOutMul", NewIncrInsertPt);

936

937 Phi->addIncoming(StartIndex,

938 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));

939 Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));

940 Phi->removeIncomingValue((unsigned)0);

941 Phi->removeIncomingValue((unsigned)0);

942}

943

944

945

947 if (I->use_empty()) {

948 return false;

949 }

950 bool Gatscat = true;

951 for (User *U : I->users()) {

953 return false;

956 return Gatscat;

957 } else {

959 if ((OpCode == Instruction::Add || OpCode == Instruction::Mul ||

960 OpCode == Instruction::Shl ||

963 continue;

964 }

965 return false;

966 }

967 }

968 return Gatscat;

969}

970

971bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,

972 LoopInfo *LI) {

973 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize: "

974 << *Offsets << "\n");

975

976

978 return false;

981 Offs->getOpcode() != Instruction::Mul &&

982 Offs->getOpcode() != Instruction::Shl)

983 return false;

985 if (L == nullptr)

986 return false;

989 return false;

990 }

991

992

993

994 PHINode *Phi;

995 int OffsSecondOp;

998 OffsSecondOp = 1;

1001 OffsSecondOp = 0;

1002 } else {

1011 return false;

1014 OffsSecondOp = 1;

1017 OffsSecondOp = 0;

1018 } else {

1019 return false;

1020 }

1021 }

1022

1023

1024 if (Phi->getParent() != L->getHeader())

1025 return false;

1026

1027

1028 BinaryOperator *IncInstruction;

1031 IncInstruction->getOpcode() != Instruction::Add)

1032 return false;

1033

1034 int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1;

1035

1036

1037 Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);

1038

1039 if (IncrementPerRound->getType() != OffsSecondOperand->getType() ||

1040 L->isLoopInvariant(OffsSecondOperand))

1041

1042 return false;

1043

1044

1045

1049 return false;

1050

1051

1052

1053 PHINode *NewPhi;

1054 if (Phi->hasNUses(2)) {

1055

1056

1057 if (!IncInstruction->hasOneUse()) {

1058

1059

1062 IncrementPerRound, "LoopIncrement", IncInstruction->getIterator());

1063 Phi->setIncomingValue(IncrementingBlock, IncInstruction);

1064 }

1065 NewPhi = Phi;

1066 } else {

1067

1069

1070 NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),

1071 Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));

1074 IncrementPerRound, "LoopIncrement", IncInstruction->getIterator());

1076 Phi->getIncomingBlock(IncrementingBlock));

1077 IncrementingBlock = 1;

1078 }

1079

1083

1085 case Instruction::Add:

1086 case Instruction::Or:

1087 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);

1088 break;

1089 case Instruction::Mul:

1090 case Instruction::Shl:

1091 pushOutMulShl(Offs->getOpcode(), NewPhi, IncrementPerRound,

1092 OffsSecondOperand, IncrementingBlock, Builder);

1093 break;

1094 default:

1095 return false;

1096 }

1097 LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable "

1098 << "add/mul\n");

1099

1100

1103

1104

1105 if (IncInstruction->use_empty())

1107

1108 return true;

1109}

1110

1112 unsigned ScaleY, IRBuilder<> &Builder) {

1113

1114

1115

1121 uint64_t N = Const->getZExtValue();

1122 if (N < (unsigned)(1 << (TargetElemSize - 1))) {

1123 NonVectorVal = Builder.CreateVectorSplat(

1124 VT->getNumElements(), Builder.getIntN(TargetElemSize, N));

1125 return;

1126 }

1127 }

1128 NonVectorVal =

1129 Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);

1130 };

1131

1134

1135

1136 if (XElType && !YElType) {

1137 FixSummands(XElType, Y);

1139 } else if (YElType && !XElType) {

1140 FixSummands(YElType, X);

1142 }

1143 assert(XElType && YElType && "Unknown vector types");

1144

1145 if (XElType != YElType) {

1146 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");

1147 return nullptr;

1148 }

1149

1151

1152

1155 if (!ConstX || !ConstY)

1156 return nullptr;

1157 unsigned TargetElemSize = 128 / XElType->getNumElements();

1158 for (unsigned i = 0; i < XElType->getNumElements(); i++) {

1163 if (!ConstXEl || !ConstYEl ||

1166 (unsigned)(1 << (TargetElemSize - 1)))

1167 return nullptr;

1168 }

1169 }

1170

1171 Value *XScale = Builder.CreateVectorSplat(

1174 Value *YScale = Builder.CreateVectorSplat(

1177 Value *Add = Builder.CreateAdd(Builder.CreateMul(X, XScale),

1178 Builder.CreateMul(Y, YScale));

1179

1181 return Add;

1182 else

1183 return nullptr;

1184}

1185

1186Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,

1187 Value *&Offsets, unsigned &Scale,

1189 Value *GEPPtr = GEP->getPointerOperand();

1191 Scale = DL->getTypeAllocSize(GEP->getSourceElementType());

1192

1193

1195 return nullptr;

1197

1198 Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Scale, Builder);

1199 if (!BaseBasePtr)

1200 return nullptr;

1202 Offsets, Scale, GEP->getOperand(1),

1203 DL->getTypeAllocSize(GEP->getSourceElementType()), Builder);

1204 if (Offsets == nullptr)

1205 return nullptr;

1206 Scale = 1;

1207 return BaseBasePtr;

1208 }

1209 return GEPPtr;

1210}

1211

1212bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,

1213 LoopInfo *LI) {

1215 if (GEP)

1216 return false;

1223 unsigned Scale;

1224 Value *Base = foldGEP(GEP, Offsets, Scale, Builder);

1225

1226

1227

1228

1230 assert(Scale == 1 && "Expected to fold GEP to a scale of 1");

1236 "gep.merged", GEP->getIterator());

1238 << "\n new : " << *NewAddress << "\n");

1239 GEP->replaceAllUsesWith(

1241 GEP = NewAddress;

1243 }

1244 }

1245 Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);

1247}

1248

1249bool MVEGatherScatterLowering::runOnFunction(Function &F) {

1251 return false;

1252 auto &TPC = getAnalysis();

1253 auto &TM = TPC.getTM();

1254 auto *ST = &TM.getSubtarget(F);

1255 if (ST->hasMVEIntegerOps())

1256 return false;

1257 LI = &getAnalysis().getLoopInfo();

1258 DL = &F.getDataLayout();

1261

1263

1264 for (BasicBlock &BB : F) {

1266

1267 for (Instruction &I : BB) {

1269 if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&

1272 Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);

1273 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&

1276 Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);

1277 }

1278 }

1279 }

1280 for (IntrinsicInst *I : Gathers) {

1282 if (L == nullptr)

1283 continue;

1284

1285

1288 }

1289

1290 for (IntrinsicInst *I : Scatters) {

1292 if (S == nullptr)

1293 continue;

1294

1295

1298 }

1300}

assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")

MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL

cl::opt< bool > EnableMaskedGatherScatters

This file contains the declarations for the subclasses of Constant, which represent the different fla...

static Decomposition decomposeGEP(GEPOperator &GEP, SmallVectorImpl< ConditionTy > &Preconditions, bool IsSigned, const DataLayout &DL)

static bool runOnFunction(Function &F, bool PostInlining)

static bool isAddLikeOr(Instruction *I, const DataLayout &DL)

Definition MVEGatherScatterLowering.cpp:366

static bool hasAllGatScatUsers(Instruction *I, const DataLayout &DL)

Definition MVEGatherScatterLowering.cpp:946

static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount)

Definition MVEGatherScatterLowering.cpp:184

static Value * CheckAndCreateOffsetAdd(Value *X, unsigned ScaleX, Value *Y, unsigned ScaleY, IRBuilder<> &Builder)

Definition MVEGatherScatterLowering.cpp:1111

cl::opt< bool > EnableMaskedGatherScatters("enable-arm-maskedgatscat", cl::Hidden, cl::init(true), cl::desc("Enable the generation of masked gathers and scatters"))

uint64_t IntrinsicInst * II

#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)

static unsigned getNumElements(Type *Ty)

static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")

static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")

This file describes how to lower LLVM code to machine code.

Target-Independent Code Generator Pass Configuration Options pass.

This pass exposes codegen information to IR-level passes.

AnalysisUsage & addRequired()

LLVM_ABI void setPreservesCFG()

This function should be called by the pass, iff they do not:

InstListType::iterator iterator

Instruction iterators...

LLVM_ABI LLVMContext & getContext() const

Get the context in which this basic block lives.

BinaryOps getOpcode() const

static LLVM_ABI BinaryOperator * Create(BinaryOps Op, Value *S1, Value *S2, const Twine &Name=Twine(), InsertPosition InsertBefore=nullptr)

Construct a binary instruction, given the opcode and the two operands.

Type * getDestTy() const

Return the destination type, as a convenience.

This is the shared class of boolean and integer constants.

int64_t getSExtValue() const

Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...

uint64_t getZExtValue() const

Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...

This is an important base class in LLVM.

LLVM_ABI Constant * getAggregateElement(unsigned Elt) const

For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...

A parsed version of the target data layout string in and methods for querying it.

Class to represent fixed width SIMD vectors.

unsigned getNumElements() const

static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)

FunctionPass class - This class is used to implement most global optimizations.

static GetElementPtrInst * Create(Type *PointeeType, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)

LLVM_ABI Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")

Return a vector value that contains.

Value * CreateIntToPtr(Value *V, Type *DestTy, const Twine &Name="")

IntegerType * getInt32Ty()

Fetch the type representing a 32-bit integer.

void SetCurrentDebugLocation(DebugLoc L)

Set location information used by debugging information.

LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")

Create a call to intrinsic ID with Args, mangled using Types.

ConstantInt * getInt32(uint32_t C)

Get a constant 32-bit value.

InstTy * Insert(InstTy *I, const Twine &Name="") const

Insert and return the specified instruction.

Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")

Value * CreateZExt(Value *V, Type *DestTy, const Twine &Name="", bool IsNonNeg=false)

Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")

Value * CreateTrunc(Value *V, Type *DestTy, const Twine &Name="", bool IsNUW=false, bool IsNSW=false)

PointerType * getPtrTy(unsigned AddrSpace=0)

Fetch the type representing a pointer.

void SetInsertPoint(BasicBlock *TheBB)

This specifies that created instructions should be appended to the end of the specified block.

IntegerType * getInt8Ty()

Fetch the type representing an 8-bit integer.

This provides a uniform API for creating instructions and inserting them into a basic block: either a...

const DebugLoc & getDebugLoc() const

Return the debug location for this node as a DebugLoc.

LLVM_ABI InstListType::iterator eraseFromParent()

This method unlinks 'this' from the containing basic block and deletes it.

unsigned getOpcode() const

Returns a member of one of the enums like Instruction::Add.

LoopT * getLoopFor(const BlockT *BB) const

Return the inner most loop that BB lives in.

void addIncoming(Value *V, BasicBlock *BB)

Add an incoming value to the end of the PHI list.

static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)

Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...

static LLVM_ABI PassRegistry * getPassRegistry()

getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...

Pass interface - Implemented by all 'passes'.

static SelectInst * Create(Value *C, Value *S1, Value *S2, const Twine &NameStr="", InsertPosition InsertBefore=nullptr, const Instruction *MDFrom=nullptr)

void push_back(const T &Elt)

bool isVectorTy() const

True if this is an instance of VectorType.

bool isIntOrIntVectorTy() const

Return true if this is an integer type or a vector of integer types.

LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY

Return the basic size of this type if it is a primitive type.

LLVM_ABI Type * getWithNewBitWidth(unsigned NewBitWidth) const

Given an integer or vector type, change the lane bitwidth to NewBitwidth, whilst keeping the old numb...

LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY

If this is a vector type, return the getPrimitiveSizeInBits value for the element type.

Value * getOperand(unsigned i) const

LLVM Value Representation.

Type * getType() const

All values are typed, get the type of this value.

bool hasOneUse() const

Return true if there is exactly one use of this value.

LLVM_ABI void replaceAllUsesWith(Value *V)

Change all uses of this to point to a new Value.

Type * getElementType() const

const ParentTy * getParent() const

self_iterator getIterator()

constexpr char Align[]

Key for Kernel::Arg::Metadata::mAlign.

constexpr std::underlying_type_t< E > Mask()

Get a bitmask with 1s in all places up to the high-order bit of E's largest value.

unsigned ID

LLVM IR allows to use arbitrary numbers as calling convention identifiers.

@ C

The default llvm calling convention, compatible with C.

bool match(Val *V, const Pattern &P)

cst_pred_ty< is_one > m_One()

Match an integer 1 or a vector with all elements equal to 1.

is_zero m_Zero()

Match any null constant or a vector with all elements equal to 0.

Offsets

Offsets in bytes from the start of the input buffer.

initializer< Ty > init(const Ty &Val)

@ User

could "use" a pointer

NodeAddr< PhiNode * > Phi

friend class Instruction

Iterator for Instructions in a `BasicBlock.

This is an optimization pass for GlobalISel generic memory operations.

LLVM_ABI bool haveNoCommonBitsSet(const WithCache< const Value * > &LHSCache, const WithCache< const Value * > &RHSCache, const SimplifyQuery &SQ)

Return true if LHS and RHS have no common bits set.

FunctionAddr VTableAddr Value

decltype(auto) dyn_cast(const From &Val)

dyn_cast - Return the argument parameter cast to the specified type.

Pass * createMVEGatherScatterLoweringPass()

LLVM_ABI bool SimplifyInstructionsInBlock(BasicBlock *BB, const TargetLibraryInfo *TLI=nullptr)

Scan the specified basic block and try to simplify any instructions in it and recursively delete dead...

LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)

Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...

bool isGatherScatter(IntrinsicInst *IntInst)

LLVM_ABI raw_ostream & dbgs()

dbgs() - This returns a reference to a raw_ostream for debugging messages.

class LLVM_GSL_OWNER SmallVector

Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...

bool isa(const From &Val)

isa - Return true if the parameter to the template is an instance of one of the template type argu...

IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >

decltype(auto) cast(const From &Val)

cast - Return the argument parameter cast to the specified type.

void initializeMVEGatherScatterLoweringPass(PassRegistry &)

@ Increment

Incrementally increasing token ID.

This struct is a compact representation of a valid (non-zero power of two) alignment.