MLIR: lib/Target/SPIRV/Deserialization/Deserializer.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

14

23 #include "llvm/ADT/STLExtras.h"

24 #include "llvm/ADT/Sequence.h"

25 #include "llvm/ADT/SmallVector.h"

26 #include "llvm/ADT/StringExtras.h"

27 #include "llvm/ADT/bit.h"

28 #include "llvm/Support/Debug.h"

29 #include "llvm/Support/SaveAndRestore.h"

30 #include "llvm/Support/raw_ostream.h"

31 #include

32

33 using namespace mlir;

34

35 #define DEBUG_TYPE "spirv-deserialization"

36

37

38

39

40

41

44 isa_and_nonnullspirv::FuncOp(block->getParentOp());

45 }

46

47

48

49

50

54 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),

55 module(createModuleOp()), opBuilder(module->getRegion()), options(options)

56 #ifndef NDEBUG

57 ,

58 logger(llvm::dbgs())

59 #endif

60 {

61 }

62

64 LLVM_DEBUG({

65 logger.resetIndent();

66 logger.startLine()

67 << "//+++---------- start deserialization ----------+++//\n";

68 });

69

70 if (failed(processHeader()))

71 return failure();

72

73 spirv::Opcode opcode = spirv::Opcode::OpNop;

75 auto binarySize = binary.size();

76 while (curOffset < binarySize) {

77

78

79 if (failed(sliceInstruction(opcode, operands)))

80 return failure();

81

82 if (failed(processInstruction(opcode, operands)))

83 return failure();

84 }

85

86 assert(curOffset == binarySize &&

87 "deserializer should never index beyond the binary end");

88

89 for (auto &deferred : deferredInstructions) {

90 if (failed(processInstruction(deferred.first, deferred.second, false))) {

91 return failure();

92 }

93 }

94

95 attachVCETriple();

96

97 LLVM_DEBUG(logger.startLine()

98 << "//+++-------- completed deserialization --------+++//\n");

99 return success();

100 }

101

103 return std::move(module);

104 }

105

106

107

108

109

112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());

113 spirv::ModuleOp::build(builder, state);

115 }

116

117 LogicalResult spirv::Deserializer::processHeader() {

120 "SPIR-V binary module must have a 5-word header");

121

123 return emitError(unknownLoc, "incorrect magic number");

124

125

126 uint32_t majorVersion = (binary[1] << 8) >> 24;

127 uint32_t minorVersion = (binary[1] << 16) >> 24;

128 if (majorVersion == 1) {

129 switch (minorVersion) {

130 #define MIN_VERSION_CASE(v) \

131 case v: \

132 version = spirv::Version::V_1_##v; \

133 break

134

141 #undef MIN_VERSION_CASE

142 default:

143 return emitError(unknownLoc, "unsupported SPIR-V minor version: ")

144 << minorVersion;

145 }

146 } else {

147 return emitError(unknownLoc, "unsupported SPIR-V major version: ")

148 << majorVersion;

149 }

150

151

153 return success();

154 }

155

156 LogicalResult

157 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {

158 if (operands.size() != 1)

159 return emitError(unknownLoc, "OpCapability must have one parameter");

160

161 auto cap = spirv::symbolizeCapability(operands[0]);

162 if (!cap)

163 return emitError(unknownLoc, "unknown capability: ") << operands[0];

164

165 capabilities.insert(*cap);

166 return success();

167 }

168

169 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {

170 if (words.empty()) {

172 unknownLoc,

173 "OpExtension must have a literal string for the extension name");

174 }

175

176 unsigned wordIndex = 0;

178 if (wordIndex != words.size())

180 "unexpected trailing words in OpExtension instruction");

181 auto ext = spirv::symbolizeExtension(extName);

182 if (!ext)

183 return emitError(unknownLoc, "unknown extension: ") << extName;

184

185 extensions.insert(*ext);

186 return success();

187 }

188

189 LogicalResult

190 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {

191 if (words.size() < 2) {

193 "OpExtInstImport must have a result and a literal "

194 "string for the extended instruction set name");

195 }

196

197 unsigned wordIndex = 1;

199 if (wordIndex != words.size()) {

201 "unexpected trailing words in OpExtInstImport");

202 }

203 return success();

204 }

205

206 void spirv::Deserializer::attachVCETriple() {

207 (*module)->setAttr(

208 spirv::ModuleOp::getVCETripleAttrName(),

210 extensions.getArrayRef(), context));

211 }

212

213 LogicalResult

214 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {

215 if (operands.size() != 2)

216 return emitError(unknownLoc, "OpMemoryModel must have two operands");

217

218 (*module)->setAttr(

219 module->getAddressingModelAttrName(),

220 opBuilder.getAttrspirv::AddressingModelAttr(

221 static_castspirv::AddressingModel\(operands.front())));

222

223 (*module)->setAttr(module->getMemoryModelAttrName(),

224 opBuilder.getAttrspirv::MemoryModelAttr(

225 static_castspirv::MemoryModel\(operands.back())));

226

227 return success();

228 }

229

230 template <typename AttrTy, typename EnumAttrTy, typename EnumTy>

234 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {

235 if (words.size() != 4) {

236 return emitError(loc, "OpDecoration with ")

237 << decorationName << "needs a cache control integer literal and a "

238 << cacheControlKind << " cache control literal";

239 }

240 unsigned cacheLevel = words[2];

241 auto cacheControlAttr = static_cast<EnumTy>(words[3]);

242 auto value = opBuilder.getAttr(cacheLevel, cacheControlAttr);

244 if (auto attrList =

245 llvm::dyn_cast_or_null(decorations[words[0]].get(symbol)))

246 llvm::append_range(attrs, attrList);

247 attrs.push_back(value);

248 decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));

249 return success();

250 }

251

252 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {

253

254

255

256 if (words.size() < 2) {

258 unknownLoc, "OpDecorate must have at least result and Decoration");

259 }

260 auto decorationName =

261 stringifyDecoration(static_castspirv::Decoration\(words[1]));

262 if (decorationName.empty()) {

263 return emitError(unknownLoc, "invalid Decoration code : ") << words[1];

264 }

265 auto symbol = getSymbolDecoration(decorationName);

266 switch (static_castspirv::Decoration\(words[1])) {

267 case spirv::Decoration::FPFastMathMode:

268 if (words.size() != 3) {

269 return emitError(unknownLoc, "OpDecorate with ")

270 << decorationName << " needs a single integer literal";

271 }

272 decorations[words[0]].set(

274 static_cast<FPFastMathMode>(words[2])));

275 break;

276 case spirv::Decoration::FPRoundingMode:

277 if (words.size() != 3) {

278 return emitError(unknownLoc, "OpDecorate with ")

279 << decorationName << " needs a single integer literal";

280 }

281 decorations[words[0]].set(

283 static_cast<FPRoundingMode>(words[2])));

284 break;

285 case spirv::Decoration::DescriptorSet:

286 case spirv::Decoration::Binding:

287 if (words.size() != 3) {

288 return emitError(unknownLoc, "OpDecorate with ")

289 << decorationName << " needs a single integer literal";

290 }

291 decorations[words[0]].set(

292 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));

293 break;

294 case spirv::Decoration::BuiltIn:

295 if (words.size() != 3) {

296 return emitError(unknownLoc, "OpDecorate with ")

297 << decorationName << " needs a single integer literal";

298 }

299 decorations[words[0]].set(

300 symbol, opBuilder.getStringAttr(

301 stringifyBuiltIn(static_castspirv::BuiltIn\(words[2]))));

302 break;

303 case spirv::Decoration::ArrayStride:

304 if (words.size() != 3) {

305 return emitError(unknownLoc, "OpDecorate with ")

306 << decorationName << " needs a single integer literal";

307 }

308 typeDecorations[words[0]] = words[2];

309 break;

310 case spirv::Decoration::LinkageAttributes: {

311 if (words.size() < 4) {

312 return emitError(unknownLoc, "OpDecorate with ")

313 << decorationName

314 << " needs at least 1 string and 1 integer literal";

315 }

316

317

318

319

320

321

322 unsigned wordIndex = 2;

324 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(

325 static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));

326 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(

328 decorations[words[0]].set(symbol, llvm::dyn_cast(linkageAttr));

329 break;

330 }

331 case spirv::Decoration::Aliased:

332 case spirv::Decoration::AliasedPointer:

333 case spirv::Decoration::Block:

334 case spirv::Decoration::BufferBlock:

335 case spirv::Decoration::Flat:

336 case spirv::Decoration::NonReadable:

337 case spirv::Decoration::NonWritable:

338 case spirv::Decoration::NoPerspective:

339 case spirv::Decoration::NoSignedWrap:

340 case spirv::Decoration::NoUnsignedWrap:

341 case spirv::Decoration::RelaxedPrecision:

