LLVM: lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
47
48#include
49
50using namespace llvm;
51using namespace PatternMatch;
52
53#define DEBUG_TYPE "lower-matrix-intrinsics"
54
57 cl::desc("Enable/disable fusing matrix instructions."));
58
62 "Tile size for matrix instruction fusion using square-shaped tiles."));
65 cl::desc("Generate loop nest for tiling."));
68 cl::desc("Force matrix instruction fusion even if not profitable."));
71 cl::desc("Allow the use of FMAs if available and profitable. This may "
72 "result in different results, due to less rounding error."));
73
76 cl::desc("Enable/disable matrix shape verification."),
78
80
83 cl::desc("Sets the default matrix layout"),
85 "Use column-major layout"),
87 "Use row-major layout")));
88
91
92
93
95 if (auto *Subprogram = dyn_cast(Scope))
96 return Subprogram;
97 return cast(Scope)->getSubprogram();
98}
99
100
101
103 if (auto *SV = dyn_cast(V))
104 return SV->isZeroEltSplat();
105 return false;
106}
107
108
109template <typename LTy, typename RTy>
110auto m_AnyMul(const LTy &L, const RTy &R) {
112}
113
114
115template <typename LTy, typename RTy>
116auto m_AnyAdd(const LTy &L, const RTy &R) {
118}
119
120namespace {
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
163 unsigned NumElements, Type *EltType,
165
166 assert((!isa(Stride) ||
167 cast(Stride)->getZExtValue() >= NumElements) &&
168 "Stride must be >= the number of elements in the result vector.");
169
170
171 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
172
173
174
175 if (isa(VecStart) && cast(VecStart)->isZero())
177 else
178 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
179
180 return VecStart;
181}
182
183namespace {
184struct ShapeInfo {
185 unsigned NumRows;
186 unsigned NumColumns;
187
188 bool IsColumnMajor;
189
190 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
191 : NumRows(NumRows), NumColumns(NumColumns),
193
194 ShapeInfo(Value *NumRows, Value *NumColumns)
197
198 bool operator==(const ShapeInfo &other) {
199 return NumRows == other.NumRows && NumColumns == other.NumColumns;
200 }
201 bool operator!=(const ShapeInfo &other) { return !(*this == other); }
202
203
204
205 operator bool() const {
206 assert(NumRows == 0 || NumColumns != 0);
207 return NumRows != 0;
208 }
209
210 unsigned getStride() const {
211 if (IsColumnMajor)
212 return NumRows;
213 return NumColumns;
214 }
215
216 unsigned getNumVectors() const {
217 if (IsColumnMajor)
218 return NumColumns;
219 return NumRows;
220 }
221
222
223 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
224};
225}
226
227static bool isUniformShape(Value *V) {
229 if ()
230 return true;
231
232 switch (I->getOpcode()) {
233 case Instruction::FAdd:
234 case Instruction::FSub:
235 case Instruction::FMul:
236 case Instruction::FNeg:
237 case Instruction::Add:
238 case Instruction::Mul:
239 case Instruction::Sub:
240 return true;
241 default:
242 return false;
243 }
244}
245
246
247static std::optional
253 if (match(I, m_IntrinsicIntrinsic::matrix\_multiply(
255 return ShapeInfo(M, K);
258
259 return ShapeInfo(N, M);
260 }
261 if (match(I, m_IntrinsicIntrinsic::matrix\_column\_major\_store(
264 return ShapeInfo(N, M);
265 if (match(I, m_IntrinsicIntrinsic::matrix\_column\_major\_load(
267 return ShapeInfo(M, N);
270 auto OpShape = ShapeMap.find(MatrixA);
271 if (OpShape != ShapeMap.end())
272 return OpShape->second;
273 }
274
275 if (isUniformShape(I)) {
276
277 for (auto &Op : I->operands()) {
278 auto OpShape = ShapeMap.find(Op.get());
279 if (OpShape != ShapeMap.end())
280 return OpShape->second;
281 }
282 }
283 return std::nullopt;
284}
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309class LowerMatrixIntrinsics {
318
319
320 struct OpInfoTy {
321
322 unsigned NumStores = 0;
323
324 unsigned NumLoads = 0;
325
326 unsigned NumComputeOps = 0;
327
328
329
330 unsigned NumExposedTransposes = 0;
331
332 OpInfoTy &operator+=(const OpInfoTy &RHS) {
333 NumStores += RHS.NumStores;
334 NumLoads += RHS.NumLoads;
335 NumComputeOps += RHS.NumComputeOps;
336 NumExposedTransposes += RHS.NumExposedTransposes;
337 return *this;
338 }
339 };
340
341
342
343 class MatrixTy {
345
346 OpInfoTy OpInfo;
347
348 bool IsColumnMajor = true;
349
350 public:
353 : Vectors(Vectors),
355 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
357
358 unsigned D = isColumnMajor() ? NumColumns : NumRows;
359 for (unsigned J = 0; J < D; ++J)
361 EltTy, isColumnMajor() ? NumRows : NumColumns)));
362 }
363
364 Value *getVector(unsigned i) const { return Vectors[i]; }
365 Value *getColumn(unsigned i) const {
366 assert(isColumnMajor() && "only supported for column-major matrixes");
367 return Vectors[i];
368 }
369 Value *getRow(unsigned i) const {
370 assert(!isColumnMajor() && "only supported for row-major matrixes");
371 return Vectors[i];
372 }
373
374 void setVector(unsigned i, Value *V) { Vectors[i] = V; }
375
376 Type *getElementType() const { return getVectorTy()->getElementType(); }
377
378 unsigned getNumVectors() const {
379 if (isColumnMajor())
380 return getNumColumns();
381 return getNumRows();
382 }
383
384 unsigned getNumColumns() const {
385 if (isColumnMajor())
386 return Vectors.size();
387 else {
388 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
389 return cast(Vectors[0]->getType())->getNumElements();
390 }
391 }
392 unsigned getNumRows() const {
393 if (isColumnMajor()) {
394 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
395 return cast(Vectors[0]->getType())->getNumElements();
396 } else
397 return Vectors.size();
398 }
399
402 assert(isColumnMajor() && "only supported for column-major matrixes");
403 return getVectorTy();
404 }
405
407 return cast(Vectors[0]->getType());
408 }
409
411 assert(isColumnMajor() &&
412 "columns() only supported for column-major matrixes");
414 }
415
418 }
419
420
421
423 return Vectors.size() == 1 ? Vectors[0]
425 }
426
427 MatrixTy &addNumLoads(unsigned N) {
428 OpInfo.NumLoads += N;
429 return *this;
430 }
431
432 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
433
434 MatrixTy &addNumStores(unsigned N) {
435 OpInfo.NumStores += N;
436 return *this;
437 }
438
439 MatrixTy &addNumExposedTransposes(unsigned N) {
440 OpInfo.NumExposedTransposes += N;
441 return *this;
442 }
443
444 MatrixTy &addNumComputeOps(unsigned N) {
445 OpInfo.NumComputeOps += N;
446 return *this;
447 }
448
449 unsigned getNumStores() const { return OpInfo.NumStores; }
450 unsigned getNumLoads() const { return OpInfo.NumLoads; }
451 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
452
453 const OpInfoTy &getOpInfo() const { return OpInfo; }
454
455 bool isColumnMajor() const { return IsColumnMajor; }
456
457 unsigned getStride() const {
458 if (isColumnMajor())
459 return getNumRows();
460 return getNumColumns();
461 }
462
463
464
465
468 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
469 assert(cast(Vec->getType())->getNumElements() >=
470 NumElts &&
471 "Extracted vector will contain poison values");
474 "block");
475 }
476 };
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
493
494
495
496
498
499
501
502private:
505
506 if (isa(*Inst))
508
510
511 return FMF;
512 }
513
514public:
517 : Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {}
518
519 unsigned getNumOps(Type *VT) {
520 assert(isa(VT) && "Expected vector type");
522 cast(VT)->getNumElements());
523 }
524
525
526 bool isMinimal() const {
527 return !DT;
528 }
529
530
531
532 unsigned getNumOps(Type *ST, unsigned N) {
533 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /
537 }
538
539
540
541
542
543
544 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
547 assert(VType && "MatrixVal must be a vector type");
549 SI.NumRows * SI.NumColumns &&
550 "The vector size must match the number of matrix elements");
551
552
553
554
555
556 auto Found = Inst2ColumnMatrix.find(MatrixVal);
557 if (Found != Inst2ColumnMatrix.end()) {
558 MatrixTy &M = Found->second;
559
560
561 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
562 return M;
563
564 MatrixVal = M.embedInVector(Builder);
565 }
566
567
569 for (unsigned MaskStart = 0;
570 MaskStart < cast(VType)->getNumElements();
571 MaskStart += SI.getStride()) {
574 "split");
576 }
577
578 return {SplitVecs};
579 }
580
581
582
583 bool setShapeInfo(Value *V, ShapeInfo Shape) {
584 assert(Shape && "Shape not set");
585 if (isa(V) || !supportsShapeInfo(V))
586 return false;
587
588 auto SIter = ShapeMap.find(V);
589 if (SIter != ShapeMap.end()) {
590 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
591 SIter->second.NumColumns != Shape.NumColumns)) {
592 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
593 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
594 << Shape.NumColumns << ") for " << *V << "\n";
596 "Matrix shape verification failed, compilation aborted!");
597 }
598
599 LLVM_DEBUG(dbgs() << " not overriding existing shape: "
600 << SIter->second.NumRows << " "
601 << SIter->second.NumColumns << " for " << *V << "\n");
602 return false;
603 }
604
605 ShapeMap.insert({V, Shape});
606 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
607 << " for " << *V << "\n");
608 return true;
609 }
610
611
612
613 bool supportsShapeInfo(Value *V) {
614 Instruction *Inst = dyn_cast(V);
615 if (!Inst)
616 return false;
617
619 if (II)
620 switch (II->getIntrinsicID()) {
621 case Intrinsic::matrix_multiply:
622 case Intrinsic::matrix_transpose:
623 case Intrinsic::matrix_column_major_load:
624 case Intrinsic::matrix_column_major_store:
625 return true;
626 default:
627 return false;
628 }
629 return isUniformShape(V) || isa(V) || isa(V);
630 }
631
632
633
634
635
639
640
641
643 while (!WorkList.empty()) {
645
646
647 bool Propagate = false;
648 if (auto SI = computeShapeInfoForInst(Inst, ShapeMap))
649 Propagate = setShapeInfo(Inst, *SI);
650
651 if (Propagate) {
656 }
657 }
658
659 return NewWorkList;
660 }
661
662
663
667
668 auto pushInstruction = [](Value *V,
671 if (I)
673 };
674
675
676
677 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
678 while (!WorkList.empty()) {
680
681 size_t BeforeProcessingV = WorkList.size();
682 if (!isa(V))
683 continue;
684
690 if (match(V, m_IntrinsicIntrinsic::matrix\_multiply(
693 if (setShapeInfo(MatrixA, {M, N}))
694 pushInstruction(MatrixA, WorkList);
695
696 if (setShapeInfo(MatrixB, {N, K}))
697 pushInstruction(MatrixB, WorkList);
698
699 } else if (match(V, m_IntrinsicIntrinsic::matrix\_transpose(
701
702 if (setShapeInfo(MatrixA, {M, N}))
703 pushInstruction(MatrixA, WorkList);
704 } else if (match(V, m_IntrinsicIntrinsic::matrix\_column\_major\_store(
707 if (setShapeInfo(MatrixA, {M, N})) {
708 pushInstruction(MatrixA, WorkList);
709 }
710 } else if (isa(V) ||
711 match(V, m_IntrinsicIntrinsic::matrix\_column\_major\_load())) {
712
713 } else if (isa(V)) {
714
715
716 } else if (isUniformShape(V)) {
717
718 ShapeInfo Shape = ShapeMap[V];
719 for (Use &U : cast(V)->operands()) {
720 if (setShapeInfo(U.get(), Shape))
721 pushInstruction(U.get(), WorkList);
722 }
723 }
724
725
726
727 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
729 if (isa(U) && V != U)
730 NewWorkList.push_back(cast(U));
731 }
732 return NewWorkList;
733 }
734
735
736
737
739 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
744 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
745
746
747 setShapeInfo(T0, Shape0.t());
749 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
750 setShapeInfo(T1, Shape1.t());
751 return Operation(T0, Shape0.t(), T1, Shape1.t());
752 }
753
754
755
756 void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {
757 auto Iter = ShapeMap.find(Inst);
758 if (Iter != ShapeMap.end())
759 ShapeMap.erase(Iter);
761 }
762
763
764
767 auto *Inst = cast(V);
768
770 return;
771 if (II != BB.rend() && Inst == &*II)
772 ++II;
773 eraseFromParentAndRemoveFromShapeMap(Inst);
774 }
775
776
777
778 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
779
780
781
782 auto S = ShapeMap.find(&Old);
783 if (S != ShapeMap.end()) {
784 ShapeMap.erase(S);
785 if (supportsShapeInfo(New))
786 ShapeMap.insert({New, S->second});
787 }
789 }
790
791
792
793
794
799
802 if ((&I, m_IntrinsicIntrinsic::matrix\_transpose(
804 return nullptr;
805
806
808 if (match(TA, m_IntrinsicIntrinsic::matrix\_transpose(m_Value(TATA)))) {
809 updateShapeAndReplaceAllUsesWith(I, TATA);
810 eraseFromParentAndMove(&I, II, BB);
811 eraseFromParentAndMove(TA, II, BB);
812 return nullptr;
813 }
814
815
817 updateShapeAndReplaceAllUsesWith(I, TA);
818 eraseFromParentAndMove(&I, II, BB);
819 return nullptr;
820 }
821
822
823
824 if (match(TA, m_IntrinsicIntrinsic::matrix\_multiply(
827 auto NewInst = distributeTransposes(
828 TAMB, {K, C}, TAMA, {R, K}, Builder,
829 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
831 Shape0.NumColumns,
832 Shape1.NumColumns, "mmul");
833 });
834 updateShapeAndReplaceAllUsesWith(I, NewInst);
835 eraseFromParentAndMove(&I, II, BB);
836 eraseFromParentAndMove(TA, II, BB);
837 return NewInst;
838 }
839
840
841
842
843
847
848
849 auto NewInst = distributeTransposes(
850 TAMA, {R, C}, TAMB, {R, C}, Builder,
851 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
852 bool IsFP = I.getType()->isFPOrFPVectorTy();
853 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
854 : LocalBuilder.CreateMul(T0, T1, "mmul");
856 setShapeInfo(Result, Shape0);
858 });
859 updateShapeAndReplaceAllUsesWith(I, NewInst);
860 eraseFromParentAndMove(&I, II, BB);
861 eraseFromParentAndMove(TA, II, BB);
862 return NewInst;
863 }
864
865
866
869 auto NewInst = distributeTransposes(
870 TAMA, {R, C}, TAMB, {R, C}, Builder,
871 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
872 bool IsFP = I.getType()->isFPOrFPVectorTy();
873 auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")
874 : LocalBuilder.CreateAdd(T0, T1, "madd");
875
877 setShapeInfo(Result, Shape0);
879 });
880 updateShapeAndReplaceAllUsesWith(I, NewInst);
881 eraseFromParentAndMove(&I, II, BB);
882 eraseFromParentAndMove(TA, II, BB);
883 return NewInst;
884 }
885
886 return nullptr;
887 }
888
890
892 if (T.use_empty())
893 eraseFromParentAndRemoveFromShapeMap(&T);
894 if (A->use_empty())
895 eraseFromParentAndRemoveFromShapeMap(cast(A));
896 if (A != B && B->use_empty())
897 eraseFromParentAndRemoveFromShapeMap(cast(B));
898 };
899
902
903 if (match(&I, m_IntrinsicIntrinsic::matrix\_multiply(
906 match(A, m_IntrinsicIntrinsic::matrix\_transpose(m_Value(AT))) &&
907 match(B, m_IntrinsicIntrinsic::matrix\_transpose(m_Value((BT))))) {
911 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
914 R->getZExtValue());
915 updateShapeAndReplaceAllUsesWith(I, NewInst);
917 }
918
919
920
922 match(A, m_IntrinsicIntrinsic::matrix\_transpose(
924 match(B, m_IntrinsicIntrinsic::matrix\_transpose(
927 auto *Add = Builder.CreateFAdd(AT, BT, "mfadd");
929 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
930 Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
931 updateShapeAndReplaceAllUsesWith(I, NewInst);
932 assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
933 computeShapeInfoForInst(&I, ShapeMap) &&
934 "Shape of new instruction doesn't match original shape.");
936 if (auto *AddI = dyn_cast(Add)) {
937 setShapeInfo(AddI, {R, C});
939 computeShapeInfoForInst(AddI, ShapeMap).value_or(ShapeMap[AddI]) ==
940 ShapeMap[AddI] &&
941 "Shape of updated addition doesn't match cached shape.");
942 }
943 }
944 }
945
946
947 void optimizeTransposes() {
948
949
953
954 ++II;
957 }
958 }
959
960
961
964 liftTranspose(I);
965 }
966 }
967 }
968
969 bool Visit() {
971
972
973
977 if ()
978 continue;
979
980 switch (II->getIntrinsicID()) {
981 case Intrinsic::matrix_multiply:
982 case Intrinsic::matrix_transpose:
983 case Intrinsic::matrix_column_major_load:
984 case Intrinsic::matrix_column_major_store:
986 break;
987 default:
988 break;
989 }
990 }
991
992
993 if (WorkList.empty())
994 return false;
995
996 if (AM) {
1001 }
1002
1003
1004 while (!WorkList.empty()) {
1005 WorkList = propagateShapeForward(WorkList);
1006 WorkList = propagateShapeBackward(WorkList);
1007 }
1008
1009 if (!isMinimal()) {
1010 optimizeTransposes();
1012 dbgs() << "Dump after matrix transpose optimization:\n";
1014 }
1015 }
1016
1017 bool Changed = false;
1021
1022
1023
1025 for (auto *BB : RPOT)
1027 if (match(&I, m_IntrinsicIntrinsic::lifetime\_end()))
1028 LifetimeEnds.push_back(cast(&I));
1029 if (ShapeMap.find(&I) == ShapeMap.end())
1030 continue;
1031 if (match(&I, m_IntrinsicIntrinsic::matrix\_multiply()))
1032 MaybeFusableInsts.push_back(cast(&I));
1034 }
1035
1036
1038 for (CallInst *CI : MaybeFusableInsts)
1039 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));
1040
1041
1042 for (CallInst *CI : MaybeFusableInsts)
1043 if (!FusedInsts.contains(CI))
1044 LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
1045
1046 Changed = !FusedInsts.empty();
1047
1048
1050 if (FusedInsts.count(Inst))
1051 continue;
1052
1054
1055 if (CallInst *CInst = dyn_cast(Inst))
1056 Changed |= VisitCallInst(CInst);
1057
1060 if (auto *BinOp = dyn_cast(Inst))
1061 Changed |= VisitBinaryOperator(BinOp);
1062 if (auto *UnOp = dyn_cast(Inst))
1063 Changed |= VisitUnaryOperator(UnOp);
1065 Changed |= VisitLoad(cast(Inst), Op1, Builder);
1067 Changed |= VisitStore(cast(Inst), Op1, Op2, Builder);
1068 }
1069
1070 if (ORE) {
1071 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1072 RemarkGen.emitRemarks();
1073 }
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1085 for (auto *Inst : reverse(ToRemove)) {
1087 if (auto *Poisoned = dyn_cast(U.getUser()))
1088 PoisonedInsts.insert(Poisoned);
1090 }
1092 PoisonedInsts.erase(Inst);
1093 }
1094 if (!PoisonedInsts.empty()) {
1095
1096 dbgs() << "Poisoned but present instructions:\n";
1097 for (auto *I : PoisonedInsts)
1100 }
1101
1102 return Changed;
1103 }
1104
1105
1106 bool VisitCallInst(CallInst *Inst) {
1108 return false;
1109
1111 case Intrinsic::matrix_multiply:
1112 LowerMultiply(Inst);
1113 break;
1114 case Intrinsic::matrix_transpose:
1115 LowerTranspose(Inst);
1116 break;
1117 case Intrinsic::matrix_column_major_load:
1118 LowerColumnMajorLoad(Inst);
1119 break;
1120 case Intrinsic::matrix_column_major_store:
1121 LowerColumnMajorStore(Inst);
1122 break;
1123 default:
1124 return false;
1125 }
1126 return true;
1127 }
1128
1129
1130
1131
1132
1133
1134 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
1136 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
1137 if (Idx == 0)
1138 return InitialAlign;
1139
1140 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
1141 if (auto *ConstStride = dyn_cast(Stride)) {
1143 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1145 }
1146 return commonAlignment(InitialAlign, ElementSizeInBits / 8);
1147 }
1148
1149
1150
1152 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1153 auto *VType = cast(Ty);
1154 Type *EltTy = VType->getElementType();
1158 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
1159 Value *GEP = computeVectorAddr(
1161 Stride, Shape.getStride(), EltTy, Builder);
1163 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),
1164 IsVolatile, "col.load");
1165
1167 }
1168 return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
1169 Result.getNumVectors());
1170 }
1171
1172
1173
1175 ShapeInfo MatrixShape, Value *I, Value *J,
1176 ShapeInfo ResultShape, Type *EltTy,
1178
1180 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1181
1184 ResultShape.NumColumns);
1185
1186 return loadMatrix(TileTy, TileStart, Align,
1187 Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1188 ResultShape, Builder);
1189 }
1190
1191
1193 bool IsVolatile, ShapeInfo Shape) {
1195 finalizeLowering(Inst,
1196 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
1197 Shape, Builder),
1198 Builder);
1199 }
1200
1201
1202
1203
1204 void LowerColumnMajorLoad(CallInst *Inst) {
1206 "Intrinsic only supports column-major layout!");
1210 cast(Inst->getArgOperand(2))->isOne(),
1211 {Inst->getArgOperand(3), Inst->getArgOperand(4)});
1212 }
1213
1214
1215
1216 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1217 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1220 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1221
1224 StoreVal.getNumColumns());
1225
1226 storeMatrix(TileTy, StoreVal, TileStart, MAlign,
1227 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1228 }
1229
1230
1231
1232 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1235 auto VType = cast(Ty);
1237 for (auto Vec : enumerate(StoreVal.vectors())) {
1238 Value *GEP = computeVectorAddr(
1239 EltPtr,
1241 Vec.index()),
1242 Stride, StoreVal.getStride(), VType->getElementType(), Builder);
1244 getAlignForIndex(Vec.index(), Stride,
1245 VType->getElementType(),
1246 MAlign),
1247 IsVolatile);
1248 }
1249 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
1250 StoreVal.getNumVectors());
1251 }
1252
1253
1255 Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1257 auto StoreVal = getMatrix(Matrix, Shape, Builder);
1258 finalizeLowering(Inst,
1259 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
1260 IsVolatile, Builder),
1261 Builder);
1262 }
1263
1264
1265
1266
1267 void LowerColumnMajorStore(CallInst *Inst) {
1269 "Intrinsic only supports column-major layout!");
1274 cast(Inst->getArgOperand(3))->isOne(),
1275 {Inst->getArgOperand(4), Inst->getArgOperand(5)});
1276 }
1277
1278
1281
1282
1283 unsigned BlockNumElts =
1284 cast(Block->getType())->getNumElements();
1285 unsigned NumElts = cast(Col->getType())->getNumElements();
1286 assert(NumElts >= BlockNumElts && "Too few elements for current block");
1287
1290
1291
1292
1294 unsigned i;
1295 for (i = 0; i < I; i++)
1296 Mask.push_back(i);
1297
1298 unsigned VecNumElts =
1299 cast(Col->getType())->getNumElements();
1300 for (; i < I + BlockNumElts; i++)
1301 Mask.push_back(i - I + VecNumElts);
1302
1303 for (; i < VecNumElts; i++)
1304 Mask.push_back(i);
1305
1307 }
1308
1310 IRBuilder<> &Builder, bool AllowContraction,
1311 unsigned &NumComputeOps) {
1312 NumComputeOps += getNumOps(A->getType());
1313 if (!Sum)
1315
1316 if (UseFPOp) {
1317 if (AllowContraction) {
1318
1319
1320 return Builder.CreateIntrinsic(Intrinsic::fmuladd, A->getType(),
1321 {A, B, Sum});
1322 }
1323 NumComputeOps += getNumOps(A->getType());
1326 }
1327
1328 NumComputeOps += getNumOps(A->getType());
1331 }
1332
1333
1334
1335
1336
1337
1340 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
1341 (void)inserted;
1342 assert(inserted.second && "multiple matrix lowering mapping");
1343
1345 Value *Flattened = nullptr;
1347 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
1348 if (!Flattened)
1349 Flattened = Matrix.embedInVector(Builder);
1350 U.set(Flattened);
1351 }
1352 }
1353 }
1354
1355
1356
1357
1358 void lowerDotProduct(CallInst *MatMul,
1361 if (FusedInsts.contains(MatMul) ||
1362 MatrixLayout != MatrixLayoutTy::ColumnMajor)
1363 return;
1366
1367 if (LShape.NumRows != 1 || RShape.NumColumns != 1)
1368 return;
1369
1372
1374 bool IsIntVec = ElementType->isIntegerTy();
1375
1376
1378 return;
1379
1380 auto CanBeFlattened = [](Value *Op) {
1382 return true;
1386 m_CombineOr(m_IntrinsicIntrinsic::matrix\_transpose(),
1387 m_IntrinsicIntrinsic::matrix\_column\_major\_load(
1389 };
1390
1391
1392
1393 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
1394 if (ShapeMap.find(Op) == ShapeMap.end())
1396
1397 if (!isa(Op))
1399
1402
1403 if (!CanBeFlattened(Op)) {
1405
1406 for (unsigned I = 1; I < N; ++I)
1407 EmbedCost +=
1410 return EmbedCost;
1411 }
1412
1416 EltTy) *
1417 N;
1419 cast(Op)->getOpcode(), VecTy);
1420 return NewCost - OriginalCost;
1421 }
1422
1423 if (match(Op, m_IntrinsicIntrinsic::matrix\_transpose())) {
1424
1425
1426
1428 for (unsigned I = 1; I < N; ++I)
1429 EmbedCost -=
1432 return EmbedCost;
1433 }
1434
1435
1436 if (N == 1)
1438
1441 };
1442
1443
1444
1445
1451 while (!WorkList.empty()) {
1453 if (!Seen.insert(Op).second)
1454 continue;
1455
1457 if (OpCost + LHSCost >= LHSCost)
1458 continue;
1459
1460 LHSCost += OpCost;
1462 if (auto *I = dyn_cast(Op))
1463 WorkList.append(I->op_begin(), I->op_end());
1464 }
1465
1466
1467 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1468 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1471 AddOpCode, cast(LHS->getType()),
1472 IsIntVec ? std::nullopt : std::optional(FMF)) +
1476 (LShape.NumColumns - 1) +
1478 (LShape.NumColumns);
1479 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))
1480 return;
1481
1482 FusedInsts.insert(MatMul);
1484 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1486
1487
1488
1489 if (!CanBeFlattened(Op))
1490 return;
1491
1493 auto It = ShapeMap.find(Op);
1494 if (It != ShapeMap.end()) {
1495 It->second = It->second.t();
1496 return;
1497 }
1498 }
1499
1500 FusedInsts.insert(cast(Op));
1501
1503 if (match(Op, m_IntrinsicIntrinsic::matrix\_column\_major\_load(
1505 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
1506 Op->replaceAllUsesWith(NewLoad);
1507 eraseFromParentAndRemoveFromShapeMap(cast(Op));
1508 return;
1509 } else if (match(Op, m_IntrinsicIntrinsic::matrix\_transpose(
1511 ToRemove.push_back(cast(Op));
1512 Op->replaceAllUsesWith(Arg);
1513 return;
1514 }
1515 };
1516
1517 for (auto *V : ToFlatten)
1518 FlattenArg(V);
1519
1521
1522
1525
1527 if (IsIntVec)
1529 else {
1531 ConstantFP::get(cast(LHS->getType())->getElementType(),
1532 0.0),
1535 }
1536
1537
1541 FusedInsts.insert(MatMul);
1543 }
1544
1545
1546
1547
1548
1549
1550
1551
1552 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1553 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1554 bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1555 const unsigned VF = std::max(
1558 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1559 1U);
1560 unsigned R = Result.getNumRows();
1561 unsigned C = Result.getNumColumns();
1562 unsigned M = A.getNumColumns();
1563
1564 bool IsFP = Result.getElementType()->isFloatingPointTy();
1565 assert(A.isColumnMajor() == B.isColumnMajor() &&
1566 Result.isColumnMajor() == A.isColumnMajor() &&
1567 "operands must agree on matrix layout");
1568 unsigned NumComputeOps = 0;
1569
1571
1572 if (A.isColumnMajor()) {
1573
1574
1575
1576 for (unsigned J = 0; J < C; ++J) {
1578
1579 bool isSumZero = isa(Result.getColumn(J));
1580
1582
1585
1587 : nullptr;
1588 for (unsigned K = 0; K < M; ++K) {
1591 B.getColumn(IsScalarMatrixTransposed ? K : J),
1592 IsScalarMatrixTransposed ? J : K);
1594 Sum =
1595 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
1596 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1597 }
1600 }
1601 }
1602 } else {
1603
1604
1605
1606 for (unsigned I = 0; I < R; ++I) {
1608 bool isSumZero = isa(Result.getRow(I));
1609 for (unsigned J = 0; J < C; J += BlockSize) {
1610
1613
1614 Value *Sum = nullptr;
1615 for (unsigned K = 0; K < M; ++K) {
1618 A.getVector(IsScalarMatrixTransposed ? K : I),
1619 IsScalarMatrixTransposed ? I : K);
1621 Sum =
1622 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
1623 IsFP, Builder, FMF.allowContract(), NumComputeOps);
1624 }
1627 }
1628 }
1629 }
1630 Result.addNumComputeOps(NumComputeOps);
1631 }
1632
1633
1634
1635
1640
1641
1642 if (AA->isNoAlias(LoadLoc, StoreLoc))
1643 return Load->getPointerOperand();
1644
1645
1646
1647
1648
1650
1651
1652
1656
1659 nullptr, "alias_cont");
1662 nullptr, "copy");
1665 nullptr, "no_alias");
1666
1667
1668
1669
1675 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
1677 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
1678 "store.end", true, true);
1680 IntPtrTy, "load.begin");
1682 Fusion);
1683
1684
1685
1686
1690 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
1691 "load.end", true, true);
1693 Fusion);
1694
1695
1697 auto *VT = cast(Load->getType());
1698
1699
1700 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());
1702 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());
1703
1708 PHI->addIncoming(Load->getPointerOperand(), Check0);
1709 PHI->addIncoming(Load->getPointerOperand(), Check1);
1710 PHI->addIncoming(Alloca, Copy);
1711
1712
1718 return PHI;
1719 }
1720
1721 bool isFusionProfitable(CallInst *MatMul) {
1723 return true;
1724
1727
1728 const unsigned R = LShape.NumRows;
1729 const unsigned C = RShape.NumColumns;
1730 const unsigned M = LShape.NumColumns;
1731 auto *EltType = cast(MatMul->getType())->getElementType();
1732
1733 const unsigned VF = std::max(
1737 1U);
1738
1739
1740
1741
1742
1743
1744
1745 if (R <= VF && C == 1)
1746 return false;
1747
1748
1749
1750
1751 unsigned Op0Regs = (R + VF - 1) / VF * M;
1752 unsigned Op1Regs = (M + VF - 1) / VF * C;
1753 return Op0Regs + Op1Regs >
1755 }
1756
1757 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1758 MatrixTy Res;
1760 for (unsigned I = 0; I < C; ++I)
1762 return Res;
1763 }
1764
1765 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1767 auto *EltType = cast(MatMul->getType())->getElementType();
1768
1769
1770 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1771 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1772 Instruction *InsertI = cast(MatMul);
1777 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);
1778
1779 Type *TileVecTy =
1781 MatrixTy TileResult;
1782
1783 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1784
1786 for (unsigned I = 0; I < TileSize; I++) {
1787 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));
1789 TI.RowLoop.Header->getSingleSuccessor());
1790 TileResult.addVector(Phi);
1792 }
1793
1794
1795
1797
1798 MatrixTy A =
1799 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,
1801 MatrixTy B =
1802 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,
1804 emitMatrixMultiply(TileResult, A, B, Builder, true, false,
1805 getFastMathFlags(MatMul));
1806
1807 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1808 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),
1809 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},
1810 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);
1811
1812 for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1813 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);
1814
1815
1816
1817
1818
1819 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);
1821 "llvm.loop.unroll.count", InnerLoopUnrollCount);
1822 }
1823
1828 "Tiling only supported for column-major matrixes at the moment!");
1829 if (!isFusionProfitable(MatMul))
1830 return;
1831
1834
1835 const unsigned R = LShape.NumRows;
1836 const unsigned C = RShape.NumColumns;
1837 const unsigned M = LShape.NumColumns;
1838 auto *EltType = cast(MatMul->getType())->getElementType();
1839
1840 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
1841 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
1842 Value *CPtr = Store->getPointerOperand();
1843
1845 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);
1846 else {
1848 for (unsigned J = 0; J < C; J += TileSize)
1849 for (unsigned I = 0; I < R; I += TileSize) {
1850 const unsigned TileR = std::min(R - I, unsigned(TileSize));
1851 const unsigned TileC = std::min(C - J, unsigned(TileSize));
1852 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
1853
1854 for (unsigned K = 0; K < M; K += TileSize) {
1855 const unsigned TileM = std::min(M - K, unsigned(TileSize));
1856 MatrixTy A =
1859 {TileR, TileM}, EltType, Builder);
1860 MatrixTy B =
1863 {TileM, TileC}, EltType, Builder);
1864 emitMatrixMultiply(Res, A, B, Builder, true, false,
1865 getFastMathFlags(MatMul));
1866 }
1867 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
1869 Builder);
1870 }
1871 }
1872
1873
1874 FusedInsts.insert(Store);
1875 FusedInsts.insert(MatMul);
1876 eraseFromParentAndRemoveFromShapeMap(Store);
1877 eraseFromParentAndRemoveFromShapeMap(MatMul);
1879 FusedInsts.insert(LoadOp0);
1880 eraseFromParentAndRemoveFromShapeMap(LoadOp0);
1881 }
1882 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
1883 FusedInsts.insert(LoadOp1);
1884 eraseFromParentAndRemoveFromShapeMap(LoadOp1);
1885 }
1886 }
1887
1888
1889
1890
1891
1892 void
1893 LowerMatrixMultiplyFused(CallInst *MatMul,
1897 return;
1898
1899 assert(AA && LI && "Analyses should be available");
1900
1903
1904
1906 if (MatrixLayout == MatrixLayoutTy::ColumnMajor
1907 ? match(B, m_IntrinsicIntrinsic::matrix\_transpose(m_Value(T)))
1908 : match(A, m_IntrinsicIntrinsic::matrix\_transpose(m_Value(T)))) {
1910 auto *EltType = cast(MatMul->getType())->getElementType();
1913 const unsigned R = LShape.NumRows;
1914 const unsigned M = LShape.NumColumns;
1915 const unsigned C = RShape.NumColumns;
1916
1917 MatrixTy MA;
1918 MatrixTy MB;
1919
1920 Value *Transpose;
1921 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {
1922 MA = getMatrix(A, ShapeInfo(R, M), Builder);
1923 MB = getMatrix(T, ShapeInfo(C, M), Builder);
1924 Transpose = B;
1925 } else {
1926 MA = getMatrix(T, ShapeInfo(R, M), Builder);
1927 MB = getMatrix(B, ShapeInfo(C, M), Builder);
1928 Transpose = A;
1929 }
1930
1931
1932 MatrixTy Result(R, C, EltType);
1933
1934 emitMatrixMultiply(Result, MA, MB, Builder, false, true,
1935 getFastMathFlags(MatMul));
1936
1937 FusedInsts.insert(MatMul);
1939 FusedInsts.insert(cast(Transpose));
1940 ToRemove.push_back(cast(Transpose));
1941
1942
1943 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
1944 }
1945 finalizeLowering(MatMul, Result, Builder);
1946 return;
1947 }
1948
1950 return;
1951
1952
1953
1954 auto *LoadOp0 = dyn_cast(A);
1955 auto *LoadOp1 = dyn_cast(B);
1956 auto *Store = dyn_cast(*MatMul->user_begin());
1957 if (LoadOp0 && LoadOp1 && Store) {
1958
1959
1963 for (unsigned I = 0; I != WorkList.size(); ++I) {
1964 Value *Current = WorkList[I];
1965 auto *CurrI = dyn_cast(Current);
1966 if (!CurrI)
1967 continue;
1968 if (isa(CurrI))
1969 return;
1970 if (DT->dominates(CurrI, MatMul))
1971 continue;
1972 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1973 return;
1975 WorkList.insert(CurrI->op_begin(), CurrI->op_end());
1976 }
1977
1980 });
1982 I->moveBefore(MatMul);
1983
1984
1985
1986
1987
1988
1989
1990
1994 bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent &&
1995 LoadOp1->getParent() == StoreParent;
1996 for (unsigned Idx = 0; Idx != LifetimeEnds.size();) {
1999
2000
2002 continue;
2004 continue;
2005
2006
2007 if (FusableOpsInSameBlock && End->getParent() != StoreParent)
2008 continue;
2009
2010
2011
2013 if (!EndLoc.Ptr)
2014 continue;
2015 if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))
2016 continue;
2017
2018
2019
2020
2021 if (End->getParent() == StoreParent) {
2022 End->moveAfter(Store);
2023 continue;
2024 }
2025
2026
2030 Inc.release();
2031 }
2032
2033 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
2034 return;
2035 }
2036 }
2037
2038
2039 void LowerMultiply(CallInst *MatMul) {
2041 auto *EltType = cast(MatMul->getType())->getElementType();
2044
2045 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
2046 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
2047 assert(Lhs.getElementType() == Rhs.getElementType() &&
2048 "Matrix multiply argument element types do not match.");
2049
2050 const unsigned R = LShape.NumRows;
2051 const unsigned C = RShape.NumColumns;
2052 assert(LShape.NumColumns == RShape.NumRows);
2053
2054
2055 MatrixTy Result(R, C, EltType);
2056 assert(Lhs.getElementType() == Result.getElementType() &&
2057 "Matrix multiply result element type does not match arguments.");
2058
2059 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
2060 getFastMathFlags(MatMul));
2061 finalizeLowering(MatMul, Result, Builder);
2062 }
2063
2064
2065 void LowerTranspose(CallInst *Inst) {
2071 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
2072
2073 const unsigned NewNumVecs =
2074 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
2075 const unsigned NewNumElts =
2076 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
2077
2078 for (unsigned I = 0; I < NewNumVecs; ++I) {
2079
2082
2083 for (auto J : enumerate(InputMatrix.vectors())) {
2085
2086 ResultVector =
2088 }
2089 Result.addVector(ResultVector);
2090 }
2091
2092
2093
2094
2095 finalizeLowering(
2096 Inst,
2097 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
2098 .addNumExposedTransposes(1),
2099 Builder);
2100 }
2101
2102
2104 auto I = ShapeMap.find(Inst);
2106 return false;
2107
2110 I->second);
2111 return true;
2112 }
2113
2116 auto I = ShapeMap.find(StoredVal);
2118 return false;
2119
2122 I->second);
2123 return true;
2124 }
2125
2126
2128 auto I = ShapeMap.find(Inst);
2130 return false;
2131
2134
2136 ShapeInfo &Shape = I->second;
2137
2139 MatrixTy A = getMatrix(Lhs, Shape, Builder);
2140 MatrixTy B = getMatrix(Rhs, Shape, Builder);
2141 assert(A.isColumnMajor() == B.isColumnMajor() &&
2142 Result.isColumnMajor() == A.isColumnMajor() &&
2143 "operands must agree on matrix layout");
2144
2146
2147
2148 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
2150 case Instruction::Add:
2151 return Builder.CreateAdd(LHS, RHS);
2152 case Instruction::Mul:
2153 return Builder.CreateMul(LHS, RHS);
2154 case Instruction::Sub:
2155 return Builder.CreateSub(LHS, RHS);
2156 case Instruction::FAdd:
2157 return Builder.CreateFAdd(LHS, RHS);
2158 case Instruction::FMul:
2159 return Builder.CreateFMul(LHS, RHS);
2160 case Instruction::FSub:
2161 return Builder.CreateFSub(LHS, RHS);
2162 default:
2164 }
2165 };
2166
2167 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2168 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
2169
2170 finalizeLowering(Inst,
2171 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2172 Result.getNumVectors()),
2173 Builder);
2174 return true;
2175 }
2176
2177
2179 auto I = ShapeMap.find(Inst);
2181 return false;
2182
2184
2186 ShapeInfo &Shape = I->second;
2187
2189 MatrixTy M = getMatrix(Op, Shape, Builder);
2190
2192
2193
2194 auto BuildVectorOp = [&Builder, Inst](Value *Op) {
2196 case Instruction::FNeg:
2198 default:
2200 }
2201 };
2202
2203 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2204 Result.addVector(BuildVectorOp(M.getVector(I)));
2205
2206 finalizeLowering(Inst,
2207 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2208 Result.getNumVectors()),
2209 Builder);
2210 return true;
2211 }
2212
2213
2214
2215
2216 struct ExprLinearizer {
2217 unsigned LengthToBreak = 100;
2218 std::string Str;
2220 unsigned LineLength = 0;
2222
2223
2224
2226
2227
2228
2230
2231
2233
2234
2236
2237
2238
2240
2241 ExprLinearizer(const DataLayout &DL,
2246 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2247 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2248
2249 void indent(unsigned N) {
2250 LineLength += N;
2251 for (unsigned i = 0; i < N; i++)
2252 Stream << " ";
2253 }
2254
2255 void lineBreak() {
2256 Stream << "\n";
2257 LineLength = 0;
2258 }
2259
2260 void maybeIndent(unsigned Indent) {
2261 if (LineLength >= LengthToBreak)
2262 lineBreak();
2263
2264 if (LineLength == 0)
2266 }
2267
2269 LineLength += S.size();
2270 Stream << S;
2271 }
2272
2273 Value *getUnderlyingObjectThroughLoads(Value *V) {
2275 return getUnderlyingObjectThroughLoads(Ptr);
2276 else if (V->getType()->isPointerTy())
2278 return V;
2279 }
2280
2281
2282 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
2283
2284
2285
2287 auto M = Inst2Matrix.find(V);
2288 if (M == Inst2Matrix.end())
2289 SS << "unknown";
2290 else {
2291 SS << M->second.getNumRows();
2292 SS << "x";
2293 SS << M->second.getNumColumns();
2294 }
2295 }
2296
2297
2298
2299
2300 void writeFnName(CallInst *CI) {
2302 write("");
2303 else {
2305 if (.starts_with("llvm.matrix")) {
2307 return;
2308 }
2309 auto *II = cast(CI);
2313 std::string Tmp;
2315
2316 switch (II->getIntrinsicID()) {
2317 case Intrinsic::matrix_multiply:
2318 prettyPrintMatrixType(II->getOperand(0), SS);
2319 SS << ".";
2320 prettyPrintMatrixType(II->getOperand(1), SS);
2321 SS << "." << *II->getType()->getScalarType();
2322 break;
2323 case Intrinsic::matrix_transpose:
2324 prettyPrintMatrixType(II->getOperand(0), SS);
2325 SS << "." << *II->getType()->getScalarType();
2326 break;
2327 case Intrinsic::matrix_column_major_load:
2328 prettyPrintMatrixType(II, SS);
2329 SS << "." << *II->getType()->getScalarType();
2330 break;
2331 case Intrinsic::matrix_column_major_store:
2332 prettyPrintMatrixType(II->getOperand(0), SS);
2333 SS << "." << *II->getOperand(0)->getType()->getScalarType();
2334 break;
2335 default:
2337 }
2339 }
2340 }
2341
2342 unsigned getNumShapeArgs(CallInst *CI) const {
2344 switch (II->getIntrinsicID()) {
2345 case Intrinsic::matrix_multiply:
2346 return 3;
2347 case Intrinsic::matrix_transpose:
2348 return 2;
2349 case Intrinsic::matrix_column_major_load:
2350 case Intrinsic::matrix_column_major_store:
2351 return 3;
2352 default:
2353 return 0;
2354 }
2355 }
2356 return 0;
2357 }
2358
2359
2360
2361
2363 V = getUnderlyingObjectThroughLoads(V);
2364 if (V->getType()->isPointerTy()) {
2365 if (isa(V)) {
2366 Stream << "stack addr";
2368 } else {
2369 Stream << "addr";
2371 }
2372 if (->getName().empty()) {
2373 Stream << " %" << V->getName() << "";
2374 LineLength += V->getName().size() + 2;
2375 }
2376 return;
2377 }
2378
2379 std::string Tmp;
2381
2382 if (auto *CI = dyn_cast(V))
2383 TmpStream << CI->getValue();
2384 else if (isa(V))
2385 TmpStream << "constant";
2386 else {
2387 if (isMatrix(V))
2388 TmpStream << "matrix";
2389 else
2390 TmpStream << "scalar";
2391 }
2392 Tmp = std::string(StringRef(Tmp).trim());
2393 LineLength += Tmp.size();
2394 Stream << Tmp;
2395 }
2396
2397
2398
2399
2400 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
2401 bool ParentShared) {
2402 auto *I = cast(Expr);
2403 maybeIndent(Indent);
2405
2406
2407 bool ExprShared = false;
2408
2409
2410 if (!ParentShared) {
2411 auto SI = Shared.find(Expr);
2412 assert(SI != Shared.end() && SI->second.count(Leaf));
2413
2414 for (Value *S : SI->second) {
2415 if (S == Leaf)
2416 continue;
2417 DebugLoc DL = cast(S)->getDebugLoc();
2418 write("shared with remark at line " + std::to_string(DL.getLine()) +
2419 " column " + std::to_string(DL.getCol()) + " (");
2420 }
2421 ExprShared = SI->second.size() > 1;
2422 }
2423
2424 bool Reused = !ReusedExprs.insert(Expr).second;
2425 if (Reused && !ParentReused)
2426 write("(reused) ");
2427
2428 if (auto *CI = dyn_cast(I)) {
2429 writeFnName(CI);
2430
2432 } else if (isa(Expr)) {
2433
2434
2435 write("matrix");
2436 return;
2437 } else {
2438 Ops.append(I->value_op_begin(), I->value_op_end());
2439 write(std::string(I->getOpcodeName()));
2440 }
2441
2442 write(std::string("("));
2443
2444 unsigned NumOpsToBreak = 1;
2445 if (match(Expr, m_IntrinsicIntrinsic::matrix\_column\_major\_load()))
2446 NumOpsToBreak = 2;
2447
2449 if (Ops.size() > NumOpsToBreak)
2450 lineBreak();
2451
2452 maybeIndent(Indent + 1);
2453 if (isMatrix(Op))
2454 linearizeExpr(Op, Indent + 1, Reused, ExprShared);
2455 else
2457 if (Op != Ops.back())
2459 }
2460
2462 }
2463
2464 const std::string &getResult() {
2465 return Str;
2466 }
2467 };
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482 struct RemarkGenerator {
2487
2490 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2491 DL(Func.getDataLayout()) {}
2492
2493
2494
2495
2499 for (auto *Expr : ExprsInSubprogram)
2502 return ExprsInSubprogram.count(U);
2503 }))
2505 return Leaves;
2506 }
2507
2508
2509
2510
2511 void collectSharedInfo(Value *Leaf, Value *V,
2514
2515 if (!ExprsInSubprogram.count(V))
2516 return;
2517
2518 Shared[V].insert(Leaf);
2519
2520 for (Value *Op : cast(V)->operand_values())
2521 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
2522 }
2523
2524
2525
2526
2527 std::pair<OpInfoTy, OpInfoTy>
2531 if (!ExprsInSubprogram.count(Root))
2532 return {};
2533
2534
2535 if (!ReusedExprs.insert(Root).second)
2536 return {};
2537
2538 OpInfoTy SharedCount;
2539 OpInfoTy Count;
2540
2541 auto I = Shared.find(Root);
2542 auto CM = Inst2Matrix.find(Root);
2543 if (I->second.size() == 1)
2544 Count = CM->second.getOpInfo();
2545 else
2546 SharedCount = CM->second.getOpInfo();
2547
2548 for (Value *Op : cast(Root)->operand_values()) {
2549 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
2550 Count += C.first;
2551 SharedCount += C.second;
2552 }
2553 return {Count, SharedCount};
2554 }
2555
2556 void emitRemarks() {
2558 return;
2559
2560
2561
2562
2564 for (const auto &KV : Inst2Matrix) {
2565 if (Func.getSubprogram()) {
2566 auto *I = cast(KV.first);
2568 while (Context) {
2569 Subprog2Exprs[getSubprogram(Context->getScope())].push_back(
2570 KV.first);
2572 }
2573 } else {
2574 Subprog2Exprs[nullptr].push_back(KV.first);
2575 }
2576 }
2577 for (auto &KV : Subprog2Exprs) {
2579 KV.second.end());
2580 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2581
2583 for (Value *Leaf : Leaves)
2584 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
2585
2586
2587 for (auto *L : Leaves) {
2588
2589 DebugLoc Loc = cast(L)->getDebugLoc();
2590 DILocation *Context = cast(L)->getDebugLoc();
2591 while (Context) {
2592 if (getSubprogram(Context->getScope()) == KV.first) {
2593 Loc = Context;
2594 break;
2595 }
2597 }
2598
2600 OpInfoTy Counts, SharedCounts;
2601 std::tie(Counts, SharedCounts) =
2602 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
2603
2605 cast(L)->getParent());
2606
2607 Rem << "Lowered with ";
2608 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2609 << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2610 << ore::NV("NumComputeOps", Counts.NumComputeOps)
2611 << " compute ops, "
2612 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2613 << " exposed transposes";
2614
2615 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2616 SharedCounts.NumComputeOps > 0) {
2617 Rem << ",\nadditionally "
2618 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2619 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2620 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2621 << " compute ops"
2622 << " are shared with other expressions";
2623 }
2624
2625 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2626 ORE.emit(Rem);
2627 }
2628 }
2629 }
2630
2631 std::string
2632 linearize(Value *L,
2636 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2637 Lin.linearizeExpr(L, 0, false, false);
2638 return Lin.getResult();
2639 }
2640 };
2641};
2642}
2643
2647
2648 LowerMatrixIntrinsics LMT(F, TTI, Minimal ? nullptr : &AM);
2649 if (LMT.Visit()) {
2651 if (!Minimal) {
2654 }
2655 return PA;
2656 }
2658}
2659
2663 OS, MapClassName2PassName);
2664 OS << '<';
2665 if (Minimal)
2666 OS << "minimal";
2667 OS << '>';
2668}
ReachingDefAnalysis InstSet & ToRemove
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static const Function * getParent(const Value *V)
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
hexagon Hexagon specific predictive commoning for HVX vectors
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
static DISubprogram * getSubprogram(DIScope *Scope)
Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...
static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))
static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))
static bool isSplat(Value *V)
Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).
static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))
static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))
auto m_AnyAdd(const LTy &L, const RTy &R)
Match any add operation (fp or integer).
static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))
auto m_AnyMul(const LTy &L, const RTy &R)
Match any mul operation (fp or integer).
static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))
static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))
static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))
uint64_t IntrinsicInst * II
PowerPC Reduce CR logical Operation
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static unsigned getNumElements(Type *Ty)
static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)
static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
This file defines the SmallSet class.
This file defines the SmallVector class.
static SymbolRef::Type getType(const Symbol *Sym)
static const int BlockSize
This pass exposes codegen information to IR-level passes.
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)
A manager for alias analyses.
bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)
A trivial helper function to check to see if the specified pointers are no-alias.
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
reverse_iterator rbegin()
InstListType::reverse_iterator reverse_iterator
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
BinaryOps getOpcode() const
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
User::op_iterator arg_begin()
Return the iterator pointing to the beginning of the argument list.
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Value * getArgOperand(unsigned i) const
User::op_iterator arg_end()
Return the iterator pointing to the end of the argument list.
This class represents a function call, abstracting a target machine's calling convention.
static ConstantAggregateZero * get(Type *Ty)
This is the shared class of boolean and integer constants.
DISubprogram * getSubprogram() const
Get the subprogram for this scope.
Base class for scope-like contexts.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
DILocation * getInlinedAt() const
iterator find(const_arg_type_t< KeyT > Val)
bool erase(const KeyT &Val)
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Analysis pass which computes a DominatorTree.
void applyUpdates(ArrayRef< UpdateType > Updates)
Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...
static constexpr UpdateKind Delete
static constexpr UpdateKind Insert
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Convenience struct for specifying and reasoning about fast-math flags.
void setAllowContract(bool B=true)
bool allowReassoc() const
Flag queries.
bool allowContract() const
Class to represent fixed width SIMD vectors.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Intrinsic::ID getIntrinsicID() const LLVM_READONLY
getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
CallInst * CreateFAddReduce(Value *Acc, Value *Src)
Create a sequential vector fadd reduction intrinsic of the source vector.
Value * CreateICmpULT(Value *LHS, Value *RHS, const Twine &Name="")
Value * CreateFSub(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Value * CreateFAdd(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
CallInst * CreateAddReduce(Value *Src)
Create a vector int add reduction intrinsic of the source vector.
IntegerType * getIntPtrTy(const DataLayout &DL, unsigned AddrSpace=0)
Fetch the type of an integer with size at least as big as that of a pointer in the given address spac...
void setFastMathFlags(FastMathFlags NewFMF)
Set the fast-math flags to be used with generated fp-math operators.
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
Value * CreateFMul(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)
Value * CreateFNeg(Value *V, const Twine &Name="", MDNode *FPMathTag=nullptr)
CallInst * CreateMemCpy(Value *Dst, MaybeAlign DstAlign, Value *Src, MaybeAlign SrcAlign, uint64_t Size, bool isVolatile=false, MDNode *TBAATag=nullptr, MDNode *TBAAStructTag=nullptr, MDNode *ScopeTag=nullptr, MDNode *NoAliasTag=nullptr)
Create and insert a memcpy between the specified pointers.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
static InstructionCost getInvalid(CostType Val=0)
void setFastMathFlags(FastMathFlags FMF)
Convenience function for setting multiple fast-math flags on this instruction, which must be an opera...
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
A wrapper class for inspecting calls to intrinsic functions.
An instruction for reading from memory.
bool isVolatile() const
Return true if this is a load from a volatile memory location.
Align getAlign() const
Return the alignment of the access that is being performed.
TypeSize getValue() const
Analysis pass that exposes the LoopInfo for a function.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)
This class implements a map that also provides access to all stored values in a deterministic order.
iterator find(const KeyT &Key)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")
Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.
CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")
Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.
Representation for a specific memory location.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
LocationSize Size
The maximum size of the location, in address-units, or UnknownSize if the size is not known.
const Value * Ptr
The address of the start of the location.
static MemoryLocation getForArgument(const CallBase *Call, unsigned ArgIdx, const TargetLibraryInfo *TLI)
Return a location representing a particular argument of a call.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
A vector that has set insertion semantics.
size_type size() const
Determine the number of elements in the SetVector.
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
bool insert(const value_type &X)
Insert a new element into the SetVector.
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
A SetVector that performs no allocations if smaller than a certain size.
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
bool isVolatile() const
Return true if this is a store to a volatile memory location.
StringRef - Represent a constant reference to a string, i.e.
StringRef drop_front(size_t N=1) const
Return a StringRef equal to 'this' but with the first N elements dropped.
constexpr size_t size() const
size - Get the string size.
Analysis pass providing the TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
TypeSize getRegisterBitWidth(RegisterKind K) const
InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, OperandValueInfo OpdInfo={OK_AnyValue, OP_None}, const Instruction *I=nullptr) const
InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, std::optional< FastMathFlags > FMF, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput) const
Calculate the cost of vector reduction intrinsics.
unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
@ TCK_RecipThroughput
Reciprocal throughput.
InstructionCost getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, TTI::OperandValueInfo Opd1Info={TTI::OK_AnyValue, TTI::OP_None}, TTI::OperandValueInfo Opd2Info={TTI::OK_AnyValue, TTI::OP_None}, ArrayRef< const Value * > Args={}, const Instruction *CxtI=nullptr, const TargetLibraryInfo *TLibInfo=nullptr) const
This is an approximation of reciprocal throughput of a math/logic op.
InstructionCost getShuffleCost(ShuffleKind Kind, VectorType *Tp, ArrayRef< int > Mask={}, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, int Index=0, VectorType *SubTp=nullptr, ArrayRef< const Value * > Args={}, const Instruction *CxtI=nullptr) const
unsigned getNumberOfRegisters(unsigned ClassID) const
@ SK_Splice
Concatenates elements from the first input vector with elements of the second input vector.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
UnaryOps getOpcode() const
A Use represents the edge between a Value definition and its users.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
user_iterator user_begin()
bool hasOneUse() const
Return true if there is exactly one use of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
iterator_range< use_iterator > uses()
StringRef getName() const
Return a constant reference to the value's name.
Type * getElementType() const
constexpr ScalarTy getFixedValue() const
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
A range adaptor for a pair of iterators.
This class implements an extremely fast bulk output stream that can only output to a stream.
A raw_ostream that writes to an std::string.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
StringRef getBaseName(ID id)
Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....
TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)
Matches StoreInst.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
class_match< BinaryOperator > m_BinOp()
Match an arbitrary binary operation and ignore it.
specific_intval< false > m_SpecificInt(const APInt &V)
Match a specific integer value or vector with all elements equal to the value.
BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
OneUse_match< T > m_OneUse(const T &SubPattern)
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
ElementType
The element type of an SRV or UAV resource.
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< PhiNode * > Phi
NodeAddr< FuncNode * > Func
This is an optimization pass for GlobalISel generic memory operations.
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
auto successors(const MachineBasicBlock *BB)
bool operator!=(uint64_t V1, const APInt &V2)
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & operator+=(DynamicAPInt &A, int64_t B)
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)
Concatenate a list of vectors.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
const Value * getPointerOperand(const Value *V)
A helper function that returns the pointer operand of a load, store or GEP instruction.
void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)
Set input string into loop metadata by keeping other values intact.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
auto reverse(ContainerTy &&C)
Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, OnCuIndexOverflow OverflowOptValue)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast - Return the argument parameter cast to the specified type.
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)
Create a sequential shuffle mask.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
This struct is a compact representation of a valid (non-zero power of two) alignment.
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
A CRTP mix-in to automatically provide informational APIs needed for passes.
A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....