LLVM: lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

51

52#include

53

54using namespace llvm;

56

57#define DEBUG_TYPE "lower-matrix-intrinsics"

58

59STATISTIC(FlattenedMatrices, "Number of matrix flattenings");

60STATISTIC(ReshapedMatrices, "Number of matrix reshapes");

61STATISTIC(SplitMatrices, "Number of matrix splits");

62

65 cl::desc("Enable/disable fusing matrix instructions."));

66

70 "Tile size for matrix instruction fusion using square-shaped tiles."));

73 cl::desc("Generate loop nest for tiling."));

76 cl::desc("Force matrix instruction fusion even if not profitable."));

79 cl::desc("Allow the use of FMAs if available and profitable. This may "

80 "result in different results, due to less rounding error."));

81

84 cl::desc("Enable/disable matrix shape verification."),

86

88

91 cl::desc("Sets the default matrix layout"),

93 "Use column-major layout"),

95 "Use row-major layout")));

96

99

101 "matrix-split-matmul-remainder-over-threshold", cl::Hidden,

102 cl::desc("Illegal remainder vectors over this size in bits should be split "

103 "in the inner loop of matmul"),

105

106

107

110 return Subprogram;

112}

113

114

115

118 return SV->isZeroEltSplat();

119 return false;

120}

121

122

123template <typename LTy, typename RTy>

124static auto m_AnyMul(const LTy &L, const RTy &R) {

126}

127

128

129template <typename LTy, typename RTy>

130static auto m_AnyAdd(const LTy &L, const RTy &R) {

132}

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

175 unsigned NumElements, Type *EltType,

177

180 "Stride must be >= the number of elements in the result vector.");

181

182

183 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");

184

185

186

188 VecStart = BasePtr;

189 else

190 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");

191

192 return VecStart;

193}

194

195namespace {

196struct ShapeInfo {

197 unsigned NumRows;

198 unsigned NumColumns;

199

200 bool IsColumnMajor;

201

202 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)

203 : NumRows(NumRows), NumColumns(NumColumns),

205

206 ShapeInfo(Value *NumRows, Value *NumColumns)

207 : ShapeInfo(cast(NumRows)->getZExtValue(),

208 cast(NumColumns)->getZExtValue()) {}

209

210 bool operator==(const ShapeInfo &other) {

211 return NumRows == other.NumRows && NumColumns == other.NumColumns;

212 }

213 bool operator!=(const ShapeInfo &other) { return !(*this == other); }

214

215

216

217 operator bool() const {

218 assert(NumRows == 0 || NumColumns != 0);

219 return NumRows != 0;

220 }

221

222 unsigned getStride() const {

223 if (IsColumnMajor)

224 return NumRows;

225 return NumColumns;

226 }

227

228 unsigned getNumVectors() const {

229 if (IsColumnMajor)

230 return NumColumns;

231 return NumRows;

232 }

233

234

235 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }

236

237 friend raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI);

238

240};

241

243 return OS << SI.NumRows << 'x' << SI.NumColumns;

244}

245

246}

247

250 if (I)

251 return true;

252

254 return true;

255

256 if (I->isBinaryOp())

257 return true;

258

260 switch (Cast->getOpcode()) {

261 case llvm::Instruction::Trunc:

262 case llvm::Instruction::ZExt:

263 case llvm::Instruction::SExt:

264 case llvm::Instruction::FPToUI:

265 case llvm::Instruction::FPToSI:

266 case llvm::Instruction::UIToFP:

267 case llvm::Instruction::SIToFP:

268 case llvm::Instruction::FPTrunc:

269 case llvm::Instruction::FPExt:

270 return true;

271 case llvm::Instruction::AddrSpaceCast:

272 case CastInst::PtrToAddr:

273 case CastInst::PtrToInt:

274 case CastInst::IntToPtr:

275 return false;

276 case CastInst::BitCast: {

279 return SrcVTy->getNumElements() == DestVTy->getNumElements();

280 return false;

281 }

282 case llvm::Instruction::CastOpsEnd:

284 }

286 }

287

289 switch (II->getIntrinsicID()) {

290 case Intrinsic::abs:

291 case Intrinsic::fabs:

292 return true;

293 default:

294 return false;

295 }

296

297 switch (I->getOpcode()) {

298 case Instruction::PHI:

299 case Instruction::FNeg:

300 return true;

301 default:

302 return false;

303 }

304}

305

306

307

310 "Can't retrieve shaped operands for an instruction that does not "

311 "preserve shape information");

312 auto Ops = I->operands();

314}

315

316

317static std::optional

325 return ShapeInfo(M, K);

328

329 return ShapeInfo(N, M);

330 }

334 return ShapeInfo(N, M);

337 return ShapeInfo(M, N);

340 auto OpShape = ShapeMap.find(MatrixA);

341 if (OpShape != ShapeMap.end())

342 return OpShape->second;

343 }

344

347

348 for (auto &Op : ShapedOps) {

349 auto OpShape = ShapeMap.find(Op.get());

350 if (OpShape != ShapeMap.end())

351 return OpShape->second;

352 }

353 }

354 return std::nullopt;

355}

356

357namespace {

358

359

360

361

362

363

364

365

366

367

368

369

370

371

372

373

374

375

376

377

378

379

380

381

382class LowerMatrixIntrinsics {

384 const DataLayout &DL;

385 const TargetTransformInfo &TTI;

388 DominatorTree *DT = nullptr;

389 LoopInfo *LI = nullptr;

390 OptimizationRemarkEmitter *ORE = nullptr;

391

392

393

394 struct OpInfoTy {

395

396 unsigned NumStores = 0;

397

398 unsigned NumLoads = 0;

399

400 unsigned NumComputeOps = 0;

401

402

403

404 unsigned NumExposedTransposes = 0;

405

407 NumStores += RHS.NumStores;

408 NumLoads += RHS.NumLoads;

409 NumComputeOps += RHS.NumComputeOps;

410 NumExposedTransposes += RHS.NumExposedTransposes;

411 return *this;

412 }

413 };

414

415

416

417 class MatrixTy {

418 SmallVector<Value *, 16> Vectors;

419

420 OpInfoTy OpInfo;

421

422 bool IsColumnMajor = true;

423

424 public:

427 : Vectors(Vectors),

429 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)

431

432 unsigned D = isColumnMajor() ? NumColumns : NumRows;

433 for (unsigned J = 0; J < D; ++J)

435 EltTy, isColumnMajor() ? NumRows : NumColumns)));

436 }

437

438 Value *getVector(unsigned i) const { return Vectors[i]; }

439 Value *getColumn(unsigned i) const {

440 assert(isColumnMajor() && "only supported for column-major matrixes");

441 return Vectors[i];

442 }

443 Value *getRow(unsigned i) const {

444 assert(!isColumnMajor() && "only supported for row-major matrixes");

445 return Vectors[i];

446 }

447

448 void setVector(unsigned i, Value *V) { Vectors[i] = V; }

449

450 Type *getElementType() const { return getVectorTy()->getElementType(); }

451

452 unsigned getNumVectors() const {

453 if (isColumnMajor())

454 return getNumColumns();

455 return getNumRows();

456 }

457

458 unsigned getNumColumns() const {

459 if (isColumnMajor())

460 return Vectors.size();

461 else {

462 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");

463 return getVectorTy()->getNumElements();

464 }

465 }

466 unsigned getNumRows() const {

467 if (isColumnMajor()) {

468 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");

469 return getVectorTy()->getNumElements();

470 } else

471 return Vectors.size();

472 }

473

474 void addVector(Value *V) { Vectors.push_back(V); }

475 FixedVectorType *getColumnTy() {

476 assert(isColumnMajor() && "only supported for column-major matrixes");

477 return getVectorTy();

478 }

479

480 FixedVectorType *getVectorTy() const {

482 }

483

484 iterator_range<SmallVector<Value *, 8>::iterator> columns() {

485 assert(isColumnMajor() &&

486 "columns() only supported for column-major matrixes");

487 return make_range(Vectors.begin(), Vectors.end());

488 }

489

490 iterator_range<SmallVector<Value *, 8>::iterator> vectors() {

491 return make_range(Vectors.begin(), Vectors.end());

492 }

493

494

495

497 return Vectors.size() == 1 ? Vectors[0]

499 }

500

501 MatrixTy &addNumLoads(unsigned N) {

502 OpInfo.NumLoads += N;

503 return *this;

504 }

505

506 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }

507

508 MatrixTy &addNumStores(unsigned N) {

509 OpInfo.NumStores += N;

510 return *this;

511 }

512

513 MatrixTy &addNumExposedTransposes(unsigned N) {

514 OpInfo.NumExposedTransposes += N;

515 return *this;

516 }

517

518 MatrixTy &addNumComputeOps(unsigned N) {

519 OpInfo.NumComputeOps += N;

520 return *this;

521 }

522

523 unsigned getNumStores() const { return OpInfo.NumStores; }

524 unsigned getNumLoads() const { return OpInfo.NumLoads; }

525 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }

526

527 const OpInfoTy &getOpInfo() const { return OpInfo; }

528

529 bool isColumnMajor() const { return IsColumnMajor; }

530

531 unsigned getStride() const {

532 if (isColumnMajor())

533 return getNumRows();

534 return getNumColumns();

535 }

536

537 ShapeInfo shape() const { return {getNumRows(), getNumColumns()}; }

538

539

540

541

544 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);

