LLVM: lib/Target/X86/X86LowerAMXType.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

55#include "llvm/IR/IntrinsicsX86.h"

63

64#include

65

66using namespace llvm;

68

69#define DEBUG_TYPE "x86-lower-amx-type"

70

76

79 if (II)

80 return false;

82 return false;

83

84

85 if (II->getType()->isX86_AMXTy())

86 return true;

87 for (Value *V : II->args()) {

88 if (V->getType()->isX86_AMXTy())

89 return true;

90 }

91

92 return false;

93}

94

98 if (I.getType()->isX86_AMXTy())

99 return true;

100 return false;

101}

102

107

110 unsigned AllocaAS = DL.getAllocaAddrSpace();

112 new AllocaInst(Ty, AllocaAS, "", F.getEntryBlock().begin());

114 return AllocaRes;

115}

116

123

126 Value *RealRow = nullptr;

128 RealRow =

129 Builder.getInt16((cast(V)->getSExtValue()) / Granularity);

131

132

133

134

135

136

137

138

139

140

141

142

144 RealRow = Builder.CreateUDiv(V, Builder.getInt16(4));

146 } else {

147

148

151 RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity));

152 }

153 return RealRow;

154}

155

156

159 Value *Row = nullptr, *Col = nullptr;

160 switch (II->getIntrinsicID()) {

161 default:

163 case Intrinsic::x86_tileloadd64_internal:

164 case Intrinsic::x86_tileloaddt164_internal:

165 case Intrinsic::x86_tilestored64_internal:

166 case Intrinsic::x86_t2rpntlvwz0rs_internal:

167 case Intrinsic::x86_t2rpntlvwz0rst1_internal:

168 case Intrinsic::x86_t2rpntlvwz1rs_internal:

169 case Intrinsic::x86_t2rpntlvwz1rst1_internal:

170 case Intrinsic::x86_tileloaddrs64_internal:

171 case Intrinsic::x86_tileloaddrst164_internal: {

172 Row = II->getArgOperand(0);

173 Col = II->getArgOperand(1);

174 break;

175 }

176

177

178 case Intrinsic::x86_tcmmimfp16ps_internal:

179 case Intrinsic::x86_tcmmrlfp16ps_internal:

180 case Intrinsic::x86_tdpbssd_internal:

181 case Intrinsic::x86_tdpbsud_internal:

182 case Intrinsic::x86_tdpbusd_internal:

183 case Intrinsic::x86_tdpbuud_internal:

184 case Intrinsic::x86_tdpbf16ps_internal:

185 case Intrinsic::x86_tdpfp16ps_internal:

186 case Intrinsic::x86_tmmultf32ps_internal:

187 case Intrinsic::x86_tdpbf8ps_internal:

188 case Intrinsic::x86_tdpbhf8ps_internal:

189 case Intrinsic::x86_tdphbf8ps_internal:

190 case Intrinsic::x86_tdphf8ps_internal: {

191 switch (OpNo) {

192 case 3:

193 Row = II->getArgOperand(0);

194 Col = II->getArgOperand(1);

195 break;

196 case 4:

197 Row = II->getArgOperand(0);

198 Col = II->getArgOperand(2);

199 break;

200 case 5:

202 Col = II->getArgOperand(1);

203 break;

204 }

205 break;

206 }

207 case Intrinsic::x86_tcvtrowd2ps_internal:

208 case Intrinsic::x86_tcvtrowps2bf16h_internal:

209 case Intrinsic::x86_tcvtrowps2bf16l_internal:

210 case Intrinsic::x86_tcvtrowps2phh_internal:

211 case Intrinsic::x86_tcvtrowps2phl_internal:

212 case Intrinsic::x86_tilemovrow_internal: {

213 assert(OpNo == 2 && "Illegal Operand Number.");

214 Row = II->getArgOperand(0);

215 Col = II->getArgOperand(1);

216 break;

217 }

218 }

219

220 return std::make_pair(Row, Col);

221}

222

224 Use &U = *(Phi->use_begin());

226 User *V = U.getUser();

227

228

229

230

231 while (V) {

233 if (V->use_empty())

234 break;

235 Use &U = *(V->use_begin());

236 OpNo = U.getOperandNo();

237 V = U.getUser();

241 if (V->use_empty())

242 break;

243 Use &U = *(V->use_begin());

245 } else {

246 break;

247 }

248 }

249

