clang: lib/CodeGen/CGHLSLRuntime.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

21#include "llvm/IR/GlobalVariable.h"

22#include "llvm/IR/LLVMContext.h"

23#include "llvm/IR/Metadata.h"

24#include "llvm/IR/Module.h"

25#include "llvm/IR/Value.h"

26#include "llvm/Support/Alignment.h"

27

28#include "llvm/Support/FormatVariadic.h"

29

30using namespace clang;

31using namespace CodeGen;

33using namespace llvm;

34

35namespace {

36

37void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {

38

39

40 VersionTuple Version;

41 if (Version.tryParse(ValVersionStr) || Version.getBuild() ||

42 Version.getSubminor() || !Version.getMinor()) {

43 return;

44 }

45

46 uint64_t Major = Version.getMajor();

47 uint64_t Minor = *Version.getMinor();

48

49 auto &Ctx = M.getContext();

50 IRBuilder<> B(M.getContext());

51 MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),

52 ConstantAsMetadata::get(B.getInt32(Minor))});

53 StringRef DXILValKey = "dx.valver";

54 auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);

55 DXILValMD->addOperand(Val);

56}

57void addDisableOptimizations(llvm::Module &M) {

58 StringRef Key = "dx.disable_optimizations";

59 M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);

60}

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

83 return;

84

85 std::vector<llvm::Type *> EltTys;

86 for (auto &Const : Buf.Constants) {

87 GlobalVariable *GV = Const.first;

88 Const.second = EltTys.size();

89 llvm::Type *Ty = GV->getValueType();

90 EltTys.emplace_back(Ty);

91 }

92 Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);

93}

94

96

97 GlobalVariable *CBGV = new GlobalVariable(

99 GlobalValue::LinkageTypes::ExternalLinkage, nullptr,

100 llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),

101 GlobalValue::NotThreadLocal);

102

103 IRBuilder<> B(CBGV->getContext());

104 Value *ZeroIdx = B.getInt32(0);

105

106 for (auto &[GV, Offset] : Buf.Constants) {

108 B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});

109

110 assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&

111 "constant type mismatch");

112

113

114 GV->replaceAllUsesWith(GEP);

115

116 GV->removeDeadConstantUsers();

117 GV->eraseFromParent();

118 }

119 return CBGV;

120}

121

122}

123

126

127

129 return TargetTy;

130

131 llvm_unreachable("Generic handling of HLSL types is not supported.");

132}

133

134llvm::Triple::ArchType CGHLSLRuntime::getArch() {

136}

137

138void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {

139 if (D->getStorageClass() == SC_Static) {

140

141

143 return;

144 }

145

147

150 codegenoptions::DebugInfoKind::LimitedDebugInfo)

151 DI->EmitGlobalVariable(cast(GV), D);

152

153

154

155 uint32_t Offset = 0;

156 bool HasUserOffset = false;

157

158 unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;

159 CB.Constants.emplace_back(std::make_pair(GV, LowerBound));

160}

161

162void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {

164 if (auto *ConstDecl = dyn_cast(it)) {

165 addConstant(ConstDecl, CB);

166 } else if (isa<CXXRecordDecl, EmptyDecl>(it)) {

167

168 } else if (isa(it)) {

169

170

172 }

173 }

174}

175

177 Buffers.emplace_back(Buffer(D));

178 addBufferDecls(D, Buffers.back());

179}

180

184 Triple T(M.getTargetTriple());

185 if (T.getArch() == Triple::ArchType::dxil)

186 addDxilValVersion(TargetOpts.DxilValidatorVersion, M);

187

190 addDisableOptimizations(M);

191

192 const DataLayout &DL = M.getDataLayout();

193

194 for (auto &Buf : Buffers) {

195 layoutBuffer(Buf, DL);

196 GlobalVariable *GV = replaceBuffer(Buf);

197 M.insertGlobalVariable(GV);

198 llvm::hlsl::ResourceClass RC = Buf.IsCBuffer

199 ? llvm::hlsl::ResourceClass::CBuffer

200 : llvm::hlsl::ResourceClass::SRV;

201 llvm::hlsl::ResourceKind RK = Buf.IsCBuffer

202 ? llvm::hlsl::ResourceKind::CBuffer

203 : llvm::hlsl::ResourceKind::TBuffer;

204 addBufferResourceAnnotation(GV, RC, RK, false,

205 llvm::hlsl::ElementType::Invalid, Buf.Binding);

206 }