546 NumElts &&

547 "Extracted vector will contain poison values");

550 "block");

551 }

552 };

553

554

555

556

557

558

559

560

561

562

563

564

565

566

567

568 DenseMap<Value *, ShapeInfo> ShapeMap;

569

570

571

572

573 SmallVector<Instruction *, 16> ToRemove;

574

575

576 MapVector<Value *, MatrixTy> Inst2ColumnMatrix;

577

578private:

579 static FastMathFlags getFastMathFlags(Instruction *Inst) {

580 FastMathFlags FMF;

581

584

586

587 return FMF;

588 }

589

590public:

591 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,

593 : Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {}

594

595 unsigned getNumOps(Type *VT) {

599 }

600

601

602 bool isMinimal() const {

603 return !DT;

604 }

605

606

607

608 unsigned getNumOps(Type *ST, unsigned N) {

609 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /

610 double(TTI.getRegisterBitWidth(

612 .getFixedValue()));

613 }

614

615

616

617

618

619

620 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,

624 "The vector size must match the number of matrix elements");

625

626

627

628

629

630 auto Found = Inst2ColumnMatrix.find(MatrixVal);

631 if (Found != Inst2ColumnMatrix.end()) {

632 MatrixTy &M = Found->second;

633

634

635 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())

636 return M;

637

638 MatrixVal = M.embedInVector(Builder);

639 }

640

641

642 SmallVector<Value *, 16> SplitVecs;

643 for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();

644 MaskStart += SI.getStride()) {

647 "split");

649 }

650

652 if (Found != Inst2ColumnMatrix.end()) {

653

654

655 LLVM_DEBUG(dbgs() << "matrix reshape from " << Found->second.shape()

656 << " to " << SI << " using at least "

657 << SplitVecs.size() << " shuffles on behalf of:\n"

658 << *Inst << '\n');

659 ReshapedMatrices++;

660 } else if (!ShapeMap.contains(MatrixVal)) {

663 << "splitting a " << SI << " matrix with " << SplitVecs.size()

664 << " shuffles beacuse we do not have a shape-aware lowering for "

665 "its def:\n"

666 << *Inst << '\n');

667 (void)Inst;

668 SplitMatrices++;

669 } else {

670

671

672

673 }

674 }

675

676 return {SplitVecs};

677 }

678

679

680

681 bool setShapeInfo(Value *V, ShapeInfo Shape) {

682 assert(Shape && "Shape not set");

684 return false;

685

686 auto SIter = ShapeMap.find(V);

687 if (SIter != ShapeMap.end()) {

688 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||

689 SIter->second.NumColumns != Shape.NumColumns)) {

690 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"

691 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"

692 << Shape.NumColumns << ") for " << *V << "\n";

694 "Matrix shape verification failed, compilation aborted!");

695 }

696

697 LLVM_DEBUG(dbgs() << " not overriding existing shape: "

698 << SIter->second.NumRows << " "

699 << SIter->second.NumColumns << " for " << *V << "\n");

700 return false;

701 }

702

703 ShapeMap.insert({V, Shape});

704 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns

705 << " for " << *V << "\n");

706 return true;

707 }

708

709

710

711 bool supportsShapeInfo(Value *V) {

713 if (!Inst)

714 return false;

715

717 if (II)

718 switch (II->getIntrinsicID()) {

719 case Intrinsic::matrix_multiply:

720 case Intrinsic::matrix_transpose:

721 case Intrinsic::matrix_column_major_load:

722 case Intrinsic::matrix_column_major_store:

723 return true;

724 default:

725 break;

726 }

728 }

729

730

731

732

733

735 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {

737

738

739

741 while (!WorkList.empty()) {

743

744

745 bool Propagate = false;

747 Propagate = setShapeInfo(Inst, *SI);

748

749 if (Propagate) {

751 for (auto *User : Inst->users())

752 if (ShapeMap.count(User) == 0)

754 }

755 }

756

757 return NewWorkList;

758 }

759

760

761

763 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {

765

766 auto pushInstruction = [](Value *V,

767 SmallVectorImpl<Instruction *> &WorkList) {

769 if (I)

771 };

772

773

774

775 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");

776 while (!WorkList.empty()) {

778

779 size_t BeforeProcessingV = WorkList.size();

781 continue;

782

791 if (setShapeInfo(MatrixA, {M, N}))

792 pushInstruction(MatrixA, WorkList);

793

794 if (setShapeInfo(MatrixB, {N, K}))

795 pushInstruction(MatrixB, WorkList);

796

799

800 if (setShapeInfo(MatrixA, {M, N}))

801 pushInstruction(MatrixA, WorkList);

805 if (setShapeInfo(MatrixA, {M, N})) {

806 pushInstruction(MatrixA, WorkList);

807 }

810

812

813

816

817 ShapeInfo Shape = ShapeMap[V];

818 for (Use &U : ShapedOps) {

819 if (setShapeInfo(U.get(), Shape))

820 pushInstruction(U.get(), WorkList);

821 }

822 }

823

824

825

826 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)

827 for (User *U : WorkList[I]->users())

830 }

831 return NewWorkList;

832 }

833

834

835

836

838 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,

839 MatrixBuilder &Builder,

840 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>

843 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");

844

845

846 setShapeInfo(T0, Shape0.t());

848 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");

849 setShapeInfo(T1, Shape1.t());

850 return Operation(T0, Shape0.t(), T1, Shape1.t());

851 }

852

853

854

855 void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {

856 ShapeMap.erase(Inst);

858 }

859

860

861

863 BasicBlock &BB) {

865

867 return;

868 if (II != BB.rend() && Inst == &*II)

869 ++II;

870 eraseFromParentAndRemoveFromShapeMap(Inst);

871 }

872

873

874

875 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {

876

877

878

879 auto S = ShapeMap.find(&Old);

880 if (S != ShapeMap.end()) {

881 ShapeMap.erase(S);

882 if (supportsShapeInfo(New))

883 ShapeMap.insert({New, S->second});

884 }

886 }

887

888

889

890

891

896 MatrixBuilder Builder(IB);

897

899 ConstantInt *R, *K, *C;

902 return nullptr;

903

904

908 updateShapeAndReplaceAllUsesWith(I, TATA);

909 eraseFromParentAndMove(&I, II, BB);

910 eraseFromParentAndMove(TA, II, BB);

912 return nullptr;

913 }

914

915

917 updateShapeAndReplaceAllUsesWith(I, TA);

918 eraseFromParentAndMove(&I, II, BB);

920 return nullptr;

921 }

922

923

924

928 auto NewInst = distributeTransposes(

929 TAMB, {K, C}, TAMA, {R, K}, Builder,

930 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {

932 Shape0.NumColumns,

933 Shape1.NumColumns, "mmul");

934 });

935 updateShapeAndReplaceAllUsesWith(I, NewInst);

936 eraseFromParentAndMove(&I, II, BB);

937 eraseFromParentAndMove(TA, II, BB);

939 return NewInst;

940 }

941

942

943

944

945

949

950

951 auto NewInst = distributeTransposes(

952 TAMA, {R, C}, TAMB, {R, C}, Builder,

953 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {

954 bool IsFP = I.getType()->isFPOrFPVectorTy();

955 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")

956 : LocalBuilder.CreateMul(T0, T1, "mmul");

958 setShapeInfo(Result, Shape0);

960 });

961 updateShapeAndReplaceAllUsesWith(I, NewInst);

962 eraseFromParentAndMove(&I, II, BB);

963 eraseFromParentAndMove(TA, II, BB);

965 return NewInst;

966 }

967

968

969

972 auto NewInst = distributeTransposes(

973 TAMA, {R, C}, TAMB, {R, C}, Builder,

974 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {

975 bool IsFP = I.getType()->isFPOrFPVectorTy();

976 auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")

977 : LocalBuilder.CreateAdd(T0, T1, "madd");

978

980 setShapeInfo(Result, Shape0);

982 });

983 updateShapeAndReplaceAllUsesWith(I, NewInst);

984 eraseFromParentAndMove(&I, II, BB);

985 eraseFromParentAndMove(TA, II, BB);

987 return NewInst;

988 }

989

990 return nullptr;

991 }

992

993 bool liftTranspose(Instruction &I) {

994

996 if (T.use_empty())

997 eraseFromParentAndRemoveFromShapeMap(&T);

998 if (A->use_empty())

1000 if (A != B && B->use_empty())

1002 };

1003

1005 ConstantInt *R, *K, *C;

1006

1013 MatrixBuilder Builder(IB);

1015 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());

1016 setShapeInfo(M, {C, R});

1018 R->getZExtValue());

1019 updateShapeAndReplaceAllUsesWith(I, NewInst);

1020 CleanupBinOp(I, A, B);

1021 return true;

1022 }

1023

1024

1025

1032 auto *Add = Builder.CreateFAdd(AT, BT, "mfadd");

1033 MatrixBuilder MBuilder(Builder);

1034 Instruction *NewInst = MBuilder.CreateMatrixTranspose(

1035 Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");

1036 updateShapeAndReplaceAllUsesWith(I, NewInst);

1039 "Shape of new instruction doesn't match original shape.");

1040 CleanupBinOp(I, A, B);

1042 setShapeInfo(AddI, {R, C});

1045 ShapeMap[AddI] &&

1046 "Shape of updated addition doesn't match cached shape.");

1047 }

1048 return true;

1049 }

1050 return false;

1051 }

1052

1053

