MLIR: lib/Dialect/Index/IR/IndexOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

17 #include "llvm/ADT/SmallString.h"

18 #include "llvm/ADT/TypeSwitch.h"

19

20 using namespace mlir;

22

23

24

25

26

27 void IndexDialect::registerOperations() {

28 addOperations<

29 #define GET_OP_LIST

30 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"

31 >();

32 }

33

36

37 if (auto boolValue = dyn_cast(value)) {

39 return nullptr;

40 return b.create(loc, type, boolValue);

41 }

42

43

44 if (auto indexValue = dyn_cast(value)) {

45 if (!llvm::isa(indexValue.getType()) ||

46 !llvm::isa(type))

47 return nullptr;

48 assert(indexValue.getValue().getBitWidth() ==

49 IndexType::kInternalStorageBitWidth);

50 return b.create(loc, indexValue);

51 }

52

53 return nullptr;

54 }

55

56

57

58

59

60

61

62

63

64

65

66

67

68

71 function_ref<std::optional(const APInt &, const APInt &)>

72 calculate) {

73 assert(operands.size() == 2 && "binary operation expected 2 operands");

74 auto lhs = dyn_cast_if_present(operands[0]);

75 auto rhs = dyn_cast_if_present(operands[1]);

76 if (!lhs || !rhs)

77 return {};

78

79 std::optional result = calculate(lhs.getValue(), rhs.getValue());

80 if (!result)

81 return {};

82 assert(result->trunc(32) ==

83 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));

85 }

86

87

88

89

90

91

92

93

94

97 function_ref<std::optional(const APInt &, const APInt &lhs)>

98 calculate) {

99 assert(operands.size() == 2 && "binary operation expected 2 operands");

100 auto lhs = dyn_cast_if_present(operands[0]);

101 auto rhs = dyn_cast_if_present(operands[1]);

102

103 if (!lhs || !rhs)

104 return {};

105

106

107 std::optional result64 = calculate(lhs.getValue(), rhs.getValue());

108 if (!result64)

109 return {};

110 std::optional result32 =

111 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));

112 if (!result32)

113 return {};

114

115 if (result64->trunc(32) != *result32)

116 return {};

117

119 }

120

121

122

123

124 template

125 LogicalResult

129 return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");

130

131 auto lhsOp = op.getLhs().template getDefiningOp();

132 if (!lhsOp)

133 return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp");

134

136 return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant");

137

138 Value c = rewriter.createOrFold(op->getLoc(), op.getRhs(),

139 lhsOp.getRhs());

141 return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded");

142

144 return success();

145 }

146

147

148

149

150

151 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {

153 adaptor.getOperands(),

154 [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; }))

155 return result;

156

157 if (auto rhs = dyn_cast_or_null(adaptor.getRhs())) {

158

159 if (rhs.getValue().isZero())

160 return getLhs();

161 }

162

163 return {};

164 }

165

166 LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {

168 }

169

170

171

172

173

174 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {

176 adaptor.getOperands(),

177 [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; }))

178 return result;

179

180 if (auto rhs = dyn_cast_or_null(adaptor.getRhs())) {

181

182 if (rhs.getValue().isZero())

183 return getLhs();

184 }

185

186 return {};

187 }

188

189

190

191

192

193 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {

195 adaptor.getOperands(),

196 [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; }))

197 return result;

198

199 if (auto rhs = dyn_cast_or_null(adaptor.getRhs())) {

200

201 if (rhs.getValue().isOne())

202 return getLhs();

203

204 if (rhs.getValue().isZero())

205 return rhs;

206 }

207

208 return {};

209 }

210

211 LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {

213 }

214

215

216

217

218

219 OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {

221 adaptor.getOperands(),

222 [](const APInt &lhs, const APInt &rhs) -> std::optional {

223

224 if (rhs.isZero())

225 return std::nullopt;

226 return lhs.sdiv(rhs);

227 });

228 }

229

230

231

232

233

234 OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {

236 adaptor.getOperands(),

237 [](const APInt &lhs, const APInt &rhs) -> std::optional {

238

239 if (rhs.isZero())

240 return std::nullopt;

241 return lhs.udiv(rhs);

242 });

243 }

244

245

246

247

