MLIR: lib/Dialect/SPIRV/IR/SPIRVTypes.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

18 #include "llvm/ADT/STLExtras.h"

19 #include "llvm/ADT/TypeSwitch.h"

20

21 #include

22 #include

23

24 using namespace mlir;

26

27

28

29

30

32 using KeyTy = std::tuple<Type, unsigned, unsigned>;

33

35 const KeyTy &key) {

37 }

38

40 return key == KeyTy(elementType, elementCount, stride);

41 }

42

44 : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),

45 stride(std::get<2>(key)) {}

46

50 };

51

53 assert(elementCount && "ArrayType needs at least one element");

55 0);

56 }

57

59 unsigned stride) {

60 assert(elementCount && "ArrayType needs at least one element");

61 return Base::get(elementType.getContext(), elementType, elementCount, stride);

62 }

63

65

67

69

71 std::optional storage) {

72 llvm::cast(getElementType()).getExtensions(extensions, storage);

73 }

74

77 std::optional storage) {

79 .getCapabilities(capabilities, storage);

80 }

81

83 auto elementType = llvm::cast(getElementType());

84 std::optional<int64_t> size = elementType.getSizeInBytes();

85 if (!size)

86 return std::nullopt;

88 }

89

90

91

92

93

95 if (auto vectorType = llvm::dyn_cast(type))

96 return isValid(vectorType);

100 }

101

103 return type.getRank() == 1 &&

104 llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&

105 llvm::isa(type.getElementType());

106 }

107

111 [](auto type) { return type.getElementType(); })

113 .Case(

115 .Default(

116 [](Type) -> Type { llvm_unreachable("invalid composite type"); });

117 }

118

120 if (auto arrayType = llvm::dyn_cast(*this))

121 return arrayType.getNumElements();

122 if (auto matrixType = llvm::dyn_cast(*this))

123 return matrixType.getNumColumns();

124 if (auto structType = llvm::dyn_cast(*this))

125 return structType.getNumElements();

126 if (auto vectorType = llvm::dyn_cast(*this))

127 return vectorType.getNumElements();

128 if (llvm::isa(*this)) {

129 llvm_unreachable(

130 "invalid to query number of elements of spirv Cooperative Matrix type");

131 }

132 if (llvm::isa(*this)) {

133 llvm_unreachable(

134 "invalid to query number of elements of spirv::RuntimeArray type");

135 }

136 llvm_unreachable("invalid composite type");

137 }

138

140 return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);

141 }

142

145 std::optional storage) {

149 [&](auto type) { type.getExtensions(extensions, storage); })

150 .Case([&](VectorType type) {

151 return llvm::cast(type.getElementType())

152 .getExtensions(extensions, storage);

153 })

154 .Default([](Type) { llvm_unreachable("invalid composite type"); });

155 }

156

159 std::optional storage) {

163 [&](auto type) { type.getCapabilities(capabilities, storage); })

164 .Case([&](VectorType type) {

166 if (vecSize == 8 || vecSize == 16) {

167 static const Capability caps[] = {Capability::Vector16};

169 capabilities.push_back(ref);

170 }

171 return llvm::cast(type.getElementType())

172 .getCapabilities(capabilities, storage);

173 })

174 .Default([](Type) { llvm_unreachable("invalid composite type"); });

175 }

176

178 if (auto arrayType = llvm::dyn_cast(*this))

179 return arrayType.getSizeInBytes();

180 if (auto structType = llvm::dyn_cast(*this))

181 return structType.getSizeInBytes();

182 if (auto vectorType = llvm::dyn_cast(*this)) {

183 std::optional<int64_t> elementSize =

184 llvm::cast(vectorType.getElementType()).getSizeInBytes();

185 if (!elementSize)

186 return std::nullopt;

187 return *elementSize * vectorType.getNumElements();

188 }

189 return std::nullopt;

190 }

191

192

193

194

195

197

198

199

200

201

202

203

204

205

206

207

208

209

211 std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;

212

217 }

218

220 return key == KeyTy(elementType, shape[0], shape[1], scope, use);

221 }

222

224 : elementType(std::get<0>(key)),

225 shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),

226 use(std::get<4>(key)) {}

227

229

230 std::array<int64_t, 2> shape;

232 CooperativeMatrixUseKHR use;

