clang: lib/CodeGen/CGHLSLRuntime.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

21#include "llvm/IR/GlobalVariable.h"

22#include "llvm/IR/LLVMContext.h"

23#include "llvm/IR/Metadata.h"

24#include "llvm/IR/Module.h"

25#include "llvm/IR/Value.h"

26#include "llvm/Support/Alignment.h"

27

28#include "llvm/Support/FormatVariadic.h"

29

30using namespace clang;

31using namespace CodeGen;

33using namespace llvm;

34

35namespace {

36

37void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {

38

39

40 VersionTuple Version;

41 if (Version.tryParse(ValVersionStr) || Version.getBuild() ||

42 Version.getSubminor() || !Version.getMinor()) {

43 return;

44 }

45

46 uint64_t Major = Version.getMajor();

47 uint64_t Minor = *Version.getMinor();

48

49 auto &Ctx = M.getContext();

50 IRBuilder<> B(M.getContext());

51 MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),

52 ConstantAsMetadata::get(B.getInt32(Minor))});

53 StringRef DXILValKey = "dx.valver";

54 auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);

55 DXILValMD->addOperand(Val);

56}

57void addDisableOptimizations(llvm::Module &M) {

58 StringRef Key = "dx.disable_optimizations";

59 M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);

60}

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

83 return;

84

85 std::vector<llvm::Type *> EltTys;

86 for (auto &Const : Buf.Constants) {

87 GlobalVariable *GV = Const.first;

88 Const.second = EltTys.size();

89 llvm::Type *Ty = GV->getValueType();

90 EltTys.emplace_back(Ty);

91 }

92 Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);

93}

94

96

97 GlobalVariable *CBGV = new GlobalVariable(

99 GlobalValue::LinkageTypes::ExternalLinkage, nullptr,

100 llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),

101 GlobalValue::NotThreadLocal);

102

103 IRBuilder<> B(CBGV->getContext());

104 Value *ZeroIdx = B.getInt32(0);

105

106 for (auto &[GV, Offset] : Buf.Constants) {

108 B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});

109

110 assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&

111 "constant type mismatch");

112

113

114 GV->replaceAllUsesWith(GEP);

115

116 GV->removeDeadConstantUsers();

117 GV->eraseFromParent();

118 }

119 return CBGV;

120}

121

122}

123

126

127

129 return TargetTy;

130

131 llvm_unreachable("Generic handling of HLSL types is not supported.");

132}

133

134llvm::Triple::ArchType CGHLSLRuntime::getArch() {

136}

137

138void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {

139 if (D->getStorageClass() == SC_Static) {

140

141

143 return;

144 }

145

147

150 codegenoptions::DebugInfoKind::LimitedDebugInfo)

151 DI->EmitGlobalVariable(cast(GV), D);

152

153

154

155 uint32_t Offset = 0;

156 bool HasUserOffset = false;

157

158 unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;

159 CB.Constants.emplace_back(std::make_pair(GV, LowerBound));

160}

161

162void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {

164 if (auto *ConstDecl = dyn_cast(it)) {

165 addConstant(ConstDecl, CB);

166 } else if (isa<CXXRecordDecl, EmptyDecl>(it)) {

167

168 } else if (isa(it)) {

169

170

172 }

173 }

174}

175

177 Buffers.emplace_back(Buffer(D));

178 addBufferDecls(D, Buffers.back());

179}

180

184 Triple T(M.getTargetTriple());

185 if (T.getArch() == Triple::ArchType::dxil)

186 addDxilValVersion(TargetOpts.DxilValidatorVersion, M);

187

190 addDisableOptimizations(M);

191

192 const DataLayout &DL = M.getDataLayout();

193

194 for (auto &Buf : Buffers) {

195 layoutBuffer(Buf, DL);

196 GlobalVariable *GV = replaceBuffer(Buf);

197 M.insertGlobalVariable(GV);

198 llvm::hlsl::ResourceClass RC = Buf.IsCBuffer

199 ? llvm::hlsl::ResourceClass::CBuffer

200 : llvm::hlsl::ResourceClass::SRV;

201 llvm::hlsl::ResourceKind RK = Buf.IsCBuffer

202 ? llvm::hlsl::ResourceKind::CBuffer

203 : llvm::hlsl::ResourceKind::TBuffer;

204 addBufferResourceAnnotation(GV, RC, RK, false,

205 llvm::hlsl::ElementType::Invalid, Buf.Binding);

206 }

