LLVM: lib/Target/AMDGPU/AMDGPUIGroupLP.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

26

27using namespace llvm;

28

29#define DEBUG_TYPE "igrouplp"

30

31namespace {

32

34 "amdgpu-igrouplp-exact-solver", cl::Hidden,

35 cl::desc("Whether to use the exponential time solver to fit "

36 "the instructions to the pipeline as closely as "

37 "possible."),

39

42 cl::desc("The maximum number of scheduling group conflicts "

43 "which we attempt to solve with the exponential time "

44 "exact solver. Problem sizes greater than this will"

45 "be solved by the less accurate greedy algorithm. Selecting "

46 "solver by size is superseded by manually selecting "

47 "the solver (e.g. by amdgpu-igrouplp-exact-solver"));

48

51 cl::desc("The amount of branches that we are willing to explore with"

52 "the exact algorithm before giving up."));

53

56 cl::desc("Whether to use the cost heuristic to make choices as we "

57 "traverse the search space using the exact solver. Defaulted "

58 "to on, and if turned off, we will use the node order -- "

59 "attempting to put the later nodes in the later sched groups. "

60 "Experimentally, results are mixed, so this should be set on a "

61 "case-by-case basis."));

62

63

64

65enum class SchedGroupMask {

67 ALU = 1u << 0,

68 VALU = 1u << 1,

69 SALU = 1u << 2,

70 MFMA = 1u << 3,

71 VMEM = 1u << 4,

72 VMEM_READ = 1u << 5,

73 VMEM_WRITE = 1u << 6,

74 DS = 1u << 7,

75 DS_READ = 1u << 8,

76 DS_WRITE = 1u << 9,

77 TRANS = 1u << 10,

78 ALL = ALU | VALU | SALU | MFMA | VMEM | VMEM_READ | VMEM_WRITE | DS |

79 DS_READ | DS_WRITE | TRANS,

81};

82

83class SchedGroup;

84

85

86

87

88class InstructionRule {

89protected:

91 unsigned SGID;

92

93

94 std::optional<SmallVector<SUnit *, 4>> Cache;

95

96public:

97 virtual bool

100 return true;

101 };

102

103 InstructionRule(const SIInstrInfo *TII, unsigned SGID,

104 bool NeedsCache = false)

105 : TII(TII), SGID(SGID) {

106 if (NeedsCache) {

108 }

109 }

110

111 virtual ~InstructionRule() = default;

112};

113

115

116

117

118

119class SchedGroup {

120private:

121

122

123

124 SchedGroupMask SGMask;

125

126

127 std::optional MaxSize;

128

129

130

131 int SyncID = 0;

132

133

134 unsigned SGID;

135

136

138

139

140 static unsigned NumSchedGroups;

141

142

144

145

146

148

149public:

150

152

155

156

157 bool canAddSU(SUnit &SU) const;

158

159

160

161

162 void link(SUnit &SU, bool MakePred = false);

163

164

165

166 int link(SUnit &SU, bool MakePred,

167 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);

168

169

170

171

173

174

175

176 void link(SchedGroup &OtherGroup);

177

178

179 bool isFull() const { return MaxSize && Collection.size() >= *MaxSize; }

180

181

182

183

184

185 void addRule(std::shared_ptr NewRule) {

187 }

188

189

190 bool allowedByRules(const SUnit *SU,

192 for (auto &Rule : Rules) {

193 if (!Rule->apply(SU, Collection, SyncPipe))

194 return false;

195 }

196 return true;

197 }

198

199

200 void add(SUnit &SU) {

202 << format_hex((int)SGMask, 10, true) << " adding "

205 }

206

207

208 void pop() { Collection.pop_back(); }

209

210

211 void initSchedGroup();

212

213

214

215

216

217

218 void initSchedGroup(std::vector::reverse_iterator RIter,

219 SUnitsToCandidateSGsMap &SyncedInstrs);

220

221 void initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs);

222

223 int getSyncID() { return SyncID; }

224

225 int getSGID() { return SGID; }

226

227 SchedGroupMask getMask() { return SGMask; }

228

229 SchedGroup(SchedGroupMask SGMask, std::optional MaxSize,

231 : SGMask(SGMask), MaxSize(MaxSize), DAG(DAG), TII(TII) {

232 SGID = NumSchedGroups++;

233 }

234

235 SchedGroup(SchedGroupMask SGMask, std::optional MaxSize, int SyncID,

237 : SGMask(SGMask), MaxSize(MaxSize), SyncID(SyncID), DAG(DAG), TII(TII) {

238 SGID = NumSchedGroups++;

239 }

240};

241

242using SUToCandSGsPair = std::pair<SUnit *, SmallVector<int, 4>>;

244

245

246

247

248

249

250

251

252

253

254class PipelineSolver {

256

257

261

263

265

266

267 bool NeedsSolver = false;

268

269

270

271 unsigned computeProblemSize();

272

273

274 int MissPenalty = 0;

275

276

277 int BestCost = -1;

278 int CurrCost = 0;

279

280

281

282 int CurrConflInstNo = 0;

283

284 int CurrSyncGroupIdx = 0;

285

286 int BeginSyncGroupIdx = 0;

287

288

289 uint64_t BranchesExplored = 0;

290

291

292 bool IsBottomUp = true;

293

294

295 void advancePosition();

296

297

298 void retreatPosition();

299

300

301 bool solveExact();

302

303 bool solveGreedy();

304

305

306

307 template

308 void greedyFind(std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I,

309 T E);

310

311 bool checkOptimal();

312

313

314 template

315 void populateReadyList(SmallVectorImpl<std::pair<int, int>> &ReadyList, T I,

316 T E);

317

318 void makePipeline();

319

320

321 template void linkSchedGroups(T I, T E);

322

323

325 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);

326

327

328

329 template

330 int linkSUnit(SUnit *SU, int SGID,

331 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E);

332

333 void removeEdges(const std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);

334

335 void convertSyncMapsToArrays();

336

337 void reset();

338

339public:

340

341

342 void solve();

343

347 : DAG(DAG), SyncedInstrs(SyncedInstrs),

348 SyncedSchedGroups(SyncedSchedGroups), IsBottomUp(IsBottomUp) {

349

350 for (auto &PipelineInstrs : SyncedInstrs) {

351 if (PipelineInstrs.second.size() > 0) {

352 NeedsSolver = true;

353 break;

354 }

355 }

356

357 if (!NeedsSolver)

358 return;

359

360 convertSyncMapsToArrays();

361

362 CurrPipeline = BestPipeline;

363

364 while (static_cast<size_t>(BeginSyncGroupIdx) < PipelineInstrs.size() &&

365 PipelineInstrs[BeginSyncGroupIdx].size() == 0)

366 ++BeginSyncGroupIdx;

367

368 if (static_cast<size_t>(BeginSyncGroupIdx) >= PipelineInstrs.size())

369 return;

370 }

371};

372

373void PipelineSolver::reset() {

374

375 for (auto &SyncPipeline : CurrPipeline) {

376 for (auto &SG : SyncPipeline) {

378 SG.Collection.clear();

380 return SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER;

381 });

382 if (SchedBarr != TempCollection.end())

383 SG.Collection.push_back(*SchedBarr);

384 }

385 }

386

387 CurrSyncGroupIdx = BeginSyncGroupIdx;

388 CurrConflInstNo = 0;

389 CurrCost = 0;

390}

391

392void PipelineSolver::convertSyncMapsToArrays() {

393 for (auto &SyncPipe : SyncedSchedGroups) {

394 BestPipeline.insert(BestPipeline.begin(), SyncPipe.second);

395 }

396

397 int PipelineIDx = SyncedInstrs.size() - 1;

398 PipelineInstrs.resize(SyncedInstrs.size());

399 for (auto &SyncInstrMap : SyncedInstrs) {

400 for (auto &SUsToCandSGs : SyncInstrMap.second) {

401 if (PipelineInstrs[PipelineIDx].size() == 0) {

402 PipelineInstrs[PipelineIDx].push_back(

403 std::pair(SUsToCandSGs.first, SUsToCandSGs.second));

404 continue;

405 }

406 auto *SortPosition = PipelineInstrs[PipelineIDx].begin();

407

408

409 while (SortPosition != PipelineInstrs[PipelineIDx].end() &&

410 SUsToCandSGs.first->NodeNum > SortPosition->first->NodeNum)

411 ++SortPosition;

412 PipelineInstrs[PipelineIDx].insert(

413 SortPosition, std::pair(SUsToCandSGs.first, SUsToCandSGs.second));

414 }

415 --PipelineIDx;

416 }

417}

418

419template void PipelineSolver::linkSchedGroups(T I, T E) {

420 for (; I != E; ++I) {

421 auto &GroupA = *I;

422 for (auto J = std::next(I); J != E; ++J) {

423 auto &GroupB = *J;

424 GroupA.link(GroupB);

425 }

426 }

427}

428

429void PipelineSolver::makePipeline() {

430

431 for (auto &SyncPipeline : BestPipeline) {

433 for (auto &SG : SyncPipeline) {

434 LLVM_DEBUG(dbgs() << "SchedGroup with SGID " << SG.getSGID()

435 << " has: \n");

436 SUnit *SGBarr = nullptr;

437 for (auto &SU : SG.Collection) {

438 if (SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)

439 SGBarr = SU;

441 }

442

443 if (!SGBarr)

444 continue;

445 SG.link(*SGBarr, false);

446 }

447 }

448

449 for (auto &SyncPipeline : BestPipeline) {

450 IsBottomUp ? linkSchedGroups(SyncPipeline.rbegin(), SyncPipeline.rend())

451 : linkSchedGroups(SyncPipeline.begin(), SyncPipeline.end());

452 }

453}

454

455template

456int PipelineSolver::linkSUnit(

457 SUnit *SU, int SGID, std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges,

459 bool MakePred = false;

460 int AddedCost = 0;

461 for (; I < E; ++I) {

462 if (I->getSGID() == SGID) {

463 MakePred = true;

464 continue;

465 }

466 auto Group = *I;

467 AddedCost += Group.link(*SU, MakePred, AddedEdges);

468 assert(AddedCost >= 0);

469 }

470 return AddedCost;

471}

