MLIR: lib/Dialect/SPIRV/IR/SPIRVOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

14

17

32 #include "llvm/ADT/APFloat.h"

33 #include "llvm/ADT/APInt.h"

34 #include "llvm/ADT/ArrayRef.h"

35 #include "llvm/ADT/STLExtras.h"

36 #include "llvm/ADT/StringExtras.h"

37 #include "llvm/ADT/TypeSwitch.h"

38 #include "llvm/Support/InterleavedRange.h"

39 #include

40 #include

41 #include

42 #include <type_traits>

43

44 using namespace mlir;

46

47

48

49

50

52 auto constOp = dyn_cast_or_nullspirv::ConstantOp(op);

53 if (!constOp) {

54 return failure();

55 }

56 auto valueAttr = constOp.getValue();

57 auto integerValueAttr = llvm::dyn_cast(valueAttr);

58 if (!integerValueAttr) {

59 return failure();

60 }

61

62 if (integerValueAttr.getType().isSignlessInteger())

63 value = integerValueAttr.getInt();

64 else

65 value = integerValueAttr.getSInt();

66

67 return success();

68 }

69

70 LogicalResult

72 spirv::MemorySemantics memorySemantics) {

73

74

75

76

77

78

79 auto atMostOneInSet = spirv::MemorySemantics::Acquire |

80 spirv::MemorySemantics::Release |

81 spirv::MemorySemantics::AcquireRelease |

82 spirv::MemorySemantics::SequentiallyConsistent;

83

84 auto bitCount =

85 llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));

86 if (bitCount > 1) {

88 "expected at most one of these four memory constraints "

89 "to be set: `Acquire`, `Release`,"

90 "`AcquireRelease` or `SequentiallyConsistent`");

91 }

92 return success();

93 }

94

97

99 stringifyDecoration(spirv::Decoration::DescriptorSet));

100 auto bindingName = llvm::convertToSnakeFromCamelCase(

101 stringifyDecoration(spirv::Decoration::Binding));

104 if (descriptorSet && binding) {

107 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()

108 << ")";

109 }

110

111

112 auto builtInName = llvm::convertToSnakeFromCamelCase(

113 stringifyDecoration(spirv::Decoration::BuiltIn));

114 if (auto builtin = op->getAttrOfType(builtInName)) {

115 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";

116 elidedAttrs.push_back(builtInName);

117 }

118

120 }

121

126

127

133 return failure();

134 auto fnType = llvm::dyn_cast(type);

135 if (!fnType) {

136 parser.emitError(loc, "expected function type");

137 return failure();

138 }

140 return failure();

141 result.addTypes(fnType.getResults());

142 return success();

143 }

149 }

150

152 assert(op->getNumResults() == 1 && "op should have one result");

153

154

155

158 [&](Type type) { return type != resultType; })) {

160 return;

161 }

162

163 p << ' ';

166

167 p << " : " << resultType;

168 }

169

170 template

173 auto valType = val.getType();

174 if (auto valVecTy = llvm::dyn_cast(valType))

175 valType = valVecTy.getElementType();

176

177 if (valType !=

178 llvm::castspirv::PointerType(ptr.getType()).getPointeeType()) {

179 return op.emitOpError("mismatch in result type and pointer type");

180 }

181 return success();

182 }

183

184

185

186

190 if (indices.empty()) {

191 emitErrorFn("expected at least one index for spirv.CompositeExtract");

192 return nullptr;

193 }

194

195 for (auto index : indices) {

196 if (auto cType = llvm::dyn_castspirv::CompositeType(type)) {

197 if (cType.hasCompileTimeKnownNumElements() &&

198 (index < 0 ||

199 static_cast<uint64_t>(index) >= cType.getNumElements())) {

200 emitErrorFn("index ") << index << " out of bounds for " << type;

201 return nullptr;

202 }

203 type = cType.getElementType(index);

204 } else {

205 emitErrorFn("cannot extract from non-composite type ")

206 << type << " with index " << index;

207 return nullptr;

208 }

209 }

210 return type;

211 }

212

216 auto indicesArrayAttr = llvm::dyn_cast(indices);

217 if (!indicesArrayAttr) {

218 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");

219 return nullptr;

220 }

221 if (indicesArrayAttr.empty()) {

222 emitErrorFn("expected at least one index for spirv.CompositeExtract");

223 return nullptr;

224 }

225

227 for (auto indexAttr : indicesArrayAttr) {

228 auto indexIntAttr = llvm::dyn_cast(indexAttr);

229 if (!indexIntAttr) {

230 emitErrorFn("expected an 32-bit integer for index, but found '")

231 << indexAttr << "'";

232 return nullptr;

233 }

234 indexVals.push_back(indexIntAttr.getInt());

235 }

237 }

238

242 };

244 }

245

247 SMLoc loc) {

249 return parser.emitError(loc, err);

250 };

252 }

253

254 template

256 auto resultType = llvm::castspirv::StructType(op.getType());

257 if (resultType.getNumElements() != 2)

258 return op.emitOpError("expected result struct type containing two members");

259

260 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),

261 resultType.getElementType(0),

262 resultType.getElementType(1)}))

263 return op.emitOpError(

264 "expected all operand types and struct member types are the same");

265

266 return success();

267 }

268

274 return failure();

275

276 Type resultType;

278 if (parser.parseType(resultType))

279 return failure();

280

281 auto structType = llvm::dyn_castspirv::StructType(resultType);

282 if (!structType || structType.getNumElements() != 2)

283 return parser.emitError(loc, "expected spirv.struct type with two members");

284

287 return failure();

288

290 return success();

291 }

292

295 printer << ' ';

299 }

300

303 return op->emitError("expected the same type for the first operand and "

304 "result, but provided ")

307 }

308 return success();

309 }

310

311

312

313

314

316 spirv::GlobalVariableOp var) {

318 }

319

321 auto varOp = dyn_cast_or_nullspirv::GlobalVariableOp(

323 getVariableAttr()));

324 if (!varOp) {

325 return emitOpError("expected spirv.GlobalVariable symbol");

326 }

327 if (getPointer().getType() != varOp.getType()) {

328 return emitOpError(

329 "result type mismatch with the referenced global variable's type");

330 }

331 return success();

332 }

333

334

335

336

337

339 operand_range constituents = this->getConstituents();

340

341

342

343