233 };

234

236 uint32_t rows,

237 uint32_t columns, Scope scope,

238 CooperativeMatrixUseKHR use) {

240 use);

241 }

242

244 return getImpl()->elementType;

245 }

246

248 assert(getImpl()->shape[0] != ShapedType::kDynamic);

249 return static_cast<uint32_t>(getImpl()->shape[0]);

250 }

251

253 assert(getImpl()->shape[1] != ShapedType::kDynamic);

254 return static_cast<uint32_t>(getImpl()->shape[1]);

255 }

256

259 }

260

262

265 }

266

269 std::optional storage) {

270 llvm::cast(getElementType()).getExtensions(extensions, storage);

271 static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};

272 extensions.push_back(exts);

273 }

274

277 std::optional storage) {

279 .getCapabilities(capabilities, storage);

280 static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};

281 capabilities.push_back(caps);

282 }

283

284

285

286

287

288 template

290 return 0;

291 }

292 template <>

294 static_assert((1 << 3) > getMaxEnumValForDim(),

295 "Not enough bits to encode Dim value");

296 return 3;

297 }

298 template <>

300 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),

301 "Not enough bits to encode ImageDepthInfo value");

302 return 2;

303 }

304 template <>

306 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),

307 "Not enough bits to encode ImageArrayedInfo value");

308 return 1;

309 }

310 template <>

312 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),

313 "Not enough bits to encode ImageSamplingInfo value");

314 return 1;

315 }

316 template <>

318 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),

319 "Not enough bits to encode ImageSamplerUseInfo value");

320 return 2;

321 }

322 template <>

324 static_assert((1 << 6) > getMaxEnumValForImageFormat(),

325 "Not enough bits to encode ImageFormat value");

326 return 6;

327 }

328

330 public:

331 using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,

332 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;

333

335 const KeyTy &key) {

337 }

338

340 return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,

341 samplerUseInfo, format);

342 }

343

345 : elementType(std::get<0>(key)), dim(std::get<1>(key)),

346 depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),

347 samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),

348 format(std::get<6>(key)) {}

349

357 };

358

361 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>

362 value) {

364 }

365

367

369

371

373 return getImpl()->arrayedInfo;

374 }

375

377 return getImpl()->samplingInfo;

378 }

379

381 return getImpl()->samplerUseInfo;

382 }

383

385

387 std::optional) {

388

389 }

390

393 std::optional) {

394 if (auto dimCaps = spirv::getCapabilities(getDim()))

395 capabilities.push_back(*dimCaps);

396

397 if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))

398 capabilities.push_back(*fmtCaps);

399 }

400

401

402

403

404

406

407

408 using KeyTy = std::pair<Type, StorageClass>;

409

411 const KeyTy &key) {

414 }

415

417 return key == KeyTy(pointeeType, storageClass);

418 }

419

421 : pointeeType(key.first), storageClass(key.second) {}

422

425 };

426

429 }

430

432

434 return getImpl()->storageClass;

435 }

436

438 std::optional storage) {

439

440

443

444 if (auto scExts = spirv::getExtensions(getStorageClass()))

445 extensions.push_back(*scExts);

446 }

447

450 std::optional storage) {

451

452

455

456 if (auto scCaps = spirv::getCapabilities(getStorageClass()))

457 capabilities.push_back(*scCaps);

458 }

459

460

461

462

463

465 using KeyTy = std::pair<Type, unsigned>;

466

468 const KeyTy &key) {

471 }

472

474 return key == KeyTy(elementType, stride);

475 }

476

478 : elementType(key.first), stride(key.second) {}

479

482 };

483

486 }

487

490 }

491

493

495

498 std::optional storage) {

499 llvm::cast(getElementType()).getExtensions(extensions, storage);

500 }

501

504 std::optional storage) {

505 {

506 static const Capability caps[] = {Capability::Shader};

508 capabilities.push_back(ref);

509 }

511 .getCapabilities(capabilities, storage);

512 }

513

514

515

516

517

519 if (auto floatType = llvm::dyn_cast(type)) {

520 return isValid(floatType);

521 }

522 if (auto intType = llvm::dyn_cast(type)) {

524 }

525 return false;

526 }

527

529 return llvm::is_contained({16u, 32u, 64u}, type.getWidth());

