MLIR: lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

14 #include "llvm/ADT/TypeSwitch.h"

15 #include "llvm/Support/InterleavedRange.h"

16

17 using namespace mlir;

19

20

21

22

23

24 namespace mlir {

25 namespace spirv {

26 #include "mlir/Dialect/SPIRV/IR/SPIRVAttrUtils.inc"

27 }

28

29

30

31

32

33 namespace spirv {

34 namespace detail {

35

37 using KeyTy = std::tuple<Attribute, Attribute, Attribute>;

38

41 : descriptorSet(descriptorSet), binding(binding),

42 storageClass(storageClass) {}

43

45 return std::get<0>(key) == descriptorSet && std::get<1>(key) == binding &&

46 std::get<2>(key) == storageClass;

47 }

48

53 std::get<2>(key));

54 }

55

59 };

60

62 using KeyTy = std::tuple<Attribute, Attribute, Attribute>;

63

66 : version(version), capabilities(capabilities), extensions(extensions) {}

67

69 return std::get<0>(key) == version && std::get<1>(key) == capabilities &&

70 std::get<2>(key) == extensions;

71 }

72

77 std::get<2>(key));

78 }

79

83 };

84

87 std::tuple<Attribute, ClientAPI, Vendor, DeviceType, uint32_t, Attribute>;

88

90 Vendor vendorID, DeviceType deviceType,

91 uint32_t deviceID, Attribute limits)

92 : triple(triple), limits(limits), clientAPI(clientAPI),

93 vendorID(vendorID), deviceType(deviceType), deviceID(deviceID) {}

94

96 return key == std::make_tuple(triple, clientAPI, vendorID, deviceType,

97 deviceID, limits);

98 }

99

104 std::get<2>(key), std::get<3>(key),

105 std::get<4>(key), std::get<5>(key));

106 }

107

114 };

115 }

116 }

117 }

118

119

120

121

122

125 std::optionalspirv::StorageClass storageClass,

130 auto storageClassAttr =

131 storageClass ? b.getI32IntegerAttr(static_cast<uint32_t>(*storageClass))

132 : IntegerAttr();

133 return get(descriptorSetAttr, bindingAttr, storageClassAttr);

134 }

135

138 IntegerAttr storageClass) {

139 assert(descriptorSet && binding);

140 MLIRContext *context = descriptorSet.getContext();

141 return Base::get(context, descriptorSet, binding, storageClass);

142 }

143

145 return "interface_var_abi";

146 }

147

149 return llvm::cast(getImpl()->binding).getInt();

150 }

151

153 return llvm::cast(getImpl()->descriptorSet).getInt();

154 }

155

156 std::optionalspirv::StorageClass

158 if (getImpl()->storageClass)

159 return static_castspirv::StorageClass\(

160 llvm::cast(getImpl()->storageClass)

161 .getValue()

162 .getZExtValue());

163 return std::nullopt;

164 }

165

168 IntegerAttr binding, IntegerAttr storageClass) {

169 if (!descriptorSet.getType().isSignlessInteger(32))

170 return emitError() << "expected 32-bit integer for descriptor set";

171

172 if (!binding.getType().isSignlessInteger(32))

173 return emitError() << "expected 32-bit integer for binding";

174

175 if (storageClass) {

176 if (auto storageClassAttr = llvm::cast(storageClass)) {

177 auto storageClassValue =

178 spirv::symbolizeStorageClass(storageClassAttr.getInt());

179 if (!storageClassValue)

180 return emitError() << "unknown storage class";

181 } else {

182 return emitError() << "expected valid storage class";

183 }

184 }

185

186 return success();

187 }

188

189

190

191

192

197

198 auto versionAttr = b.getI32IntegerAttr(static_cast<uint32_t>(version));

199

201 capAttrs.reserve(capabilities.size());

202 for (spirv::Capability cap : capabilities)

203 capAttrs.push_back(b.getI32IntegerAttr(static_cast<uint32_t>(cap)));

204

206 extAttrs.reserve(extensions.size());

207 for (spirv::Extension ext : extensions)

208 extAttrs.push_back(b.getStringAttr(spirv::stringifyExtension(ext)));

