MLIR: lib/Dialect/Math/Transforms/ExpandPatterns.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

22 #include "llvm/ADT/APFloat.h"

23

24 using namespace mlir;

25

26

29 bool losesInfo = false;

31

32 value.convert(cast(eltType).getFloatSemantics(),

33 APFloat::rmNearestTiesToEven, &losesInfo);

35 if (auto shapedTy = dyn_cast(type)) {

36 return b.createarith::ConstantOp(loc,

38 }

39

40 return b.createarith::ConstantOp(loc, attr);

41 }

42

46 }

47

48

52 if (auto shapedTy = dyn_cast(type)) {

53 return b.createarith::ConstantOp(loc,

55 }

56

57 return b.createarith::ConstantOp(loc, attr);

58 }

59

63 if (auto shapedTy = dyn_cast(opType))

64 i64Ty = shapedTy.clone(i64Ty);

65 Value fixedConvert = b.createarith::FPToSIOp(i64Ty, operand);

66 Value fpFixedConvert = b.createarith::SIToFPOp(opType, fixedConvert);

67

68

69 return b.createmath::CopySignOp(fpFixedConvert, operand);

70 }

71

72

75 Value operand = op.getOperand();

77

78 Value exp = b.createmath::ExpOp(operand);

79 Value neg = b.createarith::NegFOp(operand);

81 Value sub = b.createarith::SubFOp(exp, nexp);

83 Value res = b.createarith::MulFOp(sub, half);

85 return success();

86 }

87

88

91 Value operand = op.getOperand();

93

94 Value exp = b.createmath::ExpOp(operand);

95 Value neg = b.createarith::NegFOp(operand);

97 Value add = b.createarith::AddFOp(exp, nexp);

99 Value res = b.createarith::MulFOp(add, half);

101 return success();

102 }

103

104

105

106

107

108

109

110

111

113 auto floatType = op.getOperand().getType();

118

119

120 Value isNegative = rewriter.createarith::CmpFOp(

121 loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);

122 Value isNegativeFloat =

123 rewriter.createarith::UIToFPOp(loc, floatType, isNegative);

124 Value isNegativeTimesNegTwo =

125 rewriter.createarith::MulFOp(loc, isNegativeFloat, negTwo);

126 Value sign = rewriter.createarith::AddFOp(loc, isNegativeTimesNegTwo, one);

127

128

129 Value positiveX = rewriter.createarith::MulFOp(loc, sign, op.getOperand());

130

131

132 Value negDoubledX = rewriter.createarith::MulFOp(loc, negTwo, positiveX);

133 Value exp2x = rewriter.createmath::ExpOp(loc, negDoubledX);

134 Value dividend = rewriter.createarith::SubFOp(loc, one, exp2x);

135 Value divisor = rewriter.createarith::AddFOp(loc, one, exp2x);

136 Value positiveRes = rewriter.createarith::DivFOp(loc, dividend, divisor);

137

138

140

141 return success();

142 }

143

144

147 Value operand = op.getOperand();

149 Value sin = b.createmath::SinOp(type, operand);

150 Value cos = b.createmath::CosOp(type, operand);

151 Value div = b.createarith::DivFOp(type, sin, cos);

153 return success();

154 }

155

156

160 Value operand = op.getOperand();

162

164 Value fma = b.createmath::FmaOp(operand, operand, one);

165 Value sqrt = b.createmath::SqrtOp(fma);

166 Value add = b.createarith::AddFOp(operand, sqrt);

169 return success();

170 }

171

172

176 Value operand = op.getOperand();

178

180 Value fma = b.createmath::FmaOp(operand, operand, negOne);

181 Value sqrt = b.createmath::SqrtOp(fma);

182 Value add = b.createarith::AddFOp(operand, sqrt);

185 return success();

186 }

187

188

192 Value operand = op.getOperand();

194

196 Value add = b.createarith::AddFOp(operand, one);

197 Value neg = b.createarith::NegFOp(operand);

198 Value sub = b.createarith::AddFOp(neg, one);

199 Value div = b.createarith::DivFOp(add, sub);

202 Value res = b.createarith::MulFOp(log, half);

204 return success();

205 }

206

209 Value operandA = op.getOperand(0);

210 Value operandB = op.getOperand(1);

211 Value operandC = op.getOperand(2);

212 Type type = op.getType();

213 Value mult = b.createarith::MulFOp(type, operandA, operandB);

214 Value add = b.createarith::AddFOp(type, mult, operandC);

216 return success();

217 }

218

219

220

221

222

223

225

226 auto shapedType = dyn_cast(op.getType());

