LLVM: lib/Target/X86/X86LowerAMXType.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
55#include "llvm/IR/IntrinsicsX86.h"
63
64#include
65
66using namespace llvm;
68
69#define DEBUG_TYPE "x86-lower-amx-type"
70
76
79 if ()
80 return false;
82 return false;
83
84
85 if (II->getType()->isX86_AMXTy())
86 return true;
87 for (Value *V : II->args()) {
88 if (V->getType()->isX86_AMXTy())
89 return true;
90 }
91
92 return false;
93}
94
98 if (I.getType()->isX86_AMXTy())
99 return true;
100 return false;
101}
102
107
110 unsigned AllocaAS = DL.getAllocaAddrSpace();
112 new AllocaInst(Ty, AllocaAS, "", F.getEntryBlock().begin());
114 return AllocaRes;
115}
116
123
126 Value *RealRow = nullptr;
128 RealRow =
129 Builder.getInt16((cast(V)->getSExtValue()) / Granularity);
131
132
133
134
135
136
137
138
139
140
141
142
144 RealRow = Builder.CreateUDiv(V, Builder.getInt16(4));
146 } else {
147
148
151 RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity));
152 }
153 return RealRow;
154}
155
156
159 Value *Row = nullptr, *Col = nullptr;
160 switch (II->getIntrinsicID()) {
161 default:
163 case Intrinsic::x86_tileloadd64_internal:
164 case Intrinsic::x86_tileloaddt164_internal:
165 case Intrinsic::x86_tilestored64_internal:
166 case Intrinsic::x86_t2rpntlvwz0rs_internal:
167 case Intrinsic::x86_t2rpntlvwz0rst1_internal:
168 case Intrinsic::x86_t2rpntlvwz1rs_internal:
169 case Intrinsic::x86_t2rpntlvwz1rst1_internal:
170 case Intrinsic::x86_tileloaddrs64_internal:
171 case Intrinsic::x86_tileloaddrst164_internal: {
172 Row = II->getArgOperand(0);
173 Col = II->getArgOperand(1);
174 break;
175 }
176
177
178 case Intrinsic::x86_tcmmimfp16ps_internal:
179 case Intrinsic::x86_tcmmrlfp16ps_internal:
180 case Intrinsic::x86_tdpbssd_internal:
181 case Intrinsic::x86_tdpbsud_internal:
182 case Intrinsic::x86_tdpbusd_internal:
183 case Intrinsic::x86_tdpbuud_internal:
184 case Intrinsic::x86_tdpbf16ps_internal:
185 case Intrinsic::x86_tdpfp16ps_internal:
186 case Intrinsic::x86_tmmultf32ps_internal:
187 case Intrinsic::x86_tdpbf8ps_internal:
188 case Intrinsic::x86_tdpbhf8ps_internal:
189 case Intrinsic::x86_tdphbf8ps_internal:
190 case Intrinsic::x86_tdphf8ps_internal: {
191 switch (OpNo) {
192 case 3:
193 Row = II->getArgOperand(0);
194 Col = II->getArgOperand(1);
195 break;
196 case 4:
197 Row = II->getArgOperand(0);
198 Col = II->getArgOperand(2);
199 break;
200 case 5:
202 Col = II->getArgOperand(1);
203 break;
204 }
205 break;
206 }
207 case Intrinsic::x86_tcvtrowd2ps_internal:
208 case Intrinsic::x86_tcvtrowps2bf16h_internal:
209 case Intrinsic::x86_tcvtrowps2bf16l_internal:
210 case Intrinsic::x86_tcvtrowps2phh_internal:
211 case Intrinsic::x86_tcvtrowps2phl_internal:
212 case Intrinsic::x86_tilemovrow_internal: {
213 assert(OpNo == 2 && "Illegal Operand Number.");
214 Row = II->getArgOperand(0);
215 Col = II->getArgOperand(1);
216 break;
217 }
218 }
219
220 return std::make_pair(Row, Col);
221}
222
224 Use &U = *(Phi->use_begin());
226 User *V = U.getUser();
227
228
229
230
231 while (V) {
233 if (V->use_empty())
234 break;
235 Use &U = *(V->use_begin());
236 OpNo = U.getOperandNo();
237 V = U.getUser();
241 if (V->use_empty())
242 break;
243 Use &U = *(V->use_begin());
245 } else {
246 break;
247 }
248 }
249
250 return std::make_pair(nullptr, nullptr);
251}
252
253namespace {
254class X86LowerAMXType {
256
257
258
259
260 std::map<Value *, Value *> Col2Row;
261
262public:
263 X86LowerAMXType(Function &F) : Func(F) {}
265 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
266 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
267 bool transformBitcast(BitCastInst *Bitcast);
268};
269
270
271
272
273
274
275void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
276 Value *Row = nullptr, *Col = nullptr;
278 unsigned OpNo = U.getOperandNo();
280 std::tie(Row, Col) = getShape(II, OpNo);
282
283 Value *Stride = Builder.getInt64(64);
284 Value *I8Ptr = LD->getOperand(0);
285 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
286
288 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
289 Bitcast->replaceAllUsesWith(NewInst);
290}
291
292
293
294
295
296
297
298
299void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
300
303
304
305 Value *Row = II->getOperand(0);
306 Value *Col = II->getOperand(1);
308
309
310 Value *Stride = Builder.getInt64(64);
311 Value *I8Ptr = ST->getOperand(1);
312 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
313 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
314 if (Bitcast->hasOneUse())
315 return;
316
317
318
319
320
321
322
323
324
325 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
326 Bitcast->replaceAllUsesWith(Vec);
327}
328
329
330bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
332 AllocaInst *AllocaAddr;
333 Value *I8Ptr, *Stride;
334 auto *Src = Bitcast->getOperand(0);
335
336 auto Prepare = [&](Type *MemTy) {
338 I8Ptr = AllocaAddr;
339 Stride = Builder.getInt64(64);
340 };
341
342 if (Bitcast->getType()->isX86_AMXTy()) {
343
344
345
346
347
348
349
350
352 unsigned OpNo = U.getOperandNo();
354 if ()
355 return false;
356 Prepare(Bitcast->getOperand(0)->getType());
357 Builder.CreateStore(Src, AllocaAddr);
358
359 Value *Row = nullptr, *Col = nullptr;
360 std::tie(Row, Col) = getShape(II, OpNo);
361 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
363 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
364 Bitcast->replaceAllUsesWith(NewInst);
365 } else {
366
367
368
369
370
371
372
374 if ()
375 return false;
376 Prepare(Bitcast->getType());
377 Value *Row = II->getOperand(0);
378 Value *Col = II->getOperand(1);
379 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
380 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
381 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
382 Bitcast->replaceAllUsesWith(NewInst);
383 }
384
385 return true;
386}
387
388bool X86LowerAMXType::visit() {
389 SmallVector<Instruction *, 8> DeadInsts;
390 Col2Row.clear();
391
392 for (BasicBlock *BB : post_order(&Func)) {
395 if (!Bitcast)
396 continue;
397
399 if (Bitcast->getType()->isX86_AMXTy()) {
400 if (Bitcast->user_empty()) {
402 continue;
403 }
405 if (!LD) {
406 if (transformBitcast(Bitcast))
408 continue;
409 }
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426 combineLoadBitcast(LD, Bitcast);
428 if (LD->hasOneUse())
430 } else if (Src->getType()->isX86_AMXTy()) {
431 if (Bitcast->user_empty()) {
433 continue;
434 }
435 StoreInst *ST = nullptr;
436 for (Use &U : Bitcast->uses()) {
438 if (ST)
439 break;
440 }
441 if (!ST) {
442 if (transformBitcast(Bitcast))
444 continue;
445 }
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466 combineBitcastStore(Bitcast, ST);
467
470 }
471 }
472 }
473
474 bool C = !DeadInsts.empty();
475
476 for (auto *Inst : DeadInsts)
477 Inst->eraseFromParent();
478
479 return C;
480}
481}
482
485 IRBuilder<> Builder(&F->getEntryBlock().front());
487 unsigned AllocaAS = DL.getAllocaAddrSpace();
490 new AllocaInst(V256I32Ty, AllocaAS, "", F->getEntryBlock().begin());
492 ++Iter;
493 Builder.SetInsertPoint(&*Iter);
494 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getPtrTy());
495 return I8Ptr;
496}
497
501
502 assert(II && "Not tile intrinsic!");
503 Value *Row = II->getOperand(0);
504 Value *Col = II->getOperand(1);
505
509 Value *Stride = Builder.getInt64(64);
510 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
511
513 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
514 return TileStore;
515}
516
518 Value *V = U.get();
519 assert(V->getType()->isX86_AMXTy() && "Not define tile!");
520
521
523 if (IsPHI) {
526 } else {
528 }
529 Value *Row = II->getOperand(0);
530 Value *Col = II->getOperand(1);
531
534 Value *Stride = Builder.getInt64(64);
535 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
536
537 Value *TileLoad =
538 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
540}
541
543 for (Use &U : I->uses()) {
544 User *V = U.getUser();
546 return true;
547 }
548 return false;
549}
550
551
552
553namespace {
554class X86VolatileTileData {
556
557public:
558 X86VolatileTileData(Function &Func) : F(Func) {}
559 Value *updatePhiIncomings(BasicBlock *BB,
560 SmallVector<Instruction *, 2> &Incomings);
561 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
562 bool volatileTileData();
563 void volatileTilePHI(PHINode *PHI);
564 void volatileTileNonPHI(Instruction *I);
565};
566
567Value *X86VolatileTileData::updatePhiIncomings(
568 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
570
571 for (auto *I : Incomings) {
573
574
575 for (Use &U : I->uses()) {
578 continue;
580 }
581 }
582 return I8Ptr;
583}
584
585void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
586 Value *StorePtr) {
587 for (Use &U : PHI->uses())
589 PHI->eraseFromParent();
590}
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
649 SmallVector<Instruction *, 2> Incomings;
650
651 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
654 assert(Inst && "We shouldn't fold AMX instrution!");
656 }
657
658 Value *StorePtr = updatePhiIncomings(BB, Incomings);
659 replacePhiDefWithLoad(PHI, StorePtr);
660}
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
682
683
684 for (Use &U : I->uses()) {
687 if (V != Store)
689 }
690}
691
692
693
694
695
696
697
698
699
700
701
702
703
704bool X86VolatileTileData::volatileTileData() {
706 for (BasicBlock &BB : F) {
707 SmallVector<Instruction *, 2> PHIInsts;
708 SmallVector<Instruction *, 8> AMXDefInsts;
709
710 for (Instruction &I : BB) {
711 if (.getType()->isX86_AMXTy())
712 continue;
715 else
717 }
718
719
720 for (Instruction *I : AMXDefInsts) {
722 continue;
723 volatileTileNonPHI(I);
725 }
726
727 for (Instruction *I : PHIInsts) {
730 }
731 }
733}
734
735}
736
737namespace {
738
739class X86LowerAMXCast {
741 std::unique_ptr DT;
742
743public:
744 X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
745 bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
746 bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
747 bool combineTilezero(IntrinsicInst *Cast);
748 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
749 bool combineAMXcast(TargetLibraryInfo *TLI);
750 bool transformAMXCast(IntrinsicInst *AMXCast);
751 bool transformAllAMXCast();
752 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
753 SmallSetVector<Instruction *, 16> &DeadInst);
754};
755
757 SmallSetVector<Instruction *, 16> &WorkList,
758 const TargetLibraryInfo *TLI) {
762
763
764
765 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
766 Value *OpV = I->getOperand(i);
767 I->setOperand(i, nullptr);
768
770 continue;
771
772
773
774
777 WorkList.insert(OpI);
778 }
779 }
780 }
781 I->eraseFromParent();
782 return true;
783 }
784 return false;
785}
786
787
788
789
790
791
792
793
794
795bool X86LowerAMXCast::optimizeAMXCastFromPhi(
796 IntrinsicInst *CI, PHINode *PN,
797 SmallSetVector<Instruction *, 16> &DeadInst) {
800 Type *SrcTy = Src->getType();
802
804 SmallSetVector<PHINode *, 4> OldPhiNodes;
805
806
807
808
809
811 OldPhiNodes.insert(PN);
812 while (!PhiWorklist.empty()) {
814 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
815 Value *IncValue = OldPN->getIncomingValue(I);
816
817
820 if ((IncValue) && !IncConst->isZeroValue())
821 return false;
822 Value *Row = nullptr, *Col = nullptr;
823 std::tie(Row, Col) = getShape(OldPN);
824
825
827 return false;
828
829 auto *Block = OldPN->getIncomingBlock(I);
831 Instruction *NewInst = Builder.CreateIntrinsic(
832 Intrinsic::x86_tilezero_internal, {}, {Row, Col});
834 NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
835 {IncValue->getType()}, {NewInst});
837
838 OldPN->setIncomingValue(I, NewInst);
839 IncValue = NewInst;
840 }
841
843 if (OldPhiNodes.insert(PNode))
845 continue;
846 }
849
852 if (TyA != DestTy || TyB != SrcTy)
853 return false;
854 continue;
855 }
856 return false;
857 }
858 }
859
860
861
862 for (auto *OldPN : OldPhiNodes) {
863 for (User *V : OldPN->users()) {
866
869 if (TyA != DestTy || TyB != SrcTy)
870 return false;
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890 if (OldPhiNodes.count(PHI) == 0)
891 return false;
892 } else
893 return false;
894 }
895 }
896
897
898 SmallDenseMap<PHINode *, PHINode *> NewPNodes;
899 for (auto *OldPN : OldPhiNodes) {
900 Builder.SetInsertPoint(OldPN);
901 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
902 NewPNodes[OldPN] = NewPN;
903 }
904
905
906 for (auto *OldPN : OldPhiNodes) {
907 PHINode *NewPN = NewPNodes[OldPN];
908 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
909 Value *V = OldPN->getOperand(j);
910 Value *NewV = nullptr;
912
916 NewV = NewPNodes[PrevPN];
918 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
919 }
920 }
921
922
923
924
925
926
927
928
929
930 for (auto *OldPN : OldPhiNodes) {
931 PHINode *NewPN = NewPNodes[OldPN];
937 assert(TyA == DestTy && TyB == SrcTy);
938 (void)TyA;
939 (void)TyB;
941 DeadInst.insert(ACI);
943
944
945 assert(OldPhiNodes.contains(PHI));
946 (void)PHI;
947 } else
949 }
950 }
951 return true;
952}
953
954
955
956
957
958
959bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
961
962 assert(Tile->getType()->isX86_AMXTy() && "Not Tile Operand!");
963
964
965 if (!Tile->hasOneUse())
966 return false;
967
969
970
971 Value *Row = II->getOperand(0);
972 Value *Col = II->getOperand(1);
973
975
976
977 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
978 Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getPtrTy());
979 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
980 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
981 return true;
982}
983
984
985
986
987
988
989bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
990 bool EraseLoad = true;
991 Value *Row = nullptr, *Col = nullptr;
993 unsigned OpNo = U.getOperandNo();
995
996
998 return false;
999 std::tie(Row, Col) = getShape(II, OpNo);
1001
1002 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
1004
1005
1006
1007 if (!DT)
1008 DT.reset(new DominatorTree(Func));
1009 if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
1010
1011 auto *AllocaAddr =
1013 Builder.SetInsertPoint(&*std::next(LD->getIterator()));
1014 Builder.CreateStore(LD, AllocaAddr);
1015
1016 Builder.SetInsertPoint(Cast);
1017 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1018 EraseLoad = false;
1019 } else {
1020 I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy());
1021 }
1022 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
1023
1024 Value *NewInst =
1025 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
1027
1028 return EraseLoad;
1029}
1030
1031
1032
1033
1034bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) {
1035 Value *Row = nullptr, *Col = nullptr;
1037 unsigned OpNo = U.getOperandNo();
1040 return false;
1041
1042 std::tie(Row, Col) = getShape(II, OpNo);
1043
1045 Value *NewInst =
1046 Builder.CreateIntrinsic(Intrinsic::x86_tilezero_internal, {}, {Row, Col});
1048 return true;
1049}
1050
1051bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
1052 bool Change = false;
1053 for (auto *Cast : Casts) {
1055
1056
1057
1058
1059
1060 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1061 SmallVector<Instruction *, 2> DeadStores;
1062 for (User *U : Cast->users()) {
1064 if (!Store)
1065 continue;
1068 Change = true;
1069 }
1070 }
1071 for (auto *Store : DeadStores)
1072 Store->eraseFromParent();
1073 } else {
1074
1075
1076
1079 continue;
1080 }
1081
1083 if (!Load || ->hasOneUse())
1084 continue;
1085
1086
1087
1088
1089
1091
1093 Load->eraseFromParent();
1094 Change = true;
1095 }
1096 }
1097 }
1098 return Change;
1099}
1100
1101bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1102 bool Change = false;
1103
1104 SmallVector<Instruction *, 8> Vec2TileInsts;
1105 SmallVector<Instruction *, 8> Tile2VecInsts;
1106 SmallVector<Instruction *, 8> PhiCastWorkList;
1107 SmallSetVector<Instruction *, 16> DeadInst;
1108 for (BasicBlock &BB : Func) {
1109 for (Instruction &I : BB) {
1117 }
1118 }
1119
1120 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1121 for (auto *Inst : Insts) {
1122 for (User *U : Inst->users()) {
1124 if ( || II->getIntrinsicID() != IID)
1125 continue;
1126
1127
1128
1129
1130
1131
1132
1133 II->replaceAllUsesWith(Inst->getOperand(0));
1134 Change = true;
1135 }
1136 }
1137 };
1138
1139 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1140 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1141
1142 SmallVector<Instruction *, 8> LiveCasts;
1143 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1144 for (auto *Inst : Insts) {
1145 if (Inst->use_empty()) {
1146 Inst->eraseFromParent();
1147 Change = true;
1148 } else {
1150 }
1151 }
1152 };
1153
1154 EraseInst(Vec2TileInsts);
1155 EraseInst(Tile2VecInsts);
1156 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1157 "Vec2Tile and Tile2Vec:\n";
1158 Func.dump());
1159 Change |= combineLdSt(LiveCasts);
1160 EraseInst(LiveCasts);
1161 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1162 "AMXCast and load/store:\n";
1163 Func.dump());
1164
1165
1166 for (BasicBlock &BB : Func) {
1167 for (Instruction &I : BB) {
1171 }
1172 }
1173 }
1174 for (auto *I : PhiCastWorkList) {
1175
1177 continue;
1180 DeadInst.insert(PN);
1181 Change = true;
1182 }
1183 }
1184
1185
1186
1187 while (!DeadInst.empty()) {
1190 }
1191 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
1192 "optimizeAMXCastFromPhi:\n";
1193 Func.dump());
1194 return Change;
1195}
1196
1197
1198
1199bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1201 AllocaInst *AllocaAddr;
1202 Value *I8Ptr, *Stride;
1204
1205 auto Prepare = [&](Type *MemTy) {
1207 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1208 Stride = Builder.getInt64(64);
1209 };
1210
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1226 return true;
1227 }
1229 unsigned OpNo = U.getOperandNo();
1231 if ()
1232 return false;
1234 Builder.CreateStore(Src, AllocaAddr);
1235
1236 Value *Row = nullptr, *Col = nullptr;
1237 std::tie(Row, Col) = getShape(II, OpNo);
1238 std::array<Value *, 4> Args = {
1239 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1240 Value *NewInst =
1241 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args);
1244 } else {
1245
1246
1247
1248
1249
1250
1251
1253 if ()
1254 return false;
1255 Prepare(AMXCast->getType());
1256 Value *Row = II->getOperand(0);
1257 Value *Col = II->getOperand(1);
1258 std::array<Value *, 5> Args = {
1259 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1260 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, Args);
1261 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1264 }
1265
1266 return true;
1267}
1268
1269bool X86LowerAMXCast::transformAllAMXCast() {
1270 bool Change = false;
1271
1272 SmallVector<Instruction *, 8> WorkLists;
1273 for (BasicBlock &BB : Func) {
1274 for (Instruction &I : BB) {
1277 }
1278 }
1279
1280 for (auto *Inst : WorkLists) {
1282 }
1283
1284 return Change;
1285}
1286
1287bool lowerAmxType(Function &F, const TargetMachine *TM,
1288 TargetLibraryInfo *TLI) {
1289
1290
1291
1292
1293
1295 return false;
1296
1297 bool C = false;
1298 X86LowerAMXCast LAC(F);
1299 C |= LAC.combineAMXcast(TLI);
1300
1301
1302 C |= LAC.transformAllAMXCast();
1303
1304 X86LowerAMXType LAT(F);
1305 C |= LAT.visit();
1306
1307
1308
1309
1310 if (TM->getOptLevel() == CodeGenOptLevel::None) {
1311
1312
1313
1314
1315 if (.hasFnAttribute(Attribute::OptimizeNone)) {
1316 X86VolatileTileData VTD(F);
1317 C = VTD.volatileTileData() || C;
1318 }
1319 }
1320
1321 return C;
1322}
1323
1324}
1325
1329 bool Changed = lowerAmxType(F, TM, &TLI);
1332
1335 return PA;
1336}
1337
1338namespace {
1339
1340class X86LowerAMXTypeLegacyPass : public FunctionPass {
1341public:
1342 static char ID;
1343
1345
1349 &getAnalysis().getTLI(F);
1350 return lowerAmxType(F, TM, TLI);
1351 }
1352
1353 void getAnalysisUsage(AnalysisUsage &AU) const override {
1356 AU.addRequired();
1357 }
1358};
1359
1360}
1361
1362static const char PassName[] = "Lower AMX type for load/store";
1363char X86LowerAMXTypeLegacyPass::ID = 0;
1365 false)
1370
1372 return new X86LowerAMXTypeLegacyPass();
1373}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static bool DCEInstruction(Instruction *I, SmallSetVector< Instruction *, 16 > &WorkList, const TargetLibraryInfo *TLI)
static bool runOnFunction(Function &F, bool PostInlining)
This header defines various interfaces for pass management in LLVM.
uint64_t IntrinsicInst * II
FunctionAnalysisManager FAM
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file implements a set that has insertion order iteration characteristics.
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
static ShapeT getShape(MachineRegisterInfo *MRI, Register TileReg)
static const char PassName[]
static bool isAMXCast(Instruction *II)
Definition X86LowerAMXType.cpp:71
static Value * getRowFromCol(Instruction *II, Value *V, unsigned Granularity)
Definition X86LowerAMXType.cpp:124
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI=false)
Definition X86LowerAMXType.cpp:517
static Instruction * createTileStore(Instruction *TileDef, Value *Ptr)
Definition X86LowerAMXType.cpp:498
static Value * getAllocaPos(BasicBlock *BB)
Definition X86LowerAMXType.cpp:483
static bool containsAMXCode(Function &F)
Definition X86LowerAMXType.cpp:95
std::pair< Value *, Value * > getShape(IntrinsicInst *II, unsigned OpNo)
Definition X86LowerAMXType.cpp:157
static bool isIncomingOfPHI(Instruction *I)
Definition X86LowerAMXType.cpp:542
static bool isAMXIntrinsic(Value *I)
Definition X86LowerAMXType.cpp:77
static Instruction * getFirstNonAllocaInTheEntryBlock(Function &F)
Definition X86LowerAMXType.cpp:117
static AllocaInst * createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty)
Definition X86LowerAMXType.cpp:103
an instruction to allocate memory on the stack
void setAlignment(Align Align)
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
LLVM Basic Block Representation.
const Function * getParent() const
Return the enclosing method, or null if none.
InstListType::iterator iterator
Instruction iterators...
This class represents a no-op cast from one type to another.
Represents analyses that only rely on functions' control flow.
A parsed version of the target data layout string in and methods for querying it.
FunctionPass class - This class is used to implement most global optimizations.
Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
ConstantInt * getInt16(uint16_t C)
Get a constant 16-bit value.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
LLVM_ABI void moveBefore(InstListType::iterator InsertPos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
A wrapper class for inspecting calls to intrinsic functions.
This is an important class for using LLVM in a threaded context.
An instruction for reading from memory.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
bool contains(const_arg_type key) const
Check if the SetVector contains the given key.
bool empty() const
Determine if the SetVector is empty or not.
bool insert(const value_type &X)
Insert a new element into the SetVector.
value_type pop_back_val()
void push_back(const T &Elt)
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
Primary interface to the complete machine description for the target machine.
CodeGenOptLevel getOptLevel() const
Returns the optimization level: None, Less, Default, or Aggressive.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
static LLVM_ABI Type * getX86_AMXTy(LLVMContext &C)
bool isX86_AMXTy() const
Return true if this is X86 AMX.
A Use represents the edge between a Value definition and its users.
LLVM_ABI unsigned getOperandNo() const
Return the operand # of this use in its User.
User * getUser() const
Returns the User that contains this Use.
void setOperand(unsigned i, Value *Val)
LLVM_ABI bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
static LLVM_ABI VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM)
Definition X86LowerAMXType.cpp:1326
const ParentTy * getParent() const
self_iterator getIterator()
Pass manager infrastructure for declaring and invalidating analyses.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
@ BasicBlock
Various leaf nodes.
@ Bitcast
Perform the operation on a different, but equivalently sized type.
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_IntrinsicIntrinsic::fabs(m_Value(X))
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
@ User
could "use" a pointer
NodeAddr< UseNode * > Use
NodeAddr< FuncNode * > Func
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
decltype(auto) dyn_cast(const From &Val)
dyn_cast - Return the argument parameter cast to the specified type.
LLVM_ABI void salvageDebugInfo(const MachineRegisterInfo &MRI, MachineInstr &MI)
Assuming the instruction MI is going to be deleted, attempt to salvage debug users of MI by writing t...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
iterator_range< po_iterator< T > > post_order(const T &G)
LLVM_ABI bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
auto reverse(ContainerTy &&C)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa - Return true if the parameter to the template is an instance of one of the template type argu...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
LLVM_ABI bool salvageKnowledge(Instruction *I, AssumptionCache *AC=nullptr, DominatorTree *DT=nullptr)
Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert if before I.
DWARFExpression::Operation Op
FunctionPass * createX86LowerAMXTypeLegacyPass()
Definition X86LowerAMXType.cpp:1371
decltype(auto) cast(const From &Val)
cast - Return the argument parameter cast to the specified type.
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.