248

249

250

251 static std::optional calculateCeilDivS(const APInt &n, const APInt &m) {

252

253 if (m.isZero())

254 return std::nullopt;

255

256 if (n.isZero())

257 return n;

258

259 bool mGtZ = m.sgt(0);

260 if (n.sgt(0) != mGtZ) {

261

262

263

264 return -(-n).sdiv(m);

265 }

266

267

268 int64_t x = mGtZ ? -1 : 1;

269 return (n + x).sdiv(m) + 1;

270 }

271

272 OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) {

274 }

275

276

277

278

279

280 OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) {

281

283 adaptor.getOperands(),

284 [](const APInt &n, const APInt &m) -> std::optional {

285

286 if (m.isZero())

287 return std::nullopt;

288

289 if (n.isZero())

290 return n;

291

292 return (n - 1).udiv(m) + 1;

293 });

294 }

295

296

297

298

299

300

301

303

304 if (m.isZero())

305 return std::nullopt;

306

307 if (n.isZero())

308 return n;

309

310 bool mLtZ = m.slt(0);

311 if (n.slt(0) == mLtZ) {

312

313 return n.sdiv(m);

314 }

315

316

317

318 int64_t x = mLtZ ? 1 : -1;

319 return -1 - (x - n).sdiv(m);

320 }

321

322 OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {

324 }

325

326

327

328

329

330 OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {

332 adaptor.getOperands(),

333 [](const APInt &lhs, const APInt &rhs) -> std::optional {

334

335 if (rhs.isZero())

336 return std::nullopt;

337 return lhs.srem(rhs);

338 });

339 }

340

341

342

343

344

345 OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {

347 adaptor.getOperands(),

348 [](const APInt &lhs, const APInt &rhs) -> std::optional {

349

350 if (rhs.isZero())

351 return std::nullopt;

352 return lhs.urem(rhs);

353 });

354 }

355

356

357

358

359

360 OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {

362 [](const APInt &lhs, const APInt &rhs) {

363 return lhs.sgt(rhs) ? lhs : rhs;

364 });

365 }

366

367 LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) {

369 }

370

371

372

373

374

375 OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {

377 [](const APInt &lhs, const APInt &rhs) {

378 return lhs.ugt(rhs) ? lhs : rhs;

379 });

380 }

381

382 LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) {

384 }

385

386

387

388

389

390 OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {

392 [](const APInt &lhs, const APInt &rhs) {

393 return lhs.slt(rhs) ? lhs : rhs;

394 });

395 }

396

397 LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) {

399 }

400

401

402

403

404

405 OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {

407 [](const APInt &lhs, const APInt &rhs) {

408 return lhs.ult(rhs) ? lhs : rhs;

409 });

410 }

411

412 LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) {

414 }

415

416

417

418

419

420 OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {

422 adaptor.getOperands(),

423 [](const APInt &lhs, const APInt &rhs) -> std::optional {

424

425

426

427 if (rhs.uge(32))

428 return {};

429 return lhs << rhs;

430 });

431 }

432

433

434

435

436

437 OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {

439 adaptor.getOperands(),

440 [](const APInt &lhs, const APInt &rhs) -> std::optional {

441

442 if (rhs.uge(32))

443 return {};

444 return lhs.ashr(rhs);

445 });

446 }

447

448

449

450

451

452 OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {

454 adaptor.getOperands(),

455 [](const APInt &lhs, const APInt &rhs) -> std::optional {

456

457 if (rhs.uge(32))

458 return {};

459 return lhs.lshr(rhs);

460 });

461 }

462

463

464

465

466

467 OpFoldResult AndOp::fold(FoldAdaptor adaptor) {

469 adaptor.getOperands(),

470 [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });

471 }

472

473 LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {

475 }

476

477

478

479

480

481 OpFoldResult OrOp::fold(FoldAdaptor adaptor) {

483 adaptor.getOperands(),

484 [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });

485 }

486

487 LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {

489 }

490

491

492

493

494

495 OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {

497 adaptor.getOperands(),

498 [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });

499 }

500

501 LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) {

503 }

504

505

506

507

508

511 function_ref<APInt(const APInt &, unsigned)> extFn,