530 }

531

533 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());

534 }

535

537 std::optional storage) {

538 if (isa(*this)) {

539 static const Extension ext = Extension::SPV_KHR_bfloat16;

540 extensions.push_back(ext);

541 }

542

543

544

545

546 if (!storage)

547 return;

548

549 switch (*storage) {

550 case StorageClass::PushConstant:

551 case StorageClass::StorageBuffer:

552 case StorageClass::Uniform:

554 static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};

556 extensions.push_back(ref);

557 }

558 [[fallthrough]];

559 case StorageClass::Input:

560 case StorageClass::Output:

562 static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};

564 extensions.push_back(ref);

565 }

566 break;

567 default:

568 break;

569 }

570 }

571

574 std::optional storage) {

576

577

578

579

580

581 #define STORAGE_CASE(storage, cap8, cap16) \

582 case StorageClass::storage: { \

583 if (bitwidth == 8) { \

584 static const Capability caps[] = {Capability::cap8}; \

585 ArrayRef ref(caps, std::size(caps)); \

586 capabilities.push_back(ref); \

587 return; \

588 } \

589 if (bitwidth == 16) { \

590 static const Capability caps[] = {Capability::cap16}; \

591 ArrayRef ref(caps, std::size(caps)); \

592 capabilities.push_back(ref); \

593 return; \

594 } \

595 \

596 \

597 } break

598

599

600

601 if (storage) {

602 switch (*storage) {

603 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);

605 StorageBuffer16BitAccess);

607 StorageUniform16);

608 case StorageClass::Input:

609 case StorageClass::Output: {

610 if (bitwidth == 16) {

611 static const Capability caps[] = {Capability::StorageInputOutput16};

613 capabilities.push_back(ref);

614 return;

615 }

616 break;

617 }

618 default:

619 break;

620 }

621 }

622 #undef STORAGE_CASE

623

624

625

626

627 #define WIDTH_CASE(type, width) \

628 case width: { \

629 static const Capability caps[] = {Capability::type##width}; \

630 ArrayRef ref(caps, std::size(caps)); \

631 capabilities.push_back(ref); \

632 } break

633

634 if (auto intType = llvm::dyn_cast(*this)) {

635 switch (bitwidth) {

639 case 1:

640 case 32:

641 break;

642 default:

643 llvm_unreachable("invalid bitwidth to getCapabilities");

644 }

645 } else {

646 assert(llvm::isa(*this));

647 switch (bitwidth) {

648 case 16: {

649 if (isa(*this)) {

650 static const Capability cap = Capability::BFloat16TypeKHR;

651 capabilities.push_back(cap);

652 } else {

653 static const Capability cap = Capability::Float16;

654 capabilities.push_back(cap);

655 }

656 break;

657 }

659 case 32:

660 break;

661 default:

662 llvm_unreachable("invalid bitwidth to getCapabilities");

663 }

664 }

665

666 #undef WIDTH_CASE

667 }

668

671

672

673

674

675

676

677 if (bitWidth == 1)

678 return std::nullopt;

679 return bitWidth / 8;

680 }

681

682

683

684

685

687

688 if (llvm::isa(type.getDialect()))

689 return true;

690 if (llvm::isa(type))

691 return true;

692 if (auto vectorType = llvm::dyn_cast(type))

694 return false;

695 }

696

698 return isIntOrFloat() || llvm::isa(*this);

699 }

700

702 std::optional storage) {

703 if (auto scalarType = llvm::dyn_cast(*this)) {

704 scalarType.getExtensions(extensions, storage);

705 } else if (auto compositeType = llvm::dyn_cast(*this)) {

706 compositeType.getExtensions(extensions, storage);

707 } else if (auto imageType = llvm::dyn_cast(*this)) {

708 imageType.getExtensions(extensions, storage);

709 } else if (auto sampledImageType = llvm::dyn_cast(*this)) {

710 sampledImageType.getExtensions(extensions, storage);

711 } else if (auto matrixType = llvm::dyn_cast(*this)) {

712 matrixType.getExtensions(extensions, storage);

713 } else if (auto ptrType = llvm::dyn_cast(*this)) {

714 ptrType.getExtensions(extensions, storage);

715 } else {

716 llvm_unreachable("invalid SPIR-V Type to getExtensions");

717 }

718 }

