LLVM: lib/Target/AMDGPU/AMDGPURewriteOutArguments.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

56

57#define DEBUG_TYPE "amdgpu-rewrite-out-arguments"

58

59using namespace llvm;

60

62 "amdgpu-any-address-space-out-arguments",

63 cl::desc("Replace pointer out arguments with "

64 "struct returns for non-private address space"),

67

69 "amdgpu-max-return-arg-num-regs",

70 cl::desc("Approximately limit number of return registers for replacing out arguments"),

73

75 "Number out arguments moved to struct return values");

77 "Number of functions with out arguments moved to struct return values");

78

79namespace {

80

81class AMDGPURewriteOutArguments : public FunctionPass {

82private:

85

86 Type *getStoredType(Value &Arg) const;

87 Type *getOutArgumentType(Argument &Arg) const;

88

89public:

90 static char ID;

91

93

94 void getAnalysisUsage(AnalysisUsage &AU) const override {

95 AU.addRequired();

96 FunctionPass::getAnalysisUsage(AU);

97 }

98

99 bool doInitialization(Module &M) override;

101};

102

103}

104

106 "AMDGPU Rewrite Out Arguments", false, false)

110

111char AMDGPURewriteOutArguments::ID = 0;

112

113Type *AMDGPURewriteOutArguments::getStoredType(Value &Arg) const {

114 const int MaxUses = 10;

115 int UseCount = 0;

116

118

119 Type *StoredType = nullptr;

120 while (!Worklist.empty()) {

122

124 for (Use &U : BCI->uses())

126 continue;

127 }

128

130 if (UseCount++ > MaxUses)

131 return nullptr;

132

133 if (SI->isSimple() ||

135 return nullptr;

136

137 if (StoredType && StoredType != SI->getValueOperand()->getType())

138 return nullptr;

139 StoredType = SI->getValueOperand()->getType();

140 continue;

141 }

142

143

144 return nullptr;

145 }

146

147 return StoredType;

148}

149

150Type *AMDGPURewriteOutArguments::getOutArgumentType(Argument &Arg) const {

151 const unsigned MaxOutArgSizeBytes = 4 * MaxNumRetRegs;

153

154

155 if (!ArgTy || (ArgTy->getAddressSpace() != DL->getAllocaAddrSpace() &&

158 return nullptr;

159 }

160

161 Type *StoredType = getStoredType(Arg);

162 if (!StoredType || DL->getTypeStoreSize(StoredType) > MaxOutArgSizeBytes)

163 return nullptr;

164

165 return StoredType;

166}

167

168bool AMDGPURewriteOutArguments::doInitialization(Module &M) {

169 DL = &M.getDataLayout();

170 return false;

171}

172

173bool AMDGPURewriteOutArguments::runOnFunction(Function &F) {

174 if (skipFunction(F))

175 return false;

176

177

178 if (F.isVarArg() || F.hasStructRetAttr() ||

180 return false;

181

182 MDA = &getAnalysis().getMemDep();

183

184 unsigned ReturnNumRegs = 0;

185 SmallDenseMap<int, Type *, 4> OutArgIndexes;

187 Type *RetTy = F.getReturnType();

189 ReturnNumRegs = DL->getTypeStoreSize(RetTy) / 4;

190

192 return false;

193

195 }

196

198 for (Argument &Arg : F.args()) {

199 if (Type *Ty = getOutArgumentType(Arg)) {

200 LLVM_DEBUG(dbgs() << "Found possible out argument " << Arg

201 << " in function " << F.getName() << '\n');

203 }

204 }

205

206 if (OutArgs.empty())

207 return false;

208

210

211 DenseMap<ReturnInst *, ReplacementVec> Replacements;

212

214 for (BasicBlock &BB : F) {

217 }

218

219 if (Returns.empty())

220 return false;

221

222 bool Changing;

223

224 do {

225 Changing = false;

226

227

228

229

230

231

232

233

234 for (const auto &Pair : OutArgs) {

235 bool ThisReplaceable = true;

237

238 Argument *OutArg = Pair.first;

239 Type *ArgTy = Pair.second;

240

241

242

243

244

245

246 unsigned ArgNumRegs = DL->getTypeStoreSize(ArgTy) / 4;

248 continue;

249

250

251

252 for (ReturnInst *RI : Returns) {

254

257 StoreInst *SI = nullptr;

260

261 if (SI) {

262 LLVM_DEBUG(dbgs() << "Found out argument store: " << *SI << '\n');

264 } else {

265 ThisReplaceable = false;

266 break;

267 }

268 }

269

270 if (!ThisReplaceable)

271 continue;

272

273 for (std::pair<ReturnInst *, StoreInst *> Store : ReplaceableStores) {

274 Value *ReplVal = Store.second->getValueOperand();

275

276 auto &ValVec = Replacements[Store.first];

279 << "Saw multiple out arg stores" << *OutArg << '\n');

280

281

282 ThisReplaceable = false;

283 break;

284 }

285

286 ValVec.emplace_back(OutArg, ReplVal);

287 Store.second->eraseFromParent();

288 }

289

290 if (ThisReplaceable) {

293 ++NumOutArgumentsReplaced;

294 Changing = true;

295 }

296 }

297 } while (Changing);