344

345

346

347 auto coopElementType =

350 [](auto coopType) { return coopType.getElementType(); })

351 .Default([](Type) { return nullptr; });

352

353

354 if (coopElementType) {

355 if (constituents.size() != 1)

356 return emitOpError("has incorrect number of operands: expected ")

357 << "1, but provided " << constituents.size();

358 if (coopElementType != constituents.front().getType())

359 return emitOpError("operand type mismatch: expected operand type ")

360 << coopElementType << ", but provided "

361 << constituents.front().getType();

362 return success();

363 }

364

365

366 auto cType = llvm::castspirv::CompositeType(getType());

367 if (constituents.size() == cType.getNumElements()) {

368 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {

369 if (constituents[index].getType() != cType.getElementType(index)) {

370 return emitOpError("operand type mismatch: expected operand type ")

371 << cType.getElementType(index) << ", but provided "

372 << constituents[index].getType();

373 }

374 }

375 return success();

376 }

377

378

379 auto resultType = llvm::dyn_cast(cType);

380 if (!resultType)

381 return emitOpError(

382 "expected to return a vector or cooperative matrix when the number of "

383 "constituents is less than what the result needs");

384

386 for (Value component : constituents) {

387 if (!llvm::isa(component.getType()) &&

388 !component.getType().isIntOrFloat())

389 return emitOpError("operand type mismatch: expected operand to have "

390 "a scalar or vector type, but provided ")

391 << component.getType();

392

393 Type elementType = component.getType();

394 if (auto vectorType = llvm::dyn_cast(component.getType())) {

395 sizes.push_back(vectorType.getNumElements());

396 elementType = vectorType.getElementType();

397 } else {

398 sizes.push_back(1);

399 }

400

401 if (elementType != resultType.getElementType())

402 return emitOpError("operand element type mismatch: expected to be ")

403 << resultType.getElementType() << ", but provided " << elementType;

404 }

405 unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);

406 if (totalCount != cType.getNumElements())

407 return emitOpError("has incorrect number of operands: expected ")

408 << cType.getNumElements() << ", but provided " << totalCount;

409 return success();

410 }

411

412

413

414

415

420 auto elementType =

422 if (!elementType) {

423 return;

424 }

425 build(builder, state, elementType, composite, indexAttr);

426 }

427

432 StringRef indicesAttrName =

433 spirv::CompositeExtractOp::getIndicesAttrName(result.name);

434 Type compositeType;

435 SMLoc attrLocation;

436

442 return failure();

443 }

444

445 Type resultType =

446 getElementType(compositeType, indicesAttr, parser, attrLocation);

447 if (!resultType) {

448 return failure();

449 }

451 return success();

452 }

453

455 printer << ' ' << getComposite() << getIndices() << " : "

456 << getComposite().getType();

457 }

458

460 auto indicesArrayAttr = llvm::dyn_cast(getIndices());

461 auto resultType =

463 if (!resultType)

464 return failure();

465

466 if (resultType != getType()) {

467 return emitOpError("invalid result type: expected ")

468 << resultType << " but provided " << getType();

469 }

470

471 return success();

472 }

473

474

475

476

477

482 build(builder, state, composite.getType(), object, composite, indexAttr);

483 }

484

488 Type objectType, compositeType;

490 StringRef indicesAttrName =

491 spirv::CompositeInsertOp::getIndicesAttrName(result.name);

493

494 return failure(

499 parser.resolveOperands(operands, {objectType, compositeType}, loc,

502 }

503

505 auto indicesArrayAttr = llvm::dyn_cast(getIndices());

506 auto objectType =

508 if (!objectType)

509 return failure();

510

511 if (objectType != getObject().getType()) {

512 return emitOpError("object operand type should be ")

513 << objectType << ", but found " << getObject().getType();

514 }

515

517 return emitOpError("result type should be the same as "

518 "the composite type, but found ")

519 << getComposite().getType() << " vs " << getType();

520 }

521

522 return success();

523 }

524

526 printer << " " << getObject() << ", " << getComposite() << getIndices()

527 << " : " << getObject().getType() << " into "

528 << getComposite().getType();

529 }

530

531

532

533

534

538 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);

540 return failure();

541

543 if (auto typedAttr = llvm::dyn_cast(value))

544 type = typedAttr.getType();

545 if (llvm::isa<NoneType, TensorType>(type)) {

547 return failure();

548 }

549

551 }

552

554 printer << ' ' << getValue();

555 if (llvm::isaspirv::ArrayType(getType()))

556 printer << " : " << getType();

557 }

558

560 Type opType) {

561 if (isaspirv::CooperativeMatrixType(opType)) {

562 auto denseAttr = dyn_cast(value);

563 if (!denseAttr || !denseAttr.isSplat())

564 return op.emitOpError("expected a splat dense attribute for cooperative "

565 "matrix constant, but found ")

566 << denseAttr;

567 }

568 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {

569 auto valueType = llvm::cast(value).getType();

570 if (valueType != opType)

571 return op.emitOpError("result type (")

572 << opType << ") does not match value type (" << valueType << ")";

573 return success();

574 }

575 if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {

576 auto valueType = llvm::cast(value).getType();

577 if (valueType == opType)

578 return success();

579 auto arrayType = llvm::dyn_castspirv::ArrayType(opType);

580 auto shapedType = llvm::dyn_cast(valueType);

581 if (!arrayType)

582 return op.emitOpError("result or element type (")

583 << opType << ") does not match value type (" << valueType

584 << "), must be the same or spirv.array";

585

586 int numElements = arrayType.getNumElements();

587 auto opElemType = arrayType.getElementType();

588 while (auto t = llvm::dyn_castspirv::ArrayType(opElemType)) {

589 numElements *= t.getNumElements();

590 opElemType = t.getElementType();

591 }

592 if (!opElemType.isIntOrFloat())

593 return op.emitOpError("only support nested array result type");

594

595 auto valueElemType = shapedType.getElementType();

596 if (valueElemType != opElemType) {

597 return op.emitOpError("result element type (")

598 << opElemType << ") does not match value element type ("

599 << valueElemType << ")";

600 }

601

602 if (numElements != shapedType.getNumElements()) {

603 return op.emitOpError("result number of elements (")

604 << numElements << ") does not match value number of elements ("

605 << shapedType.getNumElements() << ")";

606 }

607 return success();

608 }

609 if (auto arrayAttr = llvm::dyn_cast(value)) {

610 auto arrayType = llvm::dyn_castspirv::ArrayType(opType);

611 if (!arrayType)

612 return op.emitOpError(

613 "must have spirv.array result type for array value");

614 Type elemType = arrayType.getElementType();

615 for (Attribute element : arrayAttr.getValue()) {

616

618 return failure();

619 }

620 return success();

621 }

622 return op.emitOpError("cannot have attribute: ") << value;

623 }

