MLIR: lib/Dialect/Linalg/Transforms/Transforms.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/InterleavedRange.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <type_traits>
39 #include
40
41 #define DEBUG_TYPE "linalg-transforms"
42
43 using namespace mlir;
45
46 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
47 #define DBGSNL() (llvm::dbgs() << "\n")
48
49
50
51
52
53
54
55
56
57
58
59
63 .Casescf::ForOp([&](scf::ForOp forOp) {
64 scf::ForOp partialIteration;
66 partialIteration)))
67 return partialIteration->getResults();
68 assert(!partialIteration && "expected that loop was not peeled");
69 return forOp->getResults();
70 })
72 }
73
74
75
78 for (auto loopOp : loops)
80 }
81
82
83
84
85
86 #ifndef NDEBUG
87
89 bool found = false;
91 if (!e.isFunctionOfDim(dim))
92 continue;
93 if (found)
94 return false;
95 found = true;
96 }
97 return true;
98 }
99
101 return llvm::interleaved(ri, ", ", "|", "");
102 }
103 #endif
104
105
106
108 int64_t dim) {
109 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
112 continue;
113 return i;
114 }
115 return std::nullopt;
116 }
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153 static FailureOr<SmallVector<std::optional<int64_t>>>
156 int64_t dim) {
157 int64_t newDim = iteratorTypes.size();
158 iteratorTypes.push_back(iteratorTypes[dim]);
159
161 indexingMaps.size(), std::nullopt);
163 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
164 ++operandIdx) {
165 AffineMap map = indexingMaps[operandIdx];
166
167
168 assert(map.getNumDims() == newDim && "num dims invariant violation");
170
171
172
173
174
176 "num results invariant violation");
178 if (!maybeOperandDimensionToPack.has_value()) {
179 newMaps.push_back(map);
180 continue;
181 }
182
183
184 if (!isa(map.getResult(maybeOperandDimensionToPack.value())))
185 return failure();
186
187
190 newMaps.push_back(map);
191
192
193 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
194 }
195 indexingMaps = newMaps;
196
197 return packedDimPerIndexingMap;
198 }
199
200 namespace {
201
202
203 struct PackedOperandsDim {
206 };
207
208
209 struct PackedOperandsDimList {
210 void pushBack(PackedOperandsDim &&packedOperandsDims) {
211 spec.emplace_back(packedOperandsDims);
212 }
213
215
217
218 private:
220 };
221
222 }
223
225 linalg::PackOp packOp,
226 bool lowerPadLikeWithInsertSlice) {
227
228 auto packedTensorType =
229 cast(packOp->getResultTypes().front());
230 if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
232 packOp,
233 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
234 }
235
236 Location loc = packOp->getLoc();
239
240
241
242 PackingMetadata packingMetadata = computePackingMetadata(
243 packedTensorType.getRank(), packOp.getInnerDimsPos());
246
247
248
251
252
257 for (auto [pos, innerSize] :
258 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
259 int outerPos =
260 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
268 auto map = AffineMap::get(2, 1, d0 * s0 - d1);
270 rewriter, loc, map, {outerSize, origSize, innerSize});
271 }
272 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
274 packingMetadata.reassociations);
275 Value paddingValue = packOp.getPaddingValue();
276 if (!paddingValue) {
277 paddingValue = rewriter.createarith::ConstantOp(
279 }
280 auto padOp =
281 rewriter.createtensor::PadOp(loc, collapsed, packOp.getSource(), lows,
282 highs, paddingValue, false);
283
284 LLVM_DEBUG(
286 DBGS() << "insertPositions: "
287 << llvm::interleaved(packingMetadata.insertPositions);
288 DBGSNL(); DBGS() << "outerPositions: "
289 << llvm::interleaved(packingMetadata.outerPositions);
291 << llvm::interleaved(packedTensorType.getShape());
292 DBGSNL(); DBGS() << "packedToStripMinedShapePerm: "
293 << llvm::interleaved(packedToStripMinedShapePerm);
295 DBGS() << "reassociations: "
296 << llvm::interleaved(llvm::map_range(
299 DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
300 DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
301
302 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
303
304
305
308
310
311
312
315
320
321 auto insertSliceOp = rewriter.createtensor::InsertSliceOp(
322 loc, padOp, packOp.getDest(),
323 zeros, sizes, ones);
324
325 LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
326
327 rewriter.replaceOp(packOp, insertSliceOp->getResults());
328
330 nullptr};
331 }
332 }
333
334
335 auto expandShapeResultType =
337 auto reshapeOp = rewriter.createtensor::ExpandShapeOp(
338 loc, expandShapeResultType, padOp.getResult(),
339 packingMetadata.reassociations);
340
341
344 auto transposeOp = rewriter.createlinalg::TransposeOp(
345 loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
346
348 DBGS() << "reshape op: " << reshapeOp; DBGSNL();
349 DBGS() << "transpPerm: " << llvm::interleaved(transpPerm);
350 DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
351
352
353 rewriter.replaceOp(packOp, transposeOp->getResults());
354
356 }
357
358 FailureOr
360 bool lowerUnpadLikeWithExtractSlice) {
361 Location loc = unPackOp->getLoc();
364
365 RankedTensorType packedTensorType = unPackOp.getSourceType();
366 int64_t packedRank = packedTensorType.getRank();
367
369 auto destTensorType = cast(unPackOp.getDest().getType());
370 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
371
372
374
375
378
379 auto extractSliceOp = rewriter.createtensor::ExtractSliceOp(
380 loc, destTensorType, unPackOp.getSource(),
383
384 rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
385
387 nullptr, extractSliceOp};
388 }
389
390
391
392 PackingMetadata packingMetadata;
395
396
397
400
401
402 RankedTensorType stripMinedTensorType =
404 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
405 stripMinedTensorType, packingMetadata.reassociations);
406
407
408
412 auto emptyOp = rewriter.createtensor::EmptyOp(
413 loc, dims, stripMinedTensorType.getElementType());
414 auto transposeOp = rewriter.createlinalg::TransposeOp(
415 loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
416
417 LLVM_DEBUG(
419 DBGS() << "insertPositions: "
420 << llvm::interleaved(packingMetadata.insertPositions);
422 << llvm::interleaved(packedTensorType.getShape());
423 DBGSNL(); DBGS() << "packedToStripMinedShapePerm: "
424 << llvm::interleaved(packedToStripMinedShapePerm);
426 DBGS() << "reassociations: "
427 << llvm::interleaved(llvm::map_range(
430 DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
431 DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
432
433
434 auto reshapeOp = rewriter.createtensor::CollapseShapeOp(
435 loc, collapsedType, transposeOp->getResult(0),
436 packingMetadata.reassociations);
437
438
439 int64_t destRank = destTensorType.getRank();
440 auto extractSliceOp = rewriter.createtensor::ExtractSliceOp(
441 loc, destTensorType, reshapeOp->getResult(0),
445
446
447 auto copyOp = rewriter.createlinalg::CopyOp(
448 loc, extractSliceOp->getResult(0), unPackOp.getDest());
449
450
451 rewriter.replaceOp(unPackOp, copyOp->getResults());
452
453 return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
454 }
455
457 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
459 for (auto &i : spec) {
460 if (!i.packedDimForEachOperand[operandPos].has_value())
461 continue;
462 res.push_back(i.packedDimForEachOperand[operandPos].value());
463 }
464 return res;
465 }
466
468 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
470 for (auto &i : spec) {
471 if (!i.packedDimForEachOperand[operandPos].has_value())
472 continue;
473 res.push_back(i.packedSize);
474 }
475 return res;
476 }
477
478
479
480
482 linalg::LinalgOp linalgOp,
484 if (packedSizes.size() != linalgOp.getNumLoops()) {
486 "incorrect number of pack sizes");
487 }
488
489 Location loc = linalgOp->getLoc();
492 linalgOp.getIteratorTypesArray();
493 LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n"
494 << "maps: " << llvm::interleaved(indexingMaps) << "\n"
495 << "iterators: " << llvm::interleaved(iteratorTypes)
496 << "\n");
497
500
501 PackedOperandsDimList listOfPackedOperandsDim;
502 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
503 std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]);
504
505 if (maybeConstant.has_value() && maybeConstant.value() == 0)
506 continue;
507
508 PackedOperandsDim packedOperandsDims;
509 packedOperandsDims.packedSize = packedSizes[i];
510 FailureOr<SmallVector<std::optional<int64_t>>>
511 maybePackedDimForEachOperand =
513 if (failed(maybePackedDimForEachOperand))
514 return failure();
515 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
516 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
517
518 LLVM_DEBUG(
519 DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
520 << "\n"
521 << "maps: " << llvm::interleaved(indexingMaps) << "\n"
522 << "iterators: " << llvm::interleaved(iteratorTypes) << "\n"
523 << "packedDimForEachOperand: "
524 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand)
525 << "\n");
526 }
527
528
531 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
533 for (const auto &operandsList : {inputOperands, initOperands}) {
534 for (OpOperand *opOperand : operandsList) {
535 int64_t pos = opOperand->getOperandNumber();
536 Value operand = opOperand->get();
538 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
540 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
541 LLVM_DEBUG(DBGS() << "operand: " << operand << "\n"
542 << "innerPos: " << llvm::interleaved(innerPos) << "\n"
543 << "innerPackSizes: "
544 << llvm::interleaved(innerPackSizes) << "\n");
545 if (innerPackSizes.empty()) {
546 inputsAndInits.push_back(operand);
547 continue;
548 }
549 Value dest = linalg::PackOp::createDestinationTensor(
550 rewriter, loc, operand, innerPackSizes, innerPos,
551 {});
552 ShapedType operandType = cast(operand.getType());
553 bool areConstantTiles =
556 });
557 if (areConstantTiles && operandType.hasStaticShape() &&
558 !linalg::PackOp::requirePaddingValue(
559 operandType.getShape(), innerPos,
560 cast(dest.getType()).getShape(), {},
561 innerPackSizes)) {
562 packOps.push_back(rewriter.createlinalg::PackOp(
563 loc, operand, dest, innerPos, innerPackSizes));
564 } else {
565
566
567 auto zeroAttr =
569 Value zero = rewriter.createarith::ConstantOp(loc, zeroAttr);
570 packOps.push_back(rewriter.createlinalg::PackOp(
571 loc, operand, dest, innerPos, innerPackSizes, zero));
572 }
573 inputsAndInits.push_back(packOps.back());
574 }
575 }
576
577
579 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
581 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
582 auto packedLinalgOp = rewriter.createlinalg::GenericOp(
583 linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
584 iteratorTypes);
586
587
588 for (OpResult result : packedLinalgOp->getResults()) {
589 int64_t resultNum = result.getResultNumber();
590 linalg::PackOp maybePackedInit =
591 inits[resultNum].getDefiningOplinalg::PackOp();
592 if (!maybePackedInit) {
593 results.push_back(result);
594 continue;
595 }
596
597 unPackOps.push_back(rewriter.createlinalg::UnPackOp(
598 packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
599 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
600 results.push_back(unPackOps.back());
601 }
602
603
604 rewriter.replaceOp(linalgOp, results);
605
606
608 castlinalg::LinalgOp(packedLinalgOp.getOperation()),
609 unPackOps};
610 }
611
612
613
614
615
616
617
618
619
620 static RankedTensorType permuteShape(RankedTensorType tensorType,
625 }
626
627
628
629
630
631
632
636
637 assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
638
639
641 cast(opOperand.get().getType()), permutation);
642 (void)tensorType;
643 assert(tensorType == transposedValue.getType() &&
644 "expected tensor type mismatch");
645
646
647
649 llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
653 permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
654
655
657 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
658
661
663 auto transposedGenericOp = rewriter.createlinalg::GenericOp(
664 linalgOp->getLoc(),
665
666 operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
667 operandsRef.take_front(linalgOp.getNumDpsInputs()),
668 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
669 indexingMaps,
670 linalgOp.getIteratorTypesArray());
671 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
672 rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
673
674 return castlinalg::LinalgOp(transposedGenericOp.getOperation());
675 }
676
677 FailureOr
679 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
682 Location loc = linalgOp.getLoc();
683
684
686 linalg::PackOp transposedPackOp =
687 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
688
689 if (!packOp.getResult().hasOneUse())
690 return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
691
692 OpOperand &packUse = *packOp->getUses().begin();
693 if (packUse.getOwner() != linalgOp) {
695 linalgOp, "not a single use by the LinalgOp target");
696 }
697 if (maybeUnPackOp &&
698 (!linalgOp.isDpsInit(&packUse) ||
699 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
701 "not produced by the LinalgOp target");
702 }
703
704
705
706
707 int64_t numLeadingDims = packOp.getSourceRank();
708 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
709
710
712 if (permutation.empty())
713 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
714
715 if (innerPerm.empty()) {
716 llvm::append_range(
717 permutation,
718 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
719 } else {
720 llvm::append_range(permutation,
721 llvm::map_range(innerPerm, [&](int64_t pos) {
722 return numLeadingDims + pos;
723 }));
724 }
727
728
729
731
734 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
735
736
737 linalg::UnPackOp transposedUnPackOp;
738 if (maybeUnPackOp) {
740 transposedLinalgOp->getOpOperand(packUseOperandNumber);
741 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
743 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
744 rewriter, loc, transposedResult, innerPerm, outerPerm);
745
746 rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
747 }
748
749
750 rewriter.replaceOp(packOp, transposedPackOp->getResults());
751
753 transposedUnPackOp};
754 }
755
756
757
758
759
760
761
762
763
764
765
766
767
768 FailureOr
773 assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
774 assert((mnkPaddedSizesNextMultipleOf.empty() ||
775 mnkPaddedSizesNextMultipleOf.size() == 3) &&
776 "num of packing sizes next multiple should be empty or of size 3");
777 assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
779
780 int64_t numLoops = linalgOp.getNumLoops();
781 if (numLoops <= 2) {
782 LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got "
783 << numLoops << "\nin: " << linalgOp << "\n");
785 linalgOp, "need 3+ loops to find a matmul to pack");
786 }
787
788
789 int64_t numPackedDims = mnkPackedSizes.size();
791 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
792 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
794 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
795 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
797 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
798 paddedSizesNextMultipleOf[mnkOrder[i]] =
799 mnkPaddedSizesNextMultipleOf.empty() ? 0
800 : mnkPaddedSizesNextMultipleOf[i];
801 }
802
803
804 FailureOr maybeDimensions =
806 if (failed(maybeDimensions)) {
807 LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp
808 << "\n");
810 "couldn't infer matmul iterators");
811 }
812
813
814
815
816
817
818 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
819 kPos = maybeDimensions->k.back();
821 DBGS() << "Start packing generic op greedily with (m@" << mPos
822 << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
823 << "\n";);
824
825
826 auto genericOp = dyn_cast(linalgOp.getOperation());
827 if (!genericOp) {
828 FailureOr generalizeResult =
830 assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
831 genericOp = *generalizeResult;
832 }
833
834
835
836
839 LLVM_DEBUG(DBGS() << "perm: " << llvm::interleaved(permutation) << "\n");
840
842 FailureOr interchangeResult =
844 assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
845 genericOp = *interchangeResult;
846 LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
863 cast(genericOp.getOperation())
864 .createLoopRanges(rewriter, genericOp.getLoc());
865
866
867
868 LLVM_DEBUG(DBGS() << "paddedSizesNextMultipleOf: "
869 << llvm::interleaved(paddedSizesNextMultipleOf) << "\n"
870 << "loopRanges: "
871 << llvm::interleaved(llvm::map_range(
872 loopRanges, [](Range r) { return r.size; }))
873 << "\n");
876 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
877 if (paddedSizesNextMultipleOf[i] == 0) {
878 adjustedPackedSizes.push_back(packedSizes[i]);
879 continue;
880 }
885 rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0,
886 {loopRanges[adjustedPackedSizes.size()].size,
887 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
888 }
889 LLVM_DEBUG(DBGS() << "adjustedPackedSizes: "
890 << llvm::interleaved(adjustedPackedSizes) << "\n");
891
892
893
894
895
896 return pack(rewriter, genericOp, adjustedPackedSizes);
897 }
898
899
900
901
902
905 assert(!tileSizeComputationFunction && "tile sizes already set");
907 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
910 &op->getParentOfTypefunc::FuncOp().getBody().front());
911 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
913 return v;
914 }));
915 };
916 return *this;
917 }
918
920 memref::CopyOp copyOp, PatternRewriter &rewriter) const {
922 }
923
924
925
929 auto padValue = padOp.getConstantPaddingValue();
930 if (padValue) {
931
932 if (padValue.getParentBlock() == &padOp.getRegion().front())
933 rewriter.moveOpBefore(padValue.getDefiningOp(), padOp);
934 return rewriter.create(padOp.getLoc(), padValue, dest).result();
935 }
936
937
938 auto generateOp = rewriter.createtensor::GenerateOp(
939 padOp.getLoc(), padOp.getResultType(), dynSizes);
940
942 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
943 return generateOp;
944 }
945
946 LogicalResult
949
951 if (auto val = llvm::dyn_cast_if_present(ofr))
952 return val;
953 return rewriter
955 padOp.getLoc(), cast(cast(ofr)).getInt())
956 .getResult();
957 };
958
959 auto resultType = padOp.getResultType();
960
963 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
964 if (resultType.isDynamicDim(dim)) {
966 padOp.getSource(), dim));
967
968 auto plusLow = rewriter.createOrFoldarith::AddIOp(
969 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
970 auto plusHigh = rewriter.createOrFoldarith::AddIOp(
971 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
972 dynSizes.push_back(plusHigh);
973 }
974 staticSizes.push_back(resultType.getDimSize(dim));
975 }
976
977
978 Value emptyTensor = rewriter.createtensor::EmptyOp(
979 padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
981
982
983 auto sourceType = padOp.getSourceType();
984
987
991 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
992 strides);
993
994 return success();
995 }
996
998 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
999 if (!sliceOp.hasUnitStride())
1000 return failure();
1001
1002 auto padOp = sliceOp.getSource().getDefiningOptensor::PadOp();
1003 if (!padOp)
1004 return failure();
1005
1006 bool zeroSliceGuard = true;
1007 if (controlFn) {
1008 if (std::optional control = controlFn(sliceOp))
1009 zeroSliceGuard = *control;
1010 else
1011 return failure();
1012 }
1013
1014 FailureOr tilingResult =
1016 sliceOp.getMixedSizes(), zeroSliceGuard);
1017 if (failed(tilingResult))
1018 return failure();
1019
1020 RankedTensorType sourceType = sliceOp.getSourceType();
1021 RankedTensorType resultType = sliceOp.getResultType();
1022
1023
1024
1025 if (sourceType.getRank() == resultType.getRank()) {
1026 rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
1027 return success();
1028 }
1029
1030
1032 rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1033
1034 rewriter.replaceOp(sliceOp, rankReduced);
1035 return success();
1036 }
1037
1038
1039
1040
1041
1042
1044 linalg::PackOp packOp) {
1045 Value input = packOp.getSource();
1046 if (!packOp.getPaddingValue()) {
1047 return input;
1048 }
1049
1050 assert(llvm::all_of(packOp.getAllOuterDims(),
1051 [](int64_t val) { return val == 1; }) &&
1052 "some outer dims are != 1");
1053
1054 Location loc = packOp.getLoc();
1055 ShapedType inputType = packOp.getSourceType();
1056 int64_t inputRank = inputType.getRank();
1057
1059 packOp.getDimAndTileMapping();
1060
1061
1063
1064
1066 for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1067
1068
1069 if (!tileAndPosMapping.count(dimIdx)) {
1070 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1071 assert(inputDimSize == 1 &&
1072 "with all outer dims == 1, this non-tiled input dim should be 1!");
1073 paddedShape.push_back(inputDimSize);
1074 continue;
1075 }
1076
1077
1078
1079
1080 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1081
1082
1083 std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
1084 if (cstTileSize.has_value()) {
1085 paddedShape.push_back(cstTileSize.value());
1086 continue;
1087 }
1088
1089
1090 paddedShape.push_back(ShapedType::kDynamic);
1091
1092
1093 dynamicTileSizes.push_back(llvm::dyn_cast(tileSizeForDim));
1094 }
1095 auto resultType =
1098 false, loc, builder,
1099 dynamicTileSizes);
1100 }
1101
1102
1103
1104
1105
1108 constexpr int64_t kNonTiledMarker = -1;
1111 vec[value] = index;
1113 vec, [&](int64_t v) { return v != kNonTiledMarker; });
1114
1116 }
1117
1118
1119
1127 int64_t dim = 0;
1128 int64_t unpackedRank = shape.size();
1129 for (auto i : llvm::seq(0, unpackedRank)) {
1131 innerDims.push_back(dim++);
1132 continue;
1133 }
1134 if (shape[i] == 1)
1135 continue;
1136 outerDims.push_back(dim++);
1138 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1139 }
1140
1141
1144 applyPermutationToVector<int64_t>(innerDims, innerPerm);
1145
1146
1148
1149 rankReducedOuterDimsPerm =
1151 if (!rankReducedOuterDimsPerm.empty())
1152 applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1153
1154
1155 perm.append(innerDims);
1156
1157 return perm;
1158 }
1159
1161 linalg::PackOp packOp, PatternRewriter &rewriter) const {
1162
1163
1164 if (llvm::any_of(packOp.getAllOuterDims(),
1165 [](int64_t dim) { return dim != 1; })) {
1167 packOp, "not all outer dimensions of the result are 1s");
1168 }
1169
1172 Location loc = packOp.getLoc();
1173
1176 packOp.getDimAndTileMapping();
1177 int64_t srcRank = packOp.getSourceRank();
1178 int64_t destRank = packOp.getDestRank();
1179 int64_t numTiles = destRank - srcRank;
1180
1181 if (!llvm::all_of(packOp.getInnerDimsPos(),
1182 [&srcRank, &numTiles](int64_t dimPos) {
1183 return dimPos >= (srcRank - numTiles - 1);
1184 }))
1186 packOp, "Attempting to tile non-trailing source dims!");
1187
1188
1189
1190
1192 for (auto i : llvm::seq(0, srcRank)) {
1193 if (dimAndTileMapping.count(i)) {
1194
1195
1196
1197 auto [_, tileSize] =
1199 tileSizes.push_back(tileSize);
1200 }
1201 }
1202
1203
1204
1205
1206
1207
1208
1209
1213 for (int64_t i = 0; i < (srcRank - numTiles); i++)
1214 srcPermForTranspose.push_back(i);
1215
1217
1218 LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
1219 << "perm: " << llvm::interleaved(srcPermForTranspose)
1220 << "\n");
1221
1222
1224 oneIdxAttr);
1225 transShapeForEmptyOp.append(tileSizes);
1226
1227 applyPermutationToVector(transShapeForEmptyOp,
1228 srcPermForTranspose);
1229 Value empty = rewriter.createtensor::EmptyOp(
1230 loc, transShapeForEmptyOp, packOp.getSourceType().getElementType());
1231
1232
1233 auto transposedOp = rewriter.createlinalg::TransposeOp(loc, input, empty,
1234 srcPermForTranspose);
1235
1236
1237
1240
1242 oneIdxAttr);
1244
1245 for (auto tileSize : packOp.getMixedTiles()) {
1246 auto [tileSizeStatic, tileSizeOfr] =
1248 writeSizes.push_back(tileSizeOfr);
1249 writeShape.push_back(tileSizeStatic);
1250 }
1251
1252
1253 auto insert = rewriter.createtensor::InsertSliceOp(
1254 loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
1255 writeSizes, writeStrides);
1256 rewriter.replaceOp(packOp, insert.getResult());
1257
1258 return success();
1259 }
1260
1262 linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1263 int64_t srcRank = unpackOp.getSourceRank();
1264 int64_t destRank = unpackOp.getDestRank();
1265 ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
1267 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1268 [](int64_t dim) { return dim != 1; })) {
1270 unpackOp,
1271 "require the tiled outer dimensions of the result are all 1s");
1272 }
1273
1274
1275
1276 Location loc = unpackOp.getLoc();
1277 Value source = unpackOp.getSource();
1279 unpackOp.getDimAndTileMapping();
1282
1283
1284
1285
1287
1288
1289
1291
1294
1295
1296
1297
1298
1299
1301
1302 for (auto i : llvm::seq(0, destRank)) {
1303
1304
1305
1306
1307
1308
1309
1310
1311 if (dimAndTileMapping.count(i)) {
1312 extractSliceSizes.push_back(oneIdxAttr);
1313 continue;
1314 }
1315
1316
1317
1318 if (ShapedType::isDynamic(srcShape[i])) {
1320 rewriter.createtensor::DimOp(loc, source, i).getResult();
1321 extractSliceSizes.push_back(dynamicDim);
1322 shapeForEmptyOp.push_back(dynamicDim);
1323 } else {
1324 extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1325 if (srcShape[i] != 1)
1326 shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
1327 }
1328
1329
1330 if (srcShape[i] != 1) {
1331 readShapeForExtractSlice.push_back(srcShape[i]);
1332 }
1333 }
1334
1335
1336 auto mixedTiles = unpackOp.getMixedTiles();
1337 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1338 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1339
1340
1341
1342 auto tileShape = srcShape.drop_front(destRank);
1343
1344 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1345 Type elemType = unpackOp.getSourceType().getElementType();
1347 Value innerTile = rewriter.createtensor::ExtractSliceOp(
1348 loc, readType, unpackOp.getSource(), extractSliceOffsets,
1349 extractSliceSizes, extractSliceStrides);
1350
1351
1353 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1354
1356 applyPermutationToVector(shapeForEmptyOp, perm);
1357
1359 rewriter.createtensor::EmptyOp(loc, shapeForEmptyOp, elemType);
1360 auto transposedOp =
1361 rewriter.createlinalg::TransposeOp(loc, innerTile, empty, perm);
1362
1363
1364
1365 int numLoops = shapeForEmptyOp.size();
1369 ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
1370 for (auto i : llvm::seq(0, destRank)) {
1371 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1372 tileSizes.push_back(
1374 }
1375
1376 auto partialTile = rewriter.createtensor::ExtractSliceOp(
1377 loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
1378
1379
1383 for (int i = 0, idx = 0; i < destRank; ++i) {
1384 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1385 writeSizes.push_back(tileSizes[idx++]);
1386 else
1387 writeSizes.push_back(oneIdxAttr);
1388 }
1389 auto insert = rewriter.createtensor::InsertSliceOp(
1390 loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
1391 writeStrides);
1392 rewriter.replaceOp(unpackOp, insert.getResult());
1393
1394 return success();
1395 }
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405 template <typename Conv2DOp, typename Conv1DOp>
1408 if (convOp.hasPureBufferSemantics())
1409 return failure();
1410
1411 Value input = convOp.getInputs().front();
1412 Value kernel = convOp.getInputs().back();
1413 Value output = convOp.getOutputs().front();
1414
1415 auto inputType = dyn_cast(input.getType());
1416 auto kernelType = dyn_cast(kernel.getType());
1417 auto outputType = dyn_cast(output.getType());
1418
1419 auto kernelShape = kernelType.getShape();
1420 auto outputShape = outputType.getShape();
1421
1422
1423 auto [khIndex, kwIndex, ohIndex, owIndex] =
1425 convOp)
1426 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1427 return std::make_tuple(0, 1, 1, 2);
1428 })
1429 .Case([&](linalg::Conv2DNchwFchwOp op) {
1430 return std::make_tuple(2, 3, 2, 3);
1431 })
1432 .Case([&](linalg::PoolingNhwcSumOp op) {
1433 return std::make_tuple(0, 1, 1, 2);
1434 })
1435 .Case([&](linalg::PoolingNchwSumOp op) {
1436 return std::make_tuple(0, 1, 2, 3);
1437 })
1438 .Case([&](linalg::PoolingNhwcMaxOp op) {
1439 return std::make_tuple(0, 1, 1, 2);
1440 })
1441 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1442 return std::make_tuple(0, 1, 1, 2);
1443 })
1444 .Case([&](linalg::PoolingNhwcMinOp op) {
1445 return std::make_tuple(0, 1, 1, 2);
1446 })
1447 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1448 return std::make_tuple(0, 1, 1, 2);
1449 })
1450 .Case([&](linalg::PoolingNchwMaxOp op) {
1451 return std::make_tuple(0, 1, 2, 3);
1452 })
1454 llvm_unreachable("unexpected conv2d/pool2d operation.");
1455 return std::make_tuple(0, 0, 0, 0);
1456 });
1457
1458
1459
1460 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1461 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1462 bool removeH = (khSize == 1 && ohSize == 1);
1463 bool removeW = (kwSize == 1 && owSize == 1);
1464 if (!removeH && !removeW)
1465 return failure();
1466
1467
1468
1470 RankedTensorType newInputType =
1471 RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
1472 RankedTensorType newKernelType =
1473 RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
1474 RankedTensorType newOutputType =
1475 RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
1476
1477
1478 Location loc = convOp.getLoc();
1480 rewriter, loc, input, newInputType);
1482 rewriter, loc, kernel, newKernelType);
1484 rewriter, loc, output, newOutputType);
1485
1486
1487
1488 auto strides =
1489 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1490 strides.erase(strides.begin() + (removeH ? 0 : 1));
1492
1493 auto dilations =
1494 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1495 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1497
1498 auto conv1DOp = rewriter.create(
1499 loc, newOutputType, ValueRange{newInput, newKernel},
1500 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1501
1502
1504 rewriter, loc, conv1DOp.getResult(0), output);
1505 rewriter.replaceOp(convOp, inserted);
1506
1507 return conv1DOp;
1508 }
1509
1511 Conv1DNwcWcfOp>;
1513 Conv1DNcwFcwOp>;
1515 PoolingNwcSumOp>;
1517 PoolingNcwSumOp>;
1519 PoolingNwcMaxOp>;
1521 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1523 PoolingNwcMinOp>;
1525 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1527 PoolingNcwMaxOp>;
1528
1529 FailureOr
1531 DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
1532 if (convOp.hasPureBufferSemantics())
1533 return failure();
1534
1535 Value input = convOp.getInputs().front();
1536 Value kernel = convOp.getInputs().back();
1537 Value output = convOp.getOutputs().front();
1538
1539 auto inputType = dyn_cast(input.getType());
1540 auto kernelType = dyn_cast(kernel.getType());
1541 auto outputType = dyn_cast(output.getType());
1542
1543 auto kernelShape = kernelType.getShape();
1544 auto outputShape = outputType.getShape();
1545
1546
1547
1548 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1549 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1550 bool removeH = (khSize == 1 && ohSize == 1);
1551 bool removeW = (kwSize == 1 && owSize == 1);
1552 if (!removeH && !removeW)
1553 return failure();
1554
1555
1556
1558 RankedTensorType newInputType =
1559 RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1560 RankedTensorType newKernelType =
1561 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1562 RankedTensorType newOutputType =
1563 RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1564
1565
1566 Location loc = convOp.getLoc();
1568 rewriter, loc, input, newInputType);
1570 rewriter, loc, kernel, newKernelType);
1572 rewriter, loc, output, newOutputType);
1573
1574
1575
1576 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1577 strides.erase(strides.begin() + (removeH ? 0 : 1));
1579
1580 auto dilations =
1581 llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1582 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1584
1585 auto conv1DOp = rewriter.create(
1586 loc, newOutputType, ValueRange{newInput, newKernel},
1587 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1588
1589
1591 rewriter, loc, conv1DOp.getResult(0), output);
1592 rewriter.replaceOp(convOp, inserted);
1593
1594 return conv1DOp;
1595 }
1596
1597 FailureOr
1600 if (convOp.hasPureBufferSemantics())
1601 return failure();
1602
1603 Value input = convOp.getInputs().front();
1604 Value kernel = convOp.getInputs().back();
1605 Value output = convOp.getOutputs().front();
1606
1607 auto inputType = dyn_cast(input.getType());
1608 auto kernelType = dyn_cast(kernel.getType());
1609 auto outputType = dyn_cast(output.getType());
1610
1611 auto kernelShape = kernelType.getShape();
1612 auto outputShape = outputType.getShape();
1613
1614
1615
1616 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1617 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1618 bool removeH = (khSize == 1 && ohSize == 1);
1619 bool removeW = (kwSize == 1 && owSize == 1);
1620 if (!removeH && !removeW)
1621 return failure();
1622
1623
1624
1626 RankedTensorType newInputType =
1627 RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
1628 RankedTensorType newKernelType =
1629 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1630 RankedTensorType newOutputType =
1631 RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
1632
1633
1634 Location loc = convOp.getLoc();
1636 rewriter, loc, input, newInputType);
1638 rewriter, loc, kernel, newKernelType);
1640 rewriter, loc, output, newOutputType);
1641
1642 auto conv1DOp = rewriter.create(loc, newOutputType,
1645
1646
1648 rewriter, loc, conv1DOp.getResult(0), output);
1649 rewriter.replaceOp(convOp, inserted);
1650
1651 return conv1DOp;
1652 }
1653
1657 Conv1DNwcWcfOp>,
1659 Conv1DNcwFcwOp>,
1661 patterns.getContext(), benefit);
1667 PoolingNwcMaxUnsignedOp>,
1670 PoolingNwcMinUnsignedOp>,
1672 patterns.getContext(), benefit);
1673 }
1674
1678 }
1679
1682 }
SmallVector< int64_t > outerDimsPerm
SmallVector< int64_t > innerDimsPos
static RankedTensorType permuteShape(RankedTensorType tensorType, ArrayRef< int64_t > permutationVector)
Return a copy of tensorType after permutation by permutationVector.
static SmallVector< int64_t > getPackUnpackRankReducedPerm(ArrayRef< int64_t > shape, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
static std::optional< int64_t > getFirstResultIndexFunctionOf(AffineMap map, int64_t dim)
Return the index of the first result of map that is a function of AffineDimExpr(dim),...
static FailureOr< SmallVector< std::optional< int64_t > > > packLinalgMetadataOnce(SmallVectorImpl< AffineMap > &indexingMaps, SmallVectorImpl< utils::IteratorType > &iteratorTypes, int64_t dim)
Perform one step of packing of a LinalgOp's metadata along dim into the newDim at iteratorTypes....
static LinalgOp transposeOneLinalgOperandAndReplace(RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, ArrayRef< int64_t > permutation, Value transposedValue)
Return a new GenericOp obtained by transposing opOperand by the permutation vector:
static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim)
Return true if map has 0 or 1 result function of AffineDimExpr(dim).
static SmallVector< int64_t > getPackUnpackNormalizedPerm(int rank, ArrayRef< int64_t > perm)
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, linalg::PackOp packOp)
If padding value is set, returns a tensor.pad Op for the source tensor, with the output shape matchin...
static std::string stringifyReassocIndices(ReassociationIndicesRef ri)
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Builder & setShape(ArrayRef< int64_t > newShape)
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, SmallVector< Value > dynOutDims={})
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
ArrayRef< int64_t > ReassociationIndicesRef
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a linalg::PackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of rank-reduced.
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.