472

473int PipelineSolver::addEdges(

475 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {

476

477

478

479

480

481

482

483

484

485 return IsBottomUp ? linkSUnit(SU, SGID, AddedEdges, SyncPipeline.rbegin(),

486 SyncPipeline.rend())

487 : linkSUnit(SU, SGID, AddedEdges, SyncPipeline.begin(),

488 SyncPipeline.end());

489}

490

491void PipelineSolver::removeEdges(

492 const std::vector<std::pair<SUnit *, SUnit *>> &EdgesToRemove) {

493

494

495 for (auto &PredSuccPair : EdgesToRemove) {

496 SUnit *Pred = PredSuccPair.first;

497 SUnit *Succ = PredSuccPair.second;

498

500 Succ->Preds, [&Pred](SDep &P) { return P.getSUnit() == Pred; });

501 if (Match != Succ->Preds.end()) {

502 assert(Match->isArtificial());

504 }

505 }

506}

507

508void PipelineSolver::advancePosition() {

509 ++CurrConflInstNo;

510

511 if (static_cast<size_t>(CurrConflInstNo) >=

512 PipelineInstrs[CurrSyncGroupIdx].size()) {

513 CurrConflInstNo = 0;

514 ++CurrSyncGroupIdx;

515

516 while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size() &&

517 PipelineInstrs[CurrSyncGroupIdx].size() == 0)

518 ++CurrSyncGroupIdx;

519 }

520}

521

522void PipelineSolver::retreatPosition() {

523 assert(CurrConflInstNo >= 0);

524 assert(CurrSyncGroupIdx >= 0);

525

526 if (CurrConflInstNo > 0) {

527 --CurrConflInstNo;

528 return;

529 }

530

531 if (CurrConflInstNo == 0) {

532

533

534 if (CurrSyncGroupIdx == BeginSyncGroupIdx)

535 return;

536

537 --CurrSyncGroupIdx;

538

539 while (PipelineInstrs[CurrSyncGroupIdx].size() == 0)

540 --CurrSyncGroupIdx;

541

542 CurrConflInstNo = PipelineInstrs[CurrSyncGroupIdx].size() - 1;

543 }

544}

545

546bool PipelineSolver::checkOptimal() {

547 if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size()) {

548 if (BestCost == -1 || CurrCost < BestCost) {

549 BestPipeline = CurrPipeline;

550 BestCost = CurrCost;

551 LLVM_DEBUG(dbgs() << "Found Fit with cost " << BestCost << "\n");

552 }

553 assert(BestCost >= 0);

554 }

555

556 bool DoneExploring = false;

557 if (MaxBranchesExplored > 0 && BranchesExplored >= MaxBranchesExplored)

558 DoneExploring = true;

559

560 return (DoneExploring || BestCost == 0);

561}

562

563template

564void PipelineSolver::populateReadyList(

566 SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];

567 auto SyncPipeline = CurrPipeline[CurrSyncGroupIdx];

568 assert(CurrSU.second.size() >= 1);

569

570 for (; I != E; ++I) {

571 std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;

572 int CandSGID = *I;

573 SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {

574 return SG.getSGID() == CandSGID;

575 });

577

578 if (UseCostHeur) {

579 if (Match->isFull()) {

580 ReadyList.push_back(std::pair(*I, MissPenalty));

581 continue;

582 }

583

584 int TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);

585 ReadyList.push_back(std::pair(*I, TempCost));

586 removeEdges(AddedEdges);

587 } else

588 ReadyList.push_back(std::pair(*I, -1));

589 }

590

591 if (UseCostHeur)

592 std::sort(ReadyList.begin(), ReadyList.end(), llvm::less_second());

593

594 assert(ReadyList.size() == CurrSU.second.size());

595}

596

597bool PipelineSolver::solveExact() {

598 if (checkOptimal())

599 return true;

600

601 if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size())

602 return false;

603

604 assert(static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size());

605 assert(static_cast<size_t>(CurrConflInstNo) <

606 PipelineInstrs[CurrSyncGroupIdx].size());

607 SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];

608 LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum

609 << ") in Pipeline # " << CurrSyncGroupIdx << "\n");

610

611

613

614 IsBottomUp ? populateReadyList(ReadyList, CurrSU.second.rbegin(),

615 CurrSU.second.rend())

616 : populateReadyList(ReadyList, CurrSU.second.begin(),

617 CurrSU.second.end());

618

619 auto *I = ReadyList.begin();

620 auto *E = ReadyList.end();

621 for (; I != E; ++I) {

622

623

624

625 if (BestCost != -1 && (CurrCost + I->second > BestCost))

626 return false;

627

628 int CandSGID = I->first;

629 int AddedCost = 0;

630 std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;

631 auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];

632 SchedGroup *Match;

633 for (auto &SG : SyncPipeline) {

634 if (SG.getSGID() == CandSGID)

635 Match = &SG;

636 }

637

638 if (Match->isFull())

639 continue;

640

641 if (!Match->allowedByRules(CurrSU.first, SyncPipeline))

642 continue;

643

644 LLVM_DEBUG(dbgs() << "Assigning to SchedGroup with Mask "

645 << (int)Match->getMask() << "and ID " << CandSGID

646 << "\n");

647 Match->add(*CurrSU.first);

648 AddedCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);

649 LLVM_DEBUG(dbgs() << "Cost of Assignment: " << AddedCost << "\n");

650 CurrCost += AddedCost;

651 advancePosition();

652 ++BranchesExplored;

653 bool FinishedExploring = false;

654

655

656 if (CurrCost < BestCost || BestCost == -1) {

657 if (solveExact()) {

658 FinishedExploring = BestCost != 0;

659 if (!FinishedExploring)

660 return true;

661 }

662 }

663

664 retreatPosition();

665 CurrCost -= AddedCost;

666 removeEdges(AddedEdges);

667 Match->pop();

668 CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;

669 if (FinishedExploring)

670 return true;

671 }

672

673

674

675

676 CurrCost += MissPenalty;

677 advancePosition();

678

679 LLVM_DEBUG(dbgs() << "NOT Assigned (" << CurrSU.first->NodeNum << ")\n");

680

681 bool FinishedExploring = false;

682 if (CurrCost < BestCost || BestCost == -1) {

683 if (solveExact()) {

684 bool FinishedExploring = BestCost != 0;

685 if (!FinishedExploring)

686 return true;

687 }

688 }

689

690 retreatPosition();

691 CurrCost -= MissPenalty;

692 return FinishedExploring;

693}

694

695template

696void PipelineSolver::greedyFind(

697 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E) {

698 SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];

699 int BestNodeCost = -1;

700 int TempCost;

701 SchedGroup *BestGroup = nullptr;

702 int BestGroupID = -1;

703 auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];

704 LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum

705 << ") in Pipeline # " << CurrSyncGroupIdx << "\n");

706

707

708

709

710

711 for (; I != E; ++I) {

712 std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;

713 int CandSGID = *I;

714 SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {

715 return SG.getSGID() == CandSGID;

716 });

718

719 LLVM_DEBUG(dbgs() << "Trying SGID # " << CandSGID << " with Mask "

720 << (int)Match->getMask() << "\n");

721

722 if (Match->isFull()) {

723 LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " is full\n");

724 continue;

725 }

726 if (!Match->allowedByRules(CurrSU.first, SyncPipeline)) {

727 LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " has conflicting rule\n");

728 continue;

729 }

730 TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);

731 LLVM_DEBUG(dbgs() << "Cost of Group " << TempCost << "\n");

732 if (TempCost < BestNodeCost || BestNodeCost == -1) {

733 BestGroup = Match;

734 BestNodeCost = TempCost;

735 BestGroupID = CandSGID;

736 }

737 removeEdges(AddedEdges);

738 if (BestNodeCost == 0)

739 break;

740 }

741

742 if (BestGroupID != -1) {

743 BestGroup->add(*CurrSU.first);

744 addEdges(SyncPipeline, CurrSU.first, BestGroupID, AddedEdges);

745 LLVM_DEBUG(dbgs() << "Best Group has ID: " << BestGroupID << " and Mask"

746 << (int)BestGroup->getMask() << "\n");

747 BestCost += TempCost;

748 } else

749 BestCost += MissPenalty;

750

751 CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;

752}

753

754bool PipelineSolver::solveGreedy() {

755 BestCost = 0;

756 std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;

757

758 while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size()) {

759 SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];

760 IsBottomUp

761 ? greedyFind(AddedEdges, CurrSU.second.rbegin(), CurrSU.second.rend())

762 : greedyFind(AddedEdges, CurrSU.second.begin(), CurrSU.second.end());

763 advancePosition();

764 }

765 BestPipeline = CurrPipeline;

766 removeEdges(AddedEdges);

767 return false;

768}

769

770unsigned PipelineSolver::computeProblemSize() {

771 unsigned ProblemSize = 0;

772 for (auto &PipeConflicts : PipelineInstrs) {

773 ProblemSize += PipeConflicts.size();

774 }

775

776 return ProblemSize;

777}

778

779void PipelineSolver::solve() {

780 if (!NeedsSolver)

781 return;

782

783 unsigned ProblemSize = computeProblemSize();

784 assert(ProblemSize > 0);

785

786 bool BelowCutoff = (CutoffForExact > 0) && ProblemSize <= CutoffForExact;

787 MissPenalty = (ProblemSize / 2) + 1;

788

790 if (EnableExactSolver || BelowCutoff) {

791 LLVM_DEBUG(dbgs() << "Starting Greedy pipeline solver\n");

792 solveGreedy();

793 reset();

794 LLVM_DEBUG(dbgs() << "Greedy produced best cost of " << BestCost << "\n");

795 if (BestCost > 0) {

796 LLVM_DEBUG(dbgs() << "Starting EXACT pipeline solver\n");

797 solveExact();

798 LLVM_DEBUG(dbgs() << "Exact produced best cost of " << BestCost << "\n");

799 }

800 } else {

801 LLVM_DEBUG(dbgs() << "Starting GREEDY pipeline solver\n");

802 solveGreedy();

803 }

804

805 makePipeline();

808}