624

626

627

628

630 }

631

632 bool spirv::ConstantOp::isBuildableWith(Type type) {

633

634 if (!llvm::isaspirv::SPIRVType(type))

635 return false;

636

637 if (isa(type.getDialect())) {

638

639 return llvm::isaspirv::ArrayType(type);

640 }

641

642 return true;

643 }

644

647 if (auto intType = llvm::dyn_cast(type)) {

648 unsigned width = intType.getWidth();

649 if (width == 1)

650 return builder.createspirv::ConstantOp(loc, type,

652 return builder.createspirv::ConstantOp(

653 loc, type, builder.getIntegerAttr(type, APInt(width, 0)));

654 }

655 if (auto floatType = llvm::dyn_cast(type)) {

656 return builder.createspirv::ConstantOp(

657 loc, type, builder.getFloatAttr(floatType, 0.0));

658 }

659 if (auto vectorType = llvm::dyn_cast(type)) {

660 Type elemType = vectorType.getElementType();

661 if (llvm::isa(elemType)) {

662 return builder.createspirv::ConstantOp(

663 loc, type,

666 }

667 if (llvm::isa(elemType)) {

668 return builder.createspirv::ConstantOp(

669 loc, type,

672 }

673 }

674

675 llvm_unreachable("unimplemented types for ConstantOp::getZero()");

676 }

677

678 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,

680 if (auto intType = llvm::dyn_cast(type)) {

681 unsigned width = intType.getWidth();

682 if (width == 1)

683 return builder.createspirv::ConstantOp(loc, type,

685 return builder.createspirv::ConstantOp(

686 loc, type, builder.getIntegerAttr(type, APInt(width, 1)));

687 }

688 if (auto floatType = llvm::dyn_cast(type)) {

689 return builder.createspirv::ConstantOp(

690 loc, type, builder.getFloatAttr(floatType, 1.0));

691 }

692 if (auto vectorType = llvm::dyn_cast(type)) {

693 Type elemType = vectorType.getElementType();

694 if (llvm::isa(elemType)) {

695 return builder.createspirv::ConstantOp(

696 loc, type,

699 }

700 if (llvm::isa(elemType)) {

701 return builder.createspirv::ConstantOp(

702 loc, type,

705 }

706 }

707

708 llvm_unreachable("unimplemented types for ConstantOp::getOne()");

709 }

710

711 void mlir::spirv::ConstantOp::getAsmResultNames(

714

716 llvm::raw_svector_ostream specialName(specialNameBuffer);

717 specialName << "cst";

718

719 IntegerType intTy = llvm::dyn_cast(type);

720

721 if (IntegerAttr intCst = llvm::dyn_cast(getValue())) {

722 if (intTy && intTy.getWidth() == 1) {

723 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));

724 }

725

726 if (intTy.isSignless()) {

727 specialName << intCst.getInt();

728 } else if (intTy.isUnsigned()) {

729 specialName << intCst.getUInt();

730 } else {

731 specialName << intCst.getSInt();

732 }

733 }

734

735 if (intTy || llvm::isa(type)) {

736 specialName << '_' << type;

737 }

738

739 if (auto vecType = llvm::dyn_cast(type)) {

740 specialName << "_vec_";

741 specialName << vecType.getDimSize(0);

742

743 Type elementType = vecType.getElementType();

744

745 if (llvm::isa(elementType) ||

746 llvm::isa(elementType)) {

747 specialName << "x" << elementType;

748 }

749 }

750

751 setNameFn(getResult(), specialName.str());

752 }

753

754 void mlir::spirv::AddressOfOp::getAsmResultNames(

757 llvm::raw_svector_ostream specialName(specialNameBuffer);

758 specialName << getVariable() << "_addr";

759 setNameFn(getResult(), specialName.str());

760 }

761

762

763

764

765

768 }

769

770

771

772

773

775 spirv::ExecutionModel executionModel,

776 spirv::FuncOp function,

778 build(builder, state,

781 }

782

785 spirv::ExecutionModel execModel;

787

789 if (parseEnumStrAttrspirv::ExecutionModelAttr(execModel, parser, result) ||

791 return failure();

792 }

793

795

797

798 FlatSymbolRefAttr var;

799 NamedAttrList attrs;

800 if (parser.parseAttribute(var, Type(), "var_symbol", attrs))

801 return failure();

802 interfaceVars.push_back(var);

803 return success();

804 }))

805 return failure();

806 }

807 result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),

809 return success();

810 }

811

813 printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";

815 auto interfaceVars = getInterface().getValue();

816 if (!interfaceVars.empty())

817 printer << ", " << llvm::interleaved(interfaceVars);

818 }

819

821

822

823 return success();

824 }

825

826

827

828

829

831 spirv::FuncOp function,

832 spirv::ExecutionMode executionMode,

837 }

838

841 spirv::ExecutionMode execMode;

844 parseEnumStrAttrspirv::ExecutionModeAttr(execMode, parser, result)) {

845 return failure();

846 }

847

853 if (parser.parseAttribute(value, i32Type, "value", attr)) {

854 return failure();

855 }

856 values.push_back(llvm::cast(value).getInt());

857 }

858 StringRef valuesAttrName =

859 spirv::ExecutionModeOp::getValuesAttrName(result.name);

862 return success();

863 }

864

866 printer << " ";

868 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";

869 ArrayAttr values = this->getValues();

870 if (!values.empty())

871 printer << ", " << llvm::interleaved(values.getAsValueRange());

872 }

873

874

875

876

877

883

884

885 StringAttr nameAttr;

888 return failure();

889

890

891 bool isVariadic = false;

893 parser, false, entryArgs, isVariadic, resultTypes,

894 resultAttrs))

895 return failure();

896

898 for (auto &arg : entryArgs)

899 argTypes.push_back(arg.type);