1054 bool optimizeTransposes() {

1056

1057

1058 for (BasicBlock &BB : reverse(Func)) {

1061

1062 ++II;

1063 if (Instruction *NewInst = sinkTranspose(I, II, Changed))

1065 }

1066 }

1067

1068

1069

1070 for (BasicBlock &BB : Func) {

1072 Changed |= liftTranspose(I);

1073 }

1074 }

1076 }

1077

1078 bool Visit() {

1080

1081

1082

1083 for (BasicBlock &BB : Func)

1084 for (Instruction &Inst : BB) {

1086 if (II)

1087 continue;

1088

1089 switch (II->getIntrinsicID()) {

1090 case Intrinsic::matrix_multiply:

1091 case Intrinsic::matrix_transpose:

1092 case Intrinsic::matrix_column_major_load:

1093 case Intrinsic::matrix_column_major_store:

1095 break;

1096 default:

1097 break;

1098 }

1099 }

1100

1101

1102 if (WorkList.empty())

1103 return false;

1104

1105 if (AM) {

1106 ORE = &AM->getResult(Func);

1107 AA = &AM->getResult(Func);

1108 DT = &AM->getResult(Func);

1109 LI = &AM->getResult(Func);

1110 }

1111

1112

1113 while (!WorkList.empty()) {

1114 WorkList = propagateShapeForward(WorkList);

1115 WorkList = propagateShapeBackward(WorkList);

1116 }

1117

1119 if (!isMinimal()) {

1120 Changed |= optimizeTransposes();

1122 dbgs() << "Dump after matrix transpose optimization:\n";

1123 Func.print(dbgs());

1124 }

1125 }

1126

1128 SmallVector<Instruction *, 16> MatrixInsts;

1130

1131

1132

1133 ReversePostOrderTraversal<Function *> RPOT(&Func);

1134 for (auto *BB : RPOT)

1135 for (Instruction &I : *BB) {

1138 if (!ShapeMap.contains(&I))

1139 continue;

1143 }

1144

1145

1146 SmallPtrSet<Instruction *, 16> FusedInsts;

1147 for (CallInst *CI : MaybeFusableInsts)

1148 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));

1149

1150

1151 for (CallInst *CI : MaybeFusableInsts)

1152 if (!FusedInsts.contains(CI))

1153 LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);

1154

1156

1157

1158

1159 for (Instruction *Inst : MatrixInsts) {

1160 if (FusedInsts.count(Inst))

1161 continue;

1162

1164 if (PHI)

1165 continue;

1166

1167 const ShapeInfo &SI = ShapeMap.at(Inst);

1169 MatrixTy PhiM(SI.NumRows, SI.NumColumns, EltTy);

1170

1172 for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI)

1173 PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(),

1174 PHI->getNumIncomingValues(),

1175 PHI->getName()));

1176 assert(!Inst2ColumnMatrix.contains(PHI) && "map already contains phi?");

1177 Inst2ColumnMatrix[PHI] = PhiM;

1178 }

1179

1180

1181 for (Instruction *Inst : MatrixInsts) {

1182 if (FusedInsts.count(Inst))

1183 continue;

1184

1185 const ShapeInfo &SI = ShapeMap.at(Inst);

1186

1192 Result = VisitBinaryOperator(BinOp, SI, Builder);

1194 Result = VisitCastInstruction(Cast, SI, Builder);

1196 Result = VisitUnaryOperator(UnOp, SI, Builder);

1198 Result = VisitIntrinsicInst(Intr, SI, Builder);

1200 Result = VisitSelectInst(Select, SI, Builder);

1206 Result = VisitPHI(PHI, SI, Builder);

1207 else

1208 continue;

1209

1210 finalizeLowering(Inst, Result, Builder);

1212 }

1213

1214 if (ORE) {

1215 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);

1216 RemarkGen.emitRemarks();

1217 }

1218

1219

1220

1221

1222

1223

1224

1225

1226

1227

1228 SmallPtrSet<Instruction *, 16> PoisonedInsts;

1229 for (auto *Inst : reverse(ToRemove)) {

1232 PoisonedInsts.insert(Poisoned);

1234 }

1236 PoisonedInsts.erase(Inst);

1237 }

1238 if (!PoisonedInsts.empty()) {

1239

1240 dbgs() << "Poisoned but present instructions:\n";

1241 for (auto *I : PoisonedInsts)

1242 dbgs() << *I << "\n";

1244 }

1245

1247 }

1248

1249

1250 MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI,

1254

1256 case Intrinsic::matrix_multiply:

1257 return LowerMultiply(Inst, Builder);

1258 case Intrinsic::matrix_transpose:

1259 return LowerTranspose(Inst, Builder);

1260 case Intrinsic::matrix_column_major_load:

1261 return LowerColumnMajorLoad(Inst, Builder);

1262 case Intrinsic::matrix_column_major_store:

1263 return LowerColumnMajorStore(Inst, Builder);

1264 case Intrinsic::abs:

1265 case Intrinsic::fabs: {

1267 MatrixTy M = getMatrix(Inst->getOperand(0), SI, Builder);

1269

1270 for (auto *Vector : M.vectors()) {

1272 case Intrinsic::abs:

1275 continue;

1276 case Intrinsic::fabs:

1279 continue;

1280 default:

1282 }

1283 }

1284

1285 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *

1286 Result.getNumVectors());

1287 }

1288 default:

1289 break;

1290 }

1292 "only intrinsics supporting shape info should be seen here");

1293 }

1294

1295

1296

1297

1298

1299

1300 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,

1301 MaybeAlign A) const {

1302 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);

1303 if (Idx == 0)

1304 return InitialAlign;

1305

1306 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);

1308 uint64_t StrideInBytes =

1309 ConstStride->getZExtValue() * ElementSizeInBits / 8;

1310 return commonAlignment(InitialAlign, Idx * StrideInBytes);

1311 }

1312 return commonAlignment(InitialAlign, ElementSizeInBits / 8);

1313 }

1314

1317 }

1318

1319 Value *getIndex(Value *Ptr, uint64_t V) const {

1320 return ConstantInt::get(getIndexType(Ptr), V);

1321 }

1322

1325 "Attempted to cast non-integral type to integer index");

1326

1327

1328

1330 V->getName() + ".cast");

1331 }

1332

1333

1334

1335 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,

1336 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {

1340 Value *EltPtr = Ptr;

1342 Stride = castToIndexType(Ptr, Stride, Builder);

1343 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {

1346 Stride, Shape.getStride(), EltTy, Builder);

1348 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),

1349 IsVolatile, "col.load");

1350

1352 }

1353 return Result.addNumLoads(getNumOps(Result.getVectorTy()) *

1354 Result.getNumVectors());

1355 }

1356

1357

1358

1359 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,

1360 ShapeInfo MatrixShape, Value *I, Value *J,

1361 ShapeInfo ResultShape, Type *EltTy,

1364 Builder.CreateMul(J, getIndex(MatrixPtr, MatrixShape.getStride())), I);

1365

1368 ResultShape.NumColumns);

1369

1370 return loadMatrix(TileTy, TileStart, Align,

1371 getIndex(MatrixPtr, MatrixShape.getStride()), IsVolatile,

1372 ResultShape, Builder);

1373 }

1374

1375

1376 MatrixTy LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align,

1377 Value *Stride, bool IsVolatile, ShapeInfo Shape,

1379 return loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, Shape,

1380 Builder);

1381 }

1382

1383

1384

1385

1386 MatrixTy LowerColumnMajorLoad(CallInst *Inst, IRBuilder<> &Builder) {

1388 "Intrinsic only supports column-major layout!");

1393 {Inst->getArgOperand(3), Inst->getArgOperand(4)}, Builder);

1394 }

1395

1396

1397

1398 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,

1399 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,

1402 Builder.CreateMul(J, getIndex(MatrixPtr, MatrixShape.getStride())), I);

1403

1406 StoreVal.getNumColumns());

1407

1408 storeMatrix(TileTy, StoreVal, TileStart, MAlign,

1409 getIndex(MatrixPtr, MatrixShape.getStride()), IsVolatile,

1410 Builder);

1411 }

1412

1413

1414

1415 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,

1416 MaybeAlign MAlign, Value *Stride, bool IsVolatile,

1419 Value *EltPtr = Ptr;

1420 Stride = castToIndexType(Ptr, Stride, Builder);

1421 for (auto Vec : enumerate(StoreVal.vectors())) {

1423 EltPtr,

1425 Vec.index()),

1426 Stride, StoreVal.getStride(), VType->getElementType(), Builder);

1428 getAlignForIndex(Vec.index(), Stride,

1430 MAlign),

1431 IsVolatile);

1432 }

1433 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *

1434 StoreVal.getNumVectors());

1435 }

1436

1437

1439 MaybeAlign A, Value *Stride, bool IsVolatile,

1440 ShapeInfo Shape, IRBuilder<> &Builder) {

1441 auto StoreVal = getMatrix(Matrix, Shape, Builder);

1442 return storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, IsVolatile,

1443 Builder);

1444 }

1445

1446

1447

1448

1449 MatrixTy LowerColumnMajorStore(CallInst *Inst, IRBuilder<> &Builder) {

1451 "Intrinsic only supports column-major layout!");

1457 {Inst->getArgOperand(4), Inst->getArgOperand(5)},

1458 Builder);

1459 }

1460

1461

1464

1465

1466 unsigned BlockNumElts =

