MLIR: lib/Dialect/Linalg/Transforms/Vectorization.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
13
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include
45 #include <type_traits>
46
47 using namespace mlir;
49
50 #define DEBUG_TYPE "linalg-vectorization"
51
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54
55
56 static FailureOr<Operation *>
60 bool flatten1DDepthwiseConv = false);
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80 static LogicalResult
84
85
86
87
88
89
90
92
93
94
95 template
97 OpType res;
98 block.walk([&](OpType op) {
99 if (res) {
100 res = nullptr;
102 }
103 res = op;
105 });
106 return res;
107 }
108
109
110
113 int64_t nSize, int64_t wSize, int64_t cSize,
114 int64_t kwSize, int strideW, int dilationW,
115 int64_t wSizeStep, bool isSingleChanneled) {
117 if (isSingleChanneled) {
118
119
122 for (int64_t kw = 0; kw < kwSize; ++kw) {
123 for (int64_t w = 0; w < wSize; w += wSizeStep) {
124 result.push_back(rewriter.createvector::ExtractStridedSliceOp(
125 loc, input, ArrayRef<int64_t>{w + kw}, sizes, strides));
126 }
127 }
128 } else {
129
130
133 for (int64_t kw = 0; kw < kwSize; ++kw) {
134 for (int64_t w = 0; w < wSize; w += wSizeStep) {
135 result.push_back(rewriter.createvector::ExtractStridedSliceOp(
136 loc, input,
137 ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
138 sizes, strides));
139 }
140 }
141 }
142 return result;
143 }
144
145
146
149 int64_t kwSize) {
151
152
153 for (int64_t kw = 0; kw < kwSize; ++kw) {
154 result.push_back(rewriter.createvector::ExtractOp(
156 }
157 return result;
158 }
159
160
161
164 int64_t nSize, int64_t wSize, int64_t fSize,
165 int64_t wSizeStep, bool isSingleChanneled) {
167 if (isSingleChanneled) {
168
171 for (int64_t w = 0; w < wSize; w += wSizeStep) {
172 result.push_back(rewriter.createvector::ExtractStridedSliceOp(
174 }
175 } else {
176
177
180 for (int64_t w = 0; w < wSize; w += wSizeStep) {
181 result.push_back(rewriter.createvector::ExtractStridedSliceOp(
182 loc, res, ArrayRef<int64_t>{0, w, 0}, sizes, strides));
183 }
184 }
185 return result;
186 }
187
188
190 Value res, int64_t wSize, int64_t wSizeStep,
192 bool isSingleChanneled) {
193
194 if (isSingleChanneled) {
195
196
198 for (int64_t w = 0; w < wSize; w += wSizeStep) {
199 res = rewriter.createvector::InsertStridedSliceOp(
200 loc, resVals[w], res, ArrayRef<int64_t>{w}, strides);
201 }
202 } else {
203
204
206 for (int64_t w = 0; w < wSize; w += wSizeStep) {
207 res = rewriter.createvector::InsertStridedSliceOp(
209 strides);
210 }
211 }
212 return res;
213 }
214
215
216
219
220
221
222 LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
225
226
228
229
230
232
233
234
235
236
238 Type elementType,
239 std::optional dimPermutation = std::nullopt) const {
242 if (dimPermutation.has_value()) {
244 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
245 scalableDims =
246 applyPermutationMap(*dimPermutation, scalableVecDims);
247 } else {
248 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
249 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
250 }
251
253 }
254
255
256
257
258
261 std::optional maybeIndexingMap = std::nullopt);
262
263 private:
264
265
266 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
267 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
268 }
269
270
271
272
273 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
274 LinalgOp linalgOp);
275
276
277
278
279
281 LinalgOp linalgOp,
282 std::optional maybeMaskingMap);
283
284
285
286
287 bool isValidMaskingMap(AffineMap maskingMap) {
289 }
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
308 }
309
310
311
313
314
315
316
318
319
321
322
323
325
326
327
329
330
331
333 };
334
335 LogicalResult
336 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
337 LinalgOp linalgOp) {
338
339 for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
340 if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
341
342 iterSpaceValueSizes.push_back(rewriter.createarith::ConstantIndexOp(
343 linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
344 continue;
345 }
346
347
348
350 unsigned operandDimPos;
351 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
352 operandDimPos)))
353 return failure();
354
355 Value dynamicDim = linalgOp.hasPureTensorSemantics()
357 linalgOp.getLoc(), operand, operandDimPos)
359 linalgOp.getLoc(), operand, operandDimPos);
360 iterSpaceValueSizes.push_back(dynamicDim);
361 }
362
363 return success();
364 }
365
366
367
368
369 LogicalResult
373
375
376 if (!inputVectorSizes.empty()) {
377
378
379
380 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
381 scalableVecDims.append(inputScalableVecDims.begin(),
382 inputScalableVecDims.end());
383 } else {
384
385
386
387 canonicalVecShape = linalgOp.getStaticLoopRanges();
388 scalableVecDims.append(linalgOp.getNumLoops(), false);
389 }
390
391 LDBG("Canonical vector shape: ");
392 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
393 LLVM_DEBUG(llvm::dbgs() << "\n");
394 LDBG("Scalable vector dims: ");
395 LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
396 LLVM_DEBUG(llvm::dbgs() << "\n");
397
398 if (ShapedType::isDynamicShape(canonicalVecShape))
399 return failure();
400
401
402 initIterSpaceStaticSizes(linalgOp);
403
404
405
406
407 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
408 return failure();
409
410 return success();
411 }
412
413
414
415
416
417 Value VectorizationState::getOrCreateMaskFor(
419 std::optional maybeMaskingMap) {
420
421 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
422 "Ill-formed masking map.");
423
424
425 auto maskableOp = dyn_castvector::MaskableOpInterface(opToMask);
426 if (!maskableOp)
428
429 assert(!maskableOp.isMasked() &&
430 "Masking an operation that is already masked");
431
432
433 assert((!maybeMaskingMap || *maybeMaskingMap) &&
434 "Unexpected null mask permutation map");
436 maybeMaskingMap ? *maybeMaskingMap
438 linalgOp.getNumLoops(), rewriter.getContext());
439
440 LDBG("Masking map: " << maskingMap << "\n");
441
442
443
444 auto activeMaskIt = activeMaskCache.find(maskingMap);
445 if (activeMaskIt != activeMaskCache.end()) {
446 Value mask = activeMaskIt->second;
447 LDBG("Reusing mask: " << mask << "\n");
448 return mask;
449 }
450
451
452
453
454
455
456
458 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
459 auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
460 auto maskShape = maskType.getShape();
461
462 LDBG("Mask shape: ");
463 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
464 LLVM_DEBUG(llvm::dbgs() << "\n");
465
466 if (permutedStaticSizes == maskShape) {
467 LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
468 activeMaskCache[maskingMap] = Value();
470 }
471
472
475 assert(!maskShape.empty() && !upperBounds.empty() &&
476 "Masked 0-d vectors are not supported yet");
477
478
479 Value mask = rewriter.createvector::CreateMaskOp(linalgOp.getLoc(),
480 maskType, upperBounds);
481 LDBG("Creating new mask: " << mask << "\n");
482 activeMaskCache[maskingMap] = mask;
483 return mask;
484 }
485
488 LinalgOp linalgOp,
489 std::optional maybeIndexingMap) {
490 LDBG("Trying to mask: " << *opToMask << "\n");
491
492 std::optional maybeMaskingMap = std::nullopt;
493 if (maybeIndexingMap)
494 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
495
496
498 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
499
500 if (!mask) {
501 LDBG("No mask required\n");
502 return opToMask;
503 }
504
505
506 assert(opToMask && "Expected a valid operation to mask");
507 auto maskOp = castvector::MaskOp(
509 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
510
513 maskOpTerminator);
514
515 LDBG("Masked operation: " << *maskOp << "\n");
516 return maskOp;
517 }
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
538 "expected projected permutation");
540 assert(res.getNumDims() ==
541 (res.getNumResults() - res.getNumOfZeroResults()) &&
542 "expected reindexed map with same number of dims and results");
543 return res;
544 }
545
546
548 W,
549 Ncw,
550 Nwc
551 };
552
553
554
556
558
560
561
563
564
567
569
570
572 };
573
574 std::optionalvector::CombiningKind
576 using ::mlir::vector::CombiningKind;
577
578 if (!combinerOp)
579 return std::nullopt;
581 .Case<arith::AddIOp, arith::AddFOp>(
582 [&](auto op) { return CombiningKind::ADD; })
583 .Casearith::AndIOp([&](auto op) { return CombiningKind::AND; })
584 .Casearith::MaxSIOp([&](auto op) { return CombiningKind::MAXSI; })
585 .Casearith::MaxUIOp([&](auto op) { return CombiningKind::MAXUI; })
586 .Casearith::MaximumFOp([&](auto op) { return CombiningKind::MAXIMUMF; })
587 .Casearith::MaxNumFOp([&](auto op) { return CombiningKind::MAXNUMF; })
588 .Casearith::MinSIOp([&](auto op) { return CombiningKind::MINSI; })
590 .Casearith::MinimumFOp([&](auto op) { return CombiningKind::MINIMUMF; })
591 .Casearith::MinNumFOp([&](auto op) { return CombiningKind::MINNUMF; })
592 .Case<arith::MulIOp, arith::MulFOp>(
593 [&](auto op) { return CombiningKind::MUL; })
594 .Casearith::OrIOp([&](auto op) { return CombiningKind::OR; })
595 .Casearith::XOrIOp([&](auto op) { return CombiningKind::XOR; })
596 .Default([&](auto op) { return std::nullopt; });
597 }
598
599
600
601
602
603
604
605
607 auto linalgOp = cast(outputOperand->getOwner());
608 unsigned outputPos =
609 outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
610
612 if ((linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
613 combinerOps.size() != 1)
614 return nullptr;
615
616
617 return combinerOps[0];
618 }
619
620
621
623 auto dstVecType = dyn_cast(dstType);
624
625 if (dstVecType.getRank() == 0)
626 return value;
629 return value;
631 return b.createOrFoldvector::BroadcastOp(loc, dstVecType, value);
632 }
633
634
635
636
637
638
643 assert(maybeKind && "Failed precondition: could not get reduction kind");
644 return b.createvector::MultiDimReductionOp(
645 reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
646 }
647
649 return llvm::to_vector(
651 }
652
653
654
656 return isalinalg::ReduceOp(op) ||
657 (isalinalg::GenericOp(op) &&
659 }
660
661
662
663
664
665
666
671 auto linalgOp = cast(outputOperand->getOwner());
672 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
673
674
675
676
677
681 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
682 });
683 auto vectorType = state.getCanonicalVecType(
685
687 if (vectorType.getRank() > 0) {
690 rewriter.createarith::ConstantIndexOp(loc, 0));
692 assert(value.getType() == vectorType && "Incorrect type");
693 write = rewriter.createvector::TransferWriteOp(
694 loc, value, outputOperand->get(), indices, writeMap);
695 } else {
696
697 if (!isa(value.getType()))
698 value = rewriter.createvector::BroadcastOp(loc, vectorType, value);
699 assert(value.getType() == vectorType && "Incorrect type");
700 write = rewriter.createvector::TransferWriteOp(
702 }
703
704 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
705
706
707
708 if (auto maskOp = dyn_castvector::MaskingOpInterface(write)) {
709 auto maskedWriteOp = castvector::TransferWriteOp(maskOp.getMaskableOp());
710 SmallVector inBounds(maskedWriteOp.getVectorType().getRank(), true);
711 maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
712 }
713
714 LDBG("vectorized op: " << *write << "\n");
718 }
719
720
721
722
724 std::function<LogicalResult(Operation *, bool)>;
725
726
727
728
731
732
733
734
735
736
737
738
743 auto yieldOp = dyn_castlinalg::YieldOp(op);
744 if (!yieldOp)
746 for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
747
748
749 Value vectorValue = bvm.lookup(output.value());
750 Value newResult =
752 linalgOp.getDpsInitOperand(output.index()), state);
753 if (newResult)
754 newResults.push_back(newResult);
755 }
756
758 }
759
760
761
762
763
767 LinalgOp linalgOp) {
768 IndexOp indexOp = dyn_castlinalg::IndexOp(op);
769 if (!indexOp)
771 auto loc = indexOp.getLoc();
772
774 auto dim = indexOp.getDim();
775
776 auto indexVectorType =
778 state.getScalableVecDims()[dim]);
779 auto indexSteps = rewriter.createvector::StepOp(loc, indexVectorType);
780
781
782
783 if (dim == targetShape.size() - 1)
785
786
787
788 auto permPattern =
789 llvm::to_vector(llvm::seq(0, targetShape.size()));
790 std::swap(permPattern[dim], permPattern.back());
791 auto permMap =
793
794 auto broadCastOp = rewriter.createvector::BroadcastOp(
795 loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
796 indexSteps);
798 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
799 std::swap(transposition.back(), transposition[dim]);
800 auto transposeOp =
801 rewriter.createvector::TransposeOp(loc, broadCastOp, transposition);
803 }
804
805
806
807 static LogicalResult
809 tensor::ExtractOp extractOp = dyn_casttensor::ExtractOp(op);
810 if (!extractOp)
811 return failure();
812
813 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
814 return failure();
815
816
817
818 if (not extractOp.getIndices().empty()) {
819 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
820 return failure();
821 }
822
823 if (!llvm::all_of(extractOp->getResultTypes(),
824 VectorType::isValidElementType)) {
825 return failure();
826 }
827
828 return success();
829 }
830
831
832
833
834
835
836
837
838
839
840
843 tensor::ExtractOp extractOp,
845
846 auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
847 auto loc = extractOp.getLoc();
848
850 rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
851
852 const size_t numIndices = extractOp.getIndices().size();
853 for (size_t i = 1; i < numIndices; i++) {
854 Value dimIdx = rewriter.createarith::ConstantIndexOp(loc, i);
855
857 rewriter,
858 rewriter.createtensor::DimOp(loc, extractOp.getTensor(), dimIdx),
859 indexVecType);
860
861 offset = rewriter.createarith::MulIOp(loc, offset, dimSize);
862
864 rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
865
866 offset = rewriter.createarith::AddIOp(loc, extractOpIndex, offset);
867 }
868
869 return offset;
870 }
871
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
891 assert(
892 (linalgOp.hasDynamicShape() ||
893 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
894 "For statically shaped Linalg Ops, only one "
895 "non-unit loop dim is expected");
896 assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");
897
898 size_t idx = loopRanges.size() - 1;
899 for (; idx != 0; idx--)
900 if (loopRanges[idx] != 1)
901 break;
902
903 return idx;
904 }
905
906
908 VectorType resType) {
909
910 assert(((llvm::count_if(resType.getShape(),
911 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
912 "n-D vectors are not yet supported");
913
914
915
916
917
918 auto *block = linalgOp.getBlock();
919 if (isa(val))
920 return llvm::all_of(block->getArguments(),
921 [&val](Value v) { return (v != val); });
922
924 assert(defOp && "This is neither a block argument nor an operation result");
925
926
927
928
929 if (auto indexOp = dyn_castlinalg::IndexOp(defOp)) {
930 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
931 }
932
933 auto *ancestor = block->findAncestorOpInBlock(*defOp);
934
935
936 if (!ancestor)
937 return true;
938
939
940 if (isaarith::ConstantOp(ancestor))
941 return true;
942
943 bool result = true;
944 for (auto op : ancestor->getOperands())
946
947 return result;
948 }
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
968 bool &foundIndexOp, VectorType resType) {
969
970 assert(((llvm::count_if(resType.getShape(),
971 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
972 "n-D vectors are not yet supported");
973
974
975
976
977
978 auto *block = linalgOp.getBlock();
979 if (isa(val))
980 return llvm::all_of(block->getArguments(),
981 [&val](Value v) { return (v != val); });
982
984 assert(defOp && "This is neither a block argument nor an operation result");
985
986 if (auto indexOp = dyn_castlinalg::IndexOp(defOp)) {
988
989 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
990 return true;
991 }
992
993 auto *ancestor = block->findAncestorOpInBlock(*defOp);
994
995 if (!ancestor)
996 return false;
997
998
999
1000 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1001 return false;
1002
1003 bool result = false;
1004 for (auto op : ancestor->getOperands())
1006
1007 return result;
1008 }
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1024 LinalgOp &linalgOp, VectorType resType) {
1025
1026 auto inputShape = cast(extractOp.getTensor().getType());
1027
1028
1029 if (inputShape.getShape().empty())
1031
1032
1033
1034 bool isOutput1DVector =
1035 (llvm::count_if(resType.getShape(),
1036 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1037
1038 if (!isOutput1DVector)
1040
1041 bool leadingIdxsLoopInvariant = true;
1042
1043
1044
1045
1046
1047 auto indices = extractOp.getIndices();
1048 auto leadIndices = indices.drop_back(1);
1049
1050 for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1051 if (inputShape.getShape()[i] == 1)
1052 continue;
1053
1054 leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1055 }
1056
1057 if (!leadingIdxsLoopInvariant) {
1058 LDBG("Found gather load: " << extractOp);
1060 }
1061
1062
1063
1064
1065
1066 auto extractOpTrailingIdx = indices.back();
1067
1068
1069
1070 if (leadingIdxsLoopInvariant &&
1072 LDBG("Found scalar broadcast load: " << extractOp);
1073
1075 }
1076
1077
1078
1079
1080
1081 bool foundIndexOp = false;
1082 bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1083 foundIndexOp, resType);
1084
1085
1086 bool isRowVector = resType.getShape().back() != 1;
1087 isContiguousLoad &= (foundIndexOp && isRowVector);
1088
1089 if (isContiguousLoad) {
1090 LDBG("Found contigous load: " << extractOp);
1092 }
1093
1094
1095 LDBG("Found gather load: " << extractOp);
1097 }
1098
1099
1100
1101
1102
1106 tensor::ExtractOp extractOp = dyn_casttensor::ExtractOp(op);
1107 if (!extractOp)
1109 auto loc = extractOp.getLoc();
1110
1111
1112 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1113 auto maskConstantOp = rewriter.createarith::ConstantOp(
1114 loc,
1116 true));
1117 auto passThruConstantOp =
1118 rewriter.createarith::ConstantOp(loc, rewriter.getZeroAttr(resultType));
1119
1120
1121
1123 extractOp.getIndices().size(),
1124 rewriter.createarith::ConstantIndexOp(loc, 0));
1125
1128
1129
1132
1133
1134 Operation *gatherOp = rewriter.createvector::GatherOp(
1135 loc, resultType, extractOp.getTensor(), baseIndices, offset,
1136 maskConstantOp, passThruConstantOp);
1137 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1138
1139 LDBG("Vectorised as gather load: " << extractOp << "\n");
1141 }
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1162 for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1163 Value idx = bvm.lookup(extractOp.getIndices()[i]);
1165 transferReadIdxs.push_back(idx);
1166 continue;
1167 }
1168
1169 auto indexAs1dVector = rewriter.createvector::ShapeCastOp(
1170 loc,
1172 resultType.getScalableDims().back()),
1173 idx);
1174 transferReadIdxs.push_back(
1175 rewriter.createvector::ExtractOp(loc, indexAs1dVector, 0));
1176 }
1177
1178
1179 auto dstRank = resultType.getRank();
1180 auto srcRank = extractOp.getTensor().getType().getRank();
1182
1183
1187 auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1188
1189 auto transferReadOp = rewriter.createvector::TransferReadOp(
1190 loc, resultType, extractOp.getTensor(), transferReadIdxs,
1191 permutationMap, inBounds);
1192
1193
1194
1195
1198 auto allTrue = rewriter.createvector::ConstantMaskOp(
1200 auto *maskedReadOp =
1202
1203 LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1205 }
1206
1207
1210
1211 int32_t rankDiff = dstRank - srcRank;
1212
1213
1214
1215
1216
1217
1218
1219 while (rankDiff > 0) {
1220 permutationMap = permutationMap.insertResult(
1222 rankDiff--;
1223 }
1224
1225 auto transferReadOp = rewriter.createvector::TransferReadOp(
1226 loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1227 inBounds);
1228
1229 LDBG("Vectorised as contiguous load: " << extractOp);
1231 }
1232
1233
1234
1235
1236
1238 Value reduceValue, Value initialValue,
1240 Value reduceVec = bvm.lookup(reduceValue);
1241 Value outputVec = bvm.lookup(initialValue);
1242 auto reduceType = dyn_cast(reduceVec.getType());
1243 auto outputType = dyn_cast(outputVec.getType());
1244
1245
1246 if (!reduceType ||
1247 (outputType && reduceType.getShape() == outputType.getShape()))
1248 return nullptr;
1251 }
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1276 LDBG("vectorize op " << *op << "\n");
1277
1278
1279 if (!customVectorizationHooks.empty()) {
1280 for (auto &customFunc : customVectorizationHooks) {
1283 continue;
1284 return result;
1285 }
1286 }
1287
1288
1289
1290 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1292
1293
1296
1297
1300 auto blockArg = dyn_cast(operand);
1301 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1302 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1303 continue;
1306 linalgOp.getRegionOutputArgs(),
1307 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1308 if (!reduceValue)
1309 continue;
1310 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1311 }
1312 if (!reductionOperands.empty()) {
1313 assert(reductionOperands.size() == 1);
1315 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1316 reductionOperands[0].second, bvm);
1317 if (reduceOp)
1319 }
1320
1321
1322
1323 VectorType firstMaxRankedType;
1325 auto vecOperand = bvm.lookup(operand);
1326 assert(vecOperand && "Vector operand couldn't be found");
1327
1328 auto vecType = dyn_cast(vecOperand.getType());
1329 if (vecType && (!firstMaxRankedType ||
1330 firstMaxRankedType.getRank() < vecType.getRank()))
1331 firstMaxRankedType = vecType;
1332 }
1333
1336 Value vecOperand = bvm.lookup(scalarOperand);
1337 assert(vecOperand && "Vector operand couldn't be found");
1338
1339 if (firstMaxRankedType) {
1340 auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1342 firstMaxRankedType.getScalableDims());
1343 vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1344 } else {
1345 vecOperands.push_back(vecOperand);
1346 }
1347 }
1348
1351 resultTypes.push_back(
1352 firstMaxRankedType
1353 ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1354 firstMaxRankedType.getScalableDims())
1355 : resultType);
1356 }
1357
1361 resultTypes, op->getAttrs())};
1362 }
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386 static LogicalResult
1388 LinalgOp linalgOp,
1390 LDBG("Vectorizing operation as linalg generic\n");
1391 Block *block = linalgOp.getBlock();
1392
1393
1394
1398 bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1399
1400 if (linalgOp.getNumDpsInits() == 0)
1401 return failure();
1402
1403
1404 Location loc = linalgOp.getLoc();
1405 Value zero = rewriter.createarith::ConstantIndexOp(loc, 0);
1406 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1407 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1408 if (linalgOp.isScalar(opOperand)) {
1409 bvm.map(bbarg, opOperand->get());
1410 continue;
1411 }
1412
1413
1414
1415 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1416
1418 VectorType readType;
1420 if (linalgOp.isDpsInput(opOperand)) {
1421
1423 readType = state.getCanonicalVecType(elemType);
1424 } else {
1425
1426
1427
1429 readType =
1430 state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1431 }
1432
1433 SmallVector indices(linalgOp.getShape(opOperand).size(), zero);
1434
1435 Operation *read = rewriter.createvector::TransferReadOp(
1436 loc, readType, opOperand->get(), indices, readMap);
1437 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1439
1440
1441
1442 if (auto maskOp = dyn_castvector::MaskingOpInterface(read)) {
1444 castvector::TransferReadOp(maskOp.getMaskableOp())
1446 }
1447
1448
1449
1450 if (readType.getRank() == 0)
1453
1455 << "\n");
1458 }
1459
1461
1464 return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1465 };
1466 hooks.push_back(vectorizeYield);
1467
1468
1472 };
1473 hooks.push_back(vectorizeIndex);
1474
1475
1479 };
1480 hooks.push_back(vectorizeExtract);
1481
1482
1485 vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1487 LDBG("failed to vectorize: " << op << "\n");
1488 return failure();
1489 }
1492 state.maskOperation(rewriter, result.newOp, linalgOp);
1493 LDBG("New vector op: " << *maybeMaskedOp << "\n");
1495 }
1496 }
1497
1498 return success();
1499 }
1500
1501
1502
1506 }
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1558
1559 if (ShapedType::isDynamicShape(destShape))
1560 return false;
1561
1562
1566 cstMaskSizes.push_back(*intSize);
1567 }
1568 }
1569
1570
1571 if (cstMaskSizes.size() != maskShape.size())
1572 return false;
1573
1574
1577 APSInt intVal;
1579 cstWriteIdxs.push_back(intVal.getSExtValue());
1580 }
1581 }
1582
1583
1584 if (cstWriteIdxs.size() != destShape.size())
1585 return false;
1586
1587
1588
1589
1590
1591
1592
1593 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1595 if ( maskShape[i] > destShape[rankDiff + i] ||
1596 destShape[rankDiff + i] <
1597 (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1598 cstWriteIdxs[i]))
1599 return false;
1600 }
1601
1602 return true;
1603 }
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1633 bool useInBoundsInsteadOfMasking = false) {
1634
1635 ShapedType destType = cast(dest.getType());
1636 int64_t destRank = destType.getRank();
1637 auto destShape = destType.getShape();
1638
1639 VectorType vecToStoreType = cast(vecToStore.getType());
1640 int64_t vecToStoreRank = vecToStoreType.getRank();
1641 auto vecToStoreShape = vecToStoreType.getShape();
1642
1643
1645 if (useInBoundsInsteadOfMasking) {
1646
1647
1648 for (unsigned i = 0; i < vecToStoreRank; i++)
1649 inBoundsVal[i] =
1650 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1651 !ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]);
1652 }
1653
1654
1655 assert(writeIndices.empty() ||
1656 writeIndices.size() == static_cast<size_t>(destRank) &&
1657 "Invalid number of write indices!");
1658 if (writeIndices.empty()) {
1659 auto zero = builder.createarith::ConstantIndexOp(loc, 0);
1660 writeIndices.assign(destRank, zero);
1661 }
1662
1663
1665 builder.createvector::TransferWriteOp(loc,
1666 vecToStore,
1667 dest,
1668 writeIndices,
1669 inBoundsVal);
1670
1671
1672 if (useInBoundsInsteadOfMasking)
1673 return write;
1674
1675
1676 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1677 return write;
1678
1679
1681
1685 destSizes.end());
1686
1688 vecToStoreShape))
1689 return write;
1690
1691 Value maskForWrite =
1692 builder.createOrFoldvector::CreateMaskOp(loc, writeMaskType, maskSizes);
1694 }
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730 static LogicalResult
1734
1737
1738 Location loc = packOp.getLoc();
1739 auto padValue = packOp.getPaddingValue();
1740 if (!padValue) {
1741 padValue = rewriter.createarith::ConstantOp(
1742 loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1743 }
1745 LogicalResult status =
1746 cast(packOp.getOperation())
1747 .reifyResultShapes(rewriter, reifiedReturnShapes);
1748 (void)status;
1749 assert(succeeded(status) && "failed to reify result shapes");
1750
1751
1752
1753
1754 bool useInBoundsInsteadOfMasking = false;
1755 if (inputVectorSizes.empty()) {
1756 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1757 inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1758 useInBoundsInsteadOfMasking = true;
1759 }
1760
1761
1763 auto innerTiles = packOp.getStaticInnerTiles();
1764 auto innerDimsPos = packOp.getInnerDimsPos();
1772 rewriter, loc, packOp.getSource(), inputShape, padValue,
1773 useInBoundsInsteadOfMasking);
1774
1775
1779 packOp.getDestType().getElementType());
1780 auto shapeCastOp =
1781 rewriter.createvector::ShapeCastOp(loc, tiledPackType, maskedRead);
1782
1783
1784 auto destPermutation =
1786 auto transposeOp = rewriter.createvector::TransposeOp(
1787 loc, shapeCastOp.getResult(), destPermutation);
1788
1789
1790 Value dest = rewriter.createtensor::EmptyOp(
1791 loc, reifiedReturnShapes[0],
1792 transposeOp.getResult().getType().getElementType());
1795 newResults.push_back(write->getResult(0));
1796 return success();
1797 }
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808 static LogicalResult
1812
1813
1816
1817 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1818
1822 bool useInBoundsInsteadOfMasking = false;
1824
1825 auto destSize = unpackOp.getDestRank();
1826
1827 if (!inputVectorSizes.empty())
1828 assert(inputVectorSizes.size() == destSize &&
1829 "Incorrect number of input vector sizes");
1830
1831
1832
1833
1834
1835
1836
1837
1838
1840 if (vectorSizes.empty()) {
1841 llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1846
1847 useInBoundsInsteadOfMasking = true;
1848 }
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869 SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1870
1872 readVectorSizes[innerDimPos[index]] =
1873 llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1874 }
1877 }
1878 readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1879 sourceShape.end());
1880
1882 LogicalResult status =
1883 cast(unpackOp.getOperation())
1884 .reifyResultShapes(rewriter, reifiedRetShapes);
1885 if (status.failed()) {
1886 LDBG("Unable to reify result shapes of " << unpackOp);
1887 return failure();
1888 }
1889 Location loc = unpackOp->getLoc();
1890
1891 auto padValue = rewriter.createarith::ConstantOp(
1892 loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1893
1894
1895
1897 rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1898 false);
1899
1900 PackingMetadata packMetadata;
1903 ShapedType maskedOpShapedType = cast(readResult.getType());
1905 mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1907 RankedTensorType stripMineTensorType =
1909
1910 vector::TransposeOp transposeOp = rewriter.createvector::TransposeOp(
1911 loc, readResult, lastDimToInsertPosPerm);
1912
1913
1914 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1915 stripMineTensorType, packMetadata.reassociations);
1916 mlir::VectorType vecCollapsedType =
1917 VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1918 vector::ShapeCastOp shapeCastOp = rewriter.createvector::ShapeCastOp(
1919 loc, vecCollapsedType, transposeOp->getResult(0));
1920
1921
1922
1924 unpackOp.getDestType().hasStaticShape()
1925 ? vectorSizes
1926 : shapeCastOp.getResultVectorType().getShape());
1927 Value dest = rewriter.createtensor::EmptyOp(
1928 loc, reifiedRetShapes[0],
1929 shapeCastOp.getResult().getType().getElementType());
1931 rewriter, loc, shapeCastOp.getResult(), dest,
1932 {}, useInBoundsInsteadOfMasking);
1933 newResults.push_back(write->getResult(0));
1934 return success();
1935 }
1936
1937
1938
1939
1940 static LogicalResult
1944 auto padValue = padOp.getConstantPaddingValue();
1945 Location loc = padOp.getLoc();
1946
1947
1950
1952 LogicalResult status =
1953 cast(padOp.getOperation())
1954 .reifyResultShapes(rewriter, reifiedReturnShapes);
1955 (void)status;
1956 assert(succeeded(status) && "failed to reify result shapes");
1958 rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1959 false);
1960
1961
1962 Value dest = rewriter.createtensor::EmptyOp(
1963 loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
1965 newResults.push_back(write->getResult(0));
1966 return success();
1967 }
1968
1969
1970
1973 LDBG("reduction precondition failed: no reduction iterator\n");
1974 return failure();
1975 }
1976 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1977 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1979 continue;
1980
1983 LDBG("reduction precondition failed: reduction detection failed\n");
1984 return failure();
1985 }
1986 }
1987 return success();
1988 }
1989
1990 static LogicalResult
1992 bool flatten1DDepthwiseConv) {
1993 if (flatten1DDepthwiseConv) {
1994 LDBG("Vectorization of flattened convs with dynamic shapes is not "
1995 "supported\n");
1996 return failure();
1997 }
1998
1999 if (!isalinalg::DepthwiseConv1DNwcWcOp(conv)) {
2000 LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
2001 return failure();
2002 }
2003
2004
2005
2006 Value lhs = conv.getDpsInputOperand(0)->get();
2008 auto shapeWithoutCh = lhsShape.drop_back(1);
2009 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2010 LDBG("Dynamically-shaped op vectorization precondition failed: only "
2011 "channel dim can be dynamic\n");
2012 return failure();
2013 }
2014
2015 return success();
2016 }
2017
2018 static LogicalResult
2020 bool flatten1DDepthwiseConv) {
2021 if (isa(op.getOperation()))
2023
2026
2027
2028
2030 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2031 op.getOperation()))
2032 return failure();
2033
2034 LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
2035 return success();
2036 }
2037
2038
2039 static LogicalResult
2042
2043 if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
2044 return !getConstantIntValue(res).has_value();
2045 })) {
2046 LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
2047 return failure();
2048 }
2049 ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
2050 bool satisfyEmptyCond = inputVectorSizes.empty() &&
2051 unpackOp.getDestType().hasStaticShape() &&
2052 unpackOp.getSourceType().hasStaticShape();
2053 if (!satisfyEmptyCond &&
2055 return failure();
2056
2057 return success();
2058 }
2059
2060 static LogicalResult
2063
2065 auto sourceType = source.getType();
2066 if (!VectorType::isValidElementType(sourceType.getElementType()))
2067 return failure();
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2082 bool isOutOfBoundsRead =
2083 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2084
2085 if (!padValue && isOutOfBoundsRead) {
2086 LDBG("Failed to get a pad value for out-of-bounds read access\n");
2087 return failure();
2088 }
2089 return success();
2090 }
2091
2092 namespace {
2093 enum class ConvOperationKind { Conv, Pool };
2094 }
2095
2097 return isa(op) && op->getNumOperands() == 1 &&
2098 isa(op->getOperand(0));
2099 }
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111 static std::optional
2113 int numBlockArguments =
2114 llvm::count_if(reduceOp->getOperands(), llvm::IsaPred);
2115
2116 switch (numBlockArguments) {
2117 case 1: {
2118
2119
2120
2121
2122 auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
2123 llvm::IsaPred);
2124 assert(feedValIt != reduceOp->operand_end() &&
2125 "Expected a non-block argument operand");
2126 Operation *feedOp = (*feedValIt).getDefiningOp();
2128 return ConvOperationKind::Pool;
2129 }
2130
2131 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2132 (isaarith::AndIOp(feedOp) &&
2135 if (isa(v))
2136 return true;
2137 if (Operation *op = v.getDefiningOp())
2138 return isCastOfBlockArgument(op);
2139 return false;
2140 }))) {
2141 return std::nullopt;
2142 }
2143
2144 return ConvOperationKind::Conv;
2145 }
2146 case 2:
2147
2148 return ConvOperationKind::Pool;
2149 default:
2150 return std::nullopt;
2151 }
2152 }
2153
2155 switch (kind) {
2156 case vector::CombiningKind::ADD:
2157 case vector::CombiningKind::MAXNUMF:
2158 case vector::CombiningKind::MAXIMUMF:
2159 case vector::CombiningKind::MAXSI:
2160 case vector::CombiningKind::MAXUI:
2161 case vector::CombiningKind::MINNUMF:
2162 case vector::CombiningKind::MINIMUMF:
2163 case vector::CombiningKind::MINSI:
2165 return true;
2166 default:
2167 return false;
2168 }
2169 }
2170
2172 auto getOperandType = [&](auto operand) {
2173 return dyn_cast((operand->get()).getType());
2174 };
2175 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2176 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2177 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2178
2179
2180
2181 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2182 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2183 return failure();
2184
2186 if (!reduceOp)
2187 return failure();
2188
2190 if (!maybeOper.has_value())
2191 return failure();
2192
2194
2195
2196
2197 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2198 *maybeKind != vector::CombiningKind::OR) &&
2199 (*maybeOper != ConvOperationKind::Pool ||
2201 return failure();
2202 }
2203
2204 auto rhsRank = rhsShapedType.getRank();
2205 if (*maybeOper == ConvOperationKind::Pool) {
2206 if (rhsRank != 1)
2207 return failure();
2208 } else {
2209 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2210 return failure();
2211 }
2212
2213 return success();
2214 }
2215
2218 bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2219
2220 if (llvm::is_contained(linalgOp.getStaticShape(), 0))
2221 return failure();
2222
2223 if (!inputVectorSizes.empty() &&
2225 inputVectorSizes)))
2226 return failure();
2227
2229 linalgOp, flatten1DDepthwiseConv))) {
2230 LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
2231 return failure();
2232 }
2233
2235
2236
2238
2239
2241
2242 if (llvm::any_of(
2243 customPreconditions,
2245 return succeeded(
2246 customPrecondition(&innerOp, vectorizeNDExtract));
2247 })) {
2248 continue;
2249 }
2250 if (!llvm::all_of(innerOp.getOperandTypes(),
2251 VectorType::isValidElementType)) {
2252 return failure();
2253 }
2254 if (!llvm::all_of(innerOp.getResultTypes(),
2255 VectorType::isValidElementType)) {
2256 return failure();
2257 }
2258 }
2260 return success();
2261
2262
2263
2264
2265 if (isa(linalgOp.getOperation()))
2267
2268
2269
2270
2272 LDBG("precondition failed: not projected permutations\n");
2273 return failure();
2274 }
2276 LDBG("precondition failed: reduction preconditions\n");
2277 return failure();
2278 }
2279 return success();
2280 }
2281
2282 static LogicalResult
2285 auto padValue = packOp.getPaddingValue();
2288 LDBG("pad value is not constant: " << packOp << "\n");
2289 return failure();
2290 }
2291 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2292 bool satisfyEmptyCond = true;
2293 if (inputVectorSizes.empty()) {
2294 if (!packOp.getDestType().hasStaticShape() ||
2295 !packOp.getSourceType().hasStaticShape())
2296 satisfyEmptyCond = false;
2297 }
2298
2299 if (!satisfyEmptyCond &&
2301 resultTensorShape.take_front(packOp.getSourceRank()),
2302 inputVectorSizes)))
2303 return failure();
2304
2305 if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2306 return !getConstantIntValue(v).has_value();
2307 })) {
2308 LDBG("inner_tiles must be constant: " << packOp << "\n");
2309 return failure();
2310 }
2311
2312 return success();
2313 }
2314
2315 static LogicalResult
2318 auto padValue = padOp.getConstantPaddingValue();
2319 if (!padValue) {
2320 LDBG("pad value is not constant: " << padOp << "\n");
2321 return failure();
2322 }
2323
2324 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2326 inputVectorSizes)))
2327 return failure();
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340 if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) {
2341 Value padValue = en.value();
2342 unsigned pos = en.index();
2343 std::optional<int64_t> pad = getConstantIntValue(padValue);
2344 return (!pad.has_value() || pad.value() != 0) &&
2345 resultTensorShape[pos] != 1;
2346 })) {
2347 LDBG("low pad must all be zero for all non unit dims: " << padOp << "\n");
2348 return failure();
2349 }
2350
2351 return success();
2352 }
2353
2354
2355
2356 static LogicalResult
2360 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2361 "Number of input vector sizes and scalable dims doesn't match");
2362
2363 size_t numOfScalableDims =
2364 llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2365
2366 if (numOfScalableDims == 0)
2367 return success();
2368
2369 auto linalgOp = dyn_cast(op);
2370
2371
2372
2373 if (!linalgOp)
2374 return failure();
2375
2376
2377 if (numOfScalableDims > 2)
2378 return failure();
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397 bool seenNonUnitParallel = false;
2398 auto iterators = linalgOp.getIteratorTypesArray();
2400 int64_t idx = scalableFlags.size() - 1;
2401 while (!scalableFlags[idx]) {
2402 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2403 seenNonUnitParallel |=
2404 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2405
2406 iterators.pop_back();
2407 scalableFlags.pop_back();
2408 --idx;
2409 }
2410
2411
2412 switch (iterators.back()) {
2413 case utils::IteratorType::reduction: {
2414
2415 if (iterators.size() != inputVectorSizes.size()) {
2416 LDBG("Non-trailing reduction dim requested for scalable "
2417 "vectorization\n");
2418 return failure();
2419 }
2420 if (isalinalg::MatmulOp(op) || isalinalg::MatmulTransposeAOp(op)) {
2421 LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2422 "is not supported\n");
2423 return failure();
2424 }
2425 break;
2426 }
2427 case utils::IteratorType::parallel: {
2428
2429 if (seenNonUnitParallel) {
2430 LDBG("Inner parallel dim not requested for scalable "
2431 "vectorization\n");
2432 return failure();
2433 }
2434 break;
2435 }
2436 }
2437
2438
2439
2440
2441
2442 if (numOfScalableDims == 2) {
2443
2444
2445
2446 if (iterators.back() == utils::IteratorType::reduction) {
2447 LDBG("Higher dim than the trailing reduction dim requested for scalable "
2448 "vectorization\n");
2449 return failure();
2450 }
2451 scalableFlags.pop_back();
2452 iterators.pop_back();
2453
2454 if (!scalableFlags.back() ||
2455 (iterators.back() != utils::IteratorType::parallel))
2456 return failure();
2457 }
2458
2459
2460
2461 if (linalgOp.hasUserDefinedMaps())
2462 return failure();
2463
2464
2465
2466 return success(isElementwise(linalgOp) || isalinalg::MatmulOp(op) ||
2467 isalinalg::MatmulTransposeAOp(op) ||
2468 isalinalg::DepthwiseConv1DNwcWcOp(op) ||
2470 }
2471
2474 ArrayRef inputScalableVecDims, bool vectorizeNDExtract,
2475 bool flatten1DDepthwiseConv) {
2476
2478 return failure();
2479
2481 inputScalableVecDims)))
2482 return failure();
2483
2485 .Caselinalg::LinalgOp([&](auto linalgOp) {
2487 vectorizeNDExtract,
2488 flatten1DDepthwiseConv);
2489 })
2490 .Casetensor::PadOp([&](auto padOp) {
2492 })
2493 .Caselinalg::PackOp([&](auto packOp) {
2495 })
2496 .Caselinalg::UnPackOp([&](auto unpackOp) {
2498 })
2499 .Casetensor::InsertSliceOp([&](auto sliceOp) {
2501 })
2502 .Default([](auto) { return failure(); });
2503 }
2504
2505
2508 auto toReplace = linalgOp.getBlock()->getOpsaffine::AffineApplyOp();
2509
2510 for (auto op : make_early_inc_range(toReplace)) {
2513 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2514 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2515 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2516 rewriter.replaceOp(op, expanded);
2517 }
2518 }
2519
2521 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2522 tensor::InsertSliceOp>(op);
2523 }
2524
2525
2526
2527
2528
2529
2530
2534 bool vectorizeNDExtract,
2535 bool flatten1DDepthwiseConv) {
2536 LDBG("Attempting to vectorize:\n" << *op << "\n");
2537 LDBG("Input vector sizes: ");
2538 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2539 LLVM_DEBUG(llvm::dbgs() << "\n");
2540 LDBG("Input scalable vector dims: ");
2541 LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2542 LLVM_DEBUG(llvm::dbgs() << "\n");
2543
2545 vectorizeNDExtract,
2546 flatten1DDepthwiseConv))) {
2547 LDBG("Vectorization pre-conditions failed\n");
2548 return failure();
2549 }
2550
2551
2553 if (auto linalgOp = dyn_castlinalg::LinalgOp(op)) {
2554 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2555 inputScalableVecDims))) {
2556 LDBG("Vectorization state couldn't be initialized\n");
2557 return failure();
2558 }
2559 }
2560
2562 auto vectorizeResult =
2564 .Caselinalg::LinalgOp([&](auto linalgOp) {
2565
2566
2567
2568 if (isa(linalgOp.getOperation())) {
2570 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2571 flatten1DDepthwiseConv);
2572 if (succeeded(convOr)) {
2573 llvm::append_range(results, (*convOr)->getResults());
2574 return success();
2575 }
2576
2577 LDBG("Unsupported convolution can't be vectorized.\n");
2578 return failure();
2579 }
2580
2581 LDBG("Vectorize generic by broadcasting to the canonical vector "
2582 "shape\n");
2583
2584
2586
2587
2588
2589
2590
2591
2593 })
2594 .Casetensor::PadOp([&](auto padOp) {
2596 results);
2597 })
2598 .Caselinalg::PackOp([&](auto packOp) {
2600 results);
2601 })
2602 .Caselinalg::UnPackOp([&](auto unpackOp) {
2604 inputVectorSizes, results);
2605 })
2606 .Casetensor::InsertSliceOp([&](auto sliceOp) {
2608 results);
2609 })
2610 .Default([](auto) { return failure(); });
2611
2612 if (failed(vectorizeResult)) {
2613 LDBG("Vectorization failed\n");
2614 return failure();
2615 }
2616
2617 if (!results.empty())
2618 rewriter.replaceOp(op, results);
2619 else
2621
2622 return success();
2623 }
2624
2626 memref::CopyOp copyOp) {
2627 auto srcType = cast(copyOp.getSource().getType());
2628 auto dstType = cast(copyOp.getTarget().getType());
2629 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2630 return failure();
2631
2634 if (!VectorType::isValidElementType(srcElementType) ||
2635 !VectorType::isValidElementType(dstElementType))
2636 return failure();
2637
2638 auto readType = VectorType::get(srcType.getShape(), srcElementType);
2639 auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2640
2641 Location loc = copyOp->getLoc();
2644
2646 loc, readType, copyOp.getSource(), indices,
2648 if (cast(readValue.getType()).getRank() == 0) {
2652 }
2653 Operation *writeValue = rewriter.createvector::TransferWriteOp(
2654 loc, readValue, copyOp.getTarget(), indices,
2657 return success();
2658 }
2659
2660
2661
2662
2663
2664
2665 template
2668
2672
2673 for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2674 if (auto op = dyn_cast(user))
2675 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2676 return success(changed);
2677 }
2678
2679 protected:
2681 tensor::PadOp padOp, OpTy op) const = 0;
2682 };
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2707
2709 vector::TransferReadOp xferOp) const override {
2710
2711 if (!padOp.hasZeroLowPad())
2712 return failure();
2713
2714 auto padValue = padOp.getConstantPaddingValue();
2715 if (!padValue)
2716 return failure();
2717
2718 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2719 return failure();
2720
2722 SmallVector inBounds(xferOp.getVectorType().getRank(), false);
2723 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2725 xferOp.getBaseMutable().assign(padOp.getSource());
2726 xferOp.getPaddingMutable().assign(padValue);
2727 });
2728
2729 return success();
2730 }
2731 };
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2769
2771 vector::TransferWriteOp xferOp) const override {
2772
2773 if (xferOp.getTransferRank() == 0)
2774 return failure();
2775
2776
2777 if (!padOp.hasZeroLowPad())
2778 return failure();
2779
2780 auto padValue = padOp.getConstantPaddingValue();
2781 if (!padValue)
2782 return failure();
2783
2784 if (!xferOp->hasOneUse())
2785 return failure();
2786 auto trimPadding = dyn_casttensor::ExtractSliceOp(*xferOp->user_begin());
2787 if (!trimPadding)
2788 return failure();
2789
2790 if (!trimPadding.hasZeroOffset())
2791 return failure();
2792
2793 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2794 return failure();
2795
2796
2798
2799 SmallVector inBounds(xferOp.getVectorType().getRank(), false);
2800 auto newXferOp = rewriter.replaceOpWithNewOpvector::TransferWriteOp(
2801 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2802 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2804 rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2805
2806 return success();
2807 }
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2820 tensor::ExtractSliceOp afterTrimming) const {
2821
2822
2823 if (auto castOp = beforePadding.getDefiningOptensor::CastOp())
2824 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2825 return true;
2826
2827 auto t1 = dyn_cast(beforePadding.getType());
2828 auto t2 = dyn_cast(afterTrimming.getType());
2829
2830 if (!t1 || !t2)
2831 return false;
2832
2833 if (t1.getRank() != t2.getRank())
2834 return false;
2835
2836
2837
2838 for (unsigned i = 0; i < t1.getRank(); ++i) {
2839 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2840 return false;
2841 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2842 return false;
2843 }
2844
2845
2846 if (t1.getNumDynamicDims() == 0)
2847 return true;
2848
2849
2850
2851
2852
2853
2854 auto beforeSlice = beforePadding.getDefiningOptensor::ExtractSliceOp();
2855 if (!beforeSlice)
2856 return false;
2857
2858 assert(static_cast<size_t>(t1.getRank()) ==
2859 beforeSlice.getMixedSizes().size());
2860 assert(static_cast<size_t>(t2.getRank()) ==
2861 afterTrimming.getMixedSizes().size());
2862
2863 for (unsigned i = 0; i < t1.getRank(); ++i) {
2864
2865 if (!t1.isDynamicDim(i))
2866 continue;
2867 auto size1 = beforeSlice.getMixedSizes()[i];
2868 auto size2 = afterTrimming.getMixedSizes()[i];
2869
2870
2872 continue;
2873
2874
2875 auto v1 = llvm::dyn_cast_if_present(size1);
2876 auto v2 = llvm::dyn_cast_if_present(size2);
2877 if (!v1 || !v2)
2878 return false;
2879
2880
2881
2882 auto minOp1 = v1.getDefiningOpaffine::AffineMinOp();
2883 auto minOp2 = v2.getDefiningOpaffine::AffineMinOp();
2884 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2885 minOp1.getOperands() == minOp2.getOperands())
2886 continue;
2887
2888
2889 }
2890
2891
2892 return true;
2893 }
2894 };
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2906 if (!op)
2907 return {};
2908
2909
2910
2911 if (auto bcast = llvm::dyn_castvector::BroadcastOp(op)) {
2912 auto source = bcast.getSource();
2913 if (llvm::dyn_cast(source.getType()))
2914 return {};
2915
2916 return source;
2917 }
2918
2919
2920
2921 if (auto fill = llvm::dyn_castlinalg::FillOp(op)) {
2922 return fill.getInputs()[0];
2923 }
2924
2925
2926
2927 if (auto generate = llvm::dyn_casttensor::GenerateOp(op)) {
2928 return {};
2929 }
2930
2931
2932
2933
2934 if (auto xferWrite = llvm::dyn_castvector::TransferWriteOp(op))
2935 return getStaticPadVal(xferWrite.getVector().getDefiningOp());
2936
2937
2938
2939
2940
2941
2942 if (auto slice = llvm::dyn_casttensor::InsertSliceOp(op))
2943 return getStaticPadVal(slice.getDest().getDefiningOp());
2944
2945 return {};
2946 }
2947
2948 static LogicalResult
2952
2955
2957 auto sourceType = source.getType();
2958 auto resultType = sliceOp.getResultType();
2959
2961
2962 if (!padValue) {
2963 auto elemType = sourceType.getElementType();
2964 padValue = rewriter.createarith::ConstantOp(
2965 sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2966 }
2967
2968
2970 size_t rankDiff = resultType.getRank() - sourceType.getRank();
2971 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
2972 if (!inputVectorSizes.empty()) {
2973 vecShape.push_back(inputVectorSizes[i]);
2974 } else if (!sourceType.isDynamicDim(i)) {
2975 vecShape.push_back(sourceType.getDimSize(i));
2976 } else if (!resultType.isDynamicDim(i)) {
2977
2978
2979
2980
2981
2982 vecShape.push_back(resultType.getDimSize(rankDiff + i));
2983 } else {
2984
2985
2986 return failure();
2987 }
2988 }
2989 auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2990
2991
2992 auto loc = sliceOp.getLoc();
2993
2994
2996 vecType.getRank(), rewriter.createarith::ConstantIndexOp(loc, 0));
2998 rewriter, loc, source, vecType.getShape(), padValue,
2999 inputVectorSizes.empty());
3000
3001
3002 auto writeIndices =
3006 writeIndices, inputVectorSizes.empty());
3007
3008
3009 newResults.push_back(write->getResult(0));
3010
3011 return success();
3012 }
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3041
3043 tensor::InsertSliceOp insertOp) const override {
3044
3045 if (!padOp.hasZeroLowPad())
3046 return failure();
3047
3048 if (!insertOp.hasUnitStride())
3049 return failure();
3050
3051 auto padValue = padOp.getConstantPaddingValue();
3052 if (!padValue)
3053 return failure();
3054
3055 if (!cast(padOp.getResult().getType()).hasStaticShape())
3056 return failure();
3057
3058 if (insertOp.getDest() == padOp.getResult())
3059 return failure();
3060
3061 auto vecType = VectorType::get(padOp.getType().getShape(),
3062 padOp.getType().getElementType());
3063 unsigned vecRank = vecType.getRank();
3064 unsigned tensorRank = insertOp.getType().getRank();
3065
3066
3067
3069 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3070 if (!llvm::all_of(
3071 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
3072 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3073 }))
3074 return failure();
3075
3076
3077
3079
3080
3081
3083 vecRank, rewriter.createarith::ConstantIndexOp(padOp.getLoc(), 0));
3084 auto read = rewriter.createvector::TransferReadOp(
3085 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
3086
3087
3088
3089
3091 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3094 insertOp, read, insertOp.getDest(), writeIndices,
3096
3097 return success();
3098 }
3099 };
3100
3107 }
3108
3109
3110
3111
3112
3113
3114
3115
3120 LDBG("interleavedUses precondition failed, firstOp: "
3121 << *firstOp << ", second op: " << *secondOp << "\n");
3122 return true;
3123 }
3124 for (auto v : values) {
3125 for (auto &u : v.getUses()) {
3126 Operation *owner = u.getOwner();
3127 if (owner == firstOp || owner == secondOp)
3128 continue;
3129
3132 continue;
3133 LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
3134 << ", second op: " << *secondOp << "\n");
3135 return true;
3136 }
3137 }
3138 return false;
3139 }
3140
3141
3142
3144 memref::SubViewOp subViewOp;
3145 for (auto &u : v.getUses()) {
3146 if (auto newSubViewOp = dyn_castmemref::SubViewOp(u.getOwner())) {
3147 if (subViewOp)
3148 return memref::SubViewOp();
3149 subViewOp = newSubViewOp;
3150 }
3151 }
3152 return subViewOp;
3153 }
3154
3155
3156
3158 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3159
3160
3161 if (xferOp.getMask())
3163
3164
3165 Value viewOrAlloc = xferOp.getBase();
3166 if (!viewOrAlloc.getDefiningOpmemref::ViewOp() &&
3168 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3169
3170
3172 if (!subViewOp)
3174 Value subView = subViewOp.getResult();
3175
3176
3177 memref::CopyOp copyOp;
3178 for (auto &u : subView.getUses()) {
3179 if (auto newCopyOp = dyn_castmemref::CopyOp(u.getOwner())) {
3180 assert(isa(newCopyOp.getTarget().getType()));
3181 if (newCopyOp.getTarget() != subView)
3182 continue;
3184 continue;
3185 copyOp = newCopyOp;
3186 break;
3187 }
3188 }
3189 if (!copyOp)
3191
3192
3193
3194 FillOp maybeFillOp;
3195 for (auto &u : viewOrAlloc.getUses()) {
3196 if (auto newFillOp = dyn_cast(u.getOwner())) {
3197 assert(isa(newFillOp.output().getType()));
3198 if (newFillOp.output() != viewOrAlloc)
3199 continue;
3201 continue;
3202 maybeFillOp = newFillOp;
3203 break;
3204 }
3205 }
3206
3207 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3209 "padding value does not match fill");
3210
3211
3212 Value in = copyOp.getSource();
3213
3214
3215
3216
3217
3218 auto vectorType = xferOp.getVectorType();
3219 Value res = rewriter.createvector::TransferReadOp(
3220 xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3221 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3224
3225 if (maybeFillOp)
3226 rewriter.eraseOp(maybeFillOp);
3227 rewriter.eraseOp(copyOp);
3228 rewriter.replaceOp(xferOp, res);
3229
3230 return success();
3231 }
3232
3233
3234
3236 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3237
3238 if (xferOp.getMask())
3240
3241
3242 Value viewOrAlloc = xferOp.getBase();
3243 if (!viewOrAlloc.getDefiningOpmemref::ViewOp() &&
3245 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3246
3247
3249 if (!subViewOp)
3251 Value subView = subViewOp.getResult();
3252
3253
3254 memref::CopyOp copyOp;
3255 for (auto &u : subViewOp.getResult().getUses()) {
3256 if (auto newCopyOp = dyn_castmemref::CopyOp(u.getOwner())) {
3257 if (newCopyOp.getSource() != subView)
3258 continue;
3260 continue;
3261 copyOp = newCopyOp;
3262 break;
3263 }
3264 }
3265 if (!copyOp)
3267
3268
3269 assert(isa(copyOp.getTarget().getType()));
3270 Value out = copyOp.getTarget();
3271
3272
3273
3274
3275
3276
3277 auto vector = xferOp.getVector();
3278 rewriter.createvector::TransferWriteOp(
3279 xferOp.getLoc(), vector, out, xferOp.getIndices(),
3280 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3282 dyn_cast(vector.getType()).getRank(), false)));
3283
3284 rewriter.eraseOp(copyOp);
3285 rewriter.eraseOp(xferOp);
3286
3287 return success();
3288 }
3289
3290
3291
3292
3293
3294 template
3296
3297 template <int N, typename IntTy, typename... IntTy2>
3298 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3299 val = shapedType.getShape()[N];
3300 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3301 }
3302
3303
3304 template <typename... IntTy>
3305 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3306 bindShapeDims<0>(shapedType, vals...);
3307 }
3308
3309 namespace {
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344 struct Conv1DGenerator
3346 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
3347 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3348
3349 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3350 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3351 resShaped = linalgOp.getDpsInitOperand(0)->get();
3352 lhsShapedType = dyn_cast(lhsShaped.getType());
3353 rhsShapedType = dyn_cast(rhsShaped.getType());
3354 resShapedType = dyn_cast(resShaped.getType());
3355
3358
3359 setConvOperationKind(reduceOp);
3360
3362 reductionKind = maybeKind.value();
3363
3364
3365
3366
3367
3369 auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
3370 strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3371 dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3372 }
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
3390
3391
3392 FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3393 int64_t nSize, wSize, cSize, kwSize, fSize;
3395 bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3396 switch (conv1DOpOrder) {
3398
3399 nSize = fSize = cSize = 0;
3400
3402
3404 lhsShape = {
3405
3406 (wSize + kwSize - 1)};
3407 rhsShape = {kwSize};
3408 resShape = {wSize};
3409 break;
3411
3412 bindShapeDims(resShapedType, nSize, wSize, fSize);
3413 switch (oper) {
3414 case ConvOperationKind::Conv:
3415
3417 break;
3418 case ConvOperationKind::Pool:
3419
3421 cSize = fSize;
3422 break;
3423 }
3424 lhsShape = {nSize,
3425
3426
3427
3428 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3429 1,
3430 cSize};
3431 switch (oper) {
3432 case ConvOperationKind::Conv:
3433 rhsShape = {kwSize, cSize, fSize};
3434 break;
3435 case ConvOperationKind::Pool:
3436 rhsShape = {kwSize};
3437 break;
3438 }
3439 resShape = {nSize, wSize, fSize};
3440 break;
3442
3443 bindShapeDims(resShapedType, nSize, fSize, wSize);
3444 switch (oper) {
3445 case ConvOperationKind::Conv:
3446
3447 bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3448 break;
3449 case ConvOperationKind::Pool:
3450
3452 cSize = fSize;
3453 break;
3454 }
3455 lhsShape = {nSize, cSize,
3456
3457
3458
3459 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3460 1};
3461 switch (oper) {
3462 case ConvOperationKind::Conv:
3463 rhsShape = {fSize, cSize, kwSize};
3464 break;
3465 case ConvOperationKind::Pool:
3466 rhsShape = {kwSize};
3467 break;
3468 }
3469 resShape = {nSize, fSize, wSize};
3470 break;
3471 }
3472
3473 vector::TransferWriteOp write;
3474 Value zero = rewriter.createarith::ConstantIndexOp(loc, 0);
3475
3476
3477
3478
3479 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3480
3481 Type lhsEltType = lhsShapedType.getElementType();
3482 Type rhsEltType = rhsShapedType.getElementType();
3483 Type resEltType = resShapedType.getElementType();
3487
3491
3492
3493 Value lhs = rewriter.createvector::TransferReadOp(loc, lhsType, lhsShaped,
3494 lhsPadding);
3495
3496 Value rhs = nullptr;
3497 if (oper == ConvOperationKind::Conv)
3498 rhs = rewriter.createvector::TransferReadOp(loc, rhsType, rhsShaped,
3499 rhsPadding);
3500 Value res = rewriter.createvector::TransferReadOp(loc, resType, resShaped,
3501 resPadding);
3502
3503
3504
3505
3506 switch (conv1DOpOrder) {
3509
3510 break;
3512
3513
3514 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3515 lhs = rewriter.createvector::TransposeOp(loc, lhs, permLhs);
3516
3517 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3518
3519
3520 if (oper == ConvOperationKind::Conv)
3521 rhs = rewriter.createvector::TransposeOp(loc, rhs, permRhs);
3522
3523 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3524 res = rewriter.createvector::TransposeOp(loc, res, permRes);
3525 break;
3526 }
3527 }
3528
3529
3530
3531
3532
3535 kwSize, strideW, dilationW, wSizeStep,
3536 isSingleChanneled);
3537
3538 if (oper == ConvOperationKind::Conv)
3541 wSizeStep, isSingleChanneled);
3542
3543 auto linearIndex = [&](int64_t kw, int64_t w) {
3544 return kw * (wSize / wSizeStep) + w;
3545 };
3546
3547
3548
3549
3550 for (int64_t kw = 0; kw < kwSize; ++kw) {
3551 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3552 switch (oper) {
3553 case ConvOperationKind::Conv:
3554 if (isSingleChanneled) {
3555 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3556 lhsVals[linearIndex(kw, w)],
3557 rhsVals[kw], resVals[w]);
3558 } else {
3559 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3560 lhsVals[linearIndex(kw, w)],
3561 rhsVals[kw], resVals[w]);
3562 }
3563 break;
3564 case ConvOperationKind::Pool:
3565 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3566 resVals[w]);
3567 break;
3568 }
3569 }
3570 }
3571
3573 isSingleChanneled);
3574
3575
3576
3577
3578
3579
3580
3581 switch (conv1DOpOrder) {
3584
3585 break;
3587
3588 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3589 res = rewriter.createvector::TransposeOp(loc, res, perm);
3590 break;
3591 }
3592 }
3593
3594 return rewriter
3595 .createvector::TransferWriteOp(loc, res, resShaped, resPadding)
3596 .getOperation();
3597 }
3598
3599
3603 assert(isa(dstElementType) || isa(dstElementType));
3604 if (srcElementType == dstElementType)
3605 return val;
3606
3609 const Type dstType =
3610 cast(val.getType()).cloneWith(std::nullopt, dstElementType);
3611
3612 if (isa(srcElementType) && isa(dstElementType)) {
3613 return rewriter.createarith::SIToFPOp(loc, dstType, val);
3614 }
3615
3616 if (isa(srcElementType) && isa(dstElementType) &&
3617 srcWidth < dstWidth)
3618 return rewriter.createarith::ExtFOp(loc, dstType, val);
3619
3620 if (isa(srcElementType) && isa(dstElementType) &&
3621 srcWidth < dstWidth)
3622 return rewriter.createarith::ExtSIOp(loc, dstType, val);
3623
3624 assert(false && "unhandled promotion case");
3625 return nullptr;
3626 }
3627
3628
3631 vector::IteratorType par = vector::IteratorType::parallel;
3632 vector::IteratorType red = vector::IteratorType::reduction;
3637 auto contrationOp = rewriter.createvector::ContractionOp(
3638 loc, lhs, rhs, res,
3639 MapList{{n, w, c}, {c, f}, {n, w, f}},
3641 contrationOp.setKind(reductionKind);
3642 return contrationOp;
3643 }
3644
3645
3646
3649 return rewriter.createvector::OuterProductOp(
3650 loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3651 }
3652
3653
3656 if (isPoolExt)
3658 return rewriter
3661 }
3662
3663
3664
3665
3666
3667
3668
3669
3670
3671
3672 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3673 bool channelDimScalableFlag,
3674 bool flatten) {
3675 bool scalableChDim = false;
3676 bool useMasking = false;
3677 int64_t nSize, wSize, cSize, kwSize;
3678
3680 if (ShapedType::isDynamic(cSize)) {
3681 assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3682 cSize = channelDimVecSize;
3683
3684
3685
3686 scalableChDim = channelDimScalableFlag;
3687 useMasking = true;
3688 }
3689
3690 assert(!(useMasking && flatten) &&
3691 "Unsupported flattened conv with dynamic shapes");
3692
3693
3695
3696 vector::TransferWriteOp write;
3697 Value zero = rewriter.createarith::ConstantIndexOp(loc, 0);
3698
3699
3700
3701
3702 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3703
3704 Type lhsEltType = lhsShapedType.getElementType();
3705 Type rhsEltType = rhsShapedType.getElementType();
3706 Type resEltType = resShapedType.getElementType();
3708 {nSize,
3709
3710
3711 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3712 cSize},
3713 lhsEltType, {false, false, scalableChDim});
3714 VectorType rhsType =
3716 {false, scalableChDim});
3717 VectorType resType =
3719 {false, false, scalableChDim});
3720
3721
3722
3726 if (!useMasking)
3727 return opToMask;
3728 auto maskType =
3730
3732 auto xferOp = cast(opToMask);
3733 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3735
3737 cast(op).hasPureTensorSemantics(), opToMask, rewriter);
3738
3740 rewriter.createvector::CreateMaskOp(loc, maskType, mixedDims);
3741
3743 };
3744
3745
3746
3747 Value lhs = rewriter.createvector::TransferReadOp(
3748 loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3749 auto maybeMaskedLhs = maybeMaskXferOp(
3750 lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3751
3752
3753 Value rhs = rewriter.createvector::TransferReadOp(loc, rhsType, rhsShaped,
3755 auto maybeMaskedRhs = maybeMaskXferOp(
3756 rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3757
3758
3759 Value res = rewriter.createvector::TransferReadOp(
3760 loc, resType, resShaped, ValueRange{zero, zero, zero});
3761 auto maybeMaskedRes = maybeMaskXferOp(
3762 resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3763
3764
3765
3766
3767
3771
3772
3773
3774 for (int64_t kw = 0; kw < kwSize; ++kw) {
3775 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3776 lhsVals.push_back(rewriter.createvector::ExtractStridedSliceOp(
3777 loc, maybeMaskedLhs->getResult(0),
3778 ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3779 inOutSliceSizes, inOutStrides));
3780 }
3781 }
3782
3783 for (int64_t kw = 0; kw < kwSize; ++kw) {
3784 rhsVals.push_back(rewriter.createvector::ExtractOp(
3785 loc, maybeMaskedRhs->getResult(0),
3787 }
3788
3789 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3790 resVals.push_back(rewriter.createvector::ExtractStridedSliceOp(
3791 loc, maybeMaskedRes->getResult(0),
3793 inOutStrides));
3794 }
3795
3796 auto linearIndex = [&](int64_t kw, int64_t w) {
3797 return kw * (wSize / wSizeStep) + w;
3798 };
3799
3800
3801
3803 auto lhsTypeAfterFlattening =
3805 auto resTypeAfterFlattening =
3807
3808
3809 for (int64_t kw = 0; kw < kwSize; ++kw) {
3810 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3811 Value lhsVal = lhsVals[linearIndex(kw, w)];
3812 Value resVal = resVals[w];
3813 if (flatten) {
3814
3815
3816 lhsVal = rewriter.createvector::ShapeCastOp(
3817 loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3818 resVal = rewriter.createvector::ShapeCastOp(
3819 loc, resTypeAfterFlattening, resVals[w]);
3820 }
3821 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3822 rhsVals[kw], resVal, flatten);
3823 if (flatten) {
3824
3825 resVals[w] = rewriter.createvector::ShapeCastOp(
3826 loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3827 }
3828 }
3829 }
3830
3831
3832 if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3833
3834 for (auto &collection :
3835 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3836 for (Value v : collection)
3839 }
3840
3841
3842
3843 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3844 maybeMaskedRes = rewriter.createvector::InsertStridedSliceOp(
3845 loc, resVals[w], maybeMaskedRes->getResult(0),
3848 }
3849
3850
3851
3852
3853
3854 Operation *resOut = rewriter.createvector::TransferWriteOp(
3855 loc, maybeMaskedRes->getResult(0), resShaped,
3857 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3858 resOut);
3859 }
3860
3861
3862
3863
3864
3867 bool flatten) {
3868 auto rhsTy = cast(rhs.getType());
3869 auto resTy = cast(res.getType());
3870
3871
3872 lhs = promote(rewriter, loc, lhs, resTy);
3873
3874 if (flatten) {
3875
3876
3877
3878
3879
3880
3881
3882
3883 auto rhsSize = cast(rhs.getType()).getShape()[0];
3884 auto resSize = cast(res.getType()).getShape()[1];
3885
3887 for (int i = 0; i < resSize / rhsSize; ++i) {
3888 for (int j = 0; j < rhsSize; ++j)
3889 indices.push_back(j);
3890 }
3891
3892 rhs = rewriter.createvector::ShuffleOp(loc, rhs, rhs, indices);
3893 }
3894
3895 rhs = rewriter.createvector::BroadcastOp(
3896 loc, resTy.clone(rhsTy.getElementType()), rhs);
3897
3898 rhs = promote(rewriter, loc, rhs, resTy);
3899
3900 if (!lhs || !rhs)
3901 return nullptr;
3902
3903 if (isa(resTy.getElementType()))
3904 return rewriter.createvector::FMAOp(loc, lhs, rhs, res);
3905
3906 auto mul = rewriter.createarith::MulIOp(loc, lhs, rhs);
3907 return rewriter.createarith::AddIOp(loc, mul, res);
3908 }
3909
3910
3911
3912 FailureOr<Operation *> generateNonChanneledConv() {
3915 if (!iters({Par(), Red()}))
3917 "failed to match conv::W 1-par 1-red");
3918
3919
3920 if (layout({ {w + kw},
3921 {kw},
3922 {w}}))
3924
3926 }
3927
3928
3929
3930 FailureOr<Operation *> generateNwcConv() {
3932 bindDims(ctx, n, w, f, kw, c);
3933 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3935 op, "failed to match conv::Nwc 3-par 2-red");
3936
3937
3938 if (layout({ {n, strideW * w + dilationW * kw, c},
3939 {kw, c, f},
3940 {n, w, f}}))
3942
3944 }
3945
3946
3947
3948 FailureOr<Operation *> generateNcwConv() {
3950 bindDims(ctx, n, f, w, c, kw);
3951 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3953 op, "failed to match conv::Ncw 3-par 2-red");
3954
3955 if (layout({ {n, c, strideW * w + dilationW * kw},
3956 {f, c, kw},
3957 {n, f, w}}))
3959
3961 }
3962
3963
3964
3965 FailureOr<Operation *> generateNwcPooling() {
3968 if (!iters({Par(), Par(), Par(), Red()}))
3970 "failed to match pooling 3-par 1-red");
3971
3972
3973 if (layout({ {n, strideW * w + dilationW * kw, c},
3974 {kw},
3975 {n, w, c}}))
3977
3978 return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3979 }
3980
3981
3982
3983 FailureOr<Operation *> generateNcwPooling() {
3986 if (!iters({Par(), Par(), Par(), Red()}))
3988 "failed to match pooling 3-par 1-red");
3989
3990 if (layout({ {n, c, strideW * w + dilationW * kw},
3991 {kw},
3992 {n, c, w}}))
3994
3995 return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3996 }
3997
3998
3999
4000 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4001 bool vecChDimScalableFlag = false,
4002 bool flatten = false) {
4005 if (!iters({Par(), Par(), Par(), Red()}))
4007 op, "failed to match depthwise::Nwc conv 3-par 1-red");
4008
4009
4010 if (layout({ {n, strideW * w + dilationW * kw, c},
4011 {kw, c},
4012 {n, w, c}}))
4013 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4014
4015 return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
4016 }
4017
4018 private:
4019 ConvOperationKind oper = ConvOperationKind::Conv;
4020 StringAttr redOp;
4021 StringAttr poolExtOp;
4022 bool isPoolExt = false;
4023 int strideW, dilationW;
4024 Value lhsShaped, rhsShaped, resShaped;
4025 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4026 vector::CombiningKind reductionKind;
4027
4028
4029 void setConvOperationKind(Operation *reduceOp) {
4030 int numBlockArguments =
4031 llvm::count_if(reduceOp->getOperands(), llvm::IsaPred);
4032 if (numBlockArguments == 1) {
4033
4034
4035
4036
4037 auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
4038 llvm::IsaPred);
4039 Operation *feedOp = (*feedValIt).getDefiningOp();
4041 oper = ConvOperationKind::Pool;
4042 isPoolExt = true;
4044 return;
4045 }
4046 oper = ConvOperationKind::Conv;
4047 return;
4048 }
4049
4050 oper = ConvOperationKind::Pool;
4051 isPoolExt = false;
4052 }
4053 };
4054 }
4055
4056
4057
4060 ArrayRef inputScalableVecDims, bool flatten1DDepthwiseConv) {
4061 Conv1DGenerator conv1dGen(rewriter, op);
4062 auto res = conv1dGen.generateNonChanneledConv();
4063 if (succeeded(res))
4064 return res;
4065 res = conv1dGen.generateNwcConv();
4066 if (succeeded(res))
4067 return res;
4068 res = conv1dGen.generateNcwConv();
4069 if (succeeded(res))
4070 return res;
4071 res = conv1dGen.generateNwcPooling();
4072 if (succeeded(res))
4073 return res;
4074 res = conv1dGen.generateNcwPooling();
4075 if (succeeded(res))
4076 return res;
4077
4078
4079
4080
4081 uint64_t vecChDimSize = ShapedType::kDynamic;
4082 bool vecChDimScalableFlag = false;
4083 if (!inputVecSizes.empty()) {
4084
4085
4086 assert((isalinalg::DepthwiseConv1DNwcWcOp(*op) ||
4087 isalinalg::DepthwiseConv1DNcwCwOp(*op)) &&
4088 "Not a 1D depthwise conv!");
4089 size_t chDimIdx =
4091 .Caselinalg::DepthwiseConv1DNwcWcOp([](auto conv) { return 2; })
4092 .Caselinalg::DepthwiseConv1DNcwCwOp([](auto conv) { return 1; });
4093
4094 vecChDimSize = inputVecSizes[chDimIdx];
4095 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4096 }
4097 return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4098 flatten1DDepthwiseConv);
4099 }
4100
4103
4107 if (failed(resultOrFail))
4108 return failure();
4109 Operation *newOp = *resultOrFail;
4111 rewriter.eraseOp(op.getOperation());
4112 return success();
4113 }
4114 assert(newOp->getNumResults() == 1 && "expected single result");
4116 return success();
4117 }
4118 };
4119
4123 }
SmallVector< int64_t > outerDimsPerm
SmallVector< OpFoldResult > innerTiles
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
SmallVector< int64_t > innerDimsPos
static std::optional< VectorShape > vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a linalg::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static inner_tiles (2) constant padding value and (3) input vector ...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumInputs() const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
operand_iterator operand_end()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.