LLVM: lib/Transforms/Scalar/LoopLoadElimination.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

56#include

57#include

58#include <forward_list>

59#include

60#include

61

62using namespace llvm;

63

64#define LLE_OPTION "loop-load-elim"

65#define DEBUG_TYPE LLE_OPTION

66

68 "runtime-check-per-loop-load-elim", cl::Hidden,

69 cl::desc("Max number of memchecks allowed per eliminated load on average"),

71

73 "loop-load-elimination-scev-check-threshold", cl::init(8), cl::Hidden,

74 cl::desc("The maximum number of SCEV checks allowed for Loop "

75 "Load Elimination"));

76

77STATISTIC(NumLoopLoadEliminted, "Number of loads eliminated by LLE");

78

79namespace {

80

81

82struct StoreToLoadForwardingCandidate {

85

87 : Load(Load), Store(Store) {}

88

89

90

91

92 bool isDependenceDistanceOfOne(PredicatedScalarEvolution &PSE, Loop *L,

93 const DominatorTree &DT) const {

94 Value *LoadPtr = Load->getPointerOperand();

95 Value *StorePtr = Store->getPointerOperand();

97 auto &DL = Load->getDataLayout();

98

101 DL.getTypeSizeInBits(LoadType) ==

103 "Should be a known dependence");

104

105 int64_t StrideLoad =

106 getPtrStride(PSE, LoadType, LoadPtr, L, DT).value_or(0);

107 int64_t StrideStore =

108 getPtrStride(PSE, LoadType, StorePtr, L, DT).value_or(0);

109 if (!StrideLoad || !StrideStore || StrideLoad != StrideStore)

110 return false;

111

112

113

114

115

116

117

118

119 if (std::abs(StrideLoad) != 1)

120 return false;

121

122 unsigned TypeByteSize = DL.getTypeAllocSize(LoadType);

123

126

127

128

131 if (!Dist)

132 return false;

133 const APInt &Val = Dist->getAPInt();

134 return Val == TypeByteSize * StrideLoad;

135 }

136

137 Value *getLoadPtr() const { return Load->getPointerOperand(); }

138

139#ifndef NDEBUG

140 friend raw_ostream &operator<<(raw_ostream &OS,

141 const StoreToLoadForwardingCandidate &Cand) {

142 OS << *Cand.Store << " -->\n";

143 OS.indent(2) << *Cand.Load << "\n";

144 return OS;

145 }

146#endif

147};

148

149}

150

151

152

156 L->getLoopLatches(Latches);

158 return DT->dominates(StoreBlock, Latch);

159 });

160}

161

162

164 return Load->getParent() != L->getHeader();

165}

166

167namespace {

168

169

170class LoadEliminationForLoop {

171public:

172 LoadEliminationForLoop(Loop *L, LoopInfo *LI, const LoopAccessInfo &LAI,

173 DominatorTree *DT, BlockFrequencyInfo *BFI,

174 ProfileSummaryInfo* PSI)

175 : L(L), LI(LI), LAI(LAI), DT(DT), BFI(BFI), PSI(PSI), PSE(LAI.getPSE()) {}

176

177

178

179

180

181

182 std::forward_list

183 findStoreToLoadDependences(const LoopAccessInfo &LAI) {

184 std::forward_list Candidates;

185

186 const auto &DepChecker = LAI.getDepChecker();

187 const auto *Deps = DepChecker.getDependences();

188 if (!Deps)

189 return Candidates;

190

191

192

193

194

195 SmallPtrSet<Instruction *, 4> LoadsWithUnknownDependence;

196

197 for (const auto &Dep : *Deps) {

199 Instruction *Destination = Dep.getDestination(DepChecker);

200

204 LoadsWithUnknownDependence.insert(Source);

206 LoadsWithUnknownDependence.insert(Destination);

207 continue;

208 }

209

210 if (Dep.isBackward())

211

212

213

215 else

216 assert(Dep.isForward() && "Needs to be a forward dependence");

217

219 if (!Store)

220 continue;

222 if (!Load)

223 continue;

224

225

228 Store->getDataLayout()))

229 continue;

230

231 Candidates.emplace_front(Load, Store);

232 }

233

234 if (!LoadsWithUnknownDependence.empty())

235 Candidates.remove_if([&](const StoreToLoadForwardingCandidate &C) {

236 return LoadsWithUnknownDependence.count(C.Load);

237 });

238

239 return Candidates;

240 }

241

