LLVM: lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

47

48#include

49

50using namespace llvm;

51using namespace PatternMatch;

52

53#define DEBUG_TYPE "lower-matrix-intrinsics"

54

57 cl::desc("Enable/disable fusing matrix instructions."));

58

62 "Tile size for matrix instruction fusion using square-shaped tiles."));

65 cl::desc("Generate loop nest for tiling."));

68 cl::desc("Force matrix instruction fusion even if not profitable."));

71 cl::desc("Allow the use of FMAs if available and profitable. This may "

72 "result in different results, due to less rounding error."));

73

76 cl::desc("Enable/disable matrix shape verification."),

78

80

83 cl::desc("Sets the default matrix layout"),

85 "Use column-major layout"),

87 "Use row-major layout")));

88

91

92

93

95 if (auto *Subprogram = dyn_cast(Scope))

96 return Subprogram;

97 return cast(Scope)->getSubprogram();

98}

99

100

101

103 if (auto *SV = dyn_cast(V))

104 return SV->isZeroEltSplat();

105 return false;

106}

107

108

109template <typename LTy, typename RTy>

110auto m_AnyMul(const LTy &L, const RTy &R) {

112}

113

114

115template <typename LTy, typename RTy>

116auto m_AnyAdd(const LTy &L, const RTy &R) {

118}

119

120namespace {

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

163 unsigned NumElements, Type *EltType,

165

166 assert((!isa(Stride) ||

167 cast(Stride)->getZExtValue() >= NumElements) &&

168 "Stride must be >= the number of elements in the result vector.");

169

170

171 Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");

172

173

174

175 if (isa(VecStart) && cast(VecStart)->isZero())

177 else

178 VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");

179

180 return VecStart;

181}

182

183namespace {

184struct ShapeInfo {

185 unsigned NumRows;

186 unsigned NumColumns;

187

188 bool IsColumnMajor;

189

190 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)

191 : NumRows(NumRows), NumColumns(NumColumns),

193

194 ShapeInfo(Value *NumRows, Value *NumColumns)

197

198 bool operator==(const ShapeInfo &other) {

199 return NumRows == other.NumRows && NumColumns == other.NumColumns;

200 }

201 bool operator!=(const ShapeInfo &other) { return !(*this == other); }

202

203

204

205 operator bool() const {

206 assert(NumRows == 0 || NumColumns != 0);

207 return NumRows != 0;

208 }

209

210 unsigned getStride() const {

211 if (IsColumnMajor)

212 return NumRows;

213 return NumColumns;

214 }

215

216 unsigned getNumVectors() const {

217 if (IsColumnMajor)

218 return NumColumns;

219 return NumRows;

220 }

221

222

223 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }

224};

225}

226

227static bool isUniformShape(Value *V) {

229 if (I)

230 return true;

231

232 switch (I->getOpcode()) {

233 case Instruction::FAdd:

234 case Instruction::FSub:

235 case Instruction::FMul:

236 case Instruction::FNeg:

237 case Instruction::Add:

238 case Instruction::Mul:

239 case Instruction::Sub:

240 return true;

241 default:

242 return false;

243 }

244}

245

246

247static std::optional

253 if (match(I, m_IntrinsicIntrinsic::matrix\_multiply(

255 return ShapeInfo(M, K);

258

259 return ShapeInfo(N, M);

260 }

261 if (match(I, m_IntrinsicIntrinsic::matrix\_column\_major\_store(

264 return ShapeInfo(N, M);

265 if (match(I, m_IntrinsicIntrinsic::matrix\_column\_major\_load(

267 return ShapeInfo(M, N);

270 auto OpShape = ShapeMap.find(MatrixA);

271 if (OpShape != ShapeMap.end())

272 return OpShape->second;

273 }

274

275 if (isUniformShape(I)) {

276

277 for (auto &Op : I->operands()) {

278 auto OpShape = ShapeMap.find(Op.get());

279 if (OpShape != ShapeMap.end())

280 return OpShape->second;

281 }

282 }

283 return std::nullopt;

284}

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

304

305

306

307

308

309class LowerMatrixIntrinsics {

318

319

320 struct OpInfoTy {

321

322 unsigned NumStores = 0;

323

324 unsigned NumLoads = 0;

325

326 unsigned NumComputeOps = 0;

327

328

329

330 unsigned NumExposedTransposes = 0;

331

332 OpInfoTy &operator+=(const OpInfoTy &RHS) {

333 NumStores += RHS.NumStores;

334 NumLoads += RHS.NumLoads;

335 NumComputeOps += RHS.NumComputeOps;

336 NumExposedTransposes += RHS.NumExposedTransposes;

337 return *this;

338 }

339 };

340

341

342

343 class MatrixTy {

345

346 OpInfoTy OpInfo;

347

348 bool IsColumnMajor = true;

349

350 public:

353 : Vectors(Vectors),

355 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)

357

358 unsigned D = isColumnMajor() ? NumColumns : NumRows;

359 for (unsigned J = 0; J < D; ++J)

361 EltTy, isColumnMajor() ? NumRows : NumColumns)));

362 }

363

364 Value *getVector(unsigned i) const { return Vectors[i]; }

365 Value *getColumn(unsigned i) const {

366 assert(isColumnMajor() && "only supported for column-major matrixes");

367 return Vectors[i];

368 }

369 Value *getRow(unsigned i) const {

370 assert(!isColumnMajor() && "only supported for row-major matrixes");

371 return Vectors[i];

372 }

373

374 void setVector(unsigned i, Value *V) { Vectors[i] = V; }

375

376 Type *getElementType() const { return getVectorTy()->getElementType(); }

377

378 unsigned getNumVectors() const {

379 if (isColumnMajor())

380 return getNumColumns();

381 return getNumRows();

382 }

383

384 unsigned getNumColumns() const {

385 if (isColumnMajor())

386 return Vectors.size();

387 else {

388 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");

389 return cast(Vectors[0]->getType())->getNumElements();

390 }

391 }

392 unsigned getNumRows() const {

393 if (isColumnMajor()) {

394 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");

395 return cast(Vectors[0]->getType())->getNumElements();

396 } else

397 return Vectors.size();

398 }

399

402 assert(isColumnMajor() && "only supported for column-major matrixes");

403 return getVectorTy();

404 }

405

407 return cast(Vectors[0]->getType());

408 }

409

411 assert(isColumnMajor() &&

412 "columns() only supported for column-major matrixes");

414 }

415

418 }

419

420

421

423 return Vectors.size() == 1 ? Vectors[0]

425 }

426

427 MatrixTy &addNumLoads(unsigned N) {

428 OpInfo.NumLoads += N;

429 return *this;

430 }

431

432 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }

433

434 MatrixTy &addNumStores(unsigned N) {

435 OpInfo.NumStores += N;

436 return *this;

437 }

438

439 MatrixTy &addNumExposedTransposes(unsigned N) {

440 OpInfo.NumExposedTransposes += N;

441 return *this;

442 }

443

444 MatrixTy &addNumComputeOps(unsigned N) {

445 OpInfo.NumComputeOps += N;

446 return *this;

447 }

448

449 unsigned getNumStores() const { return OpInfo.NumStores; }

450 unsigned getNumLoads() const { return OpInfo.NumLoads; }

451 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }

452

453 const OpInfoTy &getOpInfo() const { return OpInfo; }

454

455 bool isColumnMajor() const { return IsColumnMajor; }

456

457 unsigned getStride() const {

458 if (isColumnMajor())

459 return getNumRows();

460 return getNumColumns();

461 }

462

463

464

465

468 Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);

469 assert(cast(Vec->getType())->getNumElements() >=

470 NumElts &&

471 "Extracted vector will contain poison values");

474 "block");

475 }

476 };

477

478

479

480

481

482

483

484

485

486

487

488

489

490

491

493

494

495

496

498

499

501

502private:

505

506 if (isa(*Inst))

508

510

511 return FMF;

512 }

513

514public:

517 : Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {}

518

519 unsigned getNumOps(Type *VT) {

520 assert(isa(VT) && "Expected vector type");

522 cast(VT)->getNumElements());

523 }

524

525

526 bool isMinimal() const {

527 return !DT;

528 }

529

530

531

532 unsigned getNumOps(Type *ST, unsigned N) {

533 return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() /

537 }

538

539

540

541

542

543

544 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,

547 assert(VType && "MatrixVal must be a vector type");

549 SI.NumRows * SI.NumColumns &&

550 "The vector size must match the number of matrix elements");

551

552

553

554

555

556 auto Found = Inst2ColumnMatrix.find(MatrixVal);

557 if (Found != Inst2ColumnMatrix.end()) {

558 MatrixTy &M = Found->second;

559

560

561 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())

562 return M;

563

564 MatrixVal = M.embedInVector(Builder);

565 }

566

567

569 for (unsigned MaskStart = 0;

570 MaskStart < cast(VType)->getNumElements();

571 MaskStart += SI.getStride()) {

574 "split");

576 }

577

578 return {SplitVecs};