900 auto fnType = builder.getFunctionType(argTypes, resultTypes);

903

904

905 spirv::FunctionControl fnControl;

906 if (parseEnumStrAttrspirv::FunctionControlAttr(fnControl, parser, result))

907 return failure();

908

909

911 return failure();

912

913

914 assert(resultAttrs.size() == resultTypes.size());

916 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),

917 getResAttrsAttrName(result.name));

918

919

923 return failure(parseResult.has_value() && failed(*parseResult));

924 }

925

927

928 printer << " ";

930 auto fnType = getFunctionType();

932 printer, *this, fnType.getInputs(),

933 false, fnType.getResults());

934 printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())

935 << "\"";

937 printer, *this,

938 {spirv::attributeNamespirv::FunctionControl(),

939 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),

940 getFunctionControlAttrName()});

941

942

943 Region &body = this->getBody();

944 if (!body.empty()) {

945 printer << ' ';

946 printer.printRegion(body, false,

947 true);

948 }

949 }

950

951 LogicalResult spirv::FuncOp::verifyType() {

952 FunctionType fnType = getFunctionType();

953 if (fnType.getNumResults() > 1)

954 return emitOpError("cannot have more than one result");

955

956 auto hasDecorationAttr = [&](spirv::Decoration decoration,

957 unsigned argIndex) {

958 auto func = llvm::cast(getOperation());

959 for (auto argAttr : cast(func).getArgAttrs(argIndex)) {

960 if (argAttr.getName() != spirv::DecorationAttr::name)

961 continue;

962 if (auto decAttr = dyn_castspirv::DecorationAttr(argAttr.getValue()))

963 return decAttr.getValue() == decoration;

964 }

965 return false;

966 };

967

968 for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {

969 Type param = fnType.getInputs()[i];

970 auto inputPtrType = dyn_castspirv::PointerType(param);

971 if (!inputPtrType)

972 continue;

973

974 auto pointeePtrType =

975 dyn_castspirv::PointerType(inputPtrType.getPointeeType());

976 if (pointeePtrType) {

977

978

979

980

981

982 if (pointeePtrType.getStorageClass() !=

983 spirv::StorageClass::PhysicalStorageBuffer)

984 continue;

985

986 bool hasAliasedPtr =

987 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);

988 bool hasRestrictPtr =

989 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);

990 if (!hasAliasedPtr && !hasRestrictPtr)

991 return emitOpError()

992 << "with a pointer points to a physical buffer pointer must "

993 "be decorated either 'AliasedPointer' or 'RestrictPointer'";

994 continue;

995 }

996

997

998

999

1000 if (auto pointeeArrayType =

1001 dyn_castspirv::ArrayType(inputPtrType.getPointeeType())) {

1002 pointeePtrType =

1003 dyn_castspirv::PointerType(pointeeArrayType.getElementType());

1004 } else {

1005 pointeePtrType = inputPtrType;

1006 }

1007

1008 if (!pointeePtrType || pointeePtrType.getStorageClass() !=

1009 spirv::StorageClass::PhysicalStorageBuffer)

1010 continue;

1011

1012 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);

1013 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);

1014 if (!hasAliased && !hasRestrict)

1015 return emitOpError() << "with physical buffer pointer must be decorated "

1016 "either 'Aliased' or 'Restrict'";

1017 }

1018

1019 return success();

1020 }

1021

1022 LogicalResult spirv::FuncOp::verifyBody() {

1023 FunctionType fnType = getFunctionType();

1024 if (!isExternal()) {

1025 Block &entryBlock = front();

1026

1027 unsigned numArguments = this->getNumArguments();

1029 return emitOpError("entry block must have ")

1030 << numArguments << " arguments to match function signature";

1031

1032 for (auto [index, fnArgType, blockArgType] :

1034 if (blockArgType != fnArgType) {

1035 return emitOpError("type of entry block argument #")

1036 << index << '(' << blockArgType

1037 << ") must match the type of the corresponding argument in "

1038 << "function signature(" << fnArgType << ')';

1039 }

1040 }

1041 }

1042

1044 if (auto retOp = dyn_castspirv::ReturnOp(op)) {

1045 if (fnType.getNumResults() != 0)

1046 return retOp.emitOpError("cannot be used in functions returning value");

1047 } else if (auto retOp = dyn_castspirv::ReturnValueOp(op)) {

1048 if (fnType.getNumResults() != 1)

1049 return retOp.emitOpError(

1050 "returns 1 value but enclosing function requires ")

1051 << fnType.getNumResults() << " results";

1052

1053 auto retOperandType = retOp.getValue().getType();

1054 auto fnResultType = fnType.getResult(0);

1055 if (retOperandType != fnResultType)

1056 return retOp.emitOpError(" return value's type (")

1057 << retOperandType << ") mismatch with function's result type ("

1058 << fnResultType << ")";

1059 }

1061 });

1062

1063

1064

1065 return failure(walkResult.wasInterrupted());

1066 }

1067

1069 StringRef name, FunctionType type,

1070 spirv::FunctionControl control,

1074 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));

1075 state.addAttribute(spirv::attributeNamespirv::FunctionControl(),

1076 builder.getAttrspirv::FunctionControlAttr(control));

1077 state.attributes.append(attrs.begin(), attrs.end());

1078 state.addRegion();

1079 }

1080

1081

1082

1083

1084

1088 }

1090

1091

1092

1093

1094

1098 }

1100

1101

1102

1103

1104

1108 }

1110

1111

1112

1113

1114

1117 }

1119

1120

1121

1122

1123

1125 Type type, StringRef name,

1126 unsigned descriptorSet, unsigned binding) {

1128 state.addAttribute(

1129 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),

1131 state.addAttribute(

1132 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),

1134 }

1135

1137 Type type, StringRef name,

1138 spirv::BuiltIn builtin) {

1140 state.addAttribute(

1141 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),

1142 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));

1143 }

1144

1147

1148 StringAttr nameAttr;

1149 StringRef initializerAttrName =

1150 spirv::GlobalVariableOp::getInitializerAttrName(result.name);

1153 return failure();

1154 }

1155

1156

1163 return failure();

1164 }

1165

1167 return failure();

1168 }

1169

1171 StringRef typeAttrName =

1172 spirv::GlobalVariableOp::getTypeAttrName(result.name);

1175 return failure();

1176 }

