LLVM: lib/Target/NVPTX/NVPTXLowerArgs.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

150#include "llvm/IR/IntrinsicsNVPTX.h"

157#include

158#include

159

160#define DEBUG_TYPE "nvptx-lower-args"

161

162using namespace llvm;

163

164namespace {

165class NVPTXLowerArgsLegacyPass : public FunctionPass {

167

168public:

169 static char ID;

171 StringRef getPassName() const override {

172 return "Lower pointer arguments of CUDA kernels";

173 }

174 void getAnalysisUsage(AnalysisUsage &AU) const override {

176 }

177};

178}

179

180char NVPTXLowerArgsLegacyPass::ID = 1;

181

183 "Lower arguments (NVPTX)", false, false)

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

211 bool IsGridConstant) {

213 assert(I && "OldUse must be in an instruction");

214 struct IP {

215 Use *OldUse;

218 };

221

222 auto CloneInstInParamAS = [HasCvtaParam,

223 IsGridConstant](const IP &I) -> Value * {

225 LI->setOperand(0, I.NewParam);

226 return LI;

227 }

231 GEP->getSourceElementType(), I.NewParam, Indices, GEP->getName(),

232 GEP->getIterator());

233 NewGEP->setIsInBounds(GEP->isInBounds());

234 return NewGEP;

235 }

239 BC->getName(), BC->getIterator());

240 }

243 (void)ASC;

244

245 return I.NewParam;

246 }

248 if (MI->getRawSource() == I.OldUse->get()) {

249

252

253 CallInst *B = Builder.CreateMemTransferInst(

254 ID, MI->getRawDest(), MI->getDestAlign(), I.NewParam,

255 MI->getSourceAlign(), MI->getLength(), MI->isVolatile());

256 for (unsigned I : {0, 1})

257 if (uint64_t Bytes = MI->getParamDereferenceableBytes(I))

258 B->addDereferenceableParamAttr(I, Bytes);

259 return B;

260 }

261

262

263 }

264

265 if (HasCvtaParam) {

266 auto GetParamAddrCastToGeneric =

271 };

272 auto *ParamInGenericAS =

273 GetParamAddrCastToGeneric(I.NewParam, I.OldInstruction);

274

275

277 for (auto [Idx, V] : enumerate(PHI->incoming_values())) {

278 if (V.get() == I.OldUse->get())

279 PHI->setIncomingValue(Idx, ParamInGenericAS);

280 }

281 }

283 if (SI->getTrueValue() == I.OldUse->get())

284 SI->setTrueValue(ParamInGenericAS);

285 if (SI->getFalseValue() == I.OldUse->get())

286 SI->setFalseValue(ParamInGenericAS);

287 }

288

289

290

291 if (IsGridConstant) {

293 I.OldUse->set(ParamInGenericAS);

294 return CI;

295 }

297

298 if (SI->getValueOperand() == I.OldUse->get())

299 SI->setOperand(0, ParamInGenericAS);

300 return SI;

301 }

303 if (PI->getPointerOperand() == I.OldUse->get())

304 PI->setOperand(0, ParamInGenericAS);

305 return PI;

306 }

307

308

309 }

310 }

311

313 };

314

315 while (!ItemsToConvert.empty()) {

317 Value *NewInst = CloneInstInParamAS(I);

318

319 if (NewInst && NewInst != I.OldInstruction) {

320

321

322

323 for (Use &U : I.OldInstruction->uses())

325

326 InstructionsToDelete.push_back(I.OldInstruction);

327 }

328 }

329

330

331

332

333

334

335

336

338 I->eraseFromParent();

339}

340

341

342

343

344

345

346

351 const DataLayout &DL = Func->getDataLayout();

352

353 const Align NewArgAlign =

356

357 if (CurArgAlign >= NewArgAlign)

358 return;

359

361 << " instead of " << CurArgAlign.value() << " for " << *Arg

362 << '\n');

363

364 auto NewAlignAttr =

366 Arg->removeAttr(Attribute::Alignment);

367 Arg->addAttr(NewAlignAttr);

368

369 struct Load {

372 };

373

374 struct LoadContext {

375 Value *InitialVal;

377 };

378

380 std::queue Worklist;

381 Worklist.push({ArgInParamAS, 0});

382

383 while (!Worklist.empty()) {

384 LoadContext Ctx = Worklist.front();

385 Worklist.pop();

386

387 for (User *CurUser : Ctx.InitialVal->users()) {

393 APInt OffsetAccumulated =

395

396 if (I->accumulateConstantOffset(DL, OffsetAccumulated))

397 continue;

398

401 assert(Offset != OffsetLimit && "Expect Offset less than UINT64_MAX");

402

403 Worklist.push({I, Ctx.Offset + Offset});

404 }

405 }

