MLIR: include/mlir/Analysis/DataFlowFramework.h Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16 #ifndef MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H

17 #define MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H

18

21 #include "llvm/ADT/EquivalenceClasses.h"

22 #include "llvm/ADT/Hashing.h"

23 #include "llvm/ADT/SetVector.h"

24 #include "llvm/Support/Compiler.h"

25 #include "llvm/Support/TypeName.h"

26 #include

27 #include

28

29 namespace mlir {

30

31

32

33

34

35

36

40 };

43 }

45 lhs = lhs | rhs;

46 return lhs;

47 }

50 }

51

52

53 class AnalysisState;

54

55

56

58

60 : block(parentBlock), point(pp) {}

61

62

64

65

66

67 using KeyTy = std::tuple<Block *, Block::iterator, Operation *>;

68

69

71

72

74 this->block = point.getBlock();

75 this->point = point.getPoint();

77 }

78

81 if (std::get<0>(key)) {

83 ProgramPoint(std::get<0>(key), std::get<1>(key));

84 }

86 }

87

88

89 bool isNull() const { return block == nullptr && op == nullptr; }

90

91

93 return block == std::get<0>(key) && point == std::get<1>(key) &&

94 op == std::get<2>(key);

95 }

96

98 return block == pp.block && point == pp.point && op == pp.op;

99 }

100

101

103

104

106

107

109

110

113

114

115

116 if (block == nullptr) {

117 return op;

118 }

119 return &*point;

120 }

121

122

125

126

127

128 if (block == nullptr) {

129 return op;

130 }

132 }

133

135

136 bool isBlockEnd() const { return block && block->end() == point; }

137

138

139 void print(raw_ostream &os) const;

140

141 private:

142 Block *block = nullptr;

144

145

146

148 };

149

152 return os;

153 }

154

155

156

157

158

159

160

161

162

163

164

165

166

168 public:

170

171

173

174

176

177

178 virtual void print(raw_ostream &os) const = 0;

179

180 protected:

181

183

184 private:

185

187 };

188

189

190

191

192

193

194

195

196

197

198

199 template <typename ConcreteT, typename Value>

201 public:

202

203

205

207

208

209

210 template

213 value(std::forward(value)) {}

214

215

216

217 template <typename... Args>

219 return uniquer.get({}, std::forward(args)...);

220 }

221

222

223 template

225 ValueT &&value) {

226 return new (alloc.allocate())

227 ConcreteT(std::forward(value));

228 }

229

230

231 bool operator==(const Value &value) const { return this->value == value; }

232

233

235 return point->getTypeID() == TypeID::get();

236 }

237

238

240

241 private:

242

244 };

245

246

247

248

249

250

252 : public PointerUnion<GenericLatticeAnchor *, ProgramPoint *, Value> {

254

255 using ParentTy::PointerUnion;

256

258

259

260 void print(raw_ostream &os) const;

261

262

264 };

265

266

267 class DataFlowAnalysis;

268

269 }

270

271 template <>

274

275 namespace mlir {

276

277

278

279

280

281

282

284 public:

286

287

288

289

290

292 interprocedural = enable;

293 return *this;

294 }

295

296

298

299 private:

300 bool interprocedural = true;

301 };

302

303

304

305

306

307

308

309

310

311

312

313

314

315

316

317

318

319

320

321

322

323

324

326 public:

328 : config(config) {

330 }

331

332

333 template <typename AnalysisT, typename... Args>

334 AnalysisT *load(Args &&...args);

335

336

337

339

340

341

342 template <typename StateT, typename AnchorT>

345 getLeaderAnchorOrSelf(LatticeAnchor(anchor));

346 const auto &mapIt = analysisStates.find(latticeAnchor);

347 if (mapIt == analysisStates.end())

348 return nullptr;

349 auto it = mapIt->second.find(TypeID::get());

350 if (it == mapIt->second.end())

351 return nullptr;

352 return static_cast<const StateT *>(it->second.get());

353 }

354

355

356 template

359

360

361 for (auto &&[TypeId, eqClass] : equivalentAnchorMap) {

362 if (!eqClass.contains(latticeAnchor)) {

363 continue;

364 }

365 llvm::EquivalenceClasses::member_iterator leaderIt =

366 eqClass.findLeader(latticeAnchor);

367

368

369 if (*leaderIt == latticeAnchor && ++leaderIt != eqClass.member_end()) {

370 analysisStates[*leaderIt][TypeId] =

371 std::move(analysisStates[latticeAnchor][TypeId]);

372 }

373

374 eqClass.erase(latticeAnchor);

375 }