207}

208

210 : Name(D->getName()), IsCBuffer(D->isCBuffer()),

211 Binding(D->getAttr()) {}

212

213void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,

214 llvm::hlsl::ResourceClass RC,

215 llvm::hlsl::ResourceKind RK,

216 bool IsROV,

217 llvm::hlsl::ElementType ET,

218 BufferResBinding &Binding) {

220

221 NamedMDNode *ResourceMD = nullptr;

222 switch (RC) {

223 case llvm::hlsl::ResourceClass::UAV:

224 ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");

225 break;

226 case llvm::hlsl::ResourceClass::SRV:

227 ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");

228 break;

229 case llvm::hlsl::ResourceClass::CBuffer:

230 ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");

231 break;

232 default:

233 assert(false && "Unsupported buffer type!");

234 return;

235 }

236 assert(ResourceMD != nullptr &&

237 "ResourceMD must have been set by the switch above.");

238

239 llvm::hlsl::FrontendResource Res(

240 GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);

241 ResourceMD->addOperand(Res.getMetadata());

242}

243

244static llvm::hlsl::ElementType

246 using llvm::hlsl::ElementType;

247

248

249

251 assert(TST && "Resource types must be template specializations");

253 assert(!Args.empty() && "Resource has no element type");

254

255

256

257 QualType ElTy = Args[0].getAsType();

258

259

261 ElTy = VecTy->getElementType();

262

265 case 16:

266 return ElementType::I16;

267 case 32:

268 return ElementType::I32;

269 case 64:

270 return ElementType::I64;

271 }

274 case 16:

275 return ElementType::U16;

276 case 32:

277 return ElementType::U32;

278 case 64:

279 return ElementType::U64;

280 }

282 return ElementType::F16;

284 return ElementType::F32;

286 return ElementType::F64;

287

288

289 llvm_unreachable("Invalid element type for resource");

290}

291

293 const Type *Ty = D->getType()->getPointeeOrArrayElementType();

294 if (!Ty)

295 return;

297 if (!RD)

298 return;

299

300

301 for (auto *FD : RD->fields()) {

302 const auto *HLSLResAttr = FD->getAttr();

304 dyn_cast(FD->getType().getTypePtr());

305 if (!HLSLResAttr || !AttrResType)

306 continue;

307

308 llvm::hlsl::ResourceClass RC = AttrResType->getAttrs().ResourceClass;

309 if (RC == llvm::hlsl::ResourceClass::UAV ||

310 RC == llvm::hlsl::ResourceClass::SRV)

311

312

313

314

315

316

317 return;

318

319 bool IsROV = AttrResType->getAttrs().IsROV;

320 llvm::hlsl::ResourceKind RK = HLSLResAttr->getResourceKind();

322

323 BufferResBinding Binding(D->getAttr());

324 addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding);

325 }

326}

327

328CGHLSLRuntime::BufferResBinding::BufferResBinding(

329 HLSLResourceBindingAttr *Binding) {

330 if (Binding) {

331 llvm::APInt RegInt(64, 0);

332 Binding->getSlot().substr(1).getAsInteger(10, RegInt);

333 Reg = RegInt.getLimitedValue();

334 llvm::APInt SpaceInt(64, 0);

335 Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);

336 Space = SpaceInt.getLimitedValue();

337 } else {

338 Space = 0;

339 }

340}

341

343 const FunctionDecl *FD, llvm::Function *Fn) {

344 const auto *ShaderAttr = FD->getAttr();

345 assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");

346 const StringRef ShaderAttrKindStr = "hlsl.shader";

347 Fn->addFnAttr(ShaderAttrKindStr,

348 llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));

349 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr()) {

350 const StringRef NumThreadsKindStr = "hlsl.numthreads";

351 std::string NumThreadsStr =

352 formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),

353 NumThreadsAttr->getZ());

354 Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);

355 }

