MLIR: lib/Dialect/Arith/Utils/Utils.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
18 #include "llvm/ADT/SmallBitVector.h"
19 #include
20
21 using namespace mlir;
22
23 std::optional<SmallVector>
25 ShapedType expandedType,
28
31
32 if (inputShape.empty()) {
33 outputShapeInts.resize(expandedType.getRank(), 1);
34 return getMixedValues(outputShapeInts, outputShapeValues, b);
35 }
36
37
38 if (expandedType.hasStaticShape()) {
40 outputShapeInts.assign(staticShape.begin(), staticShape.end());
41 return getMixedValues(outputShapeInts, outputShapeValues, b);
42 }
43
44 outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
47
48 int64_t indexGroupStaticSizesProductInt = 1;
49 bool foundDynamicShape = false;
50 for (int64_t index : indexGroup) {
51 int64_t outputDimSize = expandedType.getDimSize(index);
52
53
54 if (ShapedType::isDynamic(outputDimSize)) {
55 if (foundDynamicShape)
56 return std::nullopt;
57 foundDynamicShape = true;
58 } else {
59 outputShapeInts[index] = outputDimSize;
60 indexGroupStaticSizesProductInt *= outputDimSize;
61 }
62 }
63 if (!foundDynamicShape)
64 continue;
65
66 int64_t inputIndex = it.index();
67
68
69 Value indexGroupSize = cast(inputShape[inputIndex]);
70 Value indexGroupStaticSizesProduct =
73 loc, indexGroupSize, indexGroupStaticSizesProduct);
74 outputShapeValues.push_back(dynamicDimSize);
75 }
76
77 if ((int64_t)outputShapeValues.size() !=
78 llvm::count(outputShapeInts, ShapedType::kDynamic))
79 return std::nullopt;
80
81 return getMixedValues(outputShapeInts, outputShapeValues, b);
82 }
83
84
85
86
89 }
90
93 llvm::SmallBitVector dimsToProject(shape.size());
94 for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
95 if (shape[pos] == 1) {
96 dimsToProject.set(pos);
97 --rank;
98 }
99 }
100 return dimsToProject;
101 }
102
105 if (auto value = dyn_cast_if_present(ofr))
106 return value;
107 auto attr = cast(cast(ofr));
108 return b.createarith::ConstantOp(
109 loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
110 }
111
114 if (auto value = dyn_cast_if_present(ofr))
115 return value;
116 auto attr = cast(cast(ofr));
118 }
119
122 if (targetType == value.getType())
123 return value;
124
125 bool targetIsIndex = targetType.isIndex();
127 if (targetIsIndex ^ valueIsIndex)
128 return b.createarith::IndexCastOp(loc, targetType, value);
129
130 auto targetIntegerType = dyn_cast(targetType);
131 auto valueIntegerType = dyn_cast(value.getType());
132 assert(targetIntegerType && valueIntegerType &&
133 "unexpected cast between types other than integers and index");
134 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
135
136 if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
137 return b.createarith::ExtSIOp(loc, targetIntegerType, value);
138 return b.createarith::TruncIOp(loc, targetIntegerType, value);
139 }
140
142 IntegerType toType, bool isUnsigned) {
143
144 if (isa(operand.getType())) {
145 if (isUnsigned)
146 return b.createarith::FPToUIOp(toType, operand);
147 return b.createarith::FPToSIOp(toType, operand);
148 }
149
151 return b.createarith::IndexCastOp(toType, operand);
152 if (auto fromIntType = dyn_cast(operand.getType())) {
153
154 if (toType.getWidth() > fromIntType.getWidth()) {
155 if (isUnsigned)
156 return b.createarith::ExtUIOp(toType, operand);
157 return b.createarith::ExtSIOp(toType, operand);
158 }
159 if (toType.getWidth() < fromIntType.getWidth())
160 return b.createarith::TruncIOp(toType, operand);
161 return operand;
162 }
163
164 return {};
165 }
166
168 FloatType toType, bool isUnsigned) {
169
170
171 if (isa(operand.getType())) {
172 if (isUnsigned)
173 return b.createarith::UIToFPOp(toType, operand);
174 return b.createarith::SIToFPOp(toType, operand);
175 }
176 if (auto fromFpTy = dyn_cast(operand.getType())) {
177 if (toType.getWidth() > fromFpTy.getWidth())
178 return b.createarith::ExtFOp(toType, operand);
179 if (toType.getWidth() < fromFpTy.getWidth())
180 return b.createarith::TruncFOp(toType, operand);
181 return operand;
182 }
183
184 return {};
185 }
186
188 ComplexType targetType,
189 bool isUnsigned) {
190 if (auto fromComplexType = dyn_cast(operand.getType())) {
191 if (isa(targetType.getElementType()) &&
192 isa(fromComplexType.getElementType())) {
193 Value real = b.createcomplex::ReOp(operand);
194 Value imag = b.createcomplex::ImOp(operand);
195 Type targetETy = targetType.getElementType();
196 if (targetType.getElementType().getIntOrFloatBitWidth() <
197 fromComplexType.getElementType().getIntOrFloatBitWidth()) {
198 real = b.createarith::TruncFOp(targetETy, real);
199 imag = b.createarith::TruncFOp(targetETy, imag);
200 } else {
201 real = b.createarith::ExtFOp(targetETy, real);
202 imag = b.createarith::ExtFOp(targetETy, imag);
203 }
204 return b.createcomplex::CreateOp(targetType, real, imag);
205 }
206 }
207
208 if (isa(operand.getType())) {
209 FloatType toFpTy = cast(targetType.getElementType());
210 auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
211 Value from = operand;
213 from = b.createarith::ExtFOp(toFpTy, from);
214 }
216 from = b.createarith::TruncFOp(toFpTy, from);
217 }
219 mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
220 return b.createcomplex::CreateOp(targetType, from, zero);
221 }
222
223 if (isa(operand.getType())) {
224 FloatType toFpTy = cast(targetType.getElementType());
225 Value from = operand;
226 if (isUnsigned) {
227 from = b.createarith::UIToFPOp(toFpTy, from);
228 } else {
229 from = b.createarith::SIToFPOp(toFpTy, from);
230 }
232 mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
233 return b.createcomplex::CreateOp(targetType, from, zero);
234 }
235
236 return {};
237 }
238
240 Type toType, bool isUnsignedCast) {
241 if (operand.getType() == toType)
242 return operand;
245 if (auto intTy = dyn_cast(toType)) {
247 } else if (auto floatTy = dyn_cast(toType)) {
249 } else if (auto complexTy = dyn_cast(toType)) {
250 result =
252 }
253
254 if (result)
255 return result;
256
257 emitWarning(loc) << "could not cast operand of type " << operand.getType()
258 << " to " << toType;
259 return operand;
260 }
261
265 return llvm::to_vector<4>(
266 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
268 }));
269 }
270
272 Type type, const APInt &value) {
273 TypedAttr attr;
274 if (isa(type)) {
276 } else {
277 auto vecTy = cast(type);
279 }
280
281 return builder.createarith::ConstantOp(loc, attr);
282 }
283
285 Type type, int64_t value) {
286 unsigned elementBitWidth = 0;
287 if (auto intTy = dyn_cast(type))
288 elementBitWidth = intTy.getWidth();
289 else
290 elementBitWidth = cast(type).getElementTypeBitWidth();
291
293 APInt(elementBitWidth, value));
294 }
295
297 Type type, const APFloat &value) {
298 if (isa(type))
299 return builder.createOrFoldarith::ConstantOp(
300 loc, type, builder.getFloatAttr(type, value));
302 return builder.createOrFoldarith::ConstantOp(loc, type, splat);
303 }
304
306 if (auto value = dyn_cast_if_present(ofr))
307 return value.getType();
308 auto attr = cast(cast(ofr));
309 return attr.getType();
310 }
311
313 return b.createarith::AndIOp(loc, lhs, rhs);
314 }
316 if (isa(lhs.getType()))
317 return b.createarith::AddFOp(loc, lhs, rhs);
318 return b.createarith::AddIOp(loc, lhs, rhs, ovf);
319 }
321 if (isa(lhs.getType()))
322 return b.createarith::SubFOp(loc, lhs, rhs);
323 return b.createarith::SubIOp(loc, lhs, rhs, ovf);
324 }
326 if (isa(lhs.getType()))
327 return b.createarith::MulFOp(loc, lhs, rhs);
328 return b.createarith::MulIOp(loc, lhs, rhs, ovf);
329 }
331 if (isa(lhs.getType()))
332 return b.createarith::CmpFOp(loc, arith::CmpFPredicate::OGT, lhs, rhs);
333 return b.createarith::CmpIOp(loc, arith::CmpIPredicate::sgt, lhs, rhs);
334 }
336 if (isa(lhs.getType()))
337 return b.createarith::CmpFOp(loc, arith::CmpFPredicate::OLT, lhs, rhs);
338 return b.createarith::CmpIOp(loc, arith::CmpIPredicate::slt, lhs, rhs);
339 }
341 return b.createarith::SelectOp(loc, cmp, lhs, rhs);
342 }
343
345
347 return createProduct(builder, loc, values, values.front().getType());
348 }
349
351 Type resultType) {
352 Value one = builder.create(loc, resultType,
355 return std::accumulate(
356 values.begin(), values.end(), one,
357 [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
358 }
359
360
364 .Case("f4E2M1FN", b.getType())
365 .Case("f6E2M3FN", b.getType())
366 .Case("f6E3M2FN", b.getType())
367 .Case("f8E5M2", b.getType())
368 .Case("f8E4M3", b.getType())
369 .Case("f8E4M3FN", b.getType())
370 .Case("f8E5M2FNUZ", b.getType())
371 .Case("f8E4M3FNUZ", b.getType())
372 .Case("f8E3M4", b.getType())
373 .Case("f8E8M0FNU", b.getType())
374 .Case("bf16", b.getType())
375 .Case("f16", b.getType())
376 .Case("f32", b.getType())
377 .Case("f64", b.getType())
378 .Case("f80", b.getType())
379 .Case("f128", b.getType())
380 .Default(std::nullopt);
381 }
382
383 }
static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, ComplexType targetType, bool isUnsigned)
static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, IntegerType toType, bool isUnsigned)
static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, FloatType toType, bool isUnsigned)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getOneAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns a floating point value.
Specialization of arith.constant op that returns an integer of index type.
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Helper struct to build simple arithmetic quantities with minimal type inference support.
Value mul(Value lhs, Value rhs)
Value _and(Value lhs, Value rhs)
Value slt(Value lhs, Value rhs)
Value select(Value cmp, Value lhs, Value rhs)
Value add(Value lhs, Value rhs)
Value sgt(Value lhs, Value rhs)
Value sub(Value lhs, Value rhs)
The matcher that matches a certain kind of op.