719

722 std::optional storage) {

723 if (auto scalarType = llvm::dyn_cast(*this)) {

724 scalarType.getCapabilities(capabilities, storage);

725 } else if (auto compositeType = llvm::dyn_cast(*this)) {

726 compositeType.getCapabilities(capabilities, storage);

727 } else if (auto imageType = llvm::dyn_cast(*this)) {

728 imageType.getCapabilities(capabilities, storage);

729 } else if (auto sampledImageType = llvm::dyn_cast(*this)) {

730 sampledImageType.getCapabilities(capabilities, storage);

731 } else if (auto matrixType = llvm::dyn_cast(*this)) {

732 matrixType.getCapabilities(capabilities, storage);

733 } else if (auto ptrType = llvm::dyn_cast(*this)) {

734 ptrType.getCapabilities(capabilities, storage);

735 } else {

736 llvm_unreachable("invalid SPIR-V Type to getCapabilities");

737 }

738 }

739

741 if (auto scalarType = llvm::dyn_cast(*this))

742 return scalarType.getSizeInBytes();

743 if (auto compositeType = llvm::dyn_cast(*this))

744 return compositeType.getSizeInBytes();

745 return std::nullopt;

746 }

747

748

749

750

753

755

757

759 const KeyTy &key) {

762 }

763

765 };

766

769 }

770

773 Type imageType) {

775 }

776

778

779 LogicalResult

781 Type imageType) {

782 if (!llvm::isa(imageType))

783 return emitError() << "expected image type";

784

785 return success();

786 }

787

790 std::optional storage) {

791 llvm::cast(getImageType()).getExtensions(extensions, storage);

792 }

793

796 std::optional storage) {

797 llvm::cast(getImageType()).getCapabilities(capabilities, storage);

798 }

799

800

801

802

803

804

805

806

807

808

809

810

811

812

813

814

815

816

817

819

820

821

823 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),

824 numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),

825 identifier(identifier) {}

826

827

828

830 unsigned numMembers, Type const *memberTypes,

833 : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),

834 numMembers(numMembers), numMemberDecorations(numMemberDecorations),

835 memberDecorationsInfo(memberDecorationsInfo) {}

836

837

838

839

840

841

842

843

844

845

846

847

848

849

850

854

855

856

857

858

859

861 if (isIdentified()) {

862

863 return getIdentifier() == std::get<0>(key);

864 }

865

866 return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),

867 getMemberDecorationsInfo());

868 }

869

870

871

872

873

874

875

877 const KeyTy &key) {

878 StringRef keyIdentifier = std::get<0>(key);

879

880 if (!keyIdentifier.empty()) {

881 StringRef identifier = allocator.copyInto(keyIdentifier);

882

883

884

887 }

888

890

891

892 const Type *typesList = nullptr;

893 if (!keyTypes.empty()) {

894 typesList = allocator.copyInto(keyTypes).data();

895 }

896

898 if (!std::get<2>(key).empty()) {

900 assert(keyOffsetInfo.size() == keyTypes.size() &&

901 "size of offset information must be same as the size of number of "

902 "elements");

903 offsetInfoList = allocator.copyInto(keyOffsetInfo).data();

904 }

905

907 unsigned numMemberDecorations = 0;

908 if (!std::get<3>(key).empty()) {

909 auto keyMemberDecorations = std::get<3>(key);

910 numMemberDecorations = keyMemberDecorations.size();

911 memberDecorationList = allocator.copyInto(keyMemberDecorations).data();

912 }

913

916 numMemberDecorations, memberDecorationList);

917 }

918

920 return ArrayRef(memberTypesAndIsBodySet.getPointer(), numMembers);

921 }

922

924 if (offsetInfo) {

926 }

927 return {};

928 }

929

931 if (memberDecorationsInfo) {

933 numMemberDecorations);

934 }

935 return {};

936 }

937

939

940 bool isIdentified() const { return !identifier.empty(); }

941

942

943

944

945

946

947

948

949

954 if (!isIdentified())

955 return failure();

956

957 if (memberTypesAndIsBodySet.getInt() &&

958 (getMemberTypes() != structMemberTypes ||

959 getOffsetInfo() != structOffsetInfo ||

960 getMemberDecorationsInfo() != structMemberDecorationInfo))

