MLIR: lib/Dialect/Arith/Transforms/ExpandOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

16

17 namespace mlir {

18 namespace arith {

19 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS

20 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"

21 }

22 }

23

24 using namespace mlir;

25

26

30 if (auto shapedTy = dyn_cast(type)) {

31 return rewriter.createarith::ConstantOp(

33 }

34 return rewriter.createarith::ConstantOp(loc, attr);

35 }

36

37

39 if (auto shapedTy = dyn_cast(cloneFrom)) {

40 return shapedTy.clone(cloneTo);

41 }

42 return cloneTo;

43 }

44

45 namespace {

46

47

48

49 struct CeilDivUIOpConverter : public OpRewritePatternarith::CeilDivUIOp {

51 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,

54 Value a = op.getLhs();

55 Value b = op.getRhs();

58 rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::eq, a, zero);

60 Value minusOne = rewriter.createarith::SubIOp(loc, a, one);

61 Value quotient = rewriter.createarith::DivUIOp(loc, minusOne, b);

62 Value plusOne = rewriter.createarith::AddIOp(loc, quotient, one);

63 rewriter.replaceOpWithNewOparith::SelectOp(op, compare, zero, plusOne);

64 return success();

65 }

66 };

67

68

69

70

71

72

73

74

75 struct CeilDivSIOpConverter : public OpRewritePatternarith::CeilDivSIOp {

77 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,

80 Type type = op.getType();

81 Value a = op.getLhs();

82 Value b = op.getRhs();

83

86

87 Value quotient = rewriter.createarith::DivSIOp(loc, a, b);

88 Value product = rewriter.createarith::MulIOp(loc, quotient, b);

89 Value notEqualDivisor = rewriter.createarith::CmpIOp(

90 loc, arith::CmpIPredicate::ne, a, product);

91

93 rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::slt, a, zero);

95 rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::slt, b, zero);

96

97 Value signEqual = rewriter.createarith::CmpIOp(

98 loc, arith::CmpIPredicate::eq, aNeg, bNeg);

100 rewriter.createarith::AndIOp(loc, notEqualDivisor, signEqual);

101

102 Value quotientPlusOne = rewriter.createarith::AddIOp(loc, quotient, one);

103

104 rewriter.replaceOpWithNewOparith::SelectOp(op, cond, quotientPlusOne,

105 quotient);

106 return success();

107 }

108 };

109

110

111

112

113

114

115

116

117 struct FloorDivSIOpConverter : public OpRewritePatternarith::FloorDivSIOp {

119 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,

122 Type type = op.getType();

123 Value a = op.getLhs();

124 Value b = op.getRhs();

125

126 Value quotient = rewriter.createarith::DivSIOp(loc, a, b);

127 Value product = rewriter.createarith::MulIOp(loc, quotient, b);

128 Value notEqualDivisor = rewriter.createarith::CmpIOp(

129 loc, arith::CmpIPredicate::ne, a, product);

131

133 rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::slt, a, zero);

135 rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::slt, b, zero);

136

137 Value signOpposite = rewriter.createarith::CmpIOp(

138 loc, arith::CmpIPredicate::ne, aNeg, bNeg);

140 rewriter.createarith::AndIOp(loc, notEqualDivisor, signOpposite);

141

143 Value quotientMinusOne =

144 rewriter.createarith::AddIOp(loc, quotient, minusOne);

145

146 rewriter.replaceOpWithNewOparith::SelectOp(op, cond, quotientMinusOne,

147 quotient);

148 return success();

149 }

150 };

151

152 template <typename OpTy, arith::CmpIPredicate pred>

154 public:

156

157 LogicalResult matchAndRewrite(OpTy op,

159 Value lhs = op.getLhs();

160 Value rhs = op.getRhs();

161

162 Value cmp = rewriter.createarith::CmpIOp(op.getLoc(), pred, lhs, rhs);

163 rewriter.replaceOpWithNewOparith::SelectOp(op, cmp, lhs, rhs);

164 return success();

165 }