512 function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {

513 auto attr = dyn_cast_if_present(input);

514 if (!attr)

515 return {};

516 const APInt &value = attr.getValue();

517

518 if (isa(type)) {

519

520

521

522 APInt result = extOrTruncFn(value, 64);

524 }

525

526

527

528 auto intType = cast(type);

529 unsigned width = intType.getWidth();

530

531

532

533 if (width <= 32) {

534 APInt result = value.trunc(width);

536 }

537

538

539

540 if (width >= 64) {

541 if (extFn(value.trunc(32), 64) != value)

542 return {};

543 APInt result = extFn(value, width);

545 }

546

547

548 APInt result = value.trunc(width);

549 if (result != extFn(value.trunc(32), width))

550 return {};

552 }

553

554 bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {

555 return llvm::isa(lhsTypes.front()) !=

556 llvm::isa(rhsTypes.front());

557 }

558

559 OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {

561 adaptor.getInput(), getType(),

562 [](const APInt &x, unsigned width) { return x.sext(width); },

563 [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });

564 }

565

566

567

568

569

570 bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {

571 return llvm::isa(lhsTypes.front()) !=

572 llvm::isa(rhsTypes.front());

573 }

574

575 OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {

577 adaptor.getInput(), getType(),

578 [](const APInt &x, unsigned width) { return x.zext(width); },

579 [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });

580 }

581

582

583

584

585

586

588 IndexCmpPredicate pred) {

589 switch (pred) {

590 case IndexCmpPredicate::EQ:

591 return lhs.eq(rhs);

592 case IndexCmpPredicate::NE:

593 return lhs.ne(rhs);

594 case IndexCmpPredicate::SGE:

595 return lhs.sge(rhs);

596 case IndexCmpPredicate::SGT:

597 return lhs.sgt(rhs);

598 case IndexCmpPredicate::SLE:

599 return lhs.sle(rhs);

600 case IndexCmpPredicate::SLT:

601 return lhs.slt(rhs);

602 case IndexCmpPredicate::UGE:

603 return lhs.uge(rhs);

604 case IndexCmpPredicate::UGT:

605 return lhs.ugt(rhs);

606 case IndexCmpPredicate::ULE:

607 return lhs.ule(rhs);

608 case IndexCmpPredicate::ULT:

609 return lhs.ult(rhs);

610 }

611 llvm_unreachable("unhandled IndexCmpPredicate predicate");

612 }

613

614

615

616

617

619 const APInt &cstA,

620 const APInt &cstB, unsigned width,

621 IndexCmpPredicate pred) {

623 .Case([&](MinSOp op) {

624 return ConstantIntRanges::fromSigned(

625 APInt::getSignedMinValue(width), cstA);

626 })

627 .Case([&](MinUOp op) {

628 return ConstantIntRanges::fromUnsigned(

629 APInt::getMinValue(width), cstA);

630 })

631 .Case([&](MaxSOp op) {

632 return ConstantIntRanges::fromSigned(

633 cstA, APInt::getSignedMaxValue(width));

634 })

635 .Case([&](MaxUOp op) {

636 return ConstantIntRanges::fromUnsigned(

637 cstA, APInt::getMaxValue(width));

638 });

640 lhsRange, ConstantIntRanges::constant(cstB));

641 }

642

643

645 switch (pred) {

646 case IndexCmpPredicate::EQ:

647 case IndexCmpPredicate::SGE:

648 case IndexCmpPredicate::SLE:

649 case IndexCmpPredicate::UGE:

650 case IndexCmpPredicate::ULE:

651 return true;

652 case IndexCmpPredicate::NE:

653 case IndexCmpPredicate::SGT:

654 case IndexCmpPredicate::SLT:

655 case IndexCmpPredicate::UGT:

656 case IndexCmpPredicate::ULT:

657 return false;

658 }

659 llvm_unreachable("unknown predicate in compareSameArgs");

660 }

661

662 OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {

663

664 auto lhs = dyn_cast_if_present(adaptor.getLhs());

665 auto rhs = dyn_cast_if_present(adaptor.getRhs());

666 if (lhs && rhs) {

667

668 bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());

669 bool result32 = compareIndices(lhs.getValue().trunc(32),

670 rhs.getValue().trunc(32), getPred());

671 if (result64 == result32)

673 }