809

810enum IGLPStrategyID : int {

811 MFMASmallGemmOptID = 0,

812 MFMASmallGemmSingleWaveOptID = 1,

813 MFMAExpInterleaveID = 2,

814 MFMAExpSimpleInterleaveID = 3

815};

816

817

818class IGLPStrategy {

819protected:

821

823

824public:

825

826 virtual bool applyIGLPStrategy(

830

831

834

835 bool IsBottomUp = true;

836

838 : DAG(DAG), TII(TII) {}

839

840 virtual ~IGLPStrategy() = default;

841};

842

843class MFMASmallGemmOpt final : public IGLPStrategy {

844private:

845public:

846 bool applyIGLPStrategy(

850

853 return true;

854 }

855

857 : IGLPStrategy(DAG, TII) {

858 IsBottomUp = true;

859 }

860};

861

862bool MFMASmallGemmOpt::applyIGLPStrategy(

866

867 unsigned MFMACount = 0;

869 if (TII->isMFMAorWMMA(I))

870 ++MFMACount;

871

872 const unsigned PipelineSyncID = 0;

873 SchedGroup *SG = nullptr;

874 for (unsigned I = 0; I < MFMACount * 3; ++I) {

875 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

876 SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII);

877 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

878

879 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

880 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);

881 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

882 }

883

884 return true;

885}

886

887class MFMAExpInterleaveOpt final : public IGLPStrategy {

888private:

889

890 static unsigned TransPipeCount;

891

892 static unsigned MFMAPipeCount;

893

894 static unsigned AddPipeCount;

895

896 static unsigned MFMAEnablement;

897

898 static unsigned ExpRequirement;

899

900 static unsigned MFMAChains;

901

902 static unsigned MFMAChainLength;

903

904 static bool HasCvt;

905

906

907 static bool HasChainBetweenCvt;

908

909 static std::optional FirstPipeDSR;

910

912

913

915

916

917

918 class IsPipeExp final : public InstructionRule {

919 public:

922

923 auto *DAG = SyncPipe[0].DAG;

924

925 if (Cache->empty()) {

926 auto I = DAG->SUnits.rbegin();

927 auto E = DAG->SUnits.rend();

928 for (; I != E; I++) {

929 if (TII->isMFMAorWMMA(*I->getInstr()))

930 Cache->push_back(&*I);

931 }

932 if (Cache->empty())

933 return false;

934 }

935

936 auto Reaches = any_of(*Cache, [&SU, &DAG](SUnit *TargetSU) {

937 return DAG->IsReachable(TargetSU, const_cast<SUnit *>(SU));

938 });

939

940 return Reaches;

941 }

942 IsPipeExp(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)

943 : InstructionRule(TII, SGID, NeedsCache) {}

944 };

945

946

947

948 class EnablesNthMFMA final : public InstructionRule {

949 private:

951

952 public:

955 bool FoundTrans = false;

956 unsigned Counter = 1;

957 auto *DAG = SyncPipe[0].DAG;

958

959 if (Cache->empty()) {

960 auto I = DAG->SUnits.begin();

961 auto E = DAG->SUnits.end();

962 for (; I != E; I++) {

963 if (FoundTrans && TII->isMFMAorWMMA(*I->getInstr())) {

964 if (Counter == Number) {

965 Cache->push_back(&*I);

966 break;

967 }

968 ++Counter;

969 }

970 if (!FoundTrans && TII->isTRANS(I->getInstr()->getOpcode()))

971 FoundTrans = true;

972 }

973 if (Cache->empty())

974 return false;

975 }

976

977 return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));

978 }

979

981 bool NeedsCache = false)

982 : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}

983 };

984

985

986

987 class EnablesNthMFMAInChain final : public InstructionRule {

988 private:

990 SUnit *ChainSeed;

991

992 public:

995 auto *DAG = SyncPipe[0].DAG;

996

997 if (!SU || TII->isMFMAorWMMA(*ChainSeed->getInstr()))

998 return false;

999

1000 if (Cache->empty()) {

1001 auto *TempSU = ChainSeed;

1003 while (Depth > 0) {

1005 bool Found = false;

1006 for (auto &Succ : TempSU->Succs) {

1007 if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {

1008 TempSU = Succ.getSUnit();

1009 Found = true;

1010 break;

1011 }

1012 }

1013 if (!Found)

1014 return false;

1015 }

1016

1017 Cache->push_back(TempSU);

1018 }

1019

1020

1021 assert(!Cache->empty());

1022

1023 return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));

1024 }

1025

1026 EnablesNthMFMAInChain(unsigned Number, SUnit *ChainSeed,

1028 bool NeedsCache = false)

1029 : InstructionRule(TII, SGID, NeedsCache), Number(Number),

1030 ChainSeed(ChainSeed) {}

1031 };

1032

1033

1034

1035

1036 class LessThanNSuccs final : public InstructionRule {

1037 private:

1038 unsigned Size = 1;

1039 bool HasIntermediary = false;

1040

1041 public:

1044 if (!SyncPipe.size())

1045 return false;

1046

1048 return Succ.getKind() == SDep::Data;

1049 });

1050 if (SuccSize >= Size)

1051 return false;

1052

1053 if (HasIntermediary) {

1054 for (auto Succ : SU->Succs) {

1055 auto SuccSize =

1057 return SuccSucc.getKind() == SDep::Data;

1058 });

1059 if (SuccSize >= Size)

1060 return false;

1061 }

1062 }

1063

1064 return true;

1065 }

1066 LessThanNSuccs(unsigned Size, const SIInstrInfo *TII, unsigned SGID,

1067 bool HasIntermediary = false, bool NeedsCache = false)

1068 : InstructionRule(TII, SGID, NeedsCache), Size(Size),

1069 HasIntermediary(HasIntermediary) {}

1070 };

1071

1072

1073

1074

1075

1076 class GreaterThanOrEqualToNSuccs final : public InstructionRule {

1077 private:

1078 unsigned Size = 1;

1079 bool HasIntermediary = false;

1080

1081 public:

1084 if (!SyncPipe.size())

1085 return false;

1086

1088 return Succ.getKind() == SDep::Data;

1089 });

1090 if (SuccSize >= Size)

1091 return true;

1092

1093 if (HasIntermediary) {

1094 for (auto Succ : SU->Succs) {

1095 auto SuccSize =

1097 return SuccSucc.getKind() == SDep::Data;

1098 });

1099 if (SuccSize >= Size)

1100 return true;

1101 }

1102 }

1103

1104 return false;

1105 }

1107 unsigned SGID, bool HasIntermediary = false,

1108 bool NeedsCache = false)

1109 : InstructionRule(TII, SGID, NeedsCache), Size(Size),

1110 HasIntermediary(HasIntermediary) {}

1111 };

1112

1113

1114 class IsCvt final : public InstructionRule {

1115 public:

1119 return Opc == AMDGPU::V_CVT_F16_F32_e32 ||

1120 Opc == AMDGPU::V_CVT_I32_F32_e32;

1121 }

1122 IsCvt(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)

1123 : InstructionRule(TII, SGID, NeedsCache) {}

1124 };

1125

1126

1127 class IsFMA final : public InstructionRule {

1128 public:

1133 }

1134 IsFMA(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)

1135 : InstructionRule(TII, SGID, NeedsCache) {}

1136 };

1137

1138

1139 class IsPipeAdd final : public InstructionRule {

1140 public:

1144 }

1145 IsPipeAdd(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)

1146 : InstructionRule(TII, SGID, NeedsCache) {}

1147 };

1148

1149

1150

1151 class IsSuccOfPrevNthGroup final : public InstructionRule {

1152 private:

1153 unsigned Distance = 1;

1154

1155 public:

1158 SchedGroup *OtherGroup = nullptr;

1159 if (!SyncPipe.size())

1160 return false;

1161

1162 for (auto &PipeSG : SyncPipe) {

1163 if ((unsigned)PipeSG.getSGID() == SGID - Distance)

1164 OtherGroup = &PipeSG;

1165 }

1166

1167 if (!OtherGroup)

1168 return false;

1169 if (!OtherGroup->Collection.size())

1170 return true;

1171

1172 for (auto &OtherEle : OtherGroup->Collection) {

1173 for (auto &Succ : OtherEle->Succs) {

1174 if (Succ.getSUnit() == SU && Succ.getKind() == SDep::Data)

1175 return true;

1176 }

1177 }

1178

1179 return false;

1180 }

1181 IsSuccOfPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,

1182 unsigned SGID, bool NeedsCache = false)

1183 : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}

1184 };

1185

1186

1187

1188 class IsReachableFromPrevNthGroup final : public InstructionRule {

1189 private:

1190 unsigned Distance = 1;

1191

1192 public:

1195 SchedGroup *OtherGroup = nullptr;

1196 if (!SyncPipe.size())

1197 return false;

1198

1199 for (auto &PipeSG : SyncPipe) {

1200 if ((unsigned)PipeSG.getSGID() == SGID - Distance)

1201 OtherGroup = &PipeSG;

1202 }

1203

1204 if (!OtherGroup)

1205 return false;

1206 if (!OtherGroup->Collection.size())

1207 return true;

1208

1209 auto *DAG = SyncPipe[0].DAG;

1210

1211 for (auto &OtherEle : OtherGroup->Collection)

1212 if (DAG->IsReachable(const_cast<SUnit *>(SU), OtherEle))

1213 return true;

1214

1215 return false;

1216 }

1217 IsReachableFromPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,

1218 unsigned SGID, bool NeedsCache = false)

1219 : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}

1220 };

1221

1222

1223 class OccursAtOrAfterNode final : public InstructionRule {

1224 private:

1225 unsigned Number = 1;

1226

1227 public:

1230

1232 }