376

377

378 analysisStates.erase(latticeAnchor);

379 }

380

381

383 analysisStates.clear();

384 equivalentAnchorMap.clear();

385 }

386

387

388

389 template <typename AnchorT, typename... Args>

391 return AnchorT::get(uniquer, std::forward(args)...);

392 }

393

394

399 else

402 }

403

406 nullptr);

407 }

408

413 else

416 }

417

420 nullptr);

421 }

422

423

424

425

426 using WorkItem = std::pair<ProgramPoint *, DataFlowAnalysis *>;

427

429

430

431

432 template <typename StateT, typename AnchorT>

434

435

436

437

438 template

440

441

442 template <typename StateT, typename AnchorT>

444

445

446 template

448

449

450

451

452

454

455

457

458 private:

459

461

462

463 bool isRunning = false;

464

465

466

467

468 std::queue worklist;

469

470

472

473

474

476

477

478

480 analysisStates;

481

482

483

484

485

487

488

490 };

491

492

493

494

495

496

497

498

499

500

501

502

503

504

505

506

507

508

509

511 public:

513

514

516

517

519

520

521 virtual void print(raw_ostream &os) const = 0;

522 LLVM_DUMP_METHOD void dump() const;

523

524

525

526

528

529 protected:

530

531

532

533

537 }

538

539

541

542 #if LLVM_ENABLE_ABI_BREAKING_CHECKS

543

544 StringRef debugName;

545 #endif

546

547 private:

548

549

550

551

552

553

554

555

556

558

559

561 };

562

563

564

565

566

567

568

569

570

571

572

573

574

575

576

577

578

579

580

581

582

584 public:

586

587

589

590

591

592

593

594

595

596

598

599

600

601

602

603

604

605

606

607

608

609

610

611

612

613

614

615

616

618

619

620

621

622

623

624

626

627 protected:

628

629

631

632

634

635

636 template

639 }

640

641

642 template <typename AnchorT, typename... Args>

644 return solver.getLatticeAnchor(std::forward(args)...);

645 }

646

647

648 template <typename StateT, typename AnchorT>

651 }

652

653

654

655

656 template <typename StateT, typename AnchorT>

659 }

660

661

662

663

664 template <typename StateT, typename AnchorT>

666 StateT *state = getOrCreate(anchor);

670 return state;

671 }

672

673

676 }

677

680 }

681

684 }

685

688 }

689

690

692

693 #if LLVM_ENABLE_ABI_BREAKING_CHECKS

694

695 StringRef debugName;

696 #endif

697

698 private:

699

701

702

704 };

705

706 template <typename AnalysisT, typename... Args>

708 childAnalyses.emplace_back(new AnalysisT(*this, std::forward(args)...));

709 #if LLVM_ENABLE_ABI_BREAKING_CHECKS

710 childAnalyses.back()->debugName = llvm::getTypeName();

711 #endif

712 return static_cast<AnalysisT *>(childAnalyses.back().get());

713 }

714

715 template

718 if (!equivalentAnchorMap.contains(TypeID::get())) {

719 return latticeAnchor;

720 }

721 const llvm::EquivalenceClasses &eqClass =

722 equivalentAnchorMap.at(TypeID::get());

723 llvm::EquivalenceClasses::member_iterator leaderIt =

724 eqClass.findLeader(latticeAnchor);

725 if (leaderIt != eqClass.member_end()) {

726 return *leaderIt;

727 }

728 return latticeAnchor;

729 }

730

731 template <typename StateT, typename AnchorT>

733

735 latticeAnchor = getLeaderAnchorOrSelf(latticeAnchor);

736 std::unique_ptr &state =

737 analysisStates[latticeAnchor][TypeID::get()];

738 if (!state) {

739 state = std::unique_ptr(new StateT(anchor));

740 #if LLVM_ENABLE_ABI_BREAKING_CHECKS

741 state->debugName = llvm::getTypeName();

742 #endif

743 }

744 return static_cast<StateT *>(state.get());

745 }

746

747 template

749 if (!equivalentAnchorMap.contains(TypeID::get())) {

750 return false;

751 }

752 const llvm::EquivalenceClasses &eqClass =

753 equivalentAnchorMap.at(TypeID::get());

754 if (!eqClass.contains(lhs) || !eqClass.contains(rhs))

755 return false;

756 return eqClass.isEquivalent(lhs, rhs);

757 }

758

759 template <typename StateT, typename AnchorT>

761 llvm::EquivalenceClasses &eqClass =

762 equivalentAnchorMap[TypeID::get()];