579 }

580

581

582

583 bool setShapeInfo(Value *V, ShapeInfo Shape) {

584 assert(Shape && "Shape not set");

585 if (isa(V) || !supportsShapeInfo(V))

586 return false;

587

588 auto SIter = ShapeMap.find(V);

589 if (SIter != ShapeMap.end()) {

590 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||

591 SIter->second.NumColumns != Shape.NumColumns)) {

592 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"

593 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"

594 << Shape.NumColumns << ") for " << *V << "\n";

596 "Matrix shape verification failed, compilation aborted!");

597 }

598

599 LLVM_DEBUG(dbgs() << " not overriding existing shape: "

600 << SIter->second.NumRows << " "

601 << SIter->second.NumColumns << " for " << *V << "\n");

602 return false;

603 }

604

605 ShapeMap.insert({V, Shape});

606 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns

607 << " for " << *V << "\n");

608 return true;

609 }

610

611

612

613 bool supportsShapeInfo(Value *V) {

614 Instruction *Inst = dyn_cast(V);

615 if (!Inst)

616 return false;

617

619 if (II)

620 switch (II->getIntrinsicID()) {

621 case Intrinsic::matrix_multiply:

622 case Intrinsic::matrix_transpose:

623 case Intrinsic::matrix_column_major_load:

624 case Intrinsic::matrix_column_major_store:

625 return true;

626 default:

627 return false;

628 }

629 return isUniformShape(V) || isa(V) || isa(V);

630 }

631

632

633

634

635

639

640

641

643 while (!WorkList.empty()) {

645

646

647 bool Propagate = false;

648 if (auto SI = computeShapeInfoForInst(Inst, ShapeMap))

649 Propagate = setShapeInfo(Inst, *SI);

650

651 if (Propagate) {

656 }

657 }

658

659 return NewWorkList;

660 }

661

662

663

667

668 auto pushInstruction = [](Value *V,

671 if (I)

673 };

674

675

676

677 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");

678 while (!WorkList.empty()) {

680

681 size_t BeforeProcessingV = WorkList.size();

682 if (!isa(V))

683 continue;

684

690 if (match(V, m_IntrinsicIntrinsic::matrix\_multiply(

693 if (setShapeInfo(MatrixA, {M, N}))

694 pushInstruction(MatrixA, WorkList);

695

696 if (setShapeInfo(MatrixB, {N, K}))

697 pushInstruction(MatrixB, WorkList);

698

699 } else if (match(V, m_IntrinsicIntrinsic::matrix\_transpose(

701

702 if (setShapeInfo(MatrixA, {M, N}))

703 pushInstruction(MatrixA, WorkList);

704 } else if (match(V, m_IntrinsicIntrinsic::matrix\_column\_major\_store(

707 if (setShapeInfo(MatrixA, {M, N})) {

708 pushInstruction(MatrixA, WorkList);

709 }

710 } else if (isa(V) ||

711 match(V, m_IntrinsicIntrinsic::matrix\_column\_major\_load())) {

712

713 } else if (isa(V)) {

714

715

716 } else if (isUniformShape(V)) {

717

718 ShapeInfo Shape = ShapeMap[V];

719 for (Use &U : cast(V)->operands()) {

720 if (setShapeInfo(U.get(), Shape))

721 pushInstruction(U.get(), WorkList);

722 }

723 }

724

725

726

727 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)

729 if (isa(U) && V != U)

730 NewWorkList.push_back(cast(U));

731 }

732 return NewWorkList;

733 }

734

735

736

737

739 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,

744 Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");

745

746

747 setShapeInfo(T0, Shape0.t());

749 Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");

750 setShapeInfo(T1, Shape1.t());

751 return Operation(T0, Shape0.t(), T1, Shape1.t());

752 }

753

754

755

756 void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {

757 auto Iter = ShapeMap.find(Inst);

758 if (Iter != ShapeMap.end())

759 ShapeMap.erase(Iter);

761 }

762

763

764

767 auto *Inst = cast(V);

768

770 return;

771 if (II != BB.rend() && Inst == &*II)

772 ++II;

773 eraseFromParentAndRemoveFromShapeMap(Inst);

774 }

775

776

777

778 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {

779

780

781

782 auto S = ShapeMap.find(&Old);

783 if (S != ShapeMap.end()) {

784 ShapeMap.erase(S);

785 if (supportsShapeInfo(New))

786 ShapeMap.insert({New, S->second});

787 }

789 }

790

791

792

793

794

799

802 if (match(&I, m_IntrinsicIntrinsic::matrix\_transpose(

804 return nullptr;

805

806

808 if (match(TA, m_IntrinsicIntrinsic::matrix\_transpose(m_Value(TATA)))) {

809 updateShapeAndReplaceAllUsesWith(I, TATA);

810 eraseFromParentAndMove(&I, II, BB);

811 eraseFromParentAndMove(TA, II, BB);

812 return nullptr;

813 }

814

815

817 updateShapeAndReplaceAllUsesWith(I, TA);

818 eraseFromParentAndMove(&I, II, BB);

819 return nullptr;

820 }

821

822

823

824 if (match(TA, m_IntrinsicIntrinsic::matrix\_multiply(

827 auto NewInst = distributeTransposes(

828 TAMB, {K, C}, TAMA, {R, K}, Builder,

829 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {

831 Shape0.NumColumns,

832 Shape1.NumColumns, "mmul");

833 });

834 updateShapeAndReplaceAllUsesWith(I, NewInst);

835 eraseFromParentAndMove(&I, II, BB);

836 eraseFromParentAndMove(TA, II, BB);

837 return NewInst;

838 }

839

840

841

842

843

847

848

849 auto NewInst = distributeTransposes(

850 TAMA, {R, C}, TAMB, {R, C}, Builder,

851 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {

852 bool IsFP = I.getType()->isFPOrFPVectorTy();

853 auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")

854 : LocalBuilder.CreateMul(T0, T1, "mmul");

855 auto *Result = cast(Mul);

856 setShapeInfo(Result, Shape0);

858 });

859 updateShapeAndReplaceAllUsesWith(I, NewInst);

860 eraseFromParentAndMove(&I, II, BB);

861 eraseFromParentAndMove(TA, II, BB);

862 return NewInst;

863 }

864

865

866

869 auto NewInst = distributeTransposes(

870 TAMA, {R, C}, TAMB, {R, C}, Builder,

871 [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {

872 bool IsFP = I.getType()->isFPOrFPVectorTy();

873 auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd")

874 : LocalBuilder.CreateAdd(T0, T1, "madd");

875

876 auto *Result = cast(Add);

877 setShapeInfo(Result, Shape0);

879 });

880 updateShapeAndReplaceAllUsesWith(I, NewInst);

881 eraseFromParentAndMove(&I, II, BB);

882 eraseFromParentAndMove(TA, II, BB);

883 return NewInst;

884 }

885

886 return nullptr;

887 }

888

890

892 if (T.use_empty())

893 eraseFromParentAndRemoveFromShapeMap(&T);

894 if (A->use_empty())

895 eraseFromParentAndRemoveFromShapeMap(cast(A));

896 if (A != B && B->use_empty())

897 eraseFromParentAndRemoveFromShapeMap(cast(B));

898 };

899

902

903 if (match(&I, m_IntrinsicIntrinsic::matrix\_multiply(

906 match(A, m_IntrinsicIntrinsic::matrix\_transpose(m_Value(AT))) &&

907 match(B, m_IntrinsicIntrinsic::matrix\_transpose(m_Value((BT))))) {

911 BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());

912 setShapeInfo(M, {C, R});

914 R->getZExtValue());

915 updateShapeAndReplaceAllUsesWith(I, NewInst);

916 CleanupBinOp(I, A, B);

917 }

918

919

920

922 match(A, m_IntrinsicIntrinsic::matrix\_transpose(

924 match(B, m_IntrinsicIntrinsic::matrix\_transpose(

927 auto *Add = Builder.CreateFAdd(AT, BT, "mfadd");

929 Instruction *NewInst = MBuilder.CreateMatrixTranspose(

930 Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");

931 updateShapeAndReplaceAllUsesWith(I, NewInst);

932 assert(computeShapeInfoForInst(NewInst, ShapeMap) ==

933 computeShapeInfoForInst(&I, ShapeMap) &&

934 "Shape of new instruction doesn't match original shape.");

935 CleanupBinOp(I, A, B);

936 if (auto *AddI = dyn_cast(Add)) {

937 setShapeInfo(AddI, {R, C});

939 computeShapeInfoForInst(AddI, ShapeMap).value_or(ShapeMap[AddI]) ==

940 ShapeMap[AddI] &&

941 "Shape of updated addition doesn't match cached shape.");

942 }

943 }

944 }

945

946

947 void optimizeTransposes() {

948

949

953

954 ++II;

957 }

958 }

959

960

961

964 liftTranspose(I);

965 }

966 }

967 }

968

