MLIR: lib/Dialect/Arith/Transforms/ExpandOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
16
17 namespace mlir {
18 namespace arith {
19 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
21 }
22 }
23
24 using namespace mlir;
25
26
30 if (auto shapedTy = dyn_cast(type)) {
31 return rewriter.createarith::ConstantOp(
33 }
34 return rewriter.createarith::ConstantOp(loc, attr);
35 }
36
37
39 if (auto shapedTy = dyn_cast(cloneFrom)) {
40 return shapedTy.clone(cloneTo);
41 }
42 return cloneTo;
43 }
44
45 namespace {
46
47
48
49 struct CeilDivUIOpConverter : public OpRewritePatternarith::CeilDivUIOp {
51 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
54 Value a = op.getLhs();
55 Value b = op.getRhs();
58 rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::eq, a, zero);
60 Value minusOne = rewriter.createarith::SubIOp(loc, a, one);
61 Value quotient = rewriter.createarith::DivUIOp(loc, minusOne, b);
62 Value plusOne = rewriter.createarith::AddIOp(loc, quotient, one);
63 rewriter.replaceOpWithNewOparith::SelectOp(op, compare, zero, plusOne);
64 return success();
65 }
66 };
67
68
69
70
71
72
73
74
75 struct CeilDivSIOpConverter : public OpRewritePatternarith::CeilDivSIOp {
77 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
80 Type type = op.getType();
81 Value a = op.getLhs();
82 Value b = op.getRhs();
83
86
87 Value quotient = rewriter.createarith::DivSIOp(loc, a, b);
88 Value product = rewriter.createarith::MulIOp(loc, quotient, b);
89 Value notEqualDivisor = rewriter.createarith::CmpIOp(
90 loc, arith::CmpIPredicate::ne, a, product);
91
93 rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::slt, a, zero);
95 rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::slt, b, zero);
96
97 Value signEqual = rewriter.createarith::CmpIOp(
98 loc, arith::CmpIPredicate::eq, aNeg, bNeg);
100 rewriter.createarith::AndIOp(loc, notEqualDivisor, signEqual);
101
102 Value quotientPlusOne = rewriter.createarith::AddIOp(loc, quotient, one);
103
104 rewriter.replaceOpWithNewOparith::SelectOp(op, cond, quotientPlusOne,
105 quotient);
106 return success();
107 }
108 };
109
110
111
112
113
114
115
116
117 struct FloorDivSIOpConverter : public OpRewritePatternarith::FloorDivSIOp {
119 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
122 Type type = op.getType();
123 Value a = op.getLhs();
124 Value b = op.getRhs();
125
126 Value quotient = rewriter.createarith::DivSIOp(loc, a, b);
127 Value product = rewriter.createarith::MulIOp(loc, quotient, b);
128 Value notEqualDivisor = rewriter.createarith::CmpIOp(
129 loc, arith::CmpIPredicate::ne, a, product);
131
133 rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::slt, a, zero);
135 rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::slt, b, zero);
136
137 Value signOpposite = rewriter.createarith::CmpIOp(
138 loc, arith::CmpIPredicate::ne, aNeg, bNeg);
140 rewriter.createarith::AndIOp(loc, notEqualDivisor, signOpposite);
141
143 Value quotientMinusOne =
144 rewriter.createarith::AddIOp(loc, quotient, minusOne);
145
146 rewriter.replaceOpWithNewOparith::SelectOp(op, cond, quotientMinusOne,
147 quotient);
148 return success();
149 }
150 };
151
152 template <typename OpTy, arith::CmpIPredicate pred>
154 public:
156
157 LogicalResult matchAndRewrite(OpTy op,
159 Value lhs = op.getLhs();
160 Value rhs = op.getRhs();
161
162 Value cmp = rewriter.createarith::CmpIOp(op.getLoc(), pred, lhs, rhs);
163 rewriter.replaceOpWithNewOparith::SelectOp(op, cmp, lhs, rhs);
164 return success();
165 }
166 };
167
168 template <typename OpTy, arith::CmpFPredicate pred>
169 struct MaximumMinimumFOpConverter : public OpRewritePattern {
170 public:
172
173 LogicalResult matchAndRewrite(OpTy op,
175 Value lhs = op.getLhs();
176 Value rhs = op.getRhs();
177
179
180 static_assert(pred == arith::CmpFPredicate::UGT ||
181 pred == arith::CmpFPredicate::ULT,
182 "pred must be either UGT or ULT");
183 Value cmp = rewriter.createarith::CmpFOp(loc, pred, lhs, rhs);
184 Value select = rewriter.createarith::SelectOp(loc, cmp, lhs, rhs);
185
186
187 Value isNaN = rewriter.createarith::CmpFOp(loc, arith::CmpFPredicate::UNO,
188 rhs, rhs);
189 rewriter.replaceOpWithNewOparith::SelectOp(op, isNaN, rhs, select);
190 return success();
191 }
192 };
193
194 template <typename OpTy, arith::CmpFPredicate pred>
195 struct MaxNumMinNumFOpConverter : public OpRewritePattern {
196 public:
198
199 LogicalResult matchAndRewrite(OpTy op,
201 Value lhs = op.getLhs();
202 Value rhs = op.getRhs();
203
205
206 static_assert(pred == arith::CmpFPredicate::UGT ||
207 pred == arith::CmpFPredicate::ULT,
208 "pred must be either UGT or ULT");
209 Value cmp = rewriter.createarith::CmpFOp(loc, pred, lhs, rhs);
210 Value select = rewriter.createarith::SelectOp(loc, cmp, lhs, rhs);
211
212
213 Value isNaN = rewriter.createarith::CmpFOp(loc, arith::CmpFPredicate::UNO,
214 lhs, lhs);
215 rewriter.replaceOpWithNewOparith::SelectOp(op, isNaN, rhs, select);
216 return success();
217 }
218 };
219
220 struct BFloat16ExtFOpConverter : public OpRewritePatternarith::ExtFOp {
222 LogicalResult matchAndRewrite(arith::ExtFOp op,
225 auto operand = op.getOperand();
226 Type operandTy = operand.getType();
227 Type resultTy = op.getType();
230
231 if (!operandETy.isBF16() || !resultETy.isF32()) {
232 return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
233 }
234
237
238 Value bitcast = b.createarith::BitcastOp(i16Ty, operand);
239 Value exti = b.createarith::ExtUIOp(i32Ty, bitcast);
240
242 Value shl = b.createarith::ShLIOp(exti, c16);
243 Value result = b.createarith::BitcastOp(resultTy, shl);
244
245 rewriter.replaceOp(op, result);
246 return success();
247 }
248 };
249
250 struct BFloat16TruncFOpConverter : public OpRewritePatternarith::TruncFOp {
252 LogicalResult matchAndRewrite(arith::TruncFOp op,
255 auto operand = op.getOperand();
256 Type operandTy = operand.getType();
257 Type resultTy = op.getType();
260
261 if (!operandETy.isF32() || !resultETy.isBF16()) {
262 return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
263 }
264
265 if (op.getRoundingmodeAttr()) {
266 return rewriter.notifyMatchFailure(
267 op, "only applicable to default rounding mode.");
268 }
269
272
273
274
275
276
277
278
279
280
281
282
283
284
285
287 b.createarith::CmpFOp(arith::CmpFPredicate::UNE, operand, operand);
288
289 Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
290
291 Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
292
295
296 Value bitcast = b.createarith::BitcastOp(i32Ty, operand);
297
299 b.createarith::AndIOp(b.createarith::ShRUIOp(bitcast, c16), c1);
300
301
302 Value roundingBias = b.createarith::AddIOp(bit16, c7FFF);
303
304
305
306
307
308
309 Value biased = b.createarith::AddIOp(bitcast, roundingBias);
310
311
312 Value biasedAndShifted = b.createarith::ShRUIOp(biased, c16);
313 Value normalCaseResultI16 =
314 b.createarith::TruncIOp(i16Ty, biasedAndShifted);
315
316
318 b.createarith::SelectOp(isNan, c7FC0I16, normalCaseResultI16);
319 Value result = b.createarith::BitcastOp(resultTy, select);
320 rewriter.replaceOp(op, result);
321 return success();
322 }
323 };
324
325 struct F8E8M0ExtFOpConverter : public OpRewritePatternarith::ExtFOp {
327 LogicalResult matchAndRewrite(arith::ExtFOp op,
330 Value operand = op.getOperand();
332 Type resultTy = op.getType();
335
336 if (!llvm::isa(operandETy)) {
337 return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
338 }
339
343
344 Value bitcast = b.createarith::BitcastOp(i8Ty, operand);
345
346 Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
347 Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
348 Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
349
350 Value exti = b.createarith::ExtUIOp(i32Ty, bitcast);
351 Value f32Bits = b.createarith::ShLIOp(exti, cF32MantissaWidth);
352
354 b.createarith::CmpIOp(arith::CmpIPredicate::eq, bitcast, cF8NaN);
355
356 f32Bits = b.createarith::SelectOp(isNan, cF32NaN, f32Bits);
357 Value result = b.createarith::BitcastOp(f32Ty, f32Bits);
359 result = b.createarith::TruncFOp(resultTy, result, nullptr,
360 op.getFastmathAttr());
362 result = b.createarith::ExtFOp(resultTy, result, op.getFastmathAttr());
363 }
364 rewriter.replaceOp(op, result);
365 return success();
366 }
367 };
368
369
370
371
372
373
374 struct F8E8M0TruncFOpConverter : public OpRewritePatternarith::TruncFOp {
376 LogicalResult matchAndRewrite(arith::TruncFOp op,
379 Value operand = op.getOperand();
382 Type resultTy = op.getType();
384 if (!llvm::isa(resultETy)) {
385 return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
386 }
387
388 if (op.getRoundingmodeAttr()) {
389 return rewriter.notifyMatchFailure(
390 op, "only applicable to default rounding mode.");
391 }
392
396
398 operand = b.createarith::ExtFOp(f32Ty, operand, op.getFastmathAttr());
400 operand = b.createarith::TruncFOp(
401 f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
402 }
403 Value f32Bits = b.createarith::BitcastOp(i32Ty, operand);
404 Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
405 Value f32SignExp = b.createarith::ShRUIOp(f32Bits, cF32MantissaWidth);
406 Value exp8Bits = b.createarith::TruncIOp(i8Ty, f32SignExp);
407 Value result = b.createarith::BitcastOp(resultTy, exp8Bits);
408 rewriter.replaceOp(op, result);
409 return success();
410 }
411 };
412
413 struct ScalingExtFOpConverter : public OpRewritePatternarith::ScalingExtFOp {
415 LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
418 Value inputOperand = op.getIn();
419 Value scaleOperand = op.getScale();
422
424 scaleETy = b.getF8E8M0Type();
426 scaleOperand = b.createarith::TruncFOp(scaleTy, scaleOperand, nullptr,
427 op.getFastmathAttr());
428 }
429 if (!llvm::isa(scaleETy)) {
430 return rewriter.notifyMatchFailure(
431 op, "scaling_extf is using scales of type which can not be converted "
432 "to f8E8M0FNU");
433 }
434 Type resultTy = op.getType();
435
436
438 b.createarith::ExtFOp(resultTy, scaleOperand, op.getFastmathAttr());
440 b.createarith::ExtFOp(resultTy, inputOperand, op.getFastmathAttr());
442 b.createarith::MulFOp(inputExt, scaleExt, op.getFastmathAttr());
443 rewriter.replaceOp(op, result);
444 return success();
445 }
446 };
447
448
449
450
451
452
453 struct ScalingTruncFOpConverter
456 LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
459 Value inputOperand = op.getIn();
460 Value scaleOperand = op.getScale();
463
465 scaleETy = b.getF8E8M0Type();
467 scaleOperand = b.createarith::TruncFOp(scaleTy, scaleOperand, nullptr,
468 op.getFastmathAttr());
469 }
470 if (!llvm::isa(scaleETy)) {
471 return rewriter.notifyMatchFailure(
472 op, "scaling_truncf is using scales type which can not be converted "
473 "to f8E8M0FNU");
474 }
475 Type resultTy = op.getType();
477
478
479 scaleOperand =
480 b.createarith::ExtFOp(inputTy, scaleOperand, op.getFastmathAttr());
481 Value result = b.createarith::DivFOp(inputOperand, scaleOperand,
482 op.getFastmathAttr());
483 Value resultCast = b.createarith::TruncFOp(
484 resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
485 rewriter.replaceOp(op, resultCast);
486 return success();
487 }
488 };
489
490 struct ArithExpandOpsPass
491 : public arith::impl::ArithExpandOpsPassBase {
492 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
493
494 void runOnOperation() override {
497
499
500 target.addLegalDialectarith::ArithDialect();
501
502 target.addIllegalOp<
503 arith::CeilDivSIOp,
504 arith::CeilDivUIOp,
505 arith::FloorDivSIOp,
506 arith::MaxSIOp,
507 arith::MaxUIOp,
508 arith::MinSIOp,
509 arith::MinUIOp,
510 arith::MaximumFOp,
511 arith::MinimumFOp,
512 arith::MaxNumFOp,
513 arith::MinNumFOp,
514 arith::ScalingExtFOp,
515 arith::ScalingTruncFOp
516 >();
517
518 if (includeBf16) {
520 }
521 if (includeF8E8M0) {
523 }
524
525 target.addDynamicallyLegalOparith::ExtFOp(
526 [=](arith::ExtFOp op) {
529 bool legalTypes = true;
530 if (includeBf16)
531 legalTypes &= !(inETy.isBF16() && outETy.isF32());
532 if (includeF8E8M0)
533 legalTypes &= !llvm::isa(inETy);
534 return legalTypes;
535 });
536
537 target.addDynamicallyLegalOparith::TruncFOp(
538 [=](arith::TruncFOp op) {
541 bool legalTypes = true;
542 if (includeBf16)
543 legalTypes &= !(inETy.isF32() && outETy.isBF16());
544 if (includeF8E8M0)
545 legalTypes &= !(llvm::isa(outETy));
546 return legalTypes;
547 });
548
549
552 signalPassFailure();
553 }
554 };
555
556 }
557
561 .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
563 }
564
566 patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
568 }
569
571 patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
573 }
574
577 patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
579 }
580
584
586 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
587 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
588 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
589 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
590 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
591 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
592 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
593 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
595
596 }
static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter)
Create an integer or index constant.
static Type cloneToShapedType(Type cloneFrom, Type cloneTo)
Creates shapedType using shape from cloneFrom and base type from cloneTo.
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIntegerAttr(Type type, int64_t value)
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns)
Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops.
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...