MLIR: lib/Dialect/SPIRV/IR/SPIRVTypes.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
20
21 #include
22 #include
23
24 using namespace mlir;
26
27
28
29
30
32 using KeyTy = std::tuple<Type, unsigned, unsigned>;
33
35 const KeyTy &key) {
37 }
38
40 return key == KeyTy(elementType, elementCount, stride);
41 }
42
44 : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
45 stride(std::get<2>(key)) {}
46
50 };
51
53 assert(elementCount && "ArrayType needs at least one element");
55 0);
56 }
57
59 unsigned stride) {
60 assert(elementCount && "ArrayType needs at least one element");
61 return Base::get(elementType.getContext(), elementType, elementCount, stride);
62 }
63
65
67
69
71 std::optional storage) {
72 llvm::cast(getElementType()).getExtensions(extensions, storage);
73 }
74
77 std::optional storage) {
79 .getCapabilities(capabilities, storage);
80 }
81
83 auto elementType = llvm::cast(getElementType());
84 std::optional<int64_t> size = elementType.getSizeInBytes();
85 if (!size)
86 return std::nullopt;
88 }
89
90
91
92
93
95 if (auto vectorType = llvm::dyn_cast(type))
96 return isValid(vectorType);
100 }
101
103 return type.getRank() == 1 &&
104 llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
105 llvm::isa(type.getElementType());
106 }
107
111 [](auto type) { return type.getElementType(); })
113 .Case(
115 .Default(
116 [](Type) -> Type { llvm_unreachable("invalid composite type"); });
117 }
118
120 if (auto arrayType = llvm::dyn_cast(*this))
121 return arrayType.getNumElements();
122 if (auto matrixType = llvm::dyn_cast(*this))
123 return matrixType.getNumColumns();
124 if (auto structType = llvm::dyn_cast(*this))
125 return structType.getNumElements();
126 if (auto vectorType = llvm::dyn_cast(*this))
127 return vectorType.getNumElements();
128 if (llvm::isa(*this)) {
129 llvm_unreachable(
130 "invalid to query number of elements of spirv Cooperative Matrix type");
131 }
132 if (llvm::isa(*this)) {
133 llvm_unreachable(
134 "invalid to query number of elements of spirv::RuntimeArray type");
135 }
136 llvm_unreachable("invalid composite type");
137 }
138
140 return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
141 }
142
145 std::optional storage) {
149 [&](auto type) { type.getExtensions(extensions, storage); })
150 .Case([&](VectorType type) {
151 return llvm::cast(type.getElementType())
152 .getExtensions(extensions, storage);
153 })
154 .Default([](Type) { llvm_unreachable("invalid composite type"); });
155 }
156
159 std::optional storage) {
163 [&](auto type) { type.getCapabilities(capabilities, storage); })
164 .Case([&](VectorType type) {
166 if (vecSize == 8 || vecSize == 16) {
167 static const Capability caps[] = {Capability::Vector16};
169 capabilities.push_back(ref);
170 }
171 return llvm::cast(type.getElementType())
172 .getCapabilities(capabilities, storage);
173 })
174 .Default([](Type) { llvm_unreachable("invalid composite type"); });
175 }
176
178 if (auto arrayType = llvm::dyn_cast(*this))
179 return arrayType.getSizeInBytes();
180 if (auto structType = llvm::dyn_cast(*this))
181 return structType.getSizeInBytes();
182 if (auto vectorType = llvm::dyn_cast(*this)) {
183 std::optional<int64_t> elementSize =
184 llvm::cast(vectorType.getElementType()).getSizeInBytes();
185 if (!elementSize)
186 return std::nullopt;
187 return *elementSize * vectorType.getNumElements();
188 }
189 return std::nullopt;
190 }
191
192
193
194
195
197
198
199
200
201
202
203
204
205
206
207
208
209
211 std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
212
217 }
218
220 return key == KeyTy(elementType, shape[0], shape[1], scope, use);
221 }
222
224 : elementType(std::get<0>(key)),
225 shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
226 use(std::get<4>(key)) {}
227
229
230 std::array<int64_t, 2> shape;
232 CooperativeMatrixUseKHR use;
233 };
234
236 uint32_t rows,
237 uint32_t columns, Scope scope,
238 CooperativeMatrixUseKHR use) {
240 use);
241 }
242
244 return getImpl()->elementType;
245 }
246
248 assert(getImpl()->shape[0] != ShapedType::kDynamic);
249 return static_cast<uint32_t>(getImpl()->shape[0]);
250 }
251
253 assert(getImpl()->shape[1] != ShapedType::kDynamic);
254 return static_cast<uint32_t>(getImpl()->shape[1]);
255 }
256
259 }
260
262
265 }
266
269 std::optional storage) {
270 llvm::cast(getElementType()).getExtensions(extensions, storage);
271 static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
272 extensions.push_back(exts);
273 }
274
277 std::optional storage) {
279 .getCapabilities(capabilities, storage);
280 static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
281 capabilities.push_back(caps);
282 }
283
284
285
286
287
288 template
290 return 0;
291 }
292 template <>
294 static_assert((1 << 3) > getMaxEnumValForDim(),
295 "Not enough bits to encode Dim value");
296 return 3;
297 }
298 template <>
300 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
301 "Not enough bits to encode ImageDepthInfo value");
302 return 2;
303 }
304 template <>
306 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
307 "Not enough bits to encode ImageArrayedInfo value");
308 return 1;
309 }
310 template <>
312 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
313 "Not enough bits to encode ImageSamplingInfo value");
314 return 1;
315 }
316 template <>
318 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
319 "Not enough bits to encode ImageSamplerUseInfo value");
320 return 2;
321 }
322 template <>
324 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
325 "Not enough bits to encode ImageFormat value");
326 return 6;
327 }
328
330 public:
331 using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
332 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
333
335 const KeyTy &key) {
337 }
338
340 return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
341 samplerUseInfo, format);
342 }
343
345 : elementType(std::get<0>(key)), dim(std::get<1>(key)),
346 depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
347 samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
348 format(std::get<6>(key)) {}
349
357 };
358
361 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
362 value) {
364 }
365
367
369
371
373 return getImpl()->arrayedInfo;
374 }
375
377 return getImpl()->samplingInfo;
378 }
379
381 return getImpl()->samplerUseInfo;
382 }
383
385
387 std::optional) {
388
389 }
390
393 std::optional) {
394 if (auto dimCaps = spirv::getCapabilities(getDim()))
395 capabilities.push_back(*dimCaps);
396
397 if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
398 capabilities.push_back(*fmtCaps);
399 }
400
401
402
403
404
406
407
408 using KeyTy = std::pair<Type, StorageClass>;
409
411 const KeyTy &key) {
414 }
415
417 return key == KeyTy(pointeeType, storageClass);
418 }
419
421 : pointeeType(key.first), storageClass(key.second) {}
422
425 };
426
429 }
430
432
434 return getImpl()->storageClass;
435 }
436
438 std::optional storage) {
439
440
443
444 if (auto scExts = spirv::getExtensions(getStorageClass()))
445 extensions.push_back(*scExts);
446 }
447
450 std::optional storage) {
451
452
455
456 if (auto scCaps = spirv::getCapabilities(getStorageClass()))
457 capabilities.push_back(*scCaps);
458 }
459
460
461
462
463
465 using KeyTy = std::pair<Type, unsigned>;
466
468 const KeyTy &key) {
471 }
472
474 return key == KeyTy(elementType, stride);
475 }
476
478 : elementType(key.first), stride(key.second) {}
479
482 };
483
486 }
487
490 }
491
493
495
498 std::optional storage) {
499 llvm::cast(getElementType()).getExtensions(extensions, storage);
500 }
501
504 std::optional storage) {
505 {
506 static const Capability caps[] = {Capability::Shader};
508 capabilities.push_back(ref);
509 }
511 .getCapabilities(capabilities, storage);
512 }
513
514
515
516
517
519 if (auto floatType = llvm::dyn_cast(type)) {
520 return isValid(floatType);
521 }
522 if (auto intType = llvm::dyn_cast(type)) {
524 }
525 return false;
526 }
527
529 return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
530 }
531
533 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
534 }
535
537 std::optional storage) {
538 if (isa(*this)) {
539 static const Extension ext = Extension::SPV_KHR_bfloat16;
540 extensions.push_back(ext);
541 }
542
543
544
545
546 if (!storage)
547 return;
548
549 switch (*storage) {
550 case StorageClass::PushConstant:
551 case StorageClass::StorageBuffer:
552 case StorageClass::Uniform:
554 static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
556 extensions.push_back(ref);
557 }
558 [[fallthrough]];
559 case StorageClass::Input:
560 case StorageClass::Output:
562 static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
564 extensions.push_back(ref);
565 }
566 break;
567 default:
568 break;
569 }
570 }
571
574 std::optional storage) {
576
577
578
579
580
581 #define STORAGE_CASE(storage, cap8, cap16) \
582 case StorageClass::storage: { \
583 if (bitwidth == 8) { \
584 static const Capability caps[] = {Capability::cap8}; \
585 ArrayRef ref(caps, std::size(caps)); \
586 capabilities.push_back(ref); \
587 return; \
588 } \
589 if (bitwidth == 16) { \
590 static const Capability caps[] = {Capability::cap16}; \
591 ArrayRef ref(caps, std::size(caps)); \
592 capabilities.push_back(ref); \
593 return; \
594 } \
595 \
596 \
597 } break
598
599
600
601 if (storage) {
602 switch (*storage) {
603 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
605 StorageBuffer16BitAccess);
607 StorageUniform16);
608 case StorageClass::Input:
609 case StorageClass::Output: {
610 if (bitwidth == 16) {
611 static const Capability caps[] = {Capability::StorageInputOutput16};
613 capabilities.push_back(ref);
614 return;
615 }
616 break;
617 }
618 default:
619 break;
620 }
621 }
622 #undef STORAGE_CASE
623
624
625
626
627 #define WIDTH_CASE(type, width) \
628 case width: { \
629 static const Capability caps[] = {Capability::type##width}; \
630 ArrayRef ref(caps, std::size(caps)); \
631 capabilities.push_back(ref); \
632 } break
633
634 if (auto intType = llvm::dyn_cast(*this)) {
635 switch (bitwidth) {
639 case 1:
640 case 32:
641 break;
642 default:
643 llvm_unreachable("invalid bitwidth to getCapabilities");
644 }
645 } else {
646 assert(llvm::isa(*this));
647 switch (bitwidth) {
648 case 16: {
649 if (isa(*this)) {
650 static const Capability cap = Capability::BFloat16TypeKHR;
651 capabilities.push_back(cap);
652 } else {
653 static const Capability cap = Capability::Float16;
654 capabilities.push_back(cap);
655 }
656 break;
657 }
659 case 32:
660 break;
661 default:
662 llvm_unreachable("invalid bitwidth to getCapabilities");
663 }
664 }
665
666 #undef WIDTH_CASE
667 }
668
671
672
673
674
675
676
677 if (bitWidth == 1)
678 return std::nullopt;
679 return bitWidth / 8;
680 }
681
682
683
684
685
687
688 if (llvm::isa(type.getDialect()))
689 return true;
690 if (llvm::isa(type))
691 return true;
692 if (auto vectorType = llvm::dyn_cast(type))
694 return false;
695 }
696
698 return isIntOrFloat() || llvm::isa(*this);
699 }
700
702 std::optional storage) {
703 if (auto scalarType = llvm::dyn_cast(*this)) {
704 scalarType.getExtensions(extensions, storage);
705 } else if (auto compositeType = llvm::dyn_cast(*this)) {
706 compositeType.getExtensions(extensions, storage);
707 } else if (auto imageType = llvm::dyn_cast(*this)) {
708 imageType.getExtensions(extensions, storage);
709 } else if (auto sampledImageType = llvm::dyn_cast(*this)) {
710 sampledImageType.getExtensions(extensions, storage);
711 } else if (auto matrixType = llvm::dyn_cast(*this)) {
712 matrixType.getExtensions(extensions, storage);
713 } else if (auto ptrType = llvm::dyn_cast(*this)) {
714 ptrType.getExtensions(extensions, storage);
715 } else {
716 llvm_unreachable("invalid SPIR-V Type to getExtensions");
717 }
718 }
719
722 std::optional storage) {
723 if (auto scalarType = llvm::dyn_cast(*this)) {
724 scalarType.getCapabilities(capabilities, storage);
725 } else if (auto compositeType = llvm::dyn_cast(*this)) {
726 compositeType.getCapabilities(capabilities, storage);
727 } else if (auto imageType = llvm::dyn_cast(*this)) {
728 imageType.getCapabilities(capabilities, storage);
729 } else if (auto sampledImageType = llvm::dyn_cast(*this)) {
730 sampledImageType.getCapabilities(capabilities, storage);
731 } else if (auto matrixType = llvm::dyn_cast(*this)) {
732 matrixType.getCapabilities(capabilities, storage);
733 } else if (auto ptrType = llvm::dyn_cast(*this)) {
734 ptrType.getCapabilities(capabilities, storage);
735 } else {
736 llvm_unreachable("invalid SPIR-V Type to getCapabilities");
737 }
738 }
739
741 if (auto scalarType = llvm::dyn_cast(*this))
742 return scalarType.getSizeInBytes();
743 if (auto compositeType = llvm::dyn_cast(*this))
744 return compositeType.getSizeInBytes();
745 return std::nullopt;
746 }
747
748
749
750
753
755
757
759 const KeyTy &key) {
762 }
763
765 };
766
769 }
770
773 Type imageType) {
775 }
776
778
779 LogicalResult
781 Type imageType) {
782 if (!llvm::isa(imageType))
783 return emitError() << "expected image type";
784
785 return success();
786 }
787
790 std::optional storage) {
791 llvm::cast(getImageType()).getExtensions(extensions, storage);
792 }
793
796 std::optional storage) {
797 llvm::cast(getImageType()).getCapabilities(capabilities, storage);
798 }
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
819
820
821
823 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
824 numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
825 identifier(identifier) {}
826
827
828
830 unsigned numMembers, Type const *memberTypes,
833 : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
834 numMembers(numMembers), numMemberDecorations(numMemberDecorations),
835 memberDecorationsInfo(memberDecorationsInfo) {}
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
854
855
856
857
858
859
861 if (isIdentified()) {
862
863 return getIdentifier() == std::get<0>(key);
864 }
865
866 return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
867 getMemberDecorationsInfo());
868 }
869
870
871
872
873
874
875
877 const KeyTy &key) {
878 StringRef keyIdentifier = std::get<0>(key);
879
880 if (!keyIdentifier.empty()) {
881 StringRef identifier = allocator.copyInto(keyIdentifier);
882
883
884
887 }
888
890
891
892 const Type *typesList = nullptr;
893 if (!keyTypes.empty()) {
894 typesList = allocator.copyInto(keyTypes).data();
895 }
896
898 if (!std::get<2>(key).empty()) {
900 assert(keyOffsetInfo.size() == keyTypes.size() &&
901 "size of offset information must be same as the size of number of "
902 "elements");
903 offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
904 }
905
907 unsigned numMemberDecorations = 0;
908 if (!std::get<3>(key).empty()) {
909 auto keyMemberDecorations = std::get<3>(key);
910 numMemberDecorations = keyMemberDecorations.size();
911 memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
912 }
913
916 numMemberDecorations, memberDecorationList);
917 }
918
920 return ArrayRef(memberTypesAndIsBodySet.getPointer(), numMembers);
921 }
922
924 if (offsetInfo) {
926 }
927 return {};
928 }
929
931 if (memberDecorationsInfo) {
933 numMemberDecorations);
934 }
935 return {};
936 }
937
939
940 bool isIdentified() const { return !identifier.empty(); }
941
942
943
944
945
946
947
948
949
954 if (!isIdentified())
955 return failure();
956
957 if (memberTypesAndIsBodySet.getInt() &&
958 (getMemberTypes() != structMemberTypes ||
959 getOffsetInfo() != structOffsetInfo ||
960 getMemberDecorationsInfo() != structMemberDecorationInfo))
961 return failure();
962
963 memberTypesAndIsBodySet.setInt(true);
964 numMembers = structMemberTypes.size();
965
966
967 if (!structMemberTypes.empty())
968 memberTypesAndIsBodySet.setPointer(
969 allocator.copyInto(structMemberTypes).data());
970
971 if (!structOffsetInfo.empty()) {
972 assert(structOffsetInfo.size() == structMemberTypes.size() &&
973 "size of offset information must be same as the size of number of "
974 "elements");
975 offsetInfo = allocator.copyInto(structOffsetInfo).data();
976 }
977
978 if (!structMemberDecorationInfo.empty()) {
979 numMemberDecorations = structMemberDecorationInfo.size();
980 memberDecorationsInfo =
981 allocator.copyInto(structMemberDecorationInfo).data();
982 }
983
984 return success();
985 }
986
993 };
994
999 assert(!memberTypes.empty() && "Struct needs at least one member type");
1000
1002 memberDecorations);
1003 llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
1004 return Base::get(memberTypes.vec().front().getContext(),
1005 StringRef(), memberTypes, offsetInfo,
1006 sortedDecorations);
1007 }
1008
1010 StringRef identifier) {
1011 assert(!identifier.empty() &&
1012 "StructType identifier must be non-empty string");
1013
1017 }
1018
1023
1029
1030 return newStructType;
1031 }
1032
1034
1036
1038
1040 assert(getNumElements() > index && "member index out of range");
1041 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1042 }
1043
1045 return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1047 }
1048
1050
1052 assert(getNumElements() > index && "member index out of range");
1053 return getImpl()->offsetInfo[index];
1054 }
1055
1058 const {
1059 memberDecorations.clear();
1060 auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1061 memberDecorations.append(implMemberDecorations.begin(),
1062 implMemberDecorations.end());
1063 }
1064
1066 unsigned index,
1068 assert(getNumElements() > index && "member index out of range");
1069 auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1070 decorationsInfo.clear();
1071 for (const auto &memberDecoration : memberDecorations) {
1072 if (memberDecoration.memberIndex == index) {
1073 decorationsInfo.push_back(memberDecoration);
1074 }
1075 if (memberDecoration.memberIndex > index) {
1076
1077 return;
1078 }
1079 }
1080 }
1081
1082 LogicalResult
1086 return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1087 }
1088
1090 std::optional storage) {
1092 llvm::cast(elementType).getExtensions(extensions, storage);
1093 }
1094
1097 std::optional storage) {
1099 llvm::cast(elementType).getCapabilities(capabilities, storage);
1100 }
1101
1104 return llvm::hash_combine(memberDecorationInfo.memberIndex,
1106 }
1107
1108
1109
1110
1111
1114 : columnType(columnType), columnCount(columnCount) {}
1115
1116 using KeyTy = std::tuple<Type, uint32_t>;
1117
1119 const KeyTy &key) {
1120
1121
1124 }
1125
1127 return key == KeyTy(columnType, columnCount);
1128 }
1129
1132 };
1133
1136 }
1137
1139 Type columnType, uint32_t columnCount) {
1141 columnCount);
1142 }
1143
1144 LogicalResult
1146 Type columnType, uint32_t columnCount) {
1147 if (columnCount < 2 || columnCount > 4)
1148 return emitError() << "matrix can have 2, 3, or 4 columns only";
1149
1151 return emitError() << "matrix columns must be vectors of floats";
1152
1153
1154 ArrayRef<int64_t> columnShape = llvm::cast(columnType).getShape();
1155 if (columnShape.size() != 1)
1156 return emitError() << "matrix columns must be 1D vectors";
1157
1158 if (columnShape[0] < 2 || columnShape[0] > 4)
1159 return emitError() << "matrix columns must be of size 2, 3, or 4";
1160
1161 return success();
1162 }
1163
1164
1166 if (auto vectorType = llvm::dyn_cast(columnType)) {
1167 if (llvm::isa(vectorType.getElementType()))
1168 return true;
1169 }
1170 return false;
1171 }
1172
1174
1176 return llvm::cast(getImpl()->columnType).getElementType();
1177 }
1178
1180
1182 return llvm::cast(getImpl()->columnType).getShape()[0];
1183 }
1184
1187 }
1188
1190 std::optional storage) {
1191 llvm::cast(getColumnType()).getExtensions(extensions, storage);
1192 }
1193
1196 std::optional storage) {
1197 {
1198 static const Capability caps[] = {Capability::Matrix};
1200 capabilities.push_back(ref);
1201 }
1202
1203 llvm::cast(getColumnType()).getCapabilities(capabilities, storage);
1204 }
1205
1206
1207
1208
1209
1210 void SPIRVDialect::registerTypes() {
1213 }
static MLIRContext * getContext(OpFoldResult val)
constexpr unsigned getNumBits< ImageSamplerUseInfo >()
#define STORAGE_CASE(storage, cap8, cap16)
constexpr unsigned getNumBits< ImageFormat >()
static constexpr unsigned getNumBits()
#define WIDTH_CASE(type, width)
constexpr unsigned getNumBits< ImageArrayedInfo >()
constexpr unsigned getNumBits< ImageSamplingInfo >()
constexpr unsigned getNumBits< Dim >()
constexpr unsigned getNumBits< ImageDepthInfo >()
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Base storage class appearing in a Type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
ImplType * getImpl() const
Utility for easy access to the storage instance.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
std::optional< int64_t > getSizeInBytes()
Returns the array size in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
bool hasCompileTimeKnownNumElements() const
Return true if the number of elements is known at compile time and is not implementation dependent.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumElements() const
Return the number of elements of the type.
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Type getElementType(unsigned) const
static bool classof(Type type)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
ArrayRef< int64_t > getShape() const
Type getElementType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumElements() const
Returns total number of elements (rows*columns).
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)
unsigned getNumColumns() const
Returns the number of columns.
static bool isValidColumnType(Type columnType)
Returns true if the matrix elements are vectors of float elements.
Type getElementType() const
Returns the elements' type (i.e, single element type).
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getNumRows() const
Returns the number of rows.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Type getPointeeType() const
StorageClass getStorageClass() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static PointerType get(Type pointeeType, StorageClass storageClass)
Type getElementType() const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
unsigned getArrayStride() const
Returns the array stride in bytes.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static RuntimeArrayType get(Type elementType)
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
static bool classof(Type type)
void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Appends to capabilities the capabilities needed for this type to appear in the given storage class.
void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Appends to extensions the extensions needed for this type to appear in the given storage class.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)
Type getImageType() const
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)
static SampledImageType get(Type imageType)
static bool classof(Type type)
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
std::optional< int64_t > getSizeInBytes()
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
TypeRange getElementTypes() const
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
ArrayTypeStorage(const KeyTy &key)
std::tuple< Type, unsigned, unsigned > KeyTy
static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
CooperativeMatrixTypeStorage(const KeyTy &key)
std::tuple< Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR > KeyTy
static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
CooperativeMatrixUseKHR use
std::array< int64_t, 2 > shape
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
ImageSamplerUseInfo samplerUseInfo
static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
ImageTypeStorage(const KeyTy &key)
std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy
ImageSamplingInfo samplingInfo
ImageArrayedInfo arrayedInfo
const uint32_t columnCount
MatrixTypeStorage(Type columnType, uint32_t columnCount)
bool operator==(const KeyTy &key) const
std::tuple< Type, uint32_t > KeyTy
static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
StorageClass storageClass
PointerTypeStorage(const KeyTy &key)
bool operator==(const KeyTy &key) const
std::pair< Type, StorageClass > KeyTy
static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
RuntimeArrayTypeStorage(const KeyTy &key)
std::pair< Type, unsigned > KeyTy
static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
SampledImageTypeStorage(const KeyTy &key)
static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
Type storage for SPIR-V structure types:
ArrayRef< StructType::OffsetInfo > getOffsetInfo() const
StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo)
Construct a storage object for a literal struct type.
StructType::OffsetInfo const * offsetInfo
bool operator==(const KeyTy &key) const
For identified structs, return true if the given key contains the same identifier.
LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo)
Sets the struct type content for identified structs.
static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
If the given key contains a non-empty identifier, this method constructs an identified struct and lea...
ArrayRef< Type > getMemberTypes() const
StructTypeStorage(StringRef identifier)
Construct a storage object for an identified struct type.
ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const
std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo > > KeyTy
A storage key is divided into 2 parts:
StructType::MemberDecorationInfo const * memberDecorationsInfo
llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet
StringRef getIdentifier() const
unsigned numMemberDecorations
bool isIdentified() const