242

243 unsigned getInstrIndex(Instruction *Inst) {

244 auto I = InstOrder.find(Inst);

245 assert(I != InstOrder.end() && "No index for instruction");

246 return I->second;

247 }

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

267

268 void removeDependencesFromMultipleStores(

269 std::forward_list &Candidates) {

270

271

272 using LoadToSingleCandT =

273 DenseMap<LoadInst *, const StoreToLoadForwardingCandidate *>;

274 LoadToSingleCandT LoadToSingleCand;

275

276 for (const auto &Cand : Candidates) {

277 bool NewElt;

278 LoadToSingleCandT::iterator Iter;

279

280 std::tie(Iter, NewElt) =

281 LoadToSingleCand.insert(std::make_pair(Cand.Load, &Cand));

282 if (!NewElt) {

283 const StoreToLoadForwardingCandidate *&OtherCand = Iter->second;

284

285 if (OtherCand == nullptr)

286 continue;

287

288

289

290

291 if (Cand.Store->getParent() == OtherCand->Store->getParent() &&

292 Cand.isDependenceDistanceOfOne(PSE, L, *DT) &&

293 OtherCand->isDependenceDistanceOfOne(PSE, L, *DT)) {

294

295 if (getInstrIndex(OtherCand->Store) < getInstrIndex(Cand.Store))

296 OtherCand = &Cand;

297 } else

298 OtherCand = nullptr;

299 }

300 }

301

302 Candidates.remove_if([&](const StoreToLoadForwardingCandidate &Cand) {

303 if (LoadToSingleCand[Cand.Load] != &Cand) {

305 dbgs() << "Removing from candidates: \n"

306 << Cand

307 << " The load may have multiple stores forwarding to "

308 << "it\n");

309 return true;

310 }

311 return false;

312 });

313 }

314

315

316

317

318

319

320 bool needsChecking(unsigned PtrIdx1, unsigned PtrIdx2,

321 const SmallPtrSetImpl<Value *> &PtrsWrittenOnFwdingPath,

322 const SmallPtrSetImpl<Value *> &CandLoadPtrs) {

324 LAI.getRuntimePointerChecking()->getPointerInfo(PtrIdx1).PointerValue;

326 LAI.getRuntimePointerChecking()->getPointerInfo(PtrIdx2).PointerValue;

327 return ((PtrsWrittenOnFwdingPath.count(Ptr1) && CandLoadPtrs.count(Ptr2)) ||

328 (PtrsWrittenOnFwdingPath.count(Ptr2) && CandLoadPtrs.count(Ptr1)));

329 }

330

331

332

333

334

335 SmallPtrSet<Value *, 4> findPointersWrittenOnForwardingPath(

336 const SmallVectorImpl &Candidates) {

337

338

339

340

341

342

343

344

345

346

347

348

349

350

351

352

353

354 LoadInst *LastLoad =

356 [&](const StoreToLoadForwardingCandidate &A,

357 const StoreToLoadForwardingCandidate &B) {

358 return getInstrIndex(A.Load) <

359 getInstrIndex(B.Load);

360 })

361 ->Load;

362 StoreInst *FirstStore =

364 [&](const StoreToLoadForwardingCandidate &A,

365 const StoreToLoadForwardingCandidate &B) {

366 return getInstrIndex(A.Store) <

367 getInstrIndex(B.Store);

368 })

369 ->Store;

370

371

372

373

374 SmallPtrSet<Value *, 4> PtrsWrittenOnFwdingPath;

375

378 PtrsWrittenOnFwdingPath.insert(S->getPointerOperand());

379 };

380 const auto &MemInstrs = LAI.getDepChecker().getMemoryInstructions();

381 std::for_each(MemInstrs.begin() + getInstrIndex(FirstStore) + 1,

382 MemInstrs.end(), InsertStorePtr);

383 std::for_each(MemInstrs.begin(), &MemInstrs[getInstrIndex(LastLoad)],

384 InsertStorePtr);

385

386 return PtrsWrittenOnFwdingPath;

387 }

388

389

390