961 return failure();

962

963 memberTypesAndIsBodySet.setInt(true);

964 numMembers = structMemberTypes.size();

965

966

967 if (!structMemberTypes.empty())

968 memberTypesAndIsBodySet.setPointer(

969 allocator.copyInto(structMemberTypes).data());

970

971 if (!structOffsetInfo.empty()) {

972 assert(structOffsetInfo.size() == structMemberTypes.size() &&

973 "size of offset information must be same as the size of number of "

974 "elements");

975 offsetInfo = allocator.copyInto(structOffsetInfo).data();

976 }

977

978 if (!structMemberDecorationInfo.empty()) {

979 numMemberDecorations = structMemberDecorationInfo.size();

980 memberDecorationsInfo =

981 allocator.copyInto(structMemberDecorationInfo).data();

982 }

983

984 return success();

985 }

986

993 };

994

999 assert(!memberTypes.empty() && "Struct needs at least one member type");

1000

1002 memberDecorations);

1003 llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());

1004 return Base::get(memberTypes.vec().front().getContext(),

1005 StringRef(), memberTypes, offsetInfo,

1006 sortedDecorations);

1007 }

1008

1010 StringRef identifier) {

1011 assert(!identifier.empty() &&

1012 "StructType identifier must be non-empty string");

1013

1017 }

1018

1023

1029

1030 return newStructType;

1031 }

1032

1034

1036

1038

1040 assert(getNumElements() > index && "member index out of range");

1041 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];

1042 }

1043

1045 return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),

1047 }

1048

1050

1052 assert(getNumElements() > index && "member index out of range");

1053 return getImpl()->offsetInfo[index];

1054 }

1055

1058 const {

1059 memberDecorations.clear();

1060 auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();

1061 memberDecorations.append(implMemberDecorations.begin(),

1062 implMemberDecorations.end());

1063 }

1064

1066 unsigned index,

1068 assert(getNumElements() > index && "member index out of range");

1069 auto memberDecorations = getImpl()->getMemberDecorationsInfo();

1070 decorationsInfo.clear();

1071 for (const auto &memberDecoration : memberDecorations) {

1072 if (memberDecoration.memberIndex == index) {

1073 decorationsInfo.push_back(memberDecoration);

1074 }

1075 if (memberDecoration.memberIndex > index) {

1076

1077 return;

1078 }

1079 }

1080 }

1081

1082 LogicalResult

1086 return Base::mutate(memberTypes, offsetInfo, memberDecorations);

1087 }

1088

1090 std::optional storage) {

1092 llvm::cast(elementType).getExtensions(extensions, storage);

1093 }

1094

1097 std::optional storage) {

1099 llvm::cast(elementType).getCapabilities(capabilities, storage);

1100 }

1101

1104 return llvm::hash_combine(memberDecorationInfo.memberIndex,

1106 }

1107

1108

1109

1110

1111

1114 : columnType(columnType), columnCount(columnCount) {}

1115

1116 using KeyTy = std::tuple<Type, uint32_t>;

1117

1119 const KeyTy &key) {

1120

1121

1124 }

1125

1127 return key == KeyTy(columnType, columnCount);

1128 }

1129

1132 };

1133

1136 }

1137

1139 Type columnType, uint32_t columnCount) {

1141 columnCount);

1142 }

1143

1144 LogicalResult

1146 Type columnType, uint32_t columnCount) {

1147 if (columnCount < 2 || columnCount > 4)

1148 return emitError() << "matrix can have 2, 3, or 4 columns only";

1149

1151 return emitError() << "matrix columns must be vectors of floats";

1152

1153

1154 ArrayRef<int64_t> columnShape = llvm::cast(columnType).getShape();

1155 if (columnShape.size() != 1)

1156 return emitError() << "matrix columns must be 1D vectors";

1157

1158 if (columnShape[0] < 2 || columnShape[0] > 4)

1159 return emitError() << "matrix columns must be of size 2, 3, or 4";

1160

1161 return success();

1162 }

1163

1164

1166 if (auto vectorType = llvm::dyn_cast(columnType)) {

1167 if (llvm::isa(vectorType.getElementType()))

1168 return true;

1169 }

1170 return false;

1171 }

1172

1174