406 }

407

408 for (Load &CurLoad : Loads) {

409 Align NewLoadAlign(std::gcd(NewArgAlign.value(), CurLoad.Offset));

410 Align CurLoadAlign = CurLoad.Inst->getAlign();

411 CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign));

412 }

413}

414

415

416

420 IRB.CreateIntrinsic(Intrinsic::nvvm_internal_addrspace_wrap,

422 &Arg, {}, Arg.getName() + ".param");

423

427

428 return ArgInParam;

429}

430

431namespace {

432struct ArgUseChecker : PtrUseVisitor {

433 using Base = PtrUseVisitor;

434

435 bool IsGridConstant;

436

437 SmallPtrSet<Instruction *, 4> Conditionals;

438

439 ArgUseChecker(const DataLayout &DL, bool IsGridConstant)

440 : PtrUseVisitor(DL), IsGridConstant(IsGridConstant) {}

441

442 PtrInfo visitArgPtr(Argument &A) {

443 assert(A.getType()->isPointerTy());

445 IsOffsetKnown = false;

447 PI.reset();

448 Conditionals.clear();

449

451

452 enqueueUsers(A);

453

454

455

456

457 while (!(Worklist.empty() || PI.isAborted())) {

458 UseToVisit ToVisit = Worklist.pop_back_val();

459 U = ToVisit.UseAndIsOffsetKnown.getPointer();

462 Conditionals.insert(I);

465 }

466 if (PI.isEscaped())

467 LLVM_DEBUG(dbgs() << "Argument pointer escaped: " << *PI.getEscapingInst()

468 << "\n");

469 else if (PI.isAborted())

470 LLVM_DEBUG(dbgs() << "Pointer use needs a copy: " << *PI.getAbortingInst()

471 << "\n");

472 LLVM_DEBUG(dbgs() << "Traversed " << Conditionals.size()

473 << " conditionals\n");

474 return PI;

475 }

476

477 void visitStoreInst(StoreInst &SI) {

478

479 if (U->get() == SI.getValueOperand())

480 return PI.setEscapedAndAborted(&SI);

481

482

483 if (!IsGridConstant)

484 return PI.setAborted(&SI);

485 }

486

487 void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {

488

490 return PI.setEscapedAndAborted(&ASC);

492 }

493

494 void visitPtrToIntInst(PtrToIntInst &I) {

495 if (IsGridConstant)

496 return;

498 }

499 void visitPHINodeOrSelectInst(Instruction &I) {

501 }

502

503 void visitPHINode(PHINode &PN) { enqueueUsers(PN); }

504 void visitSelectInst(SelectInst &SI) { enqueueUsers(SI); }

505

506 void visitMemTransferInst(MemTransferInst &II) {

507 if (*U == II.getRawDest() && !IsGridConstant)

508 PI.setAborted(&II);

509

510

511 }

512

513 void visitMemSetInst(MemSetInst &II) {

514 if (!IsGridConstant)

515 PI.setAborted(&II);

516 }

517};

518

520 LLVM_DEBUG(dbgs() << "Creating a local copy of " << Arg << "\n");

521

527

528

529

533

535

536

537

538

540 IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),

541 ArgSize);

542}

543}

544

550 const DataLayout &DL = Func->getDataLayout();

554

555 ArgUseChecker AUC(DL, IsGridConstant);

556 ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(*Arg);

557 bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());

558

559 if (ArgUseIsReadOnly && AUC.Conditionals.empty()) {

560

561

563

566

567 for (Use *U : UsesToUpdate)

568 convertToParamAS(U, ArgInParamAS, HasCvtaParam, IsGridConstant);

569 LLVM_DEBUG(dbgs() << "No need to copy or cast " << *Arg << "\n");

570

571 const auto *TLI =

573

575

576 return;

577 }

578

579

580

581

582

583

584 if (IsGridConstant || (HasCvtaParam && ArgUseIsReadOnly)) {

585 LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");

586

587

588

589 IRBuilder<> IRB(&Func->getEntryBlock().front());

590

591

592

593

594

596

597

600 Arg->getName() + ".gen");

601

603

604

605 ParamSpaceArg->setOperand(0, Arg);

606 } else

607 copyByValParam(*Func, *Arg);

608}

609

612 return;

613

614

617

619 } else {

620

622 assert(InsertPt != InsertPt->getParent()->end() &&

623 "We don't call this function with Ptr being a terminator.");

624 }

625

629 Ptr->getName(), InsertPt);

630

633}

634

638

639

640

641

643