1177 if (!llvm::isaspirv::PointerType(type)) {

1178 return parser.emitError(loc, "expected spirv.ptr type");

1179 }

1181

1182 return success();

1183 }

1184

1187 spirv::attributeNamespirv::StorageClass()};

1188

1189

1190 printer << ' ';

1193

1194 StringRef initializerAttrName = this->getInitializerAttrName();

1195

1196 if (auto initializer = this->getInitializer()) {

1197 printer << " " << initializerAttrName << '(';

1199 printer << ')';

1200 elidedAttrs.push_back(initializerAttrName);

1201 }

1202

1203 StringRef typeAttrName = this->getTypeAttrName();

1204 elidedAttrs.push_back(typeAttrName);

1206 printer << " : " << getType();

1207 }

1208

1210 if (!llvm::isaspirv::PointerType(getType()))

1211 return emitOpError("result must be of a !spv.ptr type");

1212

1213

1214

1215

1216

1217 auto storageClass = this->storageClass();

1218 if (storageClass == spirv::StorageClass::Generic ||

1219 storageClass == spirv::StorageClass::Function) {

1220 return emitOpError("storage class cannot be '")

1221 << stringifyStorageClass(storageClass) << "'";

1222 }

1223

1225 this->getInitializerAttrName())) {

1227 (*this)->getParentOp(), init.getAttr());

1228

1229

1230

1231 if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,

1232 spirv::SpecConstantCompositeOp>(initOp)) {

1233 return emitOpError("initializer must be result of a "

1234 "spirv.SpecConstant or spirv.GlobalVariable or "

1235 "spirv.SpecConstantCompositeOp op");

1236 }

1237 }

1238

1239 return success();

1240 }

1241

1242

1243

1244

1245

1248 return failure();

1249

1250 return success();

1251 }

1252

1253

1254

1255

1256

1259

1260 spirv::StorageClass storageClass;

1263 Type elementType;

1266 parser.parseType(elementType)) {

1267 return failure();

1268 }

1269

1271 if (auto valVecTy = llvm::dyn_cast(elementType))

1273

1274 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,

1276 return failure();

1277 }

1278 return success();

1279 }

1280

1282 printer << " " << getPtr() << ", " << getValue() << " : "

1283 << getValue().getType();

1284 }

1285

1288 return failure();

1289

1290 return success();

1291 }

1292

1293

1294

1295

1296

1299 }

1300

1304 }

1305

1308 }

1309

1310

1311

1312

1313

1316 }

1317

1321 }

1322

1325 }

1326

1327

1328

1329

1330

1333 }

1334

1338 }

1339

1342 }

1343

1344

1345

1346

1347

1350 }

1351

1355 }

1356

1359 }

1360

1361

1362

1363

1364

1367 }

1368

1369

1370

1371

1372

1374 std::optional name) {

1377 if (name) {

1380 }

1381 }

1382

1384 spirv::AddressingModel addressingModel,

1385 spirv::MemoryModel memoryModel,

1386 std::optional vceTriple,

1387 std::optional name) {

1388 state.addAttribute(

1389 "addressing_model",

1390 builder.getAttrspirv::AddressingModelAttr(addressingModel));

1391 state.addAttribute("memory_model",

1392 builder.getAttrspirv::MemoryModelAttr(memoryModel));

1395 if (vceTriple)

1396 state.addAttribute(getVCETripleAttrName(), *vceTriple);

1397 if (name)

1400 }

1401

1405

1406

1407 StringAttr nameAttr;

1410

1411

1412 spirv::AddressingModel addrModel;

1413 spirv::MemoryModel memoryModel;

1414 if (spirv::parseEnumKeywordAttrspirv::AddressingModelAttr(addrModel, parser,

1415 result) ||

1416 spirv::parseEnumKeywordAttrspirv::MemoryModelAttr(memoryModel, parser,

1417 result))

1418 return failure();

1419

1423 spirv::ModuleOp::getVCETripleAttrName(),

1425 return failure();

1426 }

1427

1429 parser.parseRegion(*body, {}))

1430 return failure();

1431

1432

1433 if (body->empty())

1435

1436 return success();

1437 }

1438

1440 if (std::optional name = getName()) {

1441 printer << ' ';

1443 }

1444

1446

1447 printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "

1449 auto addressingModelAttrName = spirv::attributeNamespirv::AddressingModel();

1450 auto memoryModelAttrName = spirv::attributeNamespirv::MemoryModel();

1451 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,

1453

1454 if (std::optionalspirv::VerCapExtAttr triple = getVceTriple()) {

1455 printer << " requires " << *triple;

1456 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());

1457 }

1458

1460 printer << ' ';

1462 }

1463

1464 LogicalResult spirv::ModuleOp::verifyRegions() {

1465 Dialect *dialect = (*this)->getDialect();

1467 entryPoints;

1469

1470 for (auto &op : *getBody()) {

1472 return op.emitError("'spirv.module' can only contain spirv.* ops");

1473

1474

1475

1476

1477 if (auto entryPointOp = dyn_castspirv::EntryPointOp(op)) {

1478 auto funcOp = table.lookupspirv::FuncOp(entryPointOp.getFn());

1479 if (!funcOp) {

1480 return entryPointOp.emitError("function '")

1481 << entryPointOp.getFn() << "' not found in 'spirv.module'";

1482 }

1483 if (auto interface = entryPointOp.getInterface()) {

1484 for (Attribute varRef : interface) {

1485 auto varSymRef = llvm::dyn_cast(varRef);

1486 if (!varSymRef) {

1487 return entryPointOp.emitError(

1488 "expected symbol reference for interface "

1489 "specification instead of '")

1490 << varRef;

1491 }

1492 auto variableOp =

1493 table.lookupspirv::GlobalVariableOp(varSymRef.getValue());

1494 if (!variableOp) {

1495 return entryPointOp.emitError("expected spirv.GlobalVariable "

1496 "symbol reference instead of'")

1497 << varSymRef << "'";

1498 }

1499 }

1500 }

1501

1502 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(

1503 funcOp, entryPointOp.getExecutionModel());

1504 if (!entryPoints.try_emplace(key, entryPointOp).second)

1505 return entryPointOp.emitError("duplicate of a previous EntryPointOp");

1506 } else if (auto funcOp = dyn_castspirv::FuncOp(op)) {

1507

1508

1509

1510 auto linkageAttr = funcOp.getLinkageAttributes();

1511 auto hasImportLinkage =

1512 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==

1513 spirv::LinkageType::Import);

1514 if (funcOp.isExternal() && !hasImportLinkage)

1516 "'spirv.module' cannot contain external functions "

1517 "without 'Import' linkage_attributes (LinkageAttributes)");

1518

1519

1520 for (auto &block : funcOp)

1521 for (auto &op : block) {

1524 "functions in 'spirv.module' can only contain spirv.* ops");

1525 }

