MLIR: lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13 #include
14 #include
15
17
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVectorExtras.h"
26
27 using namespace mlir;
28
29
30
31
32
33
34
36 if (!attr)
37 return std::nullopt;
38
39 if (auto boolAttr = llvm::dyn_cast(attr))
40 return boolAttr.getValue();
41 if (auto splatAttr = llvm::dyn_cast(attr))
42 if (splatAttr.getElementType().isInteger(1))
43 return splatAttr.getSplatValue<bool>();
44 return std::nullopt;
45 }
46
47
48
51
52 if (!composite)
53 return {};
54
55 if (indices.empty())
56 return composite;
57
58 if (auto vector = llvm::dyn_cast(composite)) {
59 assert(indices.size() == 1 && "must have exactly one index for a vector");
60 return vector.getValues<Attribute>()[indices[0]];
61 }
62
63 if (auto array = llvm::dyn_cast(composite)) {
64 assert(!indices.empty() && "must have at least one index for an array");
66 indices.drop_front());
67 }
68
69 return {};
70 }
71
73 bool div0 = b.isZero();
74 bool overflow = a.isMinSignedValue() && b.isAllOnes();
75
76 return div0 || overflow;
77 }
78
79
80
81
82
83 namespace {
84 #include "SPIRVCanonicalization.inc"
85 }
86
87
88
89
90
91 namespace {
92
93
94
95 struct CombineChainedAccessChain final
98
99 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
101 auto parentAccessChainOp =
102 accessChainOp.getBasePtr().getDefiningOpspirv::AccessChainOp();
103
104 if (!parentAccessChainOp) {
105 return failure();
106 }
107
108
110 llvm::append_range(indices, accessChainOp.getIndices());
111
113 accessChainOp, parentAccessChainOp.getBasePtr(), indices);
114
115 return success();
116 }
117 };
118 }
119
120 void spirv::AccessChainOp::getCanonicalizationPatterns(
122 results.add(context);
123 }
124
125
126
127
128
129
130
133
137 Value lhs = op.getOperand1();
138 Value rhs = op.getOperand2();
140
141
143 Value constituents[2] = {rhs, lhs};
144 rewriter.replaceOpWithNewOpspirv::CompositeConstructOp(op, op.getType(),
145 constituents);
146 return success();
147 }
148
149
150
151
152
153
154
155
156
157
158
159
164 return failure();
165
166 auto adds = constFoldBinaryOp(
167 {lhsAttr, rhsAttr},
168 [](const APInt &a, const APInt &b) { return a + b; });
169 if (!adds)
170 return failure();
171
172 auto carrys = constFoldBinaryOp(
173 ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
175 return a.ult(b) ? (zero + 1) : zero;
176 });
177
178 if (!carrys)
179 return failure();
180
182 rewriter.createspirv::ConstantOp(loc, constituentType, adds);
183
184 Value carrysVal =
185 rewriter.createspirv::ConstantOp(loc, constituentType, carrys);
186
187
188 Value undef = rewriter.createspirv::UndefOp(loc, op.getType());
189
190 Value intermediate =
191 rewriter.createspirv::CompositeInsertOp(loc, addsVal, undef, 0);
192
194 intermediate, 1);
195 return success();
196 }
197 };
198
199 void spirv::IAddCarryOp::getCanonicalizationPatterns(
202 }
203
204
205
206
207
208
209
210 template <typename MulOp, bool IsSigned>
213
217 Value lhs = op.getOperand1();
218 Value rhs = op.getOperand2();
220
221
224 Value constituents[2] = {zero, zero};
225 rewriter.replaceOpWithNewOpspirv::CompositeConstructOp(op, op.getType(),
226 constituents);
227 return success();
228 }
229
230
231
232
233
234
235
236
237
242 return failure();
243
244 auto lowBits = constFoldBinaryOp(
245 {lhsAttr, rhsAttr},
246 [](const APInt &a, const APInt &b) { return a * b; });
247
248 if (!lowBits)
249 return failure();
250
251 auto highBits = constFoldBinaryOp(
252 {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
253 if (IsSigned) {
254 return llvm::APIntOps::mulhs(a, b);
255 } else {
256 return llvm::APIntOps::mulhu(a, b);
257 }
258 });
259
260 if (!highBits)
261 return failure();
262
263 Value lowBitsVal =
264 rewriter.createspirv::ConstantOp(loc, constituentType, lowBits);
265
266 Value highBitsVal =
267 rewriter.createspirv::ConstantOp(loc, constituentType, highBits);
268
269
270 Value undef = rewriter.createspirv::UndefOp(loc, op.getType());
271
272 Value intermediate =
273 rewriter.createspirv::CompositeInsertOp(loc, lowBitsVal, undef, 0);
274
275 rewriter.replaceOpWithNewOpspirv::CompositeInsertOp(op, highBitsVal,
276 intermediate, 1);
277 return success();
278 }
279 };
280
282 void spirv::SMulExtendedOp::getCanonicalizationPatterns(
285 }
286
289
293 Value lhs = op.getOperand1();
294 Value rhs = op.getOperand2();
296
297
300 Value constituents[2] = {lhs, zero};
301 rewriter.replaceOpWithNewOpspirv::CompositeConstructOp(op, op.getType(),
302 constituents);
303 return success();
304 }
305
306 return failure();
307 }
308 };
309
311 void spirv::UMulExtendedOp::getCanonicalizationPatterns(
314 }
315
316
317
318
319
320
321
322
323
324
325
326
327
328
331
334 auto prevUMod = umodOp.getOperand(0).getDefiningOpspirv::UModOp();
335 if (!prevUMod)
336 return failure();
337
338 TypedAttr prevValue;
339 TypedAttr currValue;
342 return failure();
343
344
345
346 bool isApplicable = false;
347 if (auto prevInt = dyn_cast(prevValue)) {
348 auto currInt = cast(currValue);
349 isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
350 } else if (auto prevVec = dyn_cast(prevValue)) {
351 auto currVec = cast(currValue);
352 isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues(),
353 currVec.getValues()),
354 [](const auto &pair) {
355 auto &[prev, curr] = pair;
356 return prev.urem(curr) == 0;
357 });
358 }
359
360 if (!isApplicable)
361 return failure();
362
363
364
366 umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
367
368 return success();
369 }
370 };
371
375 }
376
377
378
379
380
381 OpFoldResult spirv::BitcastOp::fold(FoldAdaptor ) {
382 Value curInput = getOperand();
384 return curInput;
385
386
387 if (auto prevCast = curInput.getDefiningOpspirv::BitcastOp()) {
388 Value prevInput = prevCast.getOperand();
390 return prevInput;
391
392 getOperandMutable().assign(prevInput);
393 return getResult();
394 }
395
396
397 return {};
398 }
399
400
401
402
403
404 OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
405 Value compositeOp = getComposite();
406
407 while (auto insertOp =
408 compositeOp.getDefiningOpspirv::CompositeInsertOp()) {
409 if (getIndices() == insertOp.getIndices())
410 return insertOp.getObject();
411 compositeOp = insertOp.getComposite();
412 }
413
414 if (auto constructOp =
415 compositeOp.getDefiningOpspirv::CompositeConstructOp()) {
416 auto type = llvm::castspirv::CompositeType(constructOp.getType());
418 constructOp.getConstituents().size() == type.getNumElements()) {
419 auto i = llvm::cast(*getIndices().begin());
420 if (i.getValue().getSExtValue() <
421 static_cast<int64_t>(constructOp.getConstituents().size()))
422 return constructOp.getConstituents()[i.getValue().getSExtValue()];
423 }
424 }
425
426 auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
427 return static_cast<unsigned>(llvm::cast(attr).getInt());
428 });
430 }
431
432
433
434
435
436 OpFoldResult spirv::ConstantOp::fold(FoldAdaptor ) {
437 return getValue();
438 }
439
440
441
442
443
444 OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
445
447 return getOperand1();
448
449
450
451
452
453
454 return constFoldBinaryOp(
455 adaptor.getOperands(),
456 [](APInt a, const APInt &b) { return std::move(a) + b; });
457 }
458
459
460
461
462
463 OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
464
466 return getOperand2();
467
469 return getOperand1();
470
471
472
473
474
475
476 return constFoldBinaryOp(
477 adaptor.getOperands(),
478 [](const APInt &a, const APInt &b) { return a * b; });
479 }
480
481
482
483
484
485 OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
486
487 if (getOperand1() == getOperand2())
489
490
491
492
493
494
495 return constFoldBinaryOp(
496 adaptor.getOperands(),
497 [](APInt a, const APInt &b) { return std::move(a) - b; });
498 }
499
500
501
502
503
504 OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
505
507 return getOperand1();
508
509
510
511
512
513
514
515
516
517 bool div0OrOverflow = false;
518 auto res = constFoldBinaryOp(
519 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
520 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
521 div0OrOverflow = true;
522 return a;
523 }
524 return a.sdiv(b);
525 });
526 return div0OrOverflow ? Attribute() : res;
527 }
528
529
530
531
532
533 OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
534
537
538
539
540
541
542
543
544
545
546
547
548 bool div0OrOverflow = false;
549 auto res = constFoldBinaryOp(
550 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
551 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
552 div0OrOverflow = true;
553 return a;
554 }
555 APInt c = a.abs().urem(b.abs());
556 if (c.isZero())
557 return c;
558 if (b.isNegative()) {
559 APInt zero = APInt::getZero(c.getBitWidth());
560 return a.isNegative() ? (zero - c) : (b + c);
561 }
562 return a.isNegative() ? (b - c) : c;
563 });
564 return div0OrOverflow ? Attribute() : res;
565 }
566
567
568
569
570
571 OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
572
575
576
577
578
579
580
581
582
583
584
585
586 bool div0OrOverflow = false;
587 auto res = constFoldBinaryOp(
588 adaptor.getOperands(), [&](APInt a, const APInt &b) {
589 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
590 div0OrOverflow = true;
591 return a;
592 }
593 return a.srem(b);
594 });
595 return div0OrOverflow ? Attribute() : res;
596 }
597
598
599
600
601
602 OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
603
605 return getOperand1();
606
607
608
609
610
611
612
613 bool div0 = false;
614 auto res = constFoldBinaryOp(
615 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
616 if (div0 || b.isZero()) {
617 div0 = true;
618 return a;
619 }
620 return a.udiv(b);
621 });
623 }
624
625
626
627
628
629 OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
630
633
634
635
636
637
638
639
640 bool div0 = false;
641 auto res = constFoldBinaryOp(
642 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
643 if (div0 || b.isZero()) {
644 div0 = true;
645 return a;
646 }
647 return a.urem(b);
648 });
650 }
651
652
653
654
655
656 OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
657
658 auto op = getOperand();
659 if (auto negateOp = op.getDefiningOpspirv::SNegateOp())
660 return negateOp->getOperand(0);
661
662
663
664
665 return constFoldUnaryOp(
666 adaptor.getOperands(), [](const APInt &a) {
667 APInt zero = APInt::getZero(a.getBitWidth());
668 return zero - a;
669 });
670 }
671
672
673
674
675
676 OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
677
678 auto op = getOperand();
679 if (auto notOp = op.getDefiningOpspirv::NotOp())
680 return notOp->getOperand(0);
681
682
683
684
685 return constFoldUnaryOp(adaptor.getOperands(), [&](APInt a) {
686 a.flipAllBits();
687 return a;
688 });
689 }
690
691
692
693
694
695 OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
696 if (std::optional rhs =
698
699 if (*rhs)
700 return getOperand1();
701
702
703 if (!*rhs)
704 return adaptor.getOperand2();
705 }
706
708 }
709
710
711
712
713
715 spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
716
717 if (getOperand1() == getOperand2()) {
719 if (isa(getType()))
720 return trueAttr;
721 if (auto vecTy = dyn_cast(getType()))
723 }
724
725 return constFoldBinaryOp(
726 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
727 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
728 });
729 }
730
731
732
733
734
735 OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
736 if (std::optional rhs =
738
739 if (!rhs.value())
740 return getOperand1();
741 }
742
743
744 if (getOperand1() == getOperand2()) {
746 if (isa(getType()))
747 return falseAttr;
748 if (auto vecTy = dyn_cast(getType()))
750 }
751
752 return constFoldBinaryOp(
753 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
754 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
755 });
756 }
757
758
759
760
761
762 OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
763
764 auto op = getOperand();
765 if (auto notOp = op.getDefiningOpspirv::LogicalNotOp())
766 return notOp->getOperand(0);
767
768
769
770
771 return constFoldUnaryOp(adaptor.getOperands(),
772 [](const APInt &a) {
773 APInt zero = APInt::getZero(1);
774 return a == 1 ? zero : (zero + 1);
775 });
776 }
777
778 void spirv::LogicalNotOp::getCanonicalizationPatterns(
780 results
781 .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
782 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
783 context);
784 }
785
786
787
788
789
790 OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
792 if (*rhs) {
793
794 return adaptor.getOperand2();
795 }
796
797 if (!*rhs) {
798
799 return getOperand1();
800 }
801 }
802
804 }
805
806
807
808
809
810 OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
811
812 Value trueVals = getTrueValue();
813 Value falseVals = getFalseValue();
814 if (trueVals == falseVals)
815 return trueVals;
816
818
819
820
822 return *boolAttr ? trueVals : falseVals;
823
824
825 if (!operands[0] || !operands[1] || !operands[2])
827
828
829
830
831 auto condAttrs = dyn_cast(operands[0]);
832 auto trueAttrs = dyn_cast(operands[1]);
833 auto falseAttrs = dyn_cast(operands[2]);
834 if (!condAttrs || !trueAttrs || !falseAttrs)
836
837 auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());
838 auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),
839 falseAttrs.getValues<Attribute>());
840 for (auto [result, cond, falseRes] : iters) {
841 if (!cond.getValue())
842 result = falseRes;
843 }
844
845 auto resultType = trueAttrs.getType();
847 }
848
849
850
851
852
853 OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
854
855 if (getOperand1() == getOperand2()) {
857 if (isa(getType()))
858 return trueAttr;
859 if (auto vecTy = dyn_cast(getType()))
861 }
862
863 return constFoldBinaryOp(
864 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
865 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
866 });
867 }
868
869
870
871
872
873 OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
874
875 if (getOperand1() == getOperand2()) {
877 if (isa(getType()))
878 return falseAttr;
879 if (auto vecTy = dyn_cast(getType()))
881 }
882
883 return constFoldBinaryOp(
884 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
885 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
886 });
887 }
888
889
890
891
892
894 spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
895
896 if (getOperand1() == getOperand2()) {
898 if (isa(getType()))
899 return falseAttr;
900 if (auto vecTy = dyn_cast(getType()))
902 }
903
904 return constFoldBinaryOp(
905 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
906 return a.sgt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
907 });
908 }
909
910
911
912
913
914 OpFoldResult spirv::SGreaterThanEqualOp::fold(
915 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
916
917 if (getOperand1() == getOperand2()) {
919 if (isa(getType()))
920 return trueAttr;
921 if (auto vecTy = dyn_cast(getType()))
923 }
924
925 return constFoldBinaryOp(
926 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
927 return a.sge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
928 });
929 }
930
931
932
933
934
936 spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
937
938 if (getOperand1() == getOperand2()) {
940 if (isa(getType()))
941 return falseAttr;
942 if (auto vecTy = dyn_cast(getType()))
944 }
945
946 return constFoldBinaryOp(
947 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
948 return a.ugt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
949 });
950 }
951
952
953
954
955
956 OpFoldResult spirv::UGreaterThanEqualOp::fold(
957 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
958
959 if (getOperand1() == getOperand2()) {
961 if (isa(getType()))
962 return trueAttr;
963 if (auto vecTy = dyn_cast(getType()))
965 }
966
967 return constFoldBinaryOp(
968 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
969 return a.uge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
970 });
971 }
972
973
974
975
976
977 OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
978
979 if (getOperand1() == getOperand2()) {
981 if (isa(getType()))
982 return falseAttr;
983 if (auto vecTy = dyn_cast(getType()))
985 }
986
987 return constFoldBinaryOp(
988 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
989 return a.slt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
990 });
991 }
992
993
994
995
996
998 spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
999
1000 if (getOperand1() == getOperand2()) {
1002 if (isa(getType()))
1003 return trueAttr;
1004 if (auto vecTy = dyn_cast(getType()))
1006 }
1007
1008 return constFoldBinaryOp(
1009 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1010 return a.sle(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1011 });
1012 }
1013
1014
1015
1016
1017
1018 OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
1019
1020 if (getOperand1() == getOperand2()) {
1022 if (isa(getType()))
1023 return falseAttr;
1024 if (auto vecTy = dyn_cast(getType()))
1026 }
1027
1028 return constFoldBinaryOp(
1029 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1030 return a.ult(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1031 });
1032 }
1033
1034
1035
1036
1037
1039 spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1040
1041 if (getOperand1() == getOperand2()) {
1043 if (isa(getType()))
1044 return trueAttr;
1045 if (auto vecTy = dyn_cast(getType()))
1047 }
1048
1049 return constFoldBinaryOp(
1050 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1051 return a.ule(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1052 });
1053 }
1054
1055
1056
1057
1058
1059 OpFoldResult spirv::ShiftLeftLogicalOp::fold(
1060 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1061
1063 return getOperand1();
1064 }
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074 bool shiftToLarge = false;
1075 auto res = constFoldBinaryOp(
1076 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1077 if (shiftToLarge || b.uge(a.getBitWidth())) {
1078 shiftToLarge = true;
1079 return a;
1080 }
1081 return a << b;
1082 });
1083 return shiftToLarge ? Attribute() : res;
1084 }
1085
1086
1087
1088
1089
1090 OpFoldResult spirv::ShiftRightArithmeticOp::fold(
1091 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1092
1094 return getOperand1();
1095 }
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105 bool shiftToLarge = false;
1106 auto res = constFoldBinaryOp(
1107 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1108 if (shiftToLarge || b.uge(a.getBitWidth())) {
1109 shiftToLarge = true;
1110 return a;
1111 }
1112 return a.ashr(b);
1113 });
1114 return shiftToLarge ? Attribute() : res;
1115 }
1116
1117
1118
1119
1120
1121 OpFoldResult spirv::ShiftRightLogicalOp::fold(
1122 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1123
1125 return getOperand1();
1126 }
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136 bool shiftToLarge = false;
1137 auto res = constFoldBinaryOp(
1138 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1139 if (shiftToLarge || b.uge(a.getBitWidth())) {
1140 shiftToLarge = true;
1141 return a;
1142 }
1143 return a.lshr(b);
1144 });
1145 return shiftToLarge ? Attribute() : res;
1146 }
1147
1148
1149
1150
1151
1153 spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1154
1155 if (getOperand1() == getOperand2()) {
1156 return getOperand1();
1157 }
1158
1159 APInt rhsMask;
1161
1162 if (rhsMask.isZero())
1163 return getOperand2();
1164
1165
1166 if (rhsMask.isAllOnes())
1167 return getOperand1();
1168
1169
1170 if (auto zext = getOperand1().getDefiningOpspirv::UConvertOp()) {
1171 int valueBits =
1173 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1174 return getOperand1();
1175 }
1176 }
1177
1178
1179
1180
1181
1182
1183 return constFoldBinaryOp(
1184 adaptor.getOperands(),
1185 [](const APInt &a, const APInt &b) { return a & b; });
1186 }
1187
1188
1189
1190
1191
1192 OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1193
1194 if (getOperand1() == getOperand2()) {
1195 return getOperand1();
1196 }
1197
1198 APInt rhsMask;
1200
1201 if (rhsMask.isZero())
1202 return getOperand1();
1203
1204
1205 if (rhsMask.isAllOnes())
1206 return getOperand2();
1207 }
1208
1209
1210
1211
1212
1213
1214 return constFoldBinaryOp(
1215 adaptor.getOperands(),
1216 [](const APInt &a, const APInt &b) { return a | b; });
1217 }
1218
1219
1220
1221
1222
1224 spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1225
1227 return getOperand1();
1228 }
1229
1230
1231 if (getOperand1() == getOperand2())
1233
1234
1235
1236
1237
1238
1239 return constFoldBinaryOp(
1240 adaptor.getOperands(),
1241 [](const APInt &a, const APInt &b) { return a ^ b; });
1242 }
1243
1244
1245
1246
1247
1248 namespace {
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274 struct ConvertSelectionOpToSelect final : OpRewritePatternspirv::SelectionOp {
1276
1277 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1279 Operation *op = selectionOp.getOperation();
1281
1282 if (body.empty()) {
1283 return failure();
1284 }
1285
1286
1287
1288 if (llvm::range_size(body) != 4) {
1289 return failure();
1290 }
1291
1292 Block *headerBlock = selectionOp.getHeaderBlock();
1293 if (!onlyContainsBranchConditionalOp(headerBlock)) {
1294 return failure();
1295 }
1296
1297 auto brConditionalOp =
1298 castspirv::BranchConditionalOp(headerBlock->front());
1299
1302 Block *mergeBlock = selectionOp.getMergeBlock();
1303
1304 if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1305 return failure();
1306
1307 Value trueValue = getSrcValue(trueBlock);
1308 Value falseValue = getSrcValue(falseBlock);
1309 Value ptrValue = getDstPtr(trueBlock);
1310 auto storeOpAttributes =
1311 castspirv::StoreOp(trueBlock->front())->getAttrs();
1312
1313 auto selectOp = rewriter.createspirv::SelectOp(
1314 selectionOp.getLoc(), trueValue.getType(),
1315 brConditionalOp.getCondition(), trueValue, falseValue);
1316 rewriter.createspirv::StoreOp(selectOp.getLoc(), ptrValue,
1317 selectOp.getResult(), storeOpAttributes);
1318
1319
1321 return success();
1322 }
1323
1324 private:
1325
1326
1327
1328
1329
1330
1331 LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
1332 Block *mergeBlock) const;
1333
1334 bool onlyContainsBranchConditionalOp(Block *block) const {
1335 return llvm::hasSingleElement(*block) &&
1336 isaspirv::BranchConditionalOp(block->front());
1337 }
1338
1339 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
1340 return lhs->getDiscardableAttrDictionary() ==
1341 rhs->getDiscardableAttrDictionary() &&
1342 lhs.getProperties() == rhs.getProperties();
1343 }
1344
1345
1346 Value getSrcValue(Block *block) const {
1347 auto storeOp = castspirv::StoreOp(block->front());
1348 return storeOp.getValue();
1349 }
1350
1351
1352 Value getDstPtr(Block *block) const {
1353 auto storeOp = castspirv::StoreOp(block->front());
1354 return storeOp.getPtr();
1355 }
1356 };
1357
1358 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1359 Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
1360
1361 if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1362 return failure();
1363 }
1364
1365 auto trueBrStoreOp = dyn_castspirv::StoreOp(trueBlock->front());
1366 auto trueBrBranchOp =
1367 dyn_castspirv::BranchOp(*std::next(trueBlock->begin()));
1368 auto falseBrStoreOp = dyn_castspirv::StoreOp(falseBlock->front());
1369 auto falseBrBranchOp =
1370 dyn_castspirv::BranchOp(*std::next(falseBlock->begin()));
1371
1372 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1373 !falseBrBranchOp) {
1374 return failure();
1375 }
1376
1377
1378
1379
1380
1381
1382 bool isScalarOrVector =
1383 llvm::castspirv::SPIRVType(trueBrStoreOp.getValue().getType())
1384 .isScalarOrVector();
1385
1386
1387
1388 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1389 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1390 return failure();
1391 }
1392
1393 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1394 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1395 return failure();
1396 }
1397
1398 return success();
1399 }
1400 }
1401
1402 void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1404 results.add(context);
1405 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
static MLIRContext * getContext(OpFoldResult val)
static Attribute extractCompositeElement(Attribute composite, ArrayRef< unsigned > indices)
MulExtendedFold< spirv::UMulExtendedOp, false > UMulExtendedOpFold
static std::optional< bool > getScalarOrSplatBoolAttr(Attribute attr)
Returns the boolean value under the hood if the given boolAttr is a scalar or splat vector bool const...
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Block * getSuccessor(unsigned i)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult matchAndRewrite(spirv::IAddCarryOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.