250 return std::make_pair(nullptr, nullptr);

251}

252

253namespace {

254class X86LowerAMXType {

256

257

258

259

260 std::map<Value *, Value *> Col2Row;

261

262public:

263 X86LowerAMXType(Function &F) : Func(F) {}

265 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);

266 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);

267 bool transformBitcast(BitCastInst *Bitcast);

268};

269

270

271

272

273

274

275void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {

276 Value *Row = nullptr, *Col = nullptr;

278 unsigned OpNo = U.getOperandNo();

280 std::tie(Row, Col) = getShape(II, OpNo);

282

283 Value *Stride = Builder.getInt64(64);

284 Value *I8Ptr = LD->getOperand(0);

285 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};

286

288 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);

289 Bitcast->replaceAllUsesWith(NewInst);

290}

291

292

293

294

295

296

297

298

299void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {

300

303

304

305 Value *Row = II->getOperand(0);

306 Value *Col = II->getOperand(1);

308

309

310 Value *Stride = Builder.getInt64(64);

311 Value *I8Ptr = ST->getOperand(1);

312 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};

313 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);

314 if (Bitcast->hasOneUse())

315 return;

316

317

318

319

320

321

322

323

324

325 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));

326 Bitcast->replaceAllUsesWith(Vec);

327}

328

329

330bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {

332 AllocaInst *AllocaAddr;

333 Value *I8Ptr, *Stride;

334 auto *Src = Bitcast->getOperand(0);

335

336 auto Prepare = [&](Type *MemTy) {

338 I8Ptr = AllocaAddr;

339 Stride = Builder.getInt64(64);

340 };

341

342 if (Bitcast->getType()->isX86_AMXTy()) {

343

344

345

346

347

348

349

350

352 unsigned OpNo = U.getOperandNo();

354 if (II)

355 return false;

356 Prepare(Bitcast->getOperand(0)->getType());

357 Builder.CreateStore(Src, AllocaAddr);

358

359 Value *Row = nullptr, *Col = nullptr;

360 std::tie(Row, Col) = getShape(II, OpNo);

361 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};

363 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);

364 Bitcast->replaceAllUsesWith(NewInst);

365 } else {

366

367

368

369

370

371

372

374 if (II)

375 return false;

376 Prepare(Bitcast->getType());

377 Value *Row = II->getOperand(0);

378 Value *Col = II->getOperand(1);

379 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};

380 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);

381 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);

382 Bitcast->replaceAllUsesWith(NewInst);

383 }

384

385 return true;

386}

387

388bool X86LowerAMXType::visit() {

389 SmallVector<Instruction *, 8> DeadInsts;

390 Col2Row.clear();

391

392 for (BasicBlock *BB : post_order(&Func)) {

395 if (!Bitcast)

396 continue;

397

399 if (Bitcast->getType()->isX86_AMXTy()) {

400 if (Bitcast->user_empty()) {

402 continue;

403 }

405 if (!LD) {

406 if (transformBitcast(Bitcast))

408 continue;

409 }

410

411

412

413

414

415

416

417

418

419

420

421

422

423

424

425

426 combineLoadBitcast(LD, Bitcast);

428 if (LD->hasOneUse())

430 } else if (Src->getType()->isX86_AMXTy()) {

431 if (Bitcast->user_empty()) {

433 continue;

434 }

435 StoreInst *ST = nullptr;

436 for (Use &U : Bitcast->uses()) {

438 if (ST)

439 break;

440 }

441 if (!ST) {

442 if (transformBitcast(Bitcast))

444 continue;

445 }

446

447

448

449

450

451

452

453

454

455

456

457

458

459

460

461

462

463

464

465

466 combineBitcastStore(Bitcast, ST);

467

470 }

471 }

472 }

473

474 bool C = !DeadInsts.empty();

475

476 for (auto *Inst : DeadInsts)

477 Inst->eraseFromParent();

478

479 return C;

480}

481}

482

485 IRBuilder<> Builder(&F->getEntryBlock().front());

487 unsigned AllocaAS = DL.getAllocaAddrSpace();

490 new AllocaInst(V256I32Ty, AllocaAS, "", F->getEntryBlock().begin());

492 ++Iter;

493 Builder.SetInsertPoint(&*Iter);

494 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getPtrTy());

495 return I8Ptr;

496}

497

501

502 assert(II && "Not tile intrinsic!");