356 if (HLSLWaveSizeAttr *WaveSizeAttr = FD->getAttr()) {

357 const StringRef WaveSizeKindStr = "hlsl.wavesize";

358 std::string WaveSizeStr =

359 formatv("{0},{1},{2}", WaveSizeAttr->getMin(), WaveSizeAttr->getMax(),

360 WaveSizeAttr->getPreferred());

361 Fn->addFnAttr(WaveSizeKindStr, WaveSizeStr);

362 }

363 Fn->addFnAttr(llvm::Attribute::NoInline);

364}

365

367 if (const auto *VT = dyn_cast(Ty)) {

369 for (unsigned I = 0; I < VT->getNumElements(); ++I) {

370 Value *Elt = B.CreateCall(F, {B.getInt32(I)});

371 Result = B.CreateInsertElement(Result, Elt, I);

372 }

374 }

375 return B.CreateCall(F, {B.getInt32(0)});

376}

377

380 llvm::Type *Ty) {

381 assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");

382 if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {

383 llvm::Function *DxGroupIndex =

384 CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);

385 return B.CreateCall(FunctionCallee(DxGroupIndex));

386 }

387 if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {

388 llvm::Function *ThreadIDIntrinsic =

391 }

392 if (D.hasAttr<HLSLSV_GroupThreadIDAttr>()) {

393 llvm::Function *GroupThreadIDIntrinsic =

396 }

397 if (D.hasAttr<HLSLSV_GroupIDAttr>()) {

398 llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(getGroupIdIntrinsic());

400 }

401 assert(false && "Unhandled parameter attribute");

402 return nullptr;

403}

404

406 llvm::Function *Fn) {

408 llvm::LLVMContext &Ctx = M.getContext();

409 auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);

411 Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);

412

413

414

415 AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,

416 Fn->getAttributes().getFnAttrs());

417 EntryFn->setAttributes(NewAttrs);

419

420

421 Fn->setLinkage(GlobalValue::InternalLinkage);

422

423 BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);

424 IRBuilder<> B(BB);

426

429 assert(EntryFn->isConvergent());

430 llvm::Value *I = B.CreateIntrinsic(

431 llvm::Intrinsic::experimental_convergence_entry, {}, {});

432 llvm::Value *bundleArgs[] = {I};

433 OB.emplace_back("convergencectrl", bundleArgs);

434 }

435

436

437

438 unsigned SRetOffset = 0;

439 for (const auto &Param : Fn->args()) {

440 if (Param.hasStructRetAttr()) {

441

442

443 SRetOffset = 1;

444 Args.emplace_back(PoisonValue::get(Param.getType()));

445 continue;

446 }

449 }

450

451 CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);

452 CI->setCallingConv(Fn->getCallingConv());

453

454

455 B.CreateRetVoid();

456}

457

459 llvm::Function *Fn) {

461 const StringRef ExportAttrKindStr = "hlsl.export";

462 Fn->addFnAttr(ExportAttrKindStr);

463 }

464}

465

467 bool CtorOrDtor) {

468 const auto *GV =

469 M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");

470 if (!GV)

471 return;

472 const auto *CA = dyn_cast(GV->getInitializer());

473 if (!CA)

474 return;

475

476

477

478

480 for (const auto &Ctor : CA->operands()) {

481 if (isa(Ctor))

482 continue;

483 ConstantStruct *CS = cast(Ctor);

484

485 assert(cast(CS->getOperand(0))->getValue() == 65535 &&

486 "HLSL doesn't support setting priority for global ctors.");

487 assert(isa(CS->getOperand(2)) &&

488 "HLSL doesn't support COMDat for global ctors.");

489 Fns.push_back(cast(CS->getOperand(1)));

490 }

491}

492

499

500

501

502

503 for (auto &F : M.functions()) {

504 if (!F.hasFnAttribute("hlsl.shader"))

505 continue;

507 Instruction *IP = &*F.getEntryBlock().begin();

510 llvm::Value *bundleArgs[] = {Token};

511 OB.emplace_back("convergencectrl", bundleArgs);

512 IP = Token->getNextNode();

513 }

514 IRBuilder<> B(IP);

