MLIR: lib/Dialect/Affine/Transforms/LoopFusion.cpp Source File (original) (raw)

256 !commonBlock ? &*sliceInsertionBlock->begin() : &*commonBlock->begin();

259 assert(commonBlock &&

260 "common block of producer stores and slice should exist");

262

263

264 Operation *firstAncestor = nullptr;

265 for (Operation *store : producerStores) {

267 assert(ancestor && "producer store should be contained in common block");

268 firstAncestor = !firstAncestor || ancestor->isBeforeInBlock(firstAncestor)

269 ? ancestor

270 : firstAncestor;

272 return firstAncestor;

273}

277

278

280 AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,

282 int64_t &fusedLoopNestComputeCost) {

283 LDBG() << "Determining additional compute fraction...";

284

288 LDBG() << "Failed to get source loop nest stats.";

289 return std::nullopt;

295 LDBG() << "Failed to get destination loop nest stats.";

296 return std::nullopt;

297 }

298

299

300 uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);

301

302

303 uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);

304

306

308 LDBG() << "Slice wasn't computed.";

309 return std::nullopt;

313 dstLoopNestStats, slice,

314 &fusedLoopNestComputeCost)) {

315 LDBG() << "Unable to compute fusion compute cost";

316 return std::nullopt;

317 }

318

319 double additionalComputeFraction =

320 fusedLoopNestComputeCost /

321 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -

322 1;

323

324 return additionalComputeFraction;

325}

326

327

328

329

330

331

332

335 unsigned dstLoopDepth,

336 std::optional fastMemorySpace,

337 Block *sliceInsertionBlock,

338 uint64_t localBufSizeThreshold) {

339 assert(!storeOps.empty() && "no source stores supplied");

340

341

342

343

344

345 if (storeOps.size() > 1 &&

346 !std::equal(std::next(storeOps.begin()), storeOps.end(), storeOps.begin(),

348 MemRefAccess aM(cast(a));

349 MemRefAccess bM(cast(b));

350 return aM == bM;

351 })) {

352 LDBG() << "Private memref creation unsupported for multiple producer "

353 << "stores with different access functions.";

354 return nullptr;

355 }

356

357 Operation *srcStoreOp = storeOps[0];

358

359

361

362 OpBuilder top(forOp->getParentRegion());

363

364 auto oldMemRef = cast(srcStoreOp).getMemRef();

365 auto oldMemRefType = cast(oldMemRef.getType());

366 unsigned rank = oldMemRefType.getRank();

367

368

370 bool validRegion = succeeded(

371 region.compute(srcStoreOp, dstLoopDepth, nullptr,

372 true, false));

373

374 (void)validRegion;

375 assert(validRegion && "unexpected memref region failure");

378 lbs.reserve(rank);

379

380

381 std::optional<int64_t> numElements =

383 assert(numElements && "non-constant number of elts in local buffer");

384

386

387

388

391

392

394 offsets.reserve(rank);

395

396

397

399 for (unsigned j = 0, e = lbs[0].getNumSymbols(); j < e; ++j)

401 for (unsigned d = 0; d < rank; ++d) {

402 assert(lbs[d].getNumResults() == 1 &&

403 "invalid private memref bound calculation");

404 offsets.push_back(lbs[d].getResult(0).replaceSymbols(replacements));

405 }

406

407

408

410 assert(eltSize && "memrefs with size elt types expected");

411 uint64_t bufSize = *eltSize * *numElements;

413 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {

414 newMemSpace = b.getI64IntegerAttr(*fastMemorySpace);

415 } else {

416 newMemSpace = oldMemRefType.getMemorySpace();

417 }

418 auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),

419 AffineMap(), newMemSpace);

420

421

422

423

424

425

426

427 Value newMemRef = memref::AllocOp::create(top, forOp.getLoc(), newMemRefType);

428

429

431 remapExprs.reserve(rank);

432 for (unsigned i = 0; i < rank; i++) {

433 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);

434

435 auto remapExpr =

437 remapExprs.push_back(remapExpr);