209

210 return get(versionAttr, b.getArrayAttr(capAttrs), b.getArrayAttr(extAttrs));

211 }

212

214 ArrayAttr capabilities,

215 ArrayAttr extensions) {

216 assert(version && capabilities && extensions);

217 MLIRContext *context = version.getContext();

218 return Base::get(context, version, capabilities, extensions);

219 }

220

222

224 return static_castspirv::Version\(

225 llvm::cast(getImpl()->version).getValue().getZExtValue());

226 }

227

229 : llvm::mapped_iterator<ArrayAttr::iterator,

230 spirv::Extension (*)(Attribute)>(

232 return *symbolizeExtension(llvm::cast(attr).getValue());

233 }) {}

234

238 }

239

241 return llvm::cast(getImpl()->extensions);

242 }

243

245 : llvm::mapped_iterator<ArrayAttr::iterator,

246 spirv::Capability (*)(Attribute)>(

248 return *symbolizeCapability(

249 llvm::cast(attr).getValue().getZExtValue());

250 }) {}

251

255 }

256

258 return llvm::cast(getImpl()->capabilities);

259 }

260

263 ArrayAttr capabilities, ArrayAttr extensions) {

264 if (!version.getType().isSignlessInteger(32))

265 return emitError() << "expected 32-bit integer for version";

266

267 if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) {

268 if (auto intAttr = llvm::dyn_cast(attr))

269 if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue()))

270 return true;

271 return false;

272 }))

273 return emitError() << "unknown capability in capability list";

274

275 if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {

276 if (auto strAttr = llvm::dyn_cast(attr))

277 if (spirv::symbolizeExtension(strAttr.getValue()))

278 return true;

279 return false;

280 }))

281 return emitError() << "unknown extension in extension list";

282

283 return success();

284 }

285

286

287

288

289

292 Vendor vendorID, DeviceType deviceType, uint32_t deviceID) {

293 assert(triple && limits && "expected valid triple and limits");

294 MLIRContext *context = triple.getContext();

295 return Base::get(context, triple, clientAPI, vendorID, deviceType, deviceID,

296 limits);

297 }

298

300

302 return llvm::castspirv::VerCapExtAttr(getImpl()->triple);

303 }

304

306 return getTripleAttr().getVersion();

307 }

308

310 return getTripleAttr().getExtensions();

311 }

312

314 return getTripleAttr().getExtensionsAttr();

315 }

316

318 return getTripleAttr().getCapabilities();

319 }

320

322 return getTripleAttr().getCapabilitiesAttr();

323 }

324

326 return getImpl()->clientAPI;

327 }

328

330 return getImpl()->vendorID;

331 }

332

334 return getImpl()->deviceType;

335 }

336

338 return getImpl()->deviceID;

339 }

340

342 return llvm::castspirv::ResourceLimitsAttr(getImpl()->limits);

343 }

344

345

346

347

348

349 #define GET_ATTRDEF_CLASSES

350 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.cpp.inc"

351

352

353

354

355

356

357

358 static ParseResult

360 function_ref<LogicalResult(SMLoc, StringRef)> processKeyword) {

362 return failure();

363

364

366 return success();

367

368

369

371 auto loc = parser.getCurrentLocation();

372 StringRef keyword;

373 if (parser.parseKeyword(&keyword) ||

374 failed(processKeyword(loc, keyword)))

375 return failure();

376 return success();

377 })))

378 return failure();

380 }

381

382

385 return {};

386

388

390 return {};

391

392 IntegerAttr descriptorSetAttr;

393 {

395 uint32_t descriptorSet = 0;

397

398 if (!descriptorSetParseResult.has_value() ||

399 failed(*descriptorSetParseResult)) {

400 parser.emitError(loc, "missing descriptor set");

401 return {};

402 }

404 }

405

407 return {};

408

409 IntegerAttr bindingAttr;

410 {

412 uint32_t binding = 0;

414

415 if (!bindingParseResult.has_value() || failed(*bindingParseResult)) {

416 parser.emitError(loc, "missing binding");

417 return {};

418 }

420 }

421

423 return {};