342 case spirv::Decoration::Restrict:

343 case spirv::Decoration::RestrictPointer:

344 case spirv::Decoration::NoContraction:

345 case spirv::Decoration::Constant:

346 if (words.size() != 2) {

347 return emitError(unknownLoc, "OpDecoration with ")

348 << decorationName << "needs a single target ";

349 }

350

351

352

353

354 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());

355 break;

356 case spirv::Decoration::Location:

357 case spirv::Decoration::SpecId:

358 if (words.size() != 3) {

359 return emitError(unknownLoc, "OpDecoration with ")

360 << decorationName << "needs a single integer literal";

361 }

362 decorations[words[0]].set(

363 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));

364 break;

365 case spirv::Decoration::CacheControlLoadINTEL: {

367 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(

368 unknownLoc, opBuilder, decorations, words, symbol, decorationName,

369 "load");

370 if (failed(res))

371 return res;

372 break;

373 }

374 case spirv::Decoration::CacheControlStoreINTEL: {

376 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(

377 unknownLoc, opBuilder, decorations, words, symbol, decorationName,

378 "store");

379 if (failed(res))

380 return res;

381 break;

382 }

383 default:

384 return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;

385 }

386 return success();

387 }

388

389 LogicalResult

390 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {

391

392 if (words.size() < 3) {

394 "OpMemberDecorate must have at least 3 operands");

395 }

396

397 auto decoration = static_castspirv::Decoration\(words[2]);

398 if (decoration == spirv::Decoration::Offset && words.size() != 4) {

400 " missing offset specification in OpMemberDecorate with "

401 "Offset decoration");

402 }

404 if (words.size() > 3) {

405 decorationOperands = words.slice(3);

406 }

407 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;

408 return success();

409 }

410

411 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {

412 if (words.size() < 3) {

413 return emitError(unknownLoc, "OpMemberName must have at least 3 operands");

414 }

415 unsigned wordIndex = 2;

417 if (wordIndex != words.size()) {

419 "unexpected trailing words in OpMemberName instruction");

420 }

421 memberNameMap[words[0]][words[1]] = name;

422 return success();

423 }

424

425 LogicalResult spirv::Deserializer::setFunctionArgAttrs(

427 if (!decorations.contains(argID)) {

429 return success();

430 }

431

432 spirv::DecorationAttr foundDecorationAttr;

434 for (auto decoration :

435 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,

436 spirv::Decoration::AliasedPointer,

437 spirv::Decoration::RestrictPointer}) {

438

439 if (decAttr.getName() !=

440 getSymbolDecoration(stringifyDecoration(decoration)))

441 continue;

442

443 if (foundDecorationAttr)

445 "more than one Aliased/Restrict decorations for "

446 "function argument with result ")

447 << argID;

448

450 break;

451 }

452

453 if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(

454 spirv::Decoration::RelaxedPrecision))) {

455

456

457

458

459 if (foundDecorationAttr)

460 return emitError(unknownLoc, "already found a decoration for function "

461 "argument with result ")

462 << argID;

463

465 context, spirv::Decoration::RelaxedPrecision);

466 }

467 }

468

469 if (!foundDecorationAttr)

470 return emitError(unknownLoc, "unimplemented decoration support for "

471 "function argument with result ")

472 << argID;

473

475 foundDecorationAttr);

477 return success();

478 }

479

480 LogicalResult

482 if (curFunction) {

483 return emitError(unknownLoc, "found function inside function");

484 }

485

486

487 if (operands.size() != 4) {

488 return emitError(unknownLoc, "OpFunction must have 4 parameters");

489 }

491 if (!resultType) {

492 return emitError(unknownLoc, "undefined result type from ")

493 << operands[0];

494 }

495

496 uint32_t fnID = operands[1];

497 if (funcMap.count(fnID)) {

498 return emitError(unknownLoc, "duplicate function definition/declaration");

499 }

500

501 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);

502 if (!fnControl) {

503 return emitError(unknownLoc, "unknown Function Control: ") << operands[2];

504 }

505

507 if (!fnType || !isa(fnType)) {

508 return emitError(unknownLoc, "unknown function type from ")

509 << operands[3];

510 }

511 auto functionType = cast(fnType);

512

513 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||

514 (functionType.getNumResults() == 1 &&

515 functionType.getResult(0) != resultType)) {

516 return emitError(unknownLoc, "mismatch in function type ")

517 << functionType << " and return type " << resultType << " specified";

518 }

519

520 std::string fnName = getFunctionSymbol(fnID);

521 auto funcOp = opBuilder.createspirv::FuncOp(

522 unknownLoc, fnName, functionType, fnControl.value());

523

524 if (decorations.count(fnID)) {

525 for (auto attr : decorations[fnID].getAttrs()) {

526 funcOp->setAttr(attr.getName(), attr.getValue());

527 }

528 }

529 curFunction = funcMap[fnID] = funcOp;

530 auto *entryBlock = funcOp.addEntryBlock();

531 LLVM_DEBUG({

532 logger.startLine()

533 << "//===-------------------------------------------===//\n";

534 logger.startLine() << "[fn] name: " << fnName << "\n";

535 logger.startLine() << "[fn] type: " << fnType << "\n";

536 logger.startLine() << "[fn] ID: " << fnID << "\n";

537 logger.startLine() << "[fn] entry block: " << entryBlock << "\n";

538 logger.indent();

539 });

540

542 argAttrs.resize(functionType.getNumInputs());

543

544

545 if (functionType.getNumInputs()) {

546 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {

547 auto argType = functionType.getInput(i);

548 spirv::Opcode opcode = spirv::Opcode::OpNop;

550 if (failed(sliceInstruction(opcode, operands,

551 spirv::Opcode::OpFunctionParameter))) {

552 return failure();

553 }

554 if (opcode != spirv::Opcode::OpFunctionParameter) {

556 unknownLoc,

557 "missing OpFunctionParameter instruction for argument ")

558 << i;

559 }

560 if (operands.size() != 2) {

562 unknownLoc,

563 "expected result type and result for OpFunctionParameter");

564 }

565 auto argDefinedType = getType(operands[0]);

566 if (!argDefinedType || argDefinedType != argType) {

568 "mismatch in argument type between function type "

569 "definition ")

570 << functionType << " and argument type definition "

571 << argDefinedType << " at argument " << i;

572 }

573 if (getValue(operands[1])) {

574 return emitError(unknownLoc, "duplicate definition of result ")

575 << operands[1];

576 }

577 if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {

578 return failure();

579 }

580

581 auto argValue = funcOp.getArgument(i);

582 valueMap[operands[1]] = argValue;

583 }

584 }

585

586 if (llvm::any_of(argAttrs, [](Attribute attr) {

587 auto argAttr = cast(attr);

588 return !argAttr.empty();

589 }))

590 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));

591

592

593

594

595 auto linkageAttr = funcOp.getLinkageAttributes();

596 auto hasImportLinkage =

597 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==

598 spirv::LinkageType::Import);

599 if (hasImportLinkage)

600 funcOp.eraseBody();

601

602

603

605

606 spirv::Opcode opcode = spirv::Opcode::OpNop;

608

609

610

611

612

613

614 if (failed(sliceInstruction(opcode, instOperands,

615 spirv::Opcode::OpFunctionEnd))) {

616 return failure();

617 }

618 if (opcode == spirv::Opcode::OpFunctionEnd) {

619 return processFunctionEnd(instOperands);

620 }

621 if (opcode != spirv::Opcode::OpLabel) {

622 return emitError(unknownLoc, "a basic block must start with OpLabel");

623 }

624 if (instOperands.size() != 1) {

625 return emitError(unknownLoc, "OpLabel should only have result ");

626 }

627 blockMap[instOperands[0]] = entryBlock;

628 if (failed(processLabel(instOperands))) {

629 return failure();

630 }

631

632

633

634 while (succeeded(sliceInstruction(opcode, instOperands,

635 spirv::Opcode::OpFunctionEnd)) &&

636 opcode != spirv::Opcode::OpFunctionEnd) {

637 if (failed(processInstruction(opcode, instOperands))) {

638 return failure();

639 }

640 }

641 if (opcode != spirv::Opcode::OpFunctionEnd) {

642 return failure();

643 }

644

645 return processFunctionEnd(instOperands);

646 }

647

648 LogicalResult

649 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {

650

651 if (!operands.empty()) {

652 return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");

653 }

654

655

656

657

658 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {

659 return failure();

660 }

661

662 curBlock = nullptr;

663 curFunction = std::nullopt;

664

665 LLVM_DEBUG({

666 logger.unindent();

667 logger.startLine()

668 << "//===-------------------------------------------===//\n";

669 });

670 return success();

671 }

672

673 std::optional<std::pair<Attribute, Type>>

