LLVM: lib/Transforms/Scalar/NaryReassociate.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

111#include

112#include

113

114using namespace llvm;

116

117#define DEBUG_TYPE "nary-reassociate"

118

119namespace {

120

121class NaryReassociateLegacyPass : public FunctionPass {

122public:

123 static char ID;

124

127 }

128

129 bool doInitialization(Module &M) override {

130 return false;

131 }

132

134

135 void getAnalysisUsage(AnalysisUsage &AU) const override {

145 }

146

147private:

149};

150

151}

152

153char NaryReassociateLegacyPass::ID = 0;

154

156 "Nary reassociation", false, false)

164

166 return new NaryReassociateLegacyPass();

167}

168

169bool NaryReassociateLegacyPass::runOnFunction(Function &F) {

170 if (skipFunction(F))

171 return false;

172

173 auto *AC = &getAnalysis().getAssumptionCache(F);

174 auto *DT = &getAnalysis().getDomTree();

175 auto *SE = &getAnalysis().getSE();

176 auto *TLI = &getAnalysis().getTLI(F);

177 auto *TTI = &getAnalysis().getTTI(F);

178

179 return Impl.runImpl(F, AC, DT, SE, TLI, TTI);

180}

181

189

190 if (runImpl(F, AC, DT, SE, TLI, TTI))

192

196 return PA;

197}

198

203 AC = AC_;

204 DT = DT_;

205 SE = SE_;

206 TLI = TLI_;

207 TTI = TTI_;

208 DL = &F.getDataLayout();

209

210 bool Changed = false, ChangedInThisIteration;

211 do {

212 ChangedInThisIteration = doOneIteration(F);

213 Changed |= ChangedInThisIteration;

214 } while (ChangedInThisIteration);

216}

217

218bool NaryReassociatePass::doOneIteration(Function &F) {

220 SeenExprs.clear();

221

222

223

228 const SCEV *OrigSCEV = nullptr;

229 if (Instruction *NewI = tryReassociate(&OrigI, OrigSCEV)) {

231 OrigI.replaceAllUsesWith(NewI);

232

233

235

236

237 const SCEV *NewSCEV = SE->getSCEV(NewI);

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259 if (NewSCEV != OrigSCEV)

261 } else if (OrigSCEV)

262 SeenExprs[OrigSCEV].push_back(WeakTrackingVH(&OrigI));

263 }

264 }

265

266

268 DeadInsts, TLI, nullptr, [this](Value *V) { SE->forgetValue(V); });

269

271}

272

273template

275NaryReassociatePass::matchAndReassociateMinOrMax(Instruction *I,

276 const SCEV *&OrigSCEV) {

279

280 auto MinMaxMatcher =

281 MaxMin_match<ICmpInst, bind_ty, bind_ty, PredT>(

283 if (match(I, MinMaxMatcher)) {

284 OrigSCEV = SE->getSCEV(I);

286 tryReassociateMinOrMax(I, MinMaxMatcher, LHS, RHS)))

287 return NewMinMax;

289 tryReassociateMinOrMax(I, MinMaxMatcher, RHS, LHS)))

290 return NewMinMax;

291 }

292 return nullptr;

293}

294

295Instruction *NaryReassociatePass::tryReassociate(Instruction * I,

296 const SCEV *&OrigSCEV) {

297

298 if (!SE->isSCEVable(I->getType()))

299 return nullptr;

300

301 switch (I->getOpcode()) {

302 case Instruction::Add:

303 case Instruction::Mul:

304 OrigSCEV = SE->getSCEV(I);

306 case Instruction::GetElementPtr:

307 OrigSCEV = SE->getSCEV(I);

309 default:

310 break;

311 }

312

313

315

316

317

318 if (I->getType()->isIntegerTy())

319 if ((ResI = matchAndReassociateMinOrMax<umin_pred_ty>(I, OrigSCEV)) ||

320 (ResI = matchAndReassociateMinOrMax<smin_pred_ty>(I, OrigSCEV)) ||

321 (ResI = matchAndReassociateMinOrMax<umax_pred_ty>(I, OrigSCEV)) ||

322 (ResI = matchAndReassociateMinOrMax<smax_pred_ty>(I, OrigSCEV)))

323 return ResI;

324

325 return nullptr;

326}

