MLIR: lib/Dialect/Math/Transforms/ExpandPatterns.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
22 #include "llvm/ADT/APFloat.h"
23
24 using namespace mlir;
25
26
29 bool losesInfo = false;
31
32 value.convert(cast(eltType).getFloatSemantics(),
33 APFloat::rmNearestTiesToEven, &losesInfo);
35 if (auto shapedTy = dyn_cast(type)) {
36 return b.createarith::ConstantOp(loc,
38 }
39
40 return b.createarith::ConstantOp(loc, attr);
41 }
42
46 }
47
48
52 if (auto shapedTy = dyn_cast(type)) {
53 return b.createarith::ConstantOp(loc,
55 }
56
57 return b.createarith::ConstantOp(loc, attr);
58 }
59
63 if (auto shapedTy = dyn_cast(opType))
64 i64Ty = shapedTy.clone(i64Ty);
65 Value fixedConvert = b.createarith::FPToSIOp(i64Ty, operand);
66 Value fpFixedConvert = b.createarith::SIToFPOp(opType, fixedConvert);
67
68
69 return b.createmath::CopySignOp(fpFixedConvert, operand);
70 }
71
72
75 Value operand = op.getOperand();
77
78 Value exp = b.createmath::ExpOp(operand);
79 Value neg = b.createarith::NegFOp(operand);
81 Value sub = b.createarith::SubFOp(exp, nexp);
83 Value res = b.createarith::MulFOp(sub, half);
85 return success();
86 }
87
88
91 Value operand = op.getOperand();
93
94 Value exp = b.createmath::ExpOp(operand);
95 Value neg = b.createarith::NegFOp(operand);
97 Value add = b.createarith::AddFOp(exp, nexp);
99 Value res = b.createarith::MulFOp(add, half);
101 return success();
102 }
103
104
105
106
107
108
109
110
111
113 auto floatType = op.getOperand().getType();
118
119
120 Value isNegative = rewriter.createarith::CmpFOp(
121 loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
122 Value isNegativeFloat =
123 rewriter.createarith::UIToFPOp(loc, floatType, isNegative);
124 Value isNegativeTimesNegTwo =
125 rewriter.createarith::MulFOp(loc, isNegativeFloat, negTwo);
126 Value sign = rewriter.createarith::AddFOp(loc, isNegativeTimesNegTwo, one);
127
128
129 Value positiveX = rewriter.createarith::MulFOp(loc, sign, op.getOperand());
130
131
132 Value negDoubledX = rewriter.createarith::MulFOp(loc, negTwo, positiveX);
133 Value exp2x = rewriter.createmath::ExpOp(loc, negDoubledX);
134 Value dividend = rewriter.createarith::SubFOp(loc, one, exp2x);
135 Value divisor = rewriter.createarith::AddFOp(loc, one, exp2x);
136 Value positiveRes = rewriter.createarith::DivFOp(loc, dividend, divisor);
137
138
140
141 return success();
142 }
143
144
147 Value operand = op.getOperand();
149 Value sin = b.createmath::SinOp(type, operand);
150 Value cos = b.createmath::CosOp(type, operand);
151 Value div = b.createarith::DivFOp(type, sin, cos);
153 return success();
154 }
155
156
160 Value operand = op.getOperand();
162
164 Value fma = b.createmath::FmaOp(operand, operand, one);
165 Value sqrt = b.createmath::SqrtOp(fma);
166 Value add = b.createarith::AddFOp(operand, sqrt);
169 return success();
170 }
171
172
176 Value operand = op.getOperand();
178
180 Value fma = b.createmath::FmaOp(operand, operand, negOne);
181 Value sqrt = b.createmath::SqrtOp(fma);
182 Value add = b.createarith::AddFOp(operand, sqrt);
185 return success();
186 }
187
188
192 Value operand = op.getOperand();
194
196 Value add = b.createarith::AddFOp(operand, one);
197 Value neg = b.createarith::NegFOp(operand);
198 Value sub = b.createarith::AddFOp(neg, one);
199 Value div = b.createarith::DivFOp(add, sub);
202 Value res = b.createarith::MulFOp(log, half);
204 return success();
205 }
206
209 Value operandA = op.getOperand(0);
210 Value operandB = op.getOperand(1);
211 Value operandC = op.getOperand(2);
212 Type type = op.getType();
213 Value mult = b.createarith::MulFOp(type, operandA, operandB);
214 Value add = b.createarith::AddFOp(type, mult, operandC);
216 return success();
217 }
218
219
220
221
222
223
225
226 auto shapedType = dyn_cast(op.getType());
227 if (shapedType && !shapedType.hasStaticShape())
228 return failure();
229
231 Value operand = op.getOperand();
234
235
238
239 Value gtCheck = b.createarith::CmpFOp(arith::CmpFPredicate::OGT, operand,
240 fpFixedConvert);
241 Value incrValue = b.createarith::SelectOp(op->getLoc(), gtCheck, one, zero);
242
243 Value ret = b.createarith::AddFOp(opType, fpFixedConvert, incrValue);
245 return success();
246 }
247
248
249
250
251
255 Value base = op.getOperand(0);
256 Value power = op.getOperand(1);
258
259 auto convertFPowItoPowf = [&]() -> LogicalResult {
260 Value castPowerToFp =
261 rewriter.createarith::SIToFPOp(op.getLoc(), baseType, power);
262 Value res = rewriter.createmath::PowFOp(op.getLoc(), baseType, base,
263 castPowerToFp);
265 return success();
266 };
267
270 return convertFPowItoPowf();
271
272 APInt value;
274 return convertFPowItoPowf();
275
276 int64_t powerInt = value.getSExtValue();
277 bool isNegative = powerInt < 0;
278 int64_t absPower = std::abs(powerInt);
281
282 while (absPower > 0) {
283 if (absPower & 1)
284 res = b.createarith::MulFOp(baseType, base, res);
285 absPower >>= 1;
286 base = b.createarith::MulFOp(baseType, base, base);
287 }
288
289
290 if (isNegative) {
292 .getFloatSemantics();
299 Value posInfinity =
301 APFloat::getInf(sem, false), rewriter);
302 Value negInfinity =
304 APFloat::getInf(sem, true), rewriter);
305 Value zeroEqCheck =
306 b.createarith::CmpFOp(arith::CmpFPredicate::OEQ, res, zero);
307 Value negZeroEqCheck =
308 b.createarith::CmpFOp(arith::CmpFPredicate::OEQ, res, negZero);
309 res = b.createarith::DivFOp(baseType, one, res);
310 res =
311 b.createarith::SelectOp(op->getLoc(), zeroEqCheck, posInfinity, res);
312 res = b.createarith::SelectOp(op->getLoc(), negZeroEqCheck, negInfinity,
313 res);
314 }
315
317 return success();
318 }
319
320
321
322
325 Value operandA = op.getOperand(0);
326 Value operandB = op.getOperand(1);
327 auto typeA = operandA.getType();
328 auto typeB = operandB.getType();
329
330 auto &sem =
332 APFloat valueB(sem);
334 return b.createarith::MulFOp(x, y);
335 };
337 if (valueB.isZero()) {
338
341 return success();
342 }
343 if (valueB.isExactlyValue(1.0)) {
344
345 rewriter.replaceOp(op, operandA);
346 return success();
347 }
348 if (valueB.isExactlyValue(-1.0)) {
349
351 Value div = b.createarith::DivFOp(one, operandA);
353 return success();
354 }
355 if (valueB.isExactlyValue(0.5)) {
356
357 Value sqrt = b.createmath::SqrtOp(operandA);
359 return success();
360 }
361 if (valueB.isExactlyValue(-0.5)) {
362
363 Value rsqrt = b.createmath::RsqrtOp(operandA);
365 return success();
366 }
367 if (valueB.isExactlyValue(2.0)) {
368
369 rewriter.replaceOp(op, mulf(operandA, operandA));
370 return success();
371 }
372 if (valueB.isExactlyValue(-2.0)) {
373
376 Value div = b.createarith::DivFOp(one, mulf(operandA, operandA));
378 return success();
379 }
380 if (valueB.isExactlyValue(3.0)) {
381 rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
382 return success();
383 }
384 }
385
386 Value logA = b.createmath::LogOp(operandA);
387 Value mult = b.createarith::MulFOp(operandB, logA);
388 Value expResult = b.createmath::ExpOp(mult);
389 rewriter.replaceOp(op, expResult);
390 return success();
391 }
392
393
394
395
396
400 Value operand = op.getOperand();
403 Value mult = b.createarith::MulFOp(opType, operand, ln2);
404 Value exp = b.createmath::ExpOp(op->getLoc(), mult);
406 return success();
407 }
408
413 Value operand = op.getOperand();
416
417 if (!opEType.isF32()) {
419 }
420
422 if (auto shapedTy = dyn_cast(opType))
423 i32Ty = shapedTy.clone(i32Ty);
424
429
430 Value incrValue = b.createmath::CopySignOp(half, operand);
431 Value add = b.createarith::AddFOp(opType, operand, incrValue);
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454 Value operandBitcast = b.createarith::BitcastOp(i32Ty, operand);
455 Value operandExp = b.createarith::AndIOp(
456 b.createarith::ShRUIOp(operandBitcast, c23), expMask);
457 Value operandBiasedExp = b.createarith::SubIOp(operandExp, c127);
458 Value isSpecialValOrLargeVal =
459 b.createarith::CmpIOp(arith::CmpIPredicate::sge, operandBiasedExp, c23);
460
461 Value result = b.createarith::SelectOp(isSpecialValOrLargeVal, operand,
462 fpFixedConvert);
464 return success();
465 }
466
467
468
469 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
471 auto operand = op.getOperand();
472 auto operandTy = operand.getType();
475
476 int32_t bitwidth = eTy.getIntOrFloatBitWidth();
477 if (bitwidth > 64)
478 return failure();
479
480 uint64_t allbits = -1;
481 if (bitwidth < 64) {
482 allbits = allbits >> (64 - bitwidth);
483 }
484
485 Value x = operand;
487 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
488 auto half = bw / 2;
489 auto bits = createIntConst(loc, operandTy, half, rewriter);
490 auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
491
493 rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::ule, x, mask);
494 Value add = rewriter.createarith::AddIOp(loc, count, bits);
495 Value shift = rewriter.createarith::ShLIOp(loc, x, bits);
496
497 x = rewriter.createarith::SelectOp(loc, pred, shift, x);
498 count = rewriter.createarith::SelectOp(loc, pred, add, count);
499 }
500
502 Value pred = rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::eq,
503 operand, zero);
504
506 Value sel = rewriter.createarith::SelectOp(loc, pred, bwval, count);
508 return success();
509 }
510
511
516 auto operand = op.getOperand();
517 Type operandTy = operand.getType();
518 Type resultTy = op.getType();
521
522 if (!isa(operandETy) || !isa(resultETy)) {
523 return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
524 }
525
526 Type fTy = operandTy;
528 if (auto shapedTy = dyn_cast(fTy)) {
529 iTy = shapedTy.clone(iTy);
530 }
531
533
534 unsigned mantissaWidth =
535 llvm::cast(operandETy).getFPMantissaWidth() - 1;
536 unsigned exponentWidth = bitWidth - mantissaWidth - 1;
537
538
539
540
541
548 Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b);
550 Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);
551 Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);
552
553 Value operandBitcast = b.createarith::BitcastOp(iTy, operand);
556
557
558 Value operandExp = b.createarith::AndIOp(
559 b.createarith::ShRUIOp(operandBitcast, c23), expMask);
560 Value operandBiasedExp = b.createarith::SubIOp(operandExp, c127);
561 Value roundExp = b.createarith::AndIOp(
562 b.createarith::ShRUIOp(roundBitcast, c23), expMask);
563 Value roundBiasedExp = b.createarith::SubIOp(roundExp, c127);
564
565 auto safeShiftRight = [&](Value x, Value shift) -> Value {
566
567 Value clampedShift = b.createarith::MaxSIOp(shift, c0);
568 clampedShift = b.createarith::MinSIOp(clampedShift, c31);
569 return b.createarith::ShRUIOp(x, clampedShift);
570 };
571
572 auto maskMantissa = [&](Value mantissa,
573 Value mantissaMaskRightShift) -> Value {
574 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
575 return b.createarith::AndIOp(mantissa, shiftedMantissaMask);
576 };
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592 Value roundBiasedExpEq0 =
593 b.createarith::CmpIOp(arith::CmpIPredicate::eq, roundBiasedExp, c0);
594 Value roundBiasedExpMinus1 = b.createarith::SubIOp(roundBiasedExp, c1);
595 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
596 Value roundIsNotEvenOrSpecialVal = b.createarith::CmpIOp(
597 arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
598 roundIsNotEvenOrSpecialVal =
599 b.createarith::OrIOp(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
600
601
602
603
604
605
606
607
608 Value operandBiasedExpEqNeg1 = b.createarith::CmpIOp(
609 arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
610 Value expectedOperandMaskedMantissa = b.createarith::SelectOp(
611 operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
612 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
613 Value operandIsHalfway =
614 b.createarith::CmpIOp(arith::CmpIPredicate::eq, operandMaskedMantissa,
615 expectedOperandMaskedMantissa);
616
617 Value operandBiasedExpGeNeg1 = b.createarith::CmpIOp(
618 arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
619 Value operandBiasedExpLt23 =
620 b.createarith::CmpIOp(arith::CmpIPredicate::slt, operandBiasedExp, c23);
621 operandIsHalfway =
622 b.createarith::AndIOp(operandIsHalfway, operandBiasedExpLt23);
623 operandIsHalfway =
624 b.createarith::AndIOp(operandIsHalfway, operandBiasedExpGeNeg1);
625
626
627
628 Value sign = b.createmath::CopySignOp(c1Float, operand);
630
631
632 Value needsShift =
633 b.createarith::AndIOp(roundIsNotEvenOrSpecialVal, operandIsHalfway);
634 Value result = b.createarith::SelectOp(needsShift, roundShifted, round);
635
636
637
638 result = b.createmath::CopySignOp(result, operand);
640 return success();
641 }
642
643
646
647 auto operand = op.getOperand();
648 auto operandTy = operand.getType();
649
650 auto shapedOperandType = dyn_cast(operandTy);
651 if (shapedOperandType && !shapedOperandType.hasStaticShape())
652 return failure();
653
655 if (!isa(eTy))
656 return failure();
657
659 auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
660 auto sqrtOp = rewriter.createmath::SqrtOp(loc, operand);
661 rewriter.replaceOpWithNewOparith::DivFOp(op, constOneFloat, sqrtOp);
662 return success();
663 }
664
667 }
668
671 }
672
675 }
676
679 }
680
683 }
684
687 }
688
691 }
692
695 }
696
699 }
700
703 }
704
707 }
708
711 }
712
715 }
716
719 }
720
723 }
724
727 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult convertRsqrtOp(math::RsqrtOp op, PatternRewriter &rewriter)
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b)
static LogicalResult convertFPowIOp(math::FPowIOp op, PatternRewriter &rewriter)
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter)
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter)
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter)
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter)
static LogicalResult convertAtanhOp(math::AtanhOp op, PatternRewriter &rewriter)
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, PatternRewriter &rewriter)
static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter)
static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter)
static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b)
Create a float constant.
static LogicalResult convertAsinhOp(math::AsinhOp op, PatternRewriter &rewriter)
static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b)
Create an integer constant.
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter)
Expands tanh op into 1-exp^{-2x} / 1+exp^{-2x} To avoid overflow we exploit the reflection symmetry t...
static LogicalResult convertAcoshOp(math::AcoshOp op, PatternRewriter &rewriter)
static LogicalResult convertExp2fOp(math::Exp2Op op, PatternRewriter &rewriter)
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void populateExpandSinhPattern(RewritePatternSet &patterns)
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
void populateExpandRsqrtPattern(RewritePatternSet &patterns)
void populateExpandTanhPattern(RewritePatternSet &patterns)
void populateExpandFmaFPattern(RewritePatternSet &patterns)
void populateExpandAcoshPattern(RewritePatternSet &patterns)
void populateExpandFPowIPattern(RewritePatternSet &patterns)
void populateExpandPowFPattern(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateExpandTanPattern(RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
void populateExpandCoshPattern(RewritePatternSet &patterns)
void populateExpandRoundFPattern(RewritePatternSet &patterns)
void populateExpandExp2FPattern(RewritePatternSet &patterns)
void populateExpandCeilFPattern(RewritePatternSet &patterns)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void populateExpandCtlzPattern(RewritePatternSet &patterns)
void populateExpandAsinhPattern(RewritePatternSet &patterns)
void populateExpandRoundEvenPattern(RewritePatternSet &patterns)
void populateExpandAtanhPattern(RewritePatternSet &patterns)
detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...