MLIR: lib/Dialect/Affine/Transforms/LoopFusion.cpp Source File (original) (raw)
256 !commonBlock ? &*sliceInsertionBlock->begin() : &*commonBlock->begin();
259 assert(commonBlock &&
260 "common block of producer stores and slice should exist");
264 Operation *firstAncestor = nullptr;
265 for (Operation *store : producerStores) {
267 assert(ancestor && "producer store should be contained in common block");
268 firstAncestor = !firstAncestor || ancestor->isBeforeInBlock(firstAncestor)
269 ? ancestor
270 : firstAncestor;
272 return firstAncestor;
273}
278
280 AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,
282 int64_t &fusedLoopNestComputeCost) {
283 LDBG() << "Determining additional compute fraction...";
284
288 LDBG() << "Failed to get source loop nest stats.";
289 return std::nullopt;
295 LDBG() << "Failed to get destination loop nest stats.";
296 return std::nullopt;
297 }
298
300 uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
301
302
303 uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
304
306
308 LDBG() << "Slice wasn't computed.";
309 return std::nullopt;
313 dstLoopNestStats, slice,
314 &fusedLoopNestComputeCost)) {
315 LDBG() << "Unable to compute fusion compute cost";
316 return std::nullopt;
317 }
318
319 double additionalComputeFraction =
320 fusedLoopNestComputeCost /
321 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
322 1;
323
324 return additionalComputeFraction;
325}
326
327
328
329
330
331
332
335 unsigned dstLoopDepth,
336 std::optional fastMemorySpace,
337 Block *sliceInsertionBlock,
338 uint64_t localBufSizeThreshold) {
339 assert(!storeOps.empty() && "no source stores supplied");
340
341
342
343
344
345 if (storeOps.size() > 1 &&
346 !std::equal(std::next(storeOps.begin()), storeOps.end(), storeOps.begin(),
348 MemRefAccess aM(cast(a));
349 MemRefAccess bM(cast(b));
350 return aM == bM;
351 })) {
352 LDBG() << "Private memref creation unsupported for multiple producer "
353 << "stores with different access functions.";
354 return nullptr;
355 }
356
357 Operation *srcStoreOp = storeOps[0];
358
359
361
362 OpBuilder top(forOp->getParentRegion());
363
364 auto oldMemRef = cast(srcStoreOp).getMemRef();
365 auto oldMemRefType = cast(oldMemRef.getType());
366 unsigned rank = oldMemRefType.getRank();
367
368
370 bool validRegion = succeeded(
371 region.compute(srcStoreOp, dstLoopDepth, nullptr,
372 true, false));
373
374 (void)validRegion;
375 assert(validRegion && "unexpected memref region failure");
378 lbs.reserve(rank);
379
380
381 std::optional<int64_t> numElements =
383 assert(numElements && "non-constant number of elts in local buffer");
384
386
387
388
391
392
394 offsets.reserve(rank);
395
396
397
399 for (unsigned j = 0, e = lbs[0].getNumSymbols(); j < e; ++j)
401 for (unsigned d = 0; d < rank; ++d) {
402 assert(lbs[d].getNumResults() == 1 &&
403 "invalid private memref bound calculation");
404 offsets.push_back(lbs[d].getResult(0).replaceSymbols(replacements));
405 }
406
407
408
410 assert(eltSize && "memrefs with size elt types expected");
411 uint64_t bufSize = *eltSize * *numElements;
413 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
414 newMemSpace = b.getI64IntegerAttr(*fastMemorySpace);
415 } else {
416 newMemSpace = oldMemRefType.getMemorySpace();
417 }
418 auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
419 AffineMap(), newMemSpace);
420
421
422
423
424
425
426
427 Value newMemRef = memref::AllocOp::create(top, forOp.getLoc(), newMemRefType);
428
429
431 remapExprs.reserve(rank);
432 for (unsigned i = 0; i < rank; i++) {
433 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
434
435 auto remapExpr =
437 remapExprs.push_back(remapExpr);
438 }
439
440 auto indexRemap =
441 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
442
443
446 auto userFilterFn = [&](Operation *user) {
447 auto domInfo = std::make_unique(
449 return domInfo->dominates(domFilter, user);
450 };
451 LogicalResult res = replaceAllMemRefUsesWith(
452 oldMemRef, newMemRef, {}, indexRemap,
453 outerIVs,
454 {}, userFilterFn);
455 assert(succeeded(res) &&
456 "replaceAllMemrefUsesWith should always succeed here");
458 LDBG() << "Created private memref of type: " << newMemRefType;
459 return newMemRef;
460}
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
502 AffineForOp dstForOp,
504 unsigned maxLegalFusionDepth,
505 unsigned *dstLoopDepth,
506 double computeToleranceThreshold) {
507 LDBG() << "Checking whether fusion is profitable between source nest:";
508 LDBG() << ' ' << srcForOp << " and destination nest:";
509 LDBG() << dstForOp;
510
511 if (maxLegalFusionDepth == 0) {
512 LDBG() << "Can't fuse: maxLegalFusionDepth is 0";
513 return false;
514 }
515
516
517
518
521 return false;
522
523
526 return false;
527
528
529
530
531
532
533
534
535 if (producerStores.size() > 1) {
536 LDBG() << "Limited profitability analysis. Not "
537 << "supported for multiple producer store case.";
539 int64_t fusedLoopNestComputeCost;
540
541
543 srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
544 fusedLoopNestComputeCost);
545 if (!fraction || fraction > computeToleranceThreshold) {
546 LDBG() << "Additional computation exceeds "
547 << "compute tolerance. Not fusing.";
548 return false;
549 }
550 LDBG() << "Considering fusion profitable at max legal depth.";
551 return true;
552 }
553
554 Operation *srcStoreOp = producerStores.front();
555
556
557
558
559
560
561
562 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
563 double maxStorageReduction = 0.0;
564 std::optional<uint64_t> sliceMemEstimate;
565
566
567 std::optional bestDstLoopDepth;
568
569
571 if (failed(srcWriteRegion.compute(srcStoreOp, 0))) {
572 LDBG() << "Unable to compute MemRefRegion for source operation";
573 return false;
574 }
575
576 std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
578 if (!maybeSrcWriteRegionSizeBytes.has_value())
579 return false;
580 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
581
582
583 uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
584
585
586 uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
587
588
589
590 for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
592
594 continue;
595
596
597
599
600 int64_t fusedLoopNestComputeCost;
601
602 auto mayAdditionalComputeFraction =
604 sliceCost, fusedLoopNestComputeCost);
605 if (!mayAdditionalComputeFraction) {
606 LDBG() << "Can't determine additional compute fraction.";
607 continue;
608 }
609 double additionalComputeFraction = *mayAdditionalComputeFraction;
610
611
612
613
615 if (failed(sliceWriteRegion.compute(srcStoreOp, 0, &slice))) {
616 LDBG() << "Failed to compute slice write region at loopDepth: " << i;
617 continue;
618 }
619
620 std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
622 if (!maybeSliceWriteRegionSizeBytes.has_value() ||
623 *maybeSliceWriteRegionSizeBytes == 0) {
624 LDBG() << "Failed to get slice write region size at loopDepth: " << i;
625 continue;
626 }
627 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
628
629 double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
630 static_cast<double>(sliceWriteRegionSizeBytes);
631
632 LLVM_DEBUG({
633 std::stringstream msg;
634 msg << " evaluating fusion profitability at depth : " << i << "\n"
635 << std::fixed << std::setprecision(2)
636 << " additional compute fraction: "
637 << 100.0 * additionalComputeFraction << "%\n"
638 << " storage reduction factor: " << storageReduction << "x\n"
639 << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
640 << " src write region size: " << srcWriteRegionSizeBytes << "\n"
641 << " slice write region size: " << sliceWriteRegionSizeBytes;
642 LDBG() << msg.str();
643 });
644
645
646
647
648
649 if ((storageReduction > maxStorageReduction) &&
650 (additionalComputeFraction <= computeToleranceThreshold)) {
651 maxStorageReduction = storageReduction;
652 bestDstLoopDepth = i;
653 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
654 sliceMemEstimate = sliceWriteRegionSizeBytes;
655 }
656 }
657
658
659
660 if (!bestDstLoopDepth) {
661 LDBG() << "All fusion choices involve more than the threshold amount of "
662 << "redundant computation; NOT fusing.";
663 return false;
664 }
665
666 if (!bestDstLoopDepth) {
667 LDBG() << "no fusion depth could be evaluated.";
668 return false;
669 }
670
671
672 *dstLoopDepth = *bestDstLoopDepth;
673
674 LDBG() << " LoopFusion fusion stats:";
675 LDBG() << " best loop depth: " << bestDstLoopDepth;
676 LDBG() << " src loop nest compute cost: " << srcLoopNestCost;
677 LDBG() << " dst loop nest compute cost: " << dstLoopNestCost;
678 LDBG() << " fused loop nest compute cost: " << minFusedLoopNestComputeCost;
679
682
683 std::optional storageReduction;
684
685 if (!dstMemSize || !srcMemSize) {
686 LDBG() << " fusion memory benefit cannot be evaluated; NOT fusing.";
687 return false;
688 }
689
690 auto srcMemSizeVal = *srcMemSize;
691 auto dstMemSizeVal = *dstMemSize;
692
693 assert(sliceMemEstimate && "expected value");
694 auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
695
696 LDBG() << " src mem: " << srcMemSizeVal;
697 LDBG() << " dst mem: " << dstMemSizeVal;
698 LDBG() << " fused mem: " << fusedMem;
699 LDBG() << " slice mem: " << sliceMemEstimate;
700
701 if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
702 LDBG() << "Fusion is not profitable; NOT fusing.";
703 return false;
704 }
705 storageReduction =
706 100.0 *
707 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
708
709 double additionalComputeFraction =
710 100.0 * (minFusedLoopNestComputeCost /
711 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
712 1);
713 (void)additionalComputeFraction;
714 LLVM_DEBUG({
715 std::stringstream msg;
716 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
717 << std::setprecision(2) << additionalComputeFraction
718 << "% redundant computation and a ";
719 msg << (storageReduction ? std::to_string(*storageReduction) : "");
720 msg << "% storage reduction.";
721 LDBG() << msg.str();
722 });
723
724 return true;
725}
726
727namespace {
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775struct GreedyFusion {
776public:
777
778 MemRefDependenceGraph *mdg;
779
780 SmallVector<unsigned, 8> worklist;
781
782 unsigned localBufSizeThreshold;
783
784 std::optional fastMemorySpace;
785
786
787 bool maximalFusion;
788
789
790 double computeToleranceThreshold;
791
792 using Node = MemRefDependenceGraph::Node;
793
794 GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
795 std::optional fastMemorySpace, bool maximalFusion,
796 double computeToleranceThreshold)
797 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
798 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
799 computeToleranceThreshold(computeToleranceThreshold) {}
800
801
802 void init() {
803
804
805 worklist.clear();
806 for (auto &idAndNode : mdg->nodes) {
807 const Node &node = idAndNode.second;
808 worklist.push_back(node.id);
809 }
810 }
811
812 void runSiblingFusionOnly() {
813 fuseSiblingNodes();
814 eraseUnusedMemRefAllocations();
815 }
816
817
818 void runProducerConsumerFusionOnly() {
819 fuseProducerConsumerNodes(
820 std::numeric_limits::max());
821 eraseUnusedMemRefAllocations();
822 }
823
824
825
826
827
828
829 void runGreedyFusion() {
830
831 fuseProducerConsumerNodes(1);
832 fuseSiblingNodes();
833 fuseProducerConsumerNodes(
834 std::numeric_limits::max());
835 eraseUnusedMemRefAllocations();
836 }
837
838
839
840 bool canCreatePrivateMemRef(Value memref,
842 unsigned producerId, unsigned consumerId,
843 bool removeSrcNode) {
844
846 return false;
847 const Node *consumerNode = mdg->getNode(consumerId);
848
849
850
851
852
853
854
855 if (srcEscapingMemRefs.count(memref) > 0 &&
856 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
857 return false;
858
859
860
861 if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 ||
862 mdg->getOutEdgeCount(consumerId, memref) > 0)
863 return false;
864
865
866
867
868 if (removeSrcNode &&
869 any_of(mdg->outEdges[producerId], [&](const auto &edge) {
870 return edge.value == memref && edge.id != consumerId;
871 }))
872 return false;
873
874 return true;
875 }
876
877
878
879
880 void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
881 LDBG() << "Evaluating dst loop " << dstId;
882
883 if (mdg->nodes.count(dstId) == 0)
884 return;
885
886 auto *dstNode = mdg->getNode(dstId);
887
888 if (!isa(dstNode->op))
889 return;
890
891
892 if (dstNode->op->getNumResults() > 0)
893 return;
894
895 LDBG() << "Evaluating dst loop " << dstId;
896
897
898
899
900
902 auto dstAffineForOp = cast(dstNode->op);
903
904
905
906 bool dstNodeChanged;
907 do {
908
909
910
911
912
913 dstNodeChanged = false;
914 SmallVector<unsigned, 16> srcIdCandidates;
916
917 for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
918
919 auto *srcNode = mdg->getNode(srcId);
920 auto srcAffineForOp = cast(srcNode->op);
921
922 LDBG() << "Trying to fuse producer loop nest " << srcId
923 << " with consumer loop nest " << dstId;
924 LDBG() << "Compute tolerance threshold: " << computeToleranceThreshold;
925 LDBG() << "Producer loop nest:";
926 LDBG() << *srcNode->op << " and consumer loop nest:";
927 LDBG() << *dstNode->op;
928
929 LDBG() << "Evaluating src loop " << srcId << " for dst loop " << dstId;
930
931
932
933 if (isa(srcNode->op) && srcNode->op->getNumResults() > 0)
934 continue;
935
938 producerConsumerMemrefs);
939
940
941
942 if (any_of(producerConsumerMemrefs, [&](Value memref) {
943 return mdg->getOutEdgeCount(srcNode->id, memref) >
944 maxSrcUserCount;
945 }))
946 continue;
947
948
949
950
953
954
955
956 Operation *fusedLoopInsPoint =
957 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
958 if (fusedLoopInsPoint == nullptr)
959 continue;
960
961
962
963
964
965
966
967 SmallVector<AffineForOp, 4> surroundingLoops;
969 unsigned numSurroundingLoops = surroundingLoops.size();
970
971
972
973 SmallVector<Operation *, 2> dstMemrefOps;
974 for (Operation *op : dstNode->loads)
975 if (producerConsumerMemrefs.count(
976 cast(op).getMemRef()) > 0)
977 dstMemrefOps.push_back(op);
978 for (Operation *op : dstNode->stores)
979 if (producerConsumerMemrefs.count(
980 cast(op).getMemRef()))
981 dstMemrefOps.push_back(op);
982 if (dstMemrefOps.empty())
983 continue;
984 unsigned dstLoopDepthTest =
986
987
988
989 unsigned maxLegalFusionDepth = 0;
990 SmallVector<ComputationSliceState, 8> depthSliceUnions;
991 depthSliceUnions.resize(dstLoopDepthTest);
993 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
994 FusionResult result =
996 i + numSurroundingLoops,
997 &depthSliceUnions[i - 1], strategy);
999 maxLegalFusionDepth = i;
1000 LDBG() << "Found valid slice for depth: " << i;
1001 }
1002 }
1003
1004 if (maxLegalFusionDepth == 0) {
1005 LDBG() << "Can't fuse: fusion is not legal at any depth";
1006 continue;
1007 }
1008
1009 LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth;
1010
1011 double computeToleranceThresholdToUse = computeToleranceThreshold;
1012
1013
1014
1015
1016
1017
1019 LDBG() << "Source nest has a cyclic dependence.";
1020
1021
1022
1023 if (maximalFusion) {
1024 auto srcForOp = cast(srcNode->op);
1025 auto dstForOp = cast(dstNode->op);
1026 int64_t sliceCost;
1027 int64_t fusedLoopNestComputeCost;
1029 srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1030 sliceCost, fusedLoopNestComputeCost);
1031 if (!fraction || fraction > 0) {
1032 LDBG() << "Can't perform maximal fusion with a cyclic dependence "
1033 << "and non-zero additional compute.";
1034 return;
1035 }
1036 } else {
1037
1038
1039 LDBG() << "Setting compute tolerance to zero since "
1040 << "source has a cylic dependence.";
1041 computeToleranceThresholdToUse = 0;
1042 }
1043 }
1044
1045
1046
1047
1048 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1049 if (!maximalFusion) {
1050
1051 SmallVector<Operation *, 2> producerStores;
1052 for (Operation *op : srcNode->stores)
1053 if (producerConsumerMemrefs.count(
1054 cast(op).getMemRef()))
1055 producerStores.push_back(op);
1056
1057 assert(!producerStores.empty() && "Expected producer store");
1059 dstAffineForOp, depthSliceUnions,
1060 maxLegalFusionDepth, &bestDstLoopDepth,
1061 computeToleranceThresholdToUse)) {
1062 continue;
1063 }
1064 }
1065
1066 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1067 ComputationSliceState &bestSlice =
1068 depthSliceUnions[bestDstLoopDepth - 1];
1069 assert(!bestSlice.isEmpty() && "Missing slice union for depth");
1070
1071
1072
1073
1075 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
1076 *mdg);
1077
1079 for (Value memref : producerConsumerMemrefs) {
1080 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
1081 removeSrcNode)) {
1082
1083 LDBG() << "Creating private memref for " << memref;
1084
1085 privateMemrefs.insert(memref);
1086 }
1087 }
1088
1089
1090 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
1091 dstNodeChanged = true;
1092
1093 LDBG() << "Fused src loop " << srcId << " into dst loop " << dstId
1094 << " at depth " << bestDstLoopDepth << ":";
1095 LDBG() << dstAffineForOp;
1096
1097
1098 if (fusedLoopInsPoint != dstAffineForOp)
1099 dstAffineForOp->moveBefore(fusedLoopInsPoint);
1100
1101
1102 mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1103 removeSrcNode);
1104
1105
1106 if (!privateMemrefs.empty()) {
1107
1108
1109 Block *sliceInsertionBlock = bestSlice.insertPoint->getBlock();
1110
1111
1113 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1114 Value storeMemRef = storeOp.getMemRef();
1115 if (privateMemrefs.count(storeMemRef) > 0)
1116 privateMemRefToStores[storeMemRef].push_back(storeOp);
1117 });
1118
1119
1120
1121
1122
1123 for (auto &memrefToStoresPair : privateMemRefToStores) {
1124 ArrayRef<Operation *> storesForMemref = memrefToStoresPair.second;
1126 dstAffineForOp, storesForMemref, bestDstLoopDepth,
1127 fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
1128 if (!newMemRef)
1129 continue;
1130
1131 unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
1132
1133 mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
1134 }
1135
1136
1137
1138 dstNode = mdg->getNode(dstId);
1139 }
1140
1141
1142 LoopNestStateCollector dstLoopCollector;
1143 dstLoopCollector.collect(dstAffineForOp);
1144
1145
1146 mdg->clearNodeLoadAndStores(dstNode->id);
1147 mdg->addToNode(
1151
1152 if (removeSrcNode) {
1153 LDBG() << "Removing src loop " << srcId << " after fusion";
1154
1155 srcAffineForOp.erase();
1156 mdg->removeNode(srcId);
1157 srcNode = nullptr;
1158 }
1159 }
1160 } while (dstNodeChanged);
1161 }
1162
1163
1164
1165
1166
1167 void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
1168 LDBG() << "--- Producer/Consumer Fusion ---";
1169 init();
1170 while (!worklist.empty()) {
1171 unsigned dstId = worklist.back();
1172 worklist.pop_back();
1173 performFusionsIntoDest(dstId, maxSrcUserCount);
1174 }
1175 }
1176
1177
1178
1179 void fuseSiblingNodes() {
1180 LDBG() << "--- Sibling Fusion ---";
1181 init();
1182 while (!worklist.empty()) {
1183 unsigned dstId = worklist.back();
1184 worklist.pop_back();
1185
1186
1187 if (mdg->nodes.count(dstId) == 0)
1188 continue;
1189
1190 auto *dstNode = mdg->getNode(dstId);
1191
1192 if (!isa(dstNode->op))
1193 continue;
1194
1195 fuseWithSiblingNodes(dstNode);
1196 }
1197 }
1198
1199
1200 void fuseWithSiblingNodes(Node *dstNode) {
1202 std::pair<unsigned, Value> idAndMemref;
1203 auto dstAffineForOp = cast(dstNode->op);
1204
1205 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1206 unsigned sibId = idAndMemref.first;
1207 Value memref = idAndMemref.second;
1208
1209
1210 auto *sibNode = mdg->getNode(sibId);
1211
1212
1213 assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1214 Operation *insertPointInst =
1215 sibNode->op->isBeforeInBlock(dstNode->op)
1216 ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
1217 : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
1218 if (insertPointInst == nullptr)
1219 continue;
1220
1221
1222
1223
1224 SmallVector<Operation *, 2> sibLoadOpInsts;
1225 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1226
1227 Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);
1228
1229
1230 SmallVector<Operation *, 2> dstLoadOpInsts;
1231 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1232
1233
1234
1235
1236
1237
1238
1239 SmallVector<AffineForOp, 4> surroundingLoops;
1241 unsigned numSurroundingLoops = surroundingLoops.size();
1242 SmallVector<AffineForOp, 4> dstLoopIVs;
1244 unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1245 auto sibAffineForOp = cast(sibNode->op);
1246
1247
1248 SmallVector<ComputationSliceState, 8> depthSliceUnions;
1249 depthSliceUnions.resize(dstLoopDepthTest);
1250 unsigned maxLegalFusionDepth = 0;
1251 FusionStrategy strategy(memref);
1252 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1253 FusionResult result =
1255 i + numSurroundingLoops,
1256 &depthSliceUnions[i - 1], strategy);
1257
1259 maxLegalFusionDepth = i;
1260 }
1261
1262 LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth;
1263
1264
1265 if (maxLegalFusionDepth == 0)
1266 continue;
1267
1268 double computeToleranceThresholdToUse = computeToleranceThreshold;
1269
1270
1271
1272
1273
1274
1276 LDBG() << "Source nest has a cyclic dependence.";
1277
1278
1279
1280 if (maximalFusion) {
1281 auto dstForOp = cast(dstNode->op);
1282 int64_t sliceCost;
1283 int64_t fusedLoopNestComputeCost;
1285 sibAffineForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1286 sliceCost, fusedLoopNestComputeCost);
1287 if (!fraction || fraction > 0) {
1288 LDBG() << "Can't perform maximal fusion with a cyclic dependence "
1289 << "and non-zero additional compute.";
1290 return;
1291 }
1292 } else {
1293
1294
1295 LDBG() << "Setting compute tolerance to zero since "
1296 << "source has a cyclic dependence.";
1297 computeToleranceThresholdToUse = 0.0;
1298 }
1299 }
1300
1301 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1302 if (!maximalFusion) {
1303
1304
1305
1306
1307 if ((sibAffineForOp, sibLoadOpInst, dstAffineForOp,
1308 depthSliceUnions, maxLegalFusionDepth,
1309 &bestDstLoopDepth,
1310 computeToleranceThresholdToUse))
1311 continue;
1312 }
1313
1314 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1315
1316 const ComputationSliceState &bestSlice =
1317 depthSliceUnions[bestDstLoopDepth - 1];
1318 assert(!bestSlice.isEmpty() &&
1319 "Fusion depth has no computed slice union");
1320
1321
1322
1323
1324 auto isMaximal = bestSlice.isMaximal();
1325 if (!isMaximal.value_or(false)) {
1326 LDBG() << "Slice isn't maximal; not performing sibling fusion.";
1327 continue;
1328 }
1329
1330
1331
1332
1333 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1334
1336 isInnermostInsertion);
1337
1338 auto dstForInst = cast(dstNode->op);
1339
1340 if (insertPointInst != dstForInst)
1341 dstForInst->moveBefore(insertPointInst);
1342
1343 LDBG() << "Fused sibling nest " << sibId << " into destination nest "
1344 << dstNode->id << " at depth " << bestDstLoopDepth << ":";
1345 LDBG() << dstAffineForOp;
1346
1347
1348 updateStateAfterSiblingFusion(sibNode, dstNode);
1349
1350
1351
1352 Operation *op = sibNode->op;
1353 mdg->removeNode(sibNode->id);
1355 }
1356 }
1357
1358
1359
1360
1361
1362 bool findSiblingNodeToFuse(Node *dstNode,
1364 std::pair<unsigned, Value> *idAndMemrefToFuse) {
1365
1366
1367 auto canFuseWithSibNode = [&](Node *sibNode, Value memref) {
1368
1369
1370 if (sibNode->getLoadOpCount(memref) != 1)
1371 return false;
1372
1373
1374 if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
1375 mdg->hasDependencePath(dstNode->id, sibNode->id))
1376 return false;
1377
1378
1380 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1381 if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) {
1382 return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
1383 }))
1384 return false;
1385
1386
1388 for (auto *storeOpInst : sibNode->stores) {
1389 storeMemrefs.insert(
1390 cast(storeOpInst).getMemRef());
1391 }
1392 return storeMemrefs.size() <= 1;
1393 };
1394
1395
1396 Block *block = dstNode->op->getBlock();
1397 for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) {
1399 auto loadOp = dyn_cast(user);
1400 if (!loadOp)
1401 continue;
1402
1403 SmallVector<AffineForOp, 4> loops;
1405
1406
1407
1408 auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1409 return loop->getBlock() == &mdg->block;
1410 });
1411
1412 if (it == loops.end())
1413 continue;
1414 Node *sibNode = mdg->getForOpNode(*it);
1415 assert(sibNode != nullptr);
1416
1417 if (sibNode->id == dstNode->id)
1418 continue;
1419
1420 if (visitedSibNodeIds->count(sibNode->id) > 0)
1421 continue;
1422
1423 auto memref = loadOp.getMemRef();
1424 if (dstNode->getLoadOpCount(memref) == 0)
1425 continue;
1426
1427 if (canFuseWithSibNode(sibNode, memref)) {
1428 visitedSibNodeIds->insert(sibNode->id);
1429 idAndMemrefToFuse->first = sibNode->id;
1430 idAndMemrefToFuse->second = memref;
1431 return true;
1432 }
1433 }
1434 }
1435
1436
1437
1438 SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
1439 mdg->forEachMemRefInputEdge(
1440 dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
1441
1442
1443 if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1444 (mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0 ||
1445 inEdge.value.getDefiningOp() == mdg->getNode(inEdge.id)->op))
1446 inEdges.push_back(inEdge);
1447 });
1448
1449
1450
1451 for (auto &inEdge : inEdges) {
1452
1453 SmallVector<MemRefDependenceGraph::Edge, 2> outEdges;
1454 mdg->forEachMemRefOutputEdge(
1455 inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
1456 unsigned sibNodeId = outEdge.id;
1457 if (visitedSibNodeIds->count(sibNodeId) > 0)
1458 return;
1459
1460 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1461 return;
1462 auto *sibNode = mdg->getNode(sibNodeId);
1463 if (!isa(sibNode->op))
1464 return;
1465
1466 if (canFuseWithSibNode(sibNode, outEdge.value)) {
1467
1468 outEdges.push_back(outEdge);
1469 }
1470 });
1471
1472
1473 if (!outEdges.empty()) {
1474 visitedSibNodeIds->insert(outEdges[0].id);
1475 idAndMemrefToFuse->first = outEdges[0].id;
1476 idAndMemrefToFuse->second = outEdges[0].value;
1477 return true;
1478 }
1479 }
1480 return false;
1481 }
1482
1483
1484
1485 void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {
1486
1487 mdg->updateEdges(sibNode->id, dstNode->id);
1488
1489
1490 auto dstForInst = cast(dstNode->op);
1492 dstLoopCollector.collect(dstForInst);
1493
1498 }
1499
1500
1501 void eraseUnusedMemRefAllocations() {
1503 if (pair.second > 0)
1504 continue;
1505 auto memref = pair.first;
1506
1507 if (.use_empty())
1508 continue;
1509
1510 auto *op = memref.getDefiningOp();
1511 if (isa_and_nonnullmemref::AllocOp(op))
1513 }
1514 }
1515};
1516
1517}
1518
1519
1520void LoopFusion::runOnBlock(Block *block) {
1521 MemRefDependenceGraph g(*block);
1522 if (!g.init()) {
1523 LDBG() << "MDG init failed";
1524 return;
1525 }
1526
1527 std::optional fastMemorySpaceOpt;
1528 if (fastMemorySpace.hasValue())
1529 fastMemorySpaceOpt = fastMemorySpace;
1530 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1531 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1532 maximalFusion, computeToleranceThreshold);
1533
1534 if (affineFusionMode == FusionMode::ProducerConsumer)
1535 fusion.runProducerConsumerFusionOnly();
1536 else if (affineFusionMode == FusionMode::Sibling)
1537 fusion.runSiblingFusionOnly();
1538 else
1539 fusion.runGreedyFusion();
1540}
1541
1542void LoopFusion::runOnOperation() {
1543
1544
1545 getOperation()->walk([&](Operation *op) {
1546 for (Region ®ion : op->getRegions()) {
1547 for (Block &block : region.getBlocks()) {
1548 auto affineFors = block.getOps();
1549 if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1550 runOnBlock(&block);
1551 }
1552 }
1553 });
1554}
1555
1557 unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1558 bool maximalFusion, enum FusionMode affineFusionMode) {
1559 return std::make_unique(fastMemorySpace, localBufSizeThreshold,
1560 maximalFusion, affineFusionMode);
1561}