969 bool Visit() {

971

972

973

977 if (II)

978 continue;

979

980 switch (II->getIntrinsicID()) {

981 case Intrinsic::matrix_multiply:

982 case Intrinsic::matrix_transpose:

983 case Intrinsic::matrix_column_major_load:

984 case Intrinsic::matrix_column_major_store:

986 break;

987 default:

988 break;

989 }

990 }

991

992

993 if (WorkList.empty())

994 return false;

995

996 if (AM) {

1001 }

1002

1003

1004 while (!WorkList.empty()) {

1005 WorkList = propagateShapeForward(WorkList);

1006 WorkList = propagateShapeBackward(WorkList);

1007 }

1008

1009 if (!isMinimal()) {

1010 optimizeTransposes();

1012 dbgs() << "Dump after matrix transpose optimization:\n";

1014 }

1015 }

1016

1017 bool Changed = false;

1021

1022

1023

1025 for (auto *BB : RPOT)

1027 if (match(&I, m_IntrinsicIntrinsic::lifetime\_end()))

1028 LifetimeEnds.push_back(cast(&I));

1029 if (ShapeMap.find(&I) == ShapeMap.end())

1030 continue;

1031 if (match(&I, m_IntrinsicIntrinsic::matrix\_multiply()))

1032 MaybeFusableInsts.push_back(cast(&I));

1034 }

1035

1036

1038 for (CallInst *CI : MaybeFusableInsts)

1039 lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI));

1040

1041

1042 for (CallInst *CI : MaybeFusableInsts)

1043 if (!FusedInsts.contains(CI))

1044 LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);

1045

1046 Changed = !FusedInsts.empty();

1047

1048

1050 if (FusedInsts.count(Inst))

1051 continue;

1052

1054

1055 if (CallInst *CInst = dyn_cast(Inst))

1056 Changed |= VisitCallInst(CInst);

1057

1060 if (auto *BinOp = dyn_cast(Inst))

1061 Changed |= VisitBinaryOperator(BinOp);

1062 if (auto *UnOp = dyn_cast(Inst))

1063 Changed |= VisitUnaryOperator(UnOp);

1065 Changed |= VisitLoad(cast(Inst), Op1, Builder);

1067 Changed |= VisitStore(cast(Inst), Op1, Op2, Builder);

1068 }

1069

1070 if (ORE) {

1071 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);

1072 RemarkGen.emitRemarks();

1073 }

1074

1075

1076

1077

1078

1079

1080

1081

1082

1083

1085 for (auto *Inst : reverse(ToRemove)) {

1087 if (auto *Poisoned = dyn_cast(U.getUser()))

1088 PoisonedInsts.insert(Poisoned);

1090 }

1092 PoisonedInsts.erase(Inst);

1093 }

1094 if (!PoisonedInsts.empty()) {

1095

1096 dbgs() << "Poisoned but present instructions:\n";

1097 for (auto *I : PoisonedInsts)

1098 dbgs() << *I << "\n";

1100 }

1101

1102 return Changed;

1103 }

1104

1105

1106 bool VisitCallInst(CallInst *Inst) {

1108 return false;

1109

1111 case Intrinsic::matrix_multiply:

1112 LowerMultiply(Inst);

1113 break;

1114 case Intrinsic::matrix_transpose:

1115 LowerTranspose(Inst);

1116 break;

1117 case Intrinsic::matrix_column_major_load:

1118 LowerColumnMajorLoad(Inst);

1119 break;

1120 case Intrinsic::matrix_column_major_store:

1121 LowerColumnMajorStore(Inst);

1122 break;

1123 default:

1124 return false;

1125 }

1126 return true;

1127 }

1128

1129

1130

1131

1132

1133

1134 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,

1136 Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);

1137 if (Idx == 0)

1138 return InitialAlign;

1139

1140 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);

1141 if (auto *ConstStride = dyn_cast(Stride)) {

1143 ConstStride->getZExtValue() * ElementSizeInBits / 8;

1145 }

1146 return commonAlignment(InitialAlign, ElementSizeInBits / 8);

1147 }

1148

1149

1150

1152 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {

1153 auto *VType = cast(Ty);

1154 Type *EltTy = VType->getElementType();

1158 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {

1159 Value *GEP = computeVectorAddr(

1161 Stride, Shape.getStride(), EltTy, Builder);

1163 VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign),

1164 IsVolatile, "col.load");

1165

1167 }

1168 return Result.addNumLoads(getNumOps(Result.getVectorTy()) *

1169 Result.getNumVectors());

1170 }

1171

1172

1173

1175 ShapeInfo MatrixShape, Value *I, Value *J,

1176 ShapeInfo ResultShape, Type *EltTy,

1178

1180 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);

1181

1184 ResultShape.NumColumns);

1185

1186 return loadMatrix(TileTy, TileStart, Align,

1187 Builder.getInt64(MatrixShape.getStride()), IsVolatile,

1188 ResultShape, Builder);

1189 }

1190

1191

1193 bool IsVolatile, ShapeInfo Shape) {

1195 finalizeLowering(Inst,

1196 loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,

1197 Shape, Builder),

1198 Builder);

1199 }

1200

1201

1202

1203

1204 void LowerColumnMajorLoad(CallInst *Inst) {

1206 "Intrinsic only supports column-major layout!");

1210 cast(Inst->getArgOperand(2))->isOne(),

1211 {Inst->getArgOperand(3), Inst->getArgOperand(4)});

1212 }

1213

1214

1215

1216 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,

1217 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,

1220 Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);

1221

1224 StoreVal.getNumColumns());

1225

1226 storeMatrix(TileTy, StoreVal, TileStart, MAlign,

1227 Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);

1228 }

1229

1230

1231

1232 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,

1235 auto VType = cast(Ty);

1237 for (auto Vec : enumerate(StoreVal.vectors())) {

1238 Value *GEP = computeVectorAddr(

1239 EltPtr,

1241 Vec.index()),

1242 Stride, StoreVal.getStride(), VType->getElementType(), Builder);

1244 getAlignForIndex(Vec.index(), Stride,

1245 VType->getElementType(),

1246 MAlign),

1247 IsVolatile);

1248 }

1249 return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *

1250 StoreVal.getNumVectors());

1251 }

1252

1253

1255 Value *Stride, bool IsVolatile, ShapeInfo Shape) {

1257 auto StoreVal = getMatrix(Matrix, Shape, Builder);

1258 finalizeLowering(Inst,

1259 storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,

1260 IsVolatile, Builder),

1261 Builder);

1262 }

1263

1264

1265

1266

1267 void LowerColumnMajorStore(CallInst *Inst) {

1269 "Intrinsic only supports column-major layout!");

1274 cast(Inst->getArgOperand(3))->isOne(),

1275 {Inst->getArgOperand(4), Inst->getArgOperand(5)});

1276 }

1277

1278

1281

1282

1283 unsigned BlockNumElts =

1284 cast(Block->getType())->getNumElements();

1285 unsigned NumElts = cast(Col->getType())->getNumElements();

1286 assert(NumElts >= BlockNumElts && "Too few elements for current block");

1287

1290

1291

1292

1294 unsigned i;

1295 for (i = 0; i < I; i++)

1296 Mask.push_back(i);

1297

1298 unsigned VecNumElts =

1299 cast(Col->getType())->getNumElements();

1300 for (; i < I + BlockNumElts; i++)

1301 Mask.push_back(i - I + VecNumElts);

1302

1303 for (; i < VecNumElts; i++)

1304 Mask.push_back(i);

1305

1307 }

1308

1310 IRBuilder<> &Builder, bool AllowContraction,

1311 unsigned &NumComputeOps) {

1312 NumComputeOps += getNumOps(A->getType());

1313 if (!Sum)

1315

1316 if (UseFPOp) {

1317 if (AllowContraction) {

1318

1319

1320 return Builder.CreateIntrinsic(Intrinsic::fmuladd, A->getType(),

1321 {A, B, Sum});

1322 }

1323 NumComputeOps += getNumOps(A->getType());

1326 }

1327

1328 NumComputeOps += getNumOps(A->getType());

1331 }

1332

1333

1334

1335

1336

1337

1340 auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));

1341 (void)inserted;

1342 assert(inserted.second && "multiple matrix lowering mapping");

1343

1345 Value *Flattened = nullptr;

1347 if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {

1348 if (!Flattened)

1349 Flattened = Matrix.embedInVector(Builder);

1350 U.set(Flattened);

1351 }

1352 }

1353 }

1354

1355

1356

1357

