MLIR: include/mlir/Interfaces/InferTypeOpInterface.h Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14 #ifndef MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_

15 #define MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_

16

23 #include "llvm/ADT/PointerUnion.h"

24 #include "llvm/ADT/SmallVector.h"

25

26 namespace mlir {

27

28 class ShapedTypeComponents;

30

31

32

33 LogicalResult

36

37

38

40 public:

42 if (auto st = dyn_cast(t))

43 val = st;

44 }

46 if (auto da = dyn_cast(t))

47 val = da;

48 }

51

52

54

55

57

58

59

61

62

63

65

66

67

69

70

71

73 return ShapedType::isDynamic(getDimSize(index));

74 }

75

76

78

79

80

82

83

84

86

87

88 explicit operator bool() const { return !val.isNull(); }

89

90

91 void dump() const;

92

93 private:

94

95

97 };

98

99

100

101

102

103

104

105

106

108

110

111 public:

112

115 : elementType(elementType), attr(nullptr), ranked(false) {}

117 ranked = shapedType.hasRank();

118 elementType = shapedType.getElementType();

119 if (ranked)

120 dims = llvm::to_vector<4>(shapedType.getShape());

121 }

123 ranked = adaptor.hasRank();

125 if (ranked)

127 }

128 template <typename Arg, typename = std::enable_if_t<

129 std::is_constructible<ShapeStorageT, Arg>::value>>

132 : dims(std::forward(arg)), elementType(elementType), attr(attr),

133 ranked(true) {}

136 : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),

137 ranked(true) {}

138

139

140

142 assert(ranked && "requires ranked shape");

143 return dims;

144 }

145

146

147 bool hasRank() const { return ranked; };

148

149

151

152

154

155 private:

157

159 Type elementType;

161 bool ranked{false};

162 };

163

164

165

166

167

168

169

171 public:

173

176 : RangeBaseT(values), operandShape(operandShape),

177 valueToShape(valueToShape) {}

180

182

183

185 valueToShape = fn;

186 return *this;

187 }

188

190 operandShape = fn;

191 return *this;

192 }

193

194

197

198

199

200

201

202

207

208

209

210

212

213

214

215

216

218

219

220

221

223

224

226

227 private:

228

229

231

232

233

235 };

236

237 namespace detail {

238

239

240 LogicalResult

242 SmallVectorImpl &inferredReturnTypes);

243

244

245

247

248

249

251 }

252

253 namespace OpTrait {

254 template

255 class InferTensorType;

256 }

257 }

258

259

260 #include "mlir/Interfaces/InferTypeOpInterface.h.inc"

261

262 namespace mlir {

263 namespace OpTrait {

264

265 template

267 };

268

269 template

271 : public TraitBase<ConcreteType, InferShapedTypeOpAdaptor> {};

272

273

274

275

276

277

278

279

280 template

282

283 }

284 }

285

286 #endif

Attributes are known-constant values of operations.

This class helps build Operations.

Tensor type inference trait that constructs a tensor from the inferred shape and elemental types.

Helper class for implementing traits.

Operation is the basic unit of execution within MLIR.

Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...

bool isDynamicDim(int index) const

Returns whether the index'th dimension is dynamic.

int64_t getDimSize(int index) const

Returns the size of the index'th dimension.

void dump() const

Dumps textual repesentation to stderr.

Type getElementType() const

Returns the element type.

ShapeAdaptor(Attribute t)

int64_t getRank() const

Returns the rank of the shape.

bool hasStaticShape() const

Returns whether the shape is fully static.

ShapeAdaptor(ShapedTypeComponents &components)

ShapeAdaptor(ShapedTypeComponents *components)

int64_t getNumElements() const

Returns the number of elements in the shape.

void getDims(SmallVectorImpl< int64_t > &res) const

Populates the dimensions from shape referenced.

bool hasRank() const

Returns whether the shape has a rank.

ShapedTypeComponents that represents the components of a ShapedType.

ShapedTypeComponents()

Default construction is an unranked shape.

ShapedTypeComponents(Arg &&arg, Type elementType=nullptr, Attribute attr=nullptr)

ShapedTypeComponents(ShapedType shapedType)

ShapedTypeComponents(Type elementType)

ShapedTypeComponents(ArrayRef< int64_t > vec, Type elementType=nullptr, Attribute attr=nullptr)

ShapedTypeComponents(ShapeAdaptor adaptor)

bool hasRank() const

Return whether the shape has a rank.

Type getElementType() const

Return the element type component.

Attribute getAttribute() const

Return the raw attribute component.

ArrayRef< int64_t > getDims() const

Return the dimensions of the shape.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).

ValueShapeMapFn getOperandShapeMapping() const

ValueShapeMapFn getValueToShapeMapping() const

Returns the set Value to ShapeAdaptor mapping function.

ValueShapeRange & setValueToShapeMapping(ValueShapeMapFn fn)

Sets the Value to ShapeAdaptor mapping function and returns this.

ValueShapeRange(const ValueShapeRange &)=default

ShapeAdaptor getValueAsShape(int index)

Returns an argument as shape.

type_range getTypes() const

ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape=nullptr, ValueShapeMapFn valueToShape=nullptr)

ValueShapeRange(const std::initializer_list< Value > &values)

ShapeAdaptor getShape(int index) const

Returns the shape of index'th operand.

function_ref< ShapeAdaptor(Value)> ValueShapeMapFn

ValueShapeRange & setOperandShapeMapping(ValueShapeMapFn fn)

ValueRange getValues() const

Returns the Values in the ValueRange.

This class implements iteration on the types of a given range of values.

This class implements iteration on the types of a given range of values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

LogicalResult inferReturnTensorTypes(ArrayRef< ShapedTypeComponents > retComponents, SmallVectorImpl< Type > &inferredReturnTypes)

void reportFatalInferReturnTypesError(OperationState &state)

Report a fatal error indicating that the result types could not be inferred.

LogicalResult verifyInferredResultTypes(Operation *op)

Verifies that the inferred result types match the actual result types for the op.

Include the generated interface declarations.

LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)

Reify the shape of the result of an operation (typically in terms of the shape of its operands).