424

425 IntegerAttr storageClassAttr;

426 {

429 StringRef storageClass;

431 return {};

432

433 if (auto storageClassSymbol =

434 spirv::symbolizeStorageClass(storageClass)) {

436 static_cast<uint32_t>(*storageClassSymbol));

437 } else {

438 parser.emitError(loc, "unknown storage class: ") << storageClass;

439 return {};

440 }

441 }

442 }

443

445 return {};

446

448 storageClassAttr);

449 }

450

453 return {};

454

456

457 IntegerAttr versionAttr;

458 {

460 StringRef version;

462 return {};

463

464 if (auto versionSymbol = spirv::symbolizeVersion(version)) {

465 versionAttr =

466 builder.getI32IntegerAttr(static_cast<uint32_t>(*versionSymbol));

467 } else {

468 parser.emitError(loc, "unknown version: ") << version;

469 return {};

470 }

471 }

472

473 ArrayAttr capabilitiesAttr;

474 {

476 SMLoc errorloc;

477 StringRef errorKeyword;

478

479 auto processCapability = [&](SMLoc loc, StringRef capability) {

480 if (auto capSymbol = spirv::symbolizeCapability(capability)) {

481 capabilities.push_back(

483 return success();

484 }

485 return errorloc = loc, errorKeyword = capability, failure();

486 };

488 if (!errorKeyword.empty())

489 parser.emitError(errorloc, "unknown capability: ") << errorKeyword;

490 return {};

491 }

492

493 capabilitiesAttr = builder.getArrayAttr(capabilities);

494 }

495

496 ArrayAttr extensionsAttr;

497 {

499 SMLoc errorloc;

500 StringRef errorKeyword;

501

502 auto processExtension = [&](SMLoc loc, StringRef extension) {

503 if (spirv::symbolizeExtension(extension)) {

504 extensions.push_back(builder.getStringAttr(extension));

505 return success();

506 }

507 return errorloc = loc, errorKeyword = extension, failure();

508 };

510 if (!errorKeyword.empty())

511 parser.emitError(errorloc, "unknown extension: ") << errorKeyword;

512 return {};

513 }

514

515 extensionsAttr = builder.getArrayAttr(extensions);

516 }

517

519 return {};

520

522 extensionsAttr);

523 }

524

525

528 return {};

529

532 return {};

533

534 auto clientAPI = spirv::ClientAPI::Unknown;

537 return {};

539 StringRef apiStr;

541 return {};

542 if (auto apiSymbol = spirv::symbolizeClientAPI(apiStr))

543 clientAPI = *apiSymbol;

544 else

545 parser.emitError(loc, "unknown client API: ") << apiStr;

547 return {};

548 }

549

550

551 Vendor vendorID = Vendor::Unknown;

552 DeviceType deviceType = DeviceType::Unknown;

554 {

556 StringRef vendorStr;

558 if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr))

559 vendorID = *vendorSymbol;

560 else

561 parser.emitError(loc, "unknown vendor: ") << vendorStr;

562

565 StringRef deviceTypeStr;

567 return {};

568 if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr))

569 deviceType = *deviceTypeSymbol;

570 else

571 parser.emitError(loc, "unknown device type: ") << deviceTypeStr;

572

576 return {};

577 }

578 }

580 return {};

581 }

582 }

583

584 ResourceLimitsAttr limitsAttr;

586 return {};

587

589 deviceType, deviceID);

590 }

591

593 Type type) const {

594

595 if (type) {

597 return {};

598 }

599

600

601 StringRef attrKind;

604 generatedAttributeParser(parser, &attrKind, type, attr);

606 return attr;

607

614

616 << attrKind;

617 return {};

618 }

619

620

621

622

623

626 << spirv::stringifyVersion(triple.getVersion()) << ", "