438 }

439

440 auto indexRemap =

441 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());

442

443

446 auto userFilterFn = [&](Operation *user) {

447 auto domInfo = std::make_unique(

449 return domInfo->dominates(domFilter, user);

450 };

451 LogicalResult res = replaceAllMemRefUsesWith(

452 oldMemRef, newMemRef, {}, indexRemap,

453 outerIVs,

454 {}, userFilterFn);

455 assert(succeeded(res) &&

456 "replaceAllMemrefUsesWith should always succeed here");

458 LDBG() << "Created private memref of type: " << newMemRefType;

459 return newMemRef;

460}

461

462

463

464

465

466

467

468

469

470

471

472

473

474

475

476

477

478

479

480

481

482

483

484

485

486

487

488

489

490

491

492

493

494

495

496

497

498

499

502 AffineForOp dstForOp,

504 unsigned maxLegalFusionDepth,

505 unsigned *dstLoopDepth,

506 double computeToleranceThreshold) {

507 LDBG() << "Checking whether fusion is profitable between source nest:";

508 LDBG() << ' ' << srcForOp << " and destination nest:";

509 LDBG() << dstForOp;

510

511 if (maxLegalFusionDepth == 0) {

512 LDBG() << "Can't fuse: maxLegalFusionDepth is 0";

513 return false;

514 }

515

516

517

518

521 return false;

522

523

526 return false;

527

528

529

530

531

532

533

534

535 if (producerStores.size() > 1) {

536 LDBG() << "Limited profitability analysis. Not "

537 << "supported for multiple producer store case.";

539 int64_t fusedLoopNestComputeCost;

540

541

543 srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,

544 fusedLoopNestComputeCost);

545 if (!fraction || fraction > computeToleranceThreshold) {

546 LDBG() << "Additional computation exceeds "

547 << "compute tolerance. Not fusing.";

548 return false;

549 }

550 LDBG() << "Considering fusion profitable at max legal depth.";

551 return true;

552 }

553

554 Operation *srcStoreOp = producerStores.front();

555

556

557

558

559

560

561

562 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();

563 double maxStorageReduction = 0.0;

564 std::optional<uint64_t> sliceMemEstimate;

565

566

567 std::optional bestDstLoopDepth;

568

569

571 if (failed(srcWriteRegion.compute(srcStoreOp, 0))) {

572 LDBG() << "Unable to compute MemRefRegion for source operation";

573 return false;

574 }

575

576 std::optional<int64_t> maybeSrcWriteRegionSizeBytes =

578 if (!maybeSrcWriteRegionSizeBytes.has_value())

579 return false;

580 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;

581

582

583 uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);

584

585

586 uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);

587

588

589

590 for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {

592

594 continue;

595

596

597

599

600 int64_t fusedLoopNestComputeCost;

601

602 auto mayAdditionalComputeFraction =

604 sliceCost, fusedLoopNestComputeCost);

605 if (!mayAdditionalComputeFraction) {

606 LDBG() << "Can't determine additional compute fraction.";

607 continue;

608 }

609 double additionalComputeFraction = *mayAdditionalComputeFraction;

610

611

612

613

615 if (failed(sliceWriteRegion.compute(srcStoreOp, 0, &slice))) {

616 LDBG() << "Failed to compute slice write region at loopDepth: " << i;

617 continue;

618 }

619

620 std::optional<int64_t> maybeSliceWriteRegionSizeBytes =

622 if (!maybeSliceWriteRegionSizeBytes.has_value() ||

623 *maybeSliceWriteRegionSizeBytes == 0) {

624 LDBG() << "Failed to get slice write region size at loopDepth: " << i;

625 continue;

626 }

627 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;

628

629 double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /

630 static_cast<double>(sliceWriteRegionSizeBytes);

631

