LLVM: lib/Target/AMDGPU/AMDGPULowerKernelAttributes.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

24#include "llvm/IR/IntrinsicsAMDGPU.h"

28

29#define DEBUG_TYPE "amdgpu-lower-kernel-attributes"

30

31using namespace llvm;

32

33namespace {

34

35

36enum DispatchPackedOffsets {

37 WORKGROUP_SIZE_X = 4,

38 WORKGROUP_SIZE_Y = 6,

39 WORKGROUP_SIZE_Z = 8,

40

41 GRID_SIZE_X = 12,

42 GRID_SIZE_Y = 16,

43 GRID_SIZE_Z = 20

44};

45

46

47enum ImplicitArgOffsets {

48 HIDDEN_BLOCK_COUNT_X = 0,

49 HIDDEN_BLOCK_COUNT_Y = 4,

50 HIDDEN_BLOCK_COUNT_Z = 8,

51

52 HIDDEN_GROUP_SIZE_X = 12,

53 HIDDEN_GROUP_SIZE_Y = 14,

54 HIDDEN_GROUP_SIZE_Z = 16,

55

56 HIDDEN_REMAINDER_X = 18,

57 HIDDEN_REMAINDER_Y = 20,

58 HIDDEN_REMAINDER_Z = 22,

59};

60

61class AMDGPULowerKernelAttributes : public ModulePass {

62public:

63 static char ID;

64

65 AMDGPULowerKernelAttributes() : ModulePass(ID) {}

66

68

70 return "AMDGPU Kernel Attributes";

71 }

72

75 }

76};

77

78Function *getBasePtrIntrinsic(Module &M, bool IsV5OrAbove) {

79 auto IntrinsicId = IsV5OrAbove ? Intrinsic::amdgcn_implicitarg_ptr

80 : Intrinsic::amdgcn_dispatch_ptr;

82}

83

84}

85

88 if (MaxNumGroups == 0 || MaxNumGroups == std::numeric_limits<uint32_t>::max())

89 return;

90

91 if (!Load->getType()->isIntegerTy(32))

92 return;

93

94

95 MDBuilder MDB(Load->getContext());

97 Load->setMetadata(LLVMContext::MD_range, Range);

98}

99

102

103 auto *MD = F->getMetadata("reqd_work_group_size");

104 const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;

105

106 const bool HasUniformWorkGroupSize =

107 F->getFnAttribute("uniform-work-group-size").getValueAsBool();

108

111

112 if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize &&

113 none_of(MaxNumWorkgroups, [](unsigned X) { return X != 0; }))

114 return false;

115

116 Value *BlockCounts[3] = {nullptr, nullptr, nullptr};

117 Value *GroupSizes[3] = {nullptr, nullptr, nullptr};

118 Value *Remainders[3] = {nullptr, nullptr, nullptr};

119 Value *GridSizes[3] = {nullptr, nullptr, nullptr};

120

122

123

124

126 if (!U->hasOneUse())

127 continue;

128

130 auto *Load = dyn_cast(U);

131 auto *BCI = dyn_cast(U);

132 if (!Load && !BCI) {

134 continue;

135 Load = dyn_cast(*U->user_begin());

136 BCI = dyn_cast(*U->user_begin());

137 }

138

139 if (BCI) {

140 if (!BCI->hasOneUse())

141 continue;

142 Load = dyn_cast(*BCI->user_begin());

143 }

144

145 if (!Load || !Load->isSimple())

146 continue;

147

148 unsigned LoadSize = DL.getTypeStoreSize(Load->getType());

149

150

151 if (IsV5OrAbove) {

153 case HIDDEN_BLOCK_COUNT_X:

154 if (LoadSize == 4) {

155 BlockCounts[0] = Load;

157 }

158 break;

159 case HIDDEN_BLOCK_COUNT_Y:

160 if (LoadSize == 4) {

161 BlockCounts[1] = Load;

163 }

164 break;

165 case HIDDEN_BLOCK_COUNT_Z:

166 if (LoadSize == 4) {

167 BlockCounts[2] = Load;

169 }

170 break;

171 case HIDDEN_GROUP_SIZE_X:

172 if (LoadSize == 2)

173 GroupSizes[0] = Load;

174 break;

175 case HIDDEN_GROUP_SIZE_Y:

176 if (LoadSize == 2)

177 GroupSizes[1] = Load;

178 break;

179 case HIDDEN_GROUP_SIZE_Z:

180 if (LoadSize == 2)

181 GroupSizes[2] = Load;

182 break;

183 case HIDDEN_REMAINDER_X:

184 if (LoadSize == 2)

185 Remainders[0] = Load;

186 break;

187 case HIDDEN_REMAINDER_Y:

188 if (LoadSize == 2)

189 Remainders[1] = Load;

190 break;

191 case HIDDEN_REMAINDER_Z:

192 if (LoadSize == 2)

193 Remainders[2] = Load;

194 break;

195 default:

196 break;

197 }