674 spirv::Deserializer::getConstant(uint32_t id) {

675 auto constIt = constantMap.find(id);

676 if (constIt == constantMap.end())

677 return std::nullopt;

678 return constIt->getSecond();

679 }

680

681 std::optionalspirv::SpecConstOperationMaterializationInfo

682 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {

683 auto constIt = specConstOperationMap.find(id);

684 if (constIt == specConstOperationMap.end())

685 return std::nullopt;

686 return constIt->getSecond();

687 }

688

689 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {

690 auto funcName = nameMap.lookup(id).str();

691 if (funcName.empty()) {

692 funcName = "spirv_fn_" + std::to_string(id);

693 }

694 return funcName;

695 }

696

697 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {

698 auto constName = nameMap.lookup(id).str();

699 if (constName.empty()) {

700 constName = "spirv_spec_const_" + std::to_string(id);

701 }

702 return constName;

703 }

704

705 spirv::SpecConstantOp

706 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,

707 TypedAttr defaultValue) {

708 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));

709 auto op = opBuilder.createspirv::SpecConstantOp(unknownLoc, symName,

710 defaultValue);

711 if (decorations.count(resultID)) {

712 for (auto attr : decorations[resultID].getAttrs())

713 op->setAttr(attr.getName(), attr.getValue());

714 }

715 specConstMap[resultID] = op;

716 return op;

717 }

718

719 LogicalResult

720 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {

721 unsigned wordIndex = 0;

722 if (operands.size() < 3) {

724 unknownLoc,

725 "OpVariable needs at least 3 operands, type, and storage class");

726 }

727

728

729 auto type = getType(operands[wordIndex]);

730 if (!type) {

731 return emitError(unknownLoc, "unknown result type : ")

732 << operands[wordIndex];

733 }

734 auto ptrType = dyn_castspirv::PointerType(type);

735 if (!ptrType) {

737 "expected a result type to be a spirv.ptr, found : ")

738 << type;

739 }

740 wordIndex++;

741

742

743 auto variableID = operands[wordIndex];

744 auto variableName = nameMap.lookup(variableID).str();

745 if (variableName.empty()) {

746 variableName = "spirv_var_" + std::to_string(variableID);

747 }

748 wordIndex++;

749

750

751 auto storageClass = static_castspirv::StorageClass\(operands[wordIndex]);

752 if (ptrType.getStorageClass() != storageClass) {

753 return emitError(unknownLoc, "mismatch in storage class of pointer type ")

754 << type << " and that specified in OpVariable instruction : "

755 << stringifyStorageClass(storageClass);

756 }

757 wordIndex++;

758

759

761

762 if (wordIndex < operands.size()) {

764

765 if (auto initOp = getGlobalVariable(operands[wordIndex]))

766 op = initOp;

767 else if (auto initOp = getSpecConstant(operands[wordIndex]))

768 op = initOp;

769 else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))

770 op = initOp;

771 else

772 return emitError(unknownLoc, "unknown ")

773 << operands[wordIndex] << "used as initializer";

774

776 wordIndex++;

777 }

778 if (wordIndex != operands.size()) {

780 "found more operands than expected when deserializing "

781 "OpVariable instruction, only ")

782 << wordIndex << " of " << operands.size() << " processed";

783 }

784 auto loc = createFileLineColLoc(opBuilder);

785 auto varOp = opBuilder.createspirv::GlobalVariableOp(

786 loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),

787 initializer);

788

789

790 if (decorations.count(variableID)) {

791 for (auto attr : decorations[variableID].getAttrs())

792 varOp->setAttr(attr.getName(), attr.getValue());

793 }

794 globalVariableMap[variableID] = varOp;

795 return success();

796 }

797

798 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {

799 auto constInfo = getConstant(id);

800 if (!constInfo) {

801 return nullptr;

802 }

803 return dyn_cast(constInfo->first);

804 }

805

806 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {

807 if (operands.size() < 2) {

808 return emitError(unknownLoc, "OpName needs at least 2 operands");

809 }

810 if (!nameMap.lookup(operands[0]).empty()) {

811 return emitError(unknownLoc, "duplicate name found for result ")

812 << operands[0];

813 }

814 unsigned wordIndex = 1;

816 if (wordIndex != operands.size()) {

818 "unexpected trailing words in OpName instruction");

819 }

820 nameMap[operands[0]] = name;

821 return success();

822 }

823

824

825

826

827

828 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,

830 if (operands.empty()) {

831 return emitError(unknownLoc, "type instruction with opcode ")

832 << spirv::stringifyOpcode(opcode) << " needs at least one ";

833 }

834

835

836

837 if (typeMap.count(operands[0])) {

838 return emitError(unknownLoc, "duplicate definition for result ")

839 << operands[0];

840 }

841

842 switch (opcode) {

843 case spirv::Opcode::OpTypeVoid:

844 if (operands.size() != 1)

845 return emitError(unknownLoc, "OpTypeVoid must have no parameters");

846 typeMap[operands[0]] = opBuilder.getNoneType();

847 break;

848 case spirv::Opcode::OpTypeBool:

849 if (operands.size() != 1)

850 return emitError(unknownLoc, "OpTypeBool must have no parameters");

851 typeMap[operands[0]] = opBuilder.getI1Type();

852 break;

853 case spirv::Opcode::OpTypeInt: {

854 if (operands.size() != 3)

856 unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");

857

858

859

860

861

862

863

864

866 : IntegerType::SignednessSemantics::Signless;

867 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);

868 } break;

869 case spirv::Opcode::OpTypeFloat: {

870 if (operands.size() != 2 && operands.size() != 3)

872 "OpTypeFloat expects either 2 operands (type, bitwidth) "

873 "or 3 operands (type, bitwidth, encoding), but got ")

874 << operands.size();

875 uint32_t bitWidth = operands[1];

876

877 Type floatTy;

878 switch (bitWidth) {

879 case 16:

880 floatTy = opBuilder.getF16Type();

881 break;

882 case 32:

883 floatTy = opBuilder.getF32Type();

884 break;

885 case 64:

886 floatTy = opBuilder.getF64Type();

887 break;

888 default:

889 return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")

890 << bitWidth;

891 }

892

893 if (operands.size() == 3) {

894 if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)

895 return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")

896 << operands[2];

897 if (bitWidth != 16)

899 "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")

900 << bitWidth << " (expected 16)";

901 floatTy = opBuilder.getBF16Type();

902 }

903

904 typeMap[operands[0]] = floatTy;

905 } break;

906 case spirv::Opcode::OpTypeVector: {

907 if (operands.size() != 3) {

909 unknownLoc,

910 "OpTypeVector must have element type and count parameters");

911 }

913 if (!elementTy) {

914 return emitError(unknownLoc, "OpTypeVector references undefined ")

915 << operands[1];

916 }

917 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);

918 } break;

919 case spirv::Opcode::OpTypePointer: {

920 return processOpTypePointer(operands);

921 } break;

922 case spirv::Opcode::OpTypeArray:

923 return processArrayType(operands);

924 case spirv::Opcode::OpTypeCooperativeMatrixKHR:

925 return processCooperativeMatrixTypeKHR(operands);

926 case spirv::Opcode::OpTypeFunction:

927 return processFunctionType(operands);

928 case spirv::Opcode::OpTypeImage:

929 return processImageType(operands);

930 case spirv::Opcode::OpTypeSampledImage:

931 return processSampledImageType(operands);

932 case spirv::Opcode::OpTypeRuntimeArray:

933 return processRuntimeArrayType(operands);

934 case spirv::Opcode::OpTypeStruct:

935 return processStructType(operands);

936 case spirv::Opcode::OpTypeMatrix:

937 return processMatrixType(operands);

938 default:

939 return emitError(unknownLoc, "unhandled type instruction");

940 }

941 return success();

942 }

943

944 LogicalResult

945 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {

946 if (operands.size() != 3)

947 return emitError(unknownLoc, "OpTypePointer must have two parameters");

948

949 auto pointeeType = getType(operands[2]);

950 if (!pointeeType)

951 return emitError(unknownLoc, "unknown OpTypePointer pointee type ")

952 << operands[2];

953

954 uint32_t typePointerID = operands[0];

955 auto storageClass = static_castspirv::StorageClass\(operands[1]);

957

958 for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);

959 deferredStructIt != std::end(deferredStructTypesInfos);) {

960 for (auto *unresolvedMemberIt =

961 std::begin(deferredStructIt->unresolvedMemberTypes);

962 unresolvedMemberIt !=

963 std::end(deferredStructIt->unresolvedMemberTypes);) {

964 if (unresolvedMemberIt->first == typePointerID) {

965

966

967

968 deferredStructIt->memberTypes[unresolvedMemberIt->second] =

969 typeMap[typePointerID];

970 unresolvedMemberIt =

971 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);

972 } else {

973 ++unresolvedMemberIt;

974 }

975 }