227 if (shapedType && !shapedType.hasStaticShape())

228 return failure();

229

231 Value operand = op.getOperand();

234

235

238

239 Value gtCheck = b.createarith::CmpFOp(arith::CmpFPredicate::OGT, operand,

240 fpFixedConvert);

241 Value incrValue = b.createarith::SelectOp(op->getLoc(), gtCheck, one, zero);

242

243 Value ret = b.createarith::AddFOp(opType, fpFixedConvert, incrValue);

245 return success();

246 }

247

248

249

250

251

255 Value base = op.getOperand(0);

256 Value power = op.getOperand(1);

258

259 auto convertFPowItoPowf = [&]() -> LogicalResult {

260 Value castPowerToFp =

261 rewriter.createarith::SIToFPOp(op.getLoc(), baseType, power);

262 Value res = rewriter.createmath::PowFOp(op.getLoc(), baseType, base,

263 castPowerToFp);

265 return success();

266 };

267

270 return convertFPowItoPowf();

271

272 APInt value;

274 return convertFPowItoPowf();

275

276 int64_t powerInt = value.getSExtValue();

277 bool isNegative = powerInt < 0;

278 int64_t absPower = std::abs(powerInt);

281

282 while (absPower > 0) {

283 if (absPower & 1)

284 res = b.createarith::MulFOp(baseType, base, res);

285 absPower >>= 1;

286 base = b.createarith::MulFOp(baseType, base, base);

287 }

288

289

290 if (isNegative) {

292 .getFloatSemantics();

299 Value posInfinity =

301 APFloat::getInf(sem, false), rewriter);

302 Value negInfinity =

304 APFloat::getInf(sem, true), rewriter);

305 Value zeroEqCheck =

306 b.createarith::CmpFOp(arith::CmpFPredicate::OEQ, res, zero);

307 Value negZeroEqCheck =

308 b.createarith::CmpFOp(arith::CmpFPredicate::OEQ, res, negZero);

309 res = b.createarith::DivFOp(baseType, one, res);

310 res =

311 b.createarith::SelectOp(op->getLoc(), zeroEqCheck, posInfinity, res);

312 res = b.createarith::SelectOp(op->getLoc(), negZeroEqCheck, negInfinity,

313 res);

314 }

315

317 return success();

318 }

319

320

321

322

325 Value operandA = op.getOperand(0);

326 Value operandB = op.getOperand(1);

327 auto typeA = operandA.getType();

328 auto typeB = operandB.getType();

329

330 auto &sem =

332 APFloat valueB(sem);

334 return b.createarith::MulFOp(x, y);

335 };

337 if (valueB.isZero()) {

338

341 return success();

342 }

343 if (valueB.isExactlyValue(1.0)) {

344

345 rewriter.replaceOp(op, operandA);

346 return success();

347 }

348 if (valueB.isExactlyValue(-1.0)) {

349

351 Value div = b.createarith::DivFOp(one, operandA);

353 return success();

354 }

355 if (valueB.isExactlyValue(0.5)) {

356

357 Value sqrt = b.createmath::SqrtOp(operandA);

359 return success();

360 }

361 if (valueB.isExactlyValue(-0.5)) {

362

363 Value rsqrt = b.createmath::RsqrtOp(operandA);

365 return success();

366 }

367 if (valueB.isExactlyValue(2.0)) {

368

369 rewriter.replaceOp(op, mulf(operandA, operandA));

370 return success();

371 }

372 if (valueB.isExactlyValue(-2.0)) {

373

376 Value div = b.createarith::DivFOp(one, mulf(operandA, operandA));

378 return success();

379 }

380 if (valueB.isExactlyValue(3.0)) {

381 rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));

382 return success();

383 }

384 }

385

386 Value logA = b.createmath::LogOp(operandA);

387 Value mult = b.createarith::MulFOp(operandB, logA);

388 Value expResult = b.createmath::ExpOp(mult);

389 rewriter.replaceOp(op, expResult);

390 return success();

391 }

392

393

394

395

396

400 Value operand = op.getOperand();

403 Value mult = b.createarith::MulFOp(opType, operand, ln2);

404 Value exp = b.createmath::ExpOp(op->getLoc(), mult);

406 return success();

407 }

408

413 Value operand = op.getOperand();

416

417 if (!opEType.isF32()) {

419 }

420

422 if (auto shapedTy = dyn_cast(opType))

423 i32Ty = shapedTy.clone(i32Ty);

424

429

430 Value incrValue = b.createmath::CopySignOp(half, operand);