198 } else {

200 case WORKGROUP_SIZE_X:

201 if (LoadSize == 2)

202 GroupSizes[0] = Load;

203 break;

204 case WORKGROUP_SIZE_Y:

205 if (LoadSize == 2)

206 GroupSizes[1] = Load;

207 break;

208 case WORKGROUP_SIZE_Z:

209 if (LoadSize == 2)

210 GroupSizes[2] = Load;

211 break;

212 case GRID_SIZE_X:

213 if (LoadSize == 4)

214 GridSizes[0] = Load;

215 break;

216 case GRID_SIZE_Y:

217 if (LoadSize == 4)

218 GridSizes[1] = Load;

219 break;

220 case GRID_SIZE_Z:

221 if (LoadSize == 4)

222 GridSizes[2] = Load;

223 break;

224 default:

225 break;

226 }

227 }

228 }

229

230 bool MadeChange = false;

231 if (IsV5OrAbove && HasUniformWorkGroupSize) {

232

233

234

235

236

237

238

239 for (int I = 0; I < 3; ++I) {

240 Value *BlockCount = BlockCounts[I];

241 if (!BlockCount)

242 continue;

243

245 auto GroupIDIntrin =

246 I == 0 ? m_IntrinsicIntrinsic::amdgcn\_workgroup\_id\_x()

247 : (I == 1 ? m_IntrinsicIntrinsic::amdgcn\_workgroup\_id\_y()

248 : m_IntrinsicIntrinsic::amdgcn\_workgroup\_id\_z());

249

250 for (User *ICmp : BlockCount->users()) {

254 MadeChange = true;

255 }

256 }

257 }

258

259

260 for (Value *Remainder : Remainders) {

261 if (!Remainder)

262 continue;

264 MadeChange = true;

265 }

266 } else if (HasUniformWorkGroupSize) {

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

283

284

285

286 for (int I = 0; I < 3; ++I) {

287 Value *GroupSize = GroupSizes[I];

288 Value *GridSize = GridSizes[I];

289 if (!GroupSize || !GridSize)

290 continue;

291

293 auto GroupIDIntrin =

294 I == 0 ? m_IntrinsicIntrinsic::amdgcn\_workgroup\_id\_x()

295 : (I == 1 ? m_IntrinsicIntrinsic::amdgcn\_workgroup\_id\_y()

296 : m_IntrinsicIntrinsic::amdgcn\_workgroup\_id\_z());

297

298 for (User *U : GroupSize->users()) {

299 auto *ZextGroupSize = dyn_cast(U);

300 if (!ZextGroupSize)

301 continue;

302

303 for (User *UMin : ZextGroupSize->users()) {

308 if (HasReqdWorkGroupSize) {

310 = mdconst::extract(MD->getOperand(I));

312 KnownSize, UMin->getType(), false, DL));

313 } else {

314 UMin->replaceAllUsesWith(ZextGroupSize);

315 }

316

317 MadeChange = true;

318 }

319 }

320 }

321 }

322 }

323

324

325 if (!HasReqdWorkGroupSize)

326 return MadeChange;

327

328 for (int I = 0; I < 3; I++) {

329 Value *GroupSize = GroupSizes[I];

330 if (!GroupSize)

331 continue;

332

333 ConstantInt *KnownSize = mdconst::extract(MD->getOperand(I));

336 MadeChange = true;

337 }

338

339 return MadeChange;

340}

341

342

343

344

345bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {

346 bool MadeChange = false;

347 bool IsV5OrAbove =

350

351 if (!BasePtr)

352 return false;

353

355 for (auto *U : BasePtr->users()) {

356 CallInst *CI = cast(U);

357 if (HandledUses.insert(CI).second) {

359 MadeChange = true;

360 }

361 }

362

363 return MadeChange;

364}