166 };

167

168 template <typename OpTy, arith::CmpFPredicate pred>

169 struct MaximumMinimumFOpConverter : public OpRewritePattern {

170 public:

172

173 LogicalResult matchAndRewrite(OpTy op,

175 Value lhs = op.getLhs();

176 Value rhs = op.getRhs();

177

179

180 static_assert(pred == arith::CmpFPredicate::UGT ||

181 pred == arith::CmpFPredicate::ULT,

182 "pred must be either UGT or ULT");

183 Value cmp = rewriter.createarith::CmpFOp(loc, pred, lhs, rhs);

184 Value select = rewriter.createarith::SelectOp(loc, cmp, lhs, rhs);

185

186

187 Value isNaN = rewriter.createarith::CmpFOp(loc, arith::CmpFPredicate::UNO,

188 rhs, rhs);

189 rewriter.replaceOpWithNewOparith::SelectOp(op, isNaN, rhs, select);

190 return success();

191 }

192 };

193

194 template <typename OpTy, arith::CmpFPredicate pred>

195 struct MaxNumMinNumFOpConverter : public OpRewritePattern {

196 public:

198

199 LogicalResult matchAndRewrite(OpTy op,

201 Value lhs = op.getLhs();

202 Value rhs = op.getRhs();

203

205

206 static_assert(pred == arith::CmpFPredicate::UGT ||

207 pred == arith::CmpFPredicate::ULT,

208 "pred must be either UGT or ULT");

209 Value cmp = rewriter.createarith::CmpFOp(loc, pred, lhs, rhs);

210 Value select = rewriter.createarith::SelectOp(loc, cmp, lhs, rhs);

211

212

213 Value isNaN = rewriter.createarith::CmpFOp(loc, arith::CmpFPredicate::UNO,

214 lhs, lhs);

215 rewriter.replaceOpWithNewOparith::SelectOp(op, isNaN, rhs, select);

216 return success();

217 }

218 };

219

220 struct BFloat16ExtFOpConverter : public OpRewritePatternarith::ExtFOp {

222 LogicalResult matchAndRewrite(arith::ExtFOp op,

225 auto operand = op.getOperand();

226 Type operandTy = operand.getType();

227 Type resultTy = op.getType();

230

231 if (!operandETy.isBF16() || !resultETy.isF32()) {

232 return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");

233 }

234

237

238 Value bitcast = b.createarith::BitcastOp(i16Ty, operand);

239 Value exti = b.createarith::ExtUIOp(i32Ty, bitcast);

240

242 Value shl = b.createarith::ShLIOp(exti, c16);

243 Value result = b.createarith::BitcastOp(resultTy, shl);

244

245 rewriter.replaceOp(op, result);

246 return success();

247 }

248 };

249

250 struct BFloat16TruncFOpConverter : public OpRewritePatternarith::TruncFOp {

252 LogicalResult matchAndRewrite(arith::TruncFOp op,

255 auto operand = op.getOperand();

256 Type operandTy = operand.getType();

257 Type resultTy = op.getType();

260

261 if (!operandETy.isF32() || !resultETy.isBF16()) {

262 return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");

263 }

264

265 if (op.getRoundingmodeAttr()) {

266 return rewriter.notifyMatchFailure(

267 op, "only applicable to default rounding mode.");

268 }

269

272

273

274

275

276

277

278

279

280

281

282

283

284

285

287 b.createarith::CmpFOp(arith::CmpFPredicate::UNE, operand, operand);

288

289 Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);

290

291 Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);

292

295

296 Value bitcast = b.createarith::BitcastOp(i32Ty, operand);

297

299 b.createarith::AndIOp(b.createarith::ShRUIOp(bitcast, c16), c1);

300

301

302 Value roundingBias = b.createarith::AddIOp(bit16, c7FFF);

303

304

305

306

307

308

309 Value biased = b.createarith::AddIOp(bitcast, roundingBias);

310

311

