LLVM: lib/Frontend/HLSL/RootSignatureMetadata.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
20
21using namespace llvm;
22
23namespace llvm {
24namespace hlsl {
26
28
30 unsigned int OpId) {
31 if (auto *CI =
33 return CI->getZExtValue();
34 return std::nullopt;
35}
36
38 unsigned int OpId) {
40 return CI->getValueAPF().convertToFloat();
41 return std::nullopt;
42}
43
45 unsigned int OpId) {
47 if (NodeText == nullptr)
48 return std::nullopt;
50}
51
52namespace {
53
54
55
56
57template <class... Ts> struct OverloadedVisit : Ts... {
58 using Ts::operator()...;
59};
60template <class... Ts> OverloadedVisit(Ts...) -> OverloadedVisit<Ts...>;
61
62struct FmtRange {
66
69 Space(Range.RegisterSpace) {}
70};
71
72raw_ostream &operator<<(llvm::raw_ostream &OS, const FmtRange &Range) {
74 << ", space=" << Range.Space << ")";
75 return OS;
76}
77
78struct FmtMDNode {
79 const MDNode *Node;
80
81 FmtMDNode(const MDNode *Node) : Node(Node) {}
82};
83
84raw_ostream &operator<<(llvm::raw_ostream &OS, FmtMDNode Fmt) {
85 Fmt.Node->printTree(OS);
86 return OS;
87}
88
89static Error makeRSError(const Twine &Msg) {
91}
92}
93
94template <typename T, typename = std::enable_if_t<
95 std::is_enum_v &&
96 std::is_same_v<std::underlying_type_t, uint32_t>>>
97static Expected
101 if (!VerifyFn(*Val))
102 return makeRSError(formatv("Invalid value for {0}: {1}", ErrText, Val));
103 return static_cast<T>(*Val);
104 }
105 return makeRSError(formatv("Invalid value for {0}:", ErrText));
106}
107
109 const auto Visitor = OverloadedVisit{
111 return BuildRootFlags(Flags);
112 },
114 return BuildRootConstants(Constants);
115 },
117 return BuildRootDescriptor(Descriptor);
118 },
120 return BuildDescriptorTableClause(Clause);
121 },
123 return BuildDescriptorTable(Table);
124 },
126 return BuildStaticSampler(Sampler);
127 },
128 };
129
130 for (const RootElement &Element : Elements) {
131 MDNode *ElementMD = std::visit(Visitor, Element);
132 assert(ElementMD != nullptr &&
133 "Root Element must be initialized and validated");
134 GeneratedMetadata.push_back(ElementMD);
135 }
136
137 return MDNode::get(Ctx, GeneratedMetadata);
138}
139
145 };
147}
148
149MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
154 Builder.getInt32(to_underlying(Constants.Visibility))),
158 };
160}
161
162MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) {
165 assert(!ResName.empty() && "Provided an invalid Resource Class");
166 SmallString<7> Name({"Root", ResName});
170 Builder.getInt32(to_underlying(Descriptor.Visibility))),
174 Builder.getInt32(to_underlying(Descriptor.Flags))),
175 };
177}
178
179MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
182
185 Builder.getInt32(to_underlying(Table.Visibility))));
186
187
188
189
190 assert(Table.NumClauses <= GeneratedMetadata.size() &&
191 "Table expected all owned clauses to be generated already");
192
193 TableOperands.append(GeneratedMetadata.end() - Table.NumClauses,
194 GeneratedMetadata.end());
195
196 GeneratedMetadata.pop_back_n(Table.NumClauses);
197
199}
200
201MDNode *MetadataBuilder::BuildDescriptorTableClause(
205 assert(!ResName.empty() && "Provided an invalid Resource Class");
213 };
215}
216
217MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
244 };
246}
247
248Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
249 MDNode *RootFlagNode) {
251 return makeRSError("Invalid format for RootFlags Element");
252
253 if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
254 RSD.Flags = *Val;
255 else
256 return makeRSError("Invalid value for RootFlag");
257
259}
260
261Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
262 MDNode *RootConstantNode) {
264 return makeRSError("Invalid format for RootConstants Element");
265
266 Expecteddxbc::ShaderVisibility Visibility =
268 "ShaderVisibility",
271 return Error(std::move(E));
272
274 if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
276 else
277 return makeRSError("Invalid value for ShaderRegister");
278
279 if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
281 else
282 return makeRSError("Invalid value for RegisterSpace");
283
284 if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
286 else
287 return makeRSError("Invalid value for Num32BitValues");
288
290 *Visibility, Constants);
291
293}
294
295Error MetadataParser::parseRootDescriptors(
296 mcdxbc::RootSignatureDesc &RSD, MDNode *RootDescriptorNode,
301 "parseRootDescriptors should only be called with RootDescriptor "
302 "element kind.");
304 return makeRSError("Invalid format for Root Descriptor Element");
305
307 switch (ElementKind) {
309 Type = dxbc::RootParameterType::SRV;
310 break;
312 Type = dxbc::RootParameterType::UAV;
313 break;
315 Type = dxbc::RootParameterType::CBV;
316 break;
317 default:
319 break;
320 }
321
322 Expecteddxbc::ShaderVisibility Visibility =
324 "ShaderVisibility",
327 return Error(std::move(E));
328
329 mcdxbc::RootDescriptor Descriptor;
330 if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
332 else
333 return makeRSError("Invalid value for ShaderRegister");
334
335 if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
337 else
338 return makeRSError("Invalid value for RegisterSpace");
339
340 if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
341 Descriptor.Flags = *Val;
342 else
343 return makeRSError("Invalid value for Root Descriptor Flags");
344
347}
348
349Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
350 MDNode *RangeDescriptorNode) {
352 return makeRSError("Invalid format for Descriptor Range");
353
354 mcdxbc::DescriptorRange Range;
355
356 std::optional ElementText =
358
359 if (!ElementText.has_value())
360 return makeRSError("Invalid format for Descriptor Range");
361
362 if (*ElementText == "CBV")
364 else if (*ElementText == "SRV")
366 else if (*ElementText == "UAV")
368 else if (*ElementText == "Sampler")
370 else
371 return makeRSError(formatv("Invalid Descriptor Range type.\n{0}",
372 FmtMDNode{RangeDescriptorNode}));
373
374 if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
375 Range.NumDescriptors = *Val;
376 else
377 return makeRSError(formatv("Invalid number of Descriptor in Range.\n{0}",
378 FmtMDNode{RangeDescriptorNode}));
379
380 if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
381 Range.BaseShaderRegister = *Val;
382 else
383 return makeRSError("Invalid value for BaseShaderRegister");
384
385 if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
386 Range.RegisterSpace = *Val;
387 else
388 return makeRSError("Invalid value for RegisterSpace");
389
390 if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
391 Range.OffsetInDescriptorsFromTableStart = *Val;
392 else
393 return makeRSError("Invalid value for OffsetInDescriptorsFromTableStart");
394
395 if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
396 Range.Flags = *Val;
397 else
398 return makeRSError("Invalid value for Descriptor Range Flags");
399
402}
403
404Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
405 MDNode *DescriptorTableNode) {
406 const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
407 if (NumOperands < 2)
408 return makeRSError("Invalid format for Descriptor Table");
409
410 Expecteddxbc::ShaderVisibility Visibility =
412 "ShaderVisibility",
415 return Error(std::move(E));
416
417 mcdxbc::DescriptorTable Table;
418
419 for (unsigned int I = 2; I < NumOperands; I++) {
421 if (Element == nullptr)
422 return makeRSError(formatv("Missing Root Element Metadata Node.\n{0}",
423 FmtMDNode{DescriptorTableNode}));
424
425 if (auto Err = parseDescriptorRange(Table, Element))
426 return Err;
427 }
428
430 *Visibility, Table);
432}
433
434Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
435 MDNode *StaticSamplerNode) {
437 return makeRSError("Invalid format for Static Sampler");
438
439 mcdxbc::StaticSampler Sampler;
440
443 if (auto E = Filter.takeError())
444 return Error(std::move(E));
446
447 Expecteddxbc::TextureAddressMode AddressU =
451 return Error(std::move(E));
452 Sampler.AddressU = *AddressU;
453
454 Expecteddxbc::TextureAddressMode AddressV =
458 return Error(std::move(E));
459 Sampler.AddressV = *AddressV;
460
461 Expecteddxbc::TextureAddressMode AddressW =
465 return Error(std::move(E));
466 Sampler.AddressW = *AddressW;
467
468 if (std::optional Val = extractMdFloatValue(StaticSamplerNode, 5))
469 Sampler.MipLODBias = *Val;
470 else
471 return makeRSError("Invalid value for MipLODBias");
472
473 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
474 Sampler.MaxAnisotropy = *Val;
475 else
476 return makeRSError("Invalid value for MaxAnisotropy");
477
482 return Error(std::move(E));
484
485 Expecteddxbc::StaticBorderColor BorderColor =
488 if (auto E = BorderColor.takeError())
489 return Error(std::move(E));
490 Sampler.BorderColor = *BorderColor;
491
492 if (std::optional Val = extractMdFloatValue(StaticSamplerNode, 9))
494 else
495 return makeRSError("Invalid value for MinLOD");
496
497 if (std::optional Val = extractMdFloatValue(StaticSamplerNode, 10))
499 else
500 return makeRSError("Invalid value for MaxLOD");
501
502 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
503 Sampler.ShaderRegister = *Val;
504 else
505 return makeRSError("Invalid value for ShaderRegister");
506
507 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
508 Sampler.RegisterSpace = *Val;
509 else
510 return makeRSError("Invalid value for RegisterSpace");
511
512 Expecteddxbc::ShaderVisibility Visibility =
514 "ShaderVisibility",
517 return Error(std::move(E));
518 Sampler.ShaderVisibility = *Visibility;
519
520 if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 14))
522 else
523 return makeRSError("Invalid value for Static Sampler Flags");
524
527}
528
529Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
530 MDNode *Element) {
532 if (!ElementText.has_value())
533 return makeRSError("Invalid format for Root Element");
534
536 StringSwitch(*ElementText)
545
546 switch (ElementKind) {
547
549 return parseRootFlags(RSD, Element);
551 return parseRootConstants(RSD, Element);
555 return parseRootDescriptors(RSD, Element, ElementKind);
557 return parseDescriptorTable(RSD, Element);
559 return parseStaticSampler(RSD, Element);
561 return makeRSError(
562 formatv("Invalid Root Signature Element\n{0}", FmtMDNode{Element}));
563 }
564
565 llvm_unreachable("Unhandled RootSignatureElementKind enum.");
566}
567
575 return makeRSError(
576 formatv("Samplers cannot be mixed with other resource types in a "
577 "descriptor table, {0}(location={1})",
578 getResourceClassName(CurrRC), Location));
579 CurrRC = Range.RangeType;
580 }
582}
583
588 bool IsPrevUnbound = false;
590
591 if (Range.NumDescriptors == 0)
592 continue;
593
595 Range.BaseShaderRegister, Range.NumDescriptors);
596
598 return makeRSError(
599 formatv("Overflow for shader register range: {0}", FmtRange{Range}));
600
601 bool IsAppending =
603 if (!IsAppending)
604 Offset = Range.OffsetInDescriptorsFromTableStart;
605
606 if (IsPrevUnbound && IsAppending)
607 return makeRSError(
608 formatv("Range {0} cannot be appended after an unbounded range",
609 FmtRange{Range}));
610
613
615 return makeRSError(formatv("Offset overflow for descriptor range: {0}.",
616 FmtRange{Range}));
617
618 Offset = OffsetBound + 1;
619 IsPrevUnbound =
621 }
622
624}
625
626Error MetadataParser::validateRootSignature(
631 std::move(DeferredErrs),
632 makeRSError(formatv("Invalid value for Version: {0}", RSD.Version)));
633 }
634
637 std::move(DeferredErrs),
638 makeRSError(formatv("Invalid value for RootFlags: {0}", RSD.Flags)));
639 }
640
642
643 switch (Info.Type) {
644 case dxbc::RootParameterType::Constants32Bit:
645 break;
646
647 case dxbc::RootParameterType::CBV:
648 case dxbc::RootParameterType::UAV:
649 case dxbc::RootParameterType::SRV: {
654 std::move(DeferredErrs),
655 makeRSError(formatv("Invalid value for ShaderRegister: {0}",
657
660 std::move(DeferredErrs),
661 makeRSError(formatv("Invalid value for RegisterSpace: {0}",
663
664 bool IsValidFlag =
668 if (!IsValidFlag)
670 std::move(DeferredErrs),
671 makeRSError(formatv("Invalid value for RootDescriptorFlag: {0}",
672 Descriptor.Flags)));
673 break;
674 }
675 case dxbc::RootParameterType::DescriptorTable: {
676 const mcdxbc::DescriptorTable &Table =
678 for (const mcdxbc::DescriptorRange &Range : Table) {
681 std::move(DeferredErrs),
682 makeRSError(formatv("Invalid value for RegisterSpace: {0}",
683 Range.RegisterSpace)));
684
687 std::move(DeferredErrs),
688 makeRSError(formatv("Invalid value for NumDescriptors: {0}",
689 Range.NumDescriptors)));
690
695 if (!IsValidFlag)
697 std::move(DeferredErrs),
698 makeRSError(formatv("Invalid value for DescriptorFlag: {0}",
700
703 DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
704
707 DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
708 }
709 break;
710 }
711 }
712 }
713
714 for (const mcdxbc::StaticSampler &Sampler : RSD.StaticSamplers) {
715
717 DeferredErrs =
719 makeRSError(formatv("Invalid value for MipLODBias: {0:e}",
721
723 DeferredErrs =
725 makeRSError(formatv("Invalid value for MaxAnisotropy: {0}",
726 Sampler.MaxAnisotropy)));
727
729 DeferredErrs =
731 makeRSError(formatv("Invalid value for MinLOD: {0}",
733
735 DeferredErrs =
737 makeRSError(formatv("Invalid value for MaxLOD: {0}",
739
742 std::move(DeferredErrs),
743 makeRSError(formatv("Invalid value for ShaderRegister: {0}",
744 Sampler.ShaderRegister)));
745
747 DeferredErrs =
749 makeRSError(formatv("Invalid value for RegisterSpace: {0}",
750 Sampler.RegisterSpace)));
751 bool IsValidFlag =
755 if (!IsValidFlag)
757 std::move(DeferredErrs),
758 makeRSError(formatv("Invalid value for Static Sampler Flag: {0}",
760 }
761
762 return DeferredErrs;
763}
764
765Expectedmcdxbc::RootSignatureDesc
770 for (const auto &Operand : Root->operands()) {
772 if (Element == nullptr)
774 std::move(DeferredErrs),
775 makeRSError(formatv("Missing Root Element Metadata Node.")));
776
777 if (auto Err = parseRootSignatureElement(RSD, Element))
778 DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
779 }
780
781 if (auto Err = validateRootSignature(RSD))
782 DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
783
784 if (DeferredErrs)
785 return std::move(DeferredErrs);
786
787 return std::move(RSD);
788}
789}
790}
791}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
Analysis containing CSE Info
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
Lightweight error class with error context and mandatory checking.
static ErrorSuccess success()
Create a success value.
Error takeError()
Take ownership of the stored error.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const MDOperand & getOperand(unsigned I) const
static MDTuple * get(LLVMContext &Context, ArrayRef< Metadata * > MDs)
unsigned getNumOperands() const
Return number of MDNode operands.
LLVM_ABI StringRef getString() const
static LLVM_ABI MDString * get(LLVMContext &Context, StringRef Str)
Wrapper class representing virtual and physical registers.
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
StringRef - Represent a constant reference to a string, i.e.
constexpr bool empty() const
empty - Check if the string is empty.
The instances of the Type class are immutable: once they are created, they are never changed.
static LLVM_ABI Type * getFloatTy(LLVMContext &C)
An efficient, type-erasing, non-owning reference to a callable.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
bool isValidShaderVisibility(uint32_t V)
bool isValidSamplerFilter(uint32_t V)
bool isValidStaticSamplerFlags(uint32_t V)
bool isValidRootDesciptorFlags(uint32_t V)
bool isValidDescriptorRangeFlags(uint32_t V)
bool isValidBorderColor(uint32_t V)
bool isValidComparisonFunc(uint32_t V)
bool isValidAddress(uint32_t V)
LLVM_ABI StringRef getResourceClassName(ResourceClass RC)
static std::optional< uint32_t > extractMdIntValue(MDNode *Node, unsigned int OpId)
Definition RootSignatureMetadata.cpp:29
LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version, dxbc::RootDescriptorFlags Flags)
LLVM_ABI uint64_t computeRangeBound(uint64_t Offset, uint32_t Size)
static const uint32_t NumDescriptorsUnbounded
static Error validateDescriptorTableRegisterOverflow(const mcdxbc::DescriptorTable &Table, uint32_t Location)
Definition RootSignatureMetadata.cpp:585
LLVM_ABI bool verifyRegisterSpace(uint32_t RegisterSpace)
static const uint32_t DescriptorTableOffsetAppend
static Error validateDescriptorTableSamplerMixin(const mcdxbc::DescriptorTable &Table, uint32_t Location)
Definition RootSignatureMetadata.cpp:569
LLVM_ABI bool verifyVersion(uint32_t Version)
static std::optional< StringRef > extractMdStringValue(MDNode *Node, unsigned int OpId)
Definition RootSignatureMetadata.cpp:44
LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version, dxbc::StaticSamplerFlags Flags)
LLVM_ABI bool verifyRootFlag(uint32_t Flags)
LLVM_ABI bool verifyLOD(float LOD)
LLVM_ABI bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, dxbc::DescriptorRangeFlags Flags)
std::variant< dxbc::RootFlags, RootConstants, RootDescriptor, DescriptorTable, DescriptorTableClause, StaticSampler > RootElement
Models RootElement : RootFlags | RootConstants | RootParam | DescriptorTable | DescriptorTableClause ...
LLVM_ABI bool verifyNoOverflowedOffset(uint64_t Offset)
LLVM_ABI bool verifyMipLODBias(float MipLODBias)
LLVM_ABI bool verifyNumDescriptors(uint32_t NumDescriptors)
LLVM_ABI raw_ostream & operator<<(raw_ostream &OS, const dxbc::RootFlags &Flags)
The following contains the serialization interface for root elements.
LLVM_ABI bool verifyMaxAnisotropy(uint32_t MaxAnisotropy)
static Expected< T > extractEnumValue(MDNode *Node, unsigned int OpId, StringRef ErrText, llvm::function_ref< bool(uint32_t)> VerifyFn)
Definition RootSignatureMetadata.cpp:98
LLVM_ABI bool verifyRegisterValue(uint32_t RegisterValue)
static std::optional< float > extractMdFloatValue(MDNode *Node, unsigned int OpId)
Definition RootSignatureMetadata.cpp:37
std::enable_if_t< detail::IsValidPointer< X, Y >::value, X * > dyn_extract(Y &&MD)
Extract a Value from Metadata, if any.
This is an optimization pass for GlobalISel generic memory operations.
decltype(auto) dyn_cast(const From &Val)
dyn_cast - Return the argument parameter cast to the specified type.
auto formatv(bool Validate, const char *Fmt, Ts &&...Vals)
FunctionAddr VTableAddr uintptr_t uintptr_t Version
Error joinErrors(Error E1, Error E2)
Concatenate errors.
constexpr std::underlying_type_t< Enum > to_underlying(Enum E)
Returns underlying integer value of an enum.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
Error make_error(ArgTs &&... Args)
Make a Error instance representing failure using the given error info type.
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
SmallVector< DescriptorRange > Ranges
const RootDescriptor & getRootDescriptor(size_t Index) const
const DescriptorTable & getDescriptorTable(size_t Index) const
void addParameter(dxbc::RootParameterType Type, dxbc::ShaderVisibility Visibility, RootConstants Constant)
SmallVector< StaticSampler > StaticSamplers
mcdxbc::RootParametersContainer ParametersContainer