MLIR: lib/Interfaces/InferTypeOpInterface.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

17 #include "llvm/Support/FormatVariadic.h"

18 #include "llvm/Support/InterleavedRange.h"

19

20 using namespace mlir;

21

22 namespace mlir {

23 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"

24 }

25

26 LogicalResult

29 auto reifiableOp = dyn_cast(op);

30 if (!reifiableOp)

31 return failure();

32 LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes);

33 #ifndef NDEBUG

34 if (failed(status))

35 return failure();

36

37

38 int64_t resultIdx = 0;

40 auto shapedType = dyn_cast(result.getType());

41 if (!shapedType)

42 continue;

43 if (!shapedType.hasRank()) {

44

45 ++resultIdx;

46 continue;

47 }

48

49 assert(shapedType.getRank() ==

50 static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&

51 "incorrect implementation of ReifyRankedShapedTypeOpInterface");

52 ++resultIdx;

53 }

54

55 assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) &&

56 "incorrect implementation of ReifyRankedShapedTypeOpInterface");

57 #endif

58 return status;

59 }

60

62 if (val.isNull())

63 return false;

64 if (auto t = llvm::dyn_cast_if_present(val))

65 return cast(t).hasRank();

66 if (isa(val))

67 return true;

68 return cast<ShapedTypeComponents *>(val)->hasRank();

69 }

70

72 if (val.isNull())

73 return nullptr;

74 if (auto t = llvm::dyn_cast_if_present(val))

75 return cast(t).getElementType();

76 if (isa(val))

77 return nullptr;

78 return cast<ShapedTypeComponents *>(val)->getElementType();

79 }

80

83 if (auto t = llvm::dyn_cast_if_present(val)) {

85 res.assign(vals.begin(), vals.end());

86 } else if (auto attr = llvm::dyn_cast_if_present(val)) {

87 auto dattr = cast(attr);

88 res.clear();

89 res.reserve(dattr.size());

90 for (auto it : dattr.getValues())

91 res.push_back(it.getSExtValue());

92 } else {

93 auto vals = cast<ShapedTypeComponents *>(val)->getDims();

94 res.assign(vals.begin(), vals.end());

95 }

96 }

97

100 res.ranked = true;

102 }

103

106 if (auto t = llvm::dyn_cast_if_present(val))

107 return cast(t).getDimSize(index);

108 if (auto attr = llvm::dyn_cast_if_present(val))

109 return cast(attr)

110 .getValues()[index]

111 .getSExtValue();

112 auto *stc = cast<ShapedTypeComponents *>(val);

113 return stc->getDims()[index];

114 }

115

118 if (auto t = llvm::dyn_cast_if_present(val))

119 return cast(t).getRank();

120 if (auto attr = llvm::dyn_cast_if_present(val))

121 return cast(attr).size();

122 return cast<ShapedTypeComponents *>(val)->getDims().size();

123 }

124

127 return false;

128

129 if (auto t = llvm::dyn_cast_if_present(val))

130 return cast(t).hasStaticShape();

131 if (auto attr = llvm::dyn_cast_if_present(val)) {

132 auto dattr = cast(attr);

133 for (auto index : dattr.getValues())

134 if (ShapedType::isDynamic(index.getSExtValue()))

135 return false;

136 return true;

137 }

138 auto *stc = cast<ShapedTypeComponents *>(val);

139 return llvm::none_of(stc->getDims(), ShapedType::isDynamic);

140 }

141

143 assert(hasStaticShape() && "cannot get element count of dynamic shaped type");

144

145 if (auto t = llvm::dyn_cast_if_present(val))

146 return cast(t).getNumElements();

147

148 if (auto attr = llvm::dyn_cast_if_present(val)) {

149 auto dattr = cast(attr);

150 int64_t num = 1;

151 for (auto index : dattr.getValues()) {

152 num *= index.getZExtValue();

153 assert(num >= 0 && "integer overflow in element count computation");

154 }

155 return num;

156 }

157

158 auto *stc = cast<ShapedTypeComponents *>(val);

159 int64_t num = 1;

160 for (int64_t dim : stc->getDims()) {

161 num *= dim;

162 assert(num >= 0 && "integer overflow in element count computation");

163 }

