MLIR: lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include
39 #include
40
41 using namespace mlir;
42
43 #define DEBUG_TYPE "vector-narrow-type-emulation"
44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
45 #define DBGSNL() (llvm::dbgs() << "\n")
46 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
47
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
81 int numSrcElems,
82 int numSrcElemsPerDest,
83 int numFrontPadElems = 0) {
84
85 assert(numFrontPadElems < numSrcElemsPerDest &&
86 "numFrontPadElems must be less than numSrcElemsPerDest");
87
88 auto numDestElems =
89 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
90 numSrcElemsPerDest;
91
94
95
96 while (maskOp &&
97 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
98 maskOp)) {
99 if (auto extractOp = dyn_castvector::ExtractOp(maskOp)) {
100 maskOp = extractOp.getVector().getDefiningOp();
101 extractOps.push_back(extractOp);
102 }
103 }
104
105 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
106 maskOp))
107 return failure();
108
109
110
112 cast(maskOp->getResultTypes()[0]).getShape());
113 maskShape.back() = numDestElems;
115 std::optional<Operation *> newMask =
117 .Casevector::CreateMaskOp(
118 [&](auto createMaskOp) -> std::optional<Operation *> {
119 OperandRange maskOperands = createMaskOp.getOperands();
120
121
122
123
124
125
128 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
131 rewriter, loc, s0, origIndex);
133 newMaskOperands.push_back(
135 return rewriter.createvector::CreateMaskOp(loc, newMaskType,
136 newMaskOperands);
137 })
138 .Casevector::ConstantMaskOp(
139 [&](auto constantMaskOp) -> std::optional<Operation *> {
140
142 constantMaskOp.getMaskDimSizes());
143 int64_t &maskIndex = maskDimSizes.back();
145 numSrcElemsPerDest);
146 return rewriter.createvector::ConstantMaskOp(loc, newMaskType,
147 maskDimSizes);
148 })
149 .Casearith::ConstantOp([&](auto constantOp)
150 -> std::optional<Operation *> {
151
152 if (maskShape.size() != 1)
153 return std::nullopt;
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168 auto originalMask =
169 cast(constantOp.getValue());
171 paddedMaskValues.append(originalMask.template value_begin(),
172 originalMask.template value_end());
173 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false);
174
175
177 for (size_t i = 0; i < paddedMaskValues.size();
178 i += numSrcElemsPerDest) {
179 bool combinedValue = false;
180 for (int j = 0; j < numSrcElemsPerDest; ++j) {
181 combinedValue |= paddedMaskValues[i + j];
182 }
183 compressedMaskValues.push_back(combinedValue);
184 }
185 return rewriter.createarith::ConstantOp(
187 });
188
189 if (!newMask)
190 return failure();
191
192 while (!extractOps.empty()) {
193 newMask = rewriter.createvector::ExtractOp(
194 loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
195 extractOps.pop_back();
196 }
197
198 return *newMask;
199 }
200
201
202
203
204
205
206
207
208
209
210
211
212
213
215 Value src, int64_t offset,
216 int64_t numElemsToExtract) {
217 auto vectorType = cast(src.getType());
218 assert(vectorType.getRank() == 1 && "expected source to be rank-1-D vector ");
219 assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
220 "subvector out of bounds");
221
222
223
224 if (vectorType.getNumElements() == numElemsToExtract)
225 return src;
226
228 auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract});
230
231 auto resultVectorType =
232 VectorType::get({numElemsToExtract}, vectorType.getElementType());
233 return rewriter
234 .createvector::ExtractStridedSliceOp(loc, resultVectorType, src,
235 offsets, sizes, strides)
236 ->getResult(0);
237 }
238
239
240
241
242
243
244
245
246
247
249 Value src, Value dest, int64_t offset) {
250 [[maybe_unused]] auto srcVecTy = cast(src.getType());
251 [[maybe_unused]] auto destVecTy = cast(dest.getType());
252 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
253 "expected source and dest to be rank-1 vector types");
254
255
256 if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
257 return src;
258
261 return rewriter.createvector::InsertStridedSliceOp(loc, destVecTy, src,
262 dest, offsets, strides);
263 }
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
287 int64_t numElemsToExtract) {
288 auto srcVecTy = cast(src.getType());
289 assert(srcVecTy.getRank() == 1 && "expected source to be rank-1-D vector ");
290
291
292
293 assert(numElemsToExtract <= srcVecTy.getNumElements() &&
294 "subvector out of bounds");
295
296
297
298 if (srcVecTy.getNumElements() == numElemsToExtract)
299 return src;
300
301 for (int i = 0; i < numElemsToExtract; ++i) {
302 Value extractLoc =
303 (i == 0) ? dyn_cast(offset)
304 : rewriter.createarith::AddIOp(
305 loc, rewriter.getIndexType(), dyn_cast(offset),
306 rewriter.createarith::ConstantIndexOp(loc, i));
307 auto extractOp = rewriter.createvector::ExtractOp(loc, src, extractLoc);
308 dest = rewriter.createvector::InsertOp(loc, extractOp, dest, i);
309 }
310 return dest;
311 }
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
330 int64_t numElemsToInsert) {
331 auto srcVecTy = cast(src.getType());
332 auto destVecTy = cast(dest.getType());
333 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
334 "expected source and dest to be rank-1 vector types");
335 (void)srcVecTy;
336 (void)destVecTy;
337 assert(numElemsToInsert > 0 &&
338 "the number of elements to insert must be greater than 0");
339
340
341
342 assert(numElemsToInsert <= destVecTy.getNumElements() &&
343 "subvector out of bounds");
344
346 for (int64_t i = 0; i < numElemsToInsert; ++i) {
347 auto insertLoc = i == 0
348 ? destOffsetVal
349 : rewriter.createarith::AddIOp(
351 rewriter.createarith::ConstantIndexOp(loc, i));
352 auto extractOp = rewriter.createvector::ExtractOp(loc, src, i);
353 dest = rewriter.createvector::InsertOp(loc, extractOp, dest, insertLoc);
354 }
355 return dest;
356 }
357
358
359
360
361
362
363
367 int64_t numContainerElemsToLoad,
368 Type emulatedElemTy,
369 Type containerElemTy) {
372 auto newLoad = rewriter.createvector::LoadOp(
373 loc, VectorType::get(numContainerElemsToLoad, containerElemTy), base,
375 return rewriter.createvector::BitCastOp(
376 loc,
377 VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
378 emulatedElemTy),
379 newLoad);
380 }
381
382
383
385 VectorType downcastType,
386 VectorType upcastType, Value mask,
388 assert(
389 downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
390 upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
391 "expected input and output number of bits to match");
392 if (trueValue.getType() != downcastType) {
393 trueValue = builder.createvector::BitCastOp(loc, downcastType, trueValue);
394 }
395 if (falseValue.getType() != downcastType) {
396 falseValue =
397 builder.createvector::BitCastOp(loc, downcastType, falseValue);
398 }
399 Value selectedType =
400 builder.createarith::SelectOp(loc, mask, trueValue, falseValue);
401
402 return builder.createvector::BitCastOp(loc, upcastType, selectedType);
403 }
404
405
406
407
408
409
410
411
412
413
414
415
416
417
421 assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
422
423
424
425 auto atomicOp = builder.creatememref::GenericAtomicRMWOp(
426 loc, linearizedMemref, ValueRange{storeIdx});
427 Value origValue = atomicOp.getCurrentValue();
428
431
432
433
434 auto oneElemVecType = VectorType::get({1}, origValue.getType());
435 Value origVecValue = builder.createvector::FromElementsOp(
436 loc, oneElemVecType, ValueRange{origValue});
437
438
439 Value maskedValue =
441 oneElemVecType, mask, valueToStore, origVecValue);
442 auto scalarMaskedValue =
443 builder.createvector::ExtractOp(loc, maskedValue, 0);
444 builder.creatememref::AtomicYieldOp(loc, scalarMaskedValue);
445 }
446
447
448
452 assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
453
454 auto oneElemVecType =
456 Value origVecValue = builder.createvector::LoadOp(
457 loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
458 origVecValue = builder.createvector::BitCastOp(loc, valueToStore.getType(),
459 origVecValue);
460
461 Value maskedValue =
463 oneElemVecType, mask, valueToStore, origVecValue);
464 builder.createvector::StoreOp(loc, maskedValue, linearizedMemref,
465 linearizedIndex);
466 }
467
468
469
470
471
472
473
474
475
476
479 int64_t extractOffset,
480 int64_t sliceNumElements,
481 int64_t insertOffset) {
482 assert(vector.getType().getRank() == 1 && "expected 1-D vector");
483 auto vectorElementType = vector.getType().getElementType();
484
485
486 assert(
487 sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
488 "sliceNumElements * vector element size must be less than or equal to 8");
489 assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
490 "vector element must be a valid sub-byte type");
491 auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
492 auto emptyByteVector = rewriter.createarith::ConstantOp(
493 loc, VectorType::get({emulatedPerContainerElem}, vectorElementType),
495 VectorType::get({emulatedPerContainerElem}, vectorElementType)));
497 extractOffset, sliceNumElements);
499 insertOffset);
500 }
501
502 namespace {
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
549
550 ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
552 disableAtomicRMW(disableAtomicRMW) {}
553
554 LogicalResult
555 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
557
558
559 if (op.getValueToStore().getType().getRank() != 1)
561 "only 1-D vectors are supported ATM");
562
563 auto loc = op.getLoc();
564
565 auto valueToStore = cast(op.getValueToStore());
566 auto containerElemTy =
567 cast(adaptor.getBase().getType()).getElementType();
568 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
570 int containerBits = containerElemTy.getIntOrFloatBitWidth();
571
572
573 if (containerBits % emulatedBits != 0) {
575 op, "impossible to pack emulated elements into container elements "
576 "(bit-wise misalignment)");
577 }
578 int emulatedPerContainerElem = containerBits / emulatedBits;
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593 auto origElements = valueToStore.getType().getNumElements();
594
595 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
596
597
598
599
600 auto trailingDim = op.getBase().getType().getShape().back();
601 bool trailingDimsMatch =
602 ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
603
604 auto stridedMetadata =
605 rewriter.creatememref::ExtractStridedMetadataOp(loc, op.getBase());
606
607
608
611 std::tie(linearizedInfo, linearizedIndices) =
613 rewriter, loc, emulatedBits, containerBits,
614 stridedMetadata.getConstifiedMixedOffset(),
615 stridedMetadata.getConstifiedMixedSizes(),
616 stridedMetadata.getConstifiedMixedStrides(),
618
619 std::optional<int64_t> foldedNumFrontPadElems =
620 (isDivisibleInSize && trailingDimsMatch)
621 ? 0
623
624 if (!foldedNumFrontPadElems) {
626 op, "subbyte store emulation: dynamic front padding size is "
627 "not yet implemented");
628 }
629
630 auto memrefBase = cast(adaptor.getBase());
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662 bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
663
664 if (!emulationRequiresPartialStores) {
665
666 auto numElements = origElements / emulatedPerContainerElem;
667 auto bitCast = rewriter.createvector::BitCastOp(
669 op.getValueToStore());
671 op, bitCast.getResult(), memrefBase,
673 return success();
674 }
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707 Value currentDestIndex =
709
710 auto currentSourceIndex = 0;
711
712
713 auto subWidthStoreMaskType =
715
717
718
719
720
721
722 auto frontSubWidthStoreElem =
723 (emulatedPerContainerElem - *foldedNumFrontPadElems) %
724 emulatedPerContainerElem;
725 if (frontSubWidthStoreElem > 0) {
726 SmallVector frontMaskValues(emulatedPerContainerElem, false);
727 if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
728 std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
729 origElements, true);
730 frontSubWidthStoreElem = origElements;
731 } else {
732 std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
733 *foldedNumFrontPadElems, true);
734 }
735 auto frontMask = rewriter.createarith::ConstantOp(
737
738 currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
739 auto value =
741 frontSubWidthStoreElem, *foldedNumFrontPadElems);
742
743 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
744 cast(value), frontMask.getResult());
745 }
746
747 if (currentSourceIndex >= origElements) {
749 return success();
750 }
751
752
753
754 auto constantOne = rewriter.createarith::ConstantIndexOp(loc, 1);
755 currentDestIndex = rewriter.createarith::AddIOp(
757
758
759
760
761 int64_t fullWidthStoreSize =
762 (origElements - currentSourceIndex) / emulatedPerContainerElem;
763 int64_t numNonFullWidthElements =
764 fullWidthStoreSize * emulatedPerContainerElem;
765 if (fullWidthStoreSize > 0) {
767 rewriter, loc, valueToStore, currentSourceIndex,
768 numNonFullWidthElements);
769
770 auto originType = cast(fullWidthStorePart.getType());
773 {originType.getNumElements() / emulatedPerContainerElem},
774 memrefElemType);
775 auto bitCast = rewriter.createvector::BitCastOp(loc, storeType,
776 fullWidthStorePart);
777 rewriter.createvector::StoreOp(loc, bitCast.getResult(), memrefBase,
778 currentDestIndex);
779
780 currentSourceIndex += numNonFullWidthElements;
781 currentDestIndex = rewriter.createarith::AddIOp(
782 loc, rewriter.getIndexType(), currentDestIndex,
783 rewriter.createarith::ConstantIndexOp(loc, fullWidthStoreSize));
784 }
785
786
787
788
789 auto remainingElements = origElements - currentSourceIndex;
790 if (remainingElements != 0) {
791 auto subWidthStorePart =
793 currentSourceIndex, remainingElements, 0);
794
795
796 auto maskValues = SmallVector(emulatedPerContainerElem, 0);
797 std::fill_n(maskValues.begin(), remainingElements, 1);
798 auto backMask = rewriter.createarith::ConstantOp(
800
801 storeFunc(rewriter, loc, memrefBase, currentDestIndex,
802 cast(subWidthStorePart), backMask.getResult());
803 }
804
806 return success();
807 }
808
809 private:
810 const bool disableAtomicRMW;
811 };
812
813
814
815
816
817
818 struct ConvertVectorMaskedStore final
821
822 LogicalResult
823 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
825
826
827 if (op.getValueToStore().getType().getRank() != 1)
829 "only 1-D vectors are supported ATM");
830
831 auto loc = op.getLoc();
832 auto containerElemTy =
833 cast(adaptor.getBase().getType()).getElementType();
834 Type emulatedElemTy = op.getValueToStore().getType().getElementType();
836 int containerBits = containerElemTy.getIntOrFloatBitWidth();
837
838
839 if (containerBits % emulatedBits != 0) {
841 op, "impossible to pack emulated elements into container elements "
842 "(bit-wise misalignment)");
843 }
844
845 int emulatedPerContainerElem = containerBits / emulatedBits;
846 int origElements = op.getValueToStore().getType().getNumElements();
847 if (origElements % emulatedPerContainerElem != 0)
848 return failure();
849
850 auto stridedMetadata =
851 rewriter.creatememref::ExtractStridedMetadataOp(loc, op.getBase());
854 std::tie(linearizedInfo, linearizedIndicesOfr) =
856 rewriter, loc, emulatedBits, containerBits,
857 stridedMetadata.getConstifiedMixedOffset(),
858 stridedMetadata.getConstifiedMixedSizes(),
859 stridedMetadata.getConstifiedMixedStrides(),
861 Value linearizedIndices =
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
897 rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
898 if (failed(newMask))
899 return failure();
900
901 auto numElements = (origElements + emulatedPerContainerElem - 1) /
902 emulatedPerContainerElem;
903 auto newType = VectorType::get(numElements, containerElemTy);
904 auto passThru = rewriter.createarith::ConstantOp(
905 loc, newType, rewriter.getZeroAttr(newType));
906
907 auto newLoad = rewriter.createvector::MaskedLoadOp(
908 loc, newType, adaptor.getBase(), linearizedIndices,
909 newMask.value()->getResult(0), passThru);
910
911 auto newBitCastType =
912 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
913 Value valueToStore =
914 rewriter.createvector::BitCastOp(loc, newBitCastType, newLoad);
915 valueToStore = rewriter.createarith::SelectOp(
916 loc, op.getMask(), op.getValueToStore(), valueToStore);
917 valueToStore =
918 rewriter.createvector::BitCastOp(loc, newType, valueToStore);
919
921 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
922 valueToStore);
923 return success();
924 }
925 };
926
927
928
929
930
931
934
935 LogicalResult
936 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
938
939
940 if (op.getVectorType().getRank() != 1)
942 "only 1-D vectors are supported ATM");
943
944 auto loc = op.getLoc();
945 auto containerElemTy =
946 cast(adaptor.getBase().getType()).getElementType();
947 Type emulatedElemTy = op.getType().getElementType();
949 int containerBits = containerElemTy.getIntOrFloatBitWidth();
950
951
952 if (containerBits % emulatedBits != 0) {
954 op, "impossible to pack emulated elements into container elements "
955 "(bit-wise misalignment)");
956 }
957 int emulatedPerContainerElem = containerBits / emulatedBits;
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988 auto origElements = op.getVectorType().getNumElements();
989
990 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
991
992 auto stridedMetadata =
993 rewriter.creatememref::ExtractStridedMetadataOp(loc, op.getBase());
994
997 std::tie(linearizedInfo, linearizedIndices) =
999 rewriter, loc, emulatedBits, containerBits,
1000 stridedMetadata.getConstifiedMixedOffset(),
1001 stridedMetadata.getConstifiedMixedSizes(),
1002 stridedMetadata.getConstifiedMixedStrides(),
1004
1005 std::optional<int64_t> foldedIntraVectorOffset =
1006 isDivisibleInSize ? 0
1008
1009
1010 int64_t maxintraDataOffset =
1011 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1012 auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
1013 emulatedPerContainerElem);
1015 emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
1016 numElements, emulatedElemTy, containerElemTy);
1017
1018 if (!foldedIntraVectorOffset) {
1019 auto resultVector = rewriter.createarith::ConstantOp(
1020 loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1024 } else if (!isDivisibleInSize) {
1026 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1027 }
1029 return success();
1030 }
1031 };
1032
1033
1034
1035
1036
1037
1038 struct ConvertVectorMaskedLoad final
1041
1042 LogicalResult
1043 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1045
1046 if (op.getVectorType().getRank() != 1)
1048 "only 1-D vectors are supported ATM");
1049
1050 auto loc = op.getLoc();
1051
1052 auto containerElemTy =
1053 cast(adaptor.getBase().getType()).getElementType();
1054 Type emulatedElemTy = op.getType().getElementType();
1056 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1057
1058
1059 if (containerBits % emulatedBits != 0) {
1061 op, "impossible to pack emulated elements into container elements "
1062 "(bit-wise misalignment)");
1063 }
1064 int emulatedPerContainerElem = containerBits / emulatedBits;
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108 auto origType = op.getVectorType();
1109 auto origElements = origType.getNumElements();
1110
1111 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1112
1113 auto stridedMetadata =
1114 rewriter.creatememref::ExtractStridedMetadataOp(loc, op.getBase());
1117 std::tie(linearizedInfo, linearizedIndices) =
1119 rewriter, loc, emulatedBits, containerBits,
1120 stridedMetadata.getConstifiedMixedOffset(),
1121 stridedMetadata.getConstifiedMixedSizes(),
1122 stridedMetadata.getConstifiedMixedStrides(),
1124
1125 std::optional<int64_t> foldedIntraVectorOffset =
1126 isDivisibleInSize ? 0
1128
1129 int64_t maxIntraDataOffset =
1130 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1131 FailureOr<Operation *> newMask =
1133 emulatedPerContainerElem, maxIntraDataOffset);
1134 if (failed(newMask))
1135 return failure();
1136
1137 Value passthru = op.getPassThru();
1138
1139 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1140 emulatedPerContainerElem);
1141 auto loadType = VectorType::get(numElements, containerElemTy);
1142 auto newBitcastType =
1143 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1144
1145 auto emptyVector = rewriter.createarith::ConstantOp(
1146 loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
1147 if (!foldedIntraVectorOffset) {
1149 rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
1150 origElements);
1151 } else if (!isDivisibleInSize) {
1153 *foldedIntraVectorOffset);
1154 }
1155 auto newPassThru =
1156 rewriter.createvector::BitCastOp(loc, loadType, passthru);
1157
1158
1159 auto newLoad = rewriter.createvector::MaskedLoadOp(
1160 loc, loadType, adaptor.getBase(),
1162 newMask.value()->getResult(0), newPassThru);
1163
1164
1165
1166 auto bitCast =
1167 rewriter.createvector::BitCastOp(loc, newBitcastType, newLoad);
1168
1169 Value mask = op.getMask();
1171 numElements * emulatedPerContainerElem, rewriter.getI1Type());
1172
1173 auto emptyMask = rewriter.createarith::ConstantOp(
1174 loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
1175 if (!foldedIntraVectorOffset) {
1178 origElements);
1179 } else if (!isDivisibleInSize) {
1181 *foldedIntraVectorOffset);
1182 }
1183
1185 rewriter.createarith::SelectOp(loc, mask, bitCast, passthru);
1186 if (!foldedIntraVectorOffset) {
1188 rewriter, loc, result, op.getPassThru(),
1190 } else if (!isDivisibleInSize) {
1192 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1193 }
1195
1196 return success();
1197 }
1198 };
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215 static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1216 Type multiByteScalarTy) {
1217 assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
1218
1219 int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1221
1222 assert(subByteBits < 8 && "Not a sub-byte scalar type!");
1223 assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1224 assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
1225
1226 int elemsPerMultiByte = multiByteBits / subByteBits;
1227
1228
1229 return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1230 }
1231
1232
1233
1234
1235
1236
1237 struct ConvertVectorTransferRead final
1240
1241 LogicalResult
1242 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1244
1245
1246 if (op.getVectorType().getRank() != 1)
1248 "only 1-D vectors are supported ATM");
1249
1250 auto loc = op.getLoc();
1251 auto containerElemTy =
1252 cast(adaptor.getBase().getType()).getElementType();
1253 Type emulatedElemTy = op.getType().getElementType();
1255 int containerBits = containerElemTy.getIntOrFloatBitWidth();
1256
1257
1258 if (containerBits % emulatedBits != 0) {
1260 op, "impossible to pack emulated elements into container elements "
1261 "(bit-wise misalignment)");
1262 }
1263 int emulatedPerContainerElem = containerBits / emulatedBits;
1264
1265 auto origElements = op.getVectorType().getNumElements();
1266
1267
1268 bool isDivisibleInSize =
1269 fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1270
1271 auto newPadding = rewriter.createarith::ExtUIOp(loc, containerElemTy,
1272 adaptor.getPadding());
1273
1274 auto stridedMetadata =
1275 rewriter.creatememref::ExtractStridedMetadataOp(loc, op.getBase());
1276
1279 std::tie(linearizedInfo, linearizedIndices) =
1281 rewriter, loc, emulatedBits, containerBits,
1282 stridedMetadata.getConstifiedMixedOffset(),
1283 stridedMetadata.getConstifiedMixedSizes(),
1284 stridedMetadata.getConstifiedMixedStrides(),
1286
1287 std::optional<int64_t> foldedIntraVectorOffset =
1288 isDivisibleInSize ? 0
1290
1291 int64_t maxIntraDataOffset =
1292 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1293 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1294 emulatedPerContainerElem);
1295
1296 auto newRead = rewriter.createvector::TransferReadOp(
1297 loc, VectorType::get(numElements, containerElemTy), adaptor.getBase(),
1299 newPadding);
1300
1301 auto bitCast = rewriter.createvector::BitCastOp(
1302 loc,
1303 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1304 newRead);
1305
1306 Value result = bitCast->getResult(0);
1307 if (!foldedIntraVectorOffset) {
1308 auto zeros = rewriter.createarith::ConstantOp(
1309 loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1312 origElements);
1313 } else if (!isDivisibleInSize) {
1315 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1316 }
1318
1319 return success();
1320 }
1321 };
1322 }
1323
1324
1325
1326
1327
1328 namespace {
1329
1330
1331
1332 struct SourceElementRange {
1333
1334 int64_t sourceElementIdx;
1335
1336 int64_t sourceBitBegin;
1337 int64_t sourceBitEnd;
1338 };
1339
1340 struct SourceElementRangeList : public SmallVector {
1341
1342
1343
1344
1345
1346 int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
1347 int64_t res = 0;
1348 for (int64_t i = 0; i < shuffleIdx; ++i)
1349 res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1350 return res;
1351 }
1352 };
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368 struct BitCastBitsEnumerator {
1369 BitCastBitsEnumerator(VectorType sourceVectorType,
1370 VectorType targetVectorType);
1371
1372 int64_t getMaxNumberOfEntries() {
1373 int64_t numVectors = 0;
1374 for (const auto &l : sourceElementRanges)
1375 numVectors = std::max(numVectors, (int64_t)l.size());
1376 return numVectors;
1377 }
1378
1379 VectorType sourceVectorType;
1380 VectorType targetVectorType;
1382 };
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455 struct BitCastRewriter {
1456
1457 struct Metadata {
1460 };
1461
1462 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1463
1464
1465 LogicalResult commonPrecondition(PatternRewriter &rewriter,
1466 VectorType preconditionType, Operation *op);
1467
1468
1470 precomputeMetadata(IntegerType shuffledElementType);
1471
1472
1473
1475 Value initialValue, Value runningResult,
1476 const BitCastRewriter::Metadata &metadata);
1477
1478 private:
1479
1480
1481 BitCastBitsEnumerator enumerator;
1482 };
1483
1484 }
1485
1486 [[maybe_unused]] static raw_ostream &
1488 for (const auto &l : vec) {
1490 os << "{ " << it.value().sourceElementIdx << ": b@["
1491 << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
1492 << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
1493 }
1494 os << "\n";
1495 }
1496 return os;
1497 }
1498
1499 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1500 VectorType targetVectorType)
1501 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1502
1503 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1504 "requires -D non-scalable vector type");
1505 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1506 "requires -D non-scalable vector type");
1507 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1508 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1509 LDBG("sourceVectorType: " << sourceVectorType);
1510
1511 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1512 int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1513 LDBG("targetVectorType: " << targetVectorType);
1514
1515 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1516 (void)mostMinorSourceDim;
1517 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1518 "source and target bitwidths must match");
1519
1520
1522 for (int64_t resultBit = 0; resultBit < bitwidth;) {
1523 int64_t resultElement = resultBit / targetBitWidth;
1524 int64_t resultBitInElement = resultBit % targetBitWidth;
1525 int64_t sourceElementIdx = resultBit / sourceBitWidth;
1526 int64_t sourceBitInElement = resultBit % sourceBitWidth;
1527 int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1528 targetBitWidth - resultBitInElement);
1529 sourceElementRanges[resultElement].push_back(
1530 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1531 resultBit += step;
1532 }
1533 }
1534
1535 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1536 VectorType targetVectorType)
1537 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1538 LDBG("\n" << enumerator.sourceElementRanges);
1539 }
1540
1541
1542
1544 VectorType preconditionType,
1546 if (!preconditionType || preconditionType.isScalable())
1548
1549
1550
1551 unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1552 if (bitwidth % 8 != 0)
1554
1555 return success();
1556 }
1557
1558 LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
1559 VectorType preconditionType,
1561 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1563
1564 if (!preconditionType || preconditionType.getRank() != 1)
1566
1568 }
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1602 VectorType subByteVecTy,
1603 Type containerTy,
1606 "container element type is not a scalar");
1607
1608
1609
1610 if (!subByteVecTy)
1612
1613 unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1615
1616
1617 assert(containerBits % 8 == 0 && "Not a multi-byte scalar type!");
1618
1619
1620 if (subByteBits != 2 && subByteBits != 4)
1622 op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
1623
1624
1625 if (containerBits % subByteBits != 0)
1627
1628
1629 if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1631 op, "not possible to fit this sub-byte vector type into a vector of "
1632 "the given multi-byte type");
1633
1634 return success();
1635 }
1636
1638 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1640 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1641 shuffleIdx < e; ++shuffleIdx) {
1644
1645
1646 for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
1647 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1648 ? srcEltRangeList[shuffleIdx].sourceElementIdx
1649 : 0;
1650 shuffles.push_back(sourceElement);
1651
1652 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1653 ? srcEltRangeList[shuffleIdx].sourceBitBegin
1654 : 0;
1655 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1656 ? srcEltRangeList[shuffleIdx].sourceBitEnd
1657 : 0;
1659 shuffledElementType,
1660 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1661 bitLo, bitHi));
1662 masks.push_back(mask);
1663
1664 int64_t shiftRight = bitLo;
1665 shiftRightAmounts.push_back(
1667
1668 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1669 shiftLeftAmounts.push_back(
1671 }
1672
1673 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1674 }
1675 return result;
1676 }
1677
1678 Value BitCastRewriter::genericRewriteStep(
1680 Value runningResult, const BitCastRewriter::Metadata &metadata) {
1681
1682 auto shuffleOp = rewriter.createvector::ShuffleOp(
1683 loc, initialValue, initialValue, metadata.shuffles);
1684
1685
1686 VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1687 auto constOp = rewriter.createarith::ConstantOp(
1689 Value andValue = rewriter.createarith::AndIOp(loc, shuffleOp, constOp);
1690
1691
1692 auto shiftRightConstantOp = rewriter.createarith::ConstantOp(
1693 loc,
1695 Value shiftedRight =
1696 rewriter.createarith::ShRUIOp(loc, andValue, shiftRightConstantOp);
1697
1698
1699 auto shiftLeftConstantOp = rewriter.createarith::ConstantOp(
1700 loc,
1702 Value shiftedLeft =
1703 rewriter.createarith::ShLIOp(loc, shiftedRight, shiftLeftConstantOp);
1704
1705 runningResult =
1706 runningResult
1707 ? rewriter.createarith::OrIOp(loc, runningResult, shiftedLeft)
1708 : shiftedLeft;
1709
1710 return runningResult;
1711 }
1712
1713
1714
1715
1716
1717
1718
1720 Value subByteVec) {
1721 auto srcVecType = cast(subByteVec.getType());
1722 int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1723 assert(8 % srcBitwidth == 0 &&
1724 "Unsupported sub-byte type (not a divisor of i8)");
1725 int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1727
1728 vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1730 return rewriter.createvector::BitCastOp(loc, i8VecType, subByteVec);
1731 }
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1751 int bitIdx, int numBits) {
1752 auto srcType = cast(src.getType());
1753 Value shl = src;
1754 int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1755 assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1756 "Invalid bitIdx range");
1757 if (bitsToShiftLeft != 0) {
1758 Value shiftLeftValues = rewriter.createarith::ConstantOp(
1760 shl = rewriter.createarith::ShLIOp(loc, src, shiftLeftValues);
1761 }
1762
1763 int8_t bitsToShiftRight = 8 - numBits;
1764 Value shiftRightValues = rewriter.createarith::ConstantOp(
1766 Value shr = rewriter.createarith::ShRSIOp(loc, shl, shiftRightValues);
1767 return shr;
1768 }
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1793 int bitIdx, int numBits) {
1794 assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1795 "Invalid bitIdx range");
1796 auto srcType = cast(src.getType());
1797 int8_t bitsToShiftRight = bitIdx;
1798 Value shr = src;
1799 if (bitsToShiftRight != 0) {
1800 Value shiftRightValues = rewriter.createarith::ConstantOp(
1802 shr = rewriter.createarith::ShRUIOp(loc, src, shiftRightValues);
1803 }
1804 if (bitIdx + numBits == 8) {
1805 return shr;
1806 }
1807 uint8_t lowBitsMask = (1 << numBits) - 1;
1808 Value lowBitsMaskValues = rewriter.createarith::ConstantOp(
1810 return rewriter.createarith::AndIOp(loc, shr, lowBitsMaskValues);
1811 }
1812
1815
1816
1817
1820 [[maybe_unused]] auto srcVecType = cast(srcValue.getType());
1821 assert(srcVecType.getElementType().isSignlessInteger(4) &&
1822 "Expected i4 type");
1823
1824
1826
1827
1828
1829 Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1830 Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1831
1832
1833 return rewriter.createvector::InterleaveOp(loc, low, high);
1834 }
1835
1836
1837
1840 [[maybe_unused]] VectorType srcVecType = cast(srcValue.getType());
1841 assert(srcVecType.getElementType().isSignlessInteger(2) &&
1842 "Expected i2 type");
1843
1844
1846
1847
1848
1849 Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1850
1851 Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1852
1853 Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1854
1855 Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866 Value interleave02 = rewriter.createvector::InterleaveOp(loc, vec0, vec2);
1867 Value interleave13 = rewriter.createvector::InterleaveOp(loc, vec1, vec3);
1868 return rewriter.createvector::InterleaveOp(loc, interleave02, interleave13);
1869 }
1870
1871
1872
1874 Value srcValue) {
1875 VectorType srcVecType = cast(srcValue.getType());
1876 assert(srcVecType.getElementType().isSignlessInteger(8) &&
1877 "Expected i8 type");
1878
1879
1880 auto deinterleaveOp = rewriter.createvector::DeinterleaveOp(loc, srcValue);
1881
1882
1883 constexpr int8_t i8LowBitMask = 0x0F;
1884 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1885 Value zeroOutMask = rewriter.createarith::ConstantOp(
1887 Value zeroOutLow = rewriter.createarith::AndIOp(
1888 loc, deinterleaveOp.getRes1(), zeroOutMask);
1889
1890
1891 constexpr int8_t bitsToShift = 4;
1892 auto shiftValues = rewriter.createarith::ConstantOp(
1894 Value shlHigh = rewriter.createarith::ShLIOp(loc, deinterleaveOp.getRes2(),
1895 shiftValues);
1896
1897
1898 auto mergedHiLowOp = rewriter.createarith::OrIOp(loc, zeroOutLow, shlHigh);
1899
1900
1901 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
1902 return rewriter.createvector::BitCastOp(loc, i4VecType, mergedHiLowOp);
1903 }
1904
1905 namespace {
1906
1907
1908
1909 struct RewriteBitCastOfTruncI : OpRewritePatternvector::BitCastOp {
1911
1912 LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1914
1915 auto truncOp =
1916 bitCastOp.getSource().template getDefiningOparith::TruncIOp();
1917 if (!truncOp)
1918 return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
1919
1920
1921 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1922 VectorType targetVectorType = bitCastOp.getResultVectorType();
1923 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1924 if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1925 return failure();
1926
1927
1928 Value truncValue = truncOp.getIn();
1929 auto shuffledElementType =
1931 Value runningResult;
1932 for (const BitCastRewriter ::Metadata &metadata :
1933 bcr.precomputeMetadata(shuffledElementType)) {
1934 runningResult = bcr.genericRewriteStep(
1935 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1936 }
1937
1938
1939 bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1940 shuffledElementType.getIntOrFloatBitWidth();
1941 if (narrowing) {
1942 if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1943 rewriter.replaceOp(bitCastOp, runningResult);
1944 } else {
1946 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1947 }
1948 } else {
1949 if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1950 rewriter.replaceOp(bitCastOp, runningResult);
1951 } else {
1953 bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1954 }
1955 }
1956
1957 return success();
1958 }
1959 };
1960 }
1961
1962
1963
1964
1965
1966 namespace {
1967
1968
1969
1970 template
1973
1976
1979
1980 auto bitCastOp = extOp.getIn().template getDefiningOpvector::BitCastOp();
1981 if (!bitCastOp)
1983
1984
1985 VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1986 VectorType targetVectorType = bitCastOp.getResultVectorType();
1987 BitCastRewriter bcr(sourceVectorType, targetVectorType);
1988 if (failed(bcr.commonPrecondition(
1989 rewriter, cast(extOp.getOut().getType()), bitCastOp)))
1990 return failure();
1991
1992
1993 Value runningResult;
1994 Value sourceValue = bitCastOp.getSource();
1995 auto shuffledElementType =
1997 for (const BitCastRewriter::Metadata &metadata :
1998 bcr.precomputeMetadata(shuffledElementType)) {
1999 runningResult = bcr.genericRewriteStep(
2000 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
2001 }
2002
2003
2004 bool narrowing =
2005 cast(extOp.getOut().getType()).getElementTypeBitWidth() <=
2006 shuffledElementType.getIntOrFloatBitWidth();
2007 if (narrowing) {
2009 extOp, cast(extOp.getOut().getType()), runningResult);
2010 } else {
2012 extOp, cast(extOp.getOut().getType()), runningResult);
2013 }
2014
2015 return success();
2016 }
2017 };
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053 template <typename ConversionOpType, bool isSigned>
2054 struct RewriteAlignedSubByteIntExt : OpRewritePattern {
2056
2057 LogicalResult matchAndRewrite(ConversionOpType conversionOp,
2059
2060 Value srcValue = conversionOp.getIn();
2061 VectorType srcVecType = dyn_cast(srcValue.getType());
2062 VectorType dstVecType = dyn_cast(conversionOp.getType());
2063
2064 if (failed(
2066 return failure();
2067
2068
2070 rewriter, srcVecType,
2071 rewriter.getI8Type(), conversionOp)))
2072 return failure();
2073
2074
2075 Location loc = conversionOp.getLoc();
2078 Value subByteExt;
2079 switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2080 case 2:
2081 subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
2082 break;
2083 case 4:
2084 subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
2085 break;
2086 default:
2087 return failure();
2088 }
2089
2090
2092 conversionOp, conversionOp.getType(), subByteExt);
2093 return success();
2094 }
2095 };
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114 struct RewriteAlignedSubByteIntTrunc : OpRewritePatternarith::TruncIOp {
2116
2119
2120 Value srcValue = truncOp.getIn();
2121 auto srcVecType = dyn_cast(srcValue.getType());
2122 auto dstVecType = dyn_cast(truncOp.getType());
2123 if (!srcVecType || !dstVecType)
2124 return failure();
2125
2127 return failure();
2128
2129
2130 if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2131 return failure();
2132
2133
2134
2136 rewriter, dstVecType,
2137 rewriter.getI8Type(), truncOp)))
2138 return failure();
2139
2140
2141 Location loc = truncOp.getLoc();
2142 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
2143 Value i8TruncVal =
2144 rewriter.createarith::TruncIOp(loc, i8VecType, srcValue);
2145
2146
2148
2149
2150 rewriter.replaceOp(truncOp, subByteTrunc);
2151 return success();
2152 }
2153 };
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167 struct RewriteVectorTranspose : OpRewritePatternvector::TransposeOp {
2169
2171 : OpRewritePatternvector::TransposeOp(context, benefit) {}
2172
2173 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
2175
2176 constexpr unsigned minNativeBitwidth = 8;
2177 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2178 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2179 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2181 "not a sub-byte transpose");
2182 }
2183
2184
2185 Location loc = transposeOp.getLoc();
2186
2187
2188
2189
2190 auto srcNativeVecType = srcSubByteVecType.cloneWith(
2191 std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
2192 Value extOp = rewriter.createarith::ExtSIOp(loc, srcNativeVecType,
2193 transposeOp.getVector());
2194 Value newTranspose = rewriter.createvector::TransposeOp(
2195 loc, extOp, transposeOp.getPermutation());
2196 VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2197 rewriter.replaceOpWithNewOparith::TruncIOp(transposeOp, dstSubByteVecType,
2198 newTranspose);
2199 return success();
2200 }
2201 };
2202
2203 }
2204
2205
2206
2207
2208
2209
2210 void vector::populateVectorNarrowTypeEmulationPatterns(
2211 const arith::NarrowTypeEmulationConverter &typeConverter,
2213
2214
2215
2216 patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2217 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2218 typeConverter, patterns.getContext());
2219
2220
2221
2222
2223 patterns.insert(patterns.getContext(), disableAtomicRMW);
2224 }
2225
2226 void vector::populateVectorNarrowTypeRewritePatterns(
2228
2229 patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCastarith::ExtUIOp,
2230 RewriteExtOfBitCastarith::ExtSIOp>(patterns.getContext(),
2231 benefit);
2232
2233
2234
2235
2236 patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, true>,
2237 RewriteAlignedSubByteIntExt<arith::SIToFPOp, true>,
2238 RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
2240
2242 .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, false>,
2243 RewriteAlignedSubByteIntExt<arith::UIToFPOp, false>>(
2245 }
2246
2247
2248 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
2250 patterns.add(patterns.getContext(), benefit);
2251 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
Base type for affine expression.
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult intraDataOffset
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.