207}

208

210 : Name(D->getName()), IsCBuffer(D->isCBuffer()),

211 Binding(D->getAttr()) {}

212

213void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,

214 llvm::hlsl::ResourceClass RC,

215 llvm::hlsl::ResourceKind RK,

216 bool IsROV,

217 llvm::hlsl::ElementType ET,

218 BufferResBinding &Binding) {

220

221 NamedMDNode *ResourceMD = nullptr;

222 switch (RC) {

223 case llvm::hlsl::ResourceClass::UAV:

224 ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");

225 break;

226 case llvm::hlsl::ResourceClass::SRV:

227 ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");

228 break;

229 case llvm::hlsl::ResourceClass::CBuffer:

230 ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");

231 break;

232 default:

233 assert(false && "Unsupported buffer type!");

234 return;

235 }

236 assert(ResourceMD != nullptr &&

237 "ResourceMD must have been set by the switch above.");

238

239 llvm::hlsl::FrontendResource Res(

240 GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);

241 ResourceMD->addOperand(Res.getMetadata());

242}

243

244static llvm::hlsl::ElementType

246 using llvm::hlsl::ElementType;

247

248

249

251 assert(TST && "Resource types must be template specializations");

253 assert(!Args.empty() && "Resource has no element type");

254

255

256

257 QualType ElTy = Args[0].getAsType();

258

259

261 ElTy = VecTy->getElementType();

262

265 case 16:

266 return ElementType::I16;

267 case 32:

268 return ElementType::I32;

269 case 64:

270 return ElementType::I64;

271 }

274 case 16:

275 return ElementType::U16;

276 case 32:

277 return ElementType::U32;

278 case 64:

279 return ElementType::U64;

280 }

282 return ElementType::F16;

284 return ElementType::F32;

286 return ElementType::F64;

287

288

289 llvm_unreachable("Invalid element type for resource");

290}

291

293 const Type *Ty = D->getType()->getPointeeOrArrayElementType();

294 if (!Ty)

295 return;

297 if (!RD)

298 return;

299

300

301 for (auto *FD : RD->fields()) {

302 const auto *HLSLResAttr = FD->getAttr();

304 dyn_cast(FD->getType().getTypePtr());

305 if (!HLSLResAttr || !AttrResType)

306 continue;

307

308 llvm::hlsl::ResourceClass RC = AttrResType->getAttrs().ResourceClass;

309 if (RC == llvm::hlsl::ResourceClass::UAV ||

310 RC == llvm::hlsl::ResourceClass::SRV)

311

312

313

314

315

316

317 return;

318

319 bool IsROV = AttrResType->getAttrs().IsROV;

320 llvm::hlsl::ResourceKind RK = HLSLResAttr->getResourceKind();

322

323 BufferResBinding Binding(D->getAttr());

324 addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding);

325 }

326}

327

328CGHLSLRuntime::BufferResBinding::BufferResBinding(

329 HLSLResourceBindingAttr *Binding) {

330 if (Binding) {

331 llvm::APInt RegInt(64, 0);

332 Binding->getSlot().substr(1).getAsInteger(10, RegInt);

333 Reg = RegInt.getLimitedValue();

334 llvm::APInt SpaceInt(64, 0);

335 Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);

336 Space = SpaceInt.getLimitedValue();

337 } else {

338 Space = 0;

339 }

340}

341

343 const FunctionDecl *FD, llvm::Function *Fn) {

344 const auto *ShaderAttr = FD->getAttr();

345 assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");

346 const StringRef ShaderAttrKindStr = "hlsl.shader";

347 Fn->addFnAttr(ShaderAttrKindStr,

348 llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));

349 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr()) {

350 const StringRef NumThreadsKindStr = "hlsl.numthreads";

351 std::string NumThreadsStr =

352 formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),

353 NumThreadsAttr->getZ());

354 Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);

355 }

