clang: lib/CodeGen/CGHLSLRuntime.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
21#include "llvm/IR/GlobalVariable.h"
22#include "llvm/IR/LLVMContext.h"
23#include "llvm/IR/Metadata.h"
24#include "llvm/IR/Module.h"
25#include "llvm/IR/Value.h"
26#include "llvm/Support/Alignment.h"
27
28#include "llvm/Support/FormatVariadic.h"
29
30using namespace clang;
31using namespace CodeGen;
33using namespace llvm;
34
35namespace {
36
37void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
38
39
40 VersionTuple Version;
41 if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
42 Version.getSubminor() || !Version.getMinor()) {
43 return;
44 }
45
46 uint64_t Major = Version.getMajor();
47 uint64_t Minor = *Version.getMinor();
48
49 auto &Ctx = M.getContext();
50 IRBuilder<> B(M.getContext());
51 MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
52 ConstantAsMetadata::get(B.getInt32(Minor))});
53 StringRef DXILValKey = "dx.valver";
54 auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
55 DXILValMD->addOperand(Val);
56}
57void addDisableOptimizations(llvm::Module &M) {
58 StringRef Key = "dx.disable_optimizations";
59 M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);
60}
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
83 return;
84
85 std::vector<llvm::Type *> EltTys;
86 for (auto &Const : Buf.Constants) {
87 GlobalVariable *GV = Const.first;
88 Const.second = EltTys.size();
89 llvm::Type *Ty = GV->getValueType();
90 EltTys.emplace_back(Ty);
91 }
92 Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);
93}
94
96
97 GlobalVariable *CBGV = new GlobalVariable(
99 GlobalValue::LinkageTypes::ExternalLinkage, nullptr,
100 llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),
101 GlobalValue::NotThreadLocal);
102
103 IRBuilder<> B(CBGV->getContext());
104 Value *ZeroIdx = B.getInt32(0);
105
106 for (auto &[GV, Offset] : Buf.Constants) {
108 B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});
109
110 assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&
111 "constant type mismatch");
112
113
114 GV->replaceAllUsesWith(GEP);
115
116 GV->removeDeadConstantUsers();
117 GV->eraseFromParent();
118 }
119 return CBGV;
120}
121
122}
123
126
127
129 return TargetTy;
130
131 llvm_unreachable("Generic handling of HLSL types is not supported.");
132}
133
134llvm::Triple::ArchType CGHLSLRuntime::getArch() {
136}
137
138void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
139 if (D->getStorageClass() == SC_Static) {
140
141
143 return;
144 }
145
147
150 codegenoptions::DebugInfoKind::LimitedDebugInfo)
151 DI->EmitGlobalVariable(cast(GV), D);
152
153
154
155 uint32_t Offset = 0;
156 bool HasUserOffset = false;
157
158 unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;
159 CB.Constants.emplace_back(std::make_pair(GV, LowerBound));
160}
161
162void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
164 if (auto *ConstDecl = dyn_cast(it)) {
165 addConstant(ConstDecl, CB);
166 } else if (isa<CXXRecordDecl, EmptyDecl>(it)) {
167
168 } else if (isa(it)) {
169
170
172 }
173 }
174}
175
177 Buffers.emplace_back(Buffer(D));
178 addBufferDecls(D, Buffers.back());
179}
180
184 Triple T(M.getTargetTriple());
185 if (T.getArch() == Triple::ArchType::dxil)
186 addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
187
190 addDisableOptimizations(M);
191
192 const DataLayout &DL = M.getDataLayout();
193
194 for (auto &Buf : Buffers) {
195 layoutBuffer(Buf, DL);
196 GlobalVariable *GV = replaceBuffer(Buf);
197 M.insertGlobalVariable(GV);
198 llvm::hlsl::ResourceClass RC = Buf.IsCBuffer
199 ? llvm::hlsl::ResourceClass::CBuffer
200 : llvm::hlsl::ResourceClass::SRV;
201 llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
202 ? llvm::hlsl::ResourceKind::CBuffer
203 : llvm::hlsl::ResourceKind::TBuffer;
204 addBufferResourceAnnotation(GV, RC, RK, false,
205 llvm::hlsl::ElementType::Invalid, Buf.Binding);
206 }
207}
208
210 : Name(D->getName()), IsCBuffer(D->isCBuffer()),
211 Binding(D->getAttr()) {}
212
213void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
214 llvm::hlsl::ResourceClass RC,
215 llvm::hlsl::ResourceKind RK,
216 bool IsROV,
217 llvm::hlsl::ElementType ET,
218 BufferResBinding &Binding) {
220
221 NamedMDNode *ResourceMD = nullptr;
222 switch (RC) {
223 case llvm::hlsl::ResourceClass::UAV:
224 ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");
225 break;
226 case llvm::hlsl::ResourceClass::SRV:
227 ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");
228 break;
229 case llvm::hlsl::ResourceClass::CBuffer:
230 ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");
231 break;
232 default:
233 assert(false && "Unsupported buffer type!");
234 return;
235 }
236 assert(ResourceMD != nullptr &&
237 "ResourceMD must have been set by the switch above.");
238
239 llvm::hlsl::FrontendResource Res(
240 GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);
241 ResourceMD->addOperand(Res.getMetadata());
242}
243
244static llvm::hlsl::ElementType
246 using llvm::hlsl::ElementType;
247
248
249
251 assert(TST && "Resource types must be template specializations");
253 assert(!Args.empty() && "Resource has no element type");
254
255
256
257 QualType ElTy = Args[0].getAsType();
258
259
261 ElTy = VecTy->getElementType();
262
265 case 16:
266 return ElementType::I16;
267 case 32:
268 return ElementType::I32;
269 case 64:
270 return ElementType::I64;
271 }
274 case 16:
275 return ElementType::U16;
276 case 32:
277 return ElementType::U32;
278 case 64:
279 return ElementType::U64;
280 }
282 return ElementType::F16;
284 return ElementType::F32;
286 return ElementType::F64;
287
288
289 llvm_unreachable("Invalid element type for resource");
290}
291
293 const Type *Ty = D->getType()->getPointeeOrArrayElementType();
294 if (!Ty)
295 return;
297 if (!RD)
298 return;
299
300
301 for (auto *FD : RD->fields()) {
302 const auto *HLSLResAttr = FD->getAttr();
304 dyn_cast(FD->getType().getTypePtr());
305 if (!HLSLResAttr || !AttrResType)
306 continue;
307
308 llvm::hlsl::ResourceClass RC = AttrResType->getAttrs().ResourceClass;
309 if (RC == llvm::hlsl::ResourceClass::UAV ||
310 RC == llvm::hlsl::ResourceClass::SRV)
311
312
313
314
315
316
317 return;
318
319 bool IsROV = AttrResType->getAttrs().IsROV;
320 llvm::hlsl::ResourceKind RK = HLSLResAttr->getResourceKind();
322
323 BufferResBinding Binding(D->getAttr());
324 addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding);
325 }
326}
327
328CGHLSLRuntime::BufferResBinding::BufferResBinding(
329 HLSLResourceBindingAttr *Binding) {
330 if (Binding) {
331 llvm::APInt RegInt(64, 0);
332 Binding->getSlot().substr(1).getAsInteger(10, RegInt);
333 Reg = RegInt.getLimitedValue();
334 llvm::APInt SpaceInt(64, 0);
335 Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);
336 Space = SpaceInt.getLimitedValue();
337 } else {
338 Space = 0;
339 }
340}
341
343 const FunctionDecl *FD, llvm::Function *Fn) {
344 const auto *ShaderAttr = FD->getAttr();
345 assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
346 const StringRef ShaderAttrKindStr = "hlsl.shader";
347 Fn->addFnAttr(ShaderAttrKindStr,
348 llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));
349 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr()) {
350 const StringRef NumThreadsKindStr = "hlsl.numthreads";
351 std::string NumThreadsStr =
352 formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
353 NumThreadsAttr->getZ());
354 Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
355 }
356 if (HLSLWaveSizeAttr *WaveSizeAttr = FD->getAttr()) {
357 const StringRef WaveSizeKindStr = "hlsl.wavesize";
358 std::string WaveSizeStr =
359 formatv("{0},{1},{2}", WaveSizeAttr->getMin(), WaveSizeAttr->getMax(),
360 WaveSizeAttr->getPreferred());
361 Fn->addFnAttr(WaveSizeKindStr, WaveSizeStr);
362 }
363 Fn->addFnAttr(llvm::Attribute::NoInline);
364}
365
367 if (const auto *VT = dyn_cast(Ty)) {
369 for (unsigned I = 0; I < VT->getNumElements(); ++I) {
370 Value *Elt = B.CreateCall(F, {B.getInt32(I)});
371 Result = B.CreateInsertElement(Result, Elt, I);
372 }
374 }
375 return B.CreateCall(F, {B.getInt32(0)});
376}
377
380 llvm::Type *Ty) {
381 assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
382 if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
383 llvm::Function *DxGroupIndex =
384 CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);
385 return B.CreateCall(FunctionCallee(DxGroupIndex));
386 }
387 if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
388 llvm::Function *ThreadIDIntrinsic =
391 }
392 if (D.hasAttr<HLSLSV_GroupThreadIDAttr>()) {
393 llvm::Function *GroupThreadIDIntrinsic =
396 }
397 if (D.hasAttr<HLSLSV_GroupIDAttr>()) {
398 llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(getGroupIdIntrinsic());
400 }
401 assert(false && "Unhandled parameter attribute");
402 return nullptr;
403}
404
406 llvm::Function *Fn) {
408 llvm::LLVMContext &Ctx = M.getContext();
409 auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
411 Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);
412
413
414
415 AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,
416 Fn->getAttributes().getFnAttrs());
417 EntryFn->setAttributes(NewAttrs);
419
420
421 Fn->setLinkage(GlobalValue::InternalLinkage);
422
423 BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
424 IRBuilder<> B(BB);
426
429 assert(EntryFn->isConvergent());
430 llvm::Value *I = B.CreateIntrinsic(
431 llvm::Intrinsic::experimental_convergence_entry, {}, {});
432 llvm::Value *bundleArgs[] = {I};
433 OB.emplace_back("convergencectrl", bundleArgs);
434 }
435
436
437
438 unsigned SRetOffset = 0;
439 for (const auto &Param : Fn->args()) {
440 if (Param.hasStructRetAttr()) {
441
442
443 SRetOffset = 1;
444 Args.emplace_back(PoisonValue::get(Param.getType()));
445 continue;
446 }
449 }
450
451 CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);
452 CI->setCallingConv(Fn->getCallingConv());
453
454
455 B.CreateRetVoid();
456}
457
459 llvm::Function *Fn) {
461 const StringRef ExportAttrKindStr = "hlsl.export";
462 Fn->addFnAttr(ExportAttrKindStr);
463 }
464}
465
467 bool CtorOrDtor) {
468 const auto *GV =
469 M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
470 if (!GV)
471 return;
472 const auto *CA = dyn_cast(GV->getInitializer());
473 if (!CA)
474 return;
475
476
477
478
480 for (const auto &Ctor : CA->operands()) {
481 if (isa(Ctor))
482 continue;
483 ConstantStruct *CS = cast(Ctor);
484
485 assert(cast(CS->getOperand(0))->getValue() == 65535 &&
486 "HLSL doesn't support setting priority for global ctors.");
487 assert(isa(CS->getOperand(2)) &&
488 "HLSL doesn't support COMDat for global ctors.");
489 Fns.push_back(cast(CS->getOperand(1)));
490 }
491}
492
499
500
501
502
503 for (auto &F : M.functions()) {
504 if (!F.hasFnAttribute("hlsl.shader"))
505 continue;
507 Instruction *IP = &*F.getEntryBlock().begin();
510 llvm::Value *bundleArgs[] = {Token};
511 OB.emplace_back("convergencectrl", bundleArgs);
512 IP = Token->getNextNode();
513 }
514 IRBuilder<> B(IP);
515 for (auto *Fn : CtorFns) {
516 auto CI = B.CreateCall(FunctionCallee(Fn), {}, OB);
517 CI->setCallingConv(Fn->getCallingConv());
518 }
519
520
521 B.SetInsertPoint(F.back().getTerminator());
522 for (auto *Fn : DtorFns) {
523 auto CI = B.CreateCall(FunctionCallee(Fn), {}, OB);
524 CI->setCallingConv(Fn->getCallingConv());
525 }
526 }
527
528
529
530 Triple T(M.getTargetTriple());
531 if (T.getEnvironment() != Triple::EnvironmentType::Library) {
532 if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))
533 GV->eraseFromParent();
534 if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))
535 GV->eraseFromParent();
536 }
537}
538
539
542}
543
545 llvm::GlobalVariable *GV, unsigned Slot,
546 unsigned Space) {
548 llvm::Type *Int1Ty = llvm::Type::getInt1Ty(Ctx);
549
550 llvm::Function *InitResFunc = llvm::Function::Create(
551 llvm::FunctionType::get(CGM.VoidTy, false),
552 llvm::GlobalValue::InternalLinkage,
554 InitResFunc->addFnAttr(llvm::Attribute::AlwaysInline);
555
556 llvm::BasicBlock *EntryBB =
557 llvm::BasicBlock::Create(Ctx, "entry", InitResFunc);
559 const DataLayout &DL = CGM.getModule().getDataLayout();
560 Builder.SetInsertPoint(EntryBB);
561
565
566
567
568
569 assert(AttrResType != nullptr &&
570 "Resource class must have a handle of HLSLAttributedResourceType");
571
572 llvm::Type *TargetTy =
574 assert(TargetTy != nullptr &&
575 "Failed to convert resource handle to target type");
576
577 llvm::Value *Args[] = {
578 llvm::ConstantInt::get(CGM.IntTy, Space),
579 llvm::ConstantInt::get(CGM.IntTy, Slot),
580
581 llvm::ConstantInt::get(CGM.IntTy, 1),
582 llvm::ConstantInt::get(CGM.IntTy, 0),
583
584 llvm::ConstantInt::get(Int1Ty, false)
585 };
586 llvm::Value *CreateHandle = Builder.CreateIntrinsic(
587 TargetTy,
588 CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic(), Args, nullptr,
589 Twine(VD->getName()).concat("_h"));
590
591 llvm::Value *HandleRef = Builder.CreateStructGEP(GV->getValueType(), GV, 0);
592 Builder.CreateAlignedStore(CreateHandle, HandleRef,
593 HandleRef->getPointerAlignment(DL));
594 Builder.CreateRetVoid();
595
597}
598
600 llvm::GlobalVariable *GV) {
601
602
603
604 const HLSLResourceBindingAttr *RBA = VD->getAttr();
605 if (!RBA)
606
607
608 return;
609
611
612
613
614 return;
615
617 RBA->getSpaceNumber());
618}
619
622 return nullptr;
623
624 auto E = BB.end();
625 for (auto I = BB.begin(); I != E; ++I) {
626 auto *II = dyn_castllvm::IntrinsicInst(&*I);
627 if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID())) {
628 return II;
629 }
630 }
631 llvm_unreachable("Convergence token should have been emitted.");
632 return nullptr;
633}
static llvm::hlsl::ElementType calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy)
static void gatherFunctions(SmallVectorImpl< Function * > &Fns, llvm::Module &M, bool CtorOrDtor)
static Value * buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty)
static void createResourceInitFn(CodeGenModule &CGM, const VarDecl *VD, llvm::GlobalVariable *GV, unsigned Slot, unsigned Space)
static bool isResourceRecordType(const clang::Type *Ty)
static std::string getName(const CallEvent &Call)
Defines the clang::TargetOptions class.
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
This class gathers all debug information during compilation and is responsible for emitting to llvm g...
llvm::Instruction * getConvergenceToken(llvm::BasicBlock &BB)
void setHLSLEntryAttributes(const FunctionDecl *FD, llvm::Function *Fn)
void setHLSLFunctionAttributes(const FunctionDecl *FD, llvm::Function *Fn)
void emitEntryFunction(const FunctionDecl *FD, llvm::Function *Fn)
void handleGlobalVarDefinition(const VarDecl *VD, llvm::GlobalVariable *Var)
llvm::Type * convertHLSLSpecificType(const Type *T)
llvm::Value * emitInputSemantic(llvm::IRBuilder<> &B, const ParmVarDecl &D, llvm::Type *Ty)
void annotateHLSLResource(const VarDecl *D, llvm::GlobalVariable *GV)
void addBuffer(const HLSLBufferDecl *D)
void generateGlobalCtorDtorCalls()
This class organizes the cross-function state that is used while generating LLVM code.
CGHLSLRuntime & getHLSLRuntime()
Return a reference to the configured HLSL runtime.
llvm::Module & getModule() const
CGDebugInfo * getModuleDebugInfo()
void AddCXXGlobalInit(llvm::Function *F)
const TargetInfo & getTarget() const
void EmitGlobal(GlobalDecl D)
Emit code for a single global function or var decl.
bool shouldEmitConvergenceTokens() const
ASTContext & getContext() const
llvm::Constant * GetAddrOfGlobalVar(const VarDecl *D, llvm::Type *Ty=nullptr, ForDefinition_t IsForDefinition=NotForDefinition)
Return the llvm::Constant for the address of the given global variable.
const TargetCodeGenInfo & getTargetCodeGenInfo()
const CodeGenOptions & getCodeGenOpts() const
llvm::LLVMContext & getLLVMContext()
llvm::Function * getIntrinsic(unsigned IID, ArrayRef< llvm::Type * > Tys={})
void EmitTopLevelDecl(Decl *D)
Emit code for a single top level declaration.
virtual llvm::Type * getHLSLType(CodeGenModule &CGM, const Type *T) const
Return an LLVM type that corresponds to a HLSL type.
DeclContext - This is used only as base class of specific decl types that can act as declaration cont...
decl_range decls() const
decls_begin/decls_end - Iterate over the declarations stored in this context.
Decl - This represents one declaration (or definition), e.g.
bool isInExportDeclContext() const
Whether this declaration was exported in a lexical context.
Represents a function declaration or definition.
const ParmVarDecl * getParamDecl(unsigned i) const
static const HLSLAttributedResourceType * findHandleTypeOnResource(const Type *RT)
HLSLBufferDecl - Represent a cbuffer or tbuffer declaration.
StringRef getName() const
Get the name of identifier for this declaration as a StringRef.
Represents a parameter to a function.
A (possibly-)qualified type.
const Type * getTypePtr() const
Retrieves a pointer to the underlying (unqualified) type.
TargetOptions & getTargetOpts() const
Retrieve the target options.
const llvm::Triple & getTriple() const
Returns the target triple of the primary target.
Represents a type template specialization; the template must be a class template, a type alias templa...
Token - This structure provides full information about a lexed token.
The base class of the type hierarchy.
CXXRecordDecl * getAsCXXRecordDecl() const
Retrieves the CXXRecordDecl that this type refers to, either because the type is a RecordType or beca...
bool isSignedIntegerType() const
Return true if this is an integer type that is signed, according to C99 6.2.5p4 [char,...
bool isSpecificBuiltinType(unsigned K) const
Test for a particular builtin type.
bool isHLSLSpecificType() const
bool isUnsignedIntegerType() const
Return true if this is an integer type that is unsigned, according to C99 6.2.5p6 [which returns true...
const T * getAs() const
Member-template getAs'.
Represents a variable declaration or definition.
Represents a GCC generic vector type.
bool Const(InterpState &S, CodePtr OpPC, const T &Arg)
The JSON file list parser is used to communicate input to InstallAPI.
@ Result
The result type of a method or function.
const FunctionProtoType * T
Diagnostic wrappers for TextAPI types for error reporting.
std::vector< std::pair< llvm::GlobalVariable *, unsigned > > Constants
llvm::StructType * LayoutStruct
Buffer(const HLSLBufferDecl *D)
llvm::IntegerType * IntTy
int