632 LLVM_DEBUG({

633 std::stringstream msg;

634 msg << " evaluating fusion profitability at depth : " << i << "\n"

635 << std::fixed << std::setprecision(2)

636 << " additional compute fraction: "

637 << 100.0 * additionalComputeFraction << "%\n"

638 << " storage reduction factor: " << storageReduction << "x\n"

639 << " fused nest cost: " << fusedLoopNestComputeCost << "\n"

640 << " src write region size: " << srcWriteRegionSizeBytes << "\n"

641 << " slice write region size: " << sliceWriteRegionSizeBytes;

642 LDBG() << msg.str();

643 });

644

645

646

647

648

649 if ((storageReduction > maxStorageReduction) &&

650 (additionalComputeFraction <= computeToleranceThreshold)) {

651 maxStorageReduction = storageReduction;

652 bestDstLoopDepth = i;

653 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;

654 sliceMemEstimate = sliceWriteRegionSizeBytes;

655 }

656 }

657

658

659

660 if (!bestDstLoopDepth) {

661 LDBG() << "All fusion choices involve more than the threshold amount of "

662 << "redundant computation; NOT fusing.";

663 return false;

664 }

665

666 if (!bestDstLoopDepth) {

667 LDBG() << "no fusion depth could be evaluated.";

668 return false;

669 }

670

671

672 *dstLoopDepth = *bestDstLoopDepth;

673

674 LDBG() << " LoopFusion fusion stats:";

675 LDBG() << " best loop depth: " << bestDstLoopDepth;

676 LDBG() << " src loop nest compute cost: " << srcLoopNestCost;

677 LDBG() << " dst loop nest compute cost: " << dstLoopNestCost;

678 LDBG() << " fused loop nest compute cost: " << minFusedLoopNestComputeCost;

679

682

683 std::optional storageReduction;

684

685 if (!dstMemSize || !srcMemSize) {

686 LDBG() << " fusion memory benefit cannot be evaluated; NOT fusing.";

687 return false;

688 }

689

690 auto srcMemSizeVal = *srcMemSize;

691 auto dstMemSizeVal = *dstMemSize;

692

693 assert(sliceMemEstimate && "expected value");

694 auto fusedMem = dstMemSizeVal + *sliceMemEstimate;

695

696 LDBG() << " src mem: " << srcMemSizeVal;

697 LDBG() << " dst mem: " << dstMemSizeVal;

698 LDBG() << " fused mem: " << fusedMem;

699 LDBG() << " slice mem: " << sliceMemEstimate;

700

701 if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {

702 LDBG() << "Fusion is not profitable; NOT fusing.";

703 return false;

704 }

705 storageReduction =

706 100.0 *

707 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));

708

709 double additionalComputeFraction =

710 100.0 * (minFusedLoopNestComputeCost /

711 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -

712 1);

713 (void)additionalComputeFraction;

714 LLVM_DEBUG({

715 std::stringstream msg;

716 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "

717 << std::setprecision(2) << additionalComputeFraction

718 << "% redundant computation and a ";

719 msg << (storageReduction ? std::to_string(*storageReduction) : "");

720 msg << "% storage reduction.";

721 LDBG() << msg.str();

722 });

723

724 return true;

725}

726

727namespace {

728

729

730

731

732

733

734

735

736

737

738

739

740

741

742

743

744

745

746

747

748

749

750

751

752

753

754

755

756

757

758

759

760

761

762

763

764

765

766

767

768

769

770

771

772

773

774

775struct GreedyFusion {

776public:

777

778 MemRefDependenceGraph *mdg;

779

780 SmallVector<unsigned, 8> worklist;

781

782 unsigned localBufSizeThreshold;

783

784 std::optional fastMemorySpace;

785

786

787 bool maximalFusion;

788

789

790 double computeToleranceThreshold;

791

792 using Node = MemRefDependenceGraph::Node;

793

794 GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,

795 std::optional fastMemorySpace, bool maximalFusion,

796 double computeToleranceThreshold)

797 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),

798 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),

799 computeToleranceThreshold(computeToleranceThreshold) {}

800

801

802 void init() {

803

804

805 worklist.clear();

806 for (auto &idAndNode : mdg->nodes) {

807 const Node &node = idAndNode.second;

808 worklist.push_back(node.id);

809 }

810 }

