clang: lib/CodeGen/CGHLSLRuntime.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
21#include "llvm/IR/GlobalVariable.h"
22#include "llvm/IR/LLVMContext.h"
23#include "llvm/IR/Metadata.h"
24#include "llvm/IR/Module.h"
25#include "llvm/IR/Value.h"
26#include "llvm/Support/Alignment.h"
27
28#include "llvm/Support/FormatVariadic.h"
29
30using namespace clang;
31using namespace CodeGen;
33using namespace llvm;
34
35namespace {
36
37void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
38
39
40 VersionTuple Version;
41 if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
42 Version.getSubminor() || !Version.getMinor()) {
43 return;
44 }
45
46 uint64_t Major = Version.getMajor();
47 uint64_t Minor = *Version.getMinor();
48
49 auto &Ctx = M.getContext();
50 IRBuilder<> B(M.getContext());
51 MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
52 ConstantAsMetadata::get(B.getInt32(Minor))});
53 StringRef DXILValKey = "dx.valver";
54 auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
55 DXILValMD->addOperand(Val);
56}
57void addDisableOptimizations(llvm::Module &M) {
58 StringRef Key = "dx.disable_optimizations";
59 M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);
60}
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
83 return;
84
85 std::vector<llvm::Type *> EltTys;
86 for (auto &Const : Buf.Constants) {
87 GlobalVariable *GV = Const.first;
88 Const.second = EltTys.size();
89 llvm::Type *Ty = GV->getValueType();
90 EltTys.emplace_back(Ty);
91 }
92 Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);
93}
94
96
97 GlobalVariable *CBGV = new GlobalVariable(
99 GlobalValue::LinkageTypes::ExternalLinkage, nullptr,
100 llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),
101 GlobalValue::NotThreadLocal);
102
103 IRBuilder<> B(CBGV->getContext());
104 Value *ZeroIdx = B.getInt32(0);
105
106 for (auto &[GV, Offset] : Buf.Constants) {
108 B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});
109
110 assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&
111 "constant type mismatch");
112
113
114 GV->replaceAllUsesWith(GEP);
115
116 GV->removeDeadConstantUsers();
117 GV->eraseFromParent();
118 }
119 return CBGV;
120}
121
122}
123
126
127
129 return TargetTy;
130
131 llvm_unreachable("Generic handling of HLSL types is not supported.");
132}
133
134llvm::Triple::ArchType CGHLSLRuntime::getArch() {
136}
137
138void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
139 if (D->getStorageClass() == SC_Static) {
140
141
143 return;
144 }
145
147
150 codegenoptions::DebugInfoKind::LimitedDebugInfo)
151 DI->EmitGlobalVariable(cast(GV), D);
152
153
154
155 uint32_t Offset = 0;
156 bool HasUserOffset = false;
157
158 unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;
159 CB.Constants.emplace_back(std::make_pair(GV, LowerBound));
160}
161
162void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
164 if (auto *ConstDecl = dyn_cast(it)) {
165 addConstant(ConstDecl, CB);
166 } else if (isa<CXXRecordDecl, EmptyDecl>(it)) {
167
168 } else if (isa(it)) {
169
170
172 }
173 }
174}
175
177 Buffers.emplace_back(Buffer(D));
178 addBufferDecls(D, Buffers.back());
179}
180
184 Triple T(M.getTargetTriple());
185 if (T.getArch() == Triple::ArchType::dxil)
186 addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
187
190 addDisableOptimizations(M);
191
192 const DataLayout &DL = M.getDataLayout();
193
194 for (auto &Buf : Buffers) {
195 layoutBuffer(Buf, DL);
196 GlobalVariable *GV = replaceBuffer(Buf);
197 M.insertGlobalVariable(GV);
198 llvm::hlsl::ResourceClass RC = Buf.IsCBuffer
199 ? llvm::hlsl::ResourceClass::CBuffer
200 : llvm::hlsl::ResourceClass::SRV;
201 llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
202 ? llvm::hlsl::ResourceKind::CBuffer
203 : llvm::hlsl::ResourceKind::TBuffer;
204 addBufferResourceAnnotation(GV, RC, RK, false,
205 llvm::hlsl::ElementType::Invalid, Buf.Binding);
206 }
207}
208
210 : Name(D->getName()), IsCBuffer(D->isCBuffer()),
211 Binding(D->getAttr()) {}
212
213void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
214 llvm::hlsl::ResourceClass RC,
215 llvm::hlsl::ResourceKind RK,
216 bool IsROV,
217 llvm::hlsl::ElementType ET,
218 BufferResBinding &Binding) {
220
221 NamedMDNode *ResourceMD = nullptr;
222 switch (RC) {
223 case llvm::hlsl::ResourceClass::UAV:
224 ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");
225 break;
226 case llvm::hlsl::ResourceClass::SRV:
227 ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");
228 break;
229 case llvm::hlsl::ResourceClass::CBuffer:
230 ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");
231 break;
232 default:
233 assert(false && "Unsupported buffer type!");
234 return;
235 }
236 assert(ResourceMD != nullptr &&
237 "ResourceMD must have been set by the switch above.");
238
239 llvm::hlsl::FrontendResource Res(
240 GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);
241 ResourceMD->addOperand(Res.getMetadata());
242}
243
244static llvm::hlsl::ElementType
246 using llvm::hlsl::ElementType;
247
248
249
251 assert(TST && "Resource types must be template specializations");
253 assert(!Args.empty() && "Resource has no element type");
254
255
256
257 QualType ElTy = Args[0].getAsType();
258
259
261 ElTy = VecTy->getElementType();
262
265 case 16:
266 return ElementType::I16;
267 case 32:
268 return ElementType::I32;
269 case 64:
270 return ElementType::I64;
271 }
274 case 16:
275 return ElementType::U16;
276 case 32:
277 return ElementType::U32;
278 case 64:
279 return ElementType::U64;
280 }
282 return ElementType::F16;
284 return ElementType::F32;
286 return ElementType::F64;
287
288
289 llvm_unreachable("Invalid element type for resource");
290}
291
293 const Type *Ty = D->getType()->getPointeeOrArrayElementType();
294 if (!Ty)
295 return;
297 if (!RD)
298 return;
299
300
301 for (auto *FD : RD->fields()) {
302 const auto *HLSLResAttr = FD->getAttr();
304 dyn_cast(FD->getType().getTypePtr());
305 if (!HLSLResAttr || !AttrResType)
306 continue;
307
308 llvm::hlsl::ResourceClass RC = AttrResType->getAttrs().ResourceClass;
309 if (RC == llvm::hlsl::ResourceClass::UAV ||
310 RC == llvm::hlsl::ResourceClass::SRV)
311
312
313
314
315
316
317 return;
318
319 bool IsROV = AttrResType->getAttrs().IsROV;
320 llvm::hlsl::ResourceKind RK = HLSLResAttr->getResourceKind();
322
323 BufferResBinding Binding(D->getAttr());
324 addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding);
325 }
326}
327
328CGHLSLRuntime::BufferResBinding::BufferResBinding(
329 HLSLResourceBindingAttr *Binding) {
330 if (Binding) {
331 llvm::APInt RegInt(64, 0);
332 Binding->getSlot().substr(1).getAsInteger(10, RegInt);
333 Reg = RegInt.getLimitedValue();
334 llvm::APInt SpaceInt(64, 0);
335 Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);
336 Space = SpaceInt.getLimitedValue();
337 } else {
338 Space = 0;
339 }
340}
341
343 const FunctionDecl *FD, llvm::Function *Fn) {
344 const auto *ShaderAttr = FD->getAttr();
345 assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
346 const StringRef ShaderAttrKindStr = "hlsl.shader";
347 Fn->addFnAttr(ShaderAttrKindStr,
348 llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));
349 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr()) {
350 const StringRef NumThreadsKindStr = "hlsl.numthreads";
351 std::string NumThreadsStr =
352 formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
353 NumThreadsAttr->getZ());
354 Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
355 }
356 if (HLSLWaveSizeAttr *WaveSizeAttr = FD->getAttr()) {
357 const StringRef WaveSizeKindStr = "hlsl.wavesize";
358 std::string WaveSizeStr =
359 formatv("{0},{1},{2}", WaveSizeAttr->getMin(), WaveSizeAttr->getMax(),
360 WaveSizeAttr->getPreferred());
361 Fn->addFnAttr(WaveSizeKindStr, WaveSizeStr);
362 }
363 Fn->addFnAttr(llvm::Attribute::NoInline);
364}
365
367 if (const auto *VT = dyn_cast(Ty)) {
369 for (unsigned I = 0; I < VT->getNumElements(); ++I) {
370 Value *Elt = B.CreateCall(F, {B.getInt32(I)});
371 Result = B.CreateInsertElement(Result, Elt, I);
372 }
374 }
375 return B.CreateCall(F, {B.getInt32(0)});
376}
377
380 llvm::Type *Ty) {
381 assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
382 if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
383 llvm::Function *DxGroupIndex =
384 CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);
385 return B.CreateCall(FunctionCallee(DxGroupIndex));
386 }
387 if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
388 llvm::Function *ThreadIDIntrinsic =
391 }
392 if (D.hasAttr<HLSLSV_GroupThreadIDAttr>()) {
393 llvm::Function *GroupThreadIDIntrinsic =
396 }
397 if (D.hasAttr<HLSLSV_GroupIDAttr>()) {
398 llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(getGroupIdIntrinsic());
400 }
401 assert(false && "Unhandled parameter attribute");
402 return nullptr;
403}
404
406 llvm::Function *Fn) {
408 llvm::LLVMContext &Ctx = M.getContext();
409 auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
411 Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);
412
413
414
415 AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,
416 Fn->getAttributes().getFnAttrs());
417 EntryFn->setAttributes(NewAttrs);
419
420
421 Fn->setLinkage(GlobalValue::InternalLinkage);
422
423 BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
424 IRBuilder<> B(BB);
426
429 assert(EntryFn->isConvergent());
430 llvm::Value *I = B.CreateIntrinsic(
431 llvm::Intrinsic::experimental_convergence_entry, {}, {});
432 llvm::Value *bundleArgs[] = {I};
433 OB.emplace_back("convergencectrl", bundleArgs);
434 }
435
436
437
438 unsigned SRetOffset = 0;
439 for (const auto &Param : Fn->args()) {
440 if (Param.hasStructRetAttr()) {
441
442
443 SRetOffset = 1;
444 Args.emplace_back(PoisonValue::get(Param.getType()));
445 continue;
446 }
449 }
450
451 CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);
452 CI->setCallingConv(Fn->getCallingConv());
453
454
455 B.CreateRetVoid();
456}
457
459 llvm::Function *Fn) {
461 const StringRef ExportAttrKindStr = "hlsl.export";
462 Fn->addFnAttr(ExportAttrKindStr);
463 }
464}
465
467 bool CtorOrDtor) {
468 const auto *GV =
469 M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
470 if (!GV)
471 return;
472 const auto *CA = dyn_cast(GV->getInitializer());
473 if (!CA)
474 return;
475
476
477
478
480 for (const auto &Ctor : CA->operands()) {
481 if (isa(Ctor))
482 continue;
483 ConstantStruct *CS = cast(Ctor);
484
485 assert(cast(CS->getOperand(0))->getValue() == 65535 &&
486 "HLSL doesn't support setting priority for global ctors.");
487 assert(isa(CS->getOperand(2)) &&
488 "HLSL doesn't support COMDat for global ctors.");
489 Fns.push_back(cast(CS->getOperand(1)));
490 }
491}
492
499
500
501
502
503 for (auto &F : M.functions()) {
504 if (!F.hasFnAttribute("hlsl.shader"))
505 continue;
507 Instruction *IP = &*F.getEntryBlock().begin();
510 llvm::Value *bundleArgs[] = {Token};
511 OB.emplace_back("convergencectrl", bundleArgs);
512 IP = Token->getNextNode();
513 }
514 IRBuilder<> B(IP);
515 for (auto *Fn : CtorFns) {
516 auto CI = B.CreateCall(FunctionCallee(Fn), {}, OB);
517 CI->setCallingConv(Fn->getCallingConv());
518 }
519
520
521 B.SetInsertPoint(F.back().getTerminator());
522 for (auto *Fn : DtorFns) {
523 auto CI = B.CreateCall(FunctionCallee(Fn), {}, OB);
524 CI->setCallingConv(Fn->getCallingConv());
525 }
526 }
527
528
529
530 Triple T(M.getTargetTriple());
531 if (T.getEnvironment() != Triple::EnvironmentType::Library) {
532 if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))
533 GV->eraseFromParent();
534 if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))
535 GV->eraseFromParent();
536 }
537}
538
540 llvm::GlobalVariable *GV) {
541
542
543 const HLSLResourceBindingAttr *RBA = VD->getAttr();
544 if (!RBA)
545 return;
546
549
550
551
552 return;
553
554 ResourcesToBind.emplace_back(VD, GV);
555}
556
558 return !ResourcesToBind.empty();
559}
560
562
564
566 llvm::Type *Int1Ty = llvm::Type::getInt1Ty(Ctx);
567
568 llvm::Function *InitResBindingsFunc =
569 llvm::Function::Create(llvm::FunctionType::get(CGM.VoidTy, false),
570 llvm::GlobalValue::InternalLinkage,
572
573 llvm::BasicBlock *EntryBB =
574 llvm::BasicBlock::Create(Ctx, "entry", InitResBindingsFunc);
576 const DataLayout &DL = CGM.getModule().getDataLayout();
577 Builder.SetInsertPoint(EntryBB);
578
579 for (const auto &[VD, GV] : ResourcesToBind) {
580 for (Attr *A : VD->getAttrs()) {
581 HLSLResourceBindingAttr *RBA = dyn_cast(A);
582 if (!RBA)
583 continue;
584
587 VD->getType().getTypePtr());
588
589
590
591
592 assert(AttrResType != nullptr &&
593 "Resource class must have a handle of HLSLAttributedResourceType");
594
595 llvm::Type *TargetTy =
597 assert(TargetTy != nullptr &&
598 "Failed to convert resource handle to target type");
599
600 auto *Space = llvm::ConstantInt::get(CGM.IntTy, RBA->getSpaceNumber());
601 auto *Slot = llvm::ConstantInt::get(CGM.IntTy, RBA->getSlotNumber());
602
603 auto *Range = llvm::ConstantInt::get(CGM.IntTy, 1);
604 auto *Index = llvm::ConstantInt::get(CGM.IntTy, 0);
605
606 auto *NonUniform = llvm::ConstantInt::get(Int1Ty, false);
607 llvm::Value *Args[] = {Space, Slot, Range, Index, NonUniform};
608
609 llvm::Value *CreateHandle = Builder.CreateIntrinsic(
610 TargetTy, getCreateHandleFromBindingIntrinsic(), Args,
611 nullptr, Twine(VD->getName()).concat("_h"));
612
613 llvm::Value *HandleRef =
614 Builder.CreateStructGEP(GV->getValueType(), GV, 0);
615 Builder.CreateAlignedStore(CreateHandle, HandleRef,
616 HandleRef->getPointerAlignment(DL));
617 }
618 }
619
620 Builder.CreateRetVoid();
621 return InitResBindingsFunc;
622}
623
626 return nullptr;
627
628 auto E = BB.end();
629 for (auto I = BB.begin(); I != E; ++I) {
630 auto *II = dyn_castllvm::IntrinsicInst(&*I);
631 if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID())) {
632 return II;
633 }
634 }
635 llvm_unreachable("Convergence token should have been emitted.");
636 return nullptr;
637}
static llvm::hlsl::ElementType calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy)
static void gatherFunctions(SmallVectorImpl< Function * > &Fns, llvm::Module &M, bool CtorOrDtor)
static Value * buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty)
static std::string getName(const CallEvent &Call)
Defines the clang::TargetOptions class.
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
Attr - This represents one attribute.
This class gathers all debug information during compilation and is responsible for emitting to llvm g...
llvm::Instruction * getConvergenceToken(llvm::BasicBlock &BB)
void setHLSLEntryAttributes(const FunctionDecl *FD, llvm::Function *Fn)
void setHLSLFunctionAttributes(const FunctionDecl *FD, llvm::Function *Fn)
void emitEntryFunction(const FunctionDecl *FD, llvm::Function *Fn)
bool needsResourceBindingInitFn()
void handleGlobalVarDefinition(const VarDecl *VD, llvm::GlobalVariable *Var)
llvm::Type * convertHLSLSpecificType(const Type *T)
llvm::Value * emitInputSemantic(llvm::IRBuilder<> &B, const ParmVarDecl &D, llvm::Type *Ty)
llvm::Function * createResourceBindingInitFn()
void annotateHLSLResource(const VarDecl *D, llvm::GlobalVariable *GV)
void addBuffer(const HLSLBufferDecl *D)
void generateGlobalCtorDtorCalls()
llvm::Module & getModule() const
CGDebugInfo * getModuleDebugInfo()
const TargetInfo & getTarget() const
void EmitGlobal(GlobalDecl D)
Emit code for a single global function or var decl.
bool shouldEmitConvergenceTokens() const
ASTContext & getContext() const
llvm::Constant * GetAddrOfGlobalVar(const VarDecl *D, llvm::Type *Ty=nullptr, ForDefinition_t IsForDefinition=NotForDefinition)
Return the llvm::Constant for the address of the given global variable.
const TargetCodeGenInfo & getTargetCodeGenInfo()
const CodeGenOptions & getCodeGenOpts() const
llvm::LLVMContext & getLLVMContext()
llvm::Function * getIntrinsic(unsigned IID, ArrayRef< llvm::Type * > Tys={})
void EmitTopLevelDecl(Decl *D)
Emit code for a single top level declaration.
virtual llvm::Type * getHLSLType(CodeGenModule &CGM, const Type *T) const
Return an LLVM type that corresponds to a HLSL type.
DeclContext - This is used only as base class of specific decl types that can act as declaration cont...
decl_range decls() const
decls_begin/decls_end - Iterate over the declarations stored in this context.
Decl - This represents one declaration (or definition), e.g.
bool isInExportDeclContext() const
Whether this declaration was exported in a lexical context.
Represents a function declaration or definition.
const ParmVarDecl * getParamDecl(unsigned i) const
static const HLSLAttributedResourceType * findHandleTypeOnResource(const Type *RT)
HLSLBufferDecl - Represent a cbuffer or tbuffer declaration.
StringRef getName() const
Get the name of identifier for this declaration as a StringRef.
Represents a parameter to a function.
A (possibly-)qualified type.
const Type * getTypePtr() const
Retrieves a pointer to the underlying (unqualified) type.
TargetOptions & getTargetOpts() const
Retrieve the target options.
const llvm::Triple & getTriple() const
Returns the target triple of the primary target.
Represents a type template specialization; the template must be a class template, a type alias templa...
Token - This structure provides full information about a lexed token.
The base class of the type hierarchy.
CXXRecordDecl * getAsCXXRecordDecl() const
Retrieves the CXXRecordDecl that this type refers to, either because the type is a RecordType or beca...
bool isSignedIntegerType() const
Return true if this is an integer type that is signed, according to C99 6.2.5p4 [char,...
bool isSpecificBuiltinType(unsigned K) const
Test for a particular builtin type.
bool isHLSLSpecificType() const
bool isUnsignedIntegerType() const
Return true if this is an integer type that is unsigned, according to C99 6.2.5p6 [which returns true...
const T * getAs() const
Member-template getAs'.
Represents a variable declaration or definition.
Represents a GCC generic vector type.
bool Const(InterpState &S, CodePtr OpPC, const T &Arg)
The JSON file list parser is used to communicate input to InstallAPI.
@ Result
The result type of a method or function.
const FunctionProtoType * T
Diagnostic wrappers for TextAPI types for error reporting.
std::vector< std::pair< llvm::GlobalVariable *, unsigned > > Constants
llvm::StructType * LayoutStruct
Buffer(const HLSLBufferDecl *D)
llvm::IntegerType * IntTy
int