431 Value add = b.createarith::AddFOp(opType, operand, incrValue);

433

434

435

436

437

438

439

440

441

442

443

444

445

446

447

448

449

450

451

452

453

454 Value operandBitcast = b.createarith::BitcastOp(i32Ty, operand);

455 Value operandExp = b.createarith::AndIOp(

456 b.createarith::ShRUIOp(operandBitcast, c23), expMask);

457 Value operandBiasedExp = b.createarith::SubIOp(operandExp, c127);

458 Value isSpecialValOrLargeVal =

459 b.createarith::CmpIOp(arith::CmpIPredicate::sge, operandBiasedExp, c23);

460

461 Value result = b.createarith::SelectOp(isSpecialValOrLargeVal, operand,

462 fpFixedConvert);

464 return success();

465 }

466

467

468

469 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,

471 auto operand = op.getOperand();

472 auto operandTy = operand.getType();

475

476 int32_t bitwidth = eTy.getIntOrFloatBitWidth();

477 if (bitwidth > 64)

478 return failure();

479

480 uint64_t allbits = -1;

481 if (bitwidth < 64) {

482 allbits = allbits >> (64 - bitwidth);

483 }

484

485 Value x = operand;

487 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {

488 auto half = bw / 2;

489 auto bits = createIntConst(loc, operandTy, half, rewriter);

490 auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);

491

493 rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::ule, x, mask);

494 Value add = rewriter.createarith::AddIOp(loc, count, bits);

495 Value shift = rewriter.createarith::ShLIOp(loc, x, bits);

496

497 x = rewriter.createarith::SelectOp(loc, pred, shift, x);

498 count = rewriter.createarith::SelectOp(loc, pred, add, count);

499 }

500

502 Value pred = rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::eq,

503 operand, zero);

504

506 Value sel = rewriter.createarith::SelectOp(loc, pred, bwval, count);

508 return success();

509 }

510

511

516 auto operand = op.getOperand();

517 Type operandTy = operand.getType();

518 Type resultTy = op.getType();

521

522 if (!isa(operandETy) || !isa(resultETy)) {

523 return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");

524 }

525

526 Type fTy = operandTy;

528 if (auto shapedTy = dyn_cast(fTy)) {

529 iTy = shapedTy.clone(iTy);

530 }

531

533

534 unsigned mantissaWidth =

535 llvm::cast(operandETy).getFPMantissaWidth() - 1;

536 unsigned exponentWidth = bitWidth - mantissaWidth - 1;

537

538

539

540

541

548 Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b);

550 Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);

551 Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);

552

553 Value operandBitcast = b.createarith::BitcastOp(iTy, operand);

556

557

558 Value operandExp = b.createarith::AndIOp(

559 b.createarith::ShRUIOp(operandBitcast, c23), expMask);

560 Value operandBiasedExp = b.createarith::SubIOp(operandExp, c127);

561 Value roundExp = b.createarith::AndIOp(

562 b.createarith::ShRUIOp(roundBitcast, c23), expMask);

563 Value roundBiasedExp = b.createarith::SubIOp(roundExp, c127);

564

565 auto safeShiftRight = [&](Value x, Value shift) -> Value {

566

567 Value clampedShift = b.createarith::MaxSIOp(shift, c0);

568 clampedShift = b.createarith::MinSIOp(clampedShift, c31);

569 return b.createarith::ShRUIOp(x, clampedShift);

570 };

571

572 auto maskMantissa = [&](Value mantissa,

573 Value mantissaMaskRightShift) -> Value {

574 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);

575 return b.createarith::AndIOp(mantissa, shiftedMantissaMask);

576 };

577

578

579

580

581

582

583

584

585

586

587

588

589

590

591

592 Value roundBiasedExpEq0 =

593 b.createarith::CmpIOp(arith::CmpIPredicate::eq, roundBiasedExp, c0);

594 Value roundBiasedExpMinus1 = b.createarith::SubIOp(roundBiasedExp, c1);

595 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);

596 Value roundIsNotEvenOrSpecialVal = b.createarith::CmpIOp(

597 arith::CmpIPredicate::ne, roundMaskedMantissa, c0);

598 roundIsNotEvenOrSpecialVal =

599 b.createarith::OrIOp(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);

600

601

602

603

604

605

606

607

608 Value operandBiasedExpEqNeg1 = b.createarith::CmpIOp(

609 arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);

610 Value expectedOperandMaskedMantissa = b.createarith::SelectOp(

611 operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));

612 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);

613 Value operandIsHalfway =

614 b.createarith::CmpIOp(arith::CmpIPredicate::eq, operandMaskedMantissa,

615 expectedOperandMaskedMantissa);

