LLVM: lib/Target/RISCV/RISCVGatherScatterLowering.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

26#include

27

28using namespace llvm;

30

31#define DEBUG_TYPE "riscv-gather-scatter-lowering"

32

33namespace {

34

35class RISCVGatherScatterLowering : public FunctionPass {

40

42

43

44

45

47

48public:

49 static char ID;

50

52

54

55 void getAnalysisUsage(AnalysisUsage &AU) const override {

59 }

60

61 StringRef getPassName() const override {

62 return "RISC-V gather/scatter lowering";

63 }

64

65private:

67

68 std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr,

70

71 bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,

74};

75

76}

77

78char RISCVGatherScatterLowering::ID = 0;

79

81 "RISC-V gather/scatter lowering pass", false, false)

82

84 return new RISCVGatherScatterLowering();

85}

86

87

90 return std::make_pair(nullptr, nullptr);

91

93

94

95 auto *StartVal =

97 if (!StartVal)

98 return std::make_pair(nullptr, nullptr);

99 APInt StrideVal(StartVal->getValue().getBitWidth(), 0);

101 for (unsigned i = 1; i != NumElts; ++i) {

103 if (C)

104 return std::make_pair(nullptr, nullptr);

105

106 APInt LocalStride = C->getValue() - Prev->getValue();

107 if (i == 1)

108 StrideVal = LocalStride;

109 else if (StrideVal != LocalStride)

110 return std::make_pair(nullptr, nullptr);

111

112 Prev = C;

113 }

114

115 Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);

116

117 return std::make_pair(StartVal, Stride);

118}

119

122

124 if (StartC)

126

127

129 auto *Ty = Start->getType()->getScalarType();

130 return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));

131 }

132

133

134

136 if (!BO || (BO->getOpcode() != Instruction::Add &&

137 BO->getOpcode() != Instruction::Or &&

138 BO->getOpcode() != Instruction::Shl &&

139 BO->getOpcode() != Instruction::Mul))

140 return std::make_pair(nullptr, nullptr);

141

142 if (BO->getOpcode() == Instruction::Or &&

144 return std::make_pair(nullptr, nullptr);

145

146

147 unsigned OtherIndex = 0;

151 OtherIndex = 1;

152 }

154 return std::make_pair(nullptr, nullptr);

155

157 std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),

158 Builder);

159 if (!Start)

160 return std::make_pair(nullptr, nullptr);

161

162 Builder.SetInsertPoint(BO);

163 Builder.SetCurrentDebugLocation(DebugLoc());

164

165

166 switch (BO->getOpcode()) {

167 default:

169 case Instruction::Or:

170 Start = Builder.CreateOr(Start, Splat, "", true);

171 break;

172 case Instruction::Add:

173 Start = Builder.CreateAdd(Start, Splat);

174 break;

175 case Instruction::Mul:

176 Start = Builder.CreateMul(Start, Splat);

177 Stride = Builder.CreateMul(Stride, Splat);

178 break;

179 case Instruction::Shl:

180 Start = Builder.CreateShl(Start, Splat);

181 Stride = Builder.CreateShl(Stride, Splat);

182 break;

183 }

184

185 return std::make_pair(Start, Stride);

186}

187

188

189

190

191

192bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,

197

199

200

201 if (Phi->getParent() != L->getHeader())

202 return false;

203

206 Inc->getOpcode() != Instruction::Add)

207 return false;

208 assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");

209 unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;

210 assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&

211 "Expected one operand of phi to be Inc");

212

213

215 if (!Step)

216 return false;

217

219 if (!Start)

220 return false;

221 assert(Stride != nullptr);

222

223

226 Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",

228 BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));

229 BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));

230

231

232 MaybeDeadPHIs.push_back(Phi);

233 return true;

234 }

235

236

238 if (!BO)

239 return false;

240

241 switch (BO->getOpcode()) {

242 default:

243 return false;

244 case Instruction::Or:

245

247 return false;

248 break;

249 case Instruction::Add:

250 break;

251 case Instruction::Shl:

252 break;

253 case Instruction::Mul:

254 break;

255 }

256

257

262 OtherOp = BO->getOperand(1);

267 OtherOp = BO->getOperand(0);

268 } else {

269 return false;

270 }

271

272

273 if (L->isLoopInvariant(OtherOp))

274 return false;

275

276

278 if (!SplatOp)

279 return false;

280

281

282 if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))

283 return false;