764 }

765

767 state.print(os);

768 return os;

769 }

770

772 anchor.print(os);

773 return os;

774 }

775

776 }

777

778 namespace llvm {

779

780 template <>

787 }

793 }

795 return hash_combine(pp.getBlock(), pp.getPoint().getNodePtr());

796 }

798 return lhs == rhs;

799 }

800 };

801

802

803 template

805 : public CastInfo<To, mlir::LatticeAnchor::PointerUnion> {};

806

807 template

809 : public CastInfo<To, const mlir::LatticeAnchor::PointerUnion> {};

810

811 }

812

813 #endif

Base class for generic analysis states.

AnalysisState(LatticeAnchor anchor)

Create the analysis state on the given lattice anchor.

LLVM_DUMP_METHOD void dump() const

LatticeAnchor getAnchor() const

Returns the lattice anchor this state is located at.

void addDependency(ProgramPoint *point, DataFlowAnalysis *analysis)

Add a dependency to this analysis state on a lattice anchor and an analysis.

virtual void print(raw_ostream &os) const =0

Print the contents of the analysis state.

virtual void onUpdate(DataFlowSolver *solver) const

This function is called by the solver when the analysis state is updated to enqueue more work items.

LatticeAnchor anchor

The lattice anchor to which the state belongs.

Block represents an ordered list of Operations.

OpListType::iterator iterator

Base class for all data-flow analyses.

void addDependency(AnalysisState *state, ProgramPoint *point)

Create a dependency between the given analysis state and lattice anchor on this analysis.

void unionLatticeAnchors(AnchorT anchor, AnchorT other)

Union input anchors under the given state.

void propagateIfChanged(AnalysisState *state, ChangeResult changed)

Propagate an update to a state if it changed.

const StateT * getOrCreateFor(ProgramPoint *dependent, AnchorT anchor)

Get a read-only analysis state for the given point and create a dependency on dependent.

ProgramPoint * getProgramPointAfter(Operation *op)

ProgramPoint * getProgramPointBefore(Operation *op)

Get a uniqued program point instance.

virtual void initializeEquivalentLatticeAnchor(Operation *top)

Initialize lattice anchor equivalence class from the provided top-level operation.

AnchorT * getLatticeAnchor(Args &&...args)

Get or create a custom lattice anchor.

virtual ~DataFlowAnalysis()

StateT * getOrCreate(AnchorT anchor)

Get the analysis state associated with the lattice anchor.

const DataFlowConfig & getSolverConfig() const

Return the configuration of the solver used for this analysis.

ProgramPoint * getProgramPointAfter(Block *block)

DataFlowAnalysis(DataFlowSolver &solver)

Create an analysis with a reference to the parent solver.

virtual LogicalResult initialize(Operation *top)=0

Initialize the analysis from the provided top-level operation by building an initial dependency graph...

ProgramPoint * getProgramPointBefore(Block *block)

virtual LogicalResult visit(ProgramPoint *point)=0

Visit the given program point.

void registerAnchorKind()

Register a custom lattice anchor class.

Configuration class for data flow solver and child analyses.

DataFlowConfig & setInterprocedural(bool enable)

Set whether the solver should operate interpocedurally, i.e.

bool isInterprocedural() const

Return true if the solver operates interprocedurally, false otherwise.

The general data-flow analysis solver.

void unionLatticeAnchors(AnchorT anchor, AnchorT other)

Union input anchors under the given state.

void enqueue(WorkItem item)

Push a work item onto the worklist.

bool isEquivalent(LatticeAnchor lhs, LatticeAnchor rhs) const

Return given lattice is equivalent on given state.

void eraseState(AnchorT anchor)

Erase any analysis state associated with the given lattice anchor.

void propagateIfChanged(AnalysisState *state, ChangeResult changed)

Propagate an update to an analysis state if it changed by pushing dependent work items to the back of...

const StateT * lookupState(AnchorT anchor) const

Lookup an analysis state for the given lattice anchor.

ProgramPoint * getProgramPointAfter(Operation *op)

const DataFlowConfig & getConfig() const

Get the configuration of the solver.

ProgramPoint * getProgramPointBefore(Operation *op)

Get a uniqued program point instance.

void eraseAllStates()

Erase all analysis states.

AnchorT * getLatticeAnchor(Args &&...args)

Get a uniqued lattice anchor instance.

ProgramPoint * getProgramPointBefore(Block *block)

StateT * getOrCreateState(AnchorT anchor)

Get the state associated with the given lattice anchor.