164 return num;

165 }

166

169 llvm::errs() << "<>\n";

170 return;

171 }

172

175 auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {

176 if (ShapedType::isDynamic(dim))

177 return "?";

178 return llvm::formatv("{0}", dim).str();

179 });

180 llvm::errs() << "rank = " << getRank()

181 << " dims = " << llvm::interleaved_array(mapped, "x") << "\n";

182 }

183

185 Value val = operator[](index);

186 if (valueToShape)

188 return ret;

189

192 return nullptr;

193 if (attr.getType().getRank() != 1)

194 return nullptr;

195 return attr;

196 }

197

199 if (operandShape)

201 return ret;

203 }

204

206 if (index < 0 || static_cast<size_t>(index) >= size())

207 return nullptr;

208 return getShape(operator[](index));

209 }

210

214 for (const auto &shapeAndType : retComponents) {

215 Type elementTy = shapeAndType.getElementType();

216 assert(elementTy && "element type required to construct tensor");

217

218 Attribute attr = shapeAndType.getAttribute();

219 if (shapeAndType.hasRank()) {

220 inferredReturnTypes.push_back(

222 } else {

223 assert(attr == nullptr && "attribute not supported");

225 }

226 }

227 return success();

228 }

229

232 auto retTypeFn = cast(op);

233 auto result = retTypeFn.refineReturnTypes(

236 inferredReturnTypes);

237 if (failed(result))

238 op->emitOpError() << "failed to infer returned types";

239

240 return result;

241 }

242

244 std::string buffer;

245 llvm::raw_string_ostream os(buffer);

246 os << "Failed to infer result type(s):\n"

247 << "\"" << state.name << "\"(...) "

248 << state.attributes.getDictionary(state.location.getContext()) << " : ("

249 << llvm::interleaved(llvm::map_range(

250 state.operands, [](Value val) { return val.getType(); }))

251 << ") -> ( ??? )";

252 emitRemark(state.location, "location of op");

253 llvm::report_fatal_error(llvm::StringRef(buffer));

254 }

Attributes are known-constant values of operations.

An attribute that represents a reference to a dense integer vector or tensor object.

This class helps build Operations.

This is a value defined by a result of an operation.

Operation is the basic unit of execution within MLIR.

MLIRContext * getContext()

Return the context this operation is associated with.

Location getLoc()

The source location the operation was defined or derived from.

DictionaryAttr getRawDictionaryAttrs()

Return all attributes that are not stored as properties.

MutableArrayRef< Region > getRegions()

Returns the regions held by this operation.

result_type_range getResultTypes()

operand_range getOperands()

Returns an iterator on the underlying Value's.

result_range getResults()

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

OpaqueProperties getPropertiesStorage()

Returns the properties storage.

Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...

int64_t getDimSize(int index) const

Returns the size of the index'th dimension.

void dump() const

Dumps textual repesentation to stderr.

Type getElementType() const

Returns the element type.

int64_t getRank() const

Returns the rank of the shape.

bool hasStaticShape() const

Returns whether the shape is fully static.

int64_t getNumElements() const

Returns the number of elements in the shape.

void getDims(SmallVectorImpl< int64_t > &res) const

Populates the dimensions from shape referenced.

bool hasRank() const

Returns whether the shape has a rank.

ShapedTypeComponents that represents the components of a ShapedType.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

ShapeAdaptor getValueAsShape(int index)

Returns an argument as shape.

ShapeAdaptor getShape(int index) const

Returns the shape of index'th operand.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

LogicalResult inferReturnTensorTypes(ArrayRef< ShapedTypeComponents > retComponents, SmallVectorImpl< Type > &inferredReturnTypes)

void reportFatalInferReturnTypesError(OperationState &state)

Report a fatal error indicating that the result types could not be inferred.

LogicalResult verifyInferredResultTypes(Operation *op)

Verifies that the inferred result types match the actual result types for the op.

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)

Reify the shape of the result of an operation (typically in terms of the shape of its operands).

InFlightDiagnostic emitRemark(Location loc)

Utility method to emit a remark message using this location.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

This represents an operation in an abstracted form, suitable for use with the builder APIs.