1358 void lowerDotProduct(CallInst *MatMul,

1361 if (FusedInsts.contains(MatMul) ||

1362 MatrixLayout != MatrixLayoutTy::ColumnMajor)

1363 return;

1366

1367 if (LShape.NumRows != 1 || RShape.NumColumns != 1)

1368 return;

1369

1372

1374 bool IsIntVec = ElementType->isIntegerTy();

1375

1376

1378 return;

1379

1380 auto CanBeFlattened = [](Value *Op) {

1382 return true;

1386 m_CombineOr(m_IntrinsicIntrinsic::matrix\_transpose(),

1387 m_IntrinsicIntrinsic::matrix\_column\_major\_load(

1389 };

1390

1391

1392

1393 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {

1394 if (ShapeMap.find(Op) == ShapeMap.end())

1396

1397 if (!isa(Op))

1399

1402

1403 if (!CanBeFlattened(Op)) {

1405

1406 for (unsigned I = 1; I < N; ++I)

1407 EmbedCost +=

1410 return EmbedCost;

1411 }

1412

1416 EltTy) *

1417 N;

1419 cast(Op)->getOpcode(), VecTy);

1420 return NewCost - OriginalCost;

1421 }

1422

1423 if (match(Op, m_IntrinsicIntrinsic::matrix\_transpose())) {

1424

1425

1426

1428 for (unsigned I = 1; I < N; ++I)

1429 EmbedCost -=

1432 return EmbedCost;

1433 }

1434

1435

1436 if (N == 1)

1438

1441 };

1442

1443

1444

1445

1451 while (!WorkList.empty()) {

1453 if (!Seen.insert(Op).second)

1454 continue;

1455

1457 if (OpCost + LHSCost >= LHSCost)

1458 continue;

1459

1460 LHSCost += OpCost;

1462 if (auto *I = dyn_cast(Op))

1463 WorkList.append(I->op_begin(), I->op_end());

1464 }

1465

1466

1467 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;

1468 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;

1471 AddOpCode, cast(LHS->getType()),

1472 IsIntVec ? std::nullopt : std::optional(FMF)) +

1476 (LShape.NumColumns - 1) +

1478 (LShape.NumColumns);

1479 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))

1480 return;

1481

1482 FusedInsts.insert(MatMul);

1484 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,

1486

1487

1488

1489 if (!CanBeFlattened(Op))

1490 return;

1491

1493 auto It = ShapeMap.find(Op);

1494 if (It != ShapeMap.end()) {

1495 It->second = It->second.t();

1496 return;

1497 }

1498 }

1499

1500 FusedInsts.insert(cast(Op));

1501

1503 if (match(Op, m_IntrinsicIntrinsic::matrix\_column\_major\_load(

1505 auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);

1506 Op->replaceAllUsesWith(NewLoad);

1507 eraseFromParentAndRemoveFromShapeMap(cast(Op));

1508 return;

1509 } else if (match(Op, m_IntrinsicIntrinsic::matrix\_transpose(

1511 ToRemove.push_back(cast(Op));

1512 Op->replaceAllUsesWith(Arg);

1513 return;

1514 }

1515 };

1516

1517 for (auto *V : ToFlatten)

1518 FlattenArg(V);

1519

1521

1522

1525

1527 if (IsIntVec)

1529 else {

1531 ConstantFP::get(cast(LHS->getType())->getElementType(),

1532 0.0),

1535 }

1536

1537

1541 FusedInsts.insert(MatMul);

1543 }

1544

1545

1546

1547

1548

1549

1550

1551

1552 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,

1553 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,

1554 bool IsScalarMatrixTransposed, FastMathFlags FMF) {

1555 const unsigned VF = std::max(

1558 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),

1559 1U);

1560 unsigned R = Result.getNumRows();

1561 unsigned C = Result.getNumColumns();

1562 unsigned M = A.getNumColumns();

1563

1564 bool IsFP = Result.getElementType()->isFloatingPointTy();

1565 assert(A.isColumnMajor() == B.isColumnMajor() &&

1566 Result.isColumnMajor() == A.isColumnMajor() &&

1567 "operands must agree on matrix layout");

1568 unsigned NumComputeOps = 0;

1569

1571

1572 if (A.isColumnMajor()) {

1573

1574

1575

1576 for (unsigned J = 0; J < C; ++J) {

1578

1579 bool isSumZero = isa(Result.getColumn(J));

1580

1582

1585

1587 : nullptr;

1588 for (unsigned K = 0; K < M; ++K) {

1591 B.getColumn(IsScalarMatrixTransposed ? K : J),

1592 IsScalarMatrixTransposed ? J : K);

1594 Sum =

1595 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,

1596 IsFP, Builder, FMF.allowContract(), NumComputeOps);

1597 }

1600 }

1601 }

1602 } else {

1603

1604

1605

1606 for (unsigned I = 0; I < R; ++I) {

1608 bool isSumZero = isa(Result.getRow(I));

1609 for (unsigned J = 0; J < C; J += BlockSize) {

1610

1613

1614 Value *Sum = nullptr;

1615 for (unsigned K = 0; K < M; ++K) {

1618 A.getVector(IsScalarMatrixTransposed ? K : I),

1619 IsScalarMatrixTransposed ? I : K);

1621 Sum =

1622 createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,

1623 IsFP, Builder, FMF.allowContract(), NumComputeOps);

1624 }

1627 }

1628 }

1629 }

1630 Result.addNumComputeOps(NumComputeOps);

1631 }

1632

1633

1634

1635

1640

1641

1642 if (AA->isNoAlias(LoadLoc, StoreLoc))

1643 return Load->getPointerOperand();

1644

1645

1646

1647

1648

1650

1651

1652

1656

1659 nullptr, "alias_cont");

1662 nullptr, "copy");

1665 nullptr, "no_alias");

1666

1667

1668

1669

1675 const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");

1677 StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),

1678 "store.end", true, true);

1680 IntPtrTy, "load.begin");

1682 Fusion);

1683

1684

1685

1686

1690 LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),

1691 "load.end", true, true);

1693 Fusion);

1694

1695

1697 auto *VT = cast(Load->getType());

1698

1699

1700 auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements());

1702 Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace());

1703

1708 PHI->addIncoming(Load->getPointerOperand(), Check0);

1709 PHI->addIncoming(Load->getPointerOperand(), Check1);

1710 PHI->addIncoming(Alloca, Copy);

1711

1712

1718 return PHI;

1719 }

1720

1721 bool isFusionProfitable(CallInst *MatMul) {

1723 return true;

1724

1727

1728 const unsigned R = LShape.NumRows;

1729 const unsigned C = RShape.NumColumns;

1730 const unsigned M = LShape.NumColumns;

1731 auto *EltType = cast(MatMul->getType())->getElementType();

1732

1733 const unsigned VF = std::max(

1737 1U);

1738

1739

1740

1741

1742

1743

1744

1745 if (R <= VF && C == 1)

1746 return false;

1747

1748

1749

1750

1751 unsigned Op0Regs = (R + VF - 1) / VF * M;

1752 unsigned Op1Regs = (M + VF - 1) / VF * C;

1753 return Op0Regs + Op1Regs >

1755 }

1756

1757 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {

1758 MatrixTy Res;

1760 for (unsigned I = 0; I < C; ++I)

1762 return Res;

1763 }

1764

1765 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,

1767 auto *EltType = cast(MatMul->getType())->getElementType();

1768

1769

1770 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);

1771 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);

1772 Instruction *InsertI = cast(MatMul);

1777 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU, *LI);

1778

1779 Type *TileVecTy =

1781 MatrixTy TileResult;

1782

1783 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());

1784

1786 for (unsigned I = 0; I < TileSize; I++) {

1787 auto *Phi = Builder.CreatePHI(TileVecTy, 2, "result.vec." + Twine(I));

1789 TI.RowLoop.Header->getSingleSuccessor());

1790 TileResult.addVector(Phi);

1792 }

1793

1794

1795

1797

1798 MatrixTy A =

1799 loadMatrix(LPtr, {}, false, LShape, TI.RowLoop.Index, TI.KLoop.Index,

1801 MatrixTy B =

1802 loadMatrix(RPtr, {}, false, RShape, TI.KLoop.Index, TI.ColumnLoop.Index,

1804 emitMatrixMultiply(TileResult, A, B, Builder, true, false,

1805 getFastMathFlags(MatMul));

1806

1807 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());

1808 storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(),

1809 Store->isVolatile(), {LShape.NumRows, RShape.NumColumns},

1810 TI.RowLoop.Index, TI.ColumnLoop.Index, EltType, Builder);

1811

1812 for (unsigned I = 0; I < TileResult.getNumVectors(); I++)

1813 ColumnPhis[I]->addIncoming(TileResult.getVector(I), TI.KLoop.Latch);

1814

1815

1816

1817

1818

1819 unsigned InnerLoopUnrollCount = std::min(10u, LShape.NumColumns / TileSize);

1821 "llvm.loop.unroll.count", InnerLoopUnrollCount);

1822 }

1823

1828 "Tiling only supported for column-major matrixes at the moment!");

1829 if (!isFusionProfitable(MatMul))

1830 return;

1831

1834

1835 const unsigned R = LShape.NumRows;

1836 const unsigned C = RShape.NumColumns;

1837 const unsigned M = LShape.NumColumns;

1838 auto *EltType = cast(MatMul->getType())->getElementType();

1839

1840 Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);

1841 Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);

1842 Value *CPtr = Store->getPointerOperand();

1843

1845 createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store);