644

645

646

647 auto HandleIntToPtr = [](Value &V) {

648 if (llvm::all_of(V.users(), [](User *U) { return isa(U); })) {

650 for (User *U : UsersToUpdate)

652 }

653 };

655

656 for (auto &B : F) {

657 for (auto &I : B) {

659 if (LI->getType()->isPointerTy() || LI->getType()->isIntegerTy()) {

663

664 if (LI->getType()->isPointerTy())

666 else

667 HandleIntToPtr(*LI);

668 }

669 }

670 }

671 }

672 }

673 }

674 }

675

676 LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n");

682 HandleIntToPtr(Arg);

683 }

684 }

685 return true;

686}

687

688

690 LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n");

691

692 const auto *TLI =

694

698

699 return true;

700}

701

706

707bool NVPTXLowerArgsLegacyPass::runOnFunction(Function &F) {

708 auto &TM = getAnalysis().getTM();

710}

712 return new NVPTXLowerArgsLegacyPass();

713}

714

716 LLVM_DEBUG(dbgs() << "Creating a copy of byval args of " << F.getName()

717 << "\n");

723 copyByValParam(F, Arg);

725 }

726 }

728}

729

735

assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")

MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL

static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")

static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")

static bool runOnFunction(Function &F, bool PostInlining)

NVPTX address space definition.

static bool runOnDeviceFunction(const NVPTXTargetMachine &TM, Function &F)

Definition NVPTXLowerArgs.cpp:689

static CallInst * createNVVMInternalAddrspaceWrap(IRBuilder<> &IRB, Argument &Arg)

Definition NVPTXLowerArgs.cpp:417

static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS, const NVPTXTargetLowering *TLI)

Definition NVPTXLowerArgs.cpp:347

static bool copyFunctionByValArgs(Function &F)

Definition NVPTXLowerArgs.cpp:715

static void markPointerAsAS(Value *Ptr, const unsigned AS)

Definition NVPTXLowerArgs.cpp:610

nvptx lower Lower static false void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam, bool IsGridConstant)

Definition NVPTXLowerArgs.cpp:210

static bool processFunction(Function &F, NVPTXTargetMachine &TM)

Definition NVPTXLowerArgs.cpp:702

static bool runOnKernelFunction(const NVPTXTargetMachine &TM, Function &F)

Definition NVPTXLowerArgs.cpp:642

static void markPointerAsGlobal(Value *Ptr)

Definition NVPTXLowerArgs.cpp:635

static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg)

Definition NVPTXLowerArgs.cpp:545

uint64_t IntrinsicInst * II

#define INITIALIZE_PASS_DEPENDENCY(depName)

#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)

#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)

This file provides a collection of visitors which walk the (instruction) uses of a pointer.

Target-Independent Code Generator Pass Configuration Options pass.

Class for arbitrary precision integers.

uint64_t getLimitedValue(uint64_t Limit=UINT64_MAX) const

If this value is smaller than the specified limit, return it, otherwise return the limit value.

static APInt getZero(unsigned numBits)

Get the '0' value for the specified bit-width.

This class represents a conversion between pointers from one address space to another.

unsigned getDestAddressSpace() const

Returns the address space of the result.

an instruction to allocate memory on the stack

Align getAlign() const

Return the alignment of the memory that is being allocated by the instruction.

LLVM_ABI std::optional< TypeSize > getAllocationSize(const DataLayout &DL) const

Get allocation size in bytes.

void setAlignment(Align Align)

Represent the analysis usage information of a pass.

AnalysisUsage & addRequired()

This class represents an incoming formal argument to a Function.

LLVM_ABI void addAttr(Attribute::AttrKind Kind)

LLVM_ABI bool hasByValAttr() const

Return true if this argument has the byval attribute.

LLVM_ABI void removeAttr(Attribute::AttrKind Kind)

Remove attributes from an argument.

const Function * getParent() const

LLVM_ABI Type * getParamByValType() const

If this is a byval argument, return its type.

LLVM_ABI MaybeAlign getParamAlign() const

If this is a byval or inalloca argument, return its alignment.

static LLVM_ABI Attribute getWithAlignment(LLVMContext &Context, Align Alignment)

Return a uniquified Attribute object that has the specific alignment set.

iterator begin()

Instruction iterator methods.

InstListType::iterator iterator

Instruction iterators...

void addRetAttr(Attribute::AttrKind Kind)

Adds the attribute to the return value.

This class represents a function call, abstracting a target machine's calling convention.

static LLVM_ABI CastInst * Create(Instruction::CastOps, Value *S, Type *Ty, const Twine &Name="", InsertPosition InsertBefore=nullptr)