1234 bool NeedsCache = false)

1235 : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}

1236 };

1237

1238

1239

1240 class IsExactMFMA final : public InstructionRule {

1241 private:

1242 unsigned Number = 1;

1243 SUnit *ChainSeed;

1244

1245 public:

1248 if (!SU || TII->isMFMAorWMMA(*ChainSeed->getInstr()))

1249 return false;

1250

1251 if (Cache->empty()) {

1252 auto *TempSU = ChainSeed;

1254 while (Depth > 0) {

1256 bool Found = false;

1257 for (auto &Succ : TempSU->Succs) {

1258 if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {

1259 TempSU = Succ.getSUnit();

1260 Found = true;

1261 break;

1262 }

1263 }

1264 if (!Found) {

1265 return false;

1266 }

1267 }

1268 Cache->push_back(TempSU);

1269 }

1270

1271

1272 assert(!Cache->empty());

1273

1274 return (*Cache)[0] == SU;

1275 }

1276

1278 unsigned SGID, bool NeedsCache = false)

1279 : InstructionRule(TII, SGID, NeedsCache), Number(Number),

1280 ChainSeed(ChainSeed) {}

1281 };

1282

1283

1284

1285

1286 class OccursAfterExp final : public InstructionRule {

1287 public:

1290

1291 auto *DAG = SyncPipe[0].DAG;

1292 if (Cache->empty()) {

1293 for (auto &SU : DAG->SUnits)

1295 Cache->push_back(&SU);

1296 break;

1297 }

1298 if (Cache->empty())

1299 return false;

1300 }

1301

1302 return SU->NodeNum > (*Cache)[0]->NodeNum;

1303 }

1304

1305 OccursAfterExp(const SIInstrInfo *TII, unsigned SGID,

1306 bool NeedsCache = false)

1307 : InstructionRule(TII, SGID, NeedsCache) {}

1308 };

1309

1310public:

1311 bool applyIGLPStrategy(

1315

1318

1320 : IGLPStrategy(DAG, TII) {

1321 IsBottomUp = false;

1322 }

1323};

1324

1325unsigned MFMAExpInterleaveOpt::TransPipeCount = 0;

1326unsigned MFMAExpInterleaveOpt::MFMAPipeCount = 0;

1327unsigned MFMAExpInterleaveOpt::AddPipeCount = 0;

1328unsigned MFMAExpInterleaveOpt::MFMAEnablement = 0;

1329unsigned MFMAExpInterleaveOpt::ExpRequirement = 0;

1330unsigned MFMAExpInterleaveOpt::MFMAChains = 0;

1331unsigned MFMAExpInterleaveOpt::MFMAChainLength = 0;

1332bool MFMAExpInterleaveOpt::HasCvt = false;

1333bool MFMAExpInterleaveOpt::HasChainBetweenCvt = false;

1334std::optional MFMAExpInterleaveOpt::FirstPipeDSR = std::nullopt;

1335

1336bool MFMAExpInterleaveOpt::analyzeDAG(const SIInstrInfo *TII) {

1342

1343 auto isBitPack = [](unsigned Opc) {

1344 return Opc == AMDGPU::V_PACK_B32_F16_e64 || Opc == AMDGPU::V_PERM_B32_e64;

1345 };

1346

1347 auto isCvt = [](unsigned Opc) {

1348 return Opc == AMDGPU::V_CVT_F16_F32_e32 || Opc == AMDGPU::V_CVT_I32_F32_e32;

1349 };

1350

1351 auto isAdd = [](unsigned Opc) { return Opc == AMDGPU::V_ADD_F32_e32; };

1352

1353 AddPipeCount = 0;

1354 for (SUnit &SU : DAG->SUnits) {

1356 if (TII->isTRANS(Opc)) {

1357

1358 if (SU.Succs.size() >= 7)

1359 continue;

1360 for (auto &Succ : SU.Succs) {

1361 if (Succ.getSUnit()->Succs.size() >= 7)

1362 continue;

1363 }

1365 }

1366

1369

1370 if (isBitPack(Opc))

1372

1373 if (isCvt(Opc))

1375

1376 if (isAdd(Opc))

1377 ++AddPipeCount;

1378 }

1379

1380 if (!(PackSUs.size() && MFMAPipeCands.size() && ExpPipeCands.size()))

1381 return false;

1382

1383 TransPipeCount = 0;

1384

1385 std::optional<SUnit *> TempMFMA;

1386 std::optional<SUnit *> TempExp;

1387

1388 for (auto &PredSU : ExpPipeCands) {

1389 for (auto &SuccSU : MFMAPipeCands) {

1390 if (DAG->IsReachable(SuccSU, PredSU)) {

1391 if (!TempExp) {

1392 TempExp = PredSU;

1393 TempMFMA = SuccSU;

1394 }

1396 ++TransPipeCount;

1397 break;

1398 }

1399 }

1400 }

1401

1402 if (!(TempExp && TempMFMA))

1403 return false;

1404

1405 HasChainBetweenCvt = none_of((*TempExp)->Succs, [&isCvt](SDep &Succ) {

1406 return isCvt(Succ.getSUnit()->getInstr()->getOpcode());

1407 });

1408

1409

1410 for (auto &SuccSU : MFMAPipeCands) {

1411 if (MFMAPipeSUs.size() &&

1412 any_of(MFMAPipeSUs, [&SuccSU](SUnit *PotentialMatch) {

1413 return PotentialMatch->NodeNum == SuccSU->NodeNum;

1414 }))

1415 continue;

1416

1417 for (auto &PredSU : ExpPipeCands) {

1418 if (DAG->IsReachable(SuccSU, PredSU)) {

1420 break;

1421 }

1422 }

1423 }

1424

1425 MFMAPipeCount = MFMAPipeSUs.size();

1426

1427 assert(TempExp && TempMFMA);

1428 assert(MFMAPipeCount > 0);

1429

1430 std::optional<SUnit *> TempCvt;

1431 for (auto &SuccSU : CvtSUs) {

1432 if (DAG->IsReachable(SuccSU, *TempExp)) {

1433 TempCvt = SuccSU;

1434 break;

1435 }

1436 }

1437

1438 HasCvt = false;

1439 if (TempCvt.has_value()) {

1440 for (auto &SuccSU : MFMAPipeSUs) {

1441 if (DAG->IsReachable(SuccSU, *TempCvt)) {

1442 HasCvt = true;

1443 break;

1444 }

1445 }

1446 }

1447

1448 MFMAChains = 0;

1449 for (auto &MFMAPipeSU : MFMAPipeSUs) {

1450 if (is_contained(MFMAChainSeeds, MFMAPipeSU))

1451 continue;

1453 return TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr());

1454 })) {

1455 MFMAChainSeeds.push_back(MFMAPipeSU);

1456 ++MFMAChains;

1457 }

1458 }

1459

1460 if (!MFMAChains)

1461 return false;

1462

1463 for (auto Pred : MFMAChainSeeds[0]->Preds) {

1464 if (TII->isDS(Pred.getSUnit()->getInstr()->getOpcode()) &&

1465 Pred.getSUnit()->getInstr()->mayLoad())

1466 FirstPipeDSR = Pred.getSUnit()->NodeNum;

1467 }

1468

1469 MFMAChainLength = MFMAPipeCount / MFMAChains;

1470

1471

1472 unsigned PackSuccCount =

1474 return DAG->IsReachable(VPack, *TempExp);

1475 });

1476

1477

1478 unsigned PackPredCount =

1480 auto Opc = Pred.getSUnit()->getInstr()->getOpcode();

1481 return isBitPack(Opc);

1482 });

1483

1484 auto *PackPred = llvm::find_if((*TempMFMA)->Preds, [&isBitPack](SDep &Pred) {

1485 auto Opc = Pred.getSUnit()->getInstr()->getOpcode();

1486 return isBitPack(Opc);

1487 });

1488

1489 if (PackPred == (*TempMFMA)->Preds.end())

1490 return false;

1491

1492 MFMAEnablement = 0;

1493 ExpRequirement = 0;

1494

1495 MFMAEnablement =

1497 return TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr());

1498 });

1499

1500

1501 MFMAEnablement *= PackSuccCount;

1502

1503

1504 ExpRequirement =

1506 return DAG->IsReachable(PackPred->getSUnit(), ExpBase);

1507 });

1508

1509 ExpRequirement *= PackPredCount;

1510 return true;

1511}

1512

1513bool MFMAExpInterleaveOpt::shouldApplyStrategy(ScheduleDAGInstrs *DAG,

1517

1519 MFMAChainSeeds.clear();

1521 return false;

1522

1523 return true;

1524}

1525

