MLIR: lib/Dialect/Traits.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

12 #include "llvm/Support/FormatVariadic.h"

13 #include

14

15 using namespace mlir;

16

20 extents.emplace_back(shape1.begin(), shape1.end());

21 extents.emplace_back(shape2.begin(), shape2.end());

23 }

24

27 assert(!shapes.empty() && "Expected at least one shape");

28 size_t maxRank = shapes[0].size();

29 for (size_t i = 1; i != shapes.size(); ++i)

30 maxRank = std::max(maxRank, shapes[i].size());

31

32

33 for (size_t i = 0; i != maxRank; ++i) {

34 bool seenDynamic = false;

35 std::optional<int64_t> nonOneDim;

37 int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];

38

39 if (dim == 1)

40 continue;

41

42

43

44 if (ShapedType::isDynamic(dim)) {

45 if (seenDynamic || nonOneDim)

46 return false;

47 seenDynamic = true;

48 }

49

50

51 if (nonOneDim && dim != *nonOneDim)

52 return false;

53

54 nonOneDim = dim;

55 }

56 }

57 return true;

58 }

59

63

64

65

66

67

68

69

70

71 resultShape.clear();

72 if (shape1.size() > shape2.size()) {

73 llvm::append_range(resultShape, shape1);

74 } else {

75 llvm::append_range(resultShape, shape2);

76 }

77

78 auto i1 = shape1.rbegin(), e1 = shape1.rend();

79 auto i2 = shape2.rbegin(), e2 = shape2.rend();

80 auto iR = resultShape.rbegin();

81

82

83 for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {

84 if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {

85

86

87

88

89 if (*i1 > 1) {

90 *iR = *i1;

91 } else if (*i2 > 1) {

92 *iR = *i2;

93 } else if (*i1 == 1) {

94 *iR = *i2;

95 } else if (*i2 == 1) {

96 *iR = *i1;

97 } else {

98 *iR = ShapedType::kDynamic;

99 }

100 } else {

101 if (*i1 == *i2 || *i2 == 1) {

102 *iR = *i1;

103 } else if (*i1 == 1) {

104 *iR = *i2;

105 } else {

106

107 resultShape.clear();

108 return false;

109 }

110 }

111 }

112

113 return true;

114 }

115

116

117

119 if (auto sType = dyn_cast(type))

120 return sType.getShape();

121 return {};

122 }

123

124

125

126

127

128

129

130

131

132

134 Type elementType) {

135

136

137 if (!elementType) {

140 return {};

141 }

142

143

144

145 if (isa(type1) || isa(type2)) {

146 if (isa(type1) || isa(type2))

147 return {};

149 }

150

151

152

153 auto getCompositeTypeKind = [](Type type) -> std::optional {

154 if (isa<VectorType, RankedTensorType>(type))

155 return type.getTypeID();

156 return std::nullopt;

157 };

158

159

160 std::optional compositeKind1 = getCompositeTypeKind(type1);

161 std::optional compositeKind2 = getCompositeTypeKind(type2);

162 std::optional resultCompositeKind;

163

164 if (compositeKind1 && compositeKind2) {

165

166 if (compositeKind1 != compositeKind2)

167 return {};

168 resultCompositeKind = compositeKind1;

169 } else if (compositeKind1) {

170 resultCompositeKind = compositeKind1;

171 } else if (compositeKind2) {

172 resultCompositeKind = compositeKind2;

173 }

174

175

178 return {};

179

180

181 if (resultCompositeKind == VectorType::getTypeID())

183 if (resultCompositeKind == RankedTensorType::getTypeID())

185 return elementType;

186 }

187

188

189 template <typename iterator_range>

191 return {llvm::any_of(types, llvm::IsaPred),

192 llvm::any_of(types, llvm::IsaPred)};

193 }

194

197

198 auto isCompatible = [](int64_t inferredDim, int64_t existingDim) {

199 return ShapedType::isDynamic(existingDim) ||

200 ShapedType::isDynamic(inferredDim) || inferredDim == existingDim;

201 };

202 if (inferred.size() != existing.size())

203 return false;

204 for (auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing))

205 if (!isCompatible(inferredDim, existingDim))

206 return false;

207 return true;

208 }

209

211

212

213 std::string ret;

214 llvm::raw_string_ostream ss(ret);

215 ss << '\'';

216 llvm::interleave(

217 shape, ss,

218 [&](int64_t dim) {

219 if (ShapedType::isDynamic(dim))

220 ss << '?';

221 else

222 ss << dim;

223 },

224 "x");

225 ss << '\'';

226 return ret;

227 }

228

230

231 auto operandsHasTensorVectorType =

234 if ((std::get<0>(operandsHasTensorVectorType) ||

235 std::get<0>(resultsHasTensorVectorType)) &&

236 (std::get<1>(operandsHasTensorVectorType) ||

237 std::get<1>(resultsHasTensorVectorType)))

238 return op->emitError("cannot broadcast vector with tensor");

239

240 auto rankedOperands =

241 make_filter_range(op->getOperandTypes(), llvm::IsaPred);

242

243

244 if (rankedOperands.empty())

245 return success();

246

247

248

249

252 resultShape);

253 for (auto other : make_early_inc_range(rankedOperands)) {

256 return op->emitOpError("operands don't have broadcast-compatible shapes");

257 }

258

259 auto rankedResults =

260 make_filter_range(op->getResultTypes(), llvm::IsaPred);

261

262

263 if (rankedResults.empty())

264 return success();

265

266 for (auto type : rankedResults) {

268 getShape(type).take_back(resultShape.size());

272 << " not broadcast compatible with broadcasted operands's shapes "

274 }

275 return success();

276 }

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

static std::string getShapeString(ArrayRef< int64_t > shape)

static bool isCompatibleInferredReturnShape(ArrayRef< int64_t > inferred, ArrayRef< int64_t > existing)

static std::tuple< bool, bool > hasTensorOrVectorType(iterator_range types)

Returns a tuple corresponding to whether range has tensor or vector type.

static ArrayRef< int64_t > getShape(Type type)

Returns the shape of the given type.

Operation is the basic unit of execution within MLIR.

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

operand_type_range getOperandTypes()

result_type_range getResultTypes()

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

LogicalResult verifyCompatibleOperandBroadcast(Operation *op)

bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)

Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...

Type getBroadcastedType(Type type1, Type type2, Type elementType=nullptr)

Returns the result broadcast composition type from the two given types by following NumPy broadcast s...

bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)

Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...

Include the generated interface declarations.

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...