1846 else {

1848 for (unsigned J = 0; J < C; J += TileSize)

1849 for (unsigned I = 0; I < R; I += TileSize) {

1850 const unsigned TileR = std::min(R - I, unsigned(TileSize));

1851 const unsigned TileC = std::min(C - J, unsigned(TileSize));

1852 MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);

1853

1854 for (unsigned K = 0; K < M; K += TileSize) {

1855 const unsigned TileM = std::min(M - K, unsigned(TileSize));

1856 MatrixTy A =

1859 {TileR, TileM}, EltType, Builder);

1860 MatrixTy B =

1863 {TileM, TileC}, EltType, Builder);

1864 emitMatrixMultiply(Res, A, B, Builder, true, false,

1865 getFastMathFlags(MatMul));

1866 }

1867 storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},

1869 Builder);

1870 }

1871 }

1872

1873

1874 FusedInsts.insert(Store);

1875 FusedInsts.insert(MatMul);

1876 eraseFromParentAndRemoveFromShapeMap(Store);

1877 eraseFromParentAndRemoveFromShapeMap(MatMul);

1879 FusedInsts.insert(LoadOp0);

1880 eraseFromParentAndRemoveFromShapeMap(LoadOp0);

1881 }

1882 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {

1883 FusedInsts.insert(LoadOp1);

1884 eraseFromParentAndRemoveFromShapeMap(LoadOp1);

1885 }

1886 }

1887

1888

1889

1890

1891

1892 void

1893 LowerMatrixMultiplyFused(CallInst *MatMul,

1897 return;

1898

1899 assert(AA && LI && "Analyses should be available");

1900

1903

1904

1906 if (MatrixLayout == MatrixLayoutTy::ColumnMajor

1907 ? match(B, m_IntrinsicIntrinsic::matrix\_transpose(m_Value(T)))

1908 : match(A, m_IntrinsicIntrinsic::matrix\_transpose(m_Value(T)))) {

1910 auto *EltType = cast(MatMul->getType())->getElementType();

1913 const unsigned R = LShape.NumRows;

1914 const unsigned M = LShape.NumColumns;

1915 const unsigned C = RShape.NumColumns;

1916

1917 MatrixTy MA;

1918 MatrixTy MB;

1919

1920 Value *Transpose;

1921 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {

1922 MA = getMatrix(A, ShapeInfo(R, M), Builder);

1923 MB = getMatrix(T, ShapeInfo(C, M), Builder);

1924 Transpose = B;

1925 } else {

1926 MA = getMatrix(T, ShapeInfo(R, M), Builder);

1927 MB = getMatrix(B, ShapeInfo(C, M), Builder);

1928 Transpose = A;

1929 }

1930

1931

1932 MatrixTy Result(R, C, EltType);

1933

1934 emitMatrixMultiply(Result, MA, MB, Builder, false, true,

1935 getFastMathFlags(MatMul));

1936

1937 FusedInsts.insert(MatMul);

1939 FusedInsts.insert(cast(Transpose));

1940 ToRemove.push_back(cast(Transpose));

1941

1942

1943 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);

1944 }

1945 finalizeLowering(MatMul, Result, Builder);

1946 return;

1947 }

1948

1950 return;

1951

1952

1953

1954 auto *LoadOp0 = dyn_cast(A);

1955 auto *LoadOp1 = dyn_cast(B);

1956 auto *Store = dyn_cast(*MatMul->user_begin());

1957 if (LoadOp0 && LoadOp1 && Store) {

1958

1959

1963 for (unsigned I = 0; I != WorkList.size(); ++I) {

1964 Value *Current = WorkList[I];

1965 auto *CurrI = dyn_cast(Current);

1966 if (!CurrI)

1967 continue;

1968 if (isa(CurrI))

1969 return;

1970 if (DT->dominates(CurrI, MatMul))

1971 continue;

1972 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())

1973 return;

1975 WorkList.insert(CurrI->op_begin(), CurrI->op_end());

1976 }

1977

1980 });

1982 I->moveBefore(MatMul);

1983

1984

1985

1986

1987

1988

1989

1990

1994 bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent &&

1995 LoadOp1->getParent() == StoreParent;

1996 for (unsigned Idx = 0; Idx != LifetimeEnds.size();) {

1999

2000

2002 continue;

2004 continue;

2005

2006

2007 if (FusableOpsInSameBlock && End->getParent() != StoreParent)

2008 continue;

2009

2010

2011

2013 if (!EndLoc.Ptr)

2014 continue;

2015 if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))

2016 continue;

2017

2018

2019

2020

2021 if (End->getParent() == StoreParent) {

2022 End->moveAfter(Store);

2023 continue;

2024 }

2025

2026

2030 Inc.release();

2031 }

2032

2033 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);

2034 return;

2035 }

2036 }

2037

2038

2039 void LowerMultiply(CallInst *MatMul) {

2041 auto *EltType = cast(MatMul->getType())->getElementType();

2044

2045 const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);

2046 const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);

2047 assert(Lhs.getElementType() == Rhs.getElementType() &&

2048 "Matrix multiply argument element types do not match.");

2049

2050 const unsigned R = LShape.NumRows;

2051 const unsigned C = RShape.NumColumns;

2052 assert(LShape.NumColumns == RShape.NumRows);

2053

2054

2055 MatrixTy Result(R, C, EltType);

2056 assert(Lhs.getElementType() == Result.getElementType() &&

2057 "Matrix multiply result element type does not match arguments.");

2058

2059 emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,

2060 getFastMathFlags(MatMul));

2061 finalizeLowering(MatMul, Result, Builder);

2062 }

2063

2064

2065 void LowerTranspose(CallInst *Inst) {

2071 MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);

2072

2073 const unsigned NewNumVecs =

2074 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;

2075 const unsigned NewNumElts =

2076 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;

2077

2078 for (unsigned I = 0; I < NewNumVecs; ++I) {

2079

2082

2083 for (auto J : enumerate(InputMatrix.vectors())) {

2085

2086 ResultVector =

2088 }

2089 Result.addVector(ResultVector);

2090 }

2091

2092

2093

2094

2095 finalizeLowering(

2096 Inst,

2097 Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)

2098 .addNumExposedTransposes(1),

2099 Builder);

2100 }

2101

2102

2104 auto I = ShapeMap.find(Inst);

2105 if (I == ShapeMap.end())

2106 return false;

2107

2110 I->second);

2111 return true;

2112 }

2113

2116 auto I = ShapeMap.find(StoredVal);

2117 if (I == ShapeMap.end())

2118 return false;

2119

2122 I->second);

2123 return true;

2124 }

2125

2126

2128 auto I = ShapeMap.find(Inst);

2129 if (I == ShapeMap.end())

2130 return false;

2131

2134

2136 ShapeInfo &Shape = I->second;

2137

2139 MatrixTy A = getMatrix(Lhs, Shape, Builder);

2140 MatrixTy B = getMatrix(Rhs, Shape, Builder);

2141 assert(A.isColumnMajor() == B.isColumnMajor() &&

2142 Result.isColumnMajor() == A.isColumnMajor() &&

2143 "operands must agree on matrix layout");

2144

2146

2147

2148 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {

2150 case Instruction::Add:

2151 return Builder.CreateAdd(LHS, RHS);

2152 case Instruction::Mul:

2153 return Builder.CreateMul(LHS, RHS);

2154 case Instruction::Sub:

2155 return Builder.CreateSub(LHS, RHS);

2156 case Instruction::FAdd:

2157 return Builder.CreateFAdd(LHS, RHS);

2158 case Instruction::FMul:

2159 return Builder.CreateFMul(LHS, RHS);

2160 case Instruction::FSub:

2161 return Builder.CreateFSub(LHS, RHS);

2162 default:

2164 }

2165 };

2166

2167 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)

2168 Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));

2169

2170 finalizeLowering(Inst,

2171 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *

2172 Result.getNumVectors()),

2173 Builder);

2174 return true;

2175 }

2176

2177

2179 auto I = ShapeMap.find(Inst);

2180 if (I == ShapeMap.end())

2181 return false;

2182

2184

2186 ShapeInfo &Shape = I->second;

2187

2189 MatrixTy M = getMatrix(Op, Shape, Builder);

2190

2192

2193

2194 auto BuildVectorOp = [&Builder, Inst](Value *Op) {

2196 case Instruction::FNeg:

2198 default:

2200 }

2201 };

2202

2203 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)

2204 Result.addVector(BuildVectorOp(M.getVector(I)));

2205

2206 finalizeLowering(Inst,

2207 Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *

2208 Result.getNumVectors()),

2209 Builder);

2210 return true;

2211 }

2212

2213

2214

2215

2216 struct ExprLinearizer {

2217 unsigned LengthToBreak = 100;

2218 std::string Str;

2220 unsigned LineLength = 0;

2222

2223

2224

2226

2227

2228

2230

2231

2233

2234

2236

2237

2238

2240

2241 ExprLinearizer(const DataLayout &DL,

2246 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),

2247 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}

2248