356 if (HLSLWaveSizeAttr *WaveSizeAttr = FD->getAttr()) {

357 const StringRef WaveSizeKindStr = "hlsl.wavesize";

358 std::string WaveSizeStr =

359 formatv("{0},{1},{2}", WaveSizeAttr->getMin(), WaveSizeAttr->getMax(),

360 WaveSizeAttr->getPreferred());

361 Fn->addFnAttr(WaveSizeKindStr, WaveSizeStr);

362 }

363 Fn->addFnAttr(llvm::Attribute::NoInline);

364}

365

367 if (const auto *VT = dyn_cast(Ty)) {

369 for (unsigned I = 0; I < VT->getNumElements(); ++I) {

370 Value *Elt = B.CreateCall(F, {B.getInt32(I)});

371 Result = B.CreateInsertElement(Result, Elt, I);

372 }

374 }

375 return B.CreateCall(F, {B.getInt32(0)});

376}

377

380 llvm::Type *Ty) {

381 assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");

382 if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {

383 llvm::Function *DxGroupIndex =

384 CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);

385 return B.CreateCall(FunctionCallee(DxGroupIndex));

386 }

387 if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {

388 llvm::Function *ThreadIDIntrinsic =

391 }

392 if (D.hasAttr<HLSLSV_GroupThreadIDAttr>()) {

393 llvm::Function *GroupThreadIDIntrinsic =

396 }

397 if (D.hasAttr<HLSLSV_GroupIDAttr>()) {

398 llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(getGroupIdIntrinsic());

400 }

401 assert(false && "Unhandled parameter attribute");

402 return nullptr;

403}

404

406 llvm::Function *Fn) {

408 llvm::LLVMContext &Ctx = M.getContext();

409 auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);

411 Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);

412

413

414

415 AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,

416 Fn->getAttributes().getFnAttrs());

417 EntryFn->setAttributes(NewAttrs);

419

420

421 Fn->setLinkage(GlobalValue::InternalLinkage);

422

423 BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);

424 IRBuilder<> B(BB);

426

429 assert(EntryFn->isConvergent());

430 llvm::Value *I = B.CreateIntrinsic(

431 llvm::Intrinsic::experimental_convergence_entry, {}, {});

432 llvm::Value *bundleArgs[] = {I};

433 OB.emplace_back("convergencectrl", bundleArgs);

434 }

435

436

437

438 unsigned SRetOffset = 0;

439 for (const auto &Param : Fn->args()) {

440 if (Param.hasStructRetAttr()) {

441

442

443 SRetOffset = 1;

444 Args.emplace_back(PoisonValue::get(Param.getType()));

445 continue;

446 }

449 }

450

451 CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);

452 CI->setCallingConv(Fn->getCallingConv());

453

454

455 B.CreateRetVoid();

456}

457

459 llvm::Function *Fn) {

461 const StringRef ExportAttrKindStr = "hlsl.export";

462 Fn->addFnAttr(ExportAttrKindStr);

463 }

464}

465

467 bool CtorOrDtor) {

468 const auto *GV =

469 M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");

470 if (!GV)

471 return;

472 const auto *CA = dyn_cast(GV->getInitializer());

473 if (!CA)

474 return;

475

476

477

478

480 for (const auto &Ctor : CA->operands()) {

481 if (isa(Ctor))

482 continue;

483 ConstantStruct *CS = cast(Ctor);

484

485 assert(cast(CS->getOperand(0))->getValue() == 65535 &&

486 "HLSL doesn't support setting priority for global ctors.");

487 assert(isa(CS->getOperand(2)) &&

488 "HLSL doesn't support COMDat for global ctors.");

489 Fns.push_back(cast(CS->getOperand(1)));

490 }

491}

492

499

500

501

502

503 for (auto &F : M.functions()) {

504 if (!F.hasFnAttribute("hlsl.shader"))

505 continue;

507 Instruction *IP = &*F.getEntryBlock().begin();

510 llvm::Value *bundleArgs[] = {Token};

511 OB.emplace_back("convergencectrl", bundleArgs);

512 IP = Token->getNextNode();

513 }

514 IRBuilder<> B(IP);