1176 return llvm::cast(getImpl()->columnType).getElementType();

1177 }

1178

1180

1182 return llvm::cast(getImpl()->columnType).getShape()[0];

1183 }

1184

1187 }

1188

1190 std::optional storage) {

1191 llvm::cast(getColumnType()).getExtensions(extensions, storage);

1192 }

1193

1196 std::optional storage) {

1197 {

1198 static const Capability caps[] = {Capability::Matrix};

1200 capabilities.push_back(ref);

1201 }

1202

1203 llvm::cast(getColumnType()).getCapabilities(capabilities, storage);

1204 }

1205

1206

1207

1208

1209

1210 void SPIRVDialect::registerTypes() {

1213 }

static MLIRContext * getContext(OpFoldResult val)

constexpr unsigned getNumBits< ImageSamplerUseInfo >()

#define STORAGE_CASE(storage, cap8, cap16)

constexpr unsigned getNumBits< ImageFormat >()

static constexpr unsigned getNumBits()

#define WIDTH_CASE(type, width)

constexpr unsigned getNumBits< ImageArrayedInfo >()

constexpr unsigned getNumBits< ImageSamplingInfo >()

constexpr unsigned getNumBits< Dim >()

constexpr unsigned getNumBits< ImageDepthInfo >()

This class represents a diagnostic that is inflight and set to be reported.

MLIRContext is the top-level object for a collection of MLIR operations.

This is a utility allocator used to allocate memory for instances of derived types.

ArrayRef< T > copyInto(ArrayRef< T > elements)

Copy the specified array of elements into memory managed by our bump pointer allocator.

T * allocate()

Allocate an instance of the provided type.

This class provides an abstraction over the various different ranges of value types.

Base storage class appearing in a Type.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

Dialect & getDialect() const

Get the dialect this type is registered to.

MLIRContext * getContext() const

Return the MLIRContext in which this type was uniqued.

bool isIntOrFloat() const

Return true if this is an integer (of any signedness) or a float type.

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

ImplType * getImpl() const

Utility for easy access to the storage instance.

Type getElementType() const

unsigned getArrayStride() const

Returns the array stride in bytes.

unsigned getNumElements() const

static ArrayType get(Type elementType, unsigned elementCount)

std::optional< int64_t > getSizeInBytes()

Returns the array size in bytes.

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)

void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)

void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)

std::optional< int64_t > getSizeInBytes()

bool hasCompileTimeKnownNumElements() const

Return true if the number of elements is known at compile time and is not implementation dependent.

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)

unsigned getNumElements() const

Return the number of elements of the type.

static bool isValid(VectorType)

Returns true if the given vector type is valid for the SPIR-V dialect.

Type getElementType(unsigned) const

static bool classof(Type type)

Scope getScope() const

Returns the scope of the matrix.

uint32_t getRows() const

Returns the number of rows of the matrix.

uint32_t getColumns() const

Returns the number of columns of the matrix.

static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)

ArrayRef< int64_t > getShape() const

Type getElementType() const

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)

void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)

CooperativeMatrixUseKHR getUse() const

Returns the use parameter of the cooperative matrix.

static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)

ImageDepthInfo getDepthInfo() const

ImageArrayedInfo getArrayedInfo() const

ImageFormat getImageFormat() const

ImageSamplerUseInfo getSamplerUseInfo() const

Type getElementType() const

ImageSamplingInfo getSamplingInfo() const

void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)

static MatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)

unsigned getNumElements() const

Returns total number of elements (rows*columns).

static MatrixType get(Type columnType, uint32_t columnCount)

Type getColumnType() const

static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type columnType, uint32_t columnCount)

unsigned getNumColumns() const

Returns the number of columns.

static bool isValidColumnType(Type columnType)

Returns true if the matrix elements are vectors of float elements.

Type getElementType() const

Returns the elements' type (i.e, single element type).

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)

void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)

unsigned getNumRows() const

Returns the number of rows.

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)

Type getPointeeType() const

StorageClass getStorageClass() const

void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)

static PointerType get(Type pointeeType, StorageClass storageClass)

Type getElementType() const

void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)

unsigned getArrayStride() const

Returns the array stride in bytes.

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)

static RuntimeArrayType get(Type elementType)

std::optional< int64_t > getSizeInBytes()