Provides a way to construct any of the CastInst subclasses using an opcode instead of the subclass's ...

A parsed version of the target data layout string in and methods for querying it.

FunctionPass class - This class is used to implement most global optimizations.

const BasicBlock & getEntryBlock() const

static GetElementPtrInst * Create(Type *PointeeType, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)

LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")

Create a call to intrinsic ID with Args, mangled using Types.

PointerType * getPtrTy(unsigned AddrSpace=0)

Fetch the type representing a pointer.

Value * CreateAddrSpaceCast(Value *V, Type *DestTy, const Twine &Name="")

This provides a uniform API for creating instructions and inserting them into a basic block: either a...

void visit(Iterator Start, Iterator End)

unsigned getBitWidth() const

Get the number of bits in this IntegerType.

An instruction for reading from memory.

const NVPTXTargetLowering * getTargetLowering() const override

bool hasCvtaParam() const

Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy, const DataLayout &DL) const

getFunctionParamOptimizedAlign - since function arguments are passed via .param space,...

NVPTX::DrvInterface getDrvInterface() const

const NVPTXSubtarget * getSubtargetImpl(const Function &) const override

Virtual method implemented by subclasses that returns a reference to that target's TargetSubtargetInf...

static LLVM_ABI PointerType * get(Type *ElementType, unsigned AddressSpace)

This constructs a pointer to an object of the specified type in a numbered address space.

A set of analyses that are preserved following a run of a transformation pass.

static PreservedAnalyses none()

Convenience factory function for the empty preserved set.

static PreservedAnalyses all()

Construct a special preserved set that preserves all passes.

A base class for visitors over the uses of a pointer value.

void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC)

void visitPtrToIntInst(PtrToIntInst &I)

void push_back(const T &Elt)

This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.

StringRef - Represent a constant reference to a string, i.e.

Class to represent struct types.

Target-Independent Code Generator Pass Configuration Options.

The instances of the Type class are immutable: once they are created, they are never changed.

bool isPointerTy() const

True if this is an instance of PointerType.

LLVM_ABI unsigned getPointerAddressSpace() const

Get the address space of this pointer or pointer vector type.

bool isIntegerTy() const

True if this is an instance of IntegerType.

A Use represents the edge between a Value definition and its users.

void setOperand(unsigned i, Value *Val)

LLVM Value Representation.

Type * getType() const

All values are typed, get the type of this value.

LLVM_ABI void replaceAllUsesWith(Value *V)

Change all uses of this to point to a new Value.

iterator_range< user_iterator > users()

LLVM_ABI LLVMContext & getContext() const

All values hold a context through their type.

iterator_range< use_iterator > uses()

LLVM_ABI StringRef getName() const

Return a constant reference to the value's name.

#define llvm_unreachable(msg)

Marks that the current location is not supposed to be reachable.

unsigned ID

LLVM IR allows to use arbitrary numbers as calling convention identifiers.

friend class Instruction

Iterator for Instructions in a `BasicBlock.

This is an optimization pass for GlobalISel generic memory operations.

bool all_of(R &&range, UnaryPredicate P)

Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.

auto enumerate(FirstRange &&First, RestRanges &&...Rest)

Given two or more input ranges, returns a new range whose values are tuples (A, B,...

decltype(auto) dyn_cast(const From &Val)

dyn_cast - Return the argument parameter cast to the specified type.

FunctionPass * createNVPTXLowerArgsPass()

Definition NVPTXLowerArgs.cpp:711

auto reverse(ContainerTy &&C)

LLVM_ABI raw_ostream & dbgs()

dbgs() - This returns a reference to a raw_ostream for debugging messages.

bool isa(const From &Val)

isa - Return true if the parameter to the template is an instance of one of the template type argu...

bool isParamGridConstant(const Argument &Arg)

bool isKernelFunction(const Function &F)

decltype(auto) cast(const From &Val)

cast - Return the argument parameter cast to the specified type.

iterator_range< pointer_iterator< WrappedIteratorT > > make_pointer_range(RangeT &&Range)

AnalysisManager< Function > FunctionAnalysisManager

Convenience typedef for the Function analysis manager.

LLVM_ABI const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=MaxLookupSearchDepth)

This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....

This struct is a compact representation of a valid (non-zero power of two) alignment.

constexpr uint64_t value() const

This is a hole in the type system and should not be abused.

This struct is a compact representation of a valid (power of two) or undefined (0) alignment.

Align valueOrOne() const

For convenience, returns a valid alignment or 1 if undefined.

PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)

Definition NVPTXLowerArgs.cpp:730

PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)

Definition NVPTXLowerArgs.cpp:736