1526 }

1527 }

1528

1529 return success();

1530 }

1531

1532

1533

1534

1535

1538 (*this)->getParentOp(), getSpecConstAttr());

1539 Type constType;

1540

1541 auto specConstOp = dyn_cast_or_nullspirv::SpecConstantOp(specConstSym);

1542 if (specConstOp)

1543 constType = specConstOp.getDefaultValue().getType();

1544

1545 auto specConstCompositeOp =

1546 dyn_cast_or_nullspirv::SpecConstantCompositeOp(specConstSym);

1547 if (specConstCompositeOp)

1548 constType = specConstCompositeOp.getType();

1549

1550 if (!specConstOp && !specConstCompositeOp)

1551 return emitOpError(

1552 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");

1553

1554 if (getReference().getType() != constType)

1555 return emitOpError("result type mismatch with the referenced "

1556 "specialization constant's type");

1557

1558 return success();

1559 }

1560

1561

1562

1563

1564

1567 StringAttr nameAttr;

1569 StringRef defaultValueAttrName =

1570 spirv::SpecConstantOp::getDefaultValueAttrName(result.name);

1571

1574 return failure();

1575

1576

1578 IntegerAttr specIdAttr;

1582 return failure();

1583 }

1584

1587 return failure();

1588

1589 return success();

1590 }

1591

1593 printer << ' ';

1595 if (auto specID = (*this)->getAttrOfType(kSpecIdAttrName))

1596 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';

1597 printer << " = " << getDefaultValue();

1598 }

1599

1601 if (auto specID = (*this)->getAttrOfType(kSpecIdAttrName))

1602 if (specID.getValue().isNegative())

1603 return emitOpError("SpecId cannot be negative");

1604

1605 auto value = getDefaultValue();

1606 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {

1607

1608 if (!llvm::isaspirv::SPIRVType(value.getType()))

1609 return emitOpError("default value bitwidth disallowed");

1610 return success();

1611 }

1612 return emitOpError(

1613 "default value can only be a bool, integer, or float scalar");

1614 }

1615

1616

1617

1618

1619

1621 VectorType resultType = llvm::cast(getType());

1622

1623 size_t numResultElements = resultType.getNumElements();

1624 if (numResultElements != getComponents().size())

1625 return emitOpError("result type element count (")

1626 << numResultElements

1627 << ") mismatch with the number of component selectors ("

1628 << getComponents().size() << ")";

1629

1630 size_t totalSrcElements =

1631 llvm::cast(getVector1().getType()).getNumElements() +

1632 llvm::cast(getVector2().getType()).getNumElements();

1633

1634 for (const auto &selector : getComponents().getAsValueRange()) {

1635 uint32_t index = selector.getZExtValue();

1636 if (index >= totalSrcElements &&

1637 index != std::numeric_limits<uint32_t>().max())

1638 return emitOpError("component selector ")

1639 << index << " out of range: expected to be in [0, "

1640 << totalSrcElements << ") or 0xffffffff";

1641 }

1642 return success();

1643 }

1644

1645

1646

1647

1648

1650 Type elementType =

1653 [](auto matrixType) { return matrixType.getElementType(); })

1654 .Default([](Type) { return nullptr; });

1655

1656 assert(elementType && "Unhandled type");

1657

1658

1659 if (getScalar().getType() != elementType)

1660 return emitOpError("input matrix components' type and scaling value must "

1661 "have the same type");

1662

1663 return success();

1664 }

1665

1666

1667

1668

1669

1671 auto inputMatrix = llvm::castspirv::MatrixType(getMatrix().getType());

1672 auto resultMatrix = llvm::castspirv::MatrixType(getResult().getType());

1673

1674

1675 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())

1676 return emitError("input matrix rows count must be equal to "

1677 "output matrix columns count");

1678

1679 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())

1680 return emitError("input matrix columns count must be equal to "

1681 "output matrix rows count");

1682

1683

1684 if (inputMatrix.getElementType() != resultMatrix.getElementType())

1685 return emitError("input and output matrices must have the same "

1686 "component type");

1687

1688 return success();

1689 }

1690

1691

1692

1693

1694

1696 auto matrixType = llvm::castspirv::MatrixType(getMatrix().getType());

1697 auto vectorType = llvm::cast(getVector().getType());

1698 auto resultType = llvm::cast(getType());

1699

1700 if (matrixType.getNumColumns() != vectorType.getNumElements())

1701 return emitOpError("matrix columns (")

1702 << matrixType.getNumColumns() << ") must match vector operand size ("

1703 << vectorType.getNumElements() << ")";

1704

1705 if (resultType.getNumElements() != matrixType.getNumRows())

1706 return emitOpError("result size (")

1707 << resultType.getNumElements() << ") must match the matrix rows ("

1708 << matrixType.getNumRows() << ")";

1709

1710 if (matrixType.getElementType() != resultType.getElementType())

1711 return emitOpError("matrix and result element types must match");

1712

1713 return success();

1714 }

1715

1716

1717

1718

1719

1721 auto vectorType = llvm::cast(getVector().getType());

1722 auto matrixType = llvm::castspirv::MatrixType(getMatrix().getType());

1723 auto resultType = llvm::cast(getType());

1724

1725 if (matrixType.getNumRows() != vectorType.getNumElements())

1726 return emitOpError("number of components in vector must equal the number "

1727 "of components in each column in matrix");

1728

1729 if (resultType.getNumElements() != matrixType.getNumColumns())

1730 return emitOpError("number of columns in matrix must equal the number of "

1731 "components in result");

1732

1733 if (matrixType.getElementType() != resultType.getElementType())

1734 return emitOpError("matrix must be a matrix with the same component type "

1735 "as the component type in result");

1736

1737 return success();

1738 }

1739

1740

1741

1742

1743

