MLIR: include/mlir/Analysis/DataFlowFramework.h Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 #ifndef MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
17 #define MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
18
21 #include "llvm/ADT/EquivalenceClasses.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/Support/Compiler.h"
25 #include "llvm/Support/TypeName.h"
26 #include
27 #include
28
29 namespace mlir {
30
31
32
33
34
35
36
40 };
43 }
45 lhs = lhs | rhs;
46 return lhs;
47 }
50 }
51
52
53 class AnalysisState;
54
55
56
58
60 : block(parentBlock), point(pp) {}
61
62
64
65
66
67 using KeyTy = std::tuple<Block *, Block::iterator, Operation *>;
68
69
71
72
74 this->block = point.getBlock();
75 this->point = point.getPoint();
77 }
78
81 if (std::get<0>(key)) {
83 ProgramPoint(std::get<0>(key), std::get<1>(key));
84 }
86 }
87
88
89 bool isNull() const { return block == nullptr && op == nullptr; }
90
91
93 return block == std::get<0>(key) && point == std::get<1>(key) &&
94 op == std::get<2>(key);
95 }
96
98 return block == pp.block && point == pp.point && op == pp.op;
99 }
100
101
103
104
106
107
109
110
113
114
115
116 if (block == nullptr) {
117 return op;
118 }
119 return &*point;
120 }
121
122
125
126
127
128 if (block == nullptr) {
129 return op;
130 }
132 }
133
135
136 bool isBlockEnd() const { return block && block->end() == point; }
137
138
139 void print(raw_ostream &os) const;
140
141 private:
142 Block *block = nullptr;
144
145
146
148 };
149
152 return os;
153 }
154
155
156
157
158
159
160
161
162
163
164
165
166
168 public:
170
171
173
174
176
177
178 virtual void print(raw_ostream &os) const = 0;
179
180 protected:
181
183
184 private:
185
187 };
188
189
190
191
192
193
194
195
196
197
198
199 template <typename ConcreteT, typename Value>
201 public:
202
203
205
207
208
209
210 template
213 value(std::forward(value)) {}
214
215
216
217 template <typename... Args>
219 return uniquer.get({}, std::forward(args)...);
220 }
221
222
223 template
225 ValueT &&value) {
226 return new (alloc.allocate())
227 ConcreteT(std::forward(value));
228 }
229
230
231 bool operator==(const Value &value) const { return this->value == value; }
232
233
235 return point->getTypeID() == TypeID::get();
236 }
237
238
240
241 private:
242
244 };
245
246
247
248
249
250
252 : public PointerUnion<GenericLatticeAnchor *, ProgramPoint *, Value> {
254
255 using ParentTy::PointerUnion;
256
258
259
260 void print(raw_ostream &os) const;
261
262
264 };
265
266
267 class DataFlowAnalysis;
268
269 }
270
271 template <>
274
275 namespace mlir {
276
277
278
279
280
281
282
284 public:
286
287
288
289
290
292 interprocedural = enable;
293 return *this;
294 }
295
296
298
299 private:
300 bool interprocedural = true;
301 };
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
326 public:
328 : config(config) {
330 }
331
332
333 template <typename AnalysisT, typename... Args>
334 AnalysisT *load(Args &&...args);
335
336
337
339
340
341
342 template <typename StateT, typename AnchorT>
345 getLeaderAnchorOrSelf(LatticeAnchor(anchor));
346 const auto &mapIt = analysisStates.find(latticeAnchor);
347 if (mapIt == analysisStates.end())
348 return nullptr;
349 auto it = mapIt->second.find(TypeID::get());
350 if (it == mapIt->second.end())
351 return nullptr;
352 return static_cast<const StateT *>(it->second.get());
353 }
354
355
356 template
359
360
361 for (auto &&[TypeId, eqClass] : equivalentAnchorMap) {
362 if (!eqClass.contains(latticeAnchor)) {
363 continue;
364 }
365 llvm::EquivalenceClasses::member_iterator leaderIt =
366 eqClass.findLeader(latticeAnchor);
367
368
369 if (*leaderIt == latticeAnchor && ++leaderIt != eqClass.member_end()) {
370 analysisStates[*leaderIt][TypeId] =
371 std::move(analysisStates[latticeAnchor][TypeId]);
372 }
373
374 eqClass.erase(latticeAnchor);
375 }
376
377
378 analysisStates.erase(latticeAnchor);
379 }
380
381
383 analysisStates.clear();
384 equivalentAnchorMap.clear();
385 }
386
387
388
389 template <typename AnchorT, typename... Args>
391 return AnchorT::get(uniquer, std::forward(args)...);
392 }
393
394
399 else
402 }
403
406 nullptr);
407 }
408
413 else
416 }
417
420 nullptr);
421 }
422
423
424
425
426 using WorkItem = std::pair<ProgramPoint *, DataFlowAnalysis *>;
427
429
430
431
432 template <typename StateT, typename AnchorT>
434
435
436
437
438 template
440
441
442 template <typename StateT, typename AnchorT>
444
445
446 template
448
449
450
451
452
454
455
457
458 private:
459
461
462
463 bool isRunning = false;
464
465
466
467
468 std::queue worklist;
469
470
472
473
474
476
477
478
480 analysisStates;
481
482
483
484
485
487
488
490 };
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
511 public:
513
514
516
517
519
520
521 virtual void print(raw_ostream &os) const = 0;
522 LLVM_DUMP_METHOD void dump() const;
523
524
525
526
528
529 protected:
530
531
532
533
537 }
538
539
541
542 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
543
544 StringRef debugName;
545 #endif
546
547 private:
548
549
550
551
552
553
554
555
556
558
559
561 };
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
584 public:
586
587
589
590
591
592
593
594
595
596
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
618
619
620
621
622
623
624
626
627 protected:
628
629
631
632
634
635
636 template
639 }
640
641
642 template <typename AnchorT, typename... Args>
644 return solver.getLatticeAnchor(std::forward(args)...);
645 }
646
647
648 template <typename StateT, typename AnchorT>
651 }
652
653
654
655
656 template <typename StateT, typename AnchorT>
659 }
660
661
662
663
664 template <typename StateT, typename AnchorT>
666 StateT *state = getOrCreate(anchor);
670 return state;
671 }
672
673
676 }
677
680 }
681
684 }
685
688 }
689
690
692
693 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
694
695 StringRef debugName;
696 #endif
697
698 private:
699
701
702
704 };
705
706 template <typename AnalysisT, typename... Args>
708 childAnalyses.emplace_back(new AnalysisT(*this, std::forward(args)...));
709 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
710 childAnalyses.back()->debugName = llvm::getTypeName();
711 #endif
712 return static_cast<AnalysisT *>(childAnalyses.back().get());
713 }
714
715 template
718 if (!equivalentAnchorMap.contains(TypeID::get())) {
719 return latticeAnchor;
720 }
721 const llvm::EquivalenceClasses &eqClass =
722 equivalentAnchorMap.at(TypeID::get());
723 llvm::EquivalenceClasses::member_iterator leaderIt =
724 eqClass.findLeader(latticeAnchor);
725 if (leaderIt != eqClass.member_end()) {
726 return *leaderIt;
727 }
728 return latticeAnchor;
729 }
730
731 template <typename StateT, typename AnchorT>
733
735 latticeAnchor = getLeaderAnchorOrSelf(latticeAnchor);
736 std::unique_ptr &state =
737 analysisStates[latticeAnchor][TypeID::get()];
738 if (!state) {
739 state = std::unique_ptr(new StateT(anchor));
740 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
741 state->debugName = llvm::getTypeName();
742 #endif
743 }
744 return static_cast<StateT *>(state.get());
745 }
746
747 template
749 if (!equivalentAnchorMap.contains(TypeID::get())) {
750 return false;
751 }
752 const llvm::EquivalenceClasses &eqClass =
753 equivalentAnchorMap.at(TypeID::get());
754 if (!eqClass.contains(lhs) || !eqClass.contains(rhs))
755 return false;
756 return eqClass.isEquivalent(lhs, rhs);
757 }
758
759 template <typename StateT, typename AnchorT>
761 llvm::EquivalenceClasses &eqClass =
762 equivalentAnchorMap[TypeID::get()];
764 }
765
767 state.print(os);
768 return os;
769 }
770
772 anchor.print(os);
773 return os;
774 }
775
776 }
777
778 namespace llvm {
779
780 template <>
787 }
793 }
795 return hash_combine(pp.getBlock(), pp.getPoint().getNodePtr());
796 }
798 return lhs == rhs;
799 }
800 };
801
802
803 template
805 : public CastInfo<To, mlir::LatticeAnchor::PointerUnion> {};
806
807 template
809 : public CastInfo<To, const mlir::LatticeAnchor::PointerUnion> {};
810
811 }
812
813 #endif
Base class for generic analysis states.
AnalysisState(LatticeAnchor anchor)
Create the analysis state on the given lattice anchor.
LLVM_DUMP_METHOD void dump() const
LatticeAnchor getAnchor() const
Returns the lattice anchor this state is located at.
void addDependency(ProgramPoint *point, DataFlowAnalysis *analysis)
Add a dependency to this analysis state on a lattice anchor and an analysis.
virtual void print(raw_ostream &os) const =0
Print the contents of the analysis state.
virtual void onUpdate(DataFlowSolver *solver) const
This function is called by the solver when the analysis state is updated to enqueue more work items.
LatticeAnchor anchor
The lattice anchor to which the state belongs.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Base class for all data-flow analyses.
void addDependency(AnalysisState *state, ProgramPoint *point)
Create a dependency between the given analysis state and lattice anchor on this analysis.
void unionLatticeAnchors(AnchorT anchor, AnchorT other)
Union input anchors under the given state.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
const StateT * getOrCreateFor(ProgramPoint *dependent, AnchorT anchor)
Get a read-only analysis state for the given point and create a dependency on dependent.
ProgramPoint * getProgramPointAfter(Operation *op)
ProgramPoint * getProgramPointBefore(Operation *op)
Get a uniqued program point instance.
virtual void initializeEquivalentLatticeAnchor(Operation *top)
Initialize lattice anchor equivalence class from the provided top-level operation.
AnchorT * getLatticeAnchor(Args &&...args)
Get or create a custom lattice anchor.
virtual ~DataFlowAnalysis()
StateT * getOrCreate(AnchorT anchor)
Get the analysis state associated with the lattice anchor.
const DataFlowConfig & getSolverConfig() const
Return the configuration of the solver used for this analysis.
ProgramPoint * getProgramPointAfter(Block *block)
DataFlowAnalysis(DataFlowSolver &solver)
Create an analysis with a reference to the parent solver.
virtual LogicalResult initialize(Operation *top)=0
Initialize the analysis from the provided top-level operation by building an initial dependency graph...
ProgramPoint * getProgramPointBefore(Block *block)
virtual LogicalResult visit(ProgramPoint *point)=0
Visit the given program point.
void registerAnchorKind()
Register a custom lattice anchor class.
Configuration class for data flow solver and child analyses.
DataFlowConfig & setInterprocedural(bool enable)
Set whether the solver should operate interpocedurally, i.e.
bool isInterprocedural() const
Return true if the solver operates interprocedurally, false otherwise.
The general data-flow analysis solver.
void unionLatticeAnchors(AnchorT anchor, AnchorT other)
Union input anchors under the given state.
void enqueue(WorkItem item)
Push a work item onto the worklist.
bool isEquivalent(LatticeAnchor lhs, LatticeAnchor rhs) const
Return given lattice is equivalent on given state.
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to an analysis state if it changed by pushing dependent work items to the back of...
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
ProgramPoint * getProgramPointAfter(Operation *op)
const DataFlowConfig & getConfig() const
Get the configuration of the solver.
ProgramPoint * getProgramPointBefore(Operation *op)
Get a uniqued program point instance.
void eraseAllStates()
Erase all analysis states.
AnchorT * getLatticeAnchor(Args &&...args)
Get a uniqued lattice anchor instance.
ProgramPoint * getProgramPointBefore(Block *block)
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
LatticeAnchor getLeaderAnchorOrSelf(LatticeAnchor latticeAnchor) const
Get leader lattice anchor in equivalence lattice anchor group, return input lattice anchor if input n...
ProgramPoint * getProgramPointAfter(Block *block)
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
DataFlowSolver(const DataFlowConfig &config=DataFlowConfig())
std::pair< ProgramPoint *, DataFlowAnalysis * > WorkItem
A work item on the solver queue is a program point, child analysis pair.
Base class for generic lattice anchor based on a concrete lattice anchor type and a content key.
bool operator==(const Value &value) const
Two lattice anchors are equal if their values are equal.
static ConcreteT * construct(StorageUniquer::StorageAllocator &alloc, ValueT &&value)
Allocate space for a lattice anchor and construct it in-place.
static bool classof(const GenericLatticeAnchor *point)
Provide LLVM-style RTTI using type IDs.
GenericLatticeAnchorBase(ValueT &&value)
Construct an instance of the lattice anchor using the provided value and the type ID of the concrete ...
const Value & getValue() const
Get the contents of the lattice anchor.
static ConcreteT * get(StorageUniquer &uniquer, Args &&...args)
Get a uniqued instance of this lattice anchor class with the given arguments.
Abstract class for generic lattice anchor.
virtual void print(raw_ostream &os) const =0
Print the lattice anchor.
TypeID getTypeID() const
Get the abstract lattice anchor's type identifier.
virtual Location getLoc() const =0
Get a derived source location for the lattice anchor.
GenericLatticeAnchor(TypeID typeID)
Create an abstract lattice anchor with type identifier.
virtual ~GenericLatticeAnchor()
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
This class acts as the base storage that all storage classes must derived from.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
A utility class to get or create instances of "storage classes".
Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)
Gets a uniqued instance of 'Storage'.
void registerParametricStorageType(TypeID id)
Register a new parametric storage class, this is necessary to create instances of this class type.
This class provides an efficient unique identifier for a specific C++ type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Include the generated interface declarations.
ChangeResult & operator|=(ChangeResult &lhs, ChangeResult rhs)
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
ChangeResult
A result type used to indicate if a change happened.
ChangeResult operator&(ChangeResult lhs, ChangeResult rhs)
ChangeResult operator|(ChangeResult lhs, ChangeResult rhs)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
static unsigned getHashValue(mlir::ProgramPoint pp)
static mlir::ProgramPoint getEmptyKey()
static mlir::ProgramPoint getTombstoneKey()
static bool isEqual(mlir::ProgramPoint lhs, mlir::ProgramPoint rhs)
Fundamental IR components are supported as first-class lattice anchor.
LatticeAnchor(ParentTy point=nullptr)
Allow implicit conversion from the parent type.
Location getLoc() const
Get the source location of the lattice anchor.
void print(raw_ostream &os) const
Print the lattice anchor.
Program point represents a specific location in the execution of a program.
bool isNull() const
Returns true if this program point is set.
bool isBlockStart() const
ProgramPoint(Block *parentBlock, Block::iterator pp)
Creates a new program point at the given location.
Block::iterator getPoint() const
Get the the iterator this program point refers to.
ProgramPoint()
Create a empty program point.
Operation * getOperation() const
Get the the iterator this program point refers to.
Operation * getPrevOp() const
Get the previous operation of this program point.
static ProgramPoint * construct(StorageUniquer::StorageAllocator &alloc, KeyTy &&key)
bool operator==(const ProgramPoint &pp) const
bool operator==(const KeyTy &key) const
Two program points are equal if their block and iterator are equal.
ProgramPoint(const ProgramPoint &point)
Create a new program point from the given program point.
std::tuple< Block *, Block::iterator, Operation * > KeyTy
The concrete key type used by the storage uniquer.
void print(raw_ostream &os) const
Print the program point.
Operation * getNextOp() const
Get the next operation of this program point.
Block * getBlock() const
Get the block contains this program point.
ProgramPoint(Operation *op)
Creates a new program point at the given operation.