327

331 return TTI->getGEPCost(GEP->getSourceElementType(), GEP->getPointerOperand(),

333}

334

335Instruction *NaryReassociatePass::tryReassociateGEP(GetElementPtrInst *GEP) {

336

338 return nullptr;

339

341 for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) {

343 if (auto *NewGEP = tryReassociateGEPAtIndex(GEP, I - 1,

345 return NewGEP;

346 }

347 }

348 }

349 return nullptr;

350}

351

352bool NaryReassociatePass::requiresSignExtension(Value *Index,

353 GetElementPtrInst *GEP) {

354 unsigned IndexSizeInBits =

355 DL->getIndexSizeInBits(GEP->getType()->getPointerAddressSpace());

357}

358

359GetElementPtrInst *

360NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP,

361 unsigned I, Type *IndexedType) {

362 SimplifyQuery SQ(*DL, DT, AC, GEP);

363 Value *IndexToSplit = GEP->getOperand(I + 1);

365 IndexToSplit = SExt->getOperand(0);

367

369 IndexToSplit = ZExt->getOperand(0);

370 }

371

373

374

375

376 if (requiresSignExtension(IndexToSplit, GEP) &&

378 return nullptr;

379

380 Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1);

381

382 if (auto *NewGEP = tryReassociateGEPAtIndex(GEP, I, LHS, RHS, IndexedType))

383 return NewGEP;

384

386 if (auto *NewGEP =

387 tryReassociateGEPAtIndex(GEP, I, RHS, LHS, IndexedType))

388 return NewGEP;

389 }

390 }

391 return nullptr;

392}

393

394GetElementPtrInst *

395NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP,

398

399

401 for (Use &Index : GEP->indices())

402 IndexExprs.push_back(SE->getSCEV(Index));

403

404 IndexExprs[I] = SE->getSCEV(LHS);

405 Type *GEPArgType = SE->getEffectiveSCEVType(GEP->getOperand(I)->getType());

406 Type *LHSType = SE->getEffectiveSCEVType(LHS->getType());

407 size_t LHSSize = DL->getTypeSizeInBits(LHSType).getFixedValue();

408 size_t GEPArgSize = DL->getTypeSizeInBits(GEPArgType).getFixedValue();

410 LHSSize < GEPArgSize) {

411

412

413

414

415 IndexExprs[I] = SE->getZeroExtendExpr(IndexExprs[I], GEPArgType);

416 }

418 IndexExprs);

419

420 Value *Candidate = findClosestMatchingDominator(CandidateExpr, GEP);

421 if (Candidate == nullptr)

422 return nullptr;

423

425

427

428

429 uint64_t IndexedSize = DL->getTypeAllocSize(IndexedType);

431 uint64_t ElementSize = DL->getTypeAllocSize(ElementType);

432

433

434

435

436

437

438

439

440

441

442

443

444

445

446 if (IndexedSize % ElementSize != 0)

447 return nullptr;

448

449

450 Type *PtrIdxTy = DL->getIndexType(GEP->getType());

452 RHS = Builder.CreateSExtOrTrunc(RHS, PtrIdxTy);

453 if (IndexedSize != ElementSize) {

454 RHS = Builder.CreateMul(

455 RHS, ConstantInt::get(PtrIdxTy, IndexedSize / ElementSize));

456 }

458 Builder.CreateGEP(GEP->getResultElementType(), Candidate, RHS));

461 return NewGEP;

462}

463

