MLIR: lib/Dialect/Linalg/Transforms/HoistPadding.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
30 #include "llvm/Support/Debug.h"
31
32 using llvm::dbgs;
33
34 #define DEBUG_TYPE "hoist-padding"
35
36 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
37
38 using namespace mlir;
41
42 #ifndef NDEBUG
45 (void)state;
46 if (auto forOp = dyn_castscf::ForOp(op)) {
47 forOp.getInductionVar().printAsOperand(dbgs(), state);
48 dbgs() << " @ " << forOp.getOperation();
49 return true;
50 }
51 return false;
52 }
53 #endif
54
56 LLVM_DEBUG(llvm::interleaveComma(backwardSlice, DBGS() << "--backwardSlice:",
58 dbgs() << "\n";
59 DBGS() << "----";
61 dbgs() << "\n";
62 return;
63 }
64 dbgs() << *op << "\n";
65 });
66 DBGS() << "\n";);
67 }
68
69
70
71
72
73 static void
76 scf::ForOp outermostEnclosingForOp = nullptr;
78 while (nLevels-- > 0 &&
79 (outermostEnclosingForOp = dyn_castscf::ForOp(nextEnclosingOp))) {
80 LLVM_DEBUG(DBGS() << "loops: ";
82 dbgs() << "\n");
83 reverseEnclosingLoops.push_back(outermostEnclosingForOp);
84 nextEnclosingOp = outermostEnclosingForOp->getParentOp();
85 }
86 }
87
88
89
90
91
92 static void
95 scf::ForOp outermostEnclosingForOp = nullptr;
97 while (outermostEnclosingForOp != untilLoop &&
98 (outermostEnclosingForOp = dyn_castscf::ForOp(nextEnclosingOp))) {
99 LLVM_DEBUG(DBGS() << "loops: ";
101 dbgs() << "\n");
102 reverseEnclosingLoops.push_back(outermostEnclosingForOp);
103 nextEnclosingOp = outermostEnclosingForOp->getParentOp();
104 }
105 }
106
107
108
109
110
112 scf::ForOp outermostEnclosingForOp,
117 return domInfo.dominates(outermostEnclosingForOp, op) &&
118 !padOp->isProperAncestor(op);
119 };
121
122
125 valuesDefinedAbove);
126 for (Value v : valuesDefinedAbove) {
127 LogicalResult result = getBackwardSlice(v, &backwardSlice, sliceOptions);
128 assert(result.succeeded() && "expected a backward slice");
129 (void)result;
130 }
131
132 LogicalResult result =
133 getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);
134 assert(result.succeeded() && "expected a backward slice");
135 (void)result;
136 }
137
138
139
140
141
142 namespace {
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158 struct HoistPaddingAnalysis {
159 HoistPaddingAnalysis(tensor::PadOp padOp, int numLoops);
160 HoistPaddingAnalysis(tensor::PadOp padOp, scf::ForOp outermostEnclosingForOp);
161
162 bool isValid() { return valid.has_value() && valid.value(); }
163 bool isInvalid() { return valid.has_value() && !valid.value(); }
164
165
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183 void enableHoistPadding(RewriterBase &rewriter);
184
185
186
187
188 void finalizeHoistPaddingAnalysis();
189
190 private:
191
192 std::optional valid;
193
194
195 tensor::PadOp opToHoist;
196
197
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224 LogicalResult dropNonIndexDependencies();
225
226 public:
227
228
229 scf::ForOp outermostEnclosingForOp;
230
231
232
234
235
236
237
238
239
241
242
243 tensor::ExtractSliceOp sliceOp;
244
245
246 scf::ForOp padConsumingForOp;
247 };
248
249 }
250
251 HoistPaddingAnalysis::HoistPaddingAnalysis(tensor::PadOp padOp, int numLoops)
252 : valid(std::nullopt), opToHoist(padOp) {
253
255 if (reverseEnclosingLoops.empty()) {
256 LLVM_DEBUG(DBGS() << "--No immediately enclosing loop -> Skip\n");
257 valid = false;
258 return;
259 }
260 outermostEnclosingForOp = reverseEnclosingLoops.back();
261 sliceOp = opToHoist.getSource().getDefiningOptensor::ExtractSliceOp();
262 if (!sliceOp) {
263 LLVM_DEBUG(DBGS() << "--Cannot find the extract slice op -> Skip\n");
264 valid = false;
265 return;
266 }
267 }
268
269 HoistPaddingAnalysis::HoistPaddingAnalysis(tensor::PadOp padOp,
270 scf::ForOp outermostEnclosingForOp)
271 : valid(std::nullopt), opToHoist(padOp) {
272
274 reverseEnclosingLoops);
275 if (reverseEnclosingLoops.empty()) {
276 LLVM_DEBUG(DBGS() << "--No immediately enclosing loop -> Skip\n");
277 valid = false;
278 return;
279 }
280 this->outermostEnclosingForOp = reverseEnclosingLoops.back();
281 if (this->outermostEnclosingForOp != outermostEnclosingForOp) {
282 LLVM_DEBUG(DBGS() << "--Unexpected outermost enclosing loop -> Skip\n");
283 valid = false;
284 return;
285 }
286 sliceOp = opToHoist.getSource().getDefiningOptensor::ExtractSliceOp();
287 if (!sliceOp) {
288 LLVM_DEBUG(DBGS() << "--Cannot find the extract slice op -> Skip\n");
289 valid = false;
290 return;
291 }
292 }
293
294 void HoistPaddingAnalysis::enableHoistPadding(RewriterBase &rewriter) {
295 if (isInvalid())
296 return;
297
298
299
300 if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) {
301 outermostEnclosingForOp = castscf::ForOp(
303 }
304 }
305
306 void HoistPaddingAnalysis::finalizeHoistPaddingAnalysis() {
307 if (isInvalid())
308 return;
309
310 if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) {
311 LLVM_DEBUG(DBGS() << "--outermostEnclosingForOp:\n"
312 << outermostEnclosingForOp << "\n"
313 << "--sliceOp: " << sliceOp << "\n"
314 << "--sliceOp.getSource(): " << sliceOp.getSource()
315 << "\n");
316 LLVM_DEBUG(DBGS() << "----Source not defined outside of loops -> Skip\n");
317 valid = false;
318 return;
319 }
320 if (sliceOp->hasOneUse()) {
321 padConsumingForOp = dyn_castscf::ForOp(*(sliceOp->getUsers().begin()));
322 }
323
324
325
326
327 Value paddingValue = opToHoist.getConstantPaddingValue();
328 if (!paddingValue ||
329 !isa_and_nonnullarith::ConstantOp(paddingValue.getDefiningOp())) {
330 LLVM_DEBUG(DBGS() << "Cannot find constant padding value -> Skip\n");
331 valid = false;
332 return;
333 }
334
336 if (backwardSlice.size() <= 1) {
337 valid = false;
338 return;
339 }
340
342
343
344
345 if (failed(dropNonIndexDependencies())) {
346 LLVM_DEBUG(DBGS() << "--Cannot dropNonIndexDependencies -> Skip\n");
347 valid = false;
348 return;
349 }
351
352
353
354
355
356
357
358 for (scf::ForOp forOp : llvm::reverse(reverseEnclosingLoops))
359 if (backwardSlice.contains(forOp))
360 packingLoops.push_back(forOp);
361
362
363 if (packingLoops.size() > 1 && padConsumingForOp) {
364 LLVM_DEBUG(DBGS() << "--Cannot hoist multiple loops through iter_args -> "
365 "Downgrade to 1 loop\n");
366 packingLoops.resize(1);
367 }
368
369
370
371
372
373 valid = true;
374 }
375
376 LogicalResult HoistPaddingAnalysis::dropNonIndexDependencies() {
377
379
380
381
382 auto addIndexOperandsToIndexEdges = [&](Operation *operation) {
383 for (Value operand : operation->getOperands())
384 if (operand.getType().isIndex())
385 indexEdges.insert(operand);
386 };
387
388
389 auto hasIndexResult = [&](Operation *operation) {
390 return llvm::any_of(operation->getResults(), [&](Value result) {
391 return indexEdges.contains(result);
392 });
393 };
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
416 for (Operation *op : llvm::reverse(backwardSlice)) {
417
418
419 if (op == opToHoist || op == sliceOp) {
420 addIndexOperandsToIndexEdges(op);
421 continue;
422 }
423
424
425 if (auto forOp = dyn_castscf::ForOp(op)) {
426 if (!hasIndexResult(op) && indexEdges.contains(forOp.getInductionVar())) {
427 addIndexOperandsToIndexEdges(op);
428 continue;
429 }
430 }
431
432
433 if (hasIndexResult(op)) {
434 addIndexOperandsToIndexEdges(op);
435
436 if (llvm::any_of(op->getOperandTypes(),
437 [](Type type) { return !type.isIndex(); })) {
438 LLVM_DEBUG(DBGS() << "Unsupported op with non index type operands: "
439 << op << " -> Skip\n");
440 return failure();
441 }
442
443 auto effectInterface = dyn_cast(op);
444 bool hasMemoryEffect = effectInterface && !effectInterface.hasNoEffect();
445 if (hasMemoryEffect || op->getNumRegions() != 0) {
446 LLVM_DEBUG(DBGS() << "Unsupported op with region or memory effect: "
447 << op << " -> Skip\n");
448 return failure();
449 }
450 continue;
451 }
452
453
454 if (!isaarith::ConstantOp(op))
455 operationsToRemove.insert(op);
456 }
457 backwardSlice.set_subtract(operationsToRemove);
458 return success();
459 }
460
462 HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
465
466
467
468
469
470
471 for (auto forOp : packingLoops) {
472
474 rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),
475
477 if (v == forOp.getUpperBound())
478 return false;
479
480 Operation *op = v.getDefiningOp();
481 if (!op)
482 return true;
483 return !isa<affine::AffineMinOp, affine::AffineMaxOp,
484 affine::AffineApplyOp>(op);
485 },
486 true);
487 assert(succeeded(loopUb) && "could not get upper bound");
489
490
491
492
493
498 loc, (ub - lb).ceilDiv(step),
499 ValueRange{forOp.getLowerBound(), ubVal,
500 castscf::ForOp(forOp).getStep()});
501 dynamicTensorSizes.push_back(res);
502 }
503
504 return dynamicTensorSizes;
505 }
506
509 }
510
511
512
513
514
515
516
517
518
520 scf::ForOp forOp) {
528 Value ivVal = forOp.getInductionVar(), lbVal = forOp.getLowerBound(),
529 stepVal = forOp.getStep();
530 auto loc = forOp->getLoc();
531 return rewriter.createOrFoldaffine::AffineApplyOp(
532 loc, (iv - lb).ceilDiv(step), ValueRange{ivVal, lbVal, stepVal});
533 }
534
535
536
537
538
539
540
541
542
543
544
545
548 ArrayRef<int64_t> transposeVector, RankedTensorType transposedTensorType,
549 tensor::EmptyOp emptyOp, const HoistPaddingAnalysis &analysis) {
551 SmallVector clonedLoopIvs, leadingHoistedPackedTensorIndexings;
552
553 scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
554
555 Location loc = opToHoist->getLoc();
556 RankedTensorType paddedTensorType = opToHoist.getResultType();
557 int paddedRank = paddedTensorType.getRank();
558
559
560 BlockArgument bbArg = dyn_cast(opToHoist.getSource());
561 while (bbArg) {
563 if (!forOp)
564 break;
565 if (forOp != outerLoop && !outerLoop->isAncestor(forOp))
566 break;
567 OpOperand &operand = *forOp.getTiedLoopInit(bbArg);
568 bvm.map(bbArg, operand.get());
569 bbArg = dyn_cast(operand.get());
570 }
571
572
573 Value hoistedPackedTensor = emptyOp.getResult();
575 for (Operation *op : analysis.backwardSlice) {
576
577
578 if (auto sliceOp = dyn_casttensor::ExtractSliceOp(op)) {
579 if (bvm.lookupOrDefault(sliceOp.getSource()) == hoistedPackedTensor) {
580 LLVM_DEBUG(DBGS() << "--Skip: " << sliceOp << "\n");
581 continue;
582 }
583 }
584
585
586 auto forOp = dyn_castscf::ForOp(op);
587 if (!forOp) {
588
589 rewriter.clone(*op, bvm);
590 continue;
591 }
592
593
594
595 auto clonedForOp = rewriter.createscf::ForOp(
598 bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor);
599
600
601 bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());
602 bvm.map(forOp.getRegionIterArgs(), clonedForOp.getRegionIterArgs());
603 bvm.map(forOp.getResults(), clonedForOp.getResults());
604 assert(clonedForOp->getNumRegions() == 1);
605 clonedLoopIvs.push_back(clonedForOp.getInductionVar());
606
607
609 Value loopIndependentIterationCount =
611
612
613 if (!loopIndependentIterationCount)
614 llvm_unreachable("loop independence prerequisite not met");
615 leadingHoistedPackedTensorIndexings.push_back(
616 loopIndependentIterationCount);
617 hoistedPackedTensor = clonedForOp.getRegionIterArgs().front();
618 }
619
620
621
622 int64_t nPackedLoops = clonedLoopIvs.size();
623
624 offsets =
626 leadingHoistedPackedTensorIndexings.end()};
627 offsets.append(paddedRank, rewriter.getIndexAttr(0));
628
630 for (int64_t sz : transposedTensorType.getShape()) {
631
632 if (ShapedType::isDynamic(sz))
633 return failure();
635 }
636
639
640
641 TransposeOp maybeTransposeOp;
642 Value paddedTensor = bvm.lookup(opToHoist.getResult());
643 if (!transposeVector.empty()) {
644 Value outputTensor = rewriter.createtensor::ExtractSliceOp(
645 loc, transposedTensorType, hoistedPackedTensor, offsets, sizes,
646 strides);
647 maybeTransposeOp = rewriter.createlinalg::TransposeOp(
648 loc, paddedTensor, outputTensor, transposeVector);
649 paddedTensor = maybeTransposeOp.getResult()[0];
650 }
651
652
653 if (nPackedLoops > 0) {
654
655
656 Value inserted = rewriter.createtensor::InsertSliceOp(
657 loc, paddedTensor, hoistedPackedTensor, offsets, sizes, strides);
658
659
660 Value valueToYield = inserted;
661 for (Value iv : llvm::reverse(clonedLoopIvs)) {
664 rewriter.createscf::YieldOp(loc, valueToYield);
665 valueToYield = forOp.getResult(0);
666 }
667 }
668
670 offsets,
671 sizes,
672 strides,
673 clonedLoopIvs,
674 leadingHoistedPackedTensorIndexings,
675 maybeTransposeOp,
676 casttensor::PadOp(bvm.lookup(opToHoist.getResult()).getDefiningOp())};
677 }
678
679
680
681
684 ArrayRef<int64_t> transposeVector, const HoistPaddingAnalysis &analysis) {
685
686 int nPackedLoops = analysis.packingLoops.size();
687 LLVM_DEBUG(DBGS() << "\n";
688 DBGS() << "Func:\n"
689 << *opToHoist->getParentOfTypefunc::FuncOp() << "\n";
690 DBGS() << "Start hoisting above " << nPackedLoops << " loops\n");
691
692 Location loc = opToHoist->getLoc();
693 RankedTensorType paddedTensorType = opToHoist.getResultType();
694
695
696 FailureOr transposedTensorType =
698 if (failed(transposedTensorType)) {
699 LLVM_DEBUG(DBGS() << "--Could not compute transposed type -> Skip\n");
700 return failure();
701 }
702
703
705
706 llvm::append_range(packedShape, transposedTensorType->getShape());
708 packedShape, transposedTensorType->getElementType());
709
710
711 scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
715 analysis.getHoistedPackedTensorSizes(rewriter, loc);
716 auto emptyOp = rewriter.createtensor::EmptyOp(
717 loc, hoistedPackedTensorType.getShape(),
718 hoistedPackedTensorType.getElementType(), dynamicTensorSizes);
719
721 *transposedTensorType, emptyOp, analysis);
722 }
723
724
725
726
728 RewriterBase &rewriter, tensor::PadOp opToHoist,
729 scf::ForOp outermostEnclosingForOp, ArrayRef<int64_t> transposeVector) {
730 HoistPaddingAnalysis analysis(opToHoist, outermostEnclosingForOp);
731 analysis.enableHoistPadding(rewriter);
732 analysis.finalizeHoistPaddingAnalysis();
733 if (!analysis.isValid()) {
734 LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n");
735 return failure();
736 }
739 analysis);
740 }
741
742
743
744
745
746
747
748
749
750
751
753 Value expectedSource) {
754 LLVM_DEBUG(DBGS() << "Start tracesBackToExpectedValue on: " << extractSliceOp
755 << "\n");
756 LLVM_DEBUG(DBGS() << "--with extractSlice: " << extractSliceOp << "\n");
757 Value source = extractSliceOp.getSource();
758 LLVM_DEBUG(DBGS() << "--with starting source: " << source << "\n");
759 while (source && source != expectedSource) {
760 auto destOp =
761 dyn_cast_or_null(source.getDefiningOp());
762 if (!destOp)
763 break;
764 LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n");
765 source = destOp.getDpsInitOperand(cast(source).getResultNumber())
766 ->get();
767 }
768 LLVM_DEBUG(DBGS() << "--final source: " << source << "\n");
769 LLVM_DEBUG(DBGS() << "--expected source: " << expectedSource << "\n");
770 return source == expectedSource;
771 }
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801 static tensor::ExtractSliceOp
803 Value hoistedPackedTensor,
804 tensor::ExtractSliceOp outerSliceOp, scf::ForOp forOp) {
805 LLVM_DEBUG(DBGS() << "Start padThroughLoopIterArg on: " << forOp << "\n");
806 LLVM_DEBUG(DBGS() << "--paddedValueBeforeHoisting: "
807 << paddedValueBeforeHoisting << "\n");
809 for (OpOperand &use : outerSliceOp->getUses()) {
810 if (use.getOwner() == forOp) {
811 assert(!pUse && "Multiple slice uses in the for loop");
812 pUse = &use;
813 }
814 }
815 assert(pUse && "No slice use in the for loop");
818
819 unsigned iterArgNumber = forOp.getTiedLoopResult(pUse).getResultNumber();
820 auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]
821 .getDefiningOptensor::ExtractSliceOp();
822 if (!yieldingExtractSliceOp)
823 return tensor::ExtractSliceOp();
824
825
826
827
829 paddedValueBeforeHoisting))
830 return tensor::ExtractSliceOp();
831
833 initArgs[iterArgNumber] = hoistedPackedTensor;
834 SmallVector yieldOperands = llvm::to_vector(forOp.getYieldedValues());
835 yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();
836
837 int64_t numOriginalForOpResults = initArgs.size();
838 LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults
839 << "\n");
840 tensor::ExtractSliceOp extracted;
841 {
844 extracted = rewriter.createtensor::ExtractSliceOp(
845 hoistedPackedTensor.getLoc(), hoistedPackedTensor,
846 outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(),
847 outerSliceOp.getMixedStrides());
848 rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted);
849 }
850 scf::ForOp newForOp = castscf::ForOp(*forOp.replaceWithAdditionalYields(
851 rewriter, initArgs, true,
853 return yieldOperands;
854 }));
855
856 LLVM_DEBUG(DBGS() << "newForOp results: " << newForOp.getNumResults()
857 << "\n");
858 LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n");
859 LLVM_DEBUG(DBGS() << "with result #"
860 << numOriginalForOpResults + iterArgNumber
861 << " of forOp, giving us: " << extracted << "\n");
863 extracted.getSourceMutable().assign(
864 newForOp.getResult(numOriginalForOpResults + iterArgNumber));
866
867 LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting
868 << "\n");
869 LLVM_DEBUG(DBGS() << "with region iter arg #"
870 << numOriginalForOpResults + iterArgNumber << "\n");
872 paddedValueBeforeHoisting,
873 newForOp.getRegionIterArg(numOriginalForOpResults + iterArgNumber));
874
875 return extracted;
876 }
877
878
879
882 tensor::PadOp opToHoist,
883 RankedTensorType transposedTensorType,
884 const HoistPaddingAnalysis &analysis,
886
887
890
891 Location loc = opToHoist->getLoc();
892 RankedTensorType paddedTensorType = opToHoist.getResultType();
893 int paddedRank = paddedTensorType.getRank();
894
895 int64_t nPackedLoops = packingResult.clonedLoopIvs.size();
896 LLVM_DEBUG(DBGS() << "nPackedLoops: " << nPackedLoops << " loops\n");
897
898 scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
900
901 Value hoistedPackedTensor;
905 if (nPackedLoops > 0) {
906 loopIterationCounts =
907 llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) {
909 castscf::ForOp(loop));
910 }));
911
912 if (llvm ::any_of(loopIterationCounts, [](Value v) { return !v; }))
913 llvm_unreachable("loop independence prerequisite not met");
914
915
916 std::copy(loopIterationCounts.begin(), loopIterationCounts.end(),
917 offsets.begin());
918 hoistedPackedTensor =
920 ->getResult(0);
921 } else {
922
923 hoistedPackedTensor = bvm.lookup(opToHoist.getResult());
924 }
925
926 LLVM_DEBUG(DBGS() << "hoistedPackedTensor: " << hoistedPackedTensor << "\n");
927
928
929 scf::ForOp forOp = analysis.padConsumingForOp;
930 if (forOp) {
932 analysis.sliceOp, forOp);
933 }
934
935
936
937
938 return rewriter.createtensor::ExtractSliceOp(
939 loc, transposedTensorType, hoistedPackedTensor, offsets,
941 }
942
944 RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops,
947 LLVM_DEBUG(DBGS() << "\n"; DBGS() << " Try to hoist " << *(opToHoist) << "\n";
948 DBGS() << " by " << numLoops << " loops\n");
949
950 HoistPaddingAnalysis analysis(opToHoist, numLoops);
951 analysis.enableHoistPadding(rewriter);
952 analysis.finalizeHoistPaddingAnalysis();
953 if (!analysis.isValid()) {
954 LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n");
955 return failure();
956 }
957
958
961 rewriter, bvm, opToHoist, transposeVector, analysis);
962 if (failed(packingResult)) {
963 LLVM_DEBUG(DBGS() << "--buildPackingLoopNestImpl failed -> Skip\n");
964 return failure();
965 }
966
967 if (!transposeVector.empty())
968 transposeOps.push_back(packingResult->maybeTransposeOp);
969
970 FailureOr transposedTensorType =
972 assert(succeeded(transposedTensorType) && "unexpected failure in type");
973
974
975
976 Value newResult =
978 analysis, *packingResult);
979
980 Location loc = opToHoist->getLoc();
981 RankedTensorType paddedTensorType = opToHoist.getResultType();
982 if (!transposeVector.empty()) {
985
986 Value emptyTensor = rewriter.createtensor::EmptyOp(
987 loc, paddedTensorType.getShape(), paddedTensorType.getElementType());
988 TransposeOp unTransposeOp = rewriter.createlinalg::TransposeOp(
989 loc, newResult, emptyTensor, transposeVector);
990 newResult = unTransposeOp.getResult()[0];
991 transposeOps.push_back(unTransposeOp);
992 }
993
994 LLVM_DEBUG(DBGS() << "newResult: " << newResult << "\n");
995 LLVM_DEBUG(
996 DBGS() << "After hoisting: "
998 << "\n");
999
1000
1001 hoistedOp = packingResult->hoistedPadOp;
1002
1003 LLVM_DEBUG(DBGS() << "--SUCCESS\n");
1004 return newResult;
1005 }
1006
1008 tensor::PadOp opToHoist, int64_t numLoops,
1011 IRRewriter rewriter(opToHoist.getContext());
1013 hoistedOp, transposeOps);
1014 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static tensor::ExtractSliceOp padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting, Value hoistedPackedTensor, tensor::ExtractSliceOp outerSliceOp, scf::ForOp forOp)
If the original consumer of outerSliceOp was a forOp (i.e.
static Value buildLoopIterationCount(RewriterBase &rewriter, scf::ForOp outer, scf::ForOp forOp)
Return the current iteration number in the loop (iv - lb).ceilDiv(step).
static void getEnclosingLoopsUntil(tensor::PadOp padOp, scf::ForOp untilLoop, SmallVector< scf::ForOp > &reverseEnclosingLoops)
Return at most nLevels of immediately enclosing scf::ForOp loops.
static bool debugPrintLoopInShortForm(Operation *op)
static bool tracesBackToExpectedValue(tensor::ExtractSliceOp extractSliceOp, Value expectedSource)
Return true if we can walk back the use-def chain from extractSliceOp to expectedSource going through...
static bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v)
static FailureOr< PackingResult > buildPackingLoopNestImpl(RewriterBase &rewriter, IRMapping &bvm, tensor::PadOp opToHoist, ArrayRef< int64_t > transposeVector, RankedTensorType transposedTensorType, tensor::EmptyOp emptyOp, const HoistPaddingAnalysis &analysis)
static void computeBackwardSlice(tensor::PadOp padOp, scf::ForOp outermostEnclosingForOp, SetVector< Operation * > &backwardSlice)
static Value replaceByPackingResult(RewriterBase &rewriter, const IRMapping &bvm, tensor::PadOp opToHoist, RankedTensorType transposedTensorType, const HoistPaddingAnalysis &analysis, const PackingResult &packingResult)
Produce a tensor extracted from the packingResult.
static void debugPrintBackwardSlice(SetVector< Operation * > &backwardSlice)
static void getAtMostNEnclosingLoops(tensor::PadOp padOp, int nLevels, SmallVector< scf::ForOp > &reverseEnclosingLoops)
Return at most nLevels of immediately enclosing scf::ForOp loops.
Base type for affine expression.
This class provides management for the lifetime of the state used when printing the IR.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A helper class to be used with ValueBoundsOpInterface.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
FailureOr< OpFoldResult > reifyIndexValueBound(OpBuilder &b, Location loc, presburger::BoundType type, Value value, ValueBoundsConstraintSet::StopConditionFn stopCondition=nullptr, bool closedUB=false)
Reify a bound for the given index-typed value in terms of SSA values for which stopCondition is met.
void bindDims(MLIRContext *ctx)
void bindSymbols(MLIRContext *ctx)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< RankedTensorType > computeTransposedType(RankedTensorType rankedTensorType, ArrayRef< int64_t > transposeVector)
Returns the transposed rankedTensorType if transposeVector is non-empty.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
LoopLikeOpInterface hoistLoopInvariantSubsets(RewriterBase &rewriter, LoopLikeOpInterface loopLike)
Hoist loop-invariant tensor subsets (subset extraction and subset insertion ops) from loop-like ops.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool inclusive
Include the top level op in the slice.
Helper struct to hold the results of building a packing loop nest.
SmallVector< OpFoldResult > strides
SmallVector< Value > clonedLoopIvs
SmallVector< OpFoldResult > sizes