503 Value *Row = II->getOperand(0);

504 Value *Col = II->getOperand(1);

505

509 Value *Stride = Builder.getInt64(64);

510 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};

511

513 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);

514 return TileStore;

515}

516

518 Value *V = U.get();

519 assert(V->getType()->isX86_AMXTy() && "Not define tile!");

520

521

523 if (IsPHI) {

526 } else {

528 }

529 Value *Row = II->getOperand(0);

530 Value *Col = II->getOperand(1);

531

534 Value *Stride = Builder.getInt64(64);

535 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};

536

537 Value *TileLoad =

538 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);

540}

541

543 for (Use &U : I->uses()) {

544 User *V = U.getUser();

546 return true;

547 }

548 return false;

549}

550

551

552

553namespace {

554class X86VolatileTileData {

556

557public:

558 X86VolatileTileData(Function &Func) : F(Func) {}

559 Value *updatePhiIncomings(BasicBlock *BB,

560 SmallVector<Instruction *, 2> &Incomings);

561 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);

562 bool volatileTileData();

563 void volatileTilePHI(PHINode *PHI);

564 void volatileTileNonPHI(Instruction *I);

565};

566

567Value *X86VolatileTileData::updatePhiIncomings(

568 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {

570

571 for (auto *I : Incomings) {

573

574

575 for (Use &U : I->uses()) {

576 User *V = U.getUser();

578 continue;

580 }

581 }

582 return I8Ptr;

583}

584

585void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,

586 Value *StorePtr) {

587 for (Use &U : PHI->uses())

589 PHI->eraseFromParent();

590}

591

592

593

594

595

596

597

598

599

600

601

602

603

604

605

606

607

608

609

610

611

612

613

614

615

616

617

618

619

620

621

622

623

624

625

626

627

628

629

630

631

632

633

634

635

636

637

638

639

640

641

642

643

644

645

646

647void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {

649 SmallVector<Instruction *, 2> Incomings;

650

651 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {

654 assert(Inst && "We shouldn't fold AMX instrution!");

656 }

657

658 Value *StorePtr = updatePhiIncomings(BB, Incomings);

659 replacePhiDefWithLoad(PHI, StorePtr);

660}

661

662

663

664

665

666

667

668

669

670

671

672

673

674

675

676

677

678void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {

682

683

684 for (Use &U : I->uses()) {

685 User *V = U.getUser();

687 if (V != Store)

689 }

690}

691

692

693

694

695

696

697

698

699

700

701

702

703

704bool X86VolatileTileData::volatileTileData() {

706 for (BasicBlock &BB : F) {

707 SmallVector<Instruction *, 2> PHIInsts;

708 SmallVector<Instruction *, 8> AMXDefInsts;

709

710 for (Instruction &I : BB) {

711 if (I.getType()->isX86_AMXTy())

712 continue;

715 else

717 }

718

719

720 for (Instruction *I : AMXDefInsts) {

722 continue;

723 volatileTileNonPHI(I);

725 }

726

727 for (Instruction *I : PHIInsts) {

730 }

731 }

733}

734

735}

736

737namespace {

738

739class X86LowerAMXCast {

741 std::unique_ptr DT;

742

743public:

744 X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}

745 bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);

746 bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);

747 bool combineTilezero(IntrinsicInst *Cast);

748 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);

749 bool combineAMXcast(TargetLibraryInfo *TLI);

750 bool transformAMXCast(IntrinsicInst *AMXCast);

751 bool transformAllAMXCast();

752 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,

753 SmallSetVector<Instruction *, 16> &DeadInst);

754};

755

757 SmallSetVector<Instruction *, 16> &WorkList,

758 const TargetLibraryInfo *TLI) {

762

763

764

765 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {

766 Value *OpV = I->getOperand(i);

767 I->setOperand(i, nullptr);

768

770 continue;

771

772

773

774

777 WorkList.insert(OpI);

778 }

779 }

780 }

781 I->eraseFromParent();

782 return true;

783 }

784 return false;

785}

786

787

788

789

790

791

792

793

794