1526bool MFMAExpInterleaveOpt::applyIGLPStrategy(

1530

1531 bool IsSmallKernelType =

1532 MFMAEnablement == 2 && ExpRequirement == 4 && TransPipeCount == 32;

1533 bool IsLargeKernelType =

1534 MFMAEnablement == 4 && ExpRequirement == 4 && TransPipeCount == 64;

1535

1536 if (!(IsSmallKernelType || IsLargeKernelType))

1537 return false;

1538

1541

1542 unsigned PipelineSyncID = 0;

1543 SchedGroup *SG = nullptr;

1544

1545 unsigned MFMAChain = 0;

1546 unsigned PositionInChain = 0;

1547 unsigned CurrMFMAForTransPosition = 0;

1548

1549 auto incrementTransPosition = [&MFMAChain, &PositionInChain,

1550 &CurrMFMAForTransPosition]() {

1551 CurrMFMAForTransPosition += MFMAEnablement;

1552 PositionInChain = (CurrMFMAForTransPosition / MFMAChains);

1553 MFMAChain = CurrMFMAForTransPosition % MFMAChains;

1554 };

1555

1556 auto getNextTransPositionInChain = [&CurrMFMAForTransPosition]() {

1557 auto TempMFMAForTrans = CurrMFMAForTransPosition + MFMAEnablement;

1558 return (TempMFMAForTrans / MFMAChains);

1559 };

1560

1561 auto getNextTransMFMAChain = [&CurrMFMAForTransPosition]() {

1562 auto TempMFMAForTrans = CurrMFMAForTransPosition + MFMAEnablement;

1563 return TempMFMAForTrans % MFMAChains;

1564 };

1565

1566 unsigned CurrMFMAPosition = 0;

1567 unsigned MFMAChainForMFMA = 0;

1568 unsigned PositionInChainForMFMA = 0;

1569

1570 auto incrementMFMAPosition = [&CurrMFMAPosition, &MFMAChainForMFMA,

1571 &PositionInChainForMFMA]() {

1572 ++CurrMFMAPosition;

1573 MFMAChainForMFMA = CurrMFMAPosition % MFMAChains;

1574 PositionInChainForMFMA = CurrMFMAPosition / MFMAChains;

1575 };

1576

1578 assert(IsPostRA || MFMAChainSeeds.size() == MFMAChains);

1579

1580 bool UsesFMA = IsSmallKernelType || !IsPostRA;

1581 bool UsesDSRead = IsLargeKernelType && !IsPostRA && FirstPipeDSR;

1582 bool UsesCvt = HasCvt && (IsSmallKernelType || !IsPostRA);

1583 bool UsesVALU = IsSmallKernelType;

1584

1585

1586 if (UsesFMA) {

1587

1588 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1589 SchedGroupMask::VALU, ExpRequirement, PipelineSyncID, DAG, TII);

1590 if (!IsPostRA && MFMAChains) {

1591 SG->addRule(std::make_shared(

1592 PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),

1593 true));

1594 } else

1595 SG->addRule(

1596 std::make_shared(1, TII, SG->getSGID(), true));

1597 SG->addRule(std::make_shared(TII, SG->getSGID()));

1598 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1599

1600

1601 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1602 SchedGroupMask::VALU, ExpRequirement, PipelineSyncID, DAG, TII);

1603 if (!IsPostRA && MFMAChains) {

1604 SG->addRule(std::make_shared(

1605 getNextTransPositionInChain(),

1606 MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(), true));

1607 } else

1608 SG->addRule(std::make_shared(MFMAEnablement + 1, TII,

1609 SG->getSGID(), true));

1610 SG->addRule(std::make_shared(TII, SG->getSGID()));

1611 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1612 }

1613

1614 if (UsesDSRead) {

1615 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1616 SchedGroupMask::DS_READ, 2, PipelineSyncID, DAG, TII);

1617 SG->addRule(std::make_shared(*FirstPipeDSR, TII,

1618 SG->getSGID()));

1619 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1620 }

1621

1622

1623 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1624 SchedGroupMask::TRANS, ExpRequirement, PipelineSyncID, DAG, TII);

1625 if (!IsPostRA && MFMAChains)

1626 SG->addRule(std::make_shared(

1627 PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(), true));

1628 else

1629 SG->addRule(std::make_shared(1, TII, SG->getSGID(), true));

1630 SG->addRule(std::make_shared(TII, SG->getSGID(), true));

1631 SG->addRule(std::make_shared(8, TII, SG->getSGID(),

1632 HasChainBetweenCvt));

1633 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1634

1635 incrementTransPosition();

1636

1637

1638 for (unsigned I = 0; I < ExpRequirement; I++) {

1639

1640 if (UsesCvt) {

1641 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1642 SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);

1643 SG->addRule(std::make_shared(TII, SG->getSGID()));

1644 if (HasChainBetweenCvt)

1645 SG->addRule(std::make_shared(

1646 1 + (2 + UsesFMA) * I, TII, SG->getSGID()));

1647 else

1648 SG->addRule(std::make_shared(

1649 1 + (2 + UsesFMA) * I, TII, SG->getSGID()));

1650 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1651 }

1652

1653

1654 if (UsesFMA) {

1655 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1656 SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);

1657 if (!IsPostRA && MFMAChains) {

1658 SG->addRule(std::make_shared(

1659 getNextTransPositionInChain(),

1660 MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(), true));

1661 } else

1662 SG->addRule(std::make_shared(2 * MFMAEnablement + 1,

1663 TII, SG->getSGID(), true));

1664 SG->addRule(std::make_shared(TII, SG->getSGID()));

1665 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1666 }

1667

1668

1669 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1670 SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);

1671 if (!IsPostRA && MFMAChains)

1672 SG->addRule(std::make_shared(

1673 PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),

1674 true));

1675 else

1676 SG->addRule(std::make_shared(MFMAEnablement + 1, TII,

1677 SG->getSGID(), true));

1678 SG->addRule(std::make_shared(TII, SG->getSGID(), true));

1679 SG->addRule(std::make_shared(8, TII, SG->getSGID(),

1680 HasChainBetweenCvt));

1681 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1682 }

1683

1684

1685

1686 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1687 SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);

1688 SG->addRule(std::make_shared(TII, SG->getSGID(), true));

1689 SG->addRule(std::make_shared(

1690 8, TII, SG->getSGID(), HasChainBetweenCvt));

1691 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1692

1693

1694

1695

1696 unsigned MFMARatio =

1697 MFMAEnablement > ExpRequirement ? MFMAEnablement / ExpRequirement : 1;

1698

1699 unsigned ExpRatio =

1700 MFMAEnablement > ExpRequirement ? 1 : ExpRequirement / MFMAEnablement;

1701

1702 unsigned RemainingExp = TransPipeCount > (2 * ExpRequirement)

1703 ? TransPipeCount - (2 * ExpRequirement)

1704 : 0;

1705 unsigned ExpLoopCount = RemainingExp / ExpRatio;

1706

1707 unsigned MFMAInLoop = MFMAPipeCount > (MFMAEnablement * 2)

1708 ? MFMAPipeCount - (MFMAEnablement * 2)

1709 : 0;

1710 unsigned MFMALoopCount = MFMAInLoop / MFMARatio;

1711 unsigned VALUOps =

1712 AddPipeCount < MFMAPipeCount ? 1 : AddPipeCount / MFMAPipeCount;

1713 unsigned LoopSize = std::min(ExpLoopCount, MFMALoopCount);

1714

1715 for (unsigned I = 0; I < LoopSize; I++) {

1716 if (!(I * ExpRatio % ExpRequirement))

1717 incrementTransPosition();

1718

1719

1720 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1721 SchedGroupMask::MFMA, MFMARatio, PipelineSyncID, DAG, TII);

1722 if (!IsPostRA && MFMAChains)

1723 SG->addRule(std::make_shared(

1724 PositionInChainForMFMA, MFMAChainSeeds[MFMAChainForMFMA], TII,

1725 SG->getSGID(), true));

1726 else

1727 SG->addRule(std::make_shared(TII, SG->getSGID(), true));

1728 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1729 incrementMFMAPosition();

1730

1731 if (UsesVALU) {

1732 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1733 SchedGroupMask::VALU, VALUOps, PipelineSyncID, DAG, TII);

1734 SG->addRule(std::make_shared(TII, SG->getSGID()));

1735 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1736 }

1737

1738 if (UsesDSRead && !(I % 4)) {

1739 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1740 SchedGroupMask::DS_READ, 2, PipelineSyncID, DAG, TII);

1741 SG->addRule(std::make_shared(*FirstPipeDSR, TII,

1742 SG->getSGID()));

1743 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1744 }

1745

1746

1747 for (unsigned J = 0; J < ExpRatio; J++) {

1748 auto MFMAOffset = (1 + UsesVALU) * MFMARatio * (I + 1);

1749 auto MaxMFMAOffset =

1750 (1 + UsesVALU) * ExpRequirement * MFMARatio / ExpRatio;

1751

1752

1753 if (UsesCvt) {

1754 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1755 SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);

1756 SG->addRule(std::make_shared(TII, SG->getSGID()));

1757 auto BaseDiff = (2 + UsesFMA) * (ExpRequirement - 1) + 1;

1758 auto DSROffset = I / 4 + 1;

1759 auto MaxDSROffset = MaxMFMAOffset / 4;

1760

1761 auto ExpOffset = I * ExpRatio + J >= ExpRequirement ? 0 : 1;

1762 auto CurrentOffset = UsesDSRead * std::min(MaxDSROffset, DSROffset) +

1763 std::min(MaxMFMAOffset, MFMAOffset) + BaseDiff +

1764 ExpOffset;

1765 if (HasChainBetweenCvt)

1766 SG->addRule(std::make_shared(

1767 CurrentOffset, TII, SG->getSGID()));

1768 else

1769 SG->addRule(std::make_shared(CurrentOffset, TII,

1770 SG->getSGID()));

1771 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1772 }

1773

1774

1775 if (UsesFMA) {

1776 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1777 SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);

1778 if (!IsPostRA && MFMAChains)

1779 SG->addRule(std::make_shared(

1780 getNextTransPositionInChain(),

1781 MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(),

1782 true));

1783 else

1784 SG->addRule(std::make_shared(

1785 (((I * ExpRatio + J) / ExpRequirement) + 3) * MFMAEnablement + 1,

1786 TII, SG->getSGID(), true));

1787 SG->addRule(std::make_shared(TII, SG->getSGID()));

1788 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1789 }

1790

1791

1792 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1793 SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);

1794 if (!IsPostRA && MFMAChains)

1795 SG->addRule(std::make_shared(

1796 PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),

1797 true));

1798 else

1799 SG->addRule(std::make_shared(

1800 (((I * ExpRatio + J) / ExpRequirement) + 2) * MFMAEnablement + 1,

1801 TII, SG->getSGID(), true));

1802 SG->addRule(std::make_shared(TII, SG->getSGID(), true));

1803 SG->addRule(std::make_shared(8, TII, SG->getSGID(),

1804 HasChainBetweenCvt));

1805 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1806 }

1807 }

1808

1809

1810 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1811 SchedGroupMask::MFMA, MFMAEnablement * 2, PipelineSyncID, DAG, TII);

1812 SG->addRule(std::make_shared(TII, SG->getSGID(), true));