811

812 void runSiblingFusionOnly() {

813 fuseSiblingNodes();

814 eraseUnusedMemRefAllocations();

815 }

816

817

818 void runProducerConsumerFusionOnly() {

819 fuseProducerConsumerNodes(

820 std::numeric_limits::max());

821 eraseUnusedMemRefAllocations();

822 }

823

824

825

826

827

828

829 void runGreedyFusion() {

830

831 fuseProducerConsumerNodes(1);

832 fuseSiblingNodes();

833 fuseProducerConsumerNodes(

834 std::numeric_limits::max());

835 eraseUnusedMemRefAllocations();

836 }

837

838

839

840 bool canCreatePrivateMemRef(Value memref,

842 unsigned producerId, unsigned consumerId,

843 bool removeSrcNode) {

844

846 return false;

847 const Node *consumerNode = mdg->getNode(consumerId);

848

849

850

851

852

853

854

855 if (srcEscapingMemRefs.count(memref) > 0 &&

856 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))

857 return false;

858

859

860

861 if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 ||

862 mdg->getOutEdgeCount(consumerId, memref) > 0)

863 return false;

864

865

866

867

868 if (removeSrcNode &&

869 any_of(mdg->outEdges[producerId], [&](const auto &edge) {

870 return edge.value == memref && edge.id != consumerId;

871 }))

872 return false;

873

874 return true;

875 }

876

877

878

879

880 void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {

881 LDBG() << "Evaluating dst loop " << dstId;

882

883 if (mdg->nodes.count(dstId) == 0)

884 return;

885

886 auto *dstNode = mdg->getNode(dstId);

887

888 if (!isa(dstNode->op))

889 return;

890

891

892 if (dstNode->op->getNumResults() > 0)

893 return;

894

895 LDBG() << "Evaluating dst loop " << dstId;

896

897

898

899

900

902 auto dstAffineForOp = cast(dstNode->op);

903

904

905

906 bool dstNodeChanged;

907 do {

908

909

910

911

912

913 dstNodeChanged = false;

914 SmallVector<unsigned, 16> srcIdCandidates;

916

917 for (unsigned srcId : llvm::reverse(srcIdCandidates)) {

918

919 auto *srcNode = mdg->getNode(srcId);

920 auto srcAffineForOp = cast(srcNode->op);

921

922 LDBG() << "Trying to fuse producer loop nest " << srcId

923 << " with consumer loop nest " << dstId;

924 LDBG() << "Compute tolerance threshold: " << computeToleranceThreshold;

925 LDBG() << "Producer loop nest:";

926 LDBG() << *srcNode->op << " and consumer loop nest:";

927 LDBG() << *dstNode->op;

928

929 LDBG() << "Evaluating src loop " << srcId << " for dst loop " << dstId;

930

931

932

933 if (isa(srcNode->op) && srcNode->op->getNumResults() > 0)

934 continue;

935

938 producerConsumerMemrefs);

939

940

941

942 if (any_of(producerConsumerMemrefs, [&](Value memref) {

943 return mdg->getOutEdgeCount(srcNode->id, memref) >

944 maxSrcUserCount;

945 }))

946 continue;

947

948

949

950

953

954

955

956 Operation *fusedLoopInsPoint =

957 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);

958 if (fusedLoopInsPoint == nullptr)

959 continue;

960

961

962

963

964

965

966

967 SmallVector<AffineForOp, 4> surroundingLoops;

969 unsigned numSurroundingLoops = surroundingLoops.size();

970

971

972

973 SmallVector<Operation *, 2> dstMemrefOps;

974 for (Operation *op : dstNode->loads)

975 if (producerConsumerMemrefs.count(

976 cast(op).getMemRef()) > 0)

977 dstMemrefOps.push_back(op);

978 for (Operation *op : dstNode->stores)

979 if (producerConsumerMemrefs.count(

980 cast(op).getMemRef()))

981 dstMemrefOps.push_back(op);

982 if (dstMemrefOps.empty())

983 continue;

984 unsigned dstLoopDepthTest =

986

987

988

989 unsigned maxLegalFusionDepth = 0;

990 SmallVector<ComputationSliceState, 8> depthSliceUnions;

991 depthSliceUnions.resize(dstLoopDepthTest);

993 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {

994 FusionResult result =

996 i + numSurroundingLoops,

997 &depthSliceUnions[i - 1], strategy);

999 maxLegalFusionDepth = i;

1000 LDBG() << "Found valid slice for depth: " << i;

1001 }

