MLIR: include/mlir/Interfaces/InferTypeOpInterface.h Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14 #ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
15 #define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
16
23 #include "llvm/ADT/PointerUnion.h"
24 #include "llvm/ADT/SmallVector.h"
25
26 namespace mlir {
27
28 class ShapedTypeComponents;
30
31
32
33 LogicalResult
36
37
38
40 public:
42 if (auto st = dyn_cast(t))
43 val = st;
44 }
46 if (auto da = dyn_cast(t))
47 val = da;
48 }
51
52
54
55
57
58
59
61
62
63
65
66
67
69
70
71
73 return ShapedType::isDynamic(getDimSize(index));
74 }
75
76
78
79
80
82
83
84
86
87
88 explicit operator bool() const { return !val.isNull(); }
89
90
91 void dump() const;
92
93 private:
94
95
97 };
98
99
100
101
102
103
104
105
106
108
110
111 public:
112
115 : elementType(elementType), attr(nullptr), ranked(false) {}
117 ranked = shapedType.hasRank();
118 elementType = shapedType.getElementType();
119 if (ranked)
120 dims = llvm::to_vector<4>(shapedType.getShape());
121 }
123 ranked = adaptor.hasRank();
125 if (ranked)
127 }
128 template <typename Arg, typename = std::enable_if_t<
129 std::is_constructible<ShapeStorageT, Arg>::value>>
132 : dims(std::forward(arg)), elementType(elementType), attr(attr),
133 ranked(true) {}
136 : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
137 ranked(true) {}
138
139
140
142 assert(ranked && "requires ranked shape");
143 return dims;
144 }
145
146
147 bool hasRank() const { return ranked; };
148
149
151
152
154
155 private:
157
159 Type elementType;
161 bool ranked{false};
162 };
163
164
165
166
167
168
169
171 public:
173
176 : RangeBaseT(values), operandShape(operandShape),
177 valueToShape(valueToShape) {}
180
182
183
185 valueToShape = fn;
186 return *this;
187 }
188
190 operandShape = fn;
191 return *this;
192 }
193
194
197
198
199
200
201
202
207
208
209
210
212
213
214
215
216
218
219
220
221
223
224
226
227 private:
228
229
231
232
233
235 };
236
237 namespace detail {
238
239
240 LogicalResult
242 SmallVectorImpl &inferredReturnTypes);
243
244
245
247
248
249
251 }
252
253 namespace OpTrait {
254 template
255 class InferTensorType;
256 }
257 }
258
259
260 #include "mlir/Interfaces/InferTypeOpInterface.h.inc"
261
262 namespace mlir {
263 namespace OpTrait {
264
265 template
267 };
268
269 template
271 : public TraitBase<ConcreteType, InferShapedTypeOpAdaptor> {};
272
273
274
275
276
277
278
279
280 template
282
283 }
284 }
285
286 #endif
Attributes are known-constant values of operations.
This class helps build Operations.
Tensor type inference trait that constructs a tensor from the inferred shape and elemental types.
Helper class for implementing traits.
Operation is the basic unit of execution within MLIR.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
void dump() const
Dumps textual repesentation to stderr.
Type getElementType() const
Returns the element type.
ShapeAdaptor(Attribute t)
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
ShapeAdaptor(ShapedTypeComponents &components)
ShapeAdaptor(ShapedTypeComponents *components)
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
ShapedTypeComponents()
Default construction is an unranked shape.
ShapedTypeComponents(Arg &&arg, Type elementType=nullptr, Attribute attr=nullptr)
ShapedTypeComponents(ShapedType shapedType)
ShapedTypeComponents(Type elementType)
ShapedTypeComponents(ArrayRef< int64_t > vec, Type elementType=nullptr, Attribute attr=nullptr)
ShapedTypeComponents(ShapeAdaptor adaptor)
bool hasRank() const
Return whether the shape has a rank.
Type getElementType() const
Return the element type component.
Attribute getAttribute() const
Return the raw attribute component.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ValueShapeMapFn getOperandShapeMapping() const
ValueShapeMapFn getValueToShapeMapping() const
Returns the set Value to ShapeAdaptor mapping function.
ValueShapeRange & setValueToShapeMapping(ValueShapeMapFn fn)
Sets the Value to ShapeAdaptor mapping function and returns this.
ValueShapeRange(const ValueShapeRange &)=default
ShapeAdaptor getValueAsShape(int index)
Returns an argument as shape.
type_range getTypes() const
ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape=nullptr, ValueShapeMapFn valueToShape=nullptr)
ValueShapeRange(const std::initializer_list< Value > &values)
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
function_ref< ShapeAdaptor(Value)> ValueShapeMapFn
ValueShapeRange & setOperandShapeMapping(ValueShapeMapFn fn)
ValueRange getValues() const
Returns the Values in the ValueRange.
This class implements iteration on the types of a given range of values.
This class implements iteration on the types of a given range of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult inferReturnTensorTypes(ArrayRef< ShapedTypeComponents > retComponents, SmallVectorImpl< Type > &inferredReturnTypes)
void reportFatalInferReturnTypesError(OperationState &state)
Report a fatal error indicating that the result types could not be inferred.
LogicalResult verifyInferredResultTypes(Operation *op)
Verifies that the inferred result types match the actual result types for the op.
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).