515 for (auto *Fn : CtorFns) {

516 auto CI = B.CreateCall(FunctionCallee(Fn), {}, OB);

517 CI->setCallingConv(Fn->getCallingConv());

518 }

519

520

521 B.SetInsertPoint(F.back().getTerminator());

522 for (auto *Fn : DtorFns) {

523 auto CI = B.CreateCall(FunctionCallee(Fn), {}, OB);

524 CI->setCallingConv(Fn->getCallingConv());

525 }

526 }

527

528

529

530 Triple T(M.getTargetTriple());

531 if (T.getEnvironment() != Triple::EnvironmentType::Library) {

532 if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))

533 GV->eraseFromParent();

534 if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))

535 GV->eraseFromParent();

536 }

537}

538

540 llvm::GlobalVariable *GV) {

541

542

543 const HLSLResourceBindingAttr *RBA = VD->getAttr();

544 if (!RBA)

545 return;

546

549

550

551

552 return;

553

554 ResourcesToBind.emplace_back(VD, GV);

555}

556

558 return !ResourcesToBind.empty();

559}

560

562

564

566 llvm::Type *Int1Ty = llvm::Type::getInt1Ty(Ctx);

567

568 llvm::Function *InitResBindingsFunc =

569 llvm::Function::Create(llvm::FunctionType::get(CGM.VoidTy, false),

570 llvm::GlobalValue::InternalLinkage,

572

573 llvm::BasicBlock *EntryBB =

574 llvm::BasicBlock::Create(Ctx, "entry", InitResBindingsFunc);

576 const DataLayout &DL = CGM.getModule().getDataLayout();

577 Builder.SetInsertPoint(EntryBB);

578

579 for (const auto &[VD, GV] : ResourcesToBind) {

580 for (Attr *A : VD->getAttrs()) {

581 HLSLResourceBindingAttr *RBA = dyn_cast(A);

582 if (!RBA)

583 continue;

584

587 VD->getType().getTypePtr());

588

589

590

591

592 assert(AttrResType != nullptr &&

593 "Resource class must have a handle of HLSLAttributedResourceType");

594

595 llvm::Type *TargetTy =

597 assert(TargetTy != nullptr &&

598 "Failed to convert resource handle to target type");

599

600 auto *Space = llvm::ConstantInt::get(CGM.IntTy, RBA->getSpaceNumber());

601 auto *Slot = llvm::ConstantInt::get(CGM.IntTy, RBA->getSlotNumber());

602

603 auto *Range = llvm::ConstantInt::get(CGM.IntTy, 1);

604 auto *Index = llvm::ConstantInt::get(CGM.IntTy, 0);

605

606 auto *NonUniform = llvm::ConstantInt::get(Int1Ty, false);

607 llvm::Value *Args[] = {Space, Slot, Range, Index, NonUniform};

608

609 llvm::Value *CreateHandle = Builder.CreateIntrinsic(

610 TargetTy, getCreateHandleFromBindingIntrinsic(), Args,

611 nullptr, Twine(VD->getName()).concat("_h"));

612

613 llvm::Value *HandleRef =

614 Builder.CreateStructGEP(GV->getValueType(), GV, 0);

615 Builder.CreateAlignedStore(CreateHandle, HandleRef,

616 HandleRef->getPointerAlignment(DL));

617 }

618 }

619

620 Builder.CreateRetVoid();

621 return InitResBindingsFunc;

622}

623

626 return nullptr;

627

628 auto E = BB.end();

629 for (auto I = BB.begin(); I != E; ++I) {

630 auto *II = dyn_castllvm::IntrinsicInst(&*I);

631 if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID())) {

632 return II;

633 }

634 }

635 llvm_unreachable("Convergence token should have been emitted.");

636 return nullptr;

637}

static llvm::hlsl::ElementType calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy)

static void gatherFunctions(SmallVectorImpl< Function * > &Fns, llvm::Module &M, bool CtorOrDtor)

static Value * buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty)

static std::string getName(const CallEvent &Call)

Defines the clang::TargetOptions class.

Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...

uint64_t getTypeSize(QualType T) const

Return the size of the specified (complete) type T, in bits.

Attr - This represents one attribute.

This class gathers all debug information during compilation and is responsible for emitting to llvm g...