464Instruction *NaryReassociatePass::tryReassociateBinaryOp(BinaryOperator *I) {

465 Value *LHS = I->getOperand(0), *RHS = I->getOperand(1);

466

467 if (SE->getSCEV(I)->isZero())

468 return nullptr;

469 if (auto *NewI = tryReassociateBinaryOp(LHS, RHS, I))

470 return NewI;

471 if (auto *NewI = tryReassociateBinaryOp(RHS, LHS, I))

472 return NewI;

473 return nullptr;

474}

475

477 BinaryOperator *I) {

478 Value *A = nullptr, *B = nullptr;

479

480

482

483

484 const SCEV *AExpr = SE->getSCEV(A), *BExpr = SE->getSCEV(B);

485 const SCEV *RHSExpr = SE->getSCEV(RHS);

486 if (BExpr != RHSExpr) {

487 if (auto *NewI =

488 tryReassociatedBinaryOp(getBinarySCEV(I, AExpr, RHSExpr), B, I))

489 return NewI;

490 }

491 if (AExpr != RHSExpr) {

492 if (auto *NewI =

493 tryReassociatedBinaryOp(getBinarySCEV(I, BExpr, RHSExpr), A, I))

494 return NewI;

495 }

496 }

497 return nullptr;

498}

499

500Instruction *NaryReassociatePass::tryReassociatedBinaryOp(const SCEV *LHSExpr,

502 BinaryOperator *I) {

503

504

505 auto *LHS = findClosestMatchingDominator(LHSExpr, I);

506 if (LHS == nullptr)

507 return nullptr;

508

510 switch (I->getOpcode()) {

511 case Instruction::Add:

512 NewI = BinaryOperator::CreateAdd(LHS, RHS, "", I->getIterator());

513 break;

514 case Instruction::Mul:

515 NewI = BinaryOperator::CreateMul(LHS, RHS, "", I->getIterator());

516 break;

517 default:

519 }

522 return NewI;

523}

524

525bool NaryReassociatePass::matchTernaryOp(BinaryOperator *I, Value *V,

527 switch (I->getOpcode()) {

528 case Instruction::Add:

530 case Instruction::Mul:

532 default:

534 }

535 return false;

536}

537

538const SCEV *NaryReassociatePass::getBinarySCEV(BinaryOperator *I,

539 const SCEV *LHS,

540 const SCEV *RHS) {

541 switch (I->getOpcode()) {

542 case Instruction::Add:

543 return SE->getAddExpr(LHS, RHS);

544 case Instruction::Mul:

545 return SE->getMulExpr(LHS, RHS);

546 default:

548 }

549 return nullptr;

550}

551

553NaryReassociatePass::findClosestMatchingDominator(const SCEV *CandidateExpr,

554 Instruction *Dominatee) {

555 auto Pos = SeenExprs.find(CandidateExpr);

556 if (Pos == SeenExprs.end())

557 return nullptr;

558

559 auto &Candidates = Pos->second;

560

561

562

563

564 while (!Candidates.empty()) {

565

566

567 if (Value *Candidate = Candidates.pop_back_val()) {

569 if (!DT->dominates(CandidateInstruction, Dominatee))

570 continue;

571

572

573

575 if (!SE->canReuseInstruction(CandidateExpr, CandidateInstruction,

576 DropPoisonGeneratingInsts))

577 continue;

578

579 for (Instruction *I : DropPoisonGeneratingInsts)

580 I->dropPoisonGeneratingAnnotations();

581

582 return CandidateInstruction;

583 }

584 }

585 return nullptr;

586}

587

589 if (std::is_same_v<smax_pred_ty, typename MaxMinT::PredType>)

591 else if (std::is_same_v<umax_pred_ty, typename MaxMinT::PredType>)

593 else if (std::is_same_v<smin_pred_ty, typename MaxMinT::PredType>)

595 else if (std::is_same_v<umin_pred_ty, typename MaxMinT::PredType>)

597

598 llvm_unreachable("Can't convert MinMax pattern to SCEV type");

600}

601

602

603