312 Value biasedAndShifted = b.createarith::ShRUIOp(biased, c16);

313 Value normalCaseResultI16 =

314 b.createarith::TruncIOp(i16Ty, biasedAndShifted);

315

316

318 b.createarith::SelectOp(isNan, c7FC0I16, normalCaseResultI16);

319 Value result = b.createarith::BitcastOp(resultTy, select);

320 rewriter.replaceOp(op, result);

321 return success();

322 }

323 };

324

325 struct F8E8M0ExtFOpConverter : public OpRewritePatternarith::ExtFOp {

327 LogicalResult matchAndRewrite(arith::ExtFOp op,

330 Value operand = op.getOperand();

332 Type resultTy = op.getType();

335

336 if (!llvm::isa(operandETy)) {

337 return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");

338 }

339

343

344 Value bitcast = b.createarith::BitcastOp(i8Ty, operand);

345

346 Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);

347 Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);

348 Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);

349

350 Value exti = b.createarith::ExtUIOp(i32Ty, bitcast);

351 Value f32Bits = b.createarith::ShLIOp(exti, cF32MantissaWidth);

352

354 b.createarith::CmpIOp(arith::CmpIPredicate::eq, bitcast, cF8NaN);

355

356 f32Bits = b.createarith::SelectOp(isNan, cF32NaN, f32Bits);

357 Value result = b.createarith::BitcastOp(f32Ty, f32Bits);

359 result = b.createarith::TruncFOp(resultTy, result, nullptr,

360 op.getFastmathAttr());

362 result = b.createarith::ExtFOp(resultTy, result, op.getFastmathAttr());

363 }

364 rewriter.replaceOp(op, result);

365 return success();

366 }

367 };

368

369

370

371

372

373

374 struct F8E8M0TruncFOpConverter : public OpRewritePatternarith::TruncFOp {

376 LogicalResult matchAndRewrite(arith::TruncFOp op,

379 Value operand = op.getOperand();

382 Type resultTy = op.getType();

384 if (!llvm::isa(resultETy)) {

385 return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");

386 }

387

388 if (op.getRoundingmodeAttr()) {

389 return rewriter.notifyMatchFailure(

390 op, "only applicable to default rounding mode.");

391 }

392

396

398 operand = b.createarith::ExtFOp(f32Ty, operand, op.getFastmathAttr());

400 operand = b.createarith::TruncFOp(

401 f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());

402 }

403 Value f32Bits = b.createarith::BitcastOp(i32Ty, operand);

404 Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);

405 Value f32SignExp = b.createarith::ShRUIOp(f32Bits, cF32MantissaWidth);

406 Value exp8Bits = b.createarith::TruncIOp(i8Ty, f32SignExp);

407 Value result = b.createarith::BitcastOp(resultTy, exp8Bits);

408 rewriter.replaceOp(op, result);

409 return success();

410 }

411 };

412

413 struct ScalingExtFOpConverter : public OpRewritePatternarith::ScalingExtFOp {

415 LogicalResult matchAndRewrite(arith::ScalingExtFOp op,

418 Value inputOperand = op.getIn();

419 Value scaleOperand = op.getScale();

422

424 scaleETy = b.getF8E8M0Type();

426 scaleOperand = b.createarith::TruncFOp(scaleTy, scaleOperand, nullptr,

427 op.getFastmathAttr());

428 }

429 if (!llvm::isa(scaleETy)) {

430 return rewriter.notifyMatchFailure(

431 op, "scaling_extf is using scales of type which can not be converted "

432 "to f8E8M0FNU");

433 }

434 Type resultTy = op.getType();

435

436

438 b.createarith::ExtFOp(resultTy, scaleOperand, op.getFastmathAttr());

440 b.createarith::ExtFOp(resultTy, inputOperand, op.getFastmathAttr());

442 b.createarith::MulFOp(inputExt, scaleExt, op.getFastmathAttr());

443 rewriter.replaceOp(op, result);

444 return success();

445 }

446 };

447

448

449

450

451

452