365

366

368 "AMDGPU Kernel Attributes", false, false)

371

372char AMDGPULowerKernelAttributes::ID = 0;

373

375 return new AMDGPULowerKernelAttributes();

376}

377

380 bool IsV5OrAbove =

382 Function *BasePtr = getBasePtrIntrinsic(*F.getParent(), IsV5OrAbove);

383

384 if (!BasePtr)

386

388 if (CallInst *CI = dyn_cast(&I)) {

391 }

392 }

393

395}

static void annotateGridSizeLoadWithRangeMD(LoadInst *Load, uint32_t MaxNumGroups)

static bool processUse(CallInst *CI, bool IsV5OrAbove)

MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL

Expand Atomic instructions

This file contains the declarations for the subclasses of Constant, which represent the different fla...

static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")

ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))

#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)

#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)

Class for arbitrary precision integers.

A container for analyses that lazily runs them and caches their results.

Represent the analysis usage information of a pass.

void setPreservesAll()

Set by analyses that do not transform their input at all.

Function * getCalledFunction() const

Returns the function called, or null if this is an indirect function invocation or the function signa...

This class represents a function call, abstracting a target machine's calling convention.

This is the shared class of boolean and integer constants.

static ConstantInt * getTrue(LLVMContext &Context)

static Constant * getNullValue(Type *Ty)

Constructor to create a '0' constant of arbitrary type.

A parsed version of the target data layout string in and methods for querying it.

An instruction for reading from memory.

MDNode * createRange(const APInt &Lo, const APInt &Hi)

Return metadata describing the range [Lo, Hi).

ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...

virtual bool runOnModule(Module &M)=0

runOnModule - Virtual method overriden by subclasses to process the module being operated on.

A Module instance is used to store all the information related to an LLVM module.

virtual void getAnalysisUsage(AnalysisUsage &) const

getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...

virtual StringRef getPassName() const

getPassName - Return a nice clean name for a pass.

A set of analyses that are preserved following a run of a transformation pass.

static PreservedAnalyses all()

Construct a special preserved set that preserves all passes.

std::pair< iterator, bool > insert(PtrType Ptr)

Inserts Ptr if and only if there is no element in the container equal to Ptr.

SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.

This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.

StringRef - Represent a constant reference to a string, i.e.

LLVM Value Representation.

Type * getType() const

All values are typed, get the type of this value.

void replaceAllUsesWith(Value *V)

Change all uses of this to point to a new Value.

iterator_range< user_iterator > users()

const ParentTy * getParent() const

unsigned getAMDHSACodeObjectVersion(const Module &M)

SmallVector< unsigned > getIntegerVecAttribute(const Function &F, StringRef Name, unsigned Size, unsigned DefaultVal)

unsigned ID

LLVM IR allows to use arbitrary numbers as calling convention identifiers.

Function * getDeclarationIfExists(Module *M, ID id, ArrayRef< Type * > Tys, FunctionType *FT=nullptr)

This version supports overloaded intrinsics.

bool match(Val *V, const Pattern &P)

specificval_ty m_Specific(const Value *V)

Match if we have a specific specified value.

BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)

SpecificCmpClass_match< LHS, RHS, ICmpInst > m_SpecificICmp(CmpPredicate MatchPred, const LHS &L, const RHS &R)

BinaryOp_match< LHS, RHS, Instruction::Sub > m_Sub(const LHS &L, const RHS &R)

MaxMin_match< ICmpInst, LHS, RHS, umin_pred_ty > m_UMin(const LHS &L, const RHS &R)

This is an optimization pass for GlobalISel generic memory operations.

Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)

Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.

ModulePass * createAMDGPULowerKernelAttributesPass()

bool none_of(R &&Range, UnaryPredicate P)

Provide wrappers to std::none_of which take ranges instead of having to pass begin/end explicitly.

@ UMin

Unsigned integer min implemented in terms of select(cmp()).

Constant * ConstantFoldIntegerCast(Constant *C, Type *DestTy, bool IsSigned, const DataLayout &DL)

Constant fold a zext, sext or trunc, depending on IsSigned and whether the DestTy is wider or narrowe...

PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)