627 << llvm::interleaved_array(llvm::map_range(

629 << ", "

630 << llvm::interleaved_array(

632 << ">";

633 }

634

639 if (clientAPI != spirv::ClientAPI::Unknown)

640 printer << ", api=" << clientAPI;

641 spirv::Vendor vendorID = targetEnv.getVendorID();

642 spirv::DeviceType deviceType = targetEnv.getDeviceType();

643 uint32_t deviceID = targetEnv.getDeviceID();

644 if (vendorID != spirv::Vendor::Unknown) {

645 printer << ", " << spirv::stringifyVendor(vendorID);

646 if (deviceType != spirv::DeviceType::Unknown) {

647 printer << ":" << spirv::stringifyDeviceType(deviceType);

649 printer << ":" << deviceID;

650 }

651 }

653 }

654

659 << interfaceVarABIAttr.getBinding() << ")";

660 auto storageClass = interfaceVarABIAttr.getStorageClass();

661 if (storageClass)

662 printer << ", " << spirv::stringifyStorageClass(*storageClass);

663 printer << ">";

664 }

665

666 void SPIRVDialect::printAttribute(Attribute attr,

668 if (succeeded(generatedAttributePrinter(attr, printer)))

669 return;

670

671 if (auto targetEnv = llvm::dyn_cast(attr))

672 print(targetEnv, printer);

673 else if (auto vceAttr = llvm::dyn_cast(attr))

674 print(vceAttr, printer);

675 else if (auto interfaceVarABIAttr = llvm::dyn_cast(attr))

676 print(interfaceVarABIAttr, printer);

677 else

678 llvm_unreachable("unhandled SPIR-V attribute kind");

679 }

680

681

682

683

684

685 void spirv::SPIRVDialect::registerAttributes() {

686 addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();

687 addAttributes<

688 #define GET_ATTRDEF_LIST

689 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.cpp.inc"

690 >();

691 }

static Attribute parseTargetEnvAttr(DialectAsmParser &parser)

Parses a spirv::TargetEnvAttr.

static ParseResult parseKeywordList(DialectAsmParser &parser, function_ref< LogicalResult(SMLoc, StringRef)> processKeyword)

Parses a comma-separated list of keywords, invokes processKeyword on each of the parsed keyword,...

static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)

static Attribute parseVerCapExtAttr(DialectAsmParser &parser)

static Attribute parseInterfaceVarABIAttr(DialectAsmParser &parser)

Parses a spirv::InterfaceVarABIAttr.

virtual OptionalParseResult parseOptionalInteger(APInt &result)=0

Parse an optional integer value from the stream.

virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0

Parse a list of comma-separated items with an optional delimiter.

virtual Builder & getBuilder() const =0

Return a builder which provides useful access to MLIRContext, global objects like types and attribute...

virtual ParseResult parseOptionalKeyword(StringRef keyword)=0

Parse the given keyword if present.

virtual ParseResult parseRParen()=0

Parse a ) token.

virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0

Emit a diagnostic at the specified location and return failure.

virtual ParseResult parseOptionalColon()=0

Parse a : token if present.

virtual ParseResult parseLSquare()=0

Parse a [ token.

virtual ParseResult parseRSquare()=0

Parse a ] token.

ParseResult parseInteger(IntT &result)

Parse an integer value from the stream.

virtual ParseResult parseLess()=0

Parse a '<' token.

virtual ParseResult parseEqual()=0

Parse a = token.

virtual SMLoc getCurrentLocation()=0

Get the location of the next token and store it into the argument.

virtual ParseResult parseOptionalComma()=0

Parse a , token if present.

virtual SMLoc getNameLoc() const =0

Return the location of the original name token.

virtual ParseResult parseOptionalRSquare()=0

Parse a ] token if present.

virtual ParseResult parseGreater()=0

Parse a '>' token.

virtual ParseResult parseLParen()=0