1002 }

1003

1004 if (maxLegalFusionDepth == 0) {

1005 LDBG() << "Can't fuse: fusion is not legal at any depth";

1006 continue;

1007 }

1008

1009 LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth;

1010

1011 double computeToleranceThresholdToUse = computeToleranceThreshold;

1012

1013

1014

1015

1016

1017

1019 LDBG() << "Source nest has a cyclic dependence.";

1020

1021

1022

1023 if (maximalFusion) {

1024 auto srcForOp = cast(srcNode->op);

1025 auto dstForOp = cast(dstNode->op);

1026 int64_t sliceCost;

1027 int64_t fusedLoopNestComputeCost;

1029 srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,

1030 sliceCost, fusedLoopNestComputeCost);

1031 if (!fraction || fraction > 0) {

1032 LDBG() << "Can't perform maximal fusion with a cyclic dependence "

1033 << "and non-zero additional compute.";

1034 return;

1035 }

1036 } else {

1037

1038

1039 LDBG() << "Setting compute tolerance to zero since "

1040 << "source has a cylic dependence.";

1041 computeToleranceThresholdToUse = 0;

1042 }

1043 }

1044

1045

1046

1047

1048 unsigned bestDstLoopDepth = maxLegalFusionDepth;

1049 if (!maximalFusion) {

1050

1051 SmallVector<Operation *, 2> producerStores;

1052 for (Operation *op : srcNode->stores)

1053 if (producerConsumerMemrefs.count(

1054 cast(op).getMemRef()))

1055 producerStores.push_back(op);

1056

1057 assert(!producerStores.empty() && "Expected producer store");

1059 dstAffineForOp, depthSliceUnions,

1060 maxLegalFusionDepth, &bestDstLoopDepth,

1061 computeToleranceThresholdToUse)) {

1062 continue;

1063 }

1064 }

1065

1066 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");

1067 ComputationSliceState &bestSlice =

1068 depthSliceUnions[bestDstLoopDepth - 1];

1069 assert(!bestSlice.isEmpty() && "Missing slice union for depth");

1070

1071

1072

1073

1075 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,

1076 *mdg);

1077

1079 for (Value memref : producerConsumerMemrefs) {

1080 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,

1081 removeSrcNode)) {

1082

1083 LDBG() << "Creating private memref for " << memref;

1084

1085 privateMemrefs.insert(memref);

1086 }

1087 }

1088

1089

1090 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);

1091 dstNodeChanged = true;

1092

1093 LDBG() << "Fused src loop " << srcId << " into dst loop " << dstId

1094 << " at depth " << bestDstLoopDepth << ":";

1095 LDBG() << dstAffineForOp;

1096

1097

1098 if (fusedLoopInsPoint != dstAffineForOp)

1099 dstAffineForOp->moveBefore(fusedLoopInsPoint);

1100

1101

1102 mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,

1103 removeSrcNode);

1104

1105

1106 if (!privateMemrefs.empty()) {

1107

1108

1109 Block *sliceInsertionBlock = bestSlice.insertPoint->getBlock();

1110

1111

1113 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {

1114 Value storeMemRef = storeOp.getMemRef();

1115 if (privateMemrefs.count(storeMemRef) > 0)

1116 privateMemRefToStores[storeMemRef].push_back(storeOp);

1117 });

1118

1119

1120

1121

1122

