MLIR: lib/Dialect/SPIRV/IR/SPIRVOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
14
17
32 #include "llvm/ADT/APFloat.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include "llvm/Support/InterleavedRange.h"
39 #include
40 #include
41 #include
42 #include <type_traits>
43
44 using namespace mlir;
46
47
48
49
50
52 auto constOp = dyn_cast_or_nullspirv::ConstantOp(op);
53 if (!constOp) {
54 return failure();
55 }
56 auto valueAttr = constOp.getValue();
57 auto integerValueAttr = llvm::dyn_cast(valueAttr);
58 if (!integerValueAttr) {
59 return failure();
60 }
61
62 if (integerValueAttr.getType().isSignlessInteger())
63 value = integerValueAttr.getInt();
64 else
65 value = integerValueAttr.getSInt();
66
67 return success();
68 }
69
70 LogicalResult
72 spirv::MemorySemantics memorySemantics) {
73
74
75
76
77
78
79 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
80 spirv::MemorySemantics::Release |
81 spirv::MemorySemantics::AcquireRelease |
82 spirv::MemorySemantics::SequentiallyConsistent;
83
84 auto bitCount =
85 llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
86 if (bitCount > 1) {
88 "expected at most one of these four memory constraints "
89 "to be set: `Acquire`, `Release`,"
90 "`AcquireRelease` or `SequentiallyConsistent`");
91 }
92 return success();
93 }
94
97
99 stringifyDecoration(spirv::Decoration::DescriptorSet));
100 auto bindingName = llvm::convertToSnakeFromCamelCase(
101 stringifyDecoration(spirv::Decoration::Binding));
104 if (descriptorSet && binding) {
107 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
108 << ")";
109 }
110
111
112 auto builtInName = llvm::convertToSnakeFromCamelCase(
113 stringifyDecoration(spirv::Decoration::BuiltIn));
114 if (auto builtin = op->getAttrOfType(builtInName)) {
115 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
116 elidedAttrs.push_back(builtInName);
117 }
118
120 }
121
126
127
133 return failure();
134 auto fnType = llvm::dyn_cast(type);
135 if (!fnType) {
136 parser.emitError(loc, "expected function type");
137 return failure();
138 }
140 return failure();
141 result.addTypes(fnType.getResults());
142 return success();
143 }
149 }
150
152 assert(op->getNumResults() == 1 && "op should have one result");
153
154
155
158 [&](Type type) { return type != resultType; })) {
160 return;
161 }
162
163 p << ' ';
166
167 p << " : " << resultType;
168 }
169
170 template
173 auto valType = val.getType();
174 if (auto valVecTy = llvm::dyn_cast(valType))
175 valType = valVecTy.getElementType();
176
177 if (valType !=
178 llvm::castspirv::PointerType(ptr.getType()).getPointeeType()) {
179 return op.emitOpError("mismatch in result type and pointer type");
180 }
181 return success();
182 }
183
184
185
186
190 if (indices.empty()) {
191 emitErrorFn("expected at least one index for spirv.CompositeExtract");
192 return nullptr;
193 }
194
195 for (auto index : indices) {
196 if (auto cType = llvm::dyn_castspirv::CompositeType(type)) {
197 if (cType.hasCompileTimeKnownNumElements() &&
198 (index < 0 ||
199 static_cast<uint64_t>(index) >= cType.getNumElements())) {
200 emitErrorFn("index ") << index << " out of bounds for " << type;
201 return nullptr;
202 }
203 type = cType.getElementType(index);
204 } else {
205 emitErrorFn("cannot extract from non-composite type ")
206 << type << " with index " << index;
207 return nullptr;
208 }
209 }
210 return type;
211 }
212
216 auto indicesArrayAttr = llvm::dyn_cast(indices);
217 if (!indicesArrayAttr) {
218 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
219 return nullptr;
220 }
221 if (indicesArrayAttr.empty()) {
222 emitErrorFn("expected at least one index for spirv.CompositeExtract");
223 return nullptr;
224 }
225
227 for (auto indexAttr : indicesArrayAttr) {
228 auto indexIntAttr = llvm::dyn_cast(indexAttr);
229 if (!indexIntAttr) {
230 emitErrorFn("expected an 32-bit integer for index, but found '")
231 << indexAttr << "'";
232 return nullptr;
233 }
234 indexVals.push_back(indexIntAttr.getInt());
235 }
237 }
238
242 };
244 }
245
247 SMLoc loc) {
249 return parser.emitError(loc, err);
250 };
252 }
253
254 template
256 auto resultType = llvm::castspirv::StructType(op.getType());
257 if (resultType.getNumElements() != 2)
258 return op.emitOpError("expected result struct type containing two members");
259
260 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
261 resultType.getElementType(0),
262 resultType.getElementType(1)}))
263 return op.emitOpError(
264 "expected all operand types and struct member types are the same");
265
266 return success();
267 }
268
274 return failure();
275
276 Type resultType;
278 if (parser.parseType(resultType))
279 return failure();
280
281 auto structType = llvm::dyn_castspirv::StructType(resultType);
282 if (!structType || structType.getNumElements() != 2)
283 return parser.emitError(loc, "expected spirv.struct type with two members");
284
287 return failure();
288
290 return success();
291 }
292
295 printer << ' ';
299 }
300
303 return op->emitError("expected the same type for the first operand and "
304 "result, but provided ")
307 }
308 return success();
309 }
310
311
312
313
314
316 spirv::GlobalVariableOp var) {
318 }
319
321 auto varOp = dyn_cast_or_nullspirv::GlobalVariableOp(
323 getVariableAttr()));
324 if (!varOp) {
325 return emitOpError("expected spirv.GlobalVariable symbol");
326 }
327 if (getPointer().getType() != varOp.getType()) {
328 return emitOpError(
329 "result type mismatch with the referenced global variable's type");
330 }
331 return success();
332 }
333
334
335
336
337
339 operand_range constituents = this->getConstituents();
340
341
342
343
344
345
346
347 auto coopElementType =
350 [](auto coopType) { return coopType.getElementType(); })
351 .Default([](Type) { return nullptr; });
352
353
354 if (coopElementType) {
355 if (constituents.size() != 1)
356 return emitOpError("has incorrect number of operands: expected ")
357 << "1, but provided " << constituents.size();
358 if (coopElementType != constituents.front().getType())
359 return emitOpError("operand type mismatch: expected operand type ")
360 << coopElementType << ", but provided "
361 << constituents.front().getType();
362 return success();
363 }
364
365
366 auto cType = llvm::castspirv::CompositeType(getType());
367 if (constituents.size() == cType.getNumElements()) {
368 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
369 if (constituents[index].getType() != cType.getElementType(index)) {
370 return emitOpError("operand type mismatch: expected operand type ")
371 << cType.getElementType(index) << ", but provided "
372 << constituents[index].getType();
373 }
374 }
375 return success();
376 }
377
378
379 auto resultType = llvm::dyn_cast(cType);
380 if (!resultType)
381 return emitOpError(
382 "expected to return a vector or cooperative matrix when the number of "
383 "constituents is less than what the result needs");
384
386 for (Value component : constituents) {
387 if (!llvm::isa(component.getType()) &&
388 !component.getType().isIntOrFloat())
389 return emitOpError("operand type mismatch: expected operand to have "
390 "a scalar or vector type, but provided ")
391 << component.getType();
392
393 Type elementType = component.getType();
394 if (auto vectorType = llvm::dyn_cast(component.getType())) {
395 sizes.push_back(vectorType.getNumElements());
396 elementType = vectorType.getElementType();
397 } else {
398 sizes.push_back(1);
399 }
400
401 if (elementType != resultType.getElementType())
402 return emitOpError("operand element type mismatch: expected to be ")
403 << resultType.getElementType() << ", but provided " << elementType;
404 }
405 unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
406 if (totalCount != cType.getNumElements())
407 return emitOpError("has incorrect number of operands: expected ")
408 << cType.getNumElements() << ", but provided " << totalCount;
409 return success();
410 }
411
412
413
414
415
420 auto elementType =
422 if (!elementType) {
423 return;
424 }
425 build(builder, state, elementType, composite, indexAttr);
426 }
427
432 StringRef indicesAttrName =
433 spirv::CompositeExtractOp::getIndicesAttrName(result.name);
434 Type compositeType;
435 SMLoc attrLocation;
436
442 return failure();
443 }
444
445 Type resultType =
446 getElementType(compositeType, indicesAttr, parser, attrLocation);
447 if (!resultType) {
448 return failure();
449 }
451 return success();
452 }
453
455 printer << ' ' << getComposite() << getIndices() << " : "
456 << getComposite().getType();
457 }
458
460 auto indicesArrayAttr = llvm::dyn_cast(getIndices());
461 auto resultType =
463 if (!resultType)
464 return failure();
465
466 if (resultType != getType()) {
467 return emitOpError("invalid result type: expected ")
468 << resultType << " but provided " << getType();
469 }
470
471 return success();
472 }
473
474
475
476
477
482 build(builder, state, composite.getType(), object, composite, indexAttr);
483 }
484
488 Type objectType, compositeType;
490 StringRef indicesAttrName =
491 spirv::CompositeInsertOp::getIndicesAttrName(result.name);
493
494 return failure(
499 parser.resolveOperands(operands, {objectType, compositeType}, loc,
502 }
503
505 auto indicesArrayAttr = llvm::dyn_cast(getIndices());
506 auto objectType =
508 if (!objectType)
509 return failure();
510
511 if (objectType != getObject().getType()) {
512 return emitOpError("object operand type should be ")
513 << objectType << ", but found " << getObject().getType();
514 }
515
517 return emitOpError("result type should be the same as "
518 "the composite type, but found ")
519 << getComposite().getType() << " vs " << getType();
520 }
521
522 return success();
523 }
524
526 printer << " " << getObject() << ", " << getComposite() << getIndices()
527 << " : " << getObject().getType() << " into "
528 << getComposite().getType();
529 }
530
531
532
533
534
538 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
540 return failure();
541
543 if (auto typedAttr = llvm::dyn_cast(value))
544 type = typedAttr.getType();
545 if (llvm::isa<NoneType, TensorType>(type)) {
547 return failure();
548 }
549
551 }
552
554 printer << ' ' << getValue();
555 if (llvm::isaspirv::ArrayType(getType()))
556 printer << " : " << getType();
557 }
558
560 Type opType) {
561 if (isaspirv::CooperativeMatrixType(opType)) {
562 auto denseAttr = dyn_cast(value);
563 if (!denseAttr || !denseAttr.isSplat())
564 return op.emitOpError("expected a splat dense attribute for cooperative "
565 "matrix constant, but found ")
566 << denseAttr;
567 }
568 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
569 auto valueType = llvm::cast(value).getType();
570 if (valueType != opType)
571 return op.emitOpError("result type (")
572 << opType << ") does not match value type (" << valueType << ")";
573 return success();
574 }
575 if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
576 auto valueType = llvm::cast(value).getType();
577 if (valueType == opType)
578 return success();
579 auto arrayType = llvm::dyn_castspirv::ArrayType(opType);
580 auto shapedType = llvm::dyn_cast(valueType);
581 if (!arrayType)
582 return op.emitOpError("result or element type (")
583 << opType << ") does not match value type (" << valueType
584 << "), must be the same or spirv.array";
585
586 int numElements = arrayType.getNumElements();
587 auto opElemType = arrayType.getElementType();
588 while (auto t = llvm::dyn_castspirv::ArrayType(opElemType)) {
589 numElements *= t.getNumElements();
590 opElemType = t.getElementType();
591 }
592 if (!opElemType.isIntOrFloat())
593 return op.emitOpError("only support nested array result type");
594
595 auto valueElemType = shapedType.getElementType();
596 if (valueElemType != opElemType) {
597 return op.emitOpError("result element type (")
598 << opElemType << ") does not match value element type ("
599 << valueElemType << ")";
600 }
601
602 if (numElements != shapedType.getNumElements()) {
603 return op.emitOpError("result number of elements (")
604 << numElements << ") does not match value number of elements ("
605 << shapedType.getNumElements() << ")";
606 }
607 return success();
608 }
609 if (auto arrayAttr = llvm::dyn_cast(value)) {
610 auto arrayType = llvm::dyn_castspirv::ArrayType(opType);
611 if (!arrayType)
612 return op.emitOpError(
613 "must have spirv.array result type for array value");
614 Type elemType = arrayType.getElementType();
615 for (Attribute element : arrayAttr.getValue()) {
616
618 return failure();
619 }
620 return success();
621 }
622 return op.emitOpError("cannot have attribute: ") << value;
623 }
624
626
627
628
630 }
631
632 bool spirv::ConstantOp::isBuildableWith(Type type) {
633
634 if (!llvm::isaspirv::SPIRVType(type))
635 return false;
636
637 if (isa(type.getDialect())) {
638
639 return llvm::isaspirv::ArrayType(type);
640 }
641
642 return true;
643 }
644
647 if (auto intType = llvm::dyn_cast(type)) {
648 unsigned width = intType.getWidth();
649 if (width == 1)
650 return builder.createspirv::ConstantOp(loc, type,
652 return builder.createspirv::ConstantOp(
653 loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
654 }
655 if (auto floatType = llvm::dyn_cast(type)) {
656 return builder.createspirv::ConstantOp(
657 loc, type, builder.getFloatAttr(floatType, 0.0));
658 }
659 if (auto vectorType = llvm::dyn_cast(type)) {
660 Type elemType = vectorType.getElementType();
661 if (llvm::isa(elemType)) {
662 return builder.createspirv::ConstantOp(
663 loc, type,
666 }
667 if (llvm::isa(elemType)) {
668 return builder.createspirv::ConstantOp(
669 loc, type,
672 }
673 }
674
675 llvm_unreachable("unimplemented types for ConstantOp::getZero()");
676 }
677
678 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
680 if (auto intType = llvm::dyn_cast(type)) {
681 unsigned width = intType.getWidth();
682 if (width == 1)
683 return builder.createspirv::ConstantOp(loc, type,
685 return builder.createspirv::ConstantOp(
686 loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
687 }
688 if (auto floatType = llvm::dyn_cast(type)) {
689 return builder.createspirv::ConstantOp(
690 loc, type, builder.getFloatAttr(floatType, 1.0));
691 }
692 if (auto vectorType = llvm::dyn_cast(type)) {
693 Type elemType = vectorType.getElementType();
694 if (llvm::isa(elemType)) {
695 return builder.createspirv::ConstantOp(
696 loc, type,
699 }
700 if (llvm::isa(elemType)) {
701 return builder.createspirv::ConstantOp(
702 loc, type,
705 }
706 }
707
708 llvm_unreachable("unimplemented types for ConstantOp::getOne()");
709 }
710
711 void mlir::spirv::ConstantOp::getAsmResultNames(
714
716 llvm::raw_svector_ostream specialName(specialNameBuffer);
717 specialName << "cst";
718
719 IntegerType intTy = llvm::dyn_cast(type);
720
721 if (IntegerAttr intCst = llvm::dyn_cast(getValue())) {
722 if (intTy && intTy.getWidth() == 1) {
723 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
724 }
725
726 if (intTy.isSignless()) {
727 specialName << intCst.getInt();
728 } else if (intTy.isUnsigned()) {
729 specialName << intCst.getUInt();
730 } else {
731 specialName << intCst.getSInt();
732 }
733 }
734
735 if (intTy || llvm::isa(type)) {
736 specialName << '_' << type;
737 }
738
739 if (auto vecType = llvm::dyn_cast(type)) {
740 specialName << "_vec_";
741 specialName << vecType.getDimSize(0);
742
743 Type elementType = vecType.getElementType();
744
745 if (llvm::isa(elementType) ||
746 llvm::isa(elementType)) {
747 specialName << "x" << elementType;
748 }
749 }
750
751 setNameFn(getResult(), specialName.str());
752 }
753
754 void mlir::spirv::AddressOfOp::getAsmResultNames(
757 llvm::raw_svector_ostream specialName(specialNameBuffer);
758 specialName << getVariable() << "_addr";
759 setNameFn(getResult(), specialName.str());
760 }
761
762
763
764
765
768 }
769
770
771
772
773
775 spirv::ExecutionModel executionModel,
776 spirv::FuncOp function,
778 build(builder, state,
781 }
782
785 spirv::ExecutionModel execModel;
787
789 if (parseEnumStrAttrspirv::ExecutionModelAttr(execModel, parser, result) ||
791 return failure();
792 }
793
795
797
798 FlatSymbolRefAttr var;
799 NamedAttrList attrs;
800 if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
801 return failure();
802 interfaceVars.push_back(var);
803 return success();
804 }))
805 return failure();
806 }
807 result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
809 return success();
810 }
811
813 printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
815 auto interfaceVars = getInterface().getValue();
816 if (!interfaceVars.empty())
817 printer << ", " << llvm::interleaved(interfaceVars);
818 }
819
821
822
823 return success();
824 }
825
826
827
828
829
831 spirv::FuncOp function,
832 spirv::ExecutionMode executionMode,
837 }
838
841 spirv::ExecutionMode execMode;
844 parseEnumStrAttrspirv::ExecutionModeAttr(execMode, parser, result)) {
845 return failure();
846 }
847
853 if (parser.parseAttribute(value, i32Type, "value", attr)) {
854 return failure();
855 }
856 values.push_back(llvm::cast(value).getInt());
857 }
858 StringRef valuesAttrName =
859 spirv::ExecutionModeOp::getValuesAttrName(result.name);
862 return success();
863 }
864
866 printer << " ";
868 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
869 ArrayAttr values = this->getValues();
870 if (!values.empty())
871 printer << ", " << llvm::interleaved(values.getAsValueRange());
872 }
873
874
875
876
877
883
884
885 StringAttr nameAttr;
888 return failure();
889
890
891 bool isVariadic = false;
893 parser, false, entryArgs, isVariadic, resultTypes,
894 resultAttrs))
895 return failure();
896
898 for (auto &arg : entryArgs)
899 argTypes.push_back(arg.type);
900 auto fnType = builder.getFunctionType(argTypes, resultTypes);
903
904
905 spirv::FunctionControl fnControl;
906 if (parseEnumStrAttrspirv::FunctionControlAttr(fnControl, parser, result))
907 return failure();
908
909
911 return failure();
912
913
914 assert(resultAttrs.size() == resultTypes.size());
916 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
917 getResAttrsAttrName(result.name));
918
919
923 return failure(parseResult.has_value() && failed(*parseResult));
924 }
925
927
928 printer << " ";
930 auto fnType = getFunctionType();
932 printer, *this, fnType.getInputs(),
933 false, fnType.getResults());
934 printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
935 << "\"";
937 printer, *this,
938 {spirv::attributeNamespirv::FunctionControl(),
939 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
940 getFunctionControlAttrName()});
941
942
943 Region &body = this->getBody();
944 if (!body.empty()) {
945 printer << ' ';
946 printer.printRegion(body, false,
947 true);
948 }
949 }
950
951 LogicalResult spirv::FuncOp::verifyType() {
952 FunctionType fnType = getFunctionType();
953 if (fnType.getNumResults() > 1)
954 return emitOpError("cannot have more than one result");
955
956 auto hasDecorationAttr = [&](spirv::Decoration decoration,
957 unsigned argIndex) {
958 auto func = llvm::cast(getOperation());
959 for (auto argAttr : cast(func).getArgAttrs(argIndex)) {
960 if (argAttr.getName() != spirv::DecorationAttr::name)
961 continue;
962 if (auto decAttr = dyn_castspirv::DecorationAttr(argAttr.getValue()))
963 return decAttr.getValue() == decoration;
964 }
965 return false;
966 };
967
968 for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
969 Type param = fnType.getInputs()[i];
970 auto inputPtrType = dyn_castspirv::PointerType(param);
971 if (!inputPtrType)
972 continue;
973
974 auto pointeePtrType =
975 dyn_castspirv::PointerType(inputPtrType.getPointeeType());
976 if (pointeePtrType) {
977
978
979
980
981
982 if (pointeePtrType.getStorageClass() !=
983 spirv::StorageClass::PhysicalStorageBuffer)
984 continue;
985
986 bool hasAliasedPtr =
987 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
988 bool hasRestrictPtr =
989 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
990 if (!hasAliasedPtr && !hasRestrictPtr)
991 return emitOpError()
992 << "with a pointer points to a physical buffer pointer must "
993 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
994 continue;
995 }
996
997
998
999
1000 if (auto pointeeArrayType =
1001 dyn_castspirv::ArrayType(inputPtrType.getPointeeType())) {
1002 pointeePtrType =
1003 dyn_castspirv::PointerType(pointeeArrayType.getElementType());
1004 } else {
1005 pointeePtrType = inputPtrType;
1006 }
1007
1008 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1009 spirv::StorageClass::PhysicalStorageBuffer)
1010 continue;
1011
1012 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1013 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1014 if (!hasAliased && !hasRestrict)
1015 return emitOpError() << "with physical buffer pointer must be decorated "
1016 "either 'Aliased' or 'Restrict'";
1017 }
1018
1019 return success();
1020 }
1021
1022 LogicalResult spirv::FuncOp::verifyBody() {
1023 FunctionType fnType = getFunctionType();
1024 if (!isExternal()) {
1025 Block &entryBlock = front();
1026
1027 unsigned numArguments = this->getNumArguments();
1029 return emitOpError("entry block must have ")
1030 << numArguments << " arguments to match function signature";
1031
1032 for (auto [index, fnArgType, blockArgType] :
1034 if (blockArgType != fnArgType) {
1035 return emitOpError("type of entry block argument #")
1036 << index << '(' << blockArgType
1037 << ") must match the type of the corresponding argument in "
1038 << "function signature(" << fnArgType << ')';
1039 }
1040 }
1041 }
1042
1044 if (auto retOp = dyn_castspirv::ReturnOp(op)) {
1045 if (fnType.getNumResults() != 0)
1046 return retOp.emitOpError("cannot be used in functions returning value");
1047 } else if (auto retOp = dyn_castspirv::ReturnValueOp(op)) {
1048 if (fnType.getNumResults() != 1)
1049 return retOp.emitOpError(
1050 "returns 1 value but enclosing function requires ")
1051 << fnType.getNumResults() << " results";
1052
1053 auto retOperandType = retOp.getValue().getType();
1054 auto fnResultType = fnType.getResult(0);
1055 if (retOperandType != fnResultType)
1056 return retOp.emitOpError(" return value's type (")
1057 << retOperandType << ") mismatch with function's result type ("
1058 << fnResultType << ")";
1059 }
1061 });
1062
1063
1064
1065 return failure(walkResult.wasInterrupted());
1066 }
1067
1069 StringRef name, FunctionType type,
1070 spirv::FunctionControl control,
1074 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
1075 state.addAttribute(spirv::attributeNamespirv::FunctionControl(),
1076 builder.getAttrspirv::FunctionControlAttr(control));
1077 state.attributes.append(attrs.begin(), attrs.end());
1078 state.addRegion();
1079 }
1080
1081
1082
1083
1084
1088 }
1090
1091
1092
1093
1094
1098 }
1100
1101
1102
1103
1104
1108 }
1110
1111
1112
1113
1114
1117 }
1119
1120
1121
1122
1123
1125 Type type, StringRef name,
1126 unsigned descriptorSet, unsigned binding) {
1128 state.addAttribute(
1129 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1131 state.addAttribute(
1132 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1134 }
1135
1137 Type type, StringRef name,
1138 spirv::BuiltIn builtin) {
1140 state.addAttribute(
1141 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1142 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1143 }
1144
1147
1148 StringAttr nameAttr;
1149 StringRef initializerAttrName =
1150 spirv::GlobalVariableOp::getInitializerAttrName(result.name);
1153 return failure();
1154 }
1155
1156
1163 return failure();
1164 }
1165
1167 return failure();
1168 }
1169
1171 StringRef typeAttrName =
1172 spirv::GlobalVariableOp::getTypeAttrName(result.name);
1175 return failure();
1176 }
1177 if (!llvm::isaspirv::PointerType(type)) {
1178 return parser.emitError(loc, "expected spirv.ptr type");
1179 }
1181
1182 return success();
1183 }
1184
1187 spirv::attributeNamespirv::StorageClass()};
1188
1189
1190 printer << ' ';
1193
1194 StringRef initializerAttrName = this->getInitializerAttrName();
1195
1196 if (auto initializer = this->getInitializer()) {
1197 printer << " " << initializerAttrName << '(';
1199 printer << ')';
1200 elidedAttrs.push_back(initializerAttrName);
1201 }
1202
1203 StringRef typeAttrName = this->getTypeAttrName();
1204 elidedAttrs.push_back(typeAttrName);
1206 printer << " : " << getType();
1207 }
1208
1210 if (!llvm::isaspirv::PointerType(getType()))
1211 return emitOpError("result must be of a !spv.ptr type");
1212
1213
1214
1215
1216
1217 auto storageClass = this->storageClass();
1218 if (storageClass == spirv::StorageClass::Generic ||
1219 storageClass == spirv::StorageClass::Function) {
1220 return emitOpError("storage class cannot be '")
1221 << stringifyStorageClass(storageClass) << "'";
1222 }
1223
1225 this->getInitializerAttrName())) {
1227 (*this)->getParentOp(), init.getAttr());
1228
1229
1230
1231 if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1232 spirv::SpecConstantCompositeOp>(initOp)) {
1233 return emitOpError("initializer must be result of a "
1234 "spirv.SpecConstant or spirv.GlobalVariable or "
1235 "spirv.SpecConstantCompositeOp op");
1236 }
1237 }
1238
1239 return success();
1240 }
1241
1242
1243
1244
1245
1248 return failure();
1249
1250 return success();
1251 }
1252
1253
1254
1255
1256
1259
1260 spirv::StorageClass storageClass;
1263 Type elementType;
1266 parser.parseType(elementType)) {
1267 return failure();
1268 }
1269
1271 if (auto valVecTy = llvm::dyn_cast(elementType))
1273
1274 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
1276 return failure();
1277 }
1278 return success();
1279 }
1280
1282 printer << " " << getPtr() << ", " << getValue() << " : "
1283 << getValue().getType();
1284 }
1285
1288 return failure();
1289
1290 return success();
1291 }
1292
1293
1294
1295
1296
1299 }
1300
1304 }
1305
1308 }
1309
1310
1311
1312
1313
1316 }
1317
1321 }
1322
1325 }
1326
1327
1328
1329
1330
1333 }
1334
1338 }
1339
1342 }
1343
1344
1345
1346
1347
1350 }
1351
1355 }
1356
1359 }
1360
1361
1362
1363
1364
1367 }
1368
1369
1370
1371
1372
1374 std::optional name) {
1377 if (name) {
1380 }
1381 }
1382
1384 spirv::AddressingModel addressingModel,
1385 spirv::MemoryModel memoryModel,
1386 std::optional vceTriple,
1387 std::optional name) {
1388 state.addAttribute(
1389 "addressing_model",
1390 builder.getAttrspirv::AddressingModelAttr(addressingModel));
1391 state.addAttribute("memory_model",
1392 builder.getAttrspirv::MemoryModelAttr(memoryModel));
1395 if (vceTriple)
1396 state.addAttribute(getVCETripleAttrName(), *vceTriple);
1397 if (name)
1400 }
1401
1405
1406
1407 StringAttr nameAttr;
1410
1411
1412 spirv::AddressingModel addrModel;
1413 spirv::MemoryModel memoryModel;
1414 if (spirv::parseEnumKeywordAttrspirv::AddressingModelAttr(addrModel, parser,
1415 result) ||
1416 spirv::parseEnumKeywordAttrspirv::MemoryModelAttr(memoryModel, parser,
1417 result))
1418 return failure();
1419
1423 spirv::ModuleOp::getVCETripleAttrName(),
1425 return failure();
1426 }
1427
1429 parser.parseRegion(*body, {}))
1430 return failure();
1431
1432
1433 if (body->empty())
1435
1436 return success();
1437 }
1438
1440 if (std::optional name = getName()) {
1441 printer << ' ';
1443 }
1444
1446
1447 printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
1449 auto addressingModelAttrName = spirv::attributeNamespirv::AddressingModel();
1450 auto memoryModelAttrName = spirv::attributeNamespirv::MemoryModel();
1451 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1453
1454 if (std::optionalspirv::VerCapExtAttr triple = getVceTriple()) {
1455 printer << " requires " << *triple;
1456 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1457 }
1458
1460 printer << ' ';
1462 }
1463
1464 LogicalResult spirv::ModuleOp::verifyRegions() {
1465 Dialect *dialect = (*this)->getDialect();
1467 entryPoints;
1469
1470 for (auto &op : *getBody()) {
1472 return op.emitError("'spirv.module' can only contain spirv.* ops");
1473
1474
1475
1476
1477 if (auto entryPointOp = dyn_castspirv::EntryPointOp(op)) {
1478 auto funcOp = table.lookupspirv::FuncOp(entryPointOp.getFn());
1479 if (!funcOp) {
1480 return entryPointOp.emitError("function '")
1481 << entryPointOp.getFn() << "' not found in 'spirv.module'";
1482 }
1483 if (auto interface = entryPointOp.getInterface()) {
1484 for (Attribute varRef : interface) {
1485 auto varSymRef = llvm::dyn_cast(varRef);
1486 if (!varSymRef) {
1487 return entryPointOp.emitError(
1488 "expected symbol reference for interface "
1489 "specification instead of '")
1490 << varRef;
1491 }
1492 auto variableOp =
1493 table.lookupspirv::GlobalVariableOp(varSymRef.getValue());
1494 if (!variableOp) {
1495 return entryPointOp.emitError("expected spirv.GlobalVariable "
1496 "symbol reference instead of'")
1497 << varSymRef << "'";
1498 }
1499 }
1500 }
1501
1502 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1503 funcOp, entryPointOp.getExecutionModel());
1504 if (!entryPoints.try_emplace(key, entryPointOp).second)
1505 return entryPointOp.emitError("duplicate of a previous EntryPointOp");
1506 } else if (auto funcOp = dyn_castspirv::FuncOp(op)) {
1507
1508
1509
1510 auto linkageAttr = funcOp.getLinkageAttributes();
1511 auto hasImportLinkage =
1512 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1513 spirv::LinkageType::Import);
1514 if (funcOp.isExternal() && !hasImportLinkage)
1516 "'spirv.module' cannot contain external functions "
1517 "without 'Import' linkage_attributes (LinkageAttributes)");
1518
1519
1520 for (auto &block : funcOp)
1521 for (auto &op : block) {
1524 "functions in 'spirv.module' can only contain spirv.* ops");
1525 }
1526 }
1527 }
1528
1529 return success();
1530 }
1531
1532
1533
1534
1535
1538 (*this)->getParentOp(), getSpecConstAttr());
1539 Type constType;
1540
1541 auto specConstOp = dyn_cast_or_nullspirv::SpecConstantOp(specConstSym);
1542 if (specConstOp)
1543 constType = specConstOp.getDefaultValue().getType();
1544
1545 auto specConstCompositeOp =
1546 dyn_cast_or_nullspirv::SpecConstantCompositeOp(specConstSym);
1547 if (specConstCompositeOp)
1548 constType = specConstCompositeOp.getType();
1549
1550 if (!specConstOp && !specConstCompositeOp)
1551 return emitOpError(
1552 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1553
1554 if (getReference().getType() != constType)
1555 return emitOpError("result type mismatch with the referenced "
1556 "specialization constant's type");
1557
1558 return success();
1559 }
1560
1561
1562
1563
1564
1567 StringAttr nameAttr;
1569 StringRef defaultValueAttrName =
1570 spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
1571
1574 return failure();
1575
1576
1578 IntegerAttr specIdAttr;
1582 return failure();
1583 }
1584
1587 return failure();
1588
1589 return success();
1590 }
1591
1593 printer << ' ';
1595 if (auto specID = (*this)->getAttrOfType(kSpecIdAttrName))
1596 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
1597 printer << " = " << getDefaultValue();
1598 }
1599
1601 if (auto specID = (*this)->getAttrOfType(kSpecIdAttrName))
1602 if (specID.getValue().isNegative())
1603 return emitOpError("SpecId cannot be negative");
1604
1605 auto value = getDefaultValue();
1606 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1607
1608 if (!llvm::isaspirv::SPIRVType(value.getType()))
1609 return emitOpError("default value bitwidth disallowed");
1610 return success();
1611 }
1612 return emitOpError(
1613 "default value can only be a bool, integer, or float scalar");
1614 }
1615
1616
1617
1618
1619
1621 VectorType resultType = llvm::cast(getType());
1622
1623 size_t numResultElements = resultType.getNumElements();
1624 if (numResultElements != getComponents().size())
1625 return emitOpError("result type element count (")
1626 << numResultElements
1627 << ") mismatch with the number of component selectors ("
1628 << getComponents().size() << ")";
1629
1630 size_t totalSrcElements =
1631 llvm::cast(getVector1().getType()).getNumElements() +
1632 llvm::cast(getVector2().getType()).getNumElements();
1633
1634 for (const auto &selector : getComponents().getAsValueRange()) {
1635 uint32_t index = selector.getZExtValue();
1636 if (index >= totalSrcElements &&
1637 index != std::numeric_limits<uint32_t>().max())
1638 return emitOpError("component selector ")
1639 << index << " out of range: expected to be in [0, "
1640 << totalSrcElements << ") or 0xffffffff";
1641 }
1642 return success();
1643 }
1644
1645
1646
1647
1648
1650 Type elementType =
1653 [](auto matrixType) { return matrixType.getElementType(); })
1654 .Default([](Type) { return nullptr; });
1655
1656 assert(elementType && "Unhandled type");
1657
1658
1659 if (getScalar().getType() != elementType)
1660 return emitOpError("input matrix components' type and scaling value must "
1661 "have the same type");
1662
1663 return success();
1664 }
1665
1666
1667
1668
1669
1671 auto inputMatrix = llvm::castspirv::MatrixType(getMatrix().getType());
1672 auto resultMatrix = llvm::castspirv::MatrixType(getResult().getType());
1673
1674
1675 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1676 return emitError("input matrix rows count must be equal to "
1677 "output matrix columns count");
1678
1679 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1680 return emitError("input matrix columns count must be equal to "
1681 "output matrix rows count");
1682
1683
1684 if (inputMatrix.getElementType() != resultMatrix.getElementType())
1685 return emitError("input and output matrices must have the same "
1686 "component type");
1687
1688 return success();
1689 }
1690
1691
1692
1693
1694
1696 auto matrixType = llvm::castspirv::MatrixType(getMatrix().getType());
1697 auto vectorType = llvm::cast(getVector().getType());
1698 auto resultType = llvm::cast(getType());
1699
1700 if (matrixType.getNumColumns() != vectorType.getNumElements())
1701 return emitOpError("matrix columns (")
1702 << matrixType.getNumColumns() << ") must match vector operand size ("
1703 << vectorType.getNumElements() << ")";
1704
1705 if (resultType.getNumElements() != matrixType.getNumRows())
1706 return emitOpError("result size (")
1707 << resultType.getNumElements() << ") must match the matrix rows ("
1708 << matrixType.getNumRows() << ")";
1709
1710 if (matrixType.getElementType() != resultType.getElementType())
1711 return emitOpError("matrix and result element types must match");
1712
1713 return success();
1714 }
1715
1716
1717
1718
1719
1721 auto vectorType = llvm::cast(getVector().getType());
1722 auto matrixType = llvm::castspirv::MatrixType(getMatrix().getType());
1723 auto resultType = llvm::cast(getType());
1724
1725 if (matrixType.getNumRows() != vectorType.getNumElements())
1726 return emitOpError("number of components in vector must equal the number "
1727 "of components in each column in matrix");
1728
1729 if (resultType.getNumElements() != matrixType.getNumColumns())
1730 return emitOpError("number of columns in matrix must equal the number of "
1731 "components in result");
1732
1733 if (matrixType.getElementType() != resultType.getElementType())
1734 return emitOpError("matrix must be a matrix with the same component type "
1735 "as the component type in result");
1736
1737 return success();
1738 }
1739
1740
1741
1742
1743
1745 auto leftMatrix = llvm::castspirv::MatrixType(getLeftmatrix().getType());
1746 auto rightMatrix = llvm::castspirv::MatrixType(getRightmatrix().getType());
1747 auto resultMatrix = llvm::castspirv::MatrixType(getResult().getType());
1748
1749
1750 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1751 return emitError("left matrix columns' count must be equal to "
1752 "the right matrix rows' count");
1753
1754
1755 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1757 "right and result matrices must have equal columns' count");
1758
1759
1760 if (rightMatrix.getElementType() != resultMatrix.getElementType())
1761 return emitError("right and result matrices' component type must"
1762 " be the same");
1763
1764
1765 if (leftMatrix.getElementType() != resultMatrix.getElementType())
1766 return emitError("left and result matrices' component type"
1767 " must be the same");
1768
1769
1770 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1771 return emitError("left and result matrices must have equal rows' count");
1772
1773 return success();
1774 }
1775
1776
1777
1778
1779
1782
1783 StringAttr compositeName;
1786 return failure();
1787
1789 return failure();
1790
1792
1793 do {
1794
1795 const char *attrName = "spec_const";
1798
1800 return failure();
1801
1802 constituents.push_back(specConstRef);
1804
1806 return failure();
1807
1808 StringAttr compositeSpecConstituentsName =
1809 spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
1810 result.addAttribute(compositeSpecConstituentsName,
1812
1815 return failure();
1816
1817 StringAttr typeAttrName =
1818 spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
1820
1821 return success();
1822 }
1823
1825 printer << " ";
1827 printer << " (" << llvm::interleaved(this->getConstituents().getValue())
1829 }
1830
1832 auto cType = llvm::dyn_castspirv::CompositeType(getType());
1833 auto constituents = this->getConstituents().getValue();
1834
1835 if (!cType)
1836 return emitError("result type must be a composite type, but provided ")
1838
1839 if (llvm::isaspirv::CooperativeMatrixType(cType))
1840 return emitError("unsupported composite type ") << cType;
1841 if (constituents.size() != cType.getNumElements())
1842 return emitError("has incorrect number of operands: expected ")
1843 << cType.getNumElements() << ", but provided "
1844 << constituents.size();
1845
1846 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1847 auto constituent = llvm::cast(constituents[index]);
1848
1849 auto constituentSpecConstOp =
1851 (*this)->getParentOp(), constituent.getAttr()));
1852
1853 if (constituentSpecConstOp.getDefaultValue().getType() !=
1854 cType.getElementType(index))
1855 return emitError("has incorrect types of operands: expected ")
1856 << cType.getElementType(index) << ", but provided "
1857 << constituentSpecConstOp.getDefaultValue().getType();
1858 }
1859
1860 return success();
1861 }
1862
1863
1864
1865
1866
1870
1872 return failure();
1873
1877
1878 if (!wrappedOp)
1879 return failure();
1880
1885
1887
1889 return failure();
1890
1891 return success();
1892 }
1893
1895 printer << " wraps ";
1897 }
1898
1899 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1900 Block &block = getRegion().getBlocks().front();
1901
1903 return emitOpError("expected exactly 2 nested ops");
1904
1906
1908 return emitOpError("invalid enclosed op");
1909
1910 for (auto operand : enclosedOp.getOperands())
1911 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1912 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1913 return emitOpError(
1914 "invalid operand, must be defined by a constant operation");
1915
1916 return success();
1917 }
1918
1919
1920
1921
1922
1925 llvm::dyn_castspirv::StructType(getResult().getType());
1926
1928 return emitError("result type must be a struct type with two memebers");
1929
1932 VectorType exponentVecTy = llvm::dyn_cast(exponentTy);
1933 IntegerType exponentIntTy = llvm::dyn_cast(exponentTy);
1934
1935 Type operandTy = getOperand().getType();
1936 VectorType operandVecTy = llvm::dyn_cast(operandTy);
1937 FloatType operandFTy = llvm::dyn_cast(operandTy);
1938
1939 if (significandTy != operandTy)
1940 return emitError("member zero of the resulting struct type must be the "
1941 "same type as the operand");
1942
1943 if (exponentVecTy) {
1944 IntegerType componentIntTy =
1945 llvm::dyn_cast(exponentVecTy.getElementType());
1946 if (!componentIntTy || componentIntTy.getWidth() != 32)
1947 return emitError("member one of the resulting struct type must"
1948 "be a scalar or vector of 32 bit integer type");
1949 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1950 return emitError("member one of the resulting struct type "
1951 "must be a scalar or vector of 32 bit integer type");
1952 }
1953
1954
1955 if (operandVecTy && exponentVecTy &&
1956 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1957 return success();
1958
1959 if (operandFTy && exponentIntTy)
1960 return success();
1961
1962 return emitError("member one of the resulting struct type must have the same "
1963 "number of components as the operand type");
1964 }
1965
1966
1967
1968
1969
1971 Type significandType = getX().getType();
1972 Type exponentType = getExp().getType();
1973
1974 if (llvm::isa(significandType) !=
1975 llvm::isa(exponentType))
1976 return emitOpError("operands must both be scalars or vectors");
1977
1979 if (auto vectorType = llvm::dyn_cast(type))
1980 return vectorType.getNumElements();
1981 return 1;
1982 };
1983
1985 return emitOpError("operands must have the same number of elements");
1986
1987 return success();
1988 }
1989
1990
1991
1992
1993
1996 }
1997
1998
1999
2000
2001
2004 }
2005
2006
2007
2008
2009
2012 }
2013
2014
2015
2016
2017
2020 return emitOpError("vector operand and result type mismatch");
2021 auto scalarType = llvm::cast(getType()).getElementType();
2022 if (getScalar().getType() != scalarType)
2023 return emitOpError("scalar operand and result element type match");
2024 return success();
2025 }
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
static LogicalResult verifyShiftOp(Operation *op)
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
OpListType & getOperations()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FloatAttr getFloatAttr(Type type, double value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
AttrClass getAttrOfType(StringAttr name)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.