976

977 if (deferredStructIt->unresolvedMemberTypes.empty()) {

978

979 auto structType = deferredStructIt->deferredStructType;

980

981 assert(structType && "expected a spirv::StructType");

982 assert(structType.isIdentified() && "expected an indentified struct");

983

984 if (failed(structType.trySetBody(

985 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,

986 deferredStructIt->memberDecorationsInfo)))

987 return failure();

988

989 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);

990 } else {

991 ++deferredStructIt;

992 }

993 }

994

995 return success();

996 }

997

998 LogicalResult

1000 if (operands.size() != 3) {

1002 "OpTypeArray must have element type and count parameters");

1003 }

1004

1006 if (!elementTy) {

1007 return emitError(unknownLoc, "OpTypeArray references undefined ")

1008 << operands[1];

1009 }

1010

1011 unsigned count = 0;

1012

1013 auto countInfo = getConstant(operands[2]);

1014 if (!countInfo) {

1015 return emitError(unknownLoc, "OpTypeArray count ")

1016 << operands[2] << "can only come from normal constant right now";

1017 }

1018

1019 if (auto intVal = dyn_cast(countInfo->first)) {

1020 count = intVal.getValue().getZExtValue();

1021 } else {

1022 return emitError(unknownLoc, "OpTypeArray count must come from a "

1023 "scalar integer constant instruction");

1024 }

1025

1027 elementTy, count, typeDecorations.lookup(operands[0]));

1028 return success();

1029 }

1030

1031 LogicalResult

1032 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {

1033 assert(!operands.empty() && "No operands for processing function type");

1034 if (operands.size() == 1) {

1035 return emitError(unknownLoc, "missing return type for OpTypeFunction");

1036 }

1037 auto returnType = getType(operands[1]);

1038 if (!returnType) {

1039 return emitError(unknownLoc, "unknown return type in OpTypeFunction");

1040 }

1042 for (size_t i = 2, e = operands.size(); i < e; ++i) {

1043 auto ty = getType(operands[i]);

1044 if (!ty) {

1045 return emitError(unknownLoc, "unknown argument type in OpTypeFunction");

1046 }

1047 argTypes.push_back(ty);

1048 }

1050 if (!isVoidType(returnType)) {

1052 }

1053 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);

1054 return success();

1055 }

1056

1057 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(

1059 if (operands.size() != 6) {

1061 "OpTypeCooperativeMatrixKHR must have element type, "

1062 "scope, row and column parameters, and use");

1063 }

1064

1066 if (!elementTy) {

1068 "OpTypeCooperativeMatrixKHR references undefined ")

1069 << operands[1];

1070 }

1071

1072 std::optionalspirv::Scope scope =

1073 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());

1074 if (!scope) {

1076 unknownLoc,

1077 "OpTypeCooperativeMatrixKHR references undefined scope ")

1078 << operands[2];

1079 }

1080

1081 IntegerAttr rowsAttr = getConstantInt(operands[3]);

1082 IntegerAttr columnsAttr = getConstantInt(operands[4]);

1083 IntegerAttr useAttr = getConstantInt(operands[5]);

1084

1085 if (!rowsAttr)

1086 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Rows` references "

1087 "undefined constant ")

1088 << operands[3];

1089

1090 if (!columnsAttr)

1091 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Columns` "

1092 "references undefined constant ")

1093 << operands[4];

1094

1095 if (!useAttr)