1123 for (auto &memrefToStoresPair : privateMemRefToStores) {

1124 ArrayRef<Operation *> storesForMemref = memrefToStoresPair.second;

1126 dstAffineForOp, storesForMemref, bestDstLoopDepth,

1127 fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);

1128 if (!newMemRef)

1129 continue;

1130

1131 unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());

1132

1133 mdg->addEdge(newMemRefNodeId, dstId, newMemRef);

1134 }

1135

1136

1137

1138 dstNode = mdg->getNode(dstId);

1139 }

1140

1141

1142 LoopNestStateCollector dstLoopCollector;

1143 dstLoopCollector.collect(dstAffineForOp);

1144

1145

1146 mdg->clearNodeLoadAndStores(dstNode->id);

1147 mdg->addToNode(

1151

1152 if (removeSrcNode) {

1153 LDBG() << "Removing src loop " << srcId << " after fusion";

1154

1155 srcAffineForOp.erase();

1156 mdg->removeNode(srcId);

1157 srcNode = nullptr;

1158 }

1159 }

1160 } while (dstNodeChanged);

1161 }

1162

1163

1164

1165

1166

1167 void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {

1168 LDBG() << "--- Producer/Consumer Fusion ---";

1169 init();

1170 while (!worklist.empty()) {

1171 unsigned dstId = worklist.back();

1172 worklist.pop_back();

1173 performFusionsIntoDest(dstId, maxSrcUserCount);

1174 }

1175 }

1176

1177

1178

1179 void fuseSiblingNodes() {

1180 LDBG() << "--- Sibling Fusion ---";

1181 init();

1182 while (!worklist.empty()) {

1183 unsigned dstId = worklist.back();

1184 worklist.pop_back();

1185

1186

1187 if (mdg->nodes.count(dstId) == 0)

1188 continue;

1189

1190 auto *dstNode = mdg->getNode(dstId);

1191

1192 if (!isa(dstNode->op))

1193 continue;

1194

1195 fuseWithSiblingNodes(dstNode);

1196 }

1197 }

1198

1199

1200 void fuseWithSiblingNodes(Node *dstNode) {

1202 std::pair<unsigned, Value> idAndMemref;

1203 auto dstAffineForOp = cast(dstNode->op);

1204

1205 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {

1206 unsigned sibId = idAndMemref.first;

1207 Value memref = idAndMemref.second;

1208

1209

1210 auto *sibNode = mdg->getNode(sibId);

1211

1212

1213 assert(sibNode->op->getBlock() == dstNode->op->getBlock());

1214 Operation *insertPointInst =

1215 sibNode->op->isBeforeInBlock(dstNode->op)

1216 ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)

1217 : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);

1218 if (insertPointInst == nullptr)

1219 continue;

1220

1221

1222

1223

1224 SmallVector<Operation *, 2> sibLoadOpInsts;

1225 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);

1226

1227 Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);

1228

1229

1230 SmallVector<Operation *, 2> dstLoadOpInsts;

1231 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);

1232

1233

1234

1235

1236

1237

1238

1239 SmallVector<AffineForOp, 4> surroundingLoops;

1241 unsigned numSurroundingLoops = surroundingLoops.size();

1242 SmallVector<AffineForOp, 4> dstLoopIVs;

1244 unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;

1245 auto sibAffineForOp = cast(sibNode->op);

1246

1247

1248 SmallVector<ComputationSliceState, 8> depthSliceUnions;

1249 depthSliceUnions.resize(dstLoopDepthTest);

1250 unsigned maxLegalFusionDepth = 0;

1251 FusionStrategy strategy(memref);

1252 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {

1253 FusionResult result =

1255 i + numSurroundingLoops,

1256 &depthSliceUnions[i - 1], strategy);

1257

1259 maxLegalFusionDepth = i;

1260 }

1261

1262 LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth;

1263

1264

1265 if (maxLegalFusionDepth == 0)

1266 continue;

1267

1268 double computeToleranceThresholdToUse = computeToleranceThreshold;

1269

1270

1271

1272

1273

1274

1276 LDBG() << "Source nest has a cyclic dependence.";

1277

1278