2249 void indent(unsigned N) {

2250 LineLength += N;

2251 for (unsigned i = 0; i < N; i++)

2252 Stream << " ";

2253 }

2254

2255 void lineBreak() {

2256 Stream << "\n";

2257 LineLength = 0;

2258 }

2259

2260 void maybeIndent(unsigned Indent) {

2261 if (LineLength >= LengthToBreak)

2262 lineBreak();

2263

2264 if (LineLength == 0)

2266 }

2267

2269 LineLength += S.size();

2270 Stream << S;

2271 }

2272

2273 Value *getUnderlyingObjectThroughLoads(Value *V) {

2275 return getUnderlyingObjectThroughLoads(Ptr);

2276 else if (V->getType()->isPointerTy())

2278 return V;

2279 }

2280

2281

2282 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }

2283

2284

2285

2287 auto M = Inst2Matrix.find(V);

2288 if (M == Inst2Matrix.end())

2289 SS << "unknown";

2290 else {

2291 SS << M->second.getNumRows();

2292 SS << "x";

2293 SS << M->second.getNumColumns();

2294 }

2295 }

2296

2297

2298

2299

2300 void writeFnName(CallInst *CI) {

2302 write("");

2303 else {

2305 if (Name.starts_with("llvm.matrix")) {

2307 return;

2308 }

2309 auto *II = cast(CI);

2313 std::string Tmp;

2315

2316 switch (II->getIntrinsicID()) {

2317 case Intrinsic::matrix_multiply:

2318 prettyPrintMatrixType(II->getOperand(0), SS);

2319 SS << ".";

2320 prettyPrintMatrixType(II->getOperand(1), SS);

2321 SS << "." << *II->getType()->getScalarType();

2322 break;

2323 case Intrinsic::matrix_transpose:

2324 prettyPrintMatrixType(II->getOperand(0), SS);

2325 SS << "." << *II->getType()->getScalarType();

2326 break;

2327 case Intrinsic::matrix_column_major_load:

2328 prettyPrintMatrixType(II, SS);

2329 SS << "." << *II->getType()->getScalarType();

2330 break;

2331 case Intrinsic::matrix_column_major_store:

2332 prettyPrintMatrixType(II->getOperand(0), SS);

2333 SS << "." << *II->getOperand(0)->getType()->getScalarType();

2334 break;

2335 default:

2337 }

2339 }

2340 }

2341

2342 unsigned getNumShapeArgs(CallInst *CI) const {

2344 switch (II->getIntrinsicID()) {

2345 case Intrinsic::matrix_multiply:

2346 return 3;

2347 case Intrinsic::matrix_transpose:

2348 return 2;

2349 case Intrinsic::matrix_column_major_load:

2350 case Intrinsic::matrix_column_major_store:

2351 return 3;

2352 default:

2353 return 0;

2354 }

2355 }

2356 return 0;

2357 }

2358

2359

2360

2361

2363 V = getUnderlyingObjectThroughLoads(V);

2364 if (V->getType()->isPointerTy()) {

2365 if (isa(V)) {

2366 Stream << "stack addr";

2368 } else {

2369 Stream << "addr";

2371 }

2372 if (V->getName().empty()) {

2373 Stream << " %" << V->getName() << "";

2374 LineLength += V->getName().size() + 2;

2375 }

2376 return;

2377 }

2378

2379 std::string Tmp;

2381

2382 if (auto *CI = dyn_cast(V))

2383 TmpStream << CI->getValue();

2384 else if (isa(V))

2385 TmpStream << "constant";

2386 else {

2387 if (isMatrix(V))

2388 TmpStream << "matrix";

2389 else

2390 TmpStream << "scalar";

2391 }

2392 Tmp = std::string(StringRef(Tmp).trim());

2393 LineLength += Tmp.size();

2394 Stream << Tmp;

2395 }

2396

2397

2398

2399

2400 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,

2401 bool ParentShared) {

2402 auto *I = cast(Expr);

2403 maybeIndent(Indent);

2405

2406

2407 bool ExprShared = false;

2408

2409

2410 if (!ParentShared) {

2411 auto SI = Shared.find(Expr);

2412 assert(SI != Shared.end() && SI->second.count(Leaf));

2413

2414 for (Value *S : SI->second) {

2415 if (S == Leaf)

2416 continue;

2417 DebugLoc DL = cast(S)->getDebugLoc();

2418 write("shared with remark at line " + std::to_string(DL.getLine()) +

2419 " column " + std::to_string(DL.getCol()) + " (");

2420 }

2421 ExprShared = SI->second.size() > 1;

2422 }

2423

2424 bool Reused = !ReusedExprs.insert(Expr).second;

2425 if (Reused && !ParentReused)

2426 write("(reused) ");

2427

2428 if (auto *CI = dyn_cast(I)) {

2429 writeFnName(CI);

2430

2432 } else if (isa(Expr)) {

2433

2434

2435 write("matrix");

2436 return;

2437 } else {

2438 Ops.append(I->value_op_begin(), I->value_op_end());

2439 write(std::string(I->getOpcodeName()));

2440 }

2441

2442 write(std::string("("));

2443

2444 unsigned NumOpsToBreak = 1;

2445 if (match(Expr, m_IntrinsicIntrinsic::matrix\_column\_major\_load()))

2446 NumOpsToBreak = 2;

2447

2449 if (Ops.size() > NumOpsToBreak)

2450 lineBreak();

2451

2452 maybeIndent(Indent + 1);

2453 if (isMatrix(Op))

2454 linearizeExpr(Op, Indent + 1, Reused, ExprShared);

2455 else

2457 if (Op != Ops.back())

2459 }

2460

2462 }

2463

2464 const std::string &getResult() {

2465 return Str;

2466 }

2467 };

2468

2469

2470

2471

2472

2473

2474

2475

2476

2477

2478

2479

2480

2481

2482 struct RemarkGenerator {

2487

2490 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),

2491 DL(Func.getDataLayout()) {}

2492

2493

2494

2495

2499 for (auto *Expr : ExprsInSubprogram)

2502 return ExprsInSubprogram.count(U);

2503 }))

2505 return Leaves;

2506 }

2507

2508

2509

2510

2511 void collectSharedInfo(Value *Leaf, Value *V,

2514

2515 if (!ExprsInSubprogram.count(V))

2516 return;

2517

2518 Shared[V].insert(Leaf);

2519

2520 for (Value *Op : cast(V)->operand_values())

2521 collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);

2522 }

2523

2524

2525

2526

2527 std::pair<OpInfoTy, OpInfoTy>

2531 if (!ExprsInSubprogram.count(Root))

2532 return {};

2533

2534

2535 if (!ReusedExprs.insert(Root).second)

2536 return {};

2537

2538 OpInfoTy SharedCount;

2539 OpInfoTy Count;

2540

2541 auto I = Shared.find(Root);

2542 auto CM = Inst2Matrix.find(Root);

2543 if (I->second.size() == 1)

2544 Count = CM->second.getOpInfo();

2545 else

2546 SharedCount = CM->second.getOpInfo();

2547

2548 for (Value *Op : cast(Root)->operand_values()) {

2549 auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);

2550 Count += C.first;

2551 SharedCount += C.second;

2552 }

2553 return {Count, SharedCount};

2554 }

2555

2556 void emitRemarks() {

2558 return;

2559

2560

2561

2562

2564 for (const auto &KV : Inst2Matrix) {

2565 if (Func.getSubprogram()) {

2566 auto *I = cast(KV.first);

2568 while (Context) {

2569 Subprog2Exprs[getSubprogram(Context->getScope())].push_back(

2570 KV.first);

2572 }

2573 } else {

2574 Subprog2Exprs[nullptr].push_back(KV.first);

2575 }

2576 }

2577 for (auto &KV : Subprog2Exprs) {

2579 KV.second.end());

2580 auto Leaves = getExpressionLeaves(ExprsInSubprogram);

2581

2583 for (Value *Leaf : Leaves)

2584 collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);

2585

2586

2587 for (auto *L : Leaves) {

2588

2589 DebugLoc Loc = cast(L)->getDebugLoc();

2590 DILocation *Context = cast(L)->getDebugLoc();

2591 while (Context) {

2592 if (getSubprogram(Context->getScope()) == KV.first) {

2593 Loc = Context;

2594 break;

2595 }

2597 }

2598

2600 OpInfoTy Counts, SharedCounts;

2601 std::tie(Counts, SharedCounts) =

2602 sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);

2603

2605 cast(L)->getParent());

2606

2607 Rem << "Lowered with ";

2608 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "

2609 << ore::NV("NumLoads", Counts.NumLoads) << " loads, "

2610 << ore::NV("NumComputeOps", Counts.NumComputeOps)

2611 << " compute ops, "

2612 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)

2613 << " exposed transposes";

2614

2615 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||

2616 SharedCounts.NumComputeOps > 0) {

2617 Rem << ",\nadditionally "

2618 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "

2619 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "

2620 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)

2621 << " compute ops"