391 SmallVector<RuntimePointerCheck, 4> collectMemchecks(

392 const SmallVectorImpl &Candidates) {

393

394 SmallPtrSet<Value *, 4> PtrsWrittenOnFwdingPath =

395 findPointersWrittenOnForwardingPath(Candidates);

396

397

398 SmallPtrSet<Value *, 4> CandLoadPtrs;

399 for (const auto &Candidate : Candidates)

400 CandLoadPtrs.insert(Candidate.getLoadPtr());

401

402 const auto &AllChecks = LAI.getRuntimePointerChecking()->getChecks();

403 SmallVector<RuntimePointerCheck, 4> Checks;

404

405 copy_if(AllChecks, std::back_inserter(Checks),

407 for (auto PtrIdx1 : Check.first->Members)

408 for (auto PtrIdx2 : Check.second->Members)

409 if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath,

410 CandLoadPtrs))

411 return true;

412 return false;

413 });

414

416 << "):\n");

417 LLVM_DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks));

418

419 return Checks;

420 }

421

422

423 void

424 propagateStoredValueToLoadUsers(const StoreToLoadForwardingCandidate &Cand,

425 SCEVExpander &SEE) {

426

427

428

429

430

431

432

433

434

435

436

437

438

439

440

443 auto *PH = L->getLoopPreheader();

444 assert(PH && "Preheader should exist!");

445 Value *InitialPtr = SEE.expandCodeFor(PtrSCEV->getStart(), Ptr->getType(),

446 PH->getTerminator());

448 new LoadInst(Cand.Load->getType(), InitialPtr, "load_initial",

449 false, Cand.Load->getAlign(),

450 PH->getTerminator()->getIterator());

451

452

453

454

456

458 PHI->insertBefore(L->getHeader()->begin());

459 PHI->addIncoming(Initial, PH);

460

464 (void)DL;

465

466 assert(DL.getTypeSizeInBits(LoadType) == DL.getTypeSizeInBits(StoreType) &&

467 "The type sizes should match!");

468

470 if (LoadType != StoreType) {

472 "store_forward_cast",

474

475

476

478 }

479

480 PHI->addIncoming(StoreValue, L->getLoopLatch());

481

484 }

485

486

487

488 bool processLoop() {

489 LLVM_DEBUG(dbgs() << "\nIn \"" << L->getHeader()->getParent()->getName()

490 << "\" checking " << *L << "\n");

491

492

493

494

495

496

497

498

499

500

501

502

503

504

505

506

507

508

509

510

511 auto StoreToLoadDependences = findStoreToLoadDependences(LAI);

512 if (StoreToLoadDependences.empty())

513 return false;

514

515

516

517 InstOrder = LAI.getDepChecker().generateInstructionOrderMap();

518

519

520

521 removeDependencesFromMultipleStores(StoreToLoadDependences);

522 if (StoreToLoadDependences.empty())

523 return false;

524

525

527 for (const StoreToLoadForwardingCandidate &Cand : StoreToLoadDependences) {

529

530

531

533 continue;

534

535

536

537

539 continue;

540

541

542

543 if (!Cand.isDependenceDistanceOfOne(PSE, L, *DT))

544 continue;

545

547 "Loading from something other than indvar?");

550 "Storing to something other than indvar?");

551

555 << Candidates.size()

556 << ". Valid store-to-load forwarding across the loop backedge\n");

557 }

558 if (Candidates.empty())

559 return false;

560

561

562

563 SmallVector<RuntimePointerCheck, 4> Checks = collectMemchecks(Candidates);

564

565

567 LLVM_DEBUG(dbgs() << "Too many run-time checks needed.\n");

568 return false;

569 }

570

571 if (LAI.getPSE().getPredicate().getComplexity() >

573 LLVM_DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n");

574 return false;

575 }

576

577 if (!L->isLoopSimplifyForm()) {

578 LLVM_DEBUG(dbgs() << "Loop is not is loop-simplify form");

579 return false;

580 }

581

582 if (!Checks.empty() || !LAI.getPSE().getPredicate().isAlwaysTrue()) {

583 if (LAI.hasConvergentOp()) {

584 LLVM_DEBUG(dbgs() << "Versioning is needed but not allowed with "

585 "convergent calls\n");

586 return false;

587 }

588

589 auto *HeaderBB = L->getHeader();

591 PGSOQueryType::IRPass)) {

593 dbgs() << "Versioning is needed but not allowed when optimizing "

594 "for size.\n");

595 return false;

596 }

597

598

599

600

601 LoopVersioning LV(LAI, Checks, L, LI, DT, PSE.getSE());

602 LV.versionLoop();

603

604

605

606 auto NoLongerGoodCandidate = [this](

607 const StoreToLoadForwardingCandidate &Cand) {

612 };

614 }

615

616

617