1469 assert(NumElts >= BlockNumElts && "Too few elements for current block");

1470

1473

1474

1475

1476 SmallVector<int, 16> Mask;

1477 unsigned i;

1478 for (i = 0; i < I; i++)

1479 Mask.push_back(i);

1480

1481 unsigned VecNumElts =

1483 for (; i < I + BlockNumElts; i++)

1484 Mask.push_back(i - I + VecNumElts);

1485

1486 for (; i < VecNumElts; i++)

1487 Mask.push_back(i);

1488

1490 }

1491

1493 IRBuilder<> &Builder, bool AllowContraction,

1494 unsigned &NumComputeOps) {

1495 NumComputeOps += getNumOps(A->getType());

1496 if (!Sum)

1498

1499 if (UseFPOp) {

1500 if (AllowContraction) {

1501

1502

1503 return Builder.CreateIntrinsic(Intrinsic::fmuladd, A->getType(),

1504 {A, B, Sum});

1505 }

1506 NumComputeOps += getNumOps(A->getType());

1509 }

1510

1511 NumComputeOps += getNumOps(A->getType());

1514 }

1515

1516

1517

1518

1519

1520

1521 void finalizeLowering(Instruction *Inst, MatrixTy Matrix,

1523 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));

1524 (void)inserted;

1526 "multiple matrix lowering mapping");

1527

1528 ToRemove.push_back(Inst);

1529 Value *Flattened = nullptr;

1531 if (ShapeMap.contains(U.getUser()))

1532 continue;

1533

1534 if (!Flattened) {

1535 Flattened = Matrix.embedInVector(Builder);

1538 << "flattening a " << Matrix.shape() << " matrix:\n"

1539 << *Inst

1540 << "\nbecause we do not have a shape-aware lowering for its "

1541 "user:\n"

1542 << *User << '\n';);

1543 FlattenedMatrices++;

1544 }

1545 U.set(Flattened);

1546 }

1547 }

1548

1549

1550

1551

1552 void lowerDotProduct(CallInst *MatMul,

1553 SmallPtrSet<Instruction *, 16> &FusedInsts,

1554 FastMathFlags FMF) {

1555 if (FusedInsts.contains(MatMul) ||

1557 return;

1560

1561 if (LShape.NumRows != 1 || RShape.NumColumns != 1)

1562 return;

1563

1566

1568 bool IsIntVec = ElementType->isIntegerTy();

1569

1570

1572 return;

1573

1574 auto CanBeFlattened = [](Value *Op) {

1576 return true;

1583 };

1584

1585

1586

1587 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {

1588 if (!ShapeMap.contains(Op))

1589 return InstructionCost::getInvalid();

1590

1593

1596

1597 if (!CanBeFlattened(Op)) {

1599

1600 for (unsigned I = 1; I < N; ++I)

1601 EmbedCost += TTI.getShuffleCost(

1604 return EmbedCost;

1605 }

1606

1610 EltTy) *

1611 N;

1614 return NewCost - OriginalCost;

1615 }

1616

1618

1619

1620

1622 for (unsigned I = 1; I < N; ++I)

1623 EmbedCost -= TTI.getShuffleCost(

1626 return EmbedCost;

1627 }

1628

1629

1630 if (N == 1)

1632

1633 return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -

1634 N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);

1635 };

1636

1637

1638

1639

1640 SmallPtrSet<Value *, 4> Seen;

1645 while (!WorkList.empty()) {

1647 if (!Seen.insert(Op).second)

1648 continue;

1649

1651 if (OpCost + LHSCost >= LHSCost)

1652 continue;

1653

1654 LHSCost += OpCost;

1657 WorkList.append(I->op_begin(), I->op_end());

1658 }

1659

1660

1661 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;

1662 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;

1664 TTI.getArithmeticReductionCost(

1666 IsIntVec ? std::nullopt : std::optional(FMF)) +

1667 TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());

1669 TTI.getArithmeticInstrCost(AddOpCode, ElementType) *

1670 (LShape.NumColumns - 1) +

1671 TTI.getArithmeticInstrCost(MulOpCode, ElementType) *

1672 (LShape.NumColumns);

1673 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))

1674 return;

1675

1676 FusedInsts.insert(MatMul);

1678 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,

1680

1681

1682

1683 if (!CanBeFlattened(Op))

1684 return;

1685

1687 auto It = ShapeMap.find(Op);

1688 if (It != ShapeMap.end()) {

1689 It->second = It->second.t();

1690 return;

1691 }

1692 }

1693

1695

1699 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);

1700 Op->replaceAllUsesWith(NewLoad);

1702 return;

1706 Op->replaceAllUsesWith(Arg);

1707 return;

1708 }

1709 };

1710

1711 for (auto *V : ToFlatten)

1712 FlattenArg(V);

1713

1715

1716

1719

1721 if (IsIntVec)

1723 else {

1725 ConstantFP::get(

1729 }

1730

1731

1733 Result, uint64_t(0));

1735 FusedInsts.insert(MatMul);

1736 ToRemove.push_back(MatMul);

1737 }

1738

1739

1740

1741

1742 unsigned capBlockSize(unsigned BlockSize, unsigned Remainder, Type *EltType) {

1745

1746

1748 if (TTI.isTypeLegal(VecTy))

1749 return Remainder;

1750

1751

1752

1754 return Remainder;

1755

1756

1757

1758 do {

1760 } while (BlockSize > Remainder);

1762 }

1763

1764

1765

1766

1767

1768

1769

1770

1771 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,

1772 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,

1773 bool IsScalarMatrixTransposed, FastMathFlags FMF) {

1774 const unsigned VF = std::max(

1776 .getFixedValue() /

1777 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),

1778 1U);

1779 unsigned R = Result.getNumRows();

1780 unsigned C = Result.getNumColumns();

1781 unsigned M = A.getNumColumns();

1782

1783 bool IsFP = Result.getElementType()->isFloatingPointTy();

1784 assert(A.isColumnMajor() == B.isColumnMajor() &&

1785 Result.isColumnMajor() == A.isColumnMajor() &&

1786 "operands must agree on matrix layout");

1787 unsigned NumComputeOps = 0;

1788

1790

1791 if (A.isColumnMajor()) {

1792

1793

1794

1795 for (unsigned J = 0; J < C; ++J) {

1797

1799

1801

1804 : nullptr;

1805 for (unsigned K = 0; K < M; ++K) {

1808 B.getColumn(IsScalarMatrixTransposed ? K : J),

1809 IsScalarMatrixTransposed ? J : K);

1811 Sum =

1812 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,

1813 IsFP, Builder, FMF.allowContract(), NumComputeOps);

1814 }

1817 }

1818 }

1819 } else {

1820

1821

1822

1823 for (unsigned I = 0; I < R; ++I) {

1826 for (unsigned J = 0; J < C; J += BlockSize) {

1827

1829

1830 Value *Sum = nullptr;

1831 for (unsigned K = 0; K < M; ++K) {

1834 A.getVector(IsScalarMatrixTransposed ? K : I),

1835 IsScalarMatrixTransposed ? I : K);

1837 Sum =

1838 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,

1839 IsFP, Builder, FMF.allowContract(), NumComputeOps);

1840 }

1843 }

1844 }

1845 }

1846 Result.addNumComputeOps(NumComputeOps);

1847 }

1848

1849

1850

1851

1852 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,

1853 CallInst *MatMul) {

1856

1857

1858 if (AA->isNoAlias(LoadLoc, StoreLoc))

1859 return Load->getPointerOperand();

1860

1861

1862

1863

1864

1866

1867

1868

1870 for (BasicBlock *Succ : successors(Check0))

1871 DTUpdates.push_back({DT->Delete, Check0, Succ});

1872

1875 nullptr, "alias_cont");

1878 nullptr, "copy");

1881 nullptr, "no_alias");

1882

1883

1884

1885

1891 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");

1893 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),

1894 "store.end", true, true);

1896 IntPtrTy, "load.begin");

1898 Fusion);

1899

1900

1901

1902

1906 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),

1907 "load.end", true, true);

1909 Fusion);

1910

1911

1914

1915

1916 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());

1917 AllocaInst *Alloca =

1918 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());

1919

1923 PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);

1925 PHI->addIncoming(Load->getPointerOperand(), Check1);

1926 PHI->addIncoming(Alloca, Copy);

1927

1928

1929 DTUpdates.push_back({DT->Insert, Check0, Check1});

1930 DTUpdates.push_back({DT->Insert, Check0, Fusion});

1931 DTUpdates.push_back({DT->Insert, Check1, Copy});

1932 DTUpdates.push_back({DT->Insert, Check1, Fusion});

1933 DT->applyUpdates(DTUpdates);

1934 return PHI;

1935 }

1936

1937 bool isFusionProfitable(CallInst *MatMul) {

1939 return true;

1940

1943

1944 const unsigned R = LShape.NumRows;

1945 const unsigned C = RShape.NumColumns;

1946 const unsigned M = LShape.NumColumns;

1948

1949 const unsigned VF = std::max(

1951 .getFixedValue() /

1953 1U);

1954

1955

1956

1957

1958

1959

1960

1961 if (R <= VF && C == 1)

1962 return false;

1963

1964

1965

1966

1967 unsigned Op0Regs = (R + VF - 1) / VF * M;

1968 unsigned Op1Regs = (M + VF - 1) / VF * C;

1969 return Op0Regs + Op1Regs >

1970 TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true));