604

605

606

607template

608Value *NaryReassociatePass::tryReassociateMinOrMax(Instruction *I,

609 MaxMinT MaxMinMatch,

611 Value *A = nullptr, *B = nullptr;

613

615 return nullptr;

616

618

619

621 return U != I && !(U->hasOneUser() && *U->users().begin() == I);

622 }))

623 return nullptr;

624

625 auto tryCombination = [&](Value *A, const SCEV *AExpr, Value *B,

626 const SCEV *BExpr, Value *C,

627 const SCEV *CExpr) -> Value * {

630 const SCEV *R1Expr = SE->getMinMaxExpr(SCEVType, Ops1);

631

632 Instruction *R1MinMax = findClosestMatchingDominator(R1Expr, I);

633

634 if (!R1MinMax)

635 return nullptr;

636

637 LLVM_DEBUG(dbgs() << "NARY: Found common sub-expr: " << *R1MinMax << "\n");

638

640 SE->getUnknown(R1MinMax)};

641 const SCEV *R2Expr = SE->getMinMaxExpr(SCEVType, Ops2);

642

643 SCEVExpander Expander(*SE, "nary-reassociate");

644 Value *NewMinMax = Expander.expandCodeFor(R2Expr, I->getType(), I);

645 NewMinMax->setName(Twine(I->getName()).concat(".nary"));

646

648 << "NARY: Inserting: " << *NewMinMax << "\n");

649 return NewMinMax;

650 };

651

652 const SCEV *AExpr = SE->getSCEV(A);

653 const SCEV *BExpr = SE->getSCEV(B);

654 const SCEV *RHSExpr = SE->getSCEV(RHS);

655

656 if (BExpr != RHSExpr) {

657

658 if (auto *NewMinMax = tryCombination(A, AExpr, RHS, RHSExpr, B, BExpr))

659 return NewMinMax;

660 }

661

662 if (AExpr != RHSExpr) {

663

664 if (auto *NewMinMax = tryCombination(RHS, RHSExpr, B, BExpr, A, AExpr))

665 return NewMinMax;

666 }

667

668 return nullptr;

669}

assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")

MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL

static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")

static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")

static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")

This file contains the declarations for the subclasses of Constant, which represent the different fla...

This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.

static bool runOnFunction(Function &F, bool PostInlining)

static bool runImpl(Function &F, const TargetLowering &TLI, const LibcallLoweringInfo &Libcalls, AssumptionCache *AC)

Module.h This file contains the declarations for the Module class.

static SCEVTypes convertToSCEVype(MaxMinT &MM)

Definition NaryReassociate.cpp:588

static bool isGEPFoldable(GetElementPtrInst *GEP, const TargetTransformInfo *TTI)

Definition NaryReassociate.cpp:328

#define INITIALIZE_PASS_DEPENDENCY(depName)

#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)

#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)

This file defines the SmallVector class.

This pass exposes codegen information to IR-level passes.

PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)

Get the result of an analysis pass for a given IR unit.

Represent the analysis usage information of a pass.

AnalysisUsage & addRequired()

AnalysisUsage & addPreserved()

Add the specified Pass class to the set of analyses preserved by this pass.

LLVM_ABI void setPreservesCFG()

This function should be called by the pass, iff they do not:

A function analysis which provides an AssumptionCache.

An immutable pass that tracks lazily created AssumptionCache objects.

A cache of @llvm.assume calls within a function.

LLVM Basic Block Representation.

Represents analyses that only rely on functions' control flow.

Analysis pass which computes a DominatorTree.

Legacy analysis pass which computes a DominatorTree.

Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.

FunctionPass class - This class is used to implement most global optimizations.

an instruction for type-safe pointer arithmetic to access elements of arrays and structs

LLVM_ABI void setIsInBounds(bool b=true)

Set or clear the inbounds flag on this GEP instruction.

void setDebugLoc(DebugLoc Loc)