674

675

676 Operation *lhsOp = getLhs().getDefiningOp();

677 IntegerAttr cstA;

678 if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&

681 lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred());

682 std::optional result32 =

684 rhs.getValue().trunc(32), 32, getPred());

685

686 if (result64 && result32 && *result64 == *result32)

688 }

689

690

691 if (getLhs() == getRhs())

693

694 return {};

695 }

696

697

698

699

700 LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {

701 IntegerAttr cmpRhs;

702 IntegerAttr cmpLhs;

703

705 cmpRhs.getValue().isZero();

707 cmpLhs.getValue().isZero();

708 if (!rhsIsZero && !lhsIsZero)

710 "cmp is not comparing something with 0");

711 SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOpindex::SubOp()

712 : op.getRhs().getDefiningOpindex::SubOp();

713 if (!subOp)

715 op.getLoc(), "non-zero operand is not a result of subtraction");

716

717 index::CmpOp newCmp;

718 if (rhsIsZero)

719 newCmp = rewriter.createindex::CmpOp(op.getLoc(), op.getPred(),

720 subOp.getLhs(), subOp.getRhs());

721 else

722 newCmp = rewriter.createindex::CmpOp(op.getLoc(), op.getPred(),

723 subOp.getRhs(), subOp.getLhs());

725 return success();

726 }

727

728

729

730

731

732 void ConstantOp::getAsmResultNames(

735 llvm::raw_svector_ostream specialName(specialNameBuffer);

736 specialName << "idx" << getValueAttr().getValue();

737 setNameFn(getResult(), specialName.str());

738 }

739

740 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }

741

744 }

745

746

747

748

749

750 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {

751 return getValueAttr();

752 }

753

754 void BoolConstantOp::getAsmResultNames(

756 setNameFn(getResult(), getValue() ? "true" : "false");

757 }

758

759

760

761

762

763 #define GET_OP_CLASSES

764 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"

static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)

A utility function used to materialize a constant for a given attribute and type.

static OpFoldResult foldBinaryOpUnchecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &)> calculate)

Fold an index operation irrespective of the target bitwidth.

LogicalResult canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op, PatternRewriter &rewriter)

Helper for associative and commutative binary ops that can be transformed: x = op(v,...

bool compareIndices(const APInt &lhs, const APInt &rhs, IndexCmpPredicate pred)

Compare two integers according to the comparison predicate.

static OpFoldResult foldBinaryOpChecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &lhs)> calculate)

Fold an index operation only if the truncated 64-bit result matches the 32-bit result for operations ...

static std::optional< bool > foldCmpOfMaxOrMin(Operation *lhsOp, const APInt &cstA, const APInt &cstB, unsigned width, IndexCmpPredicate pred)

cmp(max/min(x, cstA), cstB) can be folded to a constant depending on the values of cstA and cstB,...

static OpFoldResult foldCastOp(Attribute input, Type type, function_ref< APInt(const APInt &, unsigned)> extFn, function_ref< APInt(const APInt &, unsigned)> extOrTruncFn)

static std::optional< APInt > calculateCeilDivS(const APInt &n, const APInt &m)

Compute ceildivs(n, m) as x = m > 0 ? -1 : 1 and then n*m > 0 ? (n+x)/m + 1 : -(-n/m).

static bool compareSameArgs(IndexCmpPredicate pred)

Return the result of cmp(pred, x, x)

static std::optional< APInt > calculateFloorDivS(const APInt &n, const APInt &m)

Compute floordivs(n, m) as x = m < 0 ? 1 : -1 and then n*m < 0 ? -1 - (x-n)/m : n/m.

static MLIRContext * getContext(OpFoldResult val)

Attributes are known-constant values of operations.

IntegerAttr getIndexAttr(int64_t value)

A set of arbitrary-precision integers representing bounds on a given integer value.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

This class helps build Operations.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents a single result from folding an operation.

Operation is the basic unit of execution within MLIR.

Value getOperand(unsigned idx)

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isSignlessInteger() const

Return true if this is a signless integer type (with the specified width).

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)

Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...

CmpPredicate

Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

This represents an operation in an abstracted form, suitable for use with the builder APIs.