284

285

287 unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;

290

291

293 BasePtr->getIncomingBlock(StartBlock)->getTerminator());

295

296

297 switch (BO->getOpcode()) {

298 default:

300 case Instruction::Add:

301 case Instruction::Or: {

302

303

305 break;

306 }

307 case Instruction::Mul: {

309 Stride = Builder.CreateMul(Stride, SplatOp, "stride");

310 break;

311 }

312 case Instruction::Shl: {

314 Stride = Builder.CreateShl(Stride, SplatOp, "stride");

315 break;

316 }

317 }

318

319

320

322 Builder.SetInsertPoint(*StepI->getInsertionPointAfterDef());

323

324 switch (BO->getOpcode()) {

325 default:

326 break;

327 case Instruction::Mul:

328 Step = Builder.CreateMul(Step, SplatOp, "step");

329 break;

330 case Instruction::Shl:

331 Step = Builder.CreateShl(Step, SplatOp, "step");

332 break;

333 }

334

336 BasePtr->setIncomingValue(StartBlock, Start);

337 return true;

338}

339

340std::pair<Value *, Value *>

341RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,

342 IRBuilderBase &Builder) {

343

344

346 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());

347 return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0));

348 }

349

351 if (GEP)

352 return std::make_pair(nullptr, nullptr);

353

354 auto I = StridedAddrs.find(GEP);

355 if (I != StridedAddrs.end())

356 return I->second;

357

358 SmallVector<Value *, 2> Ops(GEP->operands());

359

360

363 BaseInst && BaseInst->getType()->isVectorTy()) {

364

365 auto IsScalar = [](Value *Idx) { return !Idx->getType()->isVectorTy(); };

366 if (all_of(GEP->indices(), IsScalar)) {

367 auto [BaseBase, Stride] = determineBaseAndStride(BaseInst, Builder);

368 if (BaseBase) {

371 Value *OffsetBase =

372 Builder.CreateGEP(GEP->getSourceElementType(), BaseBase, Indices,

373 GEP->getName() + "offset", GEP->isInBounds());

374 return {OffsetBase, Stride};

375 }

376 }

377 }

378

379

383 if (!ScalarBase)

384 return std::make_pair(nullptr, nullptr);

385 }

386

387 std::optional VecOperand;

388 unsigned TypeScale = 0;

389

390

392 for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {

393 if (Ops[i]->getType()->isVectorTy())

394 continue;

395

396 if (VecOperand)

397 return std::make_pair(nullptr, nullptr);

398

399 VecOperand = i;

400

403 return std::make_pair(nullptr, nullptr);

404

406 }

407

408

409 if (!VecOperand)

410 return std::make_pair(nullptr, nullptr);

411

412

413

414

415

416

417 Value *VecIndex = Ops[*VecOperand];

418 Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());

419 if (VecIndex->getType() != VecIntPtrTy) {

421 if (!VecIndexC)

422 return std::make_pair(nullptr, nullptr);

425 else

427 }

428

429

430

432 if (Start) {

435

436

438 Type *SourceTy = GEP->getSourceElementType();

441

442

443 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());

444 assert(Stride->getType() == IntPtrTy && "Unexpected type");

445

446

447 if (TypeScale != 1)

448 Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));

449

450 auto P = std::make_pair(BasePtr, Stride);

451 StridedAddrs[GEP] = P;

452 return P;

453 }

454

455

457 if (!L || L->getLoopPreheader() || L->getLoopLatch())

458 return std::make_pair(nullptr, nullptr);

459

460 BinaryOperator *Inc;

461 PHINode *BasePhi;

462 if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))

463 return std::make_pair(nullptr, nullptr);

464

466 unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;

468 "Expected one operand of phi to be Inc");

469

471

472

473 Ops[*VecOperand] = BasePhi;

474 Type *SourceTy = GEP->getSourceElementType();

477

478

481

482

483 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());

484 assert(Stride->getType() == IntPtrTy && "Unexpected type");

485

486

487 if (TypeScale != 1)

488 Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));

489

490 auto P = std::make_pair(BasePtr, Stride);

491 StridedAddrs[GEP] = P;

492 return P;

493}

494

495bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II) {

497 Value *StoreVal = nullptr, *Ptr, *Mask, *EVL = nullptr;

499 switch (II->getIntrinsicID()) {

500 case Intrinsic::masked_gather:

502 Ptr = II->getArgOperand(0);

503 Alignment = II->getParamAlign(0).valueOrOne();

504 Mask = II->getArgOperand(1);

505 break;

506 case Intrinsic::vp_gather:

508 Ptr = II->getArgOperand(0);

509

510 Alignment = II->getParamAlign(0).value_or(

511 DL->getABITypeAlign(DataType->getElementType()));

512 Mask = II->getArgOperand(1);

513 EVL = II->getArgOperand(2);

514 break;

515 case Intrinsic::masked_scatter:

517 StoreVal = II->getArgOperand(0);

518 Ptr = II->getArgOperand(1);

519 Alignment = II->getParamAlign(1).valueOrOne();

520 Mask = II->getArgOperand(2);

521 break;

522 case Intrinsic::vp_scatter:

524 StoreVal = II->getArgOperand(0);

525 Ptr = II->getArgOperand(1);

526

527 Alignment = II->getParamAlign(1).value_or(

528 DL->getABITypeAlign(DataType->getElementType()));

529 Mask = II->getArgOperand(2);

530 EVL = II->getArgOperand(3);

531 break;

532 default:

534 }

535

536

539 return false;

540

541

543 return false;

544

545

547 if (!PtrI)

548 return false;

549

550 LLVMContext &Ctx = PtrI->getContext();

551 IRBuilder Builder(Ctx, InstSimplifyFolder(*DL));

553

555 std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder);

556 if (!BasePtr)

557 return false;

558 assert(Stride != nullptr);

559

561

562 if (!EVL)

565

567

568 if (!StoreVal) {

570 Intrinsic::experimental_vp_strided_load,

573

574

575 if (II->getIntrinsicID() == Intrinsic::masked_gather)

577 } else

579 Intrinsic::experimental_vp_strided_store,

582

584 II->replaceAllUsesWith(Call);

585 II->eraseFromParent();

586

587 if (PtrI->use_empty())

589

590 return true;

591}

592

593bool RISCVGatherScatterLowering::runOnFunction(Function &F) {

594 if (skipFunction(F))

595 return false;

596

597 auto &TPC = getAnalysis();

598 auto &TM = TPC.getTM();

599 ST = &TM.getSubtarget(F);

600 if (ST->hasVInstructions() || ST->useRVVForFixedLengthVectors())

601 return false;

602

603 TLI = ST->getTargetLowering();

604 DL = &F.getDataLayout();

605 LI = &getAnalysis().getLoopInfo();

606

607 StridedAddrs.clear();

608

610

612

613 for (BasicBlock &BB : F) {

614 for (Instruction &I : BB) {

616 if (II)

617 continue;

618 switch (II->getIntrinsicID()) {

619 case Intrinsic::masked_gather:

620 case Intrinsic::masked_scatter:

621 case Intrinsic::vp_gather:

622 case Intrinsic::vp_scatter:

624 break;

625 default:

626 break;

627 }

628 }

629 }

630

631

632 for (auto *II : Worklist)

633 Changed |= tryCreateStridedLoadStore(II);

634

635

636 while (!MaybeDeadPHIs.empty()) {

639 }

640

642}

assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")

MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL

static bool runOnFunction(Function &F, bool PostInlining)

const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]

uint64_t IntrinsicInst * II

#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)

static std::pair< Value *, Value * > matchStridedStart(Value *Start, IRBuilderBase &Builder)

Definition RISCVGatherScatterLowering.cpp:120

static std::pair< Value *, Value * > matchStridedConstant(Constant *StartC)

Definition RISCVGatherScatterLowering.cpp:88

static SymbolRef::Type getType(const Symbol *Sym)

Target-Independent Code Generator Pass Configuration Options pass.

Class for arbitrary precision integers.

Represent the analysis usage information of a pass.

AnalysisUsage & addRequired()

LLVM_ABI void setPreservesCFG()

This function should be called by the pass, iff they do not:

const Instruction * getTerminator() const LLVM_READONLY

Returns the terminator instruction if the block is well formed or null if the block is not well forme...

BinaryOps getOpcode() const

This is the shared class of boolean and integer constants.

const APInt & getValue() const

Return the constant as an APInt value reference.

This is an important base class in LLVM.

LLVM_ABI Constant * getAggregateElement(unsigned Elt) const

For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...

A parsed version of the target data layout string in and methods for querying it.

FunctionPass class - This class is used to implement most global optimizations.

Common base class shared among various IRBuilders.

LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)

