MLIR: lib/Dialect/Index/IR/IndexOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
17 #include "llvm/ADT/SmallString.h"
18 #include "llvm/ADT/TypeSwitch.h"
19
20 using namespace mlir;
22
23
24
25
26
27 void IndexDialect::registerOperations() {
28 addOperations<
29 #define GET_OP_LIST
30 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
31 >();
32 }
33
36
37 if (auto boolValue = dyn_cast(value)) {
39 return nullptr;
40 return b.create(loc, type, boolValue);
41 }
42
43
44 if (auto indexValue = dyn_cast(value)) {
45 if (!llvm::isa(indexValue.getType()) ||
46 !llvm::isa(type))
47 return nullptr;
48 assert(indexValue.getValue().getBitWidth() ==
49 IndexType::kInternalStorageBitWidth);
50 return b.create(loc, indexValue);
51 }
52
53 return nullptr;
54 }
55
56
57
58
59
60
61
62
63
64
65
66
67
68
71 function_ref<std::optional(const APInt &, const APInt &)>
72 calculate) {
73 assert(operands.size() == 2 && "binary operation expected 2 operands");
74 auto lhs = dyn_cast_if_present(operands[0]);
75 auto rhs = dyn_cast_if_present(operands[1]);
76 if (!lhs || !rhs)
77 return {};
78
79 std::optional result = calculate(lhs.getValue(), rhs.getValue());
80 if (!result)
81 return {};
82 assert(result->trunc(32) ==
83 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
85 }
86
87
88
89
90
91
92
93
94
97 function_ref<std::optional(const APInt &, const APInt &lhs)>
98 calculate) {
99 assert(operands.size() == 2 && "binary operation expected 2 operands");
100 auto lhs = dyn_cast_if_present(operands[0]);
101 auto rhs = dyn_cast_if_present(operands[1]);
102
103 if (!lhs || !rhs)
104 return {};
105
106
107 std::optional result64 = calculate(lhs.getValue(), rhs.getValue());
108 if (!result64)
109 return {};
110 std::optional result32 =
111 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
112 if (!result32)
113 return {};
114
115 if (result64->trunc(32) != *result32)
116 return {};
117
119 }
120
121
122
123
124 template
125 LogicalResult
129 return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
130
131 auto lhsOp = op.getLhs().template getDefiningOp();
132 if (!lhsOp)
133 return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp");
134
136 return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant");
137
138 Value c = rewriter.createOrFold(op->getLoc(), op.getRhs(),
139 lhsOp.getRhs());
141 return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded");
142
144 return success();
145 }
146
147
148
149
150
151 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
153 adaptor.getOperands(),
154 [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; }))
155 return result;
156
157 if (auto rhs = dyn_cast_or_null(adaptor.getRhs())) {
158
159 if (rhs.getValue().isZero())
160 return getLhs();
161 }
162
163 return {};
164 }
165
166 LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
168 }
169
170
171
172
173
174 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
176 adaptor.getOperands(),
177 [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; }))
178 return result;
179
180 if (auto rhs = dyn_cast_or_null(adaptor.getRhs())) {
181
182 if (rhs.getValue().isZero())
183 return getLhs();
184 }
185
186 return {};
187 }
188
189
190
191
192
193 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
195 adaptor.getOperands(),
196 [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; }))
197 return result;
198
199 if (auto rhs = dyn_cast_or_null(adaptor.getRhs())) {
200
201 if (rhs.getValue().isOne())
202 return getLhs();
203
204 if (rhs.getValue().isZero())
205 return rhs;
206 }
207
208 return {};
209 }
210
211 LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
213 }
214
215
216
217
218
219 OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
221 adaptor.getOperands(),
222 [](const APInt &lhs, const APInt &rhs) -> std::optional {
223
224 if (rhs.isZero())
225 return std::nullopt;
226 return lhs.sdiv(rhs);
227 });
228 }
229
230
231
232
233
234 OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
236 adaptor.getOperands(),
237 [](const APInt &lhs, const APInt &rhs) -> std::optional {
238
239 if (rhs.isZero())
240 return std::nullopt;
241 return lhs.udiv(rhs);
242 });
243 }
244
245
246
247
248
249
250
251 static std::optional calculateCeilDivS(const APInt &n, const APInt &m) {
252
253 if (m.isZero())
254 return std::nullopt;
255
256 if (n.isZero())
257 return n;
258
259 bool mGtZ = m.sgt(0);
260 if (n.sgt(0) != mGtZ) {
261
262
263
264 return -(-n).sdiv(m);
265 }
266
267
268 int64_t x = mGtZ ? -1 : 1;
269 return (n + x).sdiv(m) + 1;
270 }
271
272 OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) {
274 }
275
276
277
278
279
280 OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) {
281
283 adaptor.getOperands(),
284 [](const APInt &n, const APInt &m) -> std::optional {
285
286 if (m.isZero())
287 return std::nullopt;
288
289 if (n.isZero())
290 return n;
291
292 return (n - 1).udiv(m) + 1;
293 });
294 }
295
296
297
298
299
300
301
303
304 if (m.isZero())
305 return std::nullopt;
306
307 if (n.isZero())
308 return n;
309
310 bool mLtZ = m.slt(0);
311 if (n.slt(0) == mLtZ) {
312
313 return n.sdiv(m);
314 }
315
316
317
318 int64_t x = mLtZ ? 1 : -1;
319 return -1 - (x - n).sdiv(m);
320 }
321
322 OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {
324 }
325
326
327
328
329
330 OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
332 adaptor.getOperands(),
333 [](const APInt &lhs, const APInt &rhs) -> std::optional {
334
335 if (rhs.isZero())
336 return std::nullopt;
337 return lhs.srem(rhs);
338 });
339 }
340
341
342
343
344
345 OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {
347 adaptor.getOperands(),
348 [](const APInt &lhs, const APInt &rhs) -> std::optional {
349
350 if (rhs.isZero())
351 return std::nullopt;
352 return lhs.urem(rhs);
353 });
354 }
355
356
357
358
359
360 OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
362 [](const APInt &lhs, const APInt &rhs) {
363 return lhs.sgt(rhs) ? lhs : rhs;
364 });
365 }
366
367 LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) {
369 }
370
371
372
373
374
375 OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
377 [](const APInt &lhs, const APInt &rhs) {
378 return lhs.ugt(rhs) ? lhs : rhs;
379 });
380 }
381
382 LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) {
384 }
385
386
387
388
389
390 OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
392 [](const APInt &lhs, const APInt &rhs) {
393 return lhs.slt(rhs) ? lhs : rhs;
394 });
395 }
396
397 LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) {
399 }
400
401
402
403
404
405 OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
407 [](const APInt &lhs, const APInt &rhs) {
408 return lhs.ult(rhs) ? lhs : rhs;
409 });
410 }
411
412 LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) {
414 }
415
416
417
418
419
420 OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
422 adaptor.getOperands(),
423 [](const APInt &lhs, const APInt &rhs) -> std::optional {
424
425
426
427 if (rhs.uge(32))
428 return {};
429 return lhs << rhs;
430 });
431 }
432
433
434
435
436
437 OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
439 adaptor.getOperands(),
440 [](const APInt &lhs, const APInt &rhs) -> std::optional {
441
442 if (rhs.uge(32))
443 return {};
444 return lhs.ashr(rhs);
445 });
446 }
447
448
449
450
451
452 OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
454 adaptor.getOperands(),
455 [](const APInt &lhs, const APInt &rhs) -> std::optional {
456
457 if (rhs.uge(32))
458 return {};
459 return lhs.lshr(rhs);
460 });
461 }
462
463
464
465
466
467 OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
469 adaptor.getOperands(),
470 [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
471 }
472
473 LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
475 }
476
477
478
479
480
481 OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
483 adaptor.getOperands(),
484 [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
485 }
486
487 LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
489 }
490
491
492
493
494
495 OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
497 adaptor.getOperands(),
498 [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
499 }
500
501 LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) {
503 }
504
505
506
507
508
511 function_ref<APInt(const APInt &, unsigned)> extFn,
512 function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
513 auto attr = dyn_cast_if_present(input);
514 if (!attr)
515 return {};
516 const APInt &value = attr.getValue();
517
518 if (isa(type)) {
519
520
521
522 APInt result = extOrTruncFn(value, 64);
524 }
525
526
527
528 auto intType = cast(type);
529 unsigned width = intType.getWidth();
530
531
532
533 if (width <= 32) {
534 APInt result = value.trunc(width);
536 }
537
538
539
540 if (width >= 64) {
541 if (extFn(value.trunc(32), 64) != value)
542 return {};
543 APInt result = extFn(value, width);
545 }
546
547
548 APInt result = value.trunc(width);
549 if (result != extFn(value.trunc(32), width))
550 return {};
552 }
553
554 bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
555 return llvm::isa(lhsTypes.front()) !=
556 llvm::isa(rhsTypes.front());
557 }
558
559 OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
561 adaptor.getInput(), getType(),
562 [](const APInt &x, unsigned width) { return x.sext(width); },
563 [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
564 }
565
566
567
568
569
570 bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
571 return llvm::isa(lhsTypes.front()) !=
572 llvm::isa(rhsTypes.front());
573 }
574
575 OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
577 adaptor.getInput(), getType(),
578 [](const APInt &x, unsigned width) { return x.zext(width); },
579 [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
580 }
581
582
583
584
585
586
588 IndexCmpPredicate pred) {
589 switch (pred) {
590 case IndexCmpPredicate::EQ:
591 return lhs.eq(rhs);
592 case IndexCmpPredicate::NE:
593 return lhs.ne(rhs);
594 case IndexCmpPredicate::SGE:
595 return lhs.sge(rhs);
596 case IndexCmpPredicate::SGT:
597 return lhs.sgt(rhs);
598 case IndexCmpPredicate::SLE:
599 return lhs.sle(rhs);
600 case IndexCmpPredicate::SLT:
601 return lhs.slt(rhs);
602 case IndexCmpPredicate::UGE:
603 return lhs.uge(rhs);
604 case IndexCmpPredicate::UGT:
605 return lhs.ugt(rhs);
606 case IndexCmpPredicate::ULE:
607 return lhs.ule(rhs);
608 case IndexCmpPredicate::ULT:
609 return lhs.ult(rhs);
610 }
611 llvm_unreachable("unhandled IndexCmpPredicate predicate");
612 }
613
614
615
616
617
619 const APInt &cstA,
620 const APInt &cstB, unsigned width,
621 IndexCmpPredicate pred) {
623 .Case([&](MinSOp op) {
624 return ConstantIntRanges::fromSigned(
625 APInt::getSignedMinValue(width), cstA);
626 })
627 .Case([&](MinUOp op) {
628 return ConstantIntRanges::fromUnsigned(
629 APInt::getMinValue(width), cstA);
630 })
631 .Case([&](MaxSOp op) {
632 return ConstantIntRanges::fromSigned(
633 cstA, APInt::getSignedMaxValue(width));
634 })
635 .Case([&](MaxUOp op) {
636 return ConstantIntRanges::fromUnsigned(
637 cstA, APInt::getMaxValue(width));
638 });
640 lhsRange, ConstantIntRanges::constant(cstB));
641 }
642
643
645 switch (pred) {
646 case IndexCmpPredicate::EQ:
647 case IndexCmpPredicate::SGE:
648 case IndexCmpPredicate::SLE:
649 case IndexCmpPredicate::UGE:
650 case IndexCmpPredicate::ULE:
651 return true;
652 case IndexCmpPredicate::NE:
653 case IndexCmpPredicate::SGT:
654 case IndexCmpPredicate::SLT:
655 case IndexCmpPredicate::UGT:
656 case IndexCmpPredicate::ULT:
657 return false;
658 }
659 llvm_unreachable("unknown predicate in compareSameArgs");
660 }
661
662 OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
663
664 auto lhs = dyn_cast_if_present(adaptor.getLhs());
665 auto rhs = dyn_cast_if_present(adaptor.getRhs());
666 if (lhs && rhs) {
667
668 bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
669 bool result32 = compareIndices(lhs.getValue().trunc(32),
670 rhs.getValue().trunc(32), getPred());
671 if (result64 == result32)
673 }
674
675
676 Operation *lhsOp = getLhs().getDefiningOp();
677 IntegerAttr cstA;
678 if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&
681 lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred());
682 std::optional result32 =
684 rhs.getValue().trunc(32), 32, getPred());
685
686 if (result64 && result32 && *result64 == *result32)
688 }
689
690
691 if (getLhs() == getRhs())
693
694 return {};
695 }
696
697
698
699
700 LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
701 IntegerAttr cmpRhs;
702 IntegerAttr cmpLhs;
703
705 cmpRhs.getValue().isZero();
707 cmpLhs.getValue().isZero();
708 if (!rhsIsZero && !lhsIsZero)
710 "cmp is not comparing something with 0");
711 SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOpindex::SubOp()
712 : op.getRhs().getDefiningOpindex::SubOp();
713 if (!subOp)
715 op.getLoc(), "non-zero operand is not a result of subtraction");
716
717 index::CmpOp newCmp;
718 if (rhsIsZero)
719 newCmp = rewriter.createindex::CmpOp(op.getLoc(), op.getPred(),
720 subOp.getLhs(), subOp.getRhs());
721 else
722 newCmp = rewriter.createindex::CmpOp(op.getLoc(), op.getPred(),
723 subOp.getRhs(), subOp.getLhs());
725 return success();
726 }
727
728
729
730
731
732 void ConstantOp::getAsmResultNames(
735 llvm::raw_svector_ostream specialName(specialNameBuffer);
736 specialName << "idx" << getValueAttr().getValue();
737 setNameFn(getResult(), specialName.str());
738 }
739
740 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
741
744 }
745
746
747
748
749
750 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
751 return getValueAttr();
752 }
753
754 void BoolConstantOp::getAsmResultNames(
756 setNameFn(getResult(), getValue() ? "true" : "false");
757 }
758
759
760
761
762
763 #define GET_OP_CLASSES
764 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static OpFoldResult foldBinaryOpUnchecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &)> calculate)
Fold an index operation irrespective of the target bitwidth.
LogicalResult canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op, PatternRewriter &rewriter)
Helper for associative and commutative binary ops that can be transformed: x = op(v,...
bool compareIndices(const APInt &lhs, const APInt &rhs, IndexCmpPredicate pred)
Compare two integers according to the comparison predicate.
static OpFoldResult foldBinaryOpChecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &lhs)> calculate)
Fold an index operation only if the truncated 64-bit result matches the 32-bit result for operations ...
static std::optional< bool > foldCmpOfMaxOrMin(Operation *lhsOp, const APInt &cstA, const APInt &cstB, unsigned width, IndexCmpPredicate pred)
cmp(max/min(x, cstA), cstB) can be folded to a constant depending on the values of cstA and cstB,...
static OpFoldResult foldCastOp(Attribute input, Type type, function_ref< APInt(const APInt &, unsigned)> extFn, function_ref< APInt(const APInt &, unsigned)> extOrTruncFn)
static std::optional< APInt > calculateCeilDivS(const APInt &n, const APInt &m)
Compute ceildivs(n, m) as x = m > 0 ? -1 : 1 and then n*m > 0 ? (n+x)/m + 1 : -(-n/m).
static bool compareSameArgs(IndexCmpPredicate pred)
Return the result of cmp(pred, x, x)
static std::optional< APInt > calculateFloorDivS(const APInt &n, const APInt &m)
Compute floordivs(n, m) as x = m < 0 ? 1 : -1 and then n*m < 0 ? -1 - (x-n)/m : n/m.
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
A set of arbitrary-precision integers representing bounds on a given integer value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.