795bool X86LowerAMXCast::optimizeAMXCastFromPhi(

796 IntrinsicInst *CI, PHINode *PN,

797 SmallSetVector<Instruction *, 16> &DeadInst) {

800 Type *SrcTy = Src->getType();

802

804 SmallSetVector<PHINode *, 4> OldPhiNodes;

805

806

807

808

809

811 OldPhiNodes.insert(PN);

812 while (!PhiWorklist.empty()) {

814 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {

815 Value *IncValue = OldPN->getIncomingValue(I);

816

817

820 if (isa<UndefValue>(IncValue) && !IncConst->isZeroValue())

821 return false;

822 Value *Row = nullptr, *Col = nullptr;

823 std::tie(Row, Col) = getShape(OldPN);

824

825

827 return false;

828

829 auto *Block = OldPN->getIncomingBlock(I);

831 Instruction *NewInst = Builder.CreateIntrinsic(

832 Intrinsic::x86_tilezero_internal, {}, {Row, Col});

834 NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,

835 {IncValue->getType()}, {NewInst});

837

838 OldPN->setIncomingValue(I, NewInst);

839 IncValue = NewInst;

840 }

841

843 if (OldPhiNodes.insert(PNode))

845 continue;

846 }

849

852 if (TyA != DestTy || TyB != SrcTy)

853 return false;

854 continue;

855 }

856 return false;

857 }

858 }

859

860

861

862 for (auto *OldPN : OldPhiNodes) {

863 for (User *V : OldPN->users()) {

866

869 if (TyA != DestTy || TyB != SrcTy)

870 return false;

872

873

874

875

876

877

878

879

880

881

882

883

884

885

886

887

888

889

890 if (OldPhiNodes.count(PHI) == 0)

891 return false;

892 } else

893 return false;

894 }

895 }

896

897

898 SmallDenseMap<PHINode *, PHINode *> NewPNodes;

899 for (auto *OldPN : OldPhiNodes) {

900 Builder.SetInsertPoint(OldPN);

901 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());

902 NewPNodes[OldPN] = NewPN;

903 }

904

905

906 for (auto *OldPN : OldPhiNodes) {

907 PHINode *NewPN = NewPNodes[OldPN];

908 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {

909 Value *V = OldPN->getOperand(j);

910 Value *NewV = nullptr;

912

916 NewV = NewPNodes[PrevPN];

918 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));

919 }

920 }

921

922

923

924

925

926

927

928

929

930 for (auto *OldPN : OldPhiNodes) {

931 PHINode *NewPN = NewPNodes[OldPN];

937 assert(TyA == DestTy && TyB == SrcTy);

938 (void)TyA;

939 (void)TyB;

941 DeadInst.insert(ACI);

943

944

945 assert(OldPhiNodes.contains(PHI));

946 (void)PHI;

947 } else

949 }

950 }

951 return true;

952}

953

954

955

956

957

958

959bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {

961

962 assert(Tile->getType()->isX86_AMXTy() && "Not Tile Operand!");

963

964

965 if (!Tile->hasOneUse())

966 return false;

967

969

970

971 Value *Row = II->getOperand(0);

972 Value *Col = II->getOperand(1);

973

975

976

977 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());

978 Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getPtrTy());

979 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};

980 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);

981 return true;

982}

983

984

985

986

987

988

989bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {

990 bool EraseLoad = true;

991 Value *Row = nullptr, *Col = nullptr;

993 unsigned OpNo = U.getOperandNo();

995

996

998 return false;

999 std::tie(Row, Col) = getShape(II, OpNo);

1001

1002 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());

1004

1005

1006

1007 if (!DT)

1008 DT.reset(new DominatorTree(Func));

1009 if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {

1010

1011 auto *AllocaAddr =

1013 Builder.SetInsertPoint(&*std::next(LD->getIterator()));

1014 Builder.CreateStore(LD, AllocaAddr);

1015

1016 Builder.SetInsertPoint(Cast);

1017 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());

1018 EraseLoad = false;

1019 } else {

1020 I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy());

1021 }

1022 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};

1023

1024 Value *NewInst =

1025 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);

1027

1028 return EraseLoad;

1029}

1030

1031

1032

1033

1034bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) {

1035 Value *Row = nullptr, *Col = nullptr;

1037 unsigned OpNo = U.getOperandNo();

1040 return false;

1041

1042 std::tie(Row, Col) = getShape(II, OpNo);

1043

1045 Value *NewInst =

1046 Builder.CreateIntrinsic(Intrinsic::x86_tilezero_internal, {}, {Row, Col});

1048 return true;

1049}

1050