1813 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1814 return true;

1815}

1816

1817class MFMAExpSimpleInterleaveOpt final : public IGLPStrategy {

1818public:

1819 bool applyIGLPStrategy(

1823

1826 return true;

1827 }

1828

1830 : IGLPStrategy(DAG, TII) {

1831 IsBottomUp = true;

1832 }

1833};

1834

1835bool MFMAExpSimpleInterleaveOpt::applyIGLPStrategy(

1839

1840 unsigned MFMACount = 0;

1842 if (TII->isMFMAorWMMA(I))

1843 ++MFMACount;

1844

1845 const unsigned PipelineSyncID = 0;

1846 for (unsigned I = 0; I < MFMACount * 3; ++I) {

1847 SchedGroup *SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1848 SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);

1849 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1850

1851 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

1852 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);

1853 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

1854 }

1855

1856 return true;

1857}

1858

1859class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {

1860private:

1861

1862 class EnablesInitialMFMA final : public InstructionRule {

1863 public:

1866 if (!SyncPipe.size())

1867 return false;

1868 int MFMAsFound = 0;

1869 if (!Cache->size()) {

1870 for (auto &Elt : SyncPipe[0].DAG->SUnits) {

1871 if (TII->isMFMAorWMMA(*Elt.getInstr())) {

1872 ++MFMAsFound;

1873 if (MFMAsFound > 4)

1874 break;

1875 Cache->push_back(&Elt);

1876 }

1877 }

1878 }

1879

1880 auto *DAG = SyncPipe[0].DAG;

1881 for (auto &Elt : *Cache) {

1883 return true;

1884 }

1885 return false;

1886 }

1887

1888 EnablesInitialMFMA(const SIInstrInfo *TII, unsigned SGID,

1889 bool NeedsCache = false)

1890 : InstructionRule(TII, SGID, NeedsCache) {}

1891 };

1892

1893

1894 class IsPermForDSW final : public InstructionRule {

1895 public:

1899 if (MI->getOpcode() != AMDGPU::V_PERM_B32_e64)

1900 return false;

1901

1902 bool FitsInGroup = false;

1903

1904 if (!Collection.size()) {

1905 for (auto &Succ : SU->Succs) {

1906 SUnit *SuccUnit = Succ.getSUnit();

1909 Cache->push_back(SuccUnit);

1910 FitsInGroup = true;

1911 }

1912 }

1913 return FitsInGroup;

1914 }

1915

1916

1917

1920 return ThisSucc.getSUnit() == Elt;

1921 });

1922 });

1923 }

1924

1925 IsPermForDSW(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)

1926 : InstructionRule(TII, SGID, NeedsCache) {}

1927 };

1928

1929

1930 class IsSuccOfPrevGroup final : public InstructionRule {

1931 public:

1934 SchedGroup *OtherGroup = nullptr;

1935 for (auto &PipeSG : SyncPipe) {

1936 if ((unsigned)PipeSG.getSGID() == SGID - 1) {

1937 OtherGroup = &PipeSG;

1938 }

1939 }

1940

1941 if (!OtherGroup)

1942 return false;

1943 if (!OtherGroup->Collection.size())

1944 return true;

1945

1946

1947 return any_of(OtherGroup->Collection, [&SU](SUnit *Elt) {

1948 return any_of(Elt->Succs,

1949 [&SU](SDep &Succ) { return Succ.getSUnit() == SU; });

1950 });

1951 }

1952 IsSuccOfPrevGroup(const SIInstrInfo *TII, unsigned SGID,

1953 bool NeedsCache = false)

1954 : InstructionRule(TII, SGID, NeedsCache) {}

1955 };

1956

1957

1958 class VMEMSize final : public InstructionRule {

1959 public:

1963 if (MI->getOpcode() == TargetOpcode::BUNDLE)

1964 return false;

1965 if (!Collection.size())

1966 return true;

1967

1968 int NumBits = 0;

1969

1970 auto TRI = TII->getRegisterInfo();

1971 auto &MRI = MI->getMF()->getRegInfo();

1972 for (auto &Elt : Collection) {

1973 auto Op = Elt->getInstr()->getOperand(0);

1975 TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(MRI, Op));

1976 NumBits += Size;

1977 }

1978

1979 if (NumBits < 128) {

1981 if (NumBits + TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(

1982 MRI, MI->getOperand(0))) <=

1983 128)

1984 return true;

1985 }

1986

1987 return false;

1988 }

1989

1990 VMEMSize(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)

1991 : InstructionRule(TII, SGID, NeedsCache) {}

1992 };

1993

1994

1995

1996 class SharesPredWithPrevNthGroup final : public InstructionRule {

1997 private:

1998 unsigned Distance = 1;

1999

2000 public:

2003 SchedGroup *OtherGroup = nullptr;

2004 if (!SyncPipe.size())

2005 return false;

2006

2007 if (!Cache->size()) {

2008

2009 for (auto &PipeSG : SyncPipe) {

2010 if ((unsigned)PipeSG.getSGID() == SGID - Distance) {

2011 OtherGroup = &PipeSG;

2012 }

2013 }

2014

2015 if (!OtherGroup)

2016 return false;

2017 if (!OtherGroup->Collection.size())

2018 return true;

2019

2020 for (auto &OtherEle : OtherGroup->Collection) {

2021 for (auto &Pred : OtherEle->Preds) {

2022 if (Pred.getSUnit()->getInstr()->getOpcode() ==

2023 AMDGPU::V_PERM_B32_e64)

2024 Cache->push_back(Pred.getSUnit());

2025 }

2026 }

2027

2028

2029 if (!Cache->size())

2030 return false;

2031 }

2032

2033 auto *DAG = SyncPipe[0].DAG;

2034

2035

2038 });

2039 }

2040 SharesPredWithPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,

2041 unsigned SGID, bool NeedsCache = false)

2042 : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}

2043 };

2044

2045public:

2046 bool applyIGLPStrategy(

2050

2053 return true;

2054 }

2055

2057 : IGLPStrategy(DAG, TII) {

2058 IsBottomUp = false;

2059 }

2060};

2061

2062static unsigned DSWCount = 0;

2063static unsigned DSWWithPermCount = 0;

2064static unsigned DSWWithSharedVMEMCount = 0;

2065

2066bool MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(

2067 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,

2070 unsigned MFMACount = 0;

2071 unsigned DSRCount = 0;

2072

2073 bool IsInitial = Phase == AMDGPU::SchedulingPhase::Initial;

2074

2075 assert((!IsInitial || (DSWCount == 0 && DSWWithPermCount == 0 &&

2076 DSWWithSharedVMEMCount == 0)) &&

2077 "DSWCounters should be zero in pre-RA scheduling!");

2079 for (auto &SU : DAG->SUnits) {

2080 auto *I = SU.getInstr();

2081 if (TII->isMFMAorWMMA(*I))

2082 ++MFMACount;

2083 else if (TII->isDS(*I)) {

2084 if (I->mayLoad())

2085 ++DSRCount;

2086 else if (I->mayStore() && IsInitial) {

2087 ++DSWCount;

2088 for (auto Pred : SU.Preds) {

2089 if (Pred.getSUnit()->getInstr()->getOpcode() ==

2090 AMDGPU::V_PERM_B32_e64) {

2092 break;

2093 }

2094 }

2095 }

2096 }

2097 }

2098

2099 if (IsInitial) {

2100 DSWWithPermCount = DSWithPerms.size();

2101 auto *I = DSWithPerms.begin();

2102 auto *E = DSWithPerms.end();

2103

2104

2105

2106

2107

2108

2109

2110 DenseMap<MachineInstr *, SUnit *> VMEMLookup;

2112 for (; I != E; I++) {

2113 SUnit *Cand = nullptr;

2114 bool MissedAny = false;

2115 for (auto &Pred : (*I)->Preds) {

2116 if (Pred.getSUnit()->getInstr()->getOpcode() != AMDGPU::V_PERM_B32_e64)

2117 continue;

2118

2120 break;

2121

2122 for (auto &Succ : Pred.getSUnit()->Succs) {

2123 auto *MI = Succ.getSUnit()->getInstr();

2124 if (TII->isVMEM(*MI) || MI->mayLoad())

2125 continue;

2126

2127 if (MissedAny || !VMEMLookup.size()) {

2128 MissedAny = true;

2129 VMEMLookup[MI] = *I;

2130 continue;

2131 }

2132

2134 if (Inserted) {

2135 MissedAny = true;

2136 continue;

2137 }

2138

2139 Cand = It->second;

2141 MissedAny = true;

2142 break;

2143 }

2144 }

2145 }

2146 if (!MissedAny && Cand) {

2147 DSWWithSharedVMEMCount += 2;

2150 }

2151 }

2152 }

2153

2154 assert(DSWWithSharedVMEMCount <= DSWWithPermCount);

2155 SchedGroup *SG;

2156 unsigned PipelineSyncID = 0;

2157

2158 if (DSWWithPermCount) {

2159 for (unsigned I = 0; I < MFMACount; I++) {

2160 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2161 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);

2162 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2163

2164 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2165 SchedGroupMask::VALU, 2, PipelineSyncID, DAG, TII);

2166 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2167 }

2168 }

2169

2170 PipelineSyncID = 1;

2171

2172

2173

2174

2175

2176 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2177 SchedGroupMask::DS_READ, 4, PipelineSyncID, DAG, TII);

2178 SG->addRule(std::make_shared(TII, SG->getSGID(), true));

2179 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2180

2181 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2182 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);

2183 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2184

2185

2186 for (unsigned I = 4; I < DSRCount; ++I) {

2187 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2188 SchedGroupMask::DS_READ, 1, PipelineSyncID, DAG, TII);

2189 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2190

2191 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2192 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);

2193 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2194 }

2195

2196

2197

2198

2199 for (unsigned I = DSWWithSharedVMEMCount; I < DSWWithPermCount; ++I) {

2200 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2201 SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);

2202 SG->addRule(std::make_shared(TII, SG->getSGID(), true));

