MLIR: lib/IR/BuiltinTypes.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10 #include "TypeDetail.h"
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/TypeSwitch.h"
24
25 using namespace mlir;
27
28
29
30
31
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/IR/BuiltinTypes.cpp.inc"
34
35 namespace mlir {
36 #include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
37 }
38
39
40
41
42
43 void BuiltinDialect::registerTypes() {
44 addTypes<
45 #define GET_TYPEDEF_LIST
46 #include "mlir/IR/BuiltinTypes.cpp.inc"
47 >();
48 }
49
50
51
52
53
54
56 Type elementType) {
58 return emitError() << "invalid element type for complex";
59 return success();
60 }
61
62
63
64
65
66
68 unsigned width,
69 SignednessSemantics signedness) {
70 if (width > IntegerType::kMaxWidth) {
71 return emitError() << "integer bitwidth is limited to "
72 << IntegerType::kMaxWidth << " bits";
73 }
74 return success();
75 }
76
77 unsigned IntegerType::getWidth() const { return getImpl()->width; }
78
79 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
80 return getImpl()->signedness;
81 }
82
83 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
84 if (!scale)
85 return IntegerType();
87 }
88
89
90
91
92
93
94 #define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
95 const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
96 return APFloat::SEM(); \
97 }
116 #undef FLOAT_TYPE_SEMANTICS
117
118 FloatType Float16Type::scaleElementBitwidth(unsigned scale) const {
119 if (scale == 2)
121 if (scale == 4)
123 return FloatType();
124 }
125
126 FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const {
127 if (scale == 2)
129 if (scale == 4)
131 return FloatType();
132 }
133
134 FloatType Float32Type::scaleElementBitwidth(unsigned scale) const {
135 if (scale == 2)
137 return FloatType();
138 }
139
140
141
142
143
144 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
145
147 return getImpl()->getInputs();
148 }
149
150 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
151
153 return getImpl()->getResults();
154 }
155
158 }
159
160
161
162 FunctionType FunctionType::getWithArgsAndResults(
167 insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
169 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
170 return clone(newArgTypes, newResultTypes);
171 }
172
173
174 FunctionType
175 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
176 const BitVector &resultIndices) {
180 filterTypesOut(getResults(), resultIndices, resultStorage);
181 return clone(newArgTypes, newResultTypes);
182 }
183
184
185
186
187
188
190 StringAttr dialect, StringRef typeData) {
192 return emitError() << "invalid dialect namespace '" << dialect << "'";
193
194
195 MLIRContext *context = dialect.getContext();
199 << "`!" << dialect << "<\"" << typeData << "\">"
200 << "` type created with unregistered dialect. If this is "
201 "intended, please call allowUnregisteredDialects() on the "
202 "MLIRContext, or use -allow-unregistered-dialect with "
203 "the MLIR opt tool used";
204 }
205
206 return success();
207 }
208
209
210
211
212
213 bool VectorType::isValidElementType(Type t) {
214 return isValidVectorTypeElementType(t);
215 }
216
220 if (!isValidElementType(elementType))
222 << "vector elements must be int/index/float type but got "
223 << elementType;
224
225 if (any_of(shape, [](int64_t i) { return i <= 0; }))
227 << "vector types must have positive constant sizes but got "
228 << shape;
229
230 if (scalableDims.size() != shape.size())
231 return emitError() << "number of dims must match, got "
232 << scalableDims.size() << " and " << shape.size();
233
234 return success();
235 }
236
237 VectorType VectorType::scaleElementBitwidth(unsigned scale) {
238 if (!scale)
239 return VectorType();
240 if (auto et = llvm::dyn_cast(getElementType()))
241 if (auto scaledEt = et.scaleElementBitwidth(scale))
243 if (auto et = llvm::dyn_cast(getElementType()))
244 if (auto scaledEt = et.scaleElementBitwidth(scale))
246 return VectorType();
247 }
248
249 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
250 Type elementType) const {
252 getScalableDims());
253 }
254
255
256
257
258
261 .Case<RankedTensorType, UnrankedTensorType>(
262 [](auto type) { return type.getElementType(); });
263 }
264
266 return !llvm::isa(*this);
267 }
268
270 return llvm::cast(*this).getShape();
271 }
272
274 Type elementType) const {
275 if (llvm::dyn_cast(*this)) {
276 if (shape)
279 }
280
281 auto rankedTy = llvm::cast(*this);
282 if (!shape)
284 rankedTy.getEncoding());
286 rankedTy.getEncoding());
287 }
288
290 Type elementType) const {
291 return ::llvm::cast(cloneWith(shape, elementType));
292 }
293
295 return ::llvm::cast(cloneWith(shape, getElementType()));
296 }
297
298
299 static LogicalResult
301 Type elementType) {
303 return emitError() << "invalid tensor element type: " << elementType;
304 return success();
305 }
306
307
309
310
311
312 return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
313 IndexType>(type) ||
314 !llvm::isa(type.getDialect());
315 }
316
317
318
319
320
321 LogicalResult
325 for (int64_t s : shape)
326 if (s < 0 && !ShapedType::isDynamic(s))
327 return emitError() << "invalid tensor dimension size";
328 if (auto v = llvm::dyn_cast_or_null(encoding))
329 if (failed(v.verifyEncoding(shape, elementType, emitError)))
330 return failure();
332 }
333
334
335
336
337
338 LogicalResult
340 Type elementType) {
342 }
343
344
345
346
347
351 [](auto type) { return type.getElementType(); });
352 }
353
355 return !llvm::isa(*this);
356 }
357
359 return llvm::cast(*this).getShape();
360 }
361
363 Type elementType) const {
364 if (llvm::dyn_cast(*this)) {
365 if (!shape)
369 return builder;
370 }
371
373 if (shape)
376 return builder;
377 }
378
380 Type elementType) const {
381 return ::llvm::cast(cloneWith(shape, elementType));
382 }
383
385 return ::llvm::cast(cloneWith(shape, getElementType()));
386 }
387
389 if (auto rankedMemRefTy = llvm::dyn_cast(*this))
390 return rankedMemRefTy.getMemorySpace();
391 return llvm::cast(*this).getMemorySpace();
392 }
393
395 if (auto rankedMemRefTy = llvm::dyn_cast(*this))
396 return rankedMemRefTy.getMemorySpaceAsInt();
397 return llvm::cast(*this).getMemorySpaceAsInt();
398 }
399
400
401
402
403
404 std::optional<llvm::SmallDenseSet>
407 bool matchDynamic) {
408 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
409 llvm::SmallDenseSet unusedDims;
410 unsigned reducedIdx = 0;
411 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
412
413 int64_t origSize = originalShape[originalIdx];
414
415 if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
416 (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
417 ShapedType::isDynamic(origSize))) {
418 reducedIdx++;
419 continue;
420 }
421 if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
422 reducedIdx++;
423 continue;
424 }
425
426 unusedDims.insert(originalIdx);
427
428
429 if (origSize != 1)
430 return std::nullopt;
431 }
432
433 if (reducedIdx != reducedRank)
434 return std::nullopt;
435 return unusedDims;
436 }
437
440 ShapedType candidateReducedType) {
441 if (originalType == candidateReducedType)
443
444 ShapedType originalShapedType = llvm::cast(originalType);
445 ShapedType candidateReducedShapedType =
446 llvm::cast(candidateReducedType);
447
448
451 candidateReducedShapedType.getShape();
452 unsigned originalRank = originalShape.size(),
453 candidateReducedRank = candidateReducedShape.size();
454 if (candidateReducedRank > originalRank)
456
457 auto optionalUnusedDimsMask =
459
460
461 if (!optionalUnusedDimsMask)
463
464 if (originalShapedType.getElementType() !=
465 candidateReducedShapedType.getElementType())
467
469 }
470
472
473 if (!memorySpace)
474 return true;
475
476
477 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
478 return true;
479
480
481 if (!isa(memorySpace.getDialect()))
482 return true;
483
484 return false;
485 }
486
489 if (memorySpace == 0)
490 return nullptr;
491
493 }
494
496 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null(memorySpace);
497 if (intMemorySpace && intMemorySpace.getValue() == 0)
498 return nullptr;
499
500 return memorySpace;
501 }
502
504 if (!memorySpace)
505 return 0;
506
507 assert(llvm::isa(memorySpace) &&
508 "Using `getMemorySpaceInteger` with non-Integer attribute");
509
510 return static_cast<unsigned>(llvm::cast(memorySpace).getInt());
511 }
512
515 }
516
518 MemRefLayoutAttrInterface layout,
520
521 if (!layout)
523 shape.size(), elementType.getContext()));
524
525
527
529 memorySpace);
530 }
531
532 MemRefType MemRefType::getChecked(
534 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
535
536
537 if (!layout)
539 shape.size(), elementType.getContext()));
540
541
543
544 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
545 elementType, layout, memorySpace);
546 }
547
550
551
552 if (!map)
555
556
558
559
561
563 memorySpace);
564 }
565
566 MemRefType
570
571
572 if (!map)
575
576
578
579
581
582 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
583 elementType, layout, memorySpace);
584 }
585
587 AffineMap map, unsigned memorySpaceInd) {
588
589
590 if (!map)
593
594
596
597
600
602 memorySpace);
603 }
604
605 MemRefType
608 unsigned memorySpaceInd) {
609
610
611 if (!map)
614
615
617
618
621
622 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
623 elementType, layout, memorySpace);
624 }
625
628 MemRefLayoutAttrInterface layout,
631 return emitError() << "invalid memref element type";
632
633
634 for (int64_t s : shape)
635 if (s < 0 && !ShapedType::isDynamic(s))
636 return emitError() << "invalid memref size";
637
638 assert(layout && "missing layout specification");
639 if (failed(layout.verifyLayout(shape, emitError)))
640 return failure();
641
643 return emitError() << "unsupported memory space Attribute";
644
645 return success();
646 }
647
648 bool MemRefType::areTrailingDimsContiguous(int64_t n) {
649 if (!isLastDimUnitStride())
650 return false;
651
652 auto memrefShape = getShape().take_back(n);
653 if (ShapedType::isDynamicShape(memrefShape))
654 return false;
655
656 if (getLayout().isIdentity())
657 return true;
658
659 int64_t offset;
662 return false;
664
665 if (strides.empty())
666 return true;
667
668
670 auto dimProduct = 1;
671 for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
672 dimProduct *= dim;
673 flattenedDims.push_back(dimProduct);
674 }
675
676 strides = strides.drop_back(1);
677 return llvm::equal(strides, llvm::reverse(flattenedDims));
678 }
679
680 MemRefType MemRefType::canonicalizeStridedLayout() {
681 AffineMap m = getLayout().getAffineMap();
682
683
684 if (m.isIdentity())
685 return *this;
686
687
688 if (m.getNumResults() > 1)
689 return *this;
690
691
692 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
693 if (auto cst = llvm::dyn_cast(m.getResult(0)))
694 if (cst.getValue() == 0)
696 return *this;
697 }
698
699
700
701
703 return *this;
704
705
706
707
709 auto simplifiedLayoutExpr =
710 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
711 if (expr != simplifiedLayoutExpr)
714 simplifiedLayoutExpr)));
716 }
717
719 int64_t &offset) {
720 return getLayout().getStridesAndOffset(getShape(), strides, offset);
721 }
722
725 int64_t offset;
727 (void)status;
728 assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
729 return {strides, offset};
730 }
731
732 bool MemRefType::isStrided() {
733 int64_t offset;
736 return succeeded(res);
737 }
738
739 bool MemRefType::isLastDimUnitStride() {
740 int64_t offset;
743 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
744 }
745
746
747
748
749
752 }
753
754 LogicalResult
758 return emitError() << "invalid memref element type";
759
761 return emitError() << "unsupported memory space Attribute";
762
763 return success();
764 }
765
766
767
768
769
770
771 ArrayRef TupleType::getTypes() const { return getImpl()->getTypes(); }
772
773
774
775
776
778 for (Type type : getTypes()) {
779 if (auto nestedTuple = llvm::dyn_cast(type))
780 nestedTuple.getFlattenedTypes(types);
781 else
782 types.push_back(type);
783 }
784 }
785
786
787 size_t TupleType::size() const { return getImpl()->size(); }
788
789
790
791
792
796
797 if (sizes.empty())
799
800 assert(!exprs.empty() && "expected exprs");
802 assert(!maps.empty() && "Expected one non-empty map");
803 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
804
806 bool dynamicPoisonBit = false;
807 int64_t runningSize = 1;
808 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
809 int64_t size = std::get<1>(en);
810 AffineExpr dimExpr = std::get<0>(en);
814 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
815 if (size > 0) {
816 runningSize *= size;
817 assert(runningSize > 0 && "integer overflow in size computation");
818 } else {
819 dynamicPoisonBit = true;
820 }
821 }
823 }
824
828 exprs.reserve(sizes.size());
829 for (auto dim : llvm::seq(0, sizes.size()))
832 }
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef< int64_t > shape, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
A stride specification is a list of integer values that are either static or dynamic (encoded with Sh...
static LogicalResult checkTensorElementType(function_ref< InFlightDiagnostic()> emitError, Type elementType)
#define FLOAT_TYPE_SEMANTICS(TYPE, SEM)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
MemRefType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
BaseMemRefType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setElementType(Type newElementType)
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setMemorySpace(Attribute newMemorySpace)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
TensorType cloneWith(std::optional< ArrayRef< int64_t >> shape, Type elementType) const
Clone this type with the given shape and element type.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx)
Wraps deprecated integer memory space to the new Attribute form.
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isSupportedMemorySpace(Attribute memorySpace)
Checks if the memorySpace has supported Attribute type.
Attribute skipDefaultMemorySpace(Attribute memorySpace)
Replaces default memorySpace (integer == 0) with empty Attribute.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef< int64_t > sizes, ArrayRef< AffineExpr > exprs, MLIRContext *context)
Given MemRef sizes that are either static or dynamic, returns the canonical "contiguous" strides Affi...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)