1051bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {

1052 bool Change = false;

1053 for (auto *Cast : Casts) {

1055

1056

1057

1058

1059

1060 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {

1061 SmallVector<Instruction *, 2> DeadStores;

1062 for (User *U : Cast->users()) {

1064 if (!Store)

1065 continue;

1068 Change = true;

1069 }

1070 }

1071 for (auto *Store : DeadStores)

1072 Store->eraseFromParent();

1073 } else {

1074

1075

1076

1079 continue;

1080 }

1081

1083 if (!Load || Load->hasOneUse())

1084 continue;

1085

1086

1087

1088

1089

1091

1093 Load->eraseFromParent();

1094 Change = true;

1095 }

1096 }

1097 }

1098 return Change;

1099}

1100

1101bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {

1102 bool Change = false;

1103

1104 SmallVector<Instruction *, 8> Vec2TileInsts;

1105 SmallVector<Instruction *, 8> Tile2VecInsts;

1106 SmallVector<Instruction *, 8> PhiCastWorkList;

1107 SmallSetVector<Instruction *, 16> DeadInst;

1108 for (BasicBlock &BB : Func) {

1109 for (Instruction &I : BB) {

1117 }

1118 }

1119

1120 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {

1121 for (auto *Inst : Insts) {

1122 for (User *U : Inst->users()) {

1124 if (II || II->getIntrinsicID() != IID)

1125 continue;

1126

1127

1128

1129

1130

1131

1132

1133 II->replaceAllUsesWith(Inst->getOperand(0));

1134 Change = true;

1135 }

1136 }

1137 };

1138

1139 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);

1140 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);

1141

1142 SmallVector<Instruction *, 8> LiveCasts;

1143 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {

1144 for (auto *Inst : Insts) {

1145 if (Inst->use_empty()) {

1146 Inst->eraseFromParent();

1147 Change = true;

1148 } else {

1150 }

1151 }

1152 };

1153

1154 EraseInst(Vec2TileInsts);

1155 EraseInst(Tile2VecInsts);

1156 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "

1157 "Vec2Tile and Tile2Vec:\n";

1158 Func.dump());

1159 Change |= combineLdSt(LiveCasts);

1160 EraseInst(LiveCasts);

1161 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "

1162 "AMXCast and load/store:\n";

1163 Func.dump());

1164

1165

1166 for (BasicBlock &BB : Func) {

1167 for (Instruction &I : BB) {

1171 }

1172 }

1173 }

1174 for (auto *I : PhiCastWorkList) {

1175

1177 continue;

1180 DeadInst.insert(PN);

1181 Change = true;

1182 }

1183 }

1184

1185

1186

1187 while (!DeadInst.empty()) {

1190 }

1191 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "

1192 "optimizeAMXCastFromPhi:\n";

1193 Func.dump());

1194 return Change;

1195}

1196

1197

1198

1199bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {

1201 AllocaInst *AllocaAddr;

1202 Value *I8Ptr, *Stride;

1204

1205 auto Prepare = [&](Type *MemTy) {

1207 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());

1208 Stride = Builder.getInt64(64);

1209 };

1210

1212

1213

1214

1215

1216

1217

1218

1219

1220

1221

1222

1223

1226 return true;

1227 }

1229 unsigned OpNo = U.getOperandNo();

1231 if (II)

1232 return false;

1234 Builder.CreateStore(Src, AllocaAddr);

1235

1236 Value *Row = nullptr, *Col = nullptr;

1237 std::tie(Row, Col) = getShape(II, OpNo);

1238 std::array<Value *, 4> Args = {

1239 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};

1240 Value *NewInst =

1241 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);

1244 } else {

1245

1246

1247

1248

1249

1250

1251

1253 if (II)

1254 return false;

1255 Prepare(AMXCast->getType());

1256 Value *Row = II->getOperand(0);

1257 Value *Col = II->getOperand(1);

1258 std::array<Value *, 5> Args = {

1259 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};

1260 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);

1261 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);

1264 }

1265

1266 return true;

1267}

1268

1269bool X86LowerAMXCast::transformAllAMXCast() {

1270 bool Change = false;

1271

1272 SmallVector<Instruction *, 8> WorkLists;

1273 for (BasicBlock &BB : Func) {

1274 for (Instruction &I : BB) {

1277 }

1278 }

1279

1280 for (auto *Inst : WorkLists) {

1282 }

1283

1284 return Change;

1285}

1286