1745 auto leftMatrix = llvm::castspirv::MatrixType(getLeftmatrix().getType());

1746 auto rightMatrix = llvm::castspirv::MatrixType(getRightmatrix().getType());

1747 auto resultMatrix = llvm::castspirv::MatrixType(getResult().getType());

1748

1749

1750 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())

1751 return emitError("left matrix columns' count must be equal to "

1752 "the right matrix rows' count");

1753

1754

1755 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())

1757 "right and result matrices must have equal columns' count");

1758

1759

1760 if (rightMatrix.getElementType() != resultMatrix.getElementType())

1761 return emitError("right and result matrices' component type must"

1762 " be the same");

1763

1764

1765 if (leftMatrix.getElementType() != resultMatrix.getElementType())

1766 return emitError("left and result matrices' component type"

1767 " must be the same");

1768

1769

1770 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())

1771 return emitError("left and result matrices must have equal rows' count");

1772

1773 return success();

1774 }

1775

1776

1777

1778

1779

1782

1783 StringAttr compositeName;

1786 return failure();

1787

1789 return failure();

1790

1792

1793 do {

1794

1795 const char *attrName = "spec_const";

1798

1800 return failure();

1801

1802 constituents.push_back(specConstRef);

1804

1806 return failure();

1807

1808 StringAttr compositeSpecConstituentsName =

1809 spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);

1810 result.addAttribute(compositeSpecConstituentsName,

1812

1815 return failure();

1816

1817 StringAttr typeAttrName =

1818 spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);

1820

1821 return success();

1822 }

1823

1825 printer << " ";

1827 printer << " (" << llvm::interleaved(this->getConstituents().getValue())

1829 }

1830

1832 auto cType = llvm::dyn_castspirv::CompositeType(getType());

1833 auto constituents = this->getConstituents().getValue();

1834

1835 if (!cType)

1836 return emitError("result type must be a composite type, but provided ")

1838

1839 if (llvm::isaspirv::CooperativeMatrixType(cType))

1840 return emitError("unsupported composite type ") << cType;

1841 if (constituents.size() != cType.getNumElements())

1842 return emitError("has incorrect number of operands: expected ")

1843 << cType.getNumElements() << ", but provided "

1844 << constituents.size();

1845

1846 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {

1847 auto constituent = llvm::cast(constituents[index]);

1848

1849 auto constituentSpecConstOp =

1851 (*this)->getParentOp(), constituent.getAttr()));

1852

1853 if (constituentSpecConstOp.getDefaultValue().getType() !=

1854 cType.getElementType(index))

1855 return emitError("has incorrect types of operands: expected ")

1856 << cType.getElementType(index) << ", but provided "

1857 << constituentSpecConstOp.getDefaultValue().getType();

1858 }

1859

1860 return success();

1861 }

1862

1863

1864

1865

1866

1870

1872 return failure();

1873

1877

1878 if (!wrappedOp)

1879 return failure();

1880

1885

1887

1889 return failure();

1890

1891 return success();

1892 }

1893

1895 printer << " wraps ";

1897 }

1898

1899 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {

1900 Block &block = getRegion().getBlocks().front();

1901

1903 return emitOpError("expected exactly 2 nested ops");

1904

1906

1908 return emitOpError("invalid enclosed op");

1909

1910 for (auto operand : enclosedOp.getOperands())

1911 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,

1912 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))

1913 return emitOpError(

1914 "invalid operand, must be defined by a constant operation");

1915

1916 return success();

1917 }

1918

1919

1920

1921

1922

1925 llvm::dyn_castspirv::StructType(getResult().getType());

1926

1928 return emitError("result type must be a struct type with two memebers");

1929

1932 VectorType exponentVecTy = llvm::dyn_cast(exponentTy);

1933 IntegerType exponentIntTy = llvm::dyn_cast(exponentTy);

1934

1935 Type operandTy = getOperand().getType();

1936 VectorType operandVecTy = llvm::dyn_cast(operandTy);

1937 FloatType operandFTy = llvm::dyn_cast(operandTy);

1938

1939 if (significandTy != operandTy)

1940 return emitError("member zero of the resulting struct type must be the "

1941 "same type as the operand");

1942

1943 if (exponentVecTy) {

1944 IntegerType componentIntTy =

1945 llvm::dyn_cast(exponentVecTy.getElementType());

1946 if (!componentIntTy || componentIntTy.getWidth() != 32)

1947 return emitError("member one of the resulting struct type must"

1948 "be a scalar or vector of 32 bit integer type");

1949 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {

1950 return emitError("member one of the resulting struct type "

1951 "must be a scalar or vector of 32 bit integer type");

1952 }

1953

1954

1955 if (operandVecTy && exponentVecTy &&

1956 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))

1957 return success();

1958

1959 if (operandFTy && exponentIntTy)

1960 return success();

1961

1962 return emitError("member one of the resulting struct type must have the same "

1963 "number of components as the operand type");

1964 }

1965

1966

1967

1968

1969

1971 Type significandType = getX().getType();

1972 Type exponentType = getExp().getType();

1973

1974 if (llvm::isa(significandType) !=

1975 llvm::isa(exponentType))

1976 return emitOpError("operands must both be scalars or vectors");

1977

1979 if (auto vectorType = llvm::dyn_cast(type))

1980 return vectorType.getNumElements();

1981 return 1;

1982 };

1983

1985 return emitOpError("operands must have the same number of elements");

1986

1987 return success();

1988 }

1989

1990

1991

1992

1993

1996 }

1997

1998

1999

2000

2001

2004 }

2005

2006

2007

2008

2009

2012 }

2013

2014

2015

2016

2017

2020 return emitOpError("vector operand and result type mismatch");

2021 auto scalarType = llvm::cast(getType()).getElementType();

2022 if (getScalar().getType() != scalarType)

2023 return emitOpError("scalar operand and result element type match");

2024 return success();

2025 }

static std::string bindingName()

Returns the string name of the Binding decoration.

static std::string descriptorSetName()

Returns the string name of the DescriptorSet decoration.

static Value getZero(OpBuilder &b, Location loc, Type elementType)

Get zero value for an element type.

static int64_t getNumElements(Type t)

Compute the total number of elements in the given type, also taking into account nested types.

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)

static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)

static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)

static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)

static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)

static LogicalResult verifyShiftOp(Operation *op)

static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)

static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)

