MLIR: lib/Dialect/Traits.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
12 #include "llvm/Support/FormatVariadic.h"
13 #include
14
15 using namespace mlir;
16
20 extents.emplace_back(shape1.begin(), shape1.end());
21 extents.emplace_back(shape2.begin(), shape2.end());
23 }
24
27 assert(!shapes.empty() && "Expected at least one shape");
28 size_t maxRank = shapes[0].size();
29 for (size_t i = 1; i != shapes.size(); ++i)
30 maxRank = std::max(maxRank, shapes[i].size());
31
32
33 for (size_t i = 0; i != maxRank; ++i) {
34 bool seenDynamic = false;
35 std::optional<int64_t> nonOneDim;
37 int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
38
39 if (dim == 1)
40 continue;
41
42
43
44 if (ShapedType::isDynamic(dim)) {
45 if (seenDynamic || nonOneDim)
46 return false;
47 seenDynamic = true;
48 }
49
50
51 if (nonOneDim && dim != *nonOneDim)
52 return false;
53
54 nonOneDim = dim;
55 }
56 }
57 return true;
58 }
59
63
64
65
66
67
68
69
70
71 resultShape.clear();
72 if (shape1.size() > shape2.size()) {
73 llvm::append_range(resultShape, shape1);
74 } else {
75 llvm::append_range(resultShape, shape2);
76 }
77
78 auto i1 = shape1.rbegin(), e1 = shape1.rend();
79 auto i2 = shape2.rbegin(), e2 = shape2.rend();
80 auto iR = resultShape.rbegin();
81
82
83 for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
84 if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
85
86
87
88
89 if (*i1 > 1) {
90 *iR = *i1;
91 } else if (*i2 > 1) {
92 *iR = *i2;
93 } else if (*i1 == 1) {
94 *iR = *i2;
95 } else if (*i2 == 1) {
96 *iR = *i1;
97 } else {
98 *iR = ShapedType::kDynamic;
99 }
100 } else {
101 if (*i1 == *i2 || *i2 == 1) {
102 *iR = *i1;
103 } else if (*i1 == 1) {
104 *iR = *i2;
105 } else {
106
107 resultShape.clear();
108 return false;
109 }
110 }
111 }
112
113 return true;
114 }
115
116
117
119 if (auto sType = dyn_cast(type))
120 return sType.getShape();
121 return {};
122 }
123
124
125
126
127
128
129
130
131
132
134 Type elementType) {
135
136
137 if (!elementType) {
140 return {};
141 }
142
143
144
145 if (isa(type1) || isa(type2)) {
146 if (isa(type1) || isa(type2))
147 return {};
149 }
150
151
152
153 auto getCompositeTypeKind = [](Type type) -> std::optional {
154 if (isa<VectorType, RankedTensorType>(type))
155 return type.getTypeID();
156 return std::nullopt;
157 };
158
159
160 std::optional compositeKind1 = getCompositeTypeKind(type1);
161 std::optional compositeKind2 = getCompositeTypeKind(type2);
162 std::optional resultCompositeKind;
163
164 if (compositeKind1 && compositeKind2) {
165
166 if (compositeKind1 != compositeKind2)
167 return {};
168 resultCompositeKind = compositeKind1;
169 } else if (compositeKind1) {
170 resultCompositeKind = compositeKind1;
171 } else if (compositeKind2) {
172 resultCompositeKind = compositeKind2;
173 }
174
175
178 return {};
179
180
181 if (resultCompositeKind == VectorType::getTypeID())
183 if (resultCompositeKind == RankedTensorType::getTypeID())
185 return elementType;
186 }
187
188
189 template <typename iterator_range>
191 return {llvm::any_of(types, llvm::IsaPred),
192 llvm::any_of(types, llvm::IsaPred)};
193 }
194
197
198 auto isCompatible = [](int64_t inferredDim, int64_t existingDim) {
199 return ShapedType::isDynamic(existingDim) ||
200 ShapedType::isDynamic(inferredDim) || inferredDim == existingDim;
201 };
202 if (inferred.size() != existing.size())
203 return false;
204 for (auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing))
205 if (!isCompatible(inferredDim, existingDim))
206 return false;
207 return true;
208 }
209
211
212
213 std::string ret;
214 llvm::raw_string_ostream ss(ret);
215 ss << '\'';
216 llvm::interleave(
217 shape, ss,
218 [&](int64_t dim) {
219 if (ShapedType::isDynamic(dim))
220 ss << '?';
221 else
222 ss << dim;
223 },
224 "x");
225 ss << '\'';
226 return ret;
227 }
228
230
231 auto operandsHasTensorVectorType =
234 if ((std::get<0>(operandsHasTensorVectorType) ||
235 std::get<0>(resultsHasTensorVectorType)) &&
236 (std::get<1>(operandsHasTensorVectorType) ||
237 std::get<1>(resultsHasTensorVectorType)))
238 return op->emitError("cannot broadcast vector with tensor");
239
240 auto rankedOperands =
241 make_filter_range(op->getOperandTypes(), llvm::IsaPred);
242
243
244 if (rankedOperands.empty())
245 return success();
246
247
248
249
252 resultShape);
253 for (auto other : make_early_inc_range(rankedOperands)) {
256 return op->emitOpError("operands don't have broadcast-compatible shapes");
257 }
258
259 auto rankedResults =
260 make_filter_range(op->getResultTypes(), llvm::IsaPred);
261
262
263 if (rankedResults.empty())
264 return success();
265
266 for (auto type : rankedResults) {
268 getShape(type).take_back(resultShape.size());
272 << " not broadcast compatible with broadcasted operands's shapes "
274 }
275 return success();
276 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::string getShapeString(ArrayRef< int64_t > shape)
static bool isCompatibleInferredReturnShape(ArrayRef< int64_t > inferred, ArrayRef< int64_t > existing)
static std::tuple< bool, bool > hasTensorOrVectorType(iterator_range types)
Returns a tuple corresponding to whether range has tensor or vector type.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
LogicalResult verifyCompatibleOperandBroadcast(Operation *op)
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Type getBroadcastedType(Type type1, Type type2, Type elementType=nullptr)
Returns the result broadcast composition type from the two given types by following NumPy broadcast s...
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...