llvm::Instruction * getConvergenceToken(llvm::BasicBlock &BB)

void setHLSLEntryAttributes(const FunctionDecl *FD, llvm::Function *Fn)

void setHLSLFunctionAttributes(const FunctionDecl *FD, llvm::Function *Fn)

void emitEntryFunction(const FunctionDecl *FD, llvm::Function *Fn)

bool needsResourceBindingInitFn()

void handleGlobalVarDefinition(const VarDecl *VD, llvm::GlobalVariable *Var)

llvm::Type * convertHLSLSpecificType(const Type *T)

llvm::Value * emitInputSemantic(llvm::IRBuilder<> &B, const ParmVarDecl &D, llvm::Type *Ty)

llvm::Function * createResourceBindingInitFn()

void annotateHLSLResource(const VarDecl *D, llvm::GlobalVariable *GV)

void addBuffer(const HLSLBufferDecl *D)

void generateGlobalCtorDtorCalls()

llvm::Module & getModule() const

CGDebugInfo * getModuleDebugInfo()

const TargetInfo & getTarget() const

void EmitGlobal(GlobalDecl D)

Emit code for a single global function or var decl.

bool shouldEmitConvergenceTokens() const

ASTContext & getContext() const

llvm::Constant * GetAddrOfGlobalVar(const VarDecl *D, llvm::Type *Ty=nullptr, ForDefinition_t IsForDefinition=NotForDefinition)

Return the llvm::Constant for the address of the given global variable.

const TargetCodeGenInfo & getTargetCodeGenInfo()

const CodeGenOptions & getCodeGenOpts() const

llvm::LLVMContext & getLLVMContext()

llvm::Function * getIntrinsic(unsigned IID, ArrayRef< llvm::Type * > Tys={})

void EmitTopLevelDecl(Decl *D)

Emit code for a single top level declaration.

virtual llvm::Type * getHLSLType(CodeGenModule &CGM, const Type *T) const

Return an LLVM type that corresponds to a HLSL type.

DeclContext - This is used only as base class of specific decl types that can act as declaration cont...

decl_range decls() const

decls_begin/decls_end - Iterate over the declarations stored in this context.

Decl - This represents one declaration (or definition), e.g.

bool isInExportDeclContext() const

Whether this declaration was exported in a lexical context.

Represents a function declaration or definition.

const ParmVarDecl * getParamDecl(unsigned i) const

static const HLSLAttributedResourceType * findHandleTypeOnResource(const Type *RT)

HLSLBufferDecl - Represent a cbuffer or tbuffer declaration.

StringRef getName() const

Get the name of identifier for this declaration as a StringRef.

Represents a parameter to a function.

A (possibly-)qualified type.

const Type * getTypePtr() const

Retrieves a pointer to the underlying (unqualified) type.

TargetOptions & getTargetOpts() const

Retrieve the target options.

const llvm::Triple & getTriple() const

Returns the target triple of the primary target.

Represents a type template specialization; the template must be a class template, a type alias templa...

Token - This structure provides full information about a lexed token.

The base class of the type hierarchy.

CXXRecordDecl * getAsCXXRecordDecl() const

Retrieves the CXXRecordDecl that this type refers to, either because the type is a RecordType or beca...

bool isSignedIntegerType() const

Return true if this is an integer type that is signed, according to C99 6.2.5p4 [char,...

bool isSpecificBuiltinType(unsigned K) const

Test for a particular builtin type.

bool isHLSLSpecificType() const

bool isUnsignedIntegerType() const

Return true if this is an integer type that is unsigned, according to C99 6.2.5p6 [which returns true...

const T * getAs() const

Member-template getAs'.

Represents a variable declaration or definition.

Represents a GCC generic vector type.

bool Const(InterpState &S, CodePtr OpPC, const T &Arg)

The JSON file list parser is used to communicate input to InstallAPI.

@ Result

The result type of a method or function.

const FunctionProtoType * T

Diagnostic wrappers for TextAPI types for error reporting.

std::vector< std::pair< llvm::GlobalVariable *, unsigned > > Constants

llvm::StructType * LayoutStruct

Buffer(const HLSLBufferDecl *D)

llvm::IntegerType * IntTy

int