618 SCEVExpander SEE(*PSE.getSE(), L->getHeader()->getDataLayout(),

619 "storeforward");

620 for (const auto &Cand : Candidates)

621 propagateStoredValueToLoadUsers(Cand, SEE);

622 NumLoopLoadEliminted += Candidates.size();

623

624 return true;

625 }

626

627private:

628 Loop *L;

629

630

631

632 DenseMap<Instruction *, unsigned> InstOrder;

633

634

635 LoopInfo *LI;

636 const LoopAccessInfo &LAI;

637 DominatorTree *DT;

638 BlockFrequencyInfo *BFI;

639 ProfileSummaryInfo *PSI;

640 PredicatedScalarEvolution PSE;

641};

642

643}

644

651

652

653

654

655

657

659

660 for (Loop *TopLevelLoop : LI)

663

664 if (L->isInnermost())

666 }

667

668

669 for (Loop *L : Worklist) {

670

671 if (!L->isRotatedForm() || !L->getExitingBlock())

672 continue;

673

674 LoadEliminationForLoop LEL(L, &LI, LAIs.getInfo(*L), &DT, BFI, PSI);

675 Changed |= LEL.processLoop();

678 }

680}

681

685

686

687 if (LI.empty())

694 auto *BFI = (PSI && PSI->hasProfileSummary()) ?

697

699

702

706 return PA;

707}

assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")

This file implements a class to represent arbitrary precision integral constant values and operations...

MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL

static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")

static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")

This file defines the DenseMap class.

This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.

This is the interface for a simple mod/ref and alias analysis over globals.

This header defines various interfaces for pass management in LLVM.

This header provides classes for managing per-loop analyses.

static bool eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT, BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, ScalarEvolution *SE, AssumptionCache *AC, LoopAccessInfoManager &LAIs)

Definition LoopLoadElimination.cpp:645

static cl::opt< unsigned > LoadElimSCEVCheckThreshold("loop-load-elimination-scev-check-threshold", cl::init(8), cl::Hidden, cl::desc("The maximum number of SCEV checks allowed for Loop " "Load Elimination"))

static bool isLoadConditional(LoadInst *Load, Loop *L)

Return true if the load is not executed on all paths in the loop.

Definition LoopLoadElimination.cpp:163

static bool doesStoreDominatesAllLatches(BasicBlock *StoreBlock, Loop *L, DominatorTree *DT)

Check if the store dominates all latches, so as long as there is no intervening store this value will...

Definition LoopLoadElimination.cpp:153

static cl::opt< unsigned > CheckPerElim("runtime-check-per-loop-load-elim", cl::Hidden, cl::desc("Max number of memchecks allowed per eliminated load on average"), cl::init(1))

This header defines the LoopLoadEliminationPass object.

This file defines the SmallPtrSet class.

This file defines the SmallVector class.

This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...

#define STATISTIC(VARNAME, DESC)

This pass exposes codegen information to IR-level passes.

PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)

Get the result of an analysis pass for a given IR unit.

A function analysis which provides an AssumptionCache.

A cache of @llvm.assume calls within a function.

LLVM Basic Block Representation.

Analysis pass which computes BlockFrequencyInfo.

BlockFrequencyInfo pass uses BlockFrequencyInfoImpl implementation to estimate IR basic block frequen...

static LLVM_ABI bool isBitOrNoopPointerCastable(Type *SrcTy, Type *DestTy, const DataLayout &DL)

Check whether a bitcast, inttoptr, or ptrtoint cast between these types is valid and a no-op.

static LLVM_ABI CastInst * CreateBitOrPointerCast(Value *S, Type *Ty, const Twine &Name="", InsertPosition InsertBefore=nullptr)

Create a BitCast, a PtrToInt, or an IntToPTr cast instruction.

static DebugLoc getDropped()

Analysis pass which computes a DominatorTree.

Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.

LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const

Return true if the (end of the) basic block BB dominates the use U.

const DebugLoc & getDebugLoc() const

Return the debug location for this node as a DebugLoc.

LLVM_ABI const DataLayout & getDataLayout() const

Get the data layout of the module this instruction belongs to.

An instruction for reading from memory.

Value * getPointerOperand()

Align getAlign() const

Return the alignment of the access that is being performed.

This analysis provides dependence information for the memory accesses of a loop.

LLVM_ABI const LoopAccessInfo & getInfo(Loop &L, bool AllowPartial=false)

Analysis pass that exposes the LoopInfo for a function.