Returns the size in bytes for each type.

static bool classof(Type type)

void getCapabilities(CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)

Appends to capabilities the capabilities needed for this type to appear in the given storage class.

void getExtensions(ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)

Appends to extensions the extensions needed for this type to appear in the given storage class.

static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type imageType)

void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< spirv::StorageClass > storage=std::nullopt)

static SampledImageType getChecked(function_ref< InFlightDiagnostic()> emitError, Type imageType)

Type getImageType() const

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< spirv::StorageClass > storage=std::nullopt)

static SampledImageType get(Type imageType)

static bool classof(Type type)

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)

static bool isValid(FloatType)

Returns true if the given integer type is valid for the SPIR-V dialect.

void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)

std::optional< int64_t > getSizeInBytes()

void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const

void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)

static StructType getIdentified(MLIRContext *context, StringRef identifier)

Construct an identified StructType.

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)

bool isIdentified() const

Returns true if the StructType is identified.

StringRef getIdentifier() const

For literal structs, return an empty string.

static StructType getEmpty(MLIRContext *context, StringRef identifier="")

Construct a (possibly identified) StructType with no members.

unsigned getNumElements() const

Type getElementType(unsigned) const

LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})

Sets the contents of an incomplete identified StructType.

TypeRange getElementTypes() const

static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})

Construct a literal StructType with at least one member.

uint64_t getMemberOffset(unsigned) const

llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)

Include the generated interface declarations.

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

ArrayTypeStorage(const KeyTy &key)

std::tuple< Type, unsigned, unsigned > KeyTy

static ArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)

bool operator==(const KeyTy &key) const

CooperativeMatrixTypeStorage(const KeyTy &key)

std::tuple< Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR > KeyTy

static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)

CooperativeMatrixUseKHR use

std::array< int64_t, 2 > shape

bool operator==(const KeyTy &key) const

bool operator==(const KeyTy &key) const

ImageSamplerUseInfo samplerUseInfo

static ImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)

ImageTypeStorage(const KeyTy &key)

std::tuple< Type, Dim, ImageDepthInfo, ImageArrayedInfo, ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat > KeyTy

ImageSamplingInfo samplingInfo

ImageArrayedInfo arrayedInfo

const uint32_t columnCount

MatrixTypeStorage(Type columnType, uint32_t columnCount)

bool operator==(const KeyTy &key) const

std::tuple< Type, uint32_t > KeyTy

static MatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)

StorageClass storageClass

PointerTypeStorage(const KeyTy &key)

bool operator==(const KeyTy &key) const

std::pair< Type, StorageClass > KeyTy

static PointerTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)

RuntimeArrayTypeStorage(const KeyTy &key)

std::pair< Type, unsigned > KeyTy

static RuntimeArrayTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)

bool operator==(const KeyTy &key) const

bool operator==(const KeyTy &key) const

SampledImageTypeStorage(const KeyTy &key)

static SampledImageTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)

Type storage for SPIR-V structure types:

ArrayRef< StructType::OffsetInfo > getOffsetInfo() const

StructTypeStorage(unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo)

Construct a storage object for a literal struct type.

StructType::OffsetInfo const * offsetInfo

bool operator==(const KeyTy &key) const

For identified structs, return true if the given key contains the same identifier.

LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef< Type > structMemberTypes, ArrayRef< StructType::OffsetInfo > structOffsetInfo, ArrayRef< StructType::MemberDecorationInfo > structMemberDecorationInfo)

Sets the struct type content for identified structs.

static StructTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)

If the given key contains a non-empty identifier, this method constructs an identified struct and lea...

ArrayRef< Type > getMemberTypes() const

StructTypeStorage(StringRef identifier)

Construct a storage object for an identified struct type.

ArrayRef< StructType::MemberDecorationInfo > getMemberDecorationsInfo() const

std::tuple< StringRef, ArrayRef< Type >, ArrayRef< StructType::OffsetInfo >, ArrayRef< StructType::MemberDecorationInfo > > KeyTy

A storage key is divided into 2 parts:

StructType::MemberDecorationInfo const * memberDecorationsInfo

llvm::PointerIntPair< Type const *, 1, bool > memberTypesAndIsBodySet

StringRef getIdentifier() const

unsigned numMemberDecorations

bool isIdentified() const