MLIR: lib/Dialect/SCF/Transforms/TileUsingInterface.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
14
31 #include "llvm/ADT/ScopeExit.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Debug.h"
34 #include
35
36 #define DEBUG_TYPE "tile-using-interface"
37
38 using namespace mlir;
39
43 auto tileSizes = llvm::to_vector(ts);
45 return tileSizes;
46 };
47 return *this;
48 }
49
52 assert(!numThreadsComputationFunction && "num tiles already set");
53 auto numThreads = llvm::to_vector(nt);
54 numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
55 return numThreads;
56 };
57 return *this;
58 }
59
60
61
64 size_t iterationDomainSize) {
66 if (filledVector.size() < iterationDomainSize) {
67 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
68 filledVector.append(range.begin(), range.end());
69 }
70 if (filledVector.size() > iterationDomainSize)
71 filledVector.resize(iterationDomainSize);
72 return filledVector;
73 }
74
75
76
77
78
79
80 static LogicalResult
83
84 if (options.numThreadsComputationFunction &&
87 loc, "number of threads can only by specified when loop type is "
88 "set to use `scf.forall`");
89 }
90
91
92 if (.interchangeVector.empty()) {
95 loc, "invalid interchange vector, not a permutation of the entire "
96 "iteration space");
97 }
98 }
99 return success();
100 }
101
102
103
110 size_t numLoops = iterationDomain.size();
111
112
113 if (options.numThreadsComputationFunction) {
114 numThreads = options.numThreadsComputationFunction(rewriter, op);
115 numThreads.resize(numLoops, zero);
116
117
118 if (options.tileSizeComputationFunction) {
119 tileSizes = options.tileSizeComputationFunction(rewriter, op);
120 tileSizes.resize(numLoops, zero);
121 return {tileSizes, numThreads};
122 }
123
124
125
126
127
130
133 tileSizes.resize(numLoops, zero);
134 for (auto [index, range, nt] :
137 continue;
138
140 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
141 }
142 tileSizes.resize(numLoops, zero);
143 return {tileSizes, numThreads};
144 }
145
146
147
148
149
150 assert(options.tileSizeComputationFunction &&
151 "expected tile sizes to be specified");
152 tileSizes = options.tileSizeComputationFunction(rewriter, op);
153 tileSizes.resize(numLoops, zero);
154
155 return {tileSizes, numThreads};
156 }
157
158
162 auto iterators = op.getLoopIteratorTypes();
163 assert(iterators.size() == tileSizes.size() &&
164 "expected as many tile size values as number of loops");
165 assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
166 "when specified, expected number of threads to use for each loop");
167
168 for (auto [index, iterator, tileSize] :
170
171
172 if (!numThreads.empty()) {
173 if (std::optional<int64_t> constNumThreads =
175 if (constNumThreads.value() > 1 &&
176 iterator != utils::IteratorType::parallel) {
177 op.emitWarning() << "tiling is not thread safe at axis #" << index;
178 }
179 }
180 continue;
181 }
182
183 if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
184 if (constTileSize.value() > 0 &&
185 iterator != utils::IteratorType::parallel) {
186 op.emitWarning() << "tiling is not thread safe at axis #" << index;
187 }
188 }
189 }
190 }
191
192
195 if (!offsetAsInt)
196 return false;
198 if (!sizeAsInt)
199 return false;
201 if (!strideAsInt)
202 return false;
203 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
204 }
205
206
207
212 if (ts && ts.value() == 1)
213 return tileSize;
214
217 return tileSize;
218
219
220
221
229 }
230
231
232
237 std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
238 std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
239 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
240 return false;
241 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
242 }
243
244
245
246
253 int materializedLoopNum = 0;
254
255 if (!numThreads.empty()) {
257 AffineExpr offsetExpr, residualTileSizeExpr;
260 offsetExpr = d0 + d1 * s0;
261 residualTileSizeExpr = s1 - (d0 + d1 * s0);
262
263 for (auto [nt, tileSize, loopRange] :
264 llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
265
266
267
269 offsets.push_back(loopRange.offset);
270 sizes.push_back(loopRange.size);
271 continue;
272 }
273
274 Value iv = ivs[materializedLoopNum++];
276 rewriter, loc, offsetExpr,
279 rewriter, loc, residualTileSizeExpr,
280 {loopRange.offset, nt, tileSize, loopRange.size});
281
286 {offset, loopRange.size});
288 rewriter, loc,
290 {sizeMinusOffsetPerThread, tileSize});
291 }
292
293
294
295
296
297
298
299
300
301
306 rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
307 }
308
309 offsets.push_back(offset);
310 sizes.push_back(size);
311 }
312 return {offsets, sizes};
313 } else {
314 for (auto [tileSize, loopRange] :
315 llvm::zip_equal(tileSizes, iterationDomain)) {
316
317
318
320 offsets.push_back(loopRange.offset);
321 sizes.push_back(loopRange.size);
322 continue;
323 }
324
325 Value iv = ivs[materializedLoopNum++];
327 offsets.push_back(offset);
330 sizes.push_back(size);
331 }
332 return {offsets, sizes};
333 }
334 }
335
336
342 for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
343
345 continue;
346 lbs.push_back(loopRange.offset);
347 ubs.push_back(loopRange.size);
348 steps.push_back(tileSize);
349 }
350 return {lbs, ubs, steps};
351 }
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
373
374
375
380 if (newDestArgs.empty())
381 return clonedOp;
382 if (auto destinationStyleOp = dyn_cast(clonedOp))
383 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
384 return clonedOp;
385 }
386
387
388
389
390
391
392
393
394
395
401 assert(!loopRanges.empty() && "unexpected empty loop ranges");
402 assert(loopRanges.size() == tileSizes.size() &&
403 "expected as many tile sizes as loop ranges");
405
407 std::tie(lbs, ubs, steps) =
408 getLoopBounds(rewriter, loc, loopRanges, tileSizes);
415
417 for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
418 auto loop =
419 rewriter.createscf::ForOp(loc, lb, ub, step, destinationTensors,
422 loops.push_back(loop);
423 ivs.push_back(loop.getInductionVar());
425 destinationTensors = loop.getRegionIterArgs();
426 }
427
430 if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
431 tiledResults, resultOffsets, resultSizes))) {
433 loc, "failed to generate inner tile loop body");
434 }
435 if (loops.empty())
436 return success();
437
438 assert(tiledResults.size() == destinationTensors.size() &&
439 "Number of results of body should be equal to number of iter args");
440
441
443 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
444 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
445 resultSizes)) {
448 auto insertSlice = rewriter.createtensor::InsertSliceOp(
449 loc, tiledValue, destinationTensor, resultOffset, resultSize,
450 resultStride);
451 yieldedValues.push_back(insertSlice);
452 }
453 rewriter.createscf::YieldOp(loc, yieldedValues);
454
455
456 for (auto [outerLoop, innerLoop] :
460 castscf::ForOp(outerLoop.getOperation()).getBody());
461 rewriter.createscf::YieldOp(outerLoop.getLoc(), innerLoop->getResults());
462 }
463 return success();
464 }
465
466
467
468
469
470
471
472
473
474
475
476
482 assert(!loopRanges.empty() && "unexpected empty loop ranges");
483 assert(loopRanges.size() == tileSizes.size() &&
484 "expected as many tile sizes as loop ranges");
486
487 std::optional mappingAttr;
488 if (!mappingVector.empty())
489 mappingAttr = rewriter.getArrayAttr(mappingVector);
490
491 scf::ForallOp forallOp;
492 bool useNumThreads = !numThreads.empty();
493
494 if (useNumThreads) {
495
497 for (auto nt : numThreads) {
499 continue;
500 nonZeroNumThreads.push_back(nt);
501 }
502 forallOp = rewriter.createscf::ForallOp(loc, nonZeroNumThreads,
503 destinationTensors, mappingAttr);
504 } else {
506 std::tie(lbs, ubs, steps) =
507 getLoopBounds(rewriter, loc, loopRanges, tileSizes);
508 forallOp = rewriter.createscf::ForallOp(loc, lbs, ubs, steps,
509 destinationTensors, mappingAttr);
510 }
511 loops.push_back(forallOp);
512
514 destinationTensors = forallOp.getRegionOutArgs();
515
518 if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
519 destinationTensors, tiledResults, resultOffsets,
520 resultSizes)))
521 return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
522
524 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
525 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
526 resultSizes)) {
529
530 rewriter.createtensor::ParallelInsertSliceOp(
531 loc, tiledValue, destinationTensor, resultOffset, resultSize,
532 resultStride);
533 }
534 return success();
535 }
536
537
538
539
540
541
542
543
544
545
546
552
553
557 return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
558 tiledResults, resultOffsets, resultSizes);
559 }
562 destinationTensors, tiledBodyFn, loops);
563 }
566 rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
567 destinationTensors, tiledBodyFn, loops);
568 }
570 }
571
572 static FailureOr<SmallVector>
578 switch (options.reductionStrategy) {
581 return failure();
582 return initTensors;
585 auto redOp = dyn_cast(op.getOperation());
586 if (!redOp) {
588 op, "PartialReductionOuterReduction tiling strategy is only supported"
589 "for operations implementing PartialReductionOpInterface");
590 }
591
592
593
595 for (auto [idx, iteratorType] :
597 if (iteratorType == utils::IteratorType::reduction)
598 reductionDims.push_back(idx);
599 }
600 return redOp.generateInitialTensorForPartialReduction(
601 rewriter, loc, tileSizes, reductionDims);
602 }
603 default:
605 "unhandled reduction tiling strategy");
606 }
607 }
608
609 static FailureOr
614 switch (options.reductionStrategy) {
616 return op.getTiledImplementation(rewriter, offsets, sizes);
619 auto redOp = dyn_cast(op.getOperation());
620 if (!redOp) {
622 op, "PartialReductionOuterReduction tiling strategy is only "
623 "supported for operations "
624 "implementing PartialReductionOpInterface");
625 }
626
627
628
630 for (auto [idx, iteratorType] :
632 if (iteratorType == utils::IteratorType::reduction)
633 reductionDims.push_back(idx);
634 }
635 return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
636 offsets, sizes, reductionDims);
637 }
638 default:
640 "unhandled reduction tiling strategy");
641 }
642 }
643
644 static LogicalResult
651
652 switch (options.reductionStrategy) {
654 return op.getResultTilePosition(rewriter, index, offsets, sizes,
655 resultOffset, resultSize);
658 auto redOp = dyn_cast(op.getOperation());
659 if (!redOp) {
661 op, "PartialReductionOuterReduction tiling strategy is only supported"
662 "for operations implementing PartialReductionOpInterface");
663 }
664
665
666
668 for (auto [idx, iteratorType] :
670 if (iteratorType == utils::IteratorType::reduction)
671 reductionDims.push_back(idx);
672 }
673 return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
674 resultOffset, resultSize,
675 reductionDims);
676 }
677 default:
679 "unhandled reduction tiling strategy");
680 }
681 }
682
683 static FailureOr
687 switch (options.reductionStrategy) {
689
693 auto redOp = dyn_cast(op.getOperation());
694 if (!redOp) {
696 op, "PartialReductionOuterReduction tiling strategy is only "
697 "supported for operations "
698 "implementing PartialReductionOpInterface");
699 }
700
701
702
704 for (auto [idx, iteratorType] :
706 if (iteratorType == utils::IteratorType::reduction)
707 reductionDims.push_back(idx);
708 }
709 return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
710 reductionDims);
711 }
712 default:
714 "unhandled reduction tiling strategy");
715 }
716 }
717
718
719
720
721
722
723
724
725 template
726 FailureOr
731 }
732
733
734 template <>
735 FailureOr yieldTiledValuesAndReplaceLoopscf::ForOp(
739 Location loc = loopOp.getLoc();
741
742 auto inits = llvm::to_vector(loopOp.getInitArgs());
743 inits.append(newInitOperands.begin(), newInitOperands.end());
744 auto newLoop = rewriter.createscf::ForOp(
745 loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
747
748
749 Block *loopBody = loopOp.getBody();
750 Block *newLoopBody = newLoop.getBody();
752 loopBody, newLoopBody,
753 newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
754
755 auto yieldOp = castscf::YieldOp(newLoopBody->getTerminator());
757
761 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
762 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
763 newRegionIterArgs, tiledValues, resultOffsets,
764 resultSizes))) {
765 rewriter.eraseOp(newLoop);
766 return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
767 }
768
769 SmallVector newYieldValues = llvm::to_vector(yieldOp.getOperands());
770 for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
771 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
772 resultSizes)) {
775 Value insert = rewriter.createtensor::InsertSliceOp(
776 yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
777 resultStride);
778 newYieldValues.push_back(insert);
779 }
780
783 newLoop->getResults().take_front(loopOp.getNumResults()));
784 return cast(newLoop.getOperation());
785 }
786
787
788 template <>
789 FailureOr yieldTiledValuesAndReplaceLoopscf::ForallOp(
793 Location loc = loopOp.getLoc();
795 auto inits = llvm::to_vector(loopOp.getOutputs());
796 inits.append(newInitOperands.begin(), newInitOperands.end());
797 auto newLoop = rewriter.createscf::ForallOp(
798 loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
799 loopOp.getMixedStep(), inits, loopOp.getMapping(),
801
802
803 Block *loopBody = loopOp.getBody();
804 Block *newLoopBody = newLoop.getBody();
806 loopBody, newLoopBody,
807 newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
808
809 auto terminator = castscf::InParallelOp(newLoopBody->getTerminator());
814 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
815 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
816 regionIterArgs, tiledValues, resultOffsets,
817 resultSizes))) {
818 rewriter.eraseOp(newLoop);
820 "failed to get yielded tiled values");
821 }
822
823
825
826 for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
827 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
830 rewriter.createtensor::ParallelInsertSliceOp(
831 terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
832 resultStride);
833 }
834
836 newLoop->getResults().take_front(loopOp.getNumResults()));
837 return cast(newLoop.getOperation());
838 }
839
840
841
842
844 LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
847 loopLikeOp.getOperation())
848 .Case<scf::ForOp, scf::ForallOp>(
849 [&](auto loopOp) -> FailureOr {
851 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
852 })
853 .Default([&](auto loopOp) -> FailureOr {
855 });
856 }
857
858
859
860
861
862
866 if (loops.empty())
867 return success();
870
872 for (auto &loop : loops.drop_back()) {
874
875
876 auto forLoop = castscf::ForOp(loop.getOperation());
877
878
879 SmallVector newInits = llvm::to_vector(forLoop.getInitArgs());
880 newInits.append(newInitValues.begin(), newInitValues.end());
881 auto newLoop = rewriter.createscf::ForOp(
882 forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
883 forLoop.getStep(), newInits,
885
886
888 sourceBlockArgs.push_back(newLoop.getInductionVar());
889 auto newRegionIterArgs = newLoop.getRegionIterArgs();
890 sourceBlockArgs.append(
891 newRegionIterArgs.begin(),
892 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
893 rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
895 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
896 loop = newLoop;
897 ivs.push_back(newLoop.getInductionVar());
898 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
899 }
900
901
902 LoopLikeOpInterface innerMostLoop = loops.back();
903 FailureOr newInnerMostLoop =
905 getNewTiledYieldsFn);
906
907 if (failed(newInnerMostLoop))
908 return innerMostLoop.emitOpError("failed to return additional yields");
909 loops.back() = newInnerMostLoop.value();
910
911
912
913 for (auto [outerLoop, innerLoop] :
914 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
915
916 auto outerForLoop = castscf::ForOp(outerLoop);
917 auto outerLoopYield =
918 castscf::YieldOp(outerForLoop.getBody()->getTerminator());
920 llvm::to_vector(outerLoopYield.getOperands());
922 innerLoop->getResults().take_back(newInitValues.size());
923 newYields.append(additionalYields.begin(), additionalYields.end());
925 rewriter.replaceOpWithNewOpscf::YieldOp(outerLoopYield, newYields);
926 }
927 return success();
928 }
929
930
931
932 FailureOrscf::SCFTilingResult
936 return failure();
937 }
938
941
942
943 SmallVector iterationDomain = op.getIterationDomain(rewriter);
944
945
947 std::tie(tileSizes, numThreads) =
949
950
951
954 }
955
956
957
959 if (.interchangeVector.empty()) {
961 iterationDomain.size());
963 "expected interchange vector to be a permutation");
964
967 if (!numThreads.empty())
969 }
970
971 FailureOr tilingResult;
972
973
979 -> LogicalResult {
980
983 rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
984
985
986
987 if (!interchangeVector.empty()) {
991 }
992
993
994
995
996 auto clonedOp = cast(
998
999
1000
1001
1003 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1004 tilingResult =
1005 TilingResult{{clonedOp}, clonedOp->getResults(),
1006 {}};
1007 return success();
1008 }
1009
1010
1012 offsets, sizes, options);
1013 if (failed(tilingResult)) {
1014 rewriter.eraseOp(clonedOp);
1015 return op.emitOpError("faild to tile operation");
1016 }
1017
1018
1019 rewriter.eraseOp(clonedOp);
1020
1021
1022
1023 for (auto [index, tiledValue] :
1025 tiledResults.push_back(tiledValue);
1028 sizes, resultOffset, resultSize,
1030 for (auto op : tilingResult->tiledOps) {
1032 }
1034 op, "failed to get slice of result produced");
1035 }
1036 resultOffsets.emplace_back(std::move(resultOffset));
1037 resultSizes.emplace_back(std::move(resultSize));
1038 }
1039
1040 return success();
1041 };
1042
1043
1044 FailureOr<SmallVector> maybeInits =
1046 if (failed(maybeInits)) {
1048 op, "unable to create initial tensors for tiling");
1049 }
1051
1052
1055 tileSizes, numThreads, initTensors,
1056 innerYieldTiledValuesFn, loops)))
1057 return op.emitOpError("failed to generate tiling loops");
1058 assert(succeeded(tilingResult) &&
1059 "expected tiling result to be computed after loop generation");
1060
1061 if (loops.empty()) {
1062
1063
1065 initTensors,
1066 loops,
1067 tilingResult->tiledValues,
1068 tilingResult->generatedSlices,
1069 {}};
1070 }
1071
1072 auto loopResults = llvm::map_to_vector(loops.front()->getResults(),
1074
1075
1076 if (options.reductionStrategy ==
1079 tilingResult->tiledOps, initTensors, loops, loopResults,
1080 tilingResult->generatedSlices, {}};
1081 }
1082
1083
1084 FailureOr mergeResult =
1086 if (failed(mergeResult)) {
1088 op, "Failed to merge partial results from tiling");
1089 }
1091 initTensors,
1092 loops,
1093 mergeResult->replacements,
1094 tilingResult->generatedSlices,
1095 mergeResult->mergeOps};
1096 }
1097
1098 FailureOrscf::SCFTilingResult
1100 PartialReductionOpInterface op,
1104 options.setReductionTilingStrategy(
1106 PartialReductionOuterReduction);
1107 options.setTileSizes(tileSize);
1109 }
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121 static std::tuple<OpResult, std::optional<OpOperand *>>
1124 std::optional<OpOperand *> destinationIterArg;
1125 assert(!loops.empty() && "expected non empty loops container");
1126 auto loopIt = loops.rbegin();
1127 while (loopIt != loops.rend() && isa(source->get())) {
1128 auto iterArg = cast(source->get());
1129 auto loop = *loopIt;
1130 if (iterArg.getOwner()->getParentOp() != loop)
1131 break;
1132 source = loop.getTiedLoopInit(iterArg);
1133 loopIt++;
1134 }
1135 if (loopIt == loops.rend())
1136 destinationIterArg = source;
1137 return {dyn_cast(source->get()), destinationIterArg};
1138 }
1139
1140
1141
1142 std::optionalscf::SCFFuseProducerOfSliceResult
1144 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1146
1147
1148 auto [fusableProducer, destinationInitArg] =
1150 loops);
1151 if (!fusableProducer)
1152 return std::nullopt;
1153 unsigned resultNumber = fusableProducer.getResultNumber();
1154
1157
1158
1159
1160 SmallVector origDestinationTensors, clonedOpDestinationTensors;
1161 Operation *fusableProducerOp = fusableProducer.getOwner();
1162 if (isa(fusableProducerOp) &&
1164 rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
1165 origDestinationTensors)))
1166 return std::nullopt;
1167
1168 clonedOpDestinationTensors = origDestinationTensors;
1169 if (destinationInitArg &&
1170 isa(fusableProducerOp)) {
1171
1172
1173
1174 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1175 }
1176
1178 rewriter, fusableProducerOp, clonedOpDestinationTensors);
1179
1180
1181
1183 llvm::to_vector(candidateSliceOp->getOperands());
1184 candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
1185 tensor::ExtractSliceOp clonedCandidateSliceOp =
1187 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1188
1189
1190 FailureOr tileAndFuseResult =
1192 rewriter, clonedCandidateSliceOp,
1193 clonedProducerOp->getResult(resultNumber));
1194 if (failed(tileAndFuseResult))
1195 return std::nullopt;
1196
1197
1199 tileAndFuseResult->tiledValues[0]);
1200 rewriter.eraseOp(clonedCandidateSliceOp);
1201 rewriter.eraseOp(clonedProducerOp);
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246 if (destinationInitArg &&
1247 isa(fusableProducerOp) && !loops.empty()) {
1248 loops.front()
1249 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1250 .set(origDestinationTensors[resultNumber]);
1251 }
1253 fusableProducer, tileAndFuseResult->tiledValues[0],
1254 tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1255 }
1256
1257
1259 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1263 if (loops.empty())
1264 return success();
1265
1267 *tiledOwner = fusedProducerInfo.tiledOps[0];
1268
1270
1272 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq(
1274 : llvm::to_vector(yieldResultNumber);
1276 for (const auto &resultNumber : initNumberList) {
1278 rewriter, loc, originalOwner->getResult(resultNumber));
1279 if (succeeded(initValue)) {
1280 initValueList.push_back(initValue.value());
1281 } else {
1282 return failure();
1283 }
1284 }
1285
1293
1294
1296 sliceSizes = sliceOp.getMixedSizes();
1297
1298
1299 if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
1300 return failure();
1301
1302 unsigned sliceResultNumber =
1304
1305 auto tilableOp = cast(originalOwner);
1306
1308
1309 if (tilableOp->getNumResults() > 1 &&
1310 failed(tilableOp.getIterationDomainTileFromResultTile(
1311 rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1312 iterDomainOffset, iterDomainSizes))) {
1313
1314
1315
1316
1317
1318
1319
1320
1321 return failure();
1322 }
1323
1324
1325
1327 for (const auto &resultNumber : initNumberList) {
1328 if (resultNumber == sliceResultNumber) {
1329 offsetList.push_back(sliceOffset);
1330 sizesList.push_back(sliceSizes);
1331 } else {
1332 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1333
1335 if (failed(tilableOp.getResultTilePosition(
1336 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1337 offset, sizes))) {
1338 return failure();
1339 }
1340 offsetList.push_back(offset);
1341 sizesList.push_back(sizes);
1342 }
1343 }
1344
1345
1346
1347 if (auto tiledDestStyleOp =
1348 dyn_cast(tiledOwner)) {
1350 for (const auto &&[index, newRegionArg] :
1352 auto destSlice = rewriter.createtensor::ExtractSliceOp(
1353 loc, newRegionArg, offsetList[index], sizesList[index],
1356 generatedSlices.push_back(destSlice);
1357 unsigned resultNumber = initNumberList[index];
1359 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1360 });
1361 }
1362 }
1363
1364
1365
1368 for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1369 tiledResult.push_back(tiledOwner->getResult(resultNumber));
1370 tiledOffset.emplace_back(offsetList[index]);
1371 tiledSizes.emplace_back(sizesList[index]);
1372 }
1373 return success();
1374 };
1375
1377 newYieldValuesFn))) {
1378 return failure();
1379 }
1380 return generatedSlices;
1381 }
1382
1383 namespace {
1384
1385
1386
1387
1388
1389
1390
1391
1393 public:
1394 explicit SliceTrackingListener(
1395 std::optional patterns);
1396 SliceTrackingListener() = default;
1397
1398
1399
1400
1401
1403
1404
1405 void notifyOperationInserted(Operation *op,
1407
1408
1410
1411
1412 void notifyOperationErased(Operation *op) override;
1413
1414
1415 void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1416
1417
1418
1419 std::dequetensor::ExtractSliceOp worklist;
1420
1421 private:
1422
1423
1424 std::optional patterns = std::nullopt;
1425 };
1426
1427 SliceTrackingListener::SliceTrackingListener(
1428 std::optional p) {
1430 }
1431
1432 LogicalResult
1435 if (auto slice = dyn_casttensor::ExtractSliceOp(op))
1436 worklist.push_back(slice);
1437 }
1438
1440 return success();
1441
1446 }
1447
1448 void SliceTrackingListener::notifyOperationInserted(
1450 auto slice = dyn_casttensor::ExtractSliceOp(op);
1451 if (!slice)
1452 return;
1453 worklist.push_back(slice);
1454 }
1455
1456
1457
1458
1459 void SliceTrackingListener::removeOp(Operation *op) {
1460 if (!isatensor::ExtractSliceOp(op))
1461 return;
1462 auto iter = worklist.begin();
1463 while (iter != worklist.end()) {
1464 if (*iter == op)
1465 break;
1466 iter++;
1467 }
1468 if (iter == worklist.end())
1469 return;
1470
1471 worklist.erase(iter);
1472 }
1473
1474 void SliceTrackingListener::notifyOperationErased(Operation *op) {
1475 removeOp(op);
1476 }
1477
1478 void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1480 removeOp(op);
1481 }
1482
1483
1484
1485
1486
1487
1488
1489
1491 public:
1494 : ForwardingListener(listener), replacements(replacements) {}
1495
1496 void updateReplacementValues(ValueRange origValues,
1498
1499
1500 for (auto &[key, val] : replacements) {
1501 for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1502 if (val == orig) {
1503 val = replace;
1504 }
1505 }
1506 }
1507 }
1508
1509 void notifyOperationReplaced(Operation *op, Operation *newOp) override {
1510 ForwardingListener::notifyOperationReplaced(op, newOp);
1512 }
1513
1514 void notifyOperationReplaced(Operation *op, ValueRange values) override {
1515 ForwardingListener::notifyOperationReplaced(op, values);
1516 updateReplacementValues(op->getResults(), values);
1517 }
1518
1519 private:
1521 };
1522
1523 }
1524
1525
1526 FailureOrscf::SCFTileAndFuseResult
1528 RewriterBase &rewriter, TilingInterface consumer,
1530
1531
1532 if (!consumer->getNumResults()) {
1534 consumer, "invalid pattern for op with no results");
1535 }
1536
1537
1539
1540 FailureOrscf::SCFTilingResult tilingResult =
1542
1543 if (failed(tilingResult))
1544 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1545 tiledAndFusedOps.insert_range(tilingResult->tiledOps);
1546
1548 for (auto [origVal, replacement] :
1549 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1550 replacements[origVal] = replacement;
1551 }
1552
1553
1554 auto &loops = tilingResult->loops;
1555 if (loops.empty()) {
1557 replacements};
1558 }
1559
1560
1561
1562
1564 auto resetListener =
1565 llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
1566 ReplacementListener replaceListener(replacements, previousListener);
1568
1569
1570
1571
1572
1573
1574
1575
1576 struct WorklistItem {
1577 tensor::ExtractSliceOp candidateSlice;
1579 };
1580
1581 SliceTrackingListener sliceTracker =
1582 SliceTrackingListener(options.cleanupPatterns);
1583
1584 if (failed(
1585 sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1586 return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1587 }
1589 while (!sliceTracker.worklist.empty()) {
1590 auto candidateSlice = sliceTracker.worklist.front();
1591 sliceTracker.worklist.pop_front();
1592
1593 auto [fusableProducer, destinationInitArg] =
1595 loops);
1596 if (!fusableProducer)
1597 continue;
1598
1599 std::optionalSCFTileAndFuseOptions::ControlFnResult controlFnResult =
1600 options.fusionControlFn(candidateSlice, fusableProducer,
1601 destinationInitArg.has_value());
1602 if (!controlFnResult)
1603 continue;
1604
1605 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1606
1607
1608
1609
1610 std::optionalscf::SCFFuseProducerOfSliceResult fusedResult =
1612 loops);
1613 if (!fusedResult)
1614 continue;
1615
1617
1618 if (worklistItem.controlFnResult.yieldProducerReplacement) {
1619
1620
1621
1622
1623 Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1624 FailureOr<SmallVector<Operation *>> newSlices =
1626 worklistItem.candidateSlice,
1627 fusedResult.value(), loops);
1628 if (failed(newSlices)) {
1630 fusableProducerOp, "failed to replacement value for this "
1631 "operation from within the tiled loop");
1632 }
1633 worklistCandidates.append(newSlices.value());
1634 for (auto [index, result] :
1636 replacements[result] = loops.front()->getResult(
1637 loops.front()->getNumResults() -
1639 }
1640 }
1641 if (Operation *tiledAndFusedOp =
1642 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1643 fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1644 tiledAndFusedOps.insert(tiledAndFusedOp);
1645 }
1646
1647 if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1648 return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1649 }
1650 }
1651
1653 replacements};
1654 }
1655
1656
1657
1658
1659
1660
1661
1662 static LogicalResult
1664 Value result = candidateSliceOp.getResult();
1666 if (!llvm::hasSingleElement(uses)) {
1667 LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1668 return failure();
1669 }
1670 OpOperand &operandUse = (*uses.begin());
1672 if (!isascf::YieldOp(userOp)) {
1673 LLVM_DEBUG(llvm::dbgs()
1674 << "Expected scf.yield to be the only user, but got -> "
1675 << (*userOp));
1676 return failure();
1677 }
1679 LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1680 "be in the same block\n");
1681 return failure();
1682 }
1683 return success();
1684 }
1685
1686
1687
1689 if (!isa(loopOp))
1690 return failure();
1691 Operation *firstUserOfLoop = nullptr;
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708 if (isatensor::ParallelInsertSliceOp(userOp))
1710
1711 if (loopOp->getBlock() != userOp->getBlock())
1712 return failure();
1713
1714 if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))
1715 firstUserOfLoop = userOp;
1716 }
1717 return firstUserOfLoop;
1718 }
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758 static FailureOr<llvm::SetVector<Operation *>>
1760 bool reorderOperations) {
1761 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1762 if (failed(firstUserOfLoop))
1763 return failure();
1764
1767 options.inclusive = true;
1768 options.omitBlockArguments = true;
1769 bool includeLoopOp = false;
1771 if (op == loopOp) {
1772 includeLoopOp = true;
1773 return false;
1774 }
1775
1776
1778 };
1780 for (auto operand : consumerOp->getOperands()) {
1782 assert(result.succeeded() && "expected a backward slice");
1783 (void)result;
1784 }
1785
1786 if (!slice.empty()) {
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796 if (includeLoopOp || !reorderOperations)
1797 return failure();
1798 }
1799
1800 return slice;
1801 }
1802
1803
1804
1805
1808 unsigned resultNumber) {
1809 if (!isa(loopOp))
1810 return failure();
1814 Operation *consumerOp = opOperand.getOwner();
1815
1816 if (!isa(consumerOp) ||
1817 !isa(consumerOp)) {
1818
1819
1820
1821 continue;
1822 }
1823
1824 if (loopBlock != consumerOp->getBlock())
1825 continue;
1826
1827
1829 continue;
1830
1831 FailureOr<llvm::SetVector<Operation *>> slice =
1833 if (failed(slice))
1834 continue;
1835
1836
1837 if (!slice->empty()) {
1839 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1840 assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
1841 for (auto op : *slice) {
1842 rewriter.moveOpBefore(op, *firstUserOfLoop);
1843 }
1844 }
1845 return &opOperand;
1846 }
1847 return failure();
1848 }
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863 static bool
1865 assert(!loops.empty() && "unexpected empty loop nest");
1866 if (loops.size() == 1) {
1867 return isa_and_nonnullscf::ForOp(loops.front().getOperation());
1868 }
1869 for (auto [outerLoop, innerLoop] :
1870 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1871 auto outerFor = dyn_cast_or_nullscf::ForOp(outerLoop.getOperation());
1872 auto innerFor = dyn_cast_or_nullscf::ForOp(innerLoop.getOperation());
1873 if (!outerFor || !innerFor) {
1874 return false;
1875 }
1876 auto outerBBArgs = outerFor.getRegionIterArgs();
1877 auto innerIterArgs = innerFor.getInitArgs();
1878 if (outerBBArgs.size() != innerIterArgs.size()) {
1879 return false;
1880 }
1881
1882 for (auto [outerBBArg, innerIterArg] :
1883 llvm::zip_equal(outerBBArgs, innerIterArgs)) {
1884 if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1885 innerIterArg != outerBBArg) {
1886 return false;
1887 }
1888 }
1889
1891 castscf::YieldOp(outerFor.getBody()->getTerminator())->getOperands();
1892 ValueRange innerResults = innerFor.getResults();
1893 if (outerYields.size() != innerResults.size()) {
1894 return false;
1895 }
1896 for (auto [outerYield, innerResult] :
1897 llvm::zip_equal(outerYields, innerResults)) {
1898 if (!llvm::hasSingleElement(innerResult.getUses()) ||
1899 outerYield != innerResult) {
1900 return false;
1901 }
1902 }
1903 }
1904 return true;
1905 }
1906
1907
1908
1909
1910
1911
1912
1913 static FailureOr<OpOperand *>
1915 tensor::InsertSliceOp candidateSliceOp,
1917 assert(!loops.empty() && "unexpected loops to be empty");
1918
1920 if (containingOp != loops.back()) {
1922 candidateSliceOp,
1923 "expected slice to be within body of inner-most loop");
1924 }
1925
1926
1929 candidateSliceOp, "expected passed loops to be perfectly nested.");
1930 }
1931
1933 return failure();
1934 Value sliceResult = candidateSliceOp.getResult();
1935
1936
1937 OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
1939
1940 scf::ForOp topLevelForOp = castscf::ForOp(loops.front().getOperation());
1941
1943 }
1944
1945
1946
1947 static FailureOr<OpOperand *>
1949 tensor::ParallelInsertSliceOp candidateSliceOp,
1951 assert(!loops.empty() && "unexpected loops to be empty");
1952
1953 if (loops.size() != 1) {
1955 candidateSliceOp, "expected single surrounding scf.forall");
1956 }
1957 auto forallOp = dyn_castscf::ForallOp(loops.front().getOperation());
1958 if (!forallOp) {
1960 candidateSliceOp, "expected single surrounding scf.forall");
1961 }
1962
1963
1964 Value sliceDest = candidateSliceOp.getDest();
1965 auto iterArg = dyn_cast(sliceDest);
1966 if (!iterArg)
1967 return failure();
1968 if (iterArg.getOwner()->getParentOp() != forallOp)
1969 return failure();
1970
1971 unsigned resultNumber =
1972 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
1973 .getResultNumber();
1974
1976 }
1977
1978
1979
1980 static FailureOr<OpOperand *>
1983 assert(!loops.empty() && "unexpected empty loops");
1984 if (auto insertSlice = dyn_casttensor::InsertSliceOp(sliceOp)) {
1986 } else if (auto parallelInsertSlice =
1987 dyn_casttensor::ParallelInsertSliceOp(sliceOp)) {
1989 } else {
1990 return failure();
1991 }
1992 }
1993
1994
1995
1996 FailureOrscf::SCFFuseConsumerOfSliceResult
2000
2001
2002 if (loops.empty()) {
2004 "cannot call tile and fuse consumer with an empty loop nest");
2005 }
2006 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2007 candidateSliceOp))
2008 return failure();
2009
2010
2011
2012 FailureOr<OpOperand *> maybeConsumerOpOperand =
2014 if (failed(maybeConsumerOpOperand)) {
2016 "could not fetch consumer to fuse");
2017 }
2018 OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
2020 unsigned operandNumber = consumerOpOperand->getOperandNumber();
2021 unsigned resultNumber = 0;
2022 if (auto producerResult = dyn_cast(consumerOpOperand->get())) {
2023 resultNumber = producerResult.getResultNumber();
2024 } else {
2026 consumerOp, "consumer op's operand doesn't seem to be an OpResult");
2027 }
2028
2029 LoopLikeOpInterface outerMostLoop = loops.front();
2030 LoopLikeOpInterface innerMostLoop = loops.back();
2031
2032
2035 outerMostLoop, "the first user of loop should not dominate any define "
2036 "of consumer operand(s)");
2037 }
2038
2040
2041
2042 auto dstOp = dyn_cast(consumerOp);
2043 if (!dstOp)
2045 "consumer op is not DPS operation");
2047 llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
2048 if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
2050 consumerOp,
2051 "consumer op taking the result of scf.for as init is not supported");
2052 }
2054
2055 Location loc = outerMostLoop->getLoc();
2056
2057
2058
2059 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
2060 if (failed(firstUserOfLoop)) {
2062 outerMostLoop, "could not find the first user of outer most loop");
2063 }
2064 rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);
2065
2066
2067
2068
2069
2070 tensor::InsertSliceOp clonedInsertSliceOp;
2071 if (auto sliceOp =
2072 dyn_casttensor::ParallelInsertSliceOp(candidateSliceOp)) {
2073 auto newForallOp = castscf::ForallOp(innerMostLoop.getOperation());
2075 clonedInsertSliceOp = rewriter.createtensor::InsertSliceOp(
2076 loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
2077 sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
2078 } else {
2080 clonedInsertSliceOp =
2081 casttensor::InsertSliceOp(rewriter.clone(*candidateSliceOp));
2082 }
2083
2084
2085 auto clonedConsumerOp = cast(rewriter.clone(*consumerOp));
2086
2087
2088
2089 OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
2091 operandToReplace.set(clonedInsertSliceOp.getResult());
2092 });
2093
2094
2095
2096 auto ossSliceOp =
2097 cast(clonedInsertSliceOp.getOperation());
2098 FailureOr tileAndFuseResult =
2100 rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
2101 if (failed(tileAndFuseResult)) {
2102 return failure();
2103 }
2104 auto tiledConsumerOp = cast(tileAndFuseResult->tiledOps[0]);
2105 rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
2106 clonedInsertSliceOp.getSource());
2107
2108
2115
2117
2121
2122
2125 candidateSliceOp, "containingOp's result yield with stride");
2126 }
2127
2128
2129
2130
2131
2132
2133
2135 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
2136 rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
2137 iterDomainSizes))) {
2139 clonedConsumerOp,
2140 "can't get iter domain position from input position");
2141 }
2142
2143
2144
2145
2146 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2148 totalNumResultsOfConsumer);
2150 totalNumResultsOfConsumer);
2151 for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
2152 if (failed(tiledConsumerOp.getResultTilePosition(
2153 rewriter, idx, iterDomainOffsets, iterDomainSizes,
2154 resultOffsets[idx], resultSizes[idx]))) {
2156 tiledConsumerOp,
2157 "can't get result domain position from iter domain position");
2158 }
2159 }
2160
2161
2162
2163 if (auto tiledDestStyleOp = dyn_cast(
2164 tiledConsumerOp.getOperation())) {
2166 for (const auto &&[index, newRegionArg] :
2168 auto destSlice = rewriter.createtensor::ExtractSliceOp(
2169 loc, newRegionArg, resultOffsets[index], resultSizes[index],
2172
2173
2174 auto dstNumber = index;
2176 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2177 });
2178 }
2179 }
2180
2181
2182
2185 for (const auto &&[index, result] :
2187 tiledResult.push_back(result);
2188 tiledOffset.emplace_back(resultOffsets[index]);
2189 tiledSizes.emplace_back(resultSizes[index]);
2190 }
2191 return success();
2192 };
2193
2195 newYieldValuesFn))) {
2197 "unable to add new inits to nest loop");
2198 }
2199
2200
2201
2202
2203 for (auto &&[oldResult, newResult] :
2205 loops.front()->getResults().take_back(newInits.size()))) {
2207 }
2208
2209
2210 rewriter.eraseOp(clonedConsumerOp);
2211
2213 consumerOpOperand,
2214 &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
2215 tileAndFuseResult->tiledOps};
2216 }
2217
2218
2219
2220
2221
2222 FailureOr<SmallVectorscf::ForOp>
2224 TilingInterface op) {
2225
2226 if (op->getNumResults() > 0) {
2228 op, "unable to lower to loops operations with return values");
2229 }
2230
2235 for (auto loopRange : domain) {
2236 Value offsetVal =
2240 Value strideVal =
2242 auto loop = rewriter.createscf::ForOp(op.getLoc(), offsetVal, sizeVal,
2244 loops.push_back(loop);
2245 ivs.push_back(loop.getInductionVar());
2247 }
2248 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
2249 return failure();
2250 }
2251 return loops;
2252 }
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static LogicalResult verifyTileSizeOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
static bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check that the loop is perfectly nested.
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult tileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static void checkSafeToTileToForall(TilingInterface op, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes)
Function to return the bounds of the loops to be generated.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ValueRange partialResults, const scf::SCFTilingOptions &options)
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Fetch the untiled consumer of the outermost scf.for's result which is yielded by a tensor....
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, const scf::SCFTilingOptions &options)
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents a saved insertion point.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSizes)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)
Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...
Include the generated interface declarations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Container for the result of merge operation of tiling.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Container for result values of tiling.
Fuse the consumer of the source of candidateSliceOp by computing the required slice of the consumer i...
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
SmallVector< Operation * > tiledOps
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
ReductionTilingStrategy
Specify how reduction dimensions should be tiled.
@ PartialReductionOuterReduction
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.