1279

1280 if (maximalFusion) {

1281 auto dstForOp = cast(dstNode->op);

1282 int64_t sliceCost;

1283 int64_t fusedLoopNestComputeCost;

1285 sibAffineForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,

1286 sliceCost, fusedLoopNestComputeCost);

1287 if (!fraction || fraction > 0) {

1288 LDBG() << "Can't perform maximal fusion with a cyclic dependence "

1289 << "and non-zero additional compute.";

1290 return;

1291 }

1292 } else {

1293

1294

1295 LDBG() << "Setting compute tolerance to zero since "

1296 << "source has a cyclic dependence.";

1297 computeToleranceThresholdToUse = 0.0;

1298 }

1299 }

1300

1301 unsigned bestDstLoopDepth = maxLegalFusionDepth;

1302 if (!maximalFusion) {

1303

1304

1305

1306

1307 if (isFusionProfitable(sibAffineForOp, sibLoadOpInst, dstAffineForOp,

1308 depthSliceUnions, maxLegalFusionDepth,

1309 &bestDstLoopDepth,

1310 computeToleranceThresholdToUse))

1311 continue;

1312 }

1313

1314 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");

1315

1316 const ComputationSliceState &bestSlice =

1317 depthSliceUnions[bestDstLoopDepth - 1];

1318 assert(!bestSlice.isEmpty() &&

1319 "Fusion depth has no computed slice union");

1320

1321

1322

1323

1324 auto isMaximal = bestSlice.isMaximal();

1325 if (!isMaximal.value_or(false)) {

1326 LDBG() << "Slice isn't maximal; not performing sibling fusion.";

1327 continue;

1328 }

1329

1330

1331

1332

1333 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);

1334

1336 isInnermostInsertion);

1337

1338 auto dstForInst = cast(dstNode->op);

1339

1340 if (insertPointInst != dstForInst)

1341 dstForInst->moveBefore(insertPointInst);

1342

1343 LDBG() << "Fused sibling nest " << sibId << " into destination nest "

1344 << dstNode->id << " at depth " << bestDstLoopDepth << ":";

1345 LDBG() << dstAffineForOp;

1346

1347

1348 updateStateAfterSiblingFusion(sibNode, dstNode);

1349

1350

1351

1352 Operation *op = sibNode->op;

1353 mdg->removeNode(sibNode->id);

1355 }

1356 }

1357

1358

1359

1360

1361

1362 bool findSiblingNodeToFuse(Node *dstNode,

1364 std::pair<unsigned, Value> *idAndMemrefToFuse) {

1365

1366

1367 auto canFuseWithSibNode = [&](Node *sibNode, Value memref) {

1368

1369

1370 if (sibNode->getLoadOpCount(memref) != 1)

1371 return false;

1372

1373

1374 if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||

1375 mdg->hasDependencePath(dstNode->id, sibNode->id))

1376 return false;

1377

1378

1380 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);

1381 if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) {

1382 return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;

1383 }))

1384 return false;

1385

1386

1388 for (auto *storeOpInst : sibNode->stores) {

1389 storeMemrefs.insert(

1390 cast(storeOpInst).getMemRef());

1391 }

1392 return storeMemrefs.size() <= 1;

1393 };

1394

1395

1396 Block *block = dstNode->op->getBlock();

1397 for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) {

1399 auto loadOp = dyn_cast(user);

1400 if (!loadOp)

1401 continue;

1402

1403 SmallVector<AffineForOp, 4> loops;

1405

1406

1407

1408 auto *it = llvm::find_if(loops, [&](AffineForOp loop) {

1409 return loop->getBlock() == &mdg->block;

1410 });

1411

1412 if (it == loops.end())

1413 continue;

1414 Node *sibNode = mdg->getForOpNode(*it);

1415 assert(sibNode != nullptr);

1416

1417 if (sibNode->id == dstNode->id)

1418 continue;

1419

1420 if (visitedSibNodeIds->count(sibNode->id) > 0)

1421 continue;

1422

1423 auto memref = loadOp.getMemRef();

1424 if (dstNode->getLoadOpCount(memref) == 0)

1425 continue;

1426

1427 if (canFuseWithSibNode(sibNode, memref)) {

1428 visitedSibNodeIds->insert(sibNode->id);

1429 idAndMemrefToFuse->first = sibNode->id;

1430 idAndMemrefToFuse->second = memref;

1431 return true;

1432 }

1433 }