298

299 if (Replacements.empty())

300 return false;

301

302 LLVMContext &Ctx = F.getContext();

303 StructType *NewRetTy = StructType::create(Ctx, ReturnTypes, F.getName());

304

305 FunctionType *NewFuncTy = FunctionType::get(NewRetTy,

306 F.getFunctionType()->params(),

307 F.isVarArg());

308

309 LLVM_DEBUG(dbgs() << "Computed new return type: " << *NewRetTy << '\n');

310

312 F.getName() + ".body");

313 F.getParent()->getFunctionList().insert(F.getIterator(), NewFunc);

316

317

318

320

321 AttributeMask RetAttrs;

326

327

328

329

331

332 for (std::pair<ReturnInst *, ReplacementVec> &Replacement : Replacements) {

333 ReturnInst *RI = Replacement.first;

335 B.SetCurrentDebugLocation(RI->getDebugLoc());

336

337 int RetIdx = 0;

339

341 if (RetVal)

342 NewRetVal = B.CreateInsertValue(NewRetVal, RetVal, RetIdx++);

343

344 for (std::pair<Argument *, Value *> ReturnPoint : Replacement.second)

345 NewRetVal = B.CreateInsertValue(NewRetVal, ReturnPoint.second, RetIdx++);

346

347 if (RetVal)

349 else {

350 B.CreateRet(NewRetVal);

352 }

353 }

354

355 SmallVector<Value *, 16> StubCallArgs;

356 for (Argument &Arg : F.args()) {

358

359

361 } else {

363 }

364 }

365

368 CallInst *StubCall = B.CreateCall(NewFunc, StubCallArgs);

369

370 int RetIdx = RetTy->isVoidTy() ? 0 : 1;

371 for (Argument &Arg : F.args()) {

372 auto It = OutArgIndexes.find(Arg.getArgNo());

373 if (It == OutArgIndexes.end())

374 continue;

375

376 Type *EltTy = It->second;

377 const auto Align =

378 DL->getValueOrABITypeAlignment(Arg.getParamAlign(), EltTy);

379

380 Value *Val = B.CreateExtractValue(StubCall, RetIdx++);

381 B.CreateAlignedStore(Val, &Arg, Align);

382 }

383

385 B.CreateRet(B.CreateExtractValue(StubCall, 0));

386 } else {

387 B.CreateRetVoid();

388 }

389

390

391 F.addFnAttr(Attribute::AlwaysInline);

392

393 ++NumOutArgumentFunctionsReplaced;

394 return true;

395}

396

398 return new AMDGPURewriteOutArguments();

399}

static cl::opt< unsigned > MaxNumRetRegs("amdgpu-max-return-arg-num-regs", cl::desc("Approximately limit number of return registers for replacing out arguments"), cl::Hidden, cl::init(16))

static cl::opt< bool > AnyAddressSpace("amdgpu-any-address-space-out-arguments", cl::desc("Replace pointer out arguments with " "struct returns for non-private address space"), cl::Hidden, cl::init(false))

MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL

static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")

static bool runOnFunction(Function &F, bool PostInlining)

Machine Check Debug Module

#define INITIALIZE_PASS_DEPENDENCY(depName)

#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)

#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)

This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...

#define STATISTIC(VARNAME, DESC)

AnalysisUsage & addRequired()

This class represents an incoming formal argument to a Function.