616

617 Value operandBiasedExpGeNeg1 = b.createarith::CmpIOp(

618 arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);

619 Value operandBiasedExpLt23 =

620 b.createarith::CmpIOp(arith::CmpIPredicate::slt, operandBiasedExp, c23);

621 operandIsHalfway =

622 b.createarith::AndIOp(operandIsHalfway, operandBiasedExpLt23);

623 operandIsHalfway =

624 b.createarith::AndIOp(operandIsHalfway, operandBiasedExpGeNeg1);

625

626

627

628 Value sign = b.createmath::CopySignOp(c1Float, operand);

630

631

632 Value needsShift =

633 b.createarith::AndIOp(roundIsNotEvenOrSpecialVal, operandIsHalfway);

634 Value result = b.createarith::SelectOp(needsShift, roundShifted, round);

635

636

637

638 result = b.createmath::CopySignOp(result, operand);

640 return success();

641 }

642

643

646

647 auto operand = op.getOperand();

648 auto operandTy = operand.getType();

649

650 auto shapedOperandType = dyn_cast(operandTy);

651 if (shapedOperandType && !shapedOperandType.hasStaticShape())

652 return failure();

653

655 if (!isa(eTy))

656 return failure();

657

659 auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);

660 auto sqrtOp = rewriter.createmath::SqrtOp(loc, operand);

661 rewriter.replaceOpWithNewOparith::DivFOp(op, constOneFloat, sqrtOp);

662 return success();

663 }

664

667 }

668

671 }

672

675 }

676

679 }

680

683 }

684

687 }

688

691 }

692

695 }

696

699 }

700

703 }

704

707 }

708

711 }

712

715 }

716

719 }

720

723 }

724

727 }

static Value getZero(OpBuilder &b, Location loc, Type elementType)

Get zero value for an element type.

static LogicalResult convertRsqrtOp(math::RsqrtOp op, PatternRewriter &rewriter)

static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b)

static LogicalResult convertFPowIOp(math::FPowIOp op, PatternRewriter &rewriter)

static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter)

static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter)

static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter)

static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter)

static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter)

static LogicalResult convertAtanhOp(math::AtanhOp op, PatternRewriter &rewriter)

static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter)

static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, PatternRewriter &rewriter)

static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter)

static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter)

static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b)

Create a float constant.

static LogicalResult convertAsinhOp(math::AsinhOp op, PatternRewriter &rewriter)

static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b)

Create an integer constant.

static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter)

Expands tanh op into 1-exp^{-2x} / 1+exp^{-2x} To avoid overflow we exploit the reflection symmetry t...

static LogicalResult convertAcoshOp(math::AcoshOp op, PatternRewriter &rewriter)

static LogicalResult convertExp2fOp(math::Exp2Op op, PatternRewriter &rewriter)

Attributes are known-constant values of operations.

IntegerAttr getIntegerAttr(Type type, int64_t value)

FloatAttr getFloatAttr(Type type, double value)

IntegerType getIntegerType(unsigned width)

static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)

Constructs a dense elements attribute from an array of element values.

ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...

OpTy create(Args &&...args)

Create an operation of specific op type at the current insertion point and location.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

This class helps build Operations.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

DynamicAPInt round(const Fraction &f)

Fraction abs(const Fraction &f)

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

void populateExpandSinhPattern(RewritePatternSet &patterns)

detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)

Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...

void populateExpandRsqrtPattern(RewritePatternSet &patterns)

void populateExpandTanhPattern(RewritePatternSet &patterns)

void populateExpandFmaFPattern(RewritePatternSet &patterns)

void populateExpandAcoshPattern(RewritePatternSet &patterns)

void populateExpandFPowIPattern(RewritePatternSet &patterns)

void populateExpandPowFPattern(RewritePatternSet &patterns)

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

void populateExpandTanPattern(RewritePatternSet &patterns)

const FrozenRewritePatternSet & patterns

void populateExpandCoshPattern(RewritePatternSet &patterns)

void populateExpandRoundFPattern(RewritePatternSet &patterns)

void populateExpandExp2FPattern(RewritePatternSet &patterns)

void populateExpandCeilFPattern(RewritePatternSet &patterns)

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

void populateExpandCtlzPattern(RewritePatternSet &patterns)

void populateExpandAsinhPattern(RewritePatternSet &patterns)

void populateExpandRoundEvenPattern(RewritePatternSet &patterns)

void populateExpandAtanhPattern(RewritePatternSet &patterns)

detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)

Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...