2622 << " are shared with other expressions";

2623 }

2624

2625 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));

2626 ORE.emit(Rem);

2627 }

2628 }

2629 }

2630

2631 std::string

2632 linearize(Value *L,

2636 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);

2637 Lin.linearizeExpr(L, 0, false, false);

2638 return Lin.getResult();

2639 }

2640 };

2641};

2642}

2643

2647

2648 LowerMatrixIntrinsics LMT(F, TTI, Minimal ? nullptr : &AM);

2649 if (LMT.Visit()) {

2651 if (!Minimal) {

2654 }

2655 return PA;

2656 }

2658}

2659

2663 OS, MapClassName2PassName);

2664 OS << '<';

2665 if (Minimal)

2666 OS << "minimal";

2667 OS << '>';

2668}

ReachingDefAnalysis InstSet & ToRemove

MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL

static const Function * getParent(const Value *V)

static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")

static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")

static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")

#define clEnumValN(ENUMVAL, FLAGNAME, DESC)

Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx

hexagon Hexagon specific predictive commoning for HVX vectors

This file provides various utilities for inspecting and working with the control flow graph in LLVM I...

static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)

static DISubprogram * getSubprogram(DIScope *Scope)

Helper function to either return Scope, if it is a subprogram or the attached subprogram for a local ...

static cl::opt< bool > ForceFusion("force-fuse-matrix", cl::init(false), cl::Hidden, cl::desc("Force matrix instruction fusion even if not profitable."))

static cl::opt< bool > VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, cl::desc("Enable/disable matrix shape verification."), cl::init(false))

static bool isSplat(Value *V)

Return true if V is a splat of a value (which is used when multiplying a matrix with a scalar).

static cl::opt< bool > TileUseLoops("fuse-matrix-use-loops", cl::init(false), cl::Hidden, cl::desc("Generate loop nest for tiling."))

static cl::opt< bool > FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions."))

auto m_AnyAdd(const LTy &L, const RTy &R)

Match any add operation (fp or integer).

static cl::opt< bool > AllowContractEnabled("matrix-allow-contract", cl::init(false), cl::Hidden, cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error."))

auto m_AnyMul(const LTy &L, const RTy &R)

Match any mul operation (fp or integer).

static cl::opt< bool > PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false))

static cl::opt< unsigned > TileSize("fuse-matrix-tile-size", cl::init(4), cl::Hidden, cl::desc("Tile size for matrix instruction fusion using square-shaped tiles."))

static cl::opt< MatrixLayoutTy > MatrixLayout("matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), cl::desc("Sets the default matrix layout"), cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", "Use column-major layout"), clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout")))

uint64_t IntrinsicInst * II

PowerPC Reduce CR logical Operation

This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.

assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())

static unsigned getNumElements(Type *Ty)

static Value * extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, unsigned EndIndex, const Twine &Name)

static Value * insertVector(IRBuilderTy &IRB, Value *Old, Value *V, unsigned BeginIndex, const Twine &Name)

This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...

This file defines the SmallSet class.

This file defines the SmallVector class.

static SymbolRef::Type getType(const Symbol *Sym)

static const int BlockSize

This pass exposes codegen information to IR-level passes.

static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)

Returns the opcode of Values or ~0 if they do not all agree.

static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)

static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG)

A manager for alias analyses.

bool isNoAlias(const MemoryLocation &LocA, const MemoryLocation &LocB)

A trivial helper function to check to see if the specified pointers are no-alias.

an instruction to allocate memory on the stack

Align getAlign() const

Return the alignment of the memory that is being allocated by the instruction.

A container for analyses that lazily runs them and caches their results.

PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)

Get the result of an analysis pass for a given IR unit.

ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...

LLVM Basic Block Representation.

iterator begin()

Instruction iterator methods.

reverse_iterator rbegin()

InstListType::reverse_iterator reverse_iterator

const Instruction * getTerminator() const LLVM_READONLY

Returns the terminator instruction if the block is well formed or null if the block is not well forme...

BinaryOps getOpcode() const

Function * getCalledFunction() const

Returns the function called, or null if this is an indirect function invocation or the function signa...

User::op_iterator arg_begin()

Return the iterator pointing to the beginning of the argument list.

MaybeAlign getParamAlign(unsigned ArgNo) const

Extract the alignment for a call or parameter (0=unknown).

Value * getArgOperand(unsigned i) const

User::op_iterator arg_end()

Return the iterator pointing to the end of the argument list.

This class represents a function call, abstracting a target machine's calling convention.

static ConstantAggregateZero * get(Type *Ty)

This is the shared class of boolean and integer constants.

DISubprogram * getSubprogram() const

Get the subprogram for this scope.

Base class for scope-like contexts.

This class represents an Operation in the Expression.

A parsed version of the target data layout string in and methods for querying it.

DILocation * getInlinedAt() const

iterator find(const_arg_type_t< KeyT > Val)

bool erase(const KeyT &Val)

size_type count(const_arg_type_t< KeyT > Val) const

Return 1 if the specified key is in the map, 0 otherwise.

std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)

Analysis pass which computes a DominatorTree.

void applyUpdates(ArrayRef< UpdateType > Updates)

Inform the dominator tree about a sequence of CFG edge insertions and deletions and perform a batch u...

static constexpr UpdateKind Delete

static constexpr UpdateKind Insert

Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.

bool dominates(const BasicBlock *BB, const Use &U) const

Return true if the (end of the) basic block BB dominates the use U.

Convenience struct for specifying and reasoning about fast-math flags.

void setAllowContract(bool B=true)

bool allowReassoc() const

Flag queries.

bool allowContract() const

Class to represent fixed width SIMD vectors.

static FixedVectorType * get(Type *ElementType, unsigned NumElts)

Intrinsic::ID getIntrinsicID() const LLVM_READONLY

getIntrinsicID - This method returns the ID number of the specified function, or Intrinsic::not_intri...

bool isIntrinsic() const

isIntrinsic - Returns true if the function's name starts with "llvm.".

CallInst * CreateFAddReduce(Value *Acc, Value *Src)

Create a sequential vector fadd reduction intrinsic of the source vector.

Value * CreateICmpULT(Value *LHS, Value *RHS, const Twine &Name="")

Value * CreateFSub(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)

Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")

AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")

Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")

LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)

Value * CreateFAdd(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)

Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")

Return a vector value that contains.

CallInst * CreateAddReduce(Value *Src)

Create a vector int add reduction intrinsic of the source vector.

IntegerType * getIntPtrTy(const DataLayout &DL, unsigned AddrSpace=0)

Fetch the type of an integer with size at least as big as that of a pointer in the given address spac...

void setFastMathFlags(FastMathFlags NewFMF)

Set the fast-math flags to be used with generated fp-math operators.

Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())

ConstantInt * getInt64(uint64_t C)

Get a constant 64-bit value.

CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")

Create a call to intrinsic ID with Args, mangled using Types.

PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")

Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)

ConstantInt * getIntN(unsigned N, uint64_t C)

Get a constant N-bit value, zero extended or truncated from a 64-bit value.

BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)

Create a conditional 'br Cond, TrueDest, FalseDest' instruction.

LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)

Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...

Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")

Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)

Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")

void SetInsertPoint(BasicBlock *TheBB)

This specifies that created instructions should be appended to the end of the specified block.

StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)

Value * CreateFMul(Value *L, Value *R, const Twine &Name="", MDNode *FPMD=nullptr)

Value * CreateFNeg(Value *V, const Twine &Name="", MDNode *FPMathTag=nullptr)

CallInst * CreateMemCpy(Value *Dst, MaybeAlign DstAlign, Value *Src, MaybeAlign SrcAlign, uint64_t Size, bool isVolatile=false, MDNode *TBAATag=nullptr, MDNode *TBAAStructTag=nullptr, MDNode *ScopeTag=nullptr, MDNode *NoAliasTag=nullptr)

Create and insert a memcpy between the specified pointers.

Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)

This provides a uniform API for creating instructions and inserting them into a basic block: either a...

static InstructionCost getInvalid(CostType Val=0)

void setFastMathFlags(FastMathFlags FMF)

Convenience function for setting multiple fast-math flags on this instruction, which must be an opera...

InstListType::iterator eraseFromParent()

This method unlinks 'this' from the containing basic block and deletes it.

FastMathFlags getFastMathFlags() const LLVM_READONLY

Convenience function for getting all the fast-math flags, which must be an operator which supports th...

A wrapper class for inspecting calls to intrinsic functions.

An instruction for reading from memory.

bool isVolatile() const

Return true if this is a load from a volatile memory location.

Align getAlign() const

Return the alignment of the access that is being performed.

TypeSize getValue() const

Analysis pass that exposes the LoopInfo for a function.

LoopT * getLoopFor(const BlockT *BB) const

Return the inner most loop that BB lives in.

PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)

void printPipeline(raw_ostream &OS, function_ref< StringRef(StringRef)> MapClassName2PassName)

This class implements a map that also provides access to all stored values in a deterministic order.