Walks the given type hierarchy with the given indices, potentially down to component granularity,...

static void printOneResultOp(Operation *op, OpAsmPrinter &p)

static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)

ParseResult parseSymbolName(StringAttr &result)

Parse an -identifier and store it (without the '@' symbol) in a string attribute.

virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0

Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.

virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0

Parse a list of comma-separated items with an optional delimiter.

virtual Builder & getBuilder() const =0

Return a builder which provides useful access to MLIRContext, global objects like types and attribute...

virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0

Parse a named dictionary into 'result' if it is present.

virtual ParseResult parseOptionalKeyword(StringRef keyword)=0

Parse the given keyword if present.

MLIRContext * getContext() const

virtual ParseResult parseRParen()=0

Parse a ) token.

virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0

Emit a diagnostic at the specified location and return failure.

ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)

Add the specified type to the end of the specified type list and return success.

virtual ParseResult parseEqual()=0

Parse a = token.

virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0

Parse a named dictionary into 'result' if the attributes keyword is present.

virtual ParseResult parseColonType(Type &result)=0

Parse a colon followed by a type.

virtual SMLoc getCurrentLocation()=0

Get the location of the next token and store it into the argument.

virtual ParseResult parseOptionalComma()=0

Parse a , token if present.

virtual ParseResult parseColon()=0

Parse a : token.

ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)

Add the specified types to the end of the specified type list and return success.

virtual ParseResult parseLParen()=0

Parse a ( token.

virtual ParseResult parseType(Type &result)=0

Parse a type.

virtual ParseResult parseOptionalLParen()=0

Parse a ( token if present.

ParseResult parseKeywordType(const char *keyword, Type &result)

Parse a keyword followed by a type.

ParseResult parseKeyword(StringRef keyword)

Parse a given keyword.

virtual ParseResult parseAttribute(Attribute &result, Type type={})=0

Parse an arbitrary attribute of a given type and return it in result.

virtual void printSymbolName(StringRef symbolRef)

Print the given string as a symbol reference, i.e.

Attributes are known-constant values of operations.

Block represents an ordered list of Operations.

ValueTypeRange< BlockArgListType > getArgumentTypes()

Return a range containing the types of the arguments for this block.

unsigned getNumArguments()

OpListType & getOperations()

IntegerAttr getI32IntegerAttr(int32_t value)

IntegerAttr getIntegerAttr(Type type, int64_t value)

ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)

FloatAttr getFloatAttr(Type type, double value)

FunctionType getFunctionType(TypeRange inputs, TypeRange results)

IntegerType getIntegerType(unsigned width)

BoolAttr getBoolAttr(bool value)

StringAttr getStringAttr(const Twine &bytes)

MLIRContext * getContext() const

ArrayAttr getArrayAttr(ArrayRef< Attribute > value)

Attr getAttr(Args &&...args)

Get or construct an instance of the attribute Attr with provided arguments.

static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)

Constructs a dense elements attribute from an array of element values.

static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)

Get an instance of a DenseFPElementsAttr with the given arguments.

Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...

A symbol reference with a reference path containing a single element.

This class represents a diagnostic that is inflight and set to be reported.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0

Parses a region.

virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0

Resolve an operand to an SSA value, emitting an error on failure.

virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0

Parses a region if present.

ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)

Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...

virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0

Parse an operation in its generic form.

virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0

Parse a single SSA value operand name along with a result number if allowResultNumber is true.

virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0

Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

void printOperands(const ContainerType &container)

Print a comma separated list of operands.

virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0

If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...

virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0

If the specified operation has attributes, print out an attribute dictionary with their values.

virtual void printGenericOp(Operation *op, bool printOpName=true)=0

Print the entire operation with the default generic assembly form.

virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0

Prints a region.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.

type_range getType() const

Operation is the basic unit of execution within MLIR.

Value getOperand(unsigned idx)

bool hasTrait()

Returns true if the operation was registered with a particular trait, e.g.

Dialect * getDialect()

Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...

AttrClass getAttrOfType(StringAttr name)

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

Location getLoc()

The source location the operation was defined or derived from.

ArrayRef< NamedAttribute > getAttrs()

Return all of the attributes on this operation.

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

operand_type_range getOperandTypes()

result_type_range getResultTypes()

operand_range getOperands()

Returns an iterator on the underlying Value's.

unsigned getNumResults()

Return the number of results held by this operation.

This class implements Optional functionality for ParseResult.

bool has_value() const

Returns true if we contain a valid ParseResult value.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

void push_back(Block *block)

This class allows for representing and managing the symbol table used by operations with the 'SymbolT...

static StringRef getSymbolAttrName()

Return the name of the attribute used for symbol names.

static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)

Returns the operation registered with the given symbol name within the closest parent operation of,...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

Dialect & getDialect() const

Get the dialect this type is registered to.

Type front()

Return first type in the range.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

A utility result that is used to signal how to proceed with an ongoing walk:

static WalkResult advance()

static PointerType get(Type pointeeType, StorageClass storageClass)

unsigned getNumElements() const

Type getElementType(unsigned) const

An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.

void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)

Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...

void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)

Walk all of the regions, blocks, or operations nested under (and including) the given operation.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)

Return all of the attributes for the argument at 'index'.

ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)

Parses a function signature using parser.

void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})

Prints the list of function prefixed with the "attributes" keyword.

void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)

Prints the signature of the function-like operation op.

Operation::operand_range getIndices(Operation *op)

Get the indices that the given load/store operation is operating on.

QueryRef parse(llvm::StringRef line, const QuerySession &qs)

constexpr char kFnNameAttrName[]

constexpr char kSpecIdAttrName[]

LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)

ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())

Parses the next string attribute in parser as an enumerant of the given EnumClass.

void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)

AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)

Returns addressing model selected based on target environment.

FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)

Returns execution model selected based on target environment.

FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)

Returns memory model selected based on target environment.

LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)

ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)

Include the generated interface declarations.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

This is the representation of an operand reference.

This represents an operation in an abstracted form, suitable for use with the builder APIs.

SmallVector< Value, 4 > operands

void addAttribute(StringRef name, Attribute attr)

Add an attribute with the specified name.

void addTypes(ArrayRef< Type > newTypes)

SmallVector< Type, 4 > types

Types of the results of this operation.

Region * addRegion()

Create a region that should be attached to the operation.