LatticeAnchor getLeaderAnchorOrSelf(LatticeAnchor latticeAnchor) const

Get leader lattice anchor in equivalence lattice anchor group, return input lattice anchor if input n...

ProgramPoint * getProgramPointAfter(Block *block)

AnalysisT * load(Args &&...args)

Load an analysis into the solver. Return the analysis instance.

LogicalResult initializeAndRun(Operation *top)

Initialize the children analyses starting from the provided top-level operation and run the analysis ...

DataFlowSolver(const DataFlowConfig &config=DataFlowConfig())

std::pair< ProgramPoint *, DataFlowAnalysis * > WorkItem

A work item on the solver queue is a program point, child analysis pair.

Base class for generic lattice anchor based on a concrete lattice anchor type and a content key.

bool operator==(const Value &value) const

Two lattice anchors are equal if their values are equal.

static ConcreteT * construct(StorageUniquer::StorageAllocator &alloc, ValueT &&value)

Allocate space for a lattice anchor and construct it in-place.

static bool classof(const GenericLatticeAnchor *point)

Provide LLVM-style RTTI using type IDs.

GenericLatticeAnchorBase(ValueT &&value)

Construct an instance of the lattice anchor using the provided value and the type ID of the concrete ...

const Value & getValue() const

Get the contents of the lattice anchor.

static ConcreteT * get(StorageUniquer &uniquer, Args &&...args)

Get a uniqued instance of this lattice anchor class with the given arguments.

Abstract class for generic lattice anchor.

virtual void print(raw_ostream &os) const =0

Print the lattice anchor.

TypeID getTypeID() const

Get the abstract lattice anchor's type identifier.

virtual Location getLoc() const =0

Get a derived source location for the lattice anchor.

GenericLatticeAnchor(TypeID typeID)

Create an abstract lattice anchor with type identifier.

virtual ~GenericLatticeAnchor()

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

Operation is the basic unit of execution within MLIR.

Block * getBlock()

Returns the operation block that contains this operation.

This class acts as the base storage that all storage classes must derived from.

This is a utility allocator used to allocate memory for instances of derived types.

T * allocate()

Allocate an instance of the provided type.

A utility class to get or create instances of "storage classes".

Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)

Gets a uniqued instance of 'Storage'.

void registerParametricStorageType(TypeID id)

Register a new parametric storage class, this is necessary to create instances of this class type.

This class provides an efficient unique identifier for a specific C++ type.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

The OpAsmOpInterface, see OpAsmInterface.td for more details.

Include the generated interface declarations.

ChangeResult & operator|=(ChangeResult &lhs, ChangeResult rhs)

const FrozenRewritePatternSet GreedyRewriteConfig bool * changed

ChangeResult

A result type used to indicate if a change happened.

ChangeResult operator&(ChangeResult lhs, ChangeResult rhs)

ChangeResult operator|(ChangeResult lhs, ChangeResult rhs)

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)

static unsigned getHashValue(mlir::ProgramPoint pp)

static mlir::ProgramPoint getEmptyKey()

static mlir::ProgramPoint getTombstoneKey()

static bool isEqual(mlir::ProgramPoint lhs, mlir::ProgramPoint rhs)

Fundamental IR components are supported as first-class lattice anchor.

LatticeAnchor(ParentTy point=nullptr)

Allow implicit conversion from the parent type.

Location getLoc() const

Get the source location of the lattice anchor.

void print(raw_ostream &os) const

Print the lattice anchor.

Program point represents a specific location in the execution of a program.

bool isNull() const

Returns true if this program point is set.

bool isBlockStart() const

ProgramPoint(Block *parentBlock, Block::iterator pp)

Creates a new program point at the given location.

Block::iterator getPoint() const

Get the the iterator this program point refers to.

ProgramPoint()

Create a empty program point.

Operation * getOperation() const

Get the the iterator this program point refers to.

Operation * getPrevOp() const

Get the previous operation of this program point.

static ProgramPoint * construct(StorageUniquer::StorageAllocator &alloc, KeyTy &&key)

bool operator==(const ProgramPoint &pp) const

bool operator==(const KeyTy &key) const

Two program points are equal if their block and iterator are equal.

ProgramPoint(const ProgramPoint &point)

Create a new program point from the given program point.

std::tuple< Block *, Block::iterator, Operation * > KeyTy

The concrete key type used by the storage uniquer.

void print(raw_ostream &os) const

Print the program point.

Operation * getNextOp() const

Get the next operation of this program point.

Block * getBlock() const

Get the block contains this program point.

ProgramPoint(Operation *op)

Creates a new program point at the given operation.