1287bool lowerAmxType(Function &F, const TargetMachine *TM,

1288 TargetLibraryInfo *TLI) {

1289

1290

1291

1292

1293

1295 return false;

1296

1297 bool C = false;

1298 X86LowerAMXCast LAC(F);

1299 C |= LAC.combineAMXcast(TLI);

1300

1301

1302 C |= LAC.transformAllAMXCast();

1303

1304 X86LowerAMXType LAT(F);

1305 C |= LAT.visit();

1306

1307

1308

1309

1310 if (TM->getOptLevel() == CodeGenOptLevel::None) {

1311

1312

1313

1314

1315 if (F.hasFnAttribute(Attribute::OptimizeNone)) {

1316 X86VolatileTileData VTD(F);

1317 C = VTD.volatileTileData() || C;

1318 }

1319 }

1320

1321 return C;

1322}

1323

1324}

1325

1329 bool Changed = lowerAmxType(F, TM, &TLI);

1332

1335 return PA;

1336}

1337

1338namespace {

1339

1340class X86LowerAMXTypeLegacyPass : public FunctionPass {

1341public:

1342 static char ID;

1343

1345

1349 &getAnalysis().getTLI(F);

1350 return lowerAmxType(F, TM, TLI);

1351 }

1352

1353 void getAnalysisUsage(AnalysisUsage &AU) const override {

1356 AU.addRequired();

1357 }

1358};

1359

1360}

1361

1362static const char PassName[] = "Lower AMX type for load/store";

1363char X86LowerAMXTypeLegacyPass::ID = 0;

1365 false)

1370

1372 return new X86LowerAMXTypeLegacyPass();

1373}

assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")

MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL

static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")

static bool DCEInstruction(Instruction *I, SmallSetVector< Instruction *, 16 > &WorkList, const TargetLibraryInfo *TLI)

static bool runOnFunction(Function &F, bool PostInlining)

This header defines various interfaces for pass management in LLVM.

uint64_t IntrinsicInst * II

FunctionAnalysisManager FAM

#define INITIALIZE_PASS_DEPENDENCY(depName)

#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)

#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)

This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.

void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)

This file implements a set that has insertion order iteration characteristics.

Target-Independent Code Generator Pass Configuration Options pass.

This pass exposes codegen information to IR-level passes.

static ShapeT getShape(MachineRegisterInfo *MRI, Register TileReg)

static const char PassName[]

static bool isAMXCast(Instruction *II)

Definition X86LowerAMXType.cpp:71

static Value * getRowFromCol(Instruction *II, Value *V, unsigned Granularity)

Definition X86LowerAMXType.cpp:124

static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)

Definition X86LowerAMXType.cpp:517

static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)

Definition X86LowerAMXType.cpp:498

static Value * getAllocaPos(BasicBlock *BB)

Definition X86LowerAMXType.cpp:483

static bool containsAMXCode(Function &F)

Definition X86LowerAMXType.cpp:95

std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)

Definition X86LowerAMXType.cpp:157

static bool isIncomingOfPHI(Instruction *I)

Definition X86LowerAMXType.cpp:542

static bool isAMXIntrinsic(Value *I)

Definition X86LowerAMXType.cpp:77

static Instruction * getFirstNonAllocaInTheEntryBlock(Function &F)

Definition X86LowerAMXType.cpp:117

static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty)

Definition X86LowerAMXType.cpp:103

an instruction to allocate memory on the stack

void setAlignment(Align Align)

AnalysisUsage & addRequired()

LLVM_ABI void setPreservesCFG()

This function should be called by the pass, iff they do not:

LLVM Basic Block Representation.

const Function * getParent() const

Return the enclosing method, or null if none.

InstListType::iterator iterator

Instruction iterators...

This class represents a no-op cast from one type to another.

Represents analyses that only rely on functions' control flow.

A parsed version of the target data layout string in and methods for querying it.

FunctionPass class - This class is used to implement most global optimizations.

Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)

ConstantInt * getInt16(uint16_t C)

Get a constant 16-bit value.

This provides a uniform API for creating instructions and inserting them into a basic block: either a...

LLVM_ABI void moveBefore(InstListType::iterator InsertPos)

Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...

LLVM_ABI InstListType::iterator eraseFromParent()

This method unlinks 'this' from the containing basic block and deletes it.

A wrapper class for inspecting calls to intrinsic functions.

This is an important class for using LLVM in a threaded context.

An instruction for reading from memory.

void addIncoming(Value *V, BasicBlock *BB)