1971 }

1972

1973 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {

1974 MatrixTy Res;

1976 for (unsigned I = 0; I < C; ++I)

1978 return Res;

1979 }

1980

1981 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,

1982 Value *RPtr, ShapeInfo RShape, StoreInst *Store) {

1984

1985

1986 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);

1987 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);

1993 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);

1994

1995 Type *TileVecTy =

1997 MatrixTy TileResult;

1998

1999 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());

2000

2002 for (unsigned I = 0; I < TileSize; I++) {

2003 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));

2005 TI.RowLoop.Header->getSingleSuccessor());

2006 TileResult.addVector(Phi);

2008 }

2009

2010

2011

2013

2014 MatrixTy A =

2015 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,

2017 MatrixTy B =

2018 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,

2020 emitMatrixMultiply(TileResult, A, B, Builder, true, false,

2021 getFastMathFlags(MatMul));

2022

2023 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());

2024 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),

2025 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},

2026 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);

2027

2028 for (unsigned I = 0; I < TileResult.getNumVectors(); I++)

2029 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);

2030

2031

2032

2033

2034

2035 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);

2037 "llvm.loop.unroll.count", InnerLoopUnrollCount);

2038 }

2039

2040 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,

2041 StoreInst *Store,

2042 SmallPtrSetImpl<Instruction *> &FusedInsts) {

2044 "Tiling only supported for column-major matrixes at the moment!");

2045 if (!isFusionProfitable(MatMul))

2046 return;

2047

2050

2051 const unsigned R = LShape.NumRows;

2052 const unsigned C = RShape.NumColumns;

2053 const unsigned M = LShape.NumColumns;

2055

2056 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);

2057 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);

2058 Value *CPtr = Store->getPointerOperand();

2059

2061 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);

2062 else {

2064 for (unsigned J = 0; J < C; J += TileSize)

2065 for (unsigned I = 0; I < R; I += TileSize) {

2066 const unsigned TileR = std::min(R - I, unsigned(TileSize));

2067 const unsigned TileC = std::min(C - J, unsigned(TileSize));

2068 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);

2069

2070 for (unsigned K = 0; K < M; K += TileSize) {

2071 const unsigned TileM = std::min(M - K, unsigned(TileSize));

2072 MatrixTy A =

2074 LShape, getIndex(APtr, I), getIndex(APtr, K),

2075 {TileR, TileM}, EltType, Builder);

2076 MatrixTy B =

2078 RShape, getIndex(BPtr, K), getIndex(BPtr, J),

2079 {TileM, TileC}, EltType, Builder);

2080 emitMatrixMultiply(Res, A, B, Builder, true, false,

2081 getFastMathFlags(MatMul));

2082 }

2083 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},

2084 getIndex(CPtr, I), getIndex(CPtr, J), EltType, Builder);

2085 }

2086 }

2087

2088

2089 FusedInsts.insert(Store);

2090 FusedInsts.insert(MatMul);

2091 eraseFromParentAndRemoveFromShapeMap(Store);

2092 eraseFromParentAndRemoveFromShapeMap(MatMul);

2094 FusedInsts.insert(LoadOp0);

2095 eraseFromParentAndRemoveFromShapeMap(LoadOp0);

2096 }

2097 if (LoadOp1 != LoadOp0 && LoadOp1->use_empty()) {

2098 FusedInsts.insert(LoadOp1);

2099 eraseFromParentAndRemoveFromShapeMap(LoadOp1);

2100 }

2101 }

2102

2103

2104

2105

2106

2107 void

2108 LowerMatrixMultiplyFused(CallInst *MatMul,

2109 SmallPtrSetImpl<Instruction *> &FusedInsts,

2112 return;

2113

2114 assert(AA && LI && "Analyses should be available");

2115

2118

2119

2125 auto *EltType =

2129 const unsigned R = LShape.NumRows;

2130 const unsigned M = LShape.NumColumns;

2131 const unsigned C = RShape.NumColumns;

2132

2133 MatrixTy MA;

2134 MatrixTy MB;

2135

2136 Value *Transpose;

2138 MA = getMatrix(A, ShapeInfo(R, M), Builder);

2139 MB = getMatrix(T, ShapeInfo(C, M), Builder);

2140 Transpose = B;

2141 } else {

2142 MA = getMatrix(T, ShapeInfo(R, M), Builder);

2143 MB = getMatrix(B, ShapeInfo(C, M), Builder);

2144 Transpose = A;

2145 }

2146

2147

2148 MatrixTy Result(R, C, EltType);

2149

2150 emitMatrixMultiply(Result, MA, MB, Builder, false, true,

2151 getFastMathFlags(MatMul));

2152

2153 FusedInsts.insert(MatMul);

2157

2158

2159 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);

2160 }

2161 finalizeLowering(MatMul, Result, Builder);

2162 return;

2163 }

2164

2166 return;

2167

2168

2169

2173 if (LoadOp0 && LoadOp1 && Store) {

2174

2175

2176 SetVector<Value *> WorkList;

2179 for (unsigned I = 0; I != WorkList.size(); ++I) {

2180 Value *Current = WorkList[I];

2182 if (!CurrI)

2183 continue;

2185 return;

2186 if (DT->dominates(CurrI, MatMul))

2187 continue;

2188 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())

2189 return;

2192 }

2193

2194 sort(ToHoist, [this](Instruction *A, Instruction *B) {

2195 return DT->dominates(A, B);

2196 });

2197 for (Instruction *I : ToHoist)

2199

2200

2201

2202

2203

2204

2205

2206

2210 bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent &&

2211 LoadOp1->getParent() == StoreParent;

2212 for (unsigned Idx = 0; Idx != LifetimeEnds.size();) {

2213 IntrinsicInst *End = LifetimeEnds[Idx];

2215

2216

2217 if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1))

2218 continue;

2219 if (DT->dominates(Store, End))

2220 continue;

2221

2222

2223 if (FusableOpsInSameBlock && End->getParent() != StoreParent)

2224 continue;

2225

2226

2227

2229 if (!EndLoc.Ptr)

2230 continue;

2231 if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))

2232 continue;

2233

2234

2235

2236

2237 if (End->getParent() == StoreParent) {

2239 continue;

2240 }

2241

2242

2243 ToRemove.push_back(End);

2244 std::swap(LifetimeEnds[Idx], LifetimeEnds.back());

2246 Inc.release();

2247 }

2248

2249 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);

2250 return;

2251 }

2252 }

2253

2254

2255 MatrixTy LowerMultiply(CallInst *MatMul, IRBuilder<> &Builder) {

2259

2260 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);

2261 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);

2262 assert(Lhs.getElementType() == Rhs.getElementType() &&

2263 "Matrix multiply argument element types do not match.");

2264

2265 const unsigned R = LShape.NumRows;

2266 const unsigned C = RShape.NumColumns;

2267 assert(LShape.NumColumns == RShape.NumRows);

2268

2269

2270 MatrixTy Result(R, C, EltType);

2271 assert(Lhs.getElementType() == Result.getElementType() &&

2272 "Matrix multiply result element type does not match arguments.");

2273

2274 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,

2275 getFastMathFlags(MatMul));

2277 }

2278

2279

2280 MatrixTy LowerTranspose(CallInst *Inst, IRBuilder<> &Builder) {

2285 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);

2286

2287 const unsigned NewNumVecs =

2288 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;

2289 const unsigned NewNumElts =

2290 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;

2291

2292 for (unsigned I = 0; I < NewNumVecs; ++I) {

2293

2296

2297 for (auto J : enumerate(InputMatrix.vectors())) {

2299

2300 ResultVector =

2302 }

2303 Result.addVector(ResultVector);

2304 }

2305

2306

2307

2308

2309 return Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)

2310 .addNumExposedTransposes(1);

2311 }

2312

2313

2314 MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,

2316 return LowerLoad(Inst, Ptr, Inst->getAlign(), getIndex(Ptr, SI.getStride()),

2318 }

2319

2320 MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,

2323 getIndex(Ptr, SI.getStride()), Inst->isVolatile(), SI,

2324 Builder);

2325 }

2326

2327 MatrixTy VisitPHI(PHINode *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) {

2328 auto BlockIP = Inst->getParent()->getFirstInsertionPt();

2330 MatrixTy PhiM = getMatrix(Inst, SI, Builder);

2331

2332 for (auto [IncomingV, IncomingB] :

2334

2335

2336

2339 if (auto MaybeIP = IncomingInst->getInsertionPointAfterDef())

2341

2342 MatrixTy OpM = getMatrix(IncomingV, SI, Builder);

2343

2344 for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) {

2345 PHINode *NewPHI = cast(PhiM.getVector(VI));

2346 NewPHI->addIncoming(OpM.getVector(VI), IncomingB);

2347 }

2348 }

2349

2350

2351

2353 return PhiM;

2354 }

2355

2356

2357 MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI,

2361

2363 MatrixTy A = getMatrix(Lhs, SI, Builder);

2364 MatrixTy B = getMatrix(Rhs, SI, Builder);

2365 assert(A.isColumnMajor() == B.isColumnMajor() &&

2366 Result.isColumnMajor() == A.isColumnMajor() &&

2367 "operands must agree on matrix layout");

2368

2370

2373

2374 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *

2375 Result.getNumVectors());

2376 }

2377

2378

2379 MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI,

2382

2384 MatrixTy M = getMatrix(Op, SI, Builder);

2385