515 for (auto *Fn : CtorFns) {

516 auto CI = B.CreateCall(FunctionCallee(Fn), {}, OB);

517 CI->setCallingConv(Fn->getCallingConv());

518 }

519

520

521 B.SetInsertPoint(F.back().getTerminator());

522 for (auto *Fn : DtorFns) {

523 auto CI = B.CreateCall(FunctionCallee(Fn), {}, OB);

524 CI->setCallingConv(Fn->getCallingConv());

525 }

526 }

527

528

529

530 Triple T(M.getTargetTriple());

531 if (T.getEnvironment() != Triple::EnvironmentType::Library) {

532 if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))

533 GV->eraseFromParent();

534 if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))

535 GV->eraseFromParent();

536 }

537}

538

539

542}

543

545 llvm::GlobalVariable *GV, unsigned Slot,

546 unsigned Space) {

548 llvm::Type *Int1Ty = llvm::Type::getInt1Ty(Ctx);

549

550 llvm::Function *InitResFunc = llvm::Function::Create(

551 llvm::FunctionType::get(CGM.VoidTy, false),

552 llvm::GlobalValue::InternalLinkage,

554 InitResFunc->addFnAttr(llvm::Attribute::AlwaysInline);

555

556 llvm::BasicBlock *EntryBB =

557 llvm::BasicBlock::Create(Ctx, "entry", InitResFunc);

559 const DataLayout &DL = CGM.getModule().getDataLayout();

560 Builder.SetInsertPoint(EntryBB);

561

565

566

567

568

569 assert(AttrResType != nullptr &&

570 "Resource class must have a handle of HLSLAttributedResourceType");

571

572 llvm::Type *TargetTy =

574 assert(TargetTy != nullptr &&

575 "Failed to convert resource handle to target type");

576

577 llvm::Value *Args[] = {

578 llvm::ConstantInt::get(CGM.IntTy, Space),

579 llvm::ConstantInt::get(CGM.IntTy, Slot),

580

581 llvm::ConstantInt::get(CGM.IntTy, 1),

582 llvm::ConstantInt::get(CGM.IntTy, 0),

583

584 llvm::ConstantInt::get(Int1Ty, false)

585 };

586 llvm::Value *CreateHandle = Builder.CreateIntrinsic(

587 TargetTy,

588 CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic(), Args, nullptr,

589 Twine(VD->getName()).concat("_h"));

590

591 llvm::Value *HandleRef = Builder.CreateStructGEP(GV->getValueType(), GV, 0);

592 Builder.CreateAlignedStore(CreateHandle, HandleRef,

593 HandleRef->getPointerAlignment(DL));

594 Builder.CreateRetVoid();

595

597}

598

600 llvm::GlobalVariable *GV) {

601

602

603

604 const HLSLResourceBindingAttr *RBA = VD->getAttr();

605 if (!RBA)

606

607

608 return;

609

611

612

613

614 return;

615

617 RBA->getSpaceNumber());

618}

619

622 return nullptr;

623

624 auto E = BB.end();

625 for (auto I = BB.begin(); I != E; ++I) {

626 auto *II = dyn_castllvm::IntrinsicInst(&*I);

627 if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID())) {

628 return II;

629 }

630 }

631 llvm_unreachable("Convergence token should have been emitted.");

632 return nullptr;

633}

static llvm::hlsl::ElementType calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy)

static void gatherFunctions(SmallVectorImpl< Function * > &Fns, llvm::Module &M, bool CtorOrDtor)

static Value * buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty)

static void createResourceInitFn(CodeGenModule &CGM, const VarDecl *VD, llvm::GlobalVariable *GV, unsigned Slot, unsigned Space)

static bool isResourceRecordType(const clang::Type *Ty)

static std::string getName(const CallEvent &Call)

Defines the clang::TargetOptions class.

Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...

uint64_t getTypeSize(QualType T) const

Return the size of the specified (complete) type T, in bits.

This class gathers all debug information during compilation and is responsible for emitting to llvm g...

llvm::Instruction * getConvergenceToken(llvm::BasicBlock &BB)

void setHLSLEntryAttributes(const FunctionDecl *FD, llvm::Function *Fn)