Set the debug location information for this instruction.

A Module instance is used to store all the information related to an LLVM module.

bool runImpl(Function &F, AssumptionCache *AC_, DominatorTree *DT_, ScalarEvolution *SE_, TargetLibraryInfo *TLI_, TargetTransformInfo *TTI_)

Definition NaryReassociate.cpp:199

PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)

Definition NaryReassociate.cpp:182

static LLVM_ABI PassRegistry * getPassRegistry()

getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...

A set of analyses that are preserved following a run of a transformation pass.

static PreservedAnalyses all()

Construct a special preserved set that preserves all passes.

PreservedAnalyses & preserveSet()

Mark an analysis set as preserved.

PreservedAnalyses & preserve()

Mark an analysis as preserved.

This class represents an analyzed expression in the program.

Analysis pass that exposes the ScalarEvolution for a function.

The main scalar evolution driver.

void push_back(const T &Elt)

This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.

Analysis pass providing the TargetTransformInfo.

Analysis pass providing the TargetLibraryInfo.

Provides information about what library functions are available for the current target.

Wrapper pass for TargetTransformInfo.

This pass provides access to the codegen interfaces that are needed for IR-level transformations.

@ TCC_Free

Expected to fold away in lowering.

Type * getType() const

All values are typed, get the type of this value.

LLVM_ABI void setName(const Twine &Name)

Change the name of the value.

bool hasOneUse() const

Return true if there is exactly one use of this value.

iterator_range< user_iterator > users()

LLVM_ABI bool hasNUsesOrMore(unsigned N) const

Return true if this value has N uses or more.

LLVM_ABI void takeName(Value *V)

Transfer the name from V to this value.

Value handle that is nullable, but tries to track the Value.

bool isSequential() const

Type * getIndexedType() const

#define llvm_unreachable(msg)

Marks that the current location is not supposed to be reachable.

unsigned ID

LLVM IR allows to use arbitrary numbers as calling convention identifiers.

@ C

The default llvm calling convention, compatible with C.

BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)

bool match(Val *V, const Pattern &P)

BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)

class_match< Value > m_Value()

Match an arbitrary value and ignore it.

ElementType

The element type of an SRV or UAV resource.

friend class Instruction

Iterator for Instructions in a `BasicBlock.

This is an optimization pass for GlobalISel generic memory operations.

FunctionAddr VTableAddr Value

decltype(auto) dyn_cast(const From &Val)

dyn_cast - Return the argument parameter cast to the specified type.

LLVM_ABI FunctionPass * createNaryReassociatePass()

Definition NaryReassociate.cpp:165

auto dyn_cast_or_null(const Y &Val)

bool any_of(R &&range, UnaryPredicate P)

Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.

LLVM_ABI raw_ostream & dbgs()

dbgs() - This returns a reference to a raw_ostream for debugging messages.

generic_gep_type_iterator<> gep_type_iterator

class LLVM_GSL_OWNER SmallVector

Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...

IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >

LLVM_ABI OverflowResult computeOverflowForSignedAdd(const WithCache< const Value * > &LHS, const WithCache< const Value * > &RHS, const SimplifyQuery &SQ)

LLVM_ABI void initializeNaryReassociateLegacyPassPass(PassRegistry &)

LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructionsPermissive(SmallVectorImpl< WeakTrackingVH > &DeadInsts, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())

Same functionality as RecursivelyDeleteTriviallyDeadInstructions, but allow instructions that are not...

decltype(auto) cast(const From &Val)

cast - Return the argument parameter cast to the specified type.

gep_type_iterator gep_type_begin(const User *GEP)

iterator_range< df_iterator< T > > depth_first(const T &G)

AnalysisManager< Function > FunctionAnalysisManager

Convenience typedef for the Function analysis manager.

LLVM_ABI bool isKnownNonNegative(const Value *V, const SimplifyQuery &SQ, unsigned Depth=0)

Returns true if the give value is known to be non-negative.