2387

2388

2389 auto BuildVectorOp = [&Builder, Inst](Value *Op) {

2391 case Instruction::FNeg:

2393 default:

2395 }

2396 };

2397

2398 for (auto *Vector : M.vectors())

2400

2401 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *

2402 Result.getNumVectors());

2403 }

2404

2405

2406 MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape,

2409

2411 MatrixTy M = getMatrix(Op, Shape, Builder);

2412

2414

2416 auto *NewVTy = VectorType::get(OrigVTy->getElementType(),

2418

2419 for (auto *Vector : M.vectors())

2421

2422 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *

2423 Result.getNumVectors());

2424 }

2425

2426

2427 MatrixTy VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape,

2432

2434 MatrixTy A = getMatrix(OpA, Shape, Builder);

2435 MatrixTy B = getMatrix(OpB, Shape, Builder);

2436

2439 MatrixTy C = getMatrix(Cond, Shape, Builder);

2440 llvm::copy(C.vectors(), std::back_inserter(CondV));

2441 } else {

2442 CondV.resize(A.getNumVectors());

2444 }

2445

2446 for (auto [CV, AV, BV] : llvm::zip_equal(CondV, A.vectors(), B.vectors()))

2448

2449 return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *

2450 Result.getNumVectors());

2451 }

2452

2453

2454

2455

2456 struct ExprLinearizer {

2457 unsigned LengthToBreak = 100;

2458 std::string Str;

2459 raw_string_ostream Stream;

2460 unsigned LineLength = 0;

2461 const DataLayout &DL;

2462

2463

2464

2465 const MapVector<Value *, MatrixTy> &Inst2Matrix;

2466

2467

2468

2469 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;

2470

2471

2472 const SmallSetVector<Value *, 32> &ExprsInSubprogram;

2473

2474

2476

2477

2478

2479 SmallPtrSet<Value *, 8> ReusedExprs;

2480

2481 ExprLinearizer(const DataLayout &DL,

2482 const MapVector<Value *, MatrixTy> &Inst2Matrix,

2483 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,

2484 const SmallSetVector<Value *, 32> &ExprsInSubprogram,

2486 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),

2487 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}

2488

2489 void indent(unsigned N) {

2490 LineLength += N;

2491 for (unsigned i = 0; i < N; i++)

2492 Stream << " ";

2493 }

2494

2495 void lineBreak() {

2496 Stream << "\n";

2497 LineLength = 0;

2498 }

2499

2500 void maybeIndent(unsigned Indent) {

2501 if (LineLength >= LengthToBreak)

2502 lineBreak();

2503

2504 if (LineLength == 0)

2505 indent(Indent);

2506 }

2507

2508 void write(StringRef S) {

2509 LineLength += S.size();

2510 Stream << S;

2511 }

2512

2513 Value *getUnderlyingObjectThroughLoads(Value *V) {

2515 return getUnderlyingObjectThroughLoads(Ptr);

2516 else if (V->getType()->isPointerTy())

2518 return V;

2519 }

2520

2521

2522 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }

2523

2524

2525

2526 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {

2527 auto M = Inst2Matrix.find(V);

2528 if (M == Inst2Matrix.end())

2529 SS << "unknown";

2530 else {

2531 SS << M->second.getNumRows();

2532 SS << "x";

2533 SS << M->second.getNumColumns();

2534 }

2535 }

2536

2537

2538

2539

2540 void writeFnName(CallInst *CI) {

2542 write("");

2543 else {

2545 if (Name.starts_with("llvm.matrix")) {

2547 return;

2548 }

2553 std::string Tmp;

2554 raw_string_ostream SS(Tmp);

2555

2556 switch (II->getIntrinsicID()) {

2557 case Intrinsic::matrix_multiply:

2558 prettyPrintMatrixType(II->getOperand(0), SS);

2559 SS << ".";

2560 prettyPrintMatrixType(II->getOperand(1), SS);

2561 SS << "." << *II->getType()->getScalarType();

2562 break;

2563 case Intrinsic::matrix_transpose:

2564 prettyPrintMatrixType(II->getOperand(0), SS);

2565 SS << "." << *II->getType()->getScalarType();

2566 break;

2567 case Intrinsic::matrix_column_major_load:

2568 prettyPrintMatrixType(II, SS);

2569 SS << "." << *II->getType()->getScalarType();

2570 break;

2571 case Intrinsic::matrix_column_major_store:

2572 prettyPrintMatrixType(II->getOperand(0), SS);

2573 SS << "." << *II->getOperand(0)->getType()->getScalarType();

2574 break;

2575 default:

2577 }

2579 }

2580 }

2581

2582 unsigned getNumShapeArgs(CallInst *CI) const {

2584 switch (II->getIntrinsicID()) {

2585 case Intrinsic::matrix_multiply:

2586 return 3;

2587 case Intrinsic::matrix_transpose:

2588 return 2;

2589 case Intrinsic::matrix_column_major_load:

2590 case Intrinsic::matrix_column_major_store:

2591 return 3;

2592 default:

2593 return 0;

2594 }

2595 }

2596 return 0;

2597 }

2598

2599

2600

2601

2603 V = getUnderlyingObjectThroughLoads(V);

2604 if (V->getType()->isPointerTy()) {

2606 Stream << "stack addr";

2607 LineLength += StringRef("stack addr").size();

2608 } else {

2609 Stream << "addr";

2610 LineLength += StringRef("addr").size();

2611 }

2612 if (V->getName().empty()) {

2613 Stream << " %" << V->getName() << "";

2614 LineLength += V->getName().size() + 2;

2615 }

2616 return;

2617 }

2618

2619 std::string Tmp;

2620 raw_string_ostream TmpStream(Tmp);

2621

2623 TmpStream << CI->getValue();

2625 TmpStream << "constant";

2626 else {

2627 if (isMatrix(V))

2628 TmpStream << "matrix";

2629 else

2630 TmpStream << "scalar";

2631 }

2632 Tmp = std::string(StringRef(Tmp).trim());

2633 LineLength += Tmp.size();

2634 Stream << Tmp;

2635 }

2636

2637

2638

2639

2640 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,

2641 bool ParentShared) {

2643 maybeIndent(Indent);

2644 SmallVector<Value *, 8> Ops;

2645

2646

2647 bool ExprShared = false;

2648

2649

2650 if (!ParentShared) {

2651 auto SI = Shared.find(Expr);

2652 assert(SI != Shared.end() && SI->second.count(Leaf));

2653

2654 for (Value *S : SI->second) {

2655 if (S == Leaf)

2656 continue;

2658 write("shared with remark at line " + std::to_string(DL.getLine()) +

2659 " column " + std::to_string(DL.getCol()) + " (");

2660 }

2661 ExprShared = SI->second.size() > 1;

2662 }

2663

2664 bool Reused = !ReusedExprs.insert(Expr).second;

2665 if (Reused && !ParentReused)

2666 write("(reused) ");

2667

2669 writeFnName(CI);

2670

2673

2674

2675 write("matrix");

2676 return;

2677 } else {

2678 Ops.append(I->value_op_begin(), I->value_op_end());

2679 write(I->getOpcodeName());

2680 }

2681

2683

2684 unsigned NumOpsToBreak = 1;

2686 NumOpsToBreak = 2;

2687

2689 if (Ops.size() > NumOpsToBreak)

2690 lineBreak();

2691

2692 maybeIndent(Indent + 1);

2693 if (isMatrix(Op))

2694 linearizeExpr(Op, Indent + 1, Reused, ExprShared);

2695 else

2697 if (Op != Ops.back())

2699 }

2700

2702 }

2703

2704 const std::string &getResult() {

2705 return Str;

2706 }

2707 };

2708

2709

2710

2711

2712

2713

2714

2715

2716

2717

2718

2719

2720

2721

2722 struct RemarkGenerator {

2723 const MapVector<Value *, MatrixTy> &Inst2Matrix;

2724 OptimizationRemarkEmitter &ORE;

2726 const DataLayout &DL;

2727

2728 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,

2729 OptimizationRemarkEmitter &ORE, Function &Func)

2730 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),

2731 DL(Func.getDataLayout()) {}

2732

2733

2734

2735

2737 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {

2739 for (auto *Expr : ExprsInSubprogram)

2741 any\_of(Expr->users(), [&ExprsInSubprogram](User *U) {

2742 return ExprsInSubprogram.count(U);

2743 }))

2745 return Leaves;

2746 }

2747

2748

2749

2750

2751 void collectSharedInfo(Value *Leaf, Value *V,

2752 const SmallSetVector<Value *, 32> &ExprsInSubprogram,

2753 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {

2754

2755 if (!ExprsInSubprogram.count(V))

2756 return;

2757

2759

2761 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);

2762 }

2763

2764

2765

2766

2767 std::pair<OpInfoTy, OpInfoTy>

2768 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,

2769 const SmallSetVector<Value *, 32> &ExprsInSubprogram,

2770 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {

2771 if (!ExprsInSubprogram.count(Root))

2772 return {};

2773

2774

2775 if (!ReusedExprs.insert(Root).second)

2776 return {};

2777

2778 OpInfoTy SharedCount;

2780

2781 auto I = Shared.find(Root);

2782 auto CM = Inst2Matrix.find(Root);

2783 if (I->second.size() == 1)

2784 Count = CM->second.getOpInfo();

2785 else

2786 SharedCount = CM->second.getOpInfo();

2787

2789 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);