1434 }

1435

1436

1437

1438 SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;

1439 mdg->forEachMemRefInputEdge(

1440 dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {

1441

1442

1443 if (dstNode->getLoadOpCount(inEdge.value) > 0 &&

1444 (mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0 ||

1445 inEdge.value.getDefiningOp() == mdg->getNode(inEdge.id)->op))

1446 inEdges.push_back(inEdge);

1447 });

1448

1449

1450

1451 for (auto &inEdge : inEdges) {

1452

1453 SmallVector<MemRefDependenceGraph::Edge, 2> outEdges;

1454 mdg->forEachMemRefOutputEdge(

1455 inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {

1456 unsigned sibNodeId = outEdge.id;

1457 if (visitedSibNodeIds->count(sibNodeId) > 0)

1458 return;

1459

1460 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)

1461 return;

1462 auto *sibNode = mdg->getNode(sibNodeId);

1463 if (!isa(sibNode->op))

1464 return;

1465

1466 if (canFuseWithSibNode(sibNode, outEdge.value)) {

1467

1468 outEdges.push_back(outEdge);

1469 }

1470 });

1471

1472

1473 if (!outEdges.empty()) {

1474 visitedSibNodeIds->insert(outEdges[0].id);

1475 idAndMemrefToFuse->first = outEdges[0].id;

1476 idAndMemrefToFuse->second = outEdges[0].value;

1477 return true;

1478 }

1479 }

1480 return false;

1481 }

1482

1483

1484

1485 void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {

1486

1487 mdg->updateEdges(sibNode->id, dstNode->id);

1488

1489

1490 auto dstForInst = cast(dstNode->op);

1492 dstLoopCollector.collect(dstForInst);

1493

1498 }

1499

1500

1501 void eraseUnusedMemRefAllocations() {

1503 if (pair.second > 0)

1504 continue;

1505 auto memref = pair.first;

1506

1507 if (memref.use_empty())

1508 continue;

1509

1510 auto *op = memref.getDefiningOp();

1511 if (isa_and_nonnullmemref::AllocOp(op))

1513 }

1514 }

1515};

1516

1517}

1518

1519

1520void LoopFusion::runOnBlock(Block *block) {

1521 MemRefDependenceGraph g(*block);

1522 if (!g.init()) {

1523 LDBG() << "MDG init failed";

1524 return;

1525 }

1526

1527 std::optional fastMemorySpaceOpt;

1528 if (fastMemorySpace.hasValue())

1529 fastMemorySpaceOpt = fastMemorySpace;

1530 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;

1531 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,

1532 maximalFusion, computeToleranceThreshold);

1533

1534 if (affineFusionMode == FusionMode::ProducerConsumer)

1535 fusion.runProducerConsumerFusionOnly();

1536 else if (affineFusionMode == FusionMode::Sibling)

1537 fusion.runSiblingFusionOnly();

1538 else

1539 fusion.runGreedyFusion();

1540}

1541

1542void LoopFusion::runOnOperation() {

1543

1544

1545 getOperation()->walk([&](Operation *op) {

1546 for (Region &region : op->getRegions()) {

1547 for (Block &block : region.getBlocks()) {

1548 auto affineFors = block.getOps();

1549 if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))

1550 runOnBlock(&block);

1551 }

1552 }

1553 });

1554}

1555

1557 unsigned fastMemorySpace, uint64_t localBufSizeThreshold,

1558 bool maximalFusion, enum FusionMode affineFusionMode) {

1559 return std::make_unique(fastMemorySpace, localBufSizeThreshold,

1560 maximalFusion, affineFusionMode);

1561}