LLVM_ABI bool hasByValAttr() const

Return true if this argument has the byval attribute.

unsigned getArgNo() const

Return the index of this formal argument in its containing function.

LLVM_ABI MaybeAlign getParamAlign() const

If this is a byval or inalloca argument, return its alignment.

LLVM_ABI bool hasStructRetAttr() const

Return true if this argument has the sret attribute.

AttributeMask & addAttribute(Attribute::AttrKind Val)

Add an attribute to the mask.

static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)

Creates a new BasicBlock.

A parsed version of the target data layout string in and methods for querying it.

iterator find(const_arg_type_t< KeyT > Val)

size_type count(const_arg_type_t< KeyT > Val) const

Return 1 if the specified key is in the map, 0 otherwise.

std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)

FunctionPass class - This class is used to implement most global optimizations.

static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)

void splice(Function::iterator ToIt, Function *FromF)

Transfer all blocks from FromF to this function at ToIt.

void stealArgumentListFrom(Function &Src)

Steal arguments from another function.

void removeRetAttrs(const AttributeMask &Attrs)

removes the attributes from the return value list of attributes.

void copyAttributesFrom(const Function *Src)

copyAttributesFrom - copy all additional attributes (those not needed to create a Function) from the ...

LLVM_ABI void setComdat(Comdat *C)

const DebugLoc & getDebugLoc() const

Return the debug location for this node as a DebugLoc.

LLVM_ABI InstListType::iterator eraseFromParent()

This method unlinks 'this' from the containing basic block and deletes it.

bool isDef() const

Tests if this MemDepResult represents a query that is an instruction definition dependency.

Instruction * getInst() const

If this is a normal dependency, returns the instruction that is depended on.

Provides a lazy, caching interface for making common memory aliasing information queries,...

MemDepResult getPointerDependencyFrom(const MemoryLocation &Loc, bool isLoad, BasicBlock::iterator ScanIt, BasicBlock *BB, Instruction *QueryInst=nullptr, unsigned *Limit=nullptr)

Returns the instruction on which a memory location depends.

A wrapper analysis pass for the legacy pass manager that exposes a MemoryDepnedenceResults instance.

static MemoryLocation getBeforeOrAfter(const Value *Ptr, const AAMDNodes &AATags=AAMDNodes())

Return a location that may access any location before or after Ptr, while remaining within the underl...

static LLVM_ABI PoisonValue * get(Type *T)

Static factory methods - Return an 'poison' object of the specified type.

Value * getReturnValue() const

Convenience accessor. Returns null if there is no return value.

reference emplace_back(ArgTypes &&... Args)

void push_back(const T &Elt)

This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.

static unsigned getPointerOperandIndex()

static LLVM_ABI StructType * create(LLVMContext &Context, StringRef Name)

This creates an identified struct.

The instances of the Type class are immutable: once they are created, they are never changed.

bool isVoidTy() const

Return true if this is 'void'.

A Use represents the edge between a Value definition and its users.

void setOperand(unsigned i, Value *Val)

LLVM Value Representation.

Type * getType() const

All values are typed, get the type of this value.

iterator_range< use_iterator > uses()

constexpr char Align[]

Key for Kernel::Arg::Metadata::mAlign.

LLVM_READNONE constexpr bool isEntryFunctionCC(CallingConv::ID CC)

unsigned ID

LLVM IR allows to use arbitrary numbers as calling convention identifiers.

@ BasicBlock

Various leaf nodes.

initializer< Ty > init(const Ty &Val)

This is an optimization pass for GlobalISel generic memory operations.

FunctionAddr VTableAddr Value

decltype(auto) dyn_cast(const From &Val)

dyn_cast - Return the argument parameter cast to the specified type.

FunctionPass * createAMDGPURewriteOutArgumentsPass()

Definition AMDGPURewriteOutArguments.cpp:397

LLVM_ABI raw_ostream & dbgs()

dbgs() - This returns a reference to a raw_ostream for debugging messages.

auto make_first_range(ContainerTy &&c)

Given a container of pairs, return a range over the first elements.

class LLVM_GSL_OWNER SmallVector

Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...

IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >

iterator_range< pointer_iterator< WrappedIteratorT > > make_pointer_range(RangeT &&Range)

bool is_contained(R &&Range, const E &Element)

Returns true if Element is found in Range.