iterator find(const KeyT &Key)

std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)

CallInst * CreateMatrixTranspose(Value *Matrix, unsigned Rows, unsigned Columns, const Twine &Name="")

Create a llvm.matrix.transpose call, transposing Matrix with Rows rows and Columns columns.

CallInst * CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, unsigned LHSColumns, unsigned RHSColumns, const Twine &Name="")

Create a llvm.matrix.multiply call, multiplying matrixes LHS and RHS.

Representation for a specific memory location.

static MemoryLocation get(const LoadInst *LI)

Return a location with information about the memory reference by the given instruction.

LocationSize Size

The maximum size of the location, in address-units, or UnknownSize if the size is not known.

const Value * Ptr

The address of the start of the location.

static MemoryLocation getForArgument(const CallBase *Call, unsigned ArgIdx, const TargetLibraryInfo *TLI)

Return a location representing a particular argument of a call.

static PoisonValue * get(Type *T)

Static factory methods - Return an 'poison' object of the specified type.

A set of analyses that are preserved following a run of a transformation pass.

static PreservedAnalyses all()

Construct a special preserved set that preserves all passes.

void preserve()

Mark an analysis as preserved.

A vector that has set insertion semantics.

size_type size() const

Determine the number of elements in the SetVector.

size_type count(const key_type &key) const

Count the number of elements of a given key in the SetVector.

bool insert(const value_type &X)

Insert a new element into the SetVector.

A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...

size_type count(ConstPtrType Ptr) const

count - Return 1 if the specified pointer is in the set, 0 otherwise.

std::pair< iterator, bool > insert(PtrType Ptr)

Inserts Ptr if and only if there is no element in the container equal to Ptr.

bool contains(ConstPtrType Ptr) const

SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.

A SetVector that performs no allocations if smaller than a certain size.

SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...

std::pair< const_iterator, bool > insert(const T &V)

insert - Insert an element into the set if it isn't already there.

This class consists of common code factored out of the SmallVector class to reduce code duplication b...

void append(ItTy in_start, ItTy in_end)

Add the specified range to the end of the SmallVector.

void push_back(const T &Elt)

This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.

An instruction for storing to memory.

bool isVolatile() const

Return true if this is a store to a volatile memory location.

StringRef - Represent a constant reference to a string, i.e.

StringRef drop_front(size_t N=1) const

Return a StringRef equal to 'this' but with the first N elements dropped.

constexpr size_t size() const

size - Get the string size.

Analysis pass providing the TargetTransformInfo.

This pass provides access to the codegen interfaces that are needed for IR-level transformations.

TypeSize getRegisterBitWidth(RegisterKind K) const

InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, OperandValueInfo OpdInfo={OK_AnyValue, OP_None}, const Instruction *I=nullptr) const

InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, std::optional< FastMathFlags > FMF, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput) const

Calculate the cost of vector reduction intrinsics.

unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const

@ TCK_RecipThroughput

Reciprocal throughput.

InstructionCost getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, TTI::OperandValueInfo Opd1Info={TTI::OK_AnyValue, TTI::OP_None}, TTI::OperandValueInfo Opd2Info={TTI::OK_AnyValue, TTI::OP_None}, ArrayRef< const Value * > Args={}, const Instruction *CxtI=nullptr, const TargetLibraryInfo *TLibInfo=nullptr) const

This is an approximation of reciprocal throughput of a math/logic op.

InstructionCost getShuffleCost(ShuffleKind Kind, VectorType *Tp, ArrayRef< int > Mask={}, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, int Index=0, VectorType *SubTp=nullptr, ArrayRef< const Value * > Args={}, const Instruction *CxtI=nullptr) const

unsigned getNumberOfRegisters(unsigned ClassID) const

@ SK_Splice

Concatenates elements from the first input vector with elements of the second input vector.

Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...

The instances of the Type class are immutable: once they are created, they are never changed.

unsigned getScalarSizeInBits() const LLVM_READONLY

If this is a vector type, return the getPrimitiveSizeInBits value for the element type.

TypeSize getPrimitiveSizeInBits() const LLVM_READONLY

Return the basic size of this type if it is a primitive type.

bool isVoidTy() const

Return true if this is 'void'.

Type * getScalarType() const

If this is a vector type, return the element type, otherwise return 'this'.

UnaryOps getOpcode() const

A Use represents the edge between a Value definition and its users.

Value * getOperand(unsigned i) const

LLVM Value Representation.

Type * getType() const

All values are typed, get the type of this value.

user_iterator user_begin()

bool hasOneUse() const

Return true if there is exactly one use of this value.

void replaceAllUsesWith(Value *V)

Change all uses of this to point to a new Value.

iterator_range< user_iterator > users()

bool hasNUses(unsigned N) const

Return true if this Value has exactly N uses.

iterator_range< use_iterator > uses()

StringRef getName() const

Return a constant reference to the value's name.

Type * getElementType() const

constexpr ScalarTy getFixedValue() const

An efficient, type-erasing, non-owning reference to a callable.

const ParentTy * getParent() const

A range adaptor for a pair of iterators.

This class implements an extremely fast bulk output stream that can only output to a stream.

A raw_ostream that writes to an std::string.

#define llvm_unreachable(msg)

Marks that the current location is not supposed to be reachable.

constexpr std::underlying_type_t< E > Mask()

Get a bitmask with 1s in all places up to the high-order bit of E's largest value.

@ C

The default llvm calling convention, compatible with C.

StringRef getBaseName(ID id)

Return the LLVM name for an intrinsic, without encoded types for overloading, such as "llvm....

TwoOps_match< ValueOpTy, PointerOpTy, Instruction::Store > m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp)

Matches StoreInst.

BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)

class_match< BinaryOperator > m_BinOp()

Match an arbitrary binary operation and ignore it.

specific_intval< false > m_SpecificInt(const APInt &V)

Match a specific integer value or vector with all elements equal to the value.

BinaryOp_match< LHS, RHS, Instruction::FMul > m_FMul(const LHS &L, const RHS &R)

bool match(Val *V, const Pattern &P)

class_match< ConstantInt > m_ConstantInt()

Match an arbitrary ConstantInt and ignore it.

BinaryOp_match< LHS, RHS, Instruction::FAdd > m_FAdd(const LHS &L, const RHS &R)

BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)

OneUse_match< T > m_OneUse(const T &SubPattern)

OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)

Matches LoadInst.

class_match< Value > m_Value()

Match an arbitrary value and ignore it.

match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)

Combine two pattern matchers matching L || R.

ValuesClass values(OptsTy... Options)

Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...

initializer< Ty > init(const Ty &Val)

ElementType

The element type of an SRV or UAV resource.

DiagnosticInfoOptimizationBase::Argument NV

NodeAddr< PhiNode * > Phi

NodeAddr< FuncNode * > Func

This is an optimization pass for GlobalISel generic memory operations.

auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)

Get the size of a range.

detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)

auto enumerate(FirstRange &&First, RestRanges &&...Rest)

Given two or more input ranges, returns a new range whose values are tuples (A, B,...

auto successors(const MachineBasicBlock *BB)

bool operator!=(uint64_t V1, const APInt &V2)

iterator_range< T > make_range(T x, T y)

Convenience function for iterating over sub-ranges.

LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & operator+=(DynamicAPInt &A, int64_t B)

const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)

This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....

iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)

Make a range that does early increment to allow mutation of the underlying range without disrupting i...

Value * concatenateVectors(IRBuilderBase &Builder, ArrayRef< Value * > Vecs)

Concatenate a list of vectors.

bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)

const Value * getPointerOperand(const Value *V)

A helper function that returns the pointer operand of a load, store or GEP instruction.

void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V=0)

Set input string into loop metadata by keeping other values intact.

bool any_of(R &&range, UnaryPredicate P)

Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.

auto reverse(ContainerTy &&C)

Error write(MCStreamer &Out, ArrayRef< std::string > Inputs, OnCuIndexOverflow OverflowOptValue)

void sort(IteratorTy Start, IteratorTy End)

raw_ostream & dbgs()

dbgs() - This returns a reference to a raw_ostream for debugging messages.

void report_fatal_error(Error Err, bool gen_crash_diag=true)

Report a serious error, calling any installed error handler.

raw_fd_ostream & errs()

This returns a reference to a raw_ostream for standard error.

DWARFExpression::Operation Op

decltype(auto) cast(const From &Val)

cast - Return the argument parameter cast to the specified type.

BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)

Split the specified block at the specified instruction.

Align commonAlignment(Align A, uint64_t Offset)

Returns the alignment that satisfies both alignments.

llvm::SmallVector< int, 16 > createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs)

Create a sequential shuffle mask.

void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)

Implement std::swap in terms of BitVector swap.

This struct is a compact representation of a valid (non-zero power of two) alignment.

This struct is a compact representation of a valid (power of two) or undefined (0) alignment.

A CRTP mix-in to automatically provide informational APIs needed for passes.

A helper struct to create IR loop nests for tiling in IR of the following form: for ColumnLoop....