Add an incoming value to the end of the PHI list.

A set of analyses that are preserved following a run of a transformation pass.

static PreservedAnalyses none()

Convenience factory function for the empty preserved set.

static PreservedAnalyses all()

Construct a special preserved set that preserves all passes.

PreservedAnalyses & preserveSet()

Mark an analysis set as preserved.

bool contains(const_arg_type key) const

Check if the SetVector contains the given key.

bool empty() const

Determine if the SetVector is empty or not.

bool insert(const value_type &X)

Insert a new element into the SetVector.

value_type pop_back_val()

void push_back(const T &Elt)

Analysis pass providing the TargetLibraryInfo.

Provides information about what library functions are available for the current target.

Primary interface to the complete machine description for the target machine.

CodeGenOptLevel getOptLevel() const

Returns the optimization level: None, Less, Default, or Aggressive.

Target-Independent Code Generator Pass Configuration Options.

The instances of the Type class are immutable: once they are created, they are never changed.

static LLVM_ABI Type * getX86_AMXTy(LLVMContext &C)

bool isX86_AMXTy() const

Return true if this is X86 AMX.

A Use represents the edge between a Value definition and its users.

LLVM_ABI unsigned getOperandNo() const

Return the operand # of this use in its User.

User * getUser() const

Returns the User that contains this Use.

void setOperand(unsigned i, Value *Val)

LLVM_ABI bool replaceUsesOfWith(Value *From, Value *To)

Replace uses of one Value with another.

Value * getOperand(unsigned i) const

LLVM Value Representation.

Type * getType() const

All values are typed, get the type of this value.

LLVM_ABI void replaceAllUsesWith(Value *V)

Change all uses of this to point to a new Value.

iterator_range< user_iterator > users()

static LLVM_ABI VectorType * get(Type *ElementType, ElementCount EC)

This static method is the primary way to construct an VectorType.

PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM)

Definition X86LowerAMXType.cpp:1326

const ParentTy * getParent() const

self_iterator getIterator()

Pass manager infrastructure for declaring and invalidating analyses.

#define llvm_unreachable(msg)

Marks that the current location is not supposed to be reachable.

constexpr char Args[]

Key for Kernel::Metadata::mArgs.

unsigned ID

LLVM IR allows to use arbitrary numbers as calling convention identifiers.

@ C

The default llvm calling convention, compatible with C.

@ BasicBlock

Various leaf nodes.

@ Bitcast

Perform the operation on a different, but equivalently sized type.

bool match(Val *V, const Pattern &P)

IntrinsicID_match m_Intrinsic()

Match intrinsic calls like this: m_IntrinsicIntrinsic::fabs(m_Value(X))

class_match< Value > m_Value()

Match an arbitrary value and ignore it.

@ User

could "use" a pointer

NodeAddr< UseNode * > Use

NodeAddr< FuncNode * > Func

friend class Instruction

Iterator for Instructions in a `BasicBlock.

This is an optimization pass for GlobalISel generic memory operations.

FunctionAddr VTableAddr Value

decltype(auto) dyn_cast(const From &Val)

dyn_cast - Return the argument parameter cast to the specified type.

LLVM_ABI void salvageDebugInfo(const MachineRegisterInfo &MRI, MachineInstr &MI)

Assuming the instruction MI is going to be deleted, attempt to salvage debug users of MI by writing t...

iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)

Make a range that does early increment to allow mutation of the underlying range without disrupting i...

iterator_range< po_iterator< T > > post_order(const T &G)

LLVM_ABI bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)

Return true if the result produced by the instruction is not used, and the instruction will return.

auto reverse(ContainerTy &&C)

LLVM_ABI raw_ostream & dbgs()

dbgs() - This returns a reference to a raw_ostream for debugging messages.

class LLVM_GSL_OWNER SmallVector

Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...

bool isa(const From &Val)

isa - Return true if the parameter to the template is an instance of one of the template type argu...

IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >

LLVM_ABI bool salvageKnowledge(Instruction *I, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)

Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert if before I.

DWARFExpression::Operation Op

FunctionPass * createX86LowerAMXTypeLegacyPass()

Definition X86LowerAMXType.cpp:1371

decltype(auto) cast(const From &Val)

cast - Return the argument parameter cast to the specified type.

AnalysisManager< Function > FunctionAnalysisManager

Convenience typedef for the Function analysis manager.