2791 SharedCount += C.second;

2792 }

2793 return {Count, SharedCount};

2794 }

2795

2796 void emitRemarks() {

2797 if (!ORE.allowExtraAnalysis(DEBUG_TYPE))

2798 return;

2799

2800

2801

2802

2803 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;

2804 for (const auto &KV : Inst2Matrix) {

2805 if (Func.getSubprogram()) {

2807 DILocation *Context = I->getDebugLoc();

2810 KV.first);

2812 }

2813 } else {

2814 Subprog2Exprs[nullptr].push_back(KV.first);

2815 }

2816 }

2817 for (auto &KV : Subprog2Exprs) {

2818 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),

2819 KV.second.end());

2820 auto Leaves = getExpressionLeaves(ExprsInSubprogram);

2821

2822 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;

2823 for (Value *Leaf : Leaves)

2824 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);

2825

2826

2827 for (auto *L : Leaves) {

2828

2834 break;

2835 }

2837 }

2838

2839 SmallPtrSet<Value *, 8> ReusedExprs;

2840 OpInfoTy Counts, SharedCounts;

2841 std::tie(Counts, SharedCounts) =

2842 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);

2843

2844 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,

2846

2847 Rem << "Lowered with ";

2848 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "

2849 << ore::NV("NumLoads", Counts.NumLoads) << " loads, "

2850 << ore::NV("NumComputeOps", Counts.NumComputeOps)

2851 << " compute ops, "

2852 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)

2853 << " exposed transposes";

2854

2855 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||

2856 SharedCounts.NumComputeOps > 0) {

2857 Rem << ",\nadditionally "

2858 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "

2859 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "

2860 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)

2861 << " compute ops"

2862 << " are shared with other expressions";

2863 }

2864

2865 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));

2866 ORE.emit(Rem);

2867 }

2868 }

2869 }

2870

2871 std::string

2872 linearize(Value *L,

2873 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,

2874 const SmallSetVector<Value *, 32> &ExprsInSubprogram,

2875 const DataLayout &DL) {

2876 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);

2877 Lin.linearizeExpr(L, 0, false, false);

2878 return Lin.getResult();

2879 }

2880 };

2881};

2882}

2883

2887

2888 LowerMatrixIntrinsics LMT(F, TTI, Minimal ? nullptr : &AM);

2889 if (LMT.Visit()) {

2891 if (!Minimal) {

2894 }

2895 return PA;

2896 }

2898}

2899

2903 OS, MapClassName2PassName);

2904 OS << '<';

2905 if (Minimal)

2906 OS << "minimal";

2907 OS << '>';

2908}

assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")

AMDGPU Register Bank Select

static const Function * getParent(const Value *V)

static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")

static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")

static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")

static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")

#define clEnumValN(ENUMVAL, FLAGNAME, DESC)

#define LLVM_DUMP_METHOD

Mark debug helper function definitions like dump() that should not be stripped from debug builds.

static Type * getIndexType(Value *In)

hexagon Hexagon specific predictive commoning for HVX vectors

This file provides various utilities for inspecting and working with the control flow graph in LLVM I...

const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]

static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)

static DISubprogram * getSubprogram(DIScope *Scope)

Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...

Definition LowerMatrixIntrinsics.cpp:108

static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))

static auto m_AnyAdd(const LTy &L, const RTy &R)

Match any add operation (fp or integer).

Definition LowerMatrixIntrinsics.cpp:130

static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))

static bool isShapePreserving(Value *V)

Definition LowerMatrixIntrinsics.cpp:248

static auto m_AnyMul(const LTy &L, const RTy &R)

Match any mul operation (fp or integer).

Definition LowerMatrixIntrinsics.cpp:124

static cl::opt< unsigned > SplitMatmulRemainderOverThreshold("matrix-split-matmul-remainder-over-threshold", cl::Hidden, cl::desc("Illegal remainder vectors over this size in bits should be split " "in the inner loop of matmul"), cl::init(0))

static bool isSplat(Value *V)

Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).

Definition LowerMatrixIntrinsics.cpp:116

static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))

static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))

static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))

static std::optional< ShapeInfo > computeShapeInfoForInst(Instruction *I, const DenseMap< Value *, ShapeInfo > &ShapeMap)

Return the ShapeInfo for the result of I, it it can be determined.

Definition LowerMatrixIntrinsics.cpp:318

MatrixLayoutTy

Definition LowerMatrixIntrinsics.cpp:87

@ RowMajor

Definition LowerMatrixIntrinsics.cpp:87

@ ColumnMajor

Definition LowerMatrixIntrinsics.cpp:87

static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))

#define DEBUG_TYPE

Definition LowerMatrixIntrinsics.cpp:57

static iterator_range< Use * > getShapedOperandsForInst(Instruction *I)

Return an iterator over the operands of I that should share shape information with I.

Definition LowerMatrixIntrinsics.cpp:308

static Value * computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, unsigned NumElements, Type *EltType, IRBuilder<> &Builder)

Definition LowerMatrixIntrinsics.cpp:174

static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))

static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))

uint64_t IntrinsicInst * II

PowerPC Reduce CR logical Operation

This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.

const SmallVectorImpl< MachineOperand > & Cond

static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)

static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)

This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...

This file defines the SmallVector class.

This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...

#define STATISTIC(VARNAME, DESC)

static SymbolRef::Type getType(const Symbol *Sym)

static const int BlockSize

This pass exposes codegen information to IR-level passes.

static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)

Returns the opcode of Values or ~0 if they do not all agree.

static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)

static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)

Align getAlign() const

Return the alignment of the memory that is being allocated by the instruction.

PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)

Get the result of an analysis pass for a given IR unit.

iterator begin()

Instruction iterator methods.

const Function * getParent() const

Return the enclosing method, or null if none.

reverse_iterator rbegin()

InstListType::reverse_iterator reverse_iterator

const Instruction * getTerminator() const LLVM_READONLY

Returns the terminator instruction if the block is well formed or null if the block is not well forme...

BinaryOps getOpcode() const

Function * getCalledFunction() const

Returns the function called, or null if this is an indirect function invocation or the function signa...

User::op_iterator arg_begin()

Return the iterator pointing to the beginning of the argument list.

MaybeAlign getParamAlign(unsigned ArgNo) const

Extract the alignment for a call or parameter (0=unknown).

Value * getArgOperand(unsigned i) const

User::op_iterator arg_end()

Return the iterator pointing to the end of the argument list.

Instruction::CastOps getOpcode() const

Return the opcode of this CastInst.

static LLVM_ABI ConstantAggregateZero * get(Type *Ty)

LLVM_ABI DISubprogram * getSubprogram() const

Get the subprogram for this scope.

Base class for scope-like contexts.

Subprogram description. Uses SubclassData1.

iterator find(const_arg_type_t< KeyT > Val)

Analysis pass which computes a DominatorTree.

static constexpr ElementCount getFixed(ScalarTy MinVal)

void setAllowContract(bool B=true)

bool allowReassoc() const

Flag queries.

bool allowContract() const

unsigned getNumElements() const

static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)

Intrinsic::ID getIntrinsicID() const LLVM_READONLY

getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...

bool isIntrinsic() const

isIntrinsic - Returns true if the function's name starts with "llvm.".

LLVM_ABI CallInst * CreateFAddReduce(Value *Acc, Value *Src)

Create a sequential vector fadd reduction intrinsic of the source vector.

Value * CreateICmpULT(Value *LHS, Value *RHS, const Twine &Name="")

Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")

AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")

Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")

LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)

Value * CreateZExtOrTrunc(Value *V, Type *DestTy, const Twine &Name="")

Create a ZExt or Trunc from the integer value V to DestTy.

CallInst * CreateMemCpy(Value *Dst, MaybeAlign DstAlign, Value *Src, MaybeAlign SrcAlign, uint64_t Size, bool isVolatile=false, const AAMDNodes &AAInfo=AAMDNodes())

Create and insert a memcpy between the specified pointers.

Value * CreateFAdd(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)

LLVM_ABI Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")

Return a vector value that contains.

LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)

LLVM_ABI CallInst * CreateAddReduce(Value *Src)

Create a vector int add reduction intrinsic of the source vector.

IntegerType * getIntPtrTy(const DataLayout &DL, unsigned AddrSpace=0)

Fetch the type of an integer with size at least as big as that of a pointer in the given address spac...

Value * CreateCast(Instruction::CastOps Op, Value *V, Type *DestTy, const Twine &Name="", MDNode *FPMathTag=nullptr, FMFSource FMFSource={})

void setFastMathFlags(FastMathFlags NewFMF)

Set the fast-math flags to be used with generated fp-math operators.

Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())

LLVM_ABI Value * CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS, FMFSource FMFSource={}, const Twine &Name="")

Create a call to intrinsic ID with 2 operands which is mangled on the first type.

LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")

Create a call to intrinsic ID with Args, mangled using Types.

PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")

ConstantInt * getIntN(unsigned N, uint64_t C)

Get a constant N-bit value, zero extended or truncated from a 64-bit value.

BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)

Create a conditional 'br Cond, TrueDest, FalseDest' instruction.

LLVM_ABI CallInst * CreateUnaryIntrinsic(Intrinsic::ID ID, Value *V, FMFSource FMFSource={}, const Twine &Name="")

Create a call to intrinsic ID with 1 operand which is mangled on its type.

LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)

Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...

Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")

Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)

Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")