2203 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2204

2205 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2206 SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);

2207 SG->addRule(std::make_shared(TII, SG->getSGID()));

2208 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2209

2210 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2211 SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);

2212 SG->addRule(std::make_shared(

2213 1, TII, SG->getSGID(), true));

2214 SG->addRule(std::make_shared(TII, SG->getSGID()));

2215 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2216

2217 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2218 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);

2219 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2220

2221 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2222 SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);

2223 SG->addRule(std::make_shared(

2224 3, TII, SG->getSGID(), true));

2225 SG->addRule(std::make_shared(TII, SG->getSGID()));

2226 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2227

2228 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2229 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);

2230 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2231 }

2232

2233

2234

2235

2236 for (unsigned I = DSWWithPermCount; I < DSWCount; I++) {

2237 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2238 SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);

2239 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2240

2241 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2242 SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);

2243 SG->addRule(std::make_shared(TII, SG->getSGID()));

2244 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2245

2246 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2247 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);

2248 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2249 }

2250

2251

2252

2253

2254

2255

2256 for (unsigned I = 0; I < DSWWithSharedVMEMCount; ++I) {

2257 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2258 SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);

2259 SG->addRule(std::make_shared(TII, SG->getSGID(), true));

2260 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2261

2262 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2263 SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);

2264 SG->addRule(std::make_shared(TII, SG->getSGID()));

2265 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2266

2267 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2268 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);

2269 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2270

2271 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2272 SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);

2273 SG->addRule(std::make_shared(TII, SG->getSGID(), true));

2274 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2275

2276 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2277 SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);

2278 SG->addRule(std::make_shared(TII, SG->getSGID()));

2279 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2280

2281 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2282 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);

2283 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2284

2285 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2286 SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);

2287 SG->addRule(std::make_shared(

2288 2, TII, SG->getSGID(), true));

2289 SG->addRule(std::make_shared(TII, SG->getSGID()));

2290 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2291

2292 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2293 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);

2294 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2295

2296 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2297 SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);

2298 SG->addRule(std::make_shared(

2299 4, TII, SG->getSGID(), true));

2300 SG->addRule(std::make_shared(TII, SG->getSGID()));

2301 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2302

2303 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(

2304 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);

2305 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);

2306 }

2307

2308 return true;

2309}

2310

2311static std::unique_ptr

2312createIGLPStrategy(IGLPStrategyID ID, ScheduleDAGInstrs *DAG,

2313 const SIInstrInfo *TII) {

2314 switch (ID) {

2315 case MFMASmallGemmOptID:

2316 return std::make_unique(DAG, TII);

2317 case MFMASmallGemmSingleWaveOptID:

2318 return std::make_unique(DAG, TII);

2319 case MFMAExpInterleaveID:

2320 return std::make_unique(DAG, TII);

2321 case MFMAExpSimpleInterleaveID:

2322 return std::make_unique(DAG, TII);

2323 }

2324

2326}

2327

2328class IGroupLPDAGMutation : public ScheduleDAGMutation {

2329private:

2330 const SIInstrInfo *TII;

2331

2332 ScheduleDAGMI *DAG;

2333

2334

2335

2336

2337 DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;

2338

2339

2340 DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;

2341

2342

2343 void addSchedBarrierEdges(SUnit &SU);

2344

2345

2346

2347

2348

2349

2350

2351

2352

2353

2354 SchedGroupMask invertSchedBarrierMask(SchedGroupMask Mask) const;

2355

2356

2357 void initSchedGroupBarrierPipelineStage(

2358 std::vector::reverse_iterator RIter);

2359

2360 bool initIGLPOpt(SUnit &SU);

2361

2362public:

2363 void apply(ScheduleDAGInstrs *DAGInstrs) override;

2364

2365

2366

2367

2368

2369

2370 bool IsBottomUp = true;

2371

2372

2374

2375 IGroupLPDAGMutation() = default;

2377};

2378

2379unsigned SchedGroup::NumSchedGroups = 0;

2380

2381bool SchedGroup::tryAddEdge(SUnit *A, SUnit *B) {

2384 return true;

2385 }

2386 return false;

2387}

2388

2389bool SchedGroup::canAddMI(const MachineInstr &MI) const {

2390 bool Result = false;

2391 if (MI.isMetaInstruction())

2393

2394 else if (((SGMask & SchedGroupMask::ALU) != SchedGroupMask::NONE) &&

2395 (TII->isVALU(MI) || TII->isMFMAorWMMA(MI) || TII->isSALU(MI) ||

2396 TII->isTRANS(MI)))

2397 Result = MI.mayLoadOrStore();

2398

2399 else if (((SGMask & SchedGroupMask::VALU) != SchedGroupMask::NONE) &&

2400 TII->isVALU(MI) && TII->isMFMAorWMMA(MI) && TII->isTRANS(MI)) {

2401

2402

2403

2404 Result = MI.mayLoadOrStore();

2405 }

2406

2407 else if (((SGMask & SchedGroupMask::SALU) != SchedGroupMask::NONE) &&

2408 TII->isSALU(MI))

2409 Result = MI.mayLoadOrStore();

2410

2411 else if (((SGMask & SchedGroupMask::MFMA) != SchedGroupMask::NONE) &&

2412 TII->isMFMAorWMMA(MI))

2414

2415 else if (((SGMask & SchedGroupMask::VMEM) != SchedGroupMask::NONE) &&

2416 TII->isVMEM(MI))

2418

2419 else if (((SGMask & SchedGroupMask::VMEM_READ) != SchedGroupMask::NONE) &&

2420 MI.mayLoad() && TII->isVMEM(MI))

2422

2423 else if (((SGMask & SchedGroupMask::VMEM_WRITE) != SchedGroupMask::NONE) &&

2424 MI.mayStore() && TII->isVMEM(MI))

2426

2427 else if (((SGMask & SchedGroupMask::DS) != SchedGroupMask::NONE) &&

2430

2431 else if (((SGMask & SchedGroupMask::DS_READ) != SchedGroupMask::NONE) &&

2432 MI.mayLoad() && TII->isDS(MI))

2434

2435 else if (((SGMask & SchedGroupMask::DS_WRITE) != SchedGroupMask::NONE) &&

2436 MI.mayStore() && TII->isDS(MI))

2438

2439 else if (((SGMask & SchedGroupMask::TRANS) != SchedGroupMask::NONE) &&

2440 TII->isTRANS(MI))

2442

2444 dbgs() << "For SchedGroup with mask " << format_hex((int)SGMask, 10, true)

2445 << (Result ? " could classify " : " unable to classify ") << MI);

2446

2448}

2449

2450int SchedGroup::link(SUnit &SU, bool MakePred,

2451 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {

2452 int MissedEdges = 0;

2453 for (auto *A : Collection) {

2454 SUnit *B = &SU;

2455 if (A == B || A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)

2456 continue;

2457 if (MakePred)

2459

2461 continue;

2462

2463

2464

2465 bool Added = tryAddEdge(A, B);

2466 if (Added)

2467 AddedEdges.emplace_back(A, B);

2468 else

2469 ++MissedEdges;

2470 }

2471

2472 return MissedEdges;

2473}

2474

2475void SchedGroup::link(SUnit &SU, bool MakePred) {

2476 for (auto *A : Collection) {

2477 SUnit *B = &SU;

2478 if (A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)

2479 continue;

2480 if (MakePred)

2482

2483 tryAddEdge(A, B);

2484 }

2485}

2486

2487void SchedGroup::link(SUnit &SU,

2488 function_ref<bool(const SUnit *A, const SUnit *B)> P) {

2489 for (auto *A : Collection) {

2490 SUnit *B = &SU;

2491 if (P(A, B))

2493

2494 tryAddEdge(A, B);

2495 }

2496}

2497

2498void SchedGroup::link(SchedGroup &OtherGroup) {

2499 for (auto *B : OtherGroup.Collection)

2501}

2502

2503bool SchedGroup::canAddSU(SUnit &SU) const {

2505 if (MI.getOpcode() != TargetOpcode::BUNDLE)

2506 return canAddMI(MI);

2507

2508

2509 const MachineBasicBlock *MBB = MI.getParent();

2511 while (E != MBB->end() && E->isBundledWithPred())

2512 ++E;

2513

2514

2515 return std::all_of(B, E, [this](MachineInstr &MI) { return canAddMI(MI); });

2516}

2517

2518void SchedGroup::initSchedGroup() {

2519 for (auto &SU : DAG->SUnits) {

2520 if (isFull())

2521 break;

2522

2523 if (canAddSU(SU))

2524 add(SU);

2525 }

2526}

2527

2528void SchedGroup::initSchedGroup(std::vector::reverse_iterator RIter,

2529 SUnitsToCandidateSGsMap &SyncedInstrs) {

2530 SUnit &InitSU = *RIter;

2531 for (auto E = DAG->SUnits.rend(); RIter != E; ++RIter) {

2532 auto &SU = *RIter;

2533 if (isFull())

2534 break;

2535

2536 if (canAddSU(SU))

2537 SyncedInstrs[&SU].push_back(SGID);

2538 }

2539

2540 add(InitSU);

2542 (*MaxSize)++;

2543}

2544

2545void SchedGroup::initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs) {

2546 auto I = DAG->SUnits.rbegin();

2547 auto E = DAG->SUnits.rend();

2548 for (; I != E; ++I) {

2549 auto &SU = *I;

2550 if (isFull())

2551 break;

2552 if (canAddSU(SU))

2553 SyncedInstrs[&SU].push_back(SGID);

2554 }

2555}

2556