1096 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Use` references "

1097 "undefined constant ")

1098 << operands[5];

1099

1100 unsigned rows = rowsAttr.getInt();

1101 unsigned columns = columnsAttr.getInt();

1102

1103 std::optionalspirv::CooperativeMatrixUseKHR use =

1104 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());

1105 if (!use) {

1107 unknownLoc,

1108 "OpTypeCooperativeMatrixKHR references undefined use ")

1109 << operands[5];

1110 }

1111

1112 typeMap[operands[0]] =

1114 return success();

1115 }

1116

1117 LogicalResult

1118 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {

1119 if (operands.size() != 2) {

1120 return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");

1121 }

1123 if (!memberType) {

1125 "OpTypeRuntimeArray references undefined ")

1126 << operands[1];

1127 }

1129 memberType, typeDecorations.lookup(operands[0]));

1130 return success();

1131 }

1132

1133 LogicalResult

1134 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {

1135

1136

1137 if (operands.empty()) {

1138 return emitError(unknownLoc, "OpTypeStruct must have at least result ");

1139 }

1140

1141 if (operands.size() == 1) {

1142

1143 typeMap[operands[0]] =

1145 return success();

1146 }

1147

1148

1151

1152 for (auto op : llvm::drop_begin(operands, 1)) {

1154 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);

1155

1156 if (!memberType && !typeForwardPtr)

1157 return emitError(unknownLoc, "OpTypeStruct references undefined ")

1158 << op;

1159

1160 if (!memberType)

1161 unresolvedMemberTypes.emplace_back(op, memberTypes.size());

1162

1163 memberTypes.push_back(memberType);

1164 }

1165

1168 if (memberDecorationMap.count(operands[0])) {

1169 auto &allMemberDecorations = memberDecorationMap[operands[0]];

1170 for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {

1171 if (allMemberDecorations.count(memberIndex)) {

1172 for (auto &memberDecoration : allMemberDecorations[memberIndex]) {

1173

1174 if (memberDecoration.first == spirv::Decoration::Offset) {

1175

1176 if (offsetInfo.empty()) {

1177 offsetInfo.resize(memberTypes.size());

1178 }

1179 offsetInfo[memberIndex] = memberDecoration.second[0];

1180 } else {

1181 if (!memberDecoration.second.empty()) {

1182 memberDecorationsInfo.emplace_back(memberIndex, 1,

1183 memberDecoration.first,

1184 memberDecoration.second[0]);

1185 } else {

1186 memberDecorationsInfo.emplace_back(memberIndex, 0,

1187 memberDecoration.first, 0);

1188 }

1189 }

1190 }

1191 }

1192 }

1193 }

1194

1195 uint32_t structID = operands[0];

1196 std::string structIdentifier = nameMap.lookup(structID).str();

1197

1198 if (structIdentifier.empty()) {

1199 assert(unresolvedMemberTypes.empty() &&

1200 "didn't expect unresolved member types");

1201 typeMap[structID] =

1203 } else {

1205 typeMap[structID] = structTy;

1206

1207 if (!unresolvedMemberTypes.empty())

1208 deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,

1209 memberTypes, offsetInfo,

1210 memberDecorationsInfo});

1211 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,

1212 memberDecorationsInfo)))

1213 return failure();

1214 }

1215

1216

1217

1218 return success();

1219 }

1220

1221 LogicalResult

1222 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {

1223 if (operands.size() != 3) {

1224

1225 return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"

1226 " (result_id, column_type, and column_count)");

1227 }

1228

1230 if (!elementTy) {

1232 "OpTypeMatrix references undefined column type.")

1233 << operands[1];

1234 }

1235

1236 uint32_t colsCount = operands[2];

1238 return success();

1239 }

1240

1241 LogicalResult

1242 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {

1243 if (operands.size() != 2)

1245 "OpTypeForwardPointer instruction must have two operands");

1246

1247 typeForwardPointerIDs.insert(operands[0]);

1248

1249

1250

1251 return success();

1252 }

1253

1254 LogicalResult

1255 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {

1256

1257 if (operands.size() != 8)

1259 unknownLoc,

1260 "OpTypeImage with non-eight operands are not supported yet");

1261

1263 if (!elementTy)

1264 return emitError(unknownLoc, "OpTypeImage references undefined : ")

1265 << operands[1];

1266

1267 auto dim = spirv::symbolizeDim(operands[2]);

1268 if (!dim)

1269 return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")

1270 << operands[2];

1271

1272 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);

1273 if (!depthInfo)

1274 return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")

1275 << operands[3];

1276

1277 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);

1278 if (!arrayedInfo)

1279 return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")

1280 << operands[4];

1281

1282 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);

1283 if (!samplingInfo)

1284 return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];

1285

1286 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);

1287 if (!samplerUseInfo)

1288 return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")

1289 << operands[6];

1290

1291 auto format = spirv::symbolizeImageFormat(operands[7]);

1292 if (!format)

1293 return emitError(unknownLoc, "unknown Format for OpTypeImage: ")

1294 << operands[7];

1295

1297 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),

1298 samplingInfo.value(), samplerUseInfo.value(), format.value());

1299 return success();

1300 }

1301

1302 LogicalResult

1303 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {

1304 if (operands.size() != 2)

1305 return emitError(unknownLoc, "OpTypeSampledImage must have two operands");

1306

1308 if (!elementTy)

1310 "OpTypeSampledImage references undefined : ")

1311 << operands[1];

1312

1314 return success();

1315 }

1316

1317

1318

1319

1320

1321 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,

1322 bool isSpec) {

1323 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";

1324

1325 if (operands.size() < 2) {

1327 << opname << " must have type and result ";

1328 }

1329 if (operands.size() < 3) {

1331 << opname << " must have at least 1 more parameter";

1332 }

1333

1335 if (!resultType) {

1336 return emitError(unknownLoc, "undefined result type from ")

1337 << operands[0];

1338 }

1339

1340 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {

1341 if (bitwidth == 64) {

1342 if (operands.size() == 4) {

1343 return success();

1344 }

1346 << opname << " should have 2 parameters for 64-bit values";

1347 }

1348 if (bitwidth <= 32) {

1349 if (operands.size() == 3) {

1350 return success();

1351 }

1352

1354 << opname

1355 << " should have 1 parameter for values with no more than 32 bits";

1356 }

1357 return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")

1358 << bitwidth;

1359 };

1360

1361 auto resultID = operands[1];

1362

1363 if (auto intType = dyn_cast(resultType)) {

1364 auto bitwidth = intType.getWidth();

1365 if (failed(checkOperandSizeForBitwidth(bitwidth))) {

1366 return failure();

1367 }

1368

1369 APInt value;

1370 if (bitwidth == 64) {

1371

1372

1373

1374 struct DoubleWord {

1375 uint32_t word1;

1376 uint32_t word2;

1377 } words = {operands[2], operands[3]};

1378 value = APInt(64, llvm::bit_cast<uint64_t>(words), true);

1379 } else if (bitwidth <= 32) {

1380 value = APInt(bitwidth, operands[2], true,

1381 true);

1382 }

1383

1384 auto attr = opBuilder.getIntegerAttr(intType, value);

1385

1386 if (isSpec) {

1387 createSpecConstant(unknownLoc, resultID, attr);

1388 } else {

1389

1390

1391 constantMap.try_emplace(resultID, attr, intType);

1392 }

1393

1394 return success();

1395 }

1396

1397 if (auto floatType = dyn_cast(resultType)) {

1398 auto bitwidth = floatType.getWidth();

1399 if (failed(checkOperandSizeForBitwidth(bitwidth))) {

1400 return failure();

1401 }

1402

1403 APFloat value(0.f);

1404 if (floatType.isF64()) {

1405

1406

1407

1408 struct DoubleWord {

1409 uint32_t word1;

1410 uint32_t word2;

1411 } words = {operands[2], operands[3]};

1412 value = APFloat(llvm::bit_cast(words));

1413 } else if (floatType.isF32()) {

1414 value = APFloat(llvm::bit_cast(operands[2]));

1415 } else if (floatType.isF16()) {

1416 APInt data(16, operands[2]);

1417 value = APFloat(APFloat::IEEEhalf(), data);

1418 } else if (floatType.isBF16()) {

1419 APInt data(16, operands[2]);

1420 value = APFloat(APFloat::BFloat(), data);

1421 }

1422

1423 auto attr = opBuilder.getFloatAttr(floatType, value);

1424 if (isSpec) {

1425 createSpecConstant(unknownLoc, resultID, attr);

1426 } else {

1427

1428

1429 constantMap.try_emplace(resultID, attr, floatType);

1430 }

1431

1432 return success();

1433 }

1434

1435 return emitError(unknownLoc, "OpConstant can only generate values of "

1436 "scalar integer or floating-point type");

1437 }

1438

1439 LogicalResult spirv::Deserializer::processConstantBool(

1441 if (operands.size() != 2) {

1442 return emitError(unknownLoc, "Op")

1443 << (isSpec ? "Spec" : "") << "Constant"

1444 << (isTrue ? "True" : "False")

1445 << " must have type and result ";

1446 }

1447

1448 auto attr = opBuilder.getBoolAttr(isTrue);

1449 auto resultID = operands[1];

1450 if (isSpec) {

1451 createSpecConstant(unknownLoc, resultID, attr);

1452 } else {

1453

1454

1455 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());

1456 }

1457

1458 return success();

1459 }

1460

1461 LogicalResult

1462 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {

1463 if (operands.size() < 2) {

1465 "OpConstantComposite must have type and result ");

1466 }

1467 if (operands.size() < 3) {

1469 "OpConstantComposite must have at least 1 parameter");

1470 }

1471

1473 if (!resultType) {

1474 return emitError(unknownLoc, "undefined result type from ")

1475 << operands[0];

1476 }

1477

1479 elements.reserve(operands.size() - 2);

1480 for (unsigned i = 2, e = operands.size(); i < e; ++i) {

1481 auto elementInfo = getConstant(operands[i]);

1482 if (!elementInfo) {

1483 return emitError(unknownLoc, "OpConstantComposite component ")

1484 << operands[i] << " must come from a normal constant";

1485 }

1486 elements.push_back(elementInfo->first);

1487 }

1488

1489 auto resultID = operands[1];

1490 if (auto shapedType = dyn_cast(resultType)) {

1492

1493

1494 constantMap.try_emplace(resultID, attr, shapedType);

1495 } else if (auto arrayType = dyn_castspirv::ArrayType(resultType)) {

1496 auto attr = opBuilder.getArrayAttr(elements);

1497 constantMap.try_emplace(resultID, attr, resultType);

1498 } else {

1499 return emitError(unknownLoc, "unsupported OpConstantComposite type: ")

1500 << resultType;

1501 }

1502

1503 return success();

1504 }

1505

1506 LogicalResult

1507 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {

1508 if (operands.size() < 2) {

1510 "OpConstantComposite must have type and result ");

1511 }

1512 if (operands.size() < 3) {

1514 "OpConstantComposite must have at least 1 parameter");

1515 }

1516

1518 if (!resultType) {

1519 return emitError(unknownLoc, "undefined result type from ")

1520 << operands[0];

1521 }

1522

1523 auto resultID = operands[1];

1524 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));

1525

1527 elements.reserve(operands.size() - 2);

1528 for (unsigned i = 2, e = operands.size(); i < e; ++i) {

1529 auto elementInfo = getSpecConstant(operands[i]);

1531 }

1532

1533 auto op = opBuilder.createspirv::SpecConstantCompositeOp(

1535 opBuilder.getArrayAttr(elements));

1536 specConstCompositeMap[resultID] = op;

1537

1538 return success();

1539 }

1540

1541 LogicalResult

1542 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {

1543 if (operands.size() < 3)

1544 return emitError(unknownLoc, "OpConstantOperation must have type , "

1545 "result , and operand opcode");

1546

1547 uint32_t resultTypeID = operands[0];

1548

1549 if (getType(resultTypeID))

1550 return emitError(unknownLoc, "undefined result type from ")

1551 << resultTypeID;

1552

1553 uint32_t resultID = operands[1];

1554 spirv::Opcode enclosedOpcode = static_castspirv::Opcode\(operands[2]);

1555 auto emplaceResult = specConstOperationMap.try_emplace(

1556 resultID,

1557 SpecConstOperationMaterializationInfo{

1558 enclosedOpcode, resultTypeID,

1560

1561 if (!emplaceResult.second)

1562 return emitError(unknownLoc, "value with : ")

1563 << resultID << " is probably defined before.";

1564

1565 return success();

1566 }

1567

1568 Value spirv::Deserializer::materializeSpecConstantOperation(

1569 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,

1571

1572 Type resultType = getType(resultTypeID);

1573

1574

1575

1576

1577

1578

1579

1580

1581

1582

1583

1585 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);

1586 constexpr uint32_t fakeID = static_cast<uint32_t>(-3);

1587

1589 enclosedOpResultTypeAndOperands.push_back(resultTypeID);

1590 enclosedOpResultTypeAndOperands.push_back(fakeID);

1591 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),

1592 enclosedOpOperands.end());

1593

1594

1595

1596

1597

1598 if (failed(

1599 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))

1601

1602

1603

1604 Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());

1605

1606 auto loc = createFileLineColLoc(opBuilder);

1607 auto specConstOperationOp =

1608 opBuilder.createspirv::SpecConstantOperationOp(loc, resultType);

1609

1610 Region &body = specConstOperationOp.getBody();

1611

1612 body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),

1615

1616

1617

1619 opBuilder.setInsertionPointToEnd(&block);

1620

1621 opBuilder.createspirv::YieldOp(loc, block.front().getResult(0));

1622 return specConstOperationOp.getResult();

1623 }

1624

1625 LogicalResult

1626 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {

1627 if (operands.size() != 2) {

1629 "OpConstantNull must have type and result ");

1630 }

1631

1633 if (!resultType) {

1634 return emitError(unknownLoc, "undefined result type from ")

1635 << operands[0];

1636 }

1637

1638 auto resultID = operands[1];

1639 if (resultType.isIntOrFloat() || isa(resultType)) {

1640 auto attr = opBuilder.getZeroAttr(resultType);

1641

1642

1643 constantMap.try_emplace(resultID, attr, resultType);

1644 return success();

1645 }

1646

1647 return emitError(unknownLoc, "unsupported OpConstantNull type: ")

1648 << resultType;

1649 }

1650

1651

1652

1653

1654

1655 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {

1656 if (auto *block = getBlock(id)) {

1657 LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id

1658 << " @ " << block << "\n");

1659 return block;

1660 }

1661

1662

1663

1664

1665 auto *block = curFunction->addBlock();

1666 LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id

1667 << " @ " << block << "\n");

1668 return blockMap[id] = block;

1669 }

1670

1671 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {

1672 if (!curBlock) {

1673 return emitError(unknownLoc, "OpBranch must appear inside a block");

1674 }

1675

1676 if (operands.size() != 1) {

1677 return emitError(unknownLoc, "OpBranch must take exactly one target label");

1678 }

1679

1680 auto *target = getOrCreateBlock(operands[0]);

1681 auto loc = createFileLineColLoc(opBuilder);

1682

1683

1684

1685 opBuilder.createspirv::BranchOp(loc, target);

1686

1687 clearDebugLine();

1688 return success();

1689 }

1690

1691 LogicalResult

1692 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {

1693 if (!curBlock) {

1695 "OpBranchConditional must appear inside a block");

1696 }

1697

1698 if (operands.size() != 3 && operands.size() != 5) {

1700 "OpBranchConditional must have condition, true label, "

1701 "false label, and optionally two branch weights");

1702 }

1703

1704 auto condition = getValue(operands[0]);

1705 auto *trueBlock = getOrCreateBlock(operands[1]);

1706 auto *falseBlock = getOrCreateBlock(operands[2]);

1707

1708 std::optional<std::pair<uint32_t, uint32_t>> weights;

1709 if (operands.size() == 5) {

1710 weights = std::make_pair(operands[3], operands[4]);

1711 }

1712

1713

1714

1715 auto loc = createFileLineColLoc(opBuilder);

1716 opBuilder.createspirv::BranchConditionalOp(

1717 loc, condition, trueBlock,

1720

1721 clearDebugLine();

1722 return success();

1723 }

1724

1725 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {

1726 if (!curFunction) {

1727 return emitError(unknownLoc, "OpLabel must appear inside a function");

1728 }

1729

1730 if (operands.size() != 1) {

1731 return emitError(unknownLoc, "OpLabel should only have result ");

1732 }

1733

1734 auto labelID = operands[0];

1735

1736 auto *block = getOrCreateBlock(labelID);

1737 LLVM_DEBUG(logger.startLine()

1738 << "[block] populating block " << block << "\n");

1739

1740 assert(block->empty() && "re-deserialize the same block!");

1741

1742 opBuilder.setInsertionPointToStart(block);

1743 blockMap[labelID] = curBlock = block;

1744

1745 return success();

1746 }

1747

1748 LogicalResult

1749 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {

1750 if (!curBlock) {

1751 return emitError(unknownLoc, "OpSelectionMerge must appear in a block");

1752 }

1753

1754 if (operands.size() < 2) {

1756 unknownLoc,

1757 "OpSelectionMerge must specify merge target and selection control");

1758 }

1759

1760 auto *mergeBlock = getOrCreateBlock(operands[0]);

1761 auto loc = createFileLineColLoc(opBuilder);

1762 auto selectionControl = operands[1];

1763

1764 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)

1765 .second) {

1767 unknownLoc,

1768 "a block cannot have more than one OpSelectionMerge instruction");

1769 }

1770

1771 return success();

1772 }

1773

1774 LogicalResult

1775 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {

1776 if (!curBlock) {

1777 return emitError(unknownLoc, "OpLoopMerge must appear in a block");

1778 }

1779

1780 if (operands.size() < 3) {

1781 return emitError(unknownLoc, "OpLoopMerge must specify merge target, "

1782 "continue target and loop control");

1783 }

1784

1785 auto *mergeBlock = getOrCreateBlock(operands[0]);

1786 auto *continueBlock = getOrCreateBlock(operands[1]);

1787 auto loc = createFileLineColLoc(opBuilder);

1788 uint32_t loopControl = operands[2];

1789

1790 if (!blockMergeInfo

1791 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)

1792 .second) {

1794 unknownLoc,

1795 "a block cannot have more than one OpLoopMerge instruction");

1796 }

1797

1798 return success();

1799 }

1800

1801 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {

1802 if (!curBlock) {

1803 return emitError(unknownLoc, "OpPhi must appear in a block");

1804 }

1805

1806 if (operands.size() < 4) {

1807 return emitError(unknownLoc, "OpPhi must specify result type, result , "

1808 "and variable-parent pairs");

1809 }

1810

1811

1812 Type blockArgType = getType(operands[0]);

1813 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);

1814 valueMap[operands[1]] = blockArg;

1815 LLVM_DEBUG(logger.startLine()

1816 << "[phi] created block argument " << blockArg

1817 << " id = " << operands[1] << " of type " << blockArgType << "\n");

1818

1819

1820

1821 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {

1822 uint32_t value = operands[i];

1823 Block *predecessor = getOrCreateBlock(operands[i + 1]);

1824 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};

1825 blockPhiInfo[predecessorTargetPair].push_back(value);

1826 LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor

1827 << " with arg id = " << value << "\n");

1828 }

1829

1830 return success();

1831 }

1832

1833 namespace {

1834

1835

1836 class ControlFlowStructurizer {

1837 public:

1838 #ifndef NDEBUG

1839 ControlFlowStructurizer(Location loc, uint32_t control,

1842 llvm::ScopedPrinter &logger)

1843 : location(loc), control(control), blockMergeInfo(mergeInfo),

1844 headerBlock(header), mergeBlock(merge), continueBlock(cont),

1845 logger(logger) {}

1846 #else

1847 ControlFlowStructurizer(Location loc, uint32_t control,

1850 : location(loc), control(control), blockMergeInfo(mergeInfo),

1851 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}

1852 #endif

1853

1854

1855

1856

1857

1858

1859

1860

1861 LogicalResult structurize();

1862

1863 private:

1864

1865

1866 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);

1867

1868

1869 spirv::LoopOp createLoopOp(uint32_t loopControl);

1870

1871

1872 void collectBlocksInConstruct();

1873

1875 uint32_t control;

1876

1878

1879 Block *headerBlock;

1880 Block *mergeBlock;

1881 Block *continueBlock;

1882

1884

1885 #ifndef NDEBUG

1886

1887 llvm::ScopedPrinter &logger;

1888 #endif

1889 };

1890 }

1891

1892 spirv::SelectionOp

1893 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {

1894

1895

1896 OpBuilder builder(&mergeBlock->front());

1897

1898 auto control = static_castspirv::SelectionControl\(selectionControl);

1899 auto selectionOp = builder.createspirv::SelectionOp(location, control);

1900 selectionOp.addMergeBlock(builder);

1901

1902 return selectionOp;

1903 }

1904

1905 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {

1906

1907

1908 OpBuilder builder(&mergeBlock->front());

1909

1910 auto control = static_castspirv::LoopControl\(loopControl);

1911 auto loopOp = builder.createspirv::LoopOp(location, control);

1912 loopOp.addEntryAndMergeBlock(builder);

1913

1914 return loopOp;

1915 }

1916

1917 void ControlFlowStructurizer::collectBlocksInConstruct() {

1918 assert(constructBlocks.empty() && "expected empty constructBlocks");

1919

1920

1921 constructBlocks.insert(headerBlock);

1922

1923

1924

1925 for (unsigned i = 0; i < constructBlocks.size(); ++i) {

1926 for (auto *successor : constructBlocks[i]->getSuccessors())

1927 if (successor != mergeBlock)

1928 constructBlocks.insert(successor);

1929 }

1930 }

1931

1932 LogicalResult ControlFlowStructurizer::structurize() {

1934 bool isLoop = continueBlock != nullptr;

1936 if (auto loopOp = createLoopOp(control))

1937 op = loopOp.getOperation();

1938 } else {

1939 if (auto selectionOp = createSelectionOp(control))

1940 op = selectionOp.getOperation();

1941 }

1942 if (!op)

1943 return failure();

1945

1947

1948

1949 mapper.map(mergeBlock, &body.back());

1950

1951 collectBlocksInConstruct();

1952

1953

1954

1955

1956

1957

1958

1959

1960

1961

1962

1963

1964

1965

1966

1967

1968

1969

1970

1971

1973 for (auto *block : constructBlocks) {

1974

1975

1976 auto *newBlock = builder.createBlock(&body.back());

1977 mapper.map(block, newBlock);

1978 LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock

1979 << " from block " << block << "\n");

1982 auto newArg =

1983 newBlock->addArgument(blockArg.getType(), blockArg.getLoc());

1984 mapper.map(blockArg, newArg);

1985 LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "

1986 << blockArg << " to " << newArg << "\n");

1987 }

1988 } else {

1989 LLVM_DEBUG(logger.startLine()

1990 << "[cf] block " << block << " is a function entry block\n");

1991 }

1992

1993 for (auto &op : *block)

1994 newBlock->push_back(op.clone(mapper));

1995 }

1996

1997

1998 auto remapOperands = [&](Operation *op) {

2001 operand.set(mappedOp);

2004 succOp.set(mappedOp);

2005 };

2006 for (auto &block : body)

2007 block.walk(remapOperands);

2008

2009

2010

2011

2012

2013

2014

2015 headerBlock->replaceAllUsesWith(mergeBlock);

2016

2017 LLVM_DEBUG({

2018 logger.startLine() << "[cf] after cloning and fixing references:\n";

2019 headerBlock->getParentOp()->print(logger.getOStream());

2020 logger.startLine() << "\n";

2021 });

2022

2024 if (!mergeBlock->args_empty()) {

2025 return mergeBlock->getParentOp()->emitError(

2026 "OpPhi in loop merge block unsupported");

2027 }

2028

2029

2030

2031

2032 for (BlockArgument blockArg : headerBlock->getArguments())

2033 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());

2034

2035

2036

2038 if (!headerBlock->args_empty())

2039 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};

2040

2041

2042

2043 builder.setInsertionPointToEnd(&body.front());

2044 builder.createspirv::BranchOp(location, mapper.lookupOrNull(headerBlock),

2046 }

2047

2048

2049

2051

2052

2054

2055

2056

2057

2058

2059

2060

2061

2062

2063

2064

2065

2067 for (BlockArgument blockArg : mergeBlock->getArguments()) {

2068

2069

2070

2071

2072 body.back().addArgument(blockArg.getType(), blockArg.getLoc());

2073 valuesToYield.push_back(body.back().getArguments().back());

2074 outsideUses.push_back(blockArg);

2075 }

2076

2077

2078

2079 LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");

2080

2081

2082 for (auto *block : constructBlocks)

2083 block->dropAllReferences();

2084

2085

2086

2087

2088 for (Block *block : constructBlocks) {

2092 valuesToYield.push_back(mapper.lookupOrNull(result));

2093 outsideUses.push_back(result);

2094 }

2095 }

2096 for (BlockArgument &arg : block->getArguments()) {

2097 if (!arg.use_empty()) {

2098 valuesToYield.push_back(mapper.lookupOrNull(arg));

2099 outsideUses.push_back(arg);

2100 }

2101 }

2102 }

2103

2104 assert(valuesToYield.size() == outsideUses.size());

2105

2106

2107

2108 if (!valuesToYield.empty()) {

2109 LLVM_DEBUG(logger.startLine()

2110 << "[cf] yielding values from the selection / loop region\n");

2111

2112

2113 auto mergeOps = body.back().getOpsspirv::MergeOp();

2114 Operation *merge = llvm::getSingleElement(mergeOps);

2115 assert(merge);

2116 merge->setOperands(valuesToYield);

2117

2118

2119

2120

2121

2122

2123

2124 builder.setInsertionPoint(&mergeBlock->front());

2125

2127

2129 newOp = builder.createspirv::LoopOp(

2131 static_castspirv::LoopControl\(control));

2132 else

2133 newOp = builder.createspirv::SelectionOp(

2135 static_castspirv::SelectionControl\(control));

2136

2138

2139

2141 op = newOp;

2142

2143

2144

2145 for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)

2146 outsideUses[i].replaceAllUsesWith(op->getResult(i));

2147

2148

2149

2150

2152 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());

2153 }

2154

2155

2156

2157

2158

2159 for (auto *block : constructBlocks) {

2162 return op.emitOpError("failed control flow structurization: value has "

2163 "uses outside of the "

2164 "enclosing selection/loop construct");

2165 for (BlockArgument &arg : block->getArguments())

2166 if (!arg.use_empty())

2167 return emitError(arg.getLoc(), "failed control flow structurization: "

2168 "block argument has uses outside of the "

2169 "enclosing selection/loop construct");

2170 }

2171

2172

2173 for (auto *block : constructBlocks) {

2174

2175

2176

2177

2178

2179

2180

2181

2182

2183

2184

2185

2186

2187

2188

2189

2190

2191

2192

2193

2194

2195

2196

2197

2198

2199

2200

2201

2202

2203

2204

2205

2206

2207

2208

2209

2210

2211

2212

2213 auto updateMergeInfo = [&](Block *block) -> WalkResult {

2214 auto it = blockMergeInfo.find(block);

2215 if (it != blockMergeInfo.end()) {

2216

2217 Location loc = it->second.loc;

2218

2220 if (!newHeader)

2221 return emitError(loc, "failed control flow structurization: nested "

2222 "loop header block should be remapped!");

2223

2224 Block *newContinue = it->second.continueBlock;

2225 if (newContinue) {

2226 newContinue = mapper.lookupOrNull(newContinue);

2227 if (!newContinue)

2228 return emitError(loc, "failed control flow structurization: nested "

2229 "loop continue block should be remapped!");

2230 }

2231

2232 Block *newMerge = it->second.mergeBlock;

2234 newMerge = mappedTo;

2235

2236

2237

2238 blockMergeInfo.erase(it);

2239 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,

2240 newContinue);

2241 }

2242

2244 };

2245

2246 if (block->walk(updateMergeInfo).wasInterrupted())

2247 return failure();

2248

2249

2250

2251

2252

2254 LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block

2255 << " to only contain a spirv.Branch op\n");

2256

2257

2258 block->clear();

2259 builder.setInsertionPointToEnd(block);

2260 builder.createspirv::BranchOp(location, mergeBlock);

2261 } else {

2262 LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");

2263 block->erase();

2264 }

2265 }

2266

2267 LLVM_DEBUG(logger.startLine()

2268 << "[cf] after structurizing construct with header block "

2269 << headerBlock << ":\n"

2270 << *op << "\n");

2271

2272 return success();

2273 }

2274

2275 LogicalResult spirv::Deserializer::wireUpBlockArgument() {

2276 LLVM_DEBUG({

2277 logger.startLine()

2278 << "//----- [phi] start wiring up block arguments -----//\n";

2279 logger.indent();

2280 });

2281

2283

2284 for (const auto &info : blockPhiInfo) {

2285 Block *block = info.first.first;

2286 Block *target = info.first.second;

2287 const BlockPhiInfo &phiInfo = info.second;

2288 LLVM_DEBUG({

2289 logger.startLine() << "[phi] block " << block << "\n";

2290 logger.startLine() << "[phi] before creating block argument:\n";

2292 logger.startLine() << "\n";

2293 });

2294

2295

2296

2298 opBuilder.setInsertionPoint(op);

2299

2301 blockArgs.reserve(phiInfo.size());

2302 for (uint32_t valueId : phiInfo) {

2303 if (Value value = getValue(valueId)) {

2304 blockArgs.push_back(value);

2305 LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value

2306 << " id = " << valueId << "\n");

2307 } else {

2308 return emitError(unknownLoc, "OpPhi references undefined value!");

2309 }

2310 }

2311

2312 if (auto branchOp = dyn_castspirv::BranchOp(op)) {

2313

2314 opBuilder.createspirv::BranchOp(branchOp.getLoc(), branchOp.getTarget(),

2315 blockArgs);

2316 branchOp.erase();

2317 } else if (auto branchCondOp = dyn_castspirv::BranchConditionalOp(op)) {

2318 assert((branchCondOp.getTrueBlock() == target ||

2319 branchCondOp.getFalseBlock() == target) &&

2320 "expected target to be either the true or false target");

2321 if (target == branchCondOp.getTrueTarget())

2322 opBuilder.createspirv::BranchConditionalOp(

2323 branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,

2324 branchCondOp.getFalseBlockArguments(),

2325 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),

2326 branchCondOp.getFalseTarget());

2327 else

2328 opBuilder.createspirv::BranchConditionalOp(

2329 branchCondOp.getLoc(), branchCondOp.getCondition(),

2330 branchCondOp.getTrueBlockArguments(), blockArgs,

2331 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),

2332 branchCondOp.getFalseBlock());

2333

2334 branchCondOp.erase();

2335 } else {

2336 return emitError(unknownLoc, "unimplemented terminator for Phi creation");

2337 }

2338

2339 LLVM_DEBUG({

2340 logger.startLine() << "[phi] after creating block argument:\n";

2342 logger.startLine() << "\n";

2343 });

2344 }

2345 blockPhiInfo.clear();

2346

2347 LLVM_DEBUG({

2348 logger.unindent();

2349 logger.startLine()

2350 << "//--- [phi] completed wiring up block arguments ---//\n";

2351 });

2352 return success();

2353 }

2354

2355 LogicalResult spirv::Deserializer::splitConditionalBlocks() {

2356

2358 for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();

2359 it != e; ++it) {

2360 auto &[block, mergeInfo] = *it;

2361

2362

2363 if (mergeInfo.continueBlock)

2364 continue;

2365

2367 continue;

2368

2370 assert(terminator);

2371

2372 if (!isaspirv::BranchConditionalOp(terminator))

2373 continue;

2374

2375

2376 bool splitHeaderMergeBlock = false;

2377 for (const auto &[_, mergeInfo] : blockMergeInfo) {

2378 if (mergeInfo.mergeBlock == block)

2379 splitHeaderMergeBlock = true;

2380 }

2381

2382

2383

2384

2385

2386 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {

2389 builder.createspirv::BranchOp(block->getParent()->getLoc(), newBlock);

2390

2391

2392

2393 blockMergeInfo.erase(block);

2394 blockMergeInfo.try_emplace(newBlock, mergeInfo);

2395 }

2396 }

2397

2398 return success();

2399 }

2400

2401 LogicalResult spirv::Deserializer::structurizeControlFlow() {

2402 if (options.enableControlFlowStructurization) {

2403 LLVM_DEBUG(

2404 {

2405 logger.startLine()

2406 << "//----- [cf] skip structurizing control flow -----//\n";

2407 logger.indent();

2408 });

2409 return success();

2410 }

2411

2412 LLVM_DEBUG({

2413 logger.startLine()

2414 << "//----- [cf] start structurizing control flow -----//\n";

2415 logger.indent();

2416 });

2417

2418 LLVM_DEBUG({

2419 logger.startLine() << "[cf] split conditional blocks\n";

2420 logger.startLine() << "\n";

2421 });

2422

2423 if (failed(splitConditionalBlocks())) {

2424 return failure();

2425 }

2426

2427

2428

2429

2430 while (!blockMergeInfo.empty()) {

2431 Block *headerBlock = blockMergeInfo.begin()->first;

2432 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;

2433

2434 LLVM_DEBUG({

2435 logger.startLine() << "[cf] header block " << headerBlock << ":\n";

2436 headerBlock->print(logger.getOStream());

2437 logger.startLine() << "\n";

2438 });

2439

2440 auto *mergeBlock = mergeInfo.mergeBlock;

2441 assert(mergeBlock && "merge block cannot be nullptr");

2442 if (mergeInfo.continueBlock && !mergeBlock->args_empty())

2443 return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");

2444 LLVM_DEBUG({

2445 logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";

2446 mergeBlock->print(logger.getOStream());

2447 logger.startLine() << "\n";

2448 });

2449

2450 auto *continueBlock = mergeInfo.continueBlock;

2451 LLVM_DEBUG(if (continueBlock) {

2452 logger.startLine() << "[cf] continue block " << continueBlock << ":\n";

2453 continueBlock->print(logger.getOStream());

2454 logger.startLine() << "\n";

2455 });

2456

2457

2458 blockMergeInfo.erase(blockMergeInfo.begin());

2459 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,

2460 blockMergeInfo, headerBlock,

2461 mergeBlock, continueBlock

2462 #ifndef NDEBUG

2463 ,

2464 logger

2465 #endif

2466 );

2467 if (failed(structurizer.structurize()))

2468 return failure();

2469 }

2470

2471 LLVM_DEBUG({

2472 logger.unindent();

2473 logger.startLine()

2474 << "//--- [cf] completed structurizing control flow ---//\n";

2475 });

2476 return success();

2477 }

2478

2479

2480

2481

2482

2483 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {

2484 if (!debugLine)

2485 return unknownLoc;

2486

2487 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();

2488 if (fileName.empty())

2489 fileName = "";

2491 debugLine->column);

2492 }

2493

2494 LogicalResult

2495 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {

2496

2497

2498

2499

2500

2501 if (operands.size() != 3)

2502 return emitError(unknownLoc, "OpLine must have 3 operands");

2503 debugLine = DebugLine{operands[0], operands[1], operands[2]};

2504 return success();

2505 }

2506

2507 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }

2508

2509 LogicalResult

2510 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {

2511 if (operands.size() < 2)

2512 return emitError(unknownLoc, "OpString needs at least 2 operands");

2513

2514 if (!debugInfoMap.lookup(operands[0]).empty())

2516 "duplicate debug string found for result ")

2517 << operands[0];

2518

2519 unsigned wordIndex = 1;

2521 if (wordIndex != operands.size())

2523 "unexpected trailing words in OpString instruction");

2524

2525 debugInfoMap[operands[0]] = debugString;

2526 return success();

2527 }

static bool isLoop(Operation *op)

Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...

static bool isFnEntryBlock(Block *block)

Returns true if the given block is a function entry block.

#define MIN_VERSION_CASE(v)

LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)

static llvm::ManagedStatic< PassManagerOptions > options

static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)

Attributes are known-constant values of operations.

This class represents an argument of a Block.

Location getLoc() const

Return the location for this argument.

Block represents an ordered list of Operations.

void erase()

Unlink this Block from its parent region and delete it.

Block * splitBlock(iterator splitBefore)

Split the block into two blocks before the specified operation or iterator.

Region * getParent() const

Provide a 'getParent' method for ilist_node_with_parent methods.

Operation * getTerminator()

Get the terminator operation of this block.

void print(raw_ostream &os)

bool mightHaveTerminator()

Check whether this block might have a terminator.

BlockArgListType getArguments()

bool isEntryBlock()

Return if this block is the entry block in the parent region.

void push_back(Operation *op)

Operation * getParentOp()

Returns the closest surrounding operation that contains this block.

StringAttr getStringAttr(const Twine &bytes)

ArrayAttr getArrayAttr(ArrayRef< Attribute > value)

Attr getAttr(Args &&...args)

Get or construct an instance of the attribute Attr with provided arguments.

static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)

Constructs a dense elements attribute from an array of element values.

static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)

A symbol reference with a reference path containing a single element.

This is a utility class for mapping one set of IR entities to another.

void map(Value from, Value to)

Inserts a new mapping for 'from' to 'to'.

auto lookupOrNull(T from) const

Lookup a mapped value within the map.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

NamedAttribute represents a combination of a name and an Attribute value.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Operation is the basic unit of execution within MLIR.

bool use_empty()

Returns true if this operation has no uses.

Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())

Create a deep copy of this operation, remapping any operands that use values outside of the operation...

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)

Create a new Operation with the specific fields.

Region & getRegion(unsigned index)

Returns the region held by this operation at position 'index'.

MutableArrayRef< BlockOperand > getBlockOperands()

MutableArrayRef< OpOperand > getOpOperands()

result_range getResults()

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

void erase()

Remove this operation from its parent block and delete it.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

BlockListType & getBlocks()

Location getLoc()

Return a location for this region.

BlockListType::iterator iterator

void takeBody(Region &other)

Takes body of another region (that region will have no body after this operation completes).

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isIntOrFloat() const

Return true if this is an integer (of any signedness) or a float type.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

A utility result that is used to signal how to proceed with an ongoing walk:

static WalkResult advance()

static ArrayType get(Type elementType, unsigned elementCount)

static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)

LogicalResult deserialize()

Deserializes the remembered SPIR-V binary module.

Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context, const DeserializationOptions &options)

Creates a deserializer for the given SPIR-V binary module.

OwningOpRef< spirv::ModuleOp > collect()

Collects the final SPIR-V ModuleOp.

static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)

static MatrixType get(Type columnType, uint32_t columnCount)

static PointerType get(Type pointeeType, StorageClass storageClass)

static RuntimeArrayType get(Type elementType)

static SampledImageType get(Type imageType)

static StructType getIdentified(MLIRContext *context, StringRef identifier)

Construct an identified StructType.

static StructType getEmpty(MLIRContext *context, StringRef identifier="")

Construct a (possibly identified) StructType with no members.

static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})

Construct a literal StructType with at least one member.

static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)

Gets a VerCapExtAttr instance.

The OpAsmOpInterface, see OpAsmInterface.td for more details.

constexpr uint32_t kMagicNumber

SPIR-V magic number.

StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)

Decodes a string literal in words starting at wordIndex.

DenseMap< Block *, BlockMergeInfo > BlockMergeInfoMap

Map from a selection/loop's header block to its merge (and continue) target.

constexpr unsigned kHeaderWordCount

SPIR-V binary header word count.

Include the generated interface declarations.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

static std::string debugString(T &&op)

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

This represents an operation in an abstracted form, suitable for use with the builder APIs.