Represents a single loop in the control flow graph.

static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)

Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...

ScalarEvolution * getSE() const

Returns the ScalarEvolution analysis used.

LLVM_ABI const SCEV * getSCEV(Value *V)

Returns the SCEV expression of V, in the context of the current SCEV predicate.

A set of analyses that are preserved following a run of a transformation pass.

static PreservedAnalyses all()

Construct a special preserved set that preserves all passes.

PreservedAnalyses & preserve()

Mark an analysis as preserved.

An analysis pass based on the new PM to deliver ProfileSummaryInfo.

Analysis providing profile information.

Analysis pass that exposes the ScalarEvolution for a function.

The main scalar evolution driver.

LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)

Return LHS-RHS.

size_type count(ConstPtrType Ptr) const

count - Return 1 if the specified pointer is in the set, 0 otherwise.

std::pair< iterator, bool > insert(PtrType Ptr)

Inserts Ptr if and only if there is no element in the container equal to Ptr.

void push_back(const T &Elt)

This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.

An instruction for storing to memory.

Value * getValueOperand()

Value * getPointerOperand()

LLVM_ABI unsigned getPointerAddressSpace() const

Get the address space of this pointer or pointer vector type.

Type * getType() const

All values are typed, get the type of this value.

LLVM_ABI void replaceAllUsesWith(Value *V)

Change all uses of this to point to a new Value.

const ParentTy * getParent() const

self_iterator getIterator()

raw_ostream & indent(unsigned NumSpaces)

indent - Insert 'NumSpaces' spaces.

@ C

The default llvm calling convention, compatible with C.

initializer< Ty > init(const Ty &Val)

friend class Instruction

Iterator for Instructions in a `BasicBlock.

This is an optimization pass for GlobalISel generic memory operations.

LLVM_ABI bool simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, MemorySSAUpdater *MSSAU, bool PreserveLCSSA)

Simplify each loop in a loop nest recursively.

FunctionAddr VTableAddr Value

auto min_element(R &&Range)

Provide wrappers to std::min_element which take ranges instead of having to pass begin/end explicitly...

bool all_of(R &&range, UnaryPredicate P)

Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.

std::pair< const RuntimeCheckingPtrGroup *, const RuntimeCheckingPtrGroup * > RuntimePointerCheck

A memcheck which made up of a pair of grouped pointers.

decltype(auto) dyn_cast(const From &Val)

dyn_cast - Return the argument parameter cast to the specified type.

OuterAnalysisManagerProxy< ModuleAnalysisManager, Function > ModuleAnalysisManagerFunctionProxy

Provide the ModuleAnalysisManager to Function proxy.

LLVM_ABI bool shouldOptimizeForSize(const MachineFunction *MF, ProfileSummaryInfo *PSI, const MachineBlockFrequencyInfo *BFI, PGSOQueryType QueryType=PGSOQueryType::Other)

Returns true if machine function MF is suggested to be size-optimized based on the profile.

OutputIt copy_if(R &&Range, OutputIt Out, UnaryPredicate P)

Provide wrappers to std::copy_if which take ranges instead of having to pass begin/end explicitly.

LLVM_ABI raw_ostream & dbgs()

dbgs() - This returns a reference to a raw_ostream for debugging messages.

class LLVM_GSL_OWNER SmallVector

Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...

bool isa(const From &Val)

isa - Return true if the parameter to the template is an instance of one of the template type argu...

auto max_element(R &&Range)

Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...

raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)

decltype(auto) cast(const From &Val)

cast - Return the argument parameter cast to the specified type.

void erase_if(Container &C, UnaryPredicate P)

Provide a container algorithm similar to C++ Library Fundamentals v2's erase_if which is equivalent t...

Type * getLoadStoreType(const Value *I)

A helper function that returns the type of a load or store instruction.

iterator_range< df_iterator< T > > depth_first(const T &G)

AnalysisManager< Function > FunctionAnalysisManager

Convenience typedef for the Function analysis manager.

LLVM_ABI std::optional< int64_t > getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, Value *Ptr, const Loop *Lp, const DominatorTree &DT, const DenseMap< Value *, const SCEV * > &StridesMap=DenseMap< Value *, const SCEV * >(), bool Assume=false, bool ShouldCheckWrap=true)

If the pointer has a constant stride return it in units of the access type size.

void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)

Implement std::swap in terms of BitVector swap.

PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)

Definition LoopLoadElimination.cpp:682