Value * CreateBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS, const Twine &Name="", MDNode *FPMathTag=nullptr)

void SetInsertPoint(BasicBlock *TheBB)

This specifies that created instructions should be appended to the end of the specified block.

StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)

Value * CreateFMul(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)

Value * CreateFNeg(Value *V, const Twine &Name="", MDNode *FPMathTag=nullptr)

Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)

This provides a uniform API for creating instructions and inserting them into a basic block: either a...

LLVM_ABI void moveAfter(Instruction *MovePos)

Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...

LLVM_ABI void setFastMathFlags(FastMathFlags FMF)

Convenience function for setting multiple fast-math flags on this instruction, which must be an opera...

LLVM_ABI InstListType::iterator eraseFromParent()

This method unlinks 'this' from the containing basic block and deletes it.

LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY

Convenience function for getting all the fast-math flags, which must be an operator which supports th...

Intrinsic::ID getIntrinsicID() const

Return the intrinsic ID of this intrinsic.

bool isVolatile() const

Return true if this is a load from a volatile memory location.

Align getAlign() const

Return the alignment of the access that is being performed.

TypeSize getValue() const

Analysis pass that exposes the LoopInfo for a function.

PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)

Definition LowerMatrixIntrinsics.cpp:2884

void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)

Definition LowerMatrixIntrinsics.cpp:2900

CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")

Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.

CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")

Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.

static LLVM_ABI MemoryLocation get(const LoadInst *LI)

Return a location with information about the memory reference by the given instruction.

LocationSize Size

The maximum size of the location, in address-units, or UnknownSize if the size is not known.

const Value * Ptr

The address of the start of the location.

static LLVM_ABI MemoryLocation getForArgument(const CallBase *Call, unsigned ArgIdx, const TargetLibraryInfo *TLI)

Return a location representing a particular argument of a call.

void addIncoming(Value *V, BasicBlock *BB)

Add an incoming value to the end of the PHI list.

iterator_range< const_block_iterator > blocks() const

op_range incoming_values()

static LLVM_ABI PoisonValue * get(Type *T)

Static factory methods - Return an 'poison' object of the specified type.

A set of analyses that are preserved following a run of a transformation pass.

static PreservedAnalyses all()

Construct a special preserved set that preserves all passes.

PreservedAnalyses & preserve()

Mark an analysis as preserved.

size_type size() const

Determine the number of elements in the SetVector.

void insert_range(Range &&R)

size_type count(const_arg_type key) const

Count the number of elements of a given key in the SetVector.

bool insert(const value_type &X)

Insert a new element into the SetVector.

bool erase(PtrType Ptr)

Remove pointer from the set.

size_type count(ConstPtrType Ptr) const

count - Return 1 if the specified pointer is in the set, 0 otherwise.

std::pair< iterator, bool > insert(PtrType Ptr)

Inserts Ptr if and only if there is no element in the container equal to Ptr.

bool contains(ConstPtrType Ptr) const

void append(ItTy in_start, ItTy in_end)

Add the specified range to the end of the SmallVector.

void push_back(const T &Elt)

bool isVolatile() const

Return true if this is a store to a volatile memory location.

StringRef - Represent a constant reference to a string, i.e.

StringRef drop_front(size_t N=1) const

Return a StringRef equal to 'this' but with the first N elements dropped.

constexpr size_t size() const

size - Get the string size.

Analysis pass providing the TargetTransformInfo.

@ TCK_RecipThroughput

Reciprocal throughput.

@ SK_Splice

Concatenates elements from the first input vector with elements of the second input vector.

The instances of the Type class are immutable: once they are created, they are never changed.

Type * getScalarType() const

If this is a vector type, return the element type, otherwise return 'this'.

LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY

Return the basic size of this type if it is a primitive type.

LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY

If this is a vector type, return the getPrimitiveSizeInBits value for the element type.

bool isVoidTy() const

Return true if this is 'void'.

UnaryOps getOpcode() const

Value * getOperand(unsigned i) const

LLVM Value Representation.

Type * getType() const

All values are typed, get the type of this value.

user_iterator user_begin()

bool hasOneUse() const

Return true if there is exactly one use of this value.

LLVM_ABI void replaceAllUsesWith(Value *V)

Change all uses of this to point to a new Value.

iterator_range< user_iterator > users()

iterator_range< use_iterator > uses()

LLVM_ABI StringRef getName() const

Return a constant reference to the value's name.

Type * getElementType() const

constexpr ScalarTy getFixedValue() const

An efficient, type-erasing, non-owning reference to a callable.

const ParentTy * getParent() const

self_iterator getIterator()

A range adaptor for a pair of iterators.

This class implements an extremely fast bulk output stream that can only output to a stream.

#define llvm_unreachable(msg)

Marks that the current location is not supposed to be reachable.

constexpr char Align[]

Key for Kernel::Arg::Metadata::mAlign.

constexpr std::underlying_type_t< E > Mask()

Get a bitmask with 1s in all places up to the high-order bit of E's largest value.

@ C

The default llvm calling convention, compatible with C.

@ BasicBlock

Various leaf nodes.

LLVM_ABI StringRef getBaseName(ID id)

Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....

OneUse_match< SubPat > m_OneUse(const SubPat &SP)

TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)

Matches StoreInst.

BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)

class_match< BinaryOperator > m_BinOp()

Match an arbitrary binary operation and ignore it.

specific_intval< false > m_SpecificInt(const APInt &V)

Match a specific integer value or vector with all elements equal to the value.

BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)

bool match(Val *V, const Pattern &P)

specificval_ty m_Specific(const Value *V)

Match if we have a specific specified value.

class_match< ConstantInt > m_ConstantInt()

Match an arbitrary ConstantInt and ignore it.

IntrinsicID_match m_Intrinsic()

Match intrinsic calls like this: m_IntrinsicIntrinsic::fabs(m_Value(X))

BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)

BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)

OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)

Matches LoadInst.

class_match< Value > m_Value()

Match an arbitrary value and ignore it.

match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)

Combine two pattern matchers matching L || R.

ValuesClass values(OptsTy... Options)

Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...

initializer< Ty > init(const Ty &Val)

ElementType

The element type of an SRV or UAV resource.

DiagnosticInfoOptimizationBase::Argument NV

NodeAddr< PhiNode * > Phi

friend class Instruction

Iterator for Instructions in a `BasicBlock.

This is an optimization pass for GlobalISel generic memory operations.

auto drop_begin(T &&RangeOrContainer, size_t N=1)

Return a range covering RangeOrContainer with the first N elements excluded.

void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)

FunctionAddr VTableAddr Value

void fill(R &&Range, T &&Value)

Provide wrappers to std::fill which take ranges instead of having to pass begin/end explicitly.

auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)

Get the size of a range.

detail::zippy< detail::zip_first, T, U, Args... > zip_equal(T &&t, U &&u, Args &&...args)

zip iterator that assumes that all iteratees have the same length.

detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)

auto enumerate(FirstRange &&First, RestRanges &&...Rest)

Given two or more input ranges, returns a new range whose values are tuples (A, B,...

decltype(auto) dyn_cast(const From &Val)

dyn_cast - Return the argument parameter cast to the specified type.

auto successors(const MachineBasicBlock *BB)

bool operator!=(uint64_t V1, const APInt &V2)

iterator_range< T > make_range(T x, T y)

Convenience function for iterating over sub-ranges.

LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & operator+=(DynamicAPInt &A, int64_t B)

iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)

Make a range that does early increment to allow mutation of the underlying range without disrupting i...

LLVM_ABI Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)

Concatenate a list of vectors.

bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)

const Value * getPointerOperand(const Value *V)

A helper function that returns the pointer operand of a load, store or GEP instruction.

LLVM_ABI void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)

Set input string into loop metadata by keeping other values intact.

bool any_of(R &&range, UnaryPredicate P)

Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.

auto reverse(ContainerTy &&C)

void sort(IteratorTy Start, IteratorTy End)

LLVM_ABI raw_ostream & dbgs()

dbgs() - This returns a reference to a raw_ostream for debugging messages.

LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)

FunctionAddr VTableAddr Count

class LLVM_GSL_OWNER SmallVector

Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...

bool isa(const From &Val)

isa - Return true if the parameter to the template is an instance of one of the template type argu...

LLVM_ABI raw_fd_ostream & errs()

This returns a reference to a raw_ostream for standard error.

IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >

@ Mul

Product of integers.

DWARFExpression::Operation Op

raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)

ArrayRef(const T &OneElt) -> ArrayRef< T >

OutputIt copy(R &&Range, OutputIt Out)

LLVM_ABI Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, OnCuIndexOverflow OverflowOptValue, Dwarf64StrOffsetsPromotion StrOffsetsOptValue)

decltype(auto) cast(const From &Val)

cast - Return the argument parameter cast to the specified type.

LLVM_ABI BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)

Split the specified block at the specified instruction.

Align commonAlignment(Align A, uint64_t Offset)

Returns the alignment that satisfies both alignments.

AnalysisManager< Function > FunctionAnalysisManager

Convenience typedef for the Function analysis manager.

LLVM_ABI const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=MaxLookupSearchDepth)

This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....

AAResults AliasAnalysis

Temporary typedef for legacy code that uses a generic AliasAnalysis pointer or reference.

LLVM_ABI llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)

Create a sequential shuffle mask.

void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)

Implement std::swap in terms of BitVector swap.

A CRTP mix-in to automatically provide informational APIs needed for passes.