IntegerType * getInt32Ty()

Fetch the type representing a 32-bit integer.

void SetCurrentDebugLocation(DebugLoc L)

Set location information used by debugging information.

Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())

LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")

Create a call to intrinsic ID with Args, mangled using Types.

Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)

Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)

void SetInsertPoint(BasicBlock *TheBB)

This specifies that created instructions should be appended to the end of the specified block.

Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)

LLVM_ABI Value * CreateElementCount(Type *Ty, ElementCount EC)

Create an expression which evaluates to the number of elements in EC at runtime.

LLVM_ABI bool isCommutative() const LLVM_READONLY

Return true if the instruction is commutative:

A wrapper class for inspecting calls to intrinsic functions.

LoopT * getLoopFor(const BlockT *BB) const

Return the inner most loop that BB lives in.

The legacy pass manager's analysis pass to compute loop information.

Represents a single loop in the control flow graph.

BasicBlock * getIncomingBlock(unsigned i) const

Return incoming basic block number i.

Value * getIncomingValue(unsigned i) const

Return incoming value number x.

unsigned getNumIncomingValues() const

Return the number of incoming edges.

static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)

Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...

bool isLegalStridedLoadStore(EVT DataType, Align Alignment) const

Return true if a stride load store of the given result type and alignment is legal.

void push_back(const T &Elt)

This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.

StringRef - Represent a constant reference to a string, i.e.

EVT getValueType(const DataLayout &DL, Type *Ty, bool AllowUnknown=false) const

Return the EVT corresponding to this LLVM type.

bool isTypeLegal(EVT VT) const

Return true if the target has native support for the specified value type.

Target-Independent Code Generator Pass Configuration Options.

bool isVectorTy() const

True if this is an instance of VectorType.

LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY

If this is a vector type, return the getPrimitiveSizeInBits value for the element type.

void setOperand(unsigned i, Value *Val)

Value * getOperand(unsigned i) const

LLVM Value Representation.

Type * getType() const

All values are typed, get the type of this value.

LLVM_ABI StringRef getName() const

Return a constant reference to the value's name.

LLVM_ABI void takeName(Value *V)

Transfer the name from V to this value.

constexpr ScalarTy getFixedValue() const

constexpr bool isScalable() const

Returns whether the quantity is scaled by a runtime quantity (vscale).

TypeSize getSequentialElementStride(const DataLayout &DL) const

self_iterator getIterator()

#define llvm_unreachable(msg)

Marks that the current location is not supposed to be reachable.

constexpr char Align[]

Key for Kernel::Arg::Metadata::mAlign.

constexpr std::underlying_type_t< E > Mask()

Get a bitmask with 1s in all places up to the high-order bit of E's largest value.

unsigned ID

LLVM IR allows to use arbitrary numbers as calling convention identifiers.

@ C

The default llvm calling convention, compatible with C.

bool match(Val *V, const Pattern &P)

IntrinsicID_match m_Intrinsic()

Match intrinsic calls like this: m_IntrinsicIntrinsic::fabs(m_Value(X))

NodeAddr< PhiNode * > Phi

This is an optimization pass for GlobalISel generic memory operations.

FunctionAddr VTableAddr Value

bool all_of(R &&range, UnaryPredicate P)

Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.

LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())

If the specified value is a trivially dead instruction, delete it.

decltype(auto) dyn_cast(const From &Val)

dyn_cast - Return the argument parameter cast to the specified type.

LLVM_ABI Value * getSplatValue(const Value *V)

Get splat value if the input is a splat vector or return nullptr.

FunctionPass * createRISCVGatherScatterLoweringPass()

LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)

Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...

auto dyn_cast_or_null(const Y &Val)

generic_gep_type_iterator<> gep_type_iterator

class LLVM_GSL_OWNER SmallVector

Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...

bool isa(const From &Val)

isa - Return true if the parameter to the template is an instance of one of the template type argu...

IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >

ArrayRef(const T &OneElt) -> ArrayRef< T >

decltype(auto) cast(const From &Val)

cast - Return the argument parameter cast to the specified type.

gep_type_iterator gep_type_begin(const User *GEP)

LLVM_ABI bool RecursivelyDeleteDeadPHINode(PHINode *PN, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)

If the specified value is an effectively dead PHI node, due to being a def-use chain of single-use no...

LLVM_ABI Constant * ConstantFoldCastInstruction(unsigned opcode, Constant *V, Type *DestTy)