453 struct ScalingTruncFOpConverter

456 LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,

459 Value inputOperand = op.getIn();

460 Value scaleOperand = op.getScale();

463

465 scaleETy = b.getF8E8M0Type();

467 scaleOperand = b.createarith::TruncFOp(scaleTy, scaleOperand, nullptr,

468 op.getFastmathAttr());

469 }

470 if (!llvm::isa(scaleETy)) {

471 return rewriter.notifyMatchFailure(

472 op, "scaling_truncf is using scales type which can not be converted "

473 "to f8E8M0FNU");

474 }

475 Type resultTy = op.getType();

477

478

479 scaleOperand =

480 b.createarith::ExtFOp(inputTy, scaleOperand, op.getFastmathAttr());

481 Value result = b.createarith::DivFOp(inputOperand, scaleOperand,

482 op.getFastmathAttr());

483 Value resultCast = b.createarith::TruncFOp(

484 resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());

485 rewriter.replaceOp(op, resultCast);

486 return success();

487 }

488 };

489

490 struct ArithExpandOpsPass

491 : public arith::impl::ArithExpandOpsPassBase {

492 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;

493

494 void runOnOperation() override {

497

499

500 target.addLegalDialectarith::ArithDialect();

501

502 target.addIllegalOp<

503 arith::CeilDivSIOp,

504 arith::CeilDivUIOp,

505 arith::FloorDivSIOp,

506 arith::MaxSIOp,

507 arith::MaxUIOp,

508 arith::MinSIOp,

509 arith::MinUIOp,

510 arith::MaximumFOp,

511 arith::MinimumFOp,

512 arith::MaxNumFOp,

513 arith::MinNumFOp,

514 arith::ScalingExtFOp,

515 arith::ScalingTruncFOp

516 >();

517

518 if (includeBf16) {

520 }

521 if (includeF8E8M0) {

523 }

524

525 target.addDynamicallyLegalOparith::ExtFOp(

526 [=](arith::ExtFOp op) {

529 bool legalTypes = true;

530 if (includeBf16)

531 legalTypes &= !(inETy.isBF16() && outETy.isF32());

532 if (includeF8E8M0)

533 legalTypes &= !llvm::isa(inETy);

534 return legalTypes;

535 });

536

537 target.addDynamicallyLegalOparith::TruncFOp(

538 [=](arith::TruncFOp op) {

541 bool legalTypes = true;

542 if (includeBf16)

543 legalTypes &= !(inETy.isF32() && outETy.isBF16());

544 if (includeF8E8M0)

545 legalTypes &= !(llvm::isa(outETy));

546 return legalTypes;

547 });

548

549

552 signalPassFailure();

553 }

554 };

555

556 }

557

561 .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(

563 }

564

566 patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(

568 }

569

571 patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(

573 }

574

577 patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(

579 }

580

584

586 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,

587 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,

588 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,

589 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,

590 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,

591 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,

592 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,

593 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>

595

596 }

static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter)

Create an integer or index constant.

static Type cloneToShapedType(Type cloneFrom, Type cloneTo)

Creates shapedType using shape from cloneFrom and base type from cloneTo.

static int64_t product(ArrayRef< int64_t > vals)

static MLIRContext * getContext(OpFoldResult val)

IntegerAttr getIntegerAttr(Type type, int64_t value)

This class describes a specific conversion target.

static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)

Constructs a dense elements attribute from an array of element values.

ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

void populateExpandBFloat16Patterns(RewritePatternSet &patterns)

Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.

void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns)

Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops.

void populateExpandF8E8M0Patterns(RewritePatternSet &patterns)

Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.

void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)

Add patterns to expand Arith ceil/floor division ops.

void populateArithExpandOpsPatterns(RewritePatternSet &patterns)

Add patterns to expand Arith ops.

int compare(const Fraction &x, const Fraction &y)

Three-way comparison between two fractions.

Include the generated interface declarations.

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())

Below we define several entry points for operation conversion.

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...