Parse a ( token.

virtual ParseResult parseComma()=0

Parse a , token.

ParseResult parseKeyword(StringRef keyword)

Parse a given keyword.

virtual ParseResult parseAttribute(Attribute &result, Type type={})=0

Parse an arbitrary attribute of a given type and return it in result.

Base storage class appearing in an attribute.

Attributes are known-constant values of operations.

This class is a general helper class for creating context-global objects like types,...

IntegerAttr getI32IntegerAttr(int32_t value)

StringAttr getStringAttr(const Twine &bytes)

ArrayAttr getArrayAttr(ArrayRef< Attribute > value)

The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

This class represents a diagnostic that is inflight and set to be reported.

MLIRContext is the top-level object for a collection of MLIR operations.

This class implements Optional functionality for ParseResult.

bool has_value() const

Returns true if we contain a valid ParseResult value.

This is a utility allocator used to allocate memory for instances of derived types.

T * allocate()

Allocate an instance of the provided type.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

ImplType * getImpl() const

Utility for easy access to the storage instance.

An attribute that specifies the information regarding the interface variable: descriptor set,...

uint32_t getBinding()

Returns binding.

static StringRef getKindName()

Returns the attribute kind's name (without the 'spirv.' prefix).

uint32_t getDescriptorSet()

Returns descriptor set.

static InterfaceVarABIAttr get(uint32_t descriptorSet, uint32_t binding, std::optional< StorageClass > storageClass, MLIRContext *context)

Gets a InterfaceVarABIAttr.

static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, IntegerAttr descriptorSet, IntegerAttr binding, IntegerAttr storageClass)

std::optional< StorageClass > getStorageClass()

Returns spirv::StorageClass.

An attribute that specifies the target version, allowed extensions and capabilities,...

Version getVersion() const

Returns the target version.

VerCapExtAttr::cap_range getCapabilities()

Returns the target capabilities.

ResourceLimitsAttr getResourceLimits() const

Returns the target resource limits.

static StringRef getKindName()

Returns the attribute kind's name (without the 'spirv.' prefix).

VerCapExtAttr getTripleAttr() const

Returns the (version, capabilities, extensions) triple attribute.

ArrayAttr getCapabilitiesAttr()

Returns the target capabilities as an integer array attribute.

VerCapExtAttr::ext_range getExtensions()

Returns the target extensions.

Vendor getVendorID() const

Returns the vendor ID.

DeviceType getDeviceType() const

Returns the device type.

ClientAPI getClientAPI() const

Returns the client API.

ArrayAttr getExtensionsAttr()

Returns the target extensions as a string array attribute.

uint32_t getDeviceID() const

Returns the device ID.

static constexpr uint32_t kUnknownDeviceID

ID for unknown devices.

static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI=ClientAPI::Unknown, Vendor vendorID=Vendor::Unknown, DeviceType deviceType=DeviceType::Unknown, uint32_t deviceId=kUnknownDeviceID)

Gets a TargetEnvAttr instance.

An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.

cap_range getCapabilities()

Returns the capabilities.

static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, IntegerAttr version, ArrayAttr capabilities, ArrayAttr extensions)

Version getVersion()

Returns the version.

static StringRef getKindName()

Returns the attribute kind's name (without the 'spirv.' prefix).

static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)

Gets a VerCapExtAttr instance.

ArrayAttr getCapabilitiesAttr()

Returns the capabilities as an integer array attribute.

ext_range getExtensions()

Returns the extensions.

ArrayAttr getExtensionsAttr()

Returns the extensions as a string array attribute.

The OpAsmOpInterface, see OpAsmInterface.td for more details.

Include the generated interface declarations.

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)

This parses a single MLIR attribute to an MLIR context if it was valid.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

cap_iterator(ArrayAttr::iterator it)

ext_iterator(ArrayAttr::iterator it)

InterfaceVarABIAttributeStorage(Attribute descriptorSet, Attribute binding, Attribute storageClass)

std::tuple< Attribute, Attribute, Attribute > KeyTy

bool operator==(const KeyTy &key) const

static InterfaceVarABIAttributeStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)

static TargetEnvAttributeStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)

TargetEnvAttributeStorage(Attribute triple, ClientAPI clientAPI, Vendor vendorID, DeviceType deviceType, uint32_t deviceID, Attribute limits)

bool operator==(const KeyTy &key) const

std::tuple< Attribute, ClientAPI, Vendor, DeviceType, uint32_t, Attribute > KeyTy

static VerCapExtAttributeStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)

std::tuple< Attribute, Attribute, Attribute > KeyTy

VerCapExtAttributeStorage(Attribute version, Attribute capabilities, Attribute extensions)

bool operator==(const KeyTy &key) const