void setHLSLFunctionAttributes(const FunctionDecl *FD, llvm::Function *Fn)

void emitEntryFunction(const FunctionDecl *FD, llvm::Function *Fn)

void handleGlobalVarDefinition(const VarDecl *VD, llvm::GlobalVariable *Var)

llvm::Type * convertHLSLSpecificType(const Type *T)

llvm::Value * emitInputSemantic(llvm::IRBuilder<> &B, const ParmVarDecl &D, llvm::Type *Ty)

void annotateHLSLResource(const VarDecl *D, llvm::GlobalVariable *GV)

void addBuffer(const HLSLBufferDecl *D)

void generateGlobalCtorDtorCalls()

This class organizes the cross-function state that is used while generating LLVM code.

CGHLSLRuntime & getHLSLRuntime()

Return a reference to the configured HLSL runtime.

llvm::Module & getModule() const

CGDebugInfo * getModuleDebugInfo()

void AddCXXGlobalInit(llvm::Function *F)

const TargetInfo & getTarget() const

void EmitGlobal(GlobalDecl D)

Emit code for a single global function or var decl.

bool shouldEmitConvergenceTokens() const

ASTContext & getContext() const

llvm::Constant * GetAddrOfGlobalVar(const VarDecl *D, llvm::Type *Ty=nullptr, ForDefinition_t IsForDefinition=NotForDefinition)

Return the llvm::Constant for the address of the given global variable.

const TargetCodeGenInfo & getTargetCodeGenInfo()

const CodeGenOptions & getCodeGenOpts() const

llvm::LLVMContext & getLLVMContext()

llvm::Function * getIntrinsic(unsigned IID, ArrayRef< llvm::Type * > Tys={})

void EmitTopLevelDecl(Decl *D)

Emit code for a single top level declaration.

virtual llvm::Type * getHLSLType(CodeGenModule &CGM, const Type *T) const

Return an LLVM type that corresponds to a HLSL type.

DeclContext - This is used only as base class of specific decl types that can act as declaration cont...

decl_range decls() const

decls_begin/decls_end - Iterate over the declarations stored in this context.

Decl - This represents one declaration (or definition), e.g.

bool isInExportDeclContext() const

Whether this declaration was exported in a lexical context.

Represents a function declaration or definition.

const ParmVarDecl * getParamDecl(unsigned i) const

static const HLSLAttributedResourceType * findHandleTypeOnResource(const Type *RT)

HLSLBufferDecl - Represent a cbuffer or tbuffer declaration.

StringRef getName() const

Get the name of identifier for this declaration as a StringRef.

Represents a parameter to a function.

A (possibly-)qualified type.

const Type * getTypePtr() const

Retrieves a pointer to the underlying (unqualified) type.

TargetOptions & getTargetOpts() const

Retrieve the target options.

const llvm::Triple & getTriple() const

Returns the target triple of the primary target.

Represents a type template specialization; the template must be a class template, a type alias templa...

Token - This structure provides full information about a lexed token.

The base class of the type hierarchy.

CXXRecordDecl * getAsCXXRecordDecl() const

Retrieves the CXXRecordDecl that this type refers to, either because the type is a RecordType or beca...

bool isSignedIntegerType() const

Return true if this is an integer type that is signed, according to C99 6.2.5p4 [char,...

bool isSpecificBuiltinType(unsigned K) const

Test for a particular builtin type.

bool isHLSLSpecificType() const

bool isUnsignedIntegerType() const

Return true if this is an integer type that is unsigned, according to C99 6.2.5p6 [which returns true...

const T * getAs() const

Member-template getAs'.

Represents a variable declaration or definition.

Represents a GCC generic vector type.

bool Const(InterpState &S, CodePtr OpPC, const T &Arg)

The JSON file list parser is used to communicate input to InstallAPI.

@ Result

The result type of a method or function.

const FunctionProtoType * T

Diagnostic wrappers for TextAPI types for error reporting.

std::vector< std::pair< llvm::GlobalVariable *, unsigned > > Constants

llvm::StructType * LayoutStruct

Buffer(const HLSLBufferDecl *D)

llvm::IntegerType * IntTy

int