2557void IGroupLPDAGMutation::apply(ScheduleDAGInstrs *DAGInstrs) {

2558 const TargetSchedModel *TSchedModel = DAGInstrs->getSchedModel();

2559 if (!TSchedModel || DAGInstrs->SUnits.empty())

2560 return;

2561

2562 LLVM_DEBUG(dbgs() << "Applying IGroupLPDAGMutation...\n");

2563 const GCNSubtarget &ST = DAGInstrs->MF.getSubtarget();

2564 TII = ST.getInstrInfo();

2565 DAG = static_cast<ScheduleDAGMI *>(DAGInstrs);

2566 SyncedSchedGroups.clear();

2567 SyncedInstrs.clear();

2568 bool FoundSB = false;

2569 bool FoundIGLP = false;

2570 bool ShouldApplyIGLP = false;

2571 for (auto R = DAG->SUnits.rbegin(), E = DAG->SUnits.rend(); R != E; ++R) {

2572 unsigned Opc = R->getInstr()->getOpcode();

2573

2574 if (Opc == AMDGPU::SCHED_BARRIER) {

2575 addSchedBarrierEdges(*R);

2576 FoundSB = true;

2577 } else if (Opc == AMDGPU::SCHED_GROUP_BARRIER) {

2578 initSchedGroupBarrierPipelineStage(R);

2579 FoundSB = true;

2580 } else if (Opc == AMDGPU::IGLP_OPT) {

2581 if (!FoundSB && !FoundIGLP) {

2582 FoundIGLP = true;

2583 ShouldApplyIGLP = initIGLPOpt(*R);

2584 }

2585 }

2586 }

2587

2588 if (FoundSB || (FoundIGLP && ShouldApplyIGLP)) {

2589 PipelineSolver PS(SyncedSchedGroups, SyncedInstrs, DAG, IsBottomUp);

2590

2591

2592 PS.solve();

2593 return;

2594 }

2595}

2596

2597void IGroupLPDAGMutation::addSchedBarrierEdges(SUnit &SchedBarrier) {

2598 MachineInstr &MI = *SchedBarrier.getInstr();

2599 assert(MI.getOpcode() == AMDGPU::SCHED_BARRIER);

2600

2601

2602 LLVM_DEBUG(dbgs() << "Building SchedGroup for SchedBarrier with Mask: "

2603 << MI.getOperand(0).getImm() << "\n");

2604 auto InvertedMask =

2605 invertSchedBarrierMask((SchedGroupMask)MI.getOperand(0).getImm());

2606 SchedGroup SG(InvertedMask, std::nullopt, DAG, TII);

2607 SG.initSchedGroup();

2608

2609

2610 SG.link(

2611 SchedBarrier,

2612 (function_ref<bool(const SUnit *A, const SUnit *B)>)[](

2613 const SUnit *A, const SUnit *B) { return A->NodeNum > B->NodeNum; });

2614}

2615

2616SchedGroupMask

2617IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {

2618

2619

2620 SchedGroupMask InvertedMask = ~Mask;

2621

2622

2623 if ((InvertedMask & SchedGroupMask::ALU) == SchedGroupMask::NONE)

2624 InvertedMask &= ~SchedGroupMask::VALU & ~SchedGroupMask::SALU &

2625 ~SchedGroupMask::MFMA & ~SchedGroupMask::TRANS;

2626

2627 else if ((InvertedMask & SchedGroupMask::VALU) == SchedGroupMask::NONE ||

2628 (InvertedMask & SchedGroupMask::SALU) == SchedGroupMask::NONE ||

2629 (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE ||

2630 (InvertedMask & SchedGroupMask::TRANS) == SchedGroupMask::NONE)

2631 InvertedMask &= ~SchedGroupMask::ALU;

2632

2633

2634 if ((InvertedMask & SchedGroupMask::VMEM) == SchedGroupMask::NONE)

2635 InvertedMask &= ~SchedGroupMask::VMEM_READ & ~SchedGroupMask::VMEM_WRITE;

2636

2637 else if ((InvertedMask & SchedGroupMask::VMEM_READ) == SchedGroupMask::NONE ||

2638 (InvertedMask & SchedGroupMask::VMEM_WRITE) == SchedGroupMask::NONE)

2639 InvertedMask &= ~SchedGroupMask::VMEM;

2640

2641

2642 if ((InvertedMask & SchedGroupMask::DS) == SchedGroupMask::NONE)

2643 InvertedMask &= ~SchedGroupMask::DS_READ & ~SchedGroupMask::DS_WRITE;

2644

2645 else if ((InvertedMask & SchedGroupMask::DS_READ) == SchedGroupMask::NONE ||

2646 (InvertedMask & SchedGroupMask::DS_WRITE) == SchedGroupMask::NONE)

2647 InvertedMask &= ~SchedGroupMask::DS;

2648

2649 LLVM_DEBUG(dbgs() << "After Inverting, SchedGroup Mask: " << (int)InvertedMask

2650 << "\n");

2651

2652 return InvertedMask;

2653}

2654

2655void IGroupLPDAGMutation::initSchedGroupBarrierPipelineStage(

2656 std::vector::reverse_iterator RIter) {

2657

2658

2659 MachineInstr &SGB = *RIter->getInstr();

2664

2665 auto &SG = SyncedSchedGroups[SyncID].emplace_back((SchedGroupMask)SGMask,

2666 Size, SyncID, DAG, TII);

2667

2668 SG.initSchedGroup(RIter, SyncedInstrs[SG.getSyncID()]);

2669}

2670

2671bool IGroupLPDAGMutation::initIGLPOpt(SUnit &SU) {

2672 IGLPStrategyID StrategyID =

2674 auto S = createIGLPStrategy(StrategyID, DAG, TII);

2675 if (!S->shouldApplyStrategy(DAG, Phase))

2676 return false;

2677

2678 IsBottomUp = S->IsBottomUp;

2679 return S->applyIGLPStrategy(SyncedInstrs, SyncedSchedGroups, Phase);

2680}

2681

2682}

2683

2684

2685

2686

2687

2688

2689std::unique_ptr

2691 return std::make_unique(Phase);

2692}

unsigned const MachineRegisterInfo * MRI

aarch64 falkor hwpf fix Falkor HW Prefetch Fix Late Phase

assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")

const TargetInstrInfo & TII

Provides AMDGPU specific target descriptions.

AMDGPU Rewrite AGPR Copy MFMA

static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")

static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")

static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")

This file defines the DenseMap class.

static std::pair< Value *, APInt > getMask(Value *WideMask, unsigned Factor, ElementCount LeafValueEC)

Register const TargetRegisterInfo * TRI

Interface definition for SIInstrInfo.

ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...

size_t size() const

size - Get the array size.

std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)

Instructions::iterator instr_iterator

const TargetSubtargetInfo & getSubtarget() const

getSubtarget - Return the subtarget for which this machine code is being compiled.

Representation of each machine instruction.

unsigned getOpcode() const

Returns the opcode of this MachineInstr.

bool mayStore(QueryType Type=AnyInBundle) const

Return true if this instruction could possibly modify memory.

const MachineOperand & getOperand(unsigned i) const

@ Data

Regular data dependence (aka true-dependence).

@ Artificial

Arbitrary strong DAG edge (no real dependence).

Scheduling unit. This is a node in the scheduling DAG.

unsigned NodeNum

Entry # of node in the node vector.

LLVM_ABI void removePred(const SDep &D)

Removes the specified edge as a pred of the current node if it exists.

SmallVector< SDep, 4 > Succs

All sunit successors.

SmallVector< SDep, 4 > Preds

All sunit predecessors.

MachineInstr * getInstr() const

Returns the representative MachineInstr for this SUnit.

A ScheduleDAG for scheduling lists of MachineInstr.

const TargetSchedModel * getSchedModel() const

Gets the machine model for instruction scheduling.

bool addEdge(SUnit *SuccSU, const SDep &PredDep)

Add a DAG edge to the given SU with the given predecessor dependence data.

bool IsReachable(SUnit *SU, SUnit *TargetSU)

IsReachable - Checks if SU is reachable from TargetSU.

bool canAddEdge(SUnit *SuccSU, SUnit *PredSU)

True if an edge can be added from PredSU to SuccSU without creating a cycle.

void dump() const override

ScheduleDAGMI is an implementation of ScheduleDAGInstrs that simply schedules machine instructions ac...

std::vector< SUnit > SUnits

The scheduling units.

MachineFunction & MF

Machine function.

This class consists of common code factored out of the SmallVector class to reduce code duplication b...

void push_back(const T &Elt)

reverse_iterator rbegin()

This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.

An efficient, type-erasing, non-owning reference to a callable.

#define llvm_unreachable(msg)

Marks that the current location is not supposed to be reachable.

unsigned ID

LLVM IR allows to use arbitrary numbers as calling convention identifiers.

void apply(Opt *O, const Mod &M, const Mods &... Ms)

initializer< Ty > init(const Ty &Val)

LLVM_ABI void link(std::unique_ptr< LinkGraph > G, std::unique_ptr< JITLinkContext > Ctx)

Link the given graph.

This is an optimization pass for GlobalISel generic memory operations.

auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)

Get the size of a range.

std::unique_ptr< ScheduleDAGMutation > createIGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase)

Phase specifes whether or not this is a reentry into the IGroupLPDAGMutation.

Definition AMDGPUIGroupLP.cpp:2690

bool any_of(R &&range, UnaryPredicate P)

Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.

LLVM_ABI raw_ostream & dbgs()

dbgs() - This returns a reference to a raw_ostream for debugging messages.

bool none_of(R &&Range, UnaryPredicate P)

Provide wrappers to std::none_of which take ranges instead of having to pass begin/end explicitly.

class LLVM_GSL_OWNER SmallVector

Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...

FormattedNumber format_hex(uint64_t N, unsigned Width, bool Upper=false)

format_hex - Output N as a fixed width hexadecimal.

@ LLVM_MARK_AS_BITMASK_ENUM

DWARFExpression::Operation Op

auto count_if(R &&Range, UnaryPredicate P)

Wrapper function around std::count_if to count the number of times an element satisfying a given pred...

auto find_if(R &&Range, UnaryPredicate P)

Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.

bool is_contained(R &&Range, const E &Element)

Returns true if Element is found in Range.

void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)

Implement std::swap in terms of BitVector swap.

Function object to check whether the second component of a container supported by std::get (like std:...