MLIR: lib/Dialect/Transform/IR/Utils.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
13 #include "llvm/Support/Debug.h"
14
15 using namespace mlir;
16
17 #define DEBUG_TYPE "transform-dialect-utils"
18 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
19
20
21
22
23
24 static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
25 return func1.isExternal() && (func2.isPublic() || func2.isExternal());
26 }
27
28
29
30
32 FunctionOpInterface func2) {
34 assert(func1->getParentOp() == func2->getParentOp() &&
35 "expected func1 and func2 to be in the same parent op");
36
37
38 if (func1.getFunctionType() != func2.getFunctionType()) {
39 return func1.emitError()
40 << "external definition has a mismatching signature ("
41 << func2.getFunctionType() << ")";
42 }
43
44
45 MLIRContext *context = func1->getContext();
46 auto *td = context->getLoadedDialecttransform::TransformDialect();
47 StringAttr consumedName = td->getConsumedAttrName();
48 StringAttr readOnlyName = td->getReadOnlyAttrName();
49 for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
50 bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
51 bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
52 bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
53 bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
54 if (!isExternalConsumed && !isExternalReadonly) {
55 if (isConsumed)
56 func2.setArgAttr(i, consumedName, UnitAttr::get(context));
57 else if (isReadonly)
58 func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
59 continue;
60 }
61
62 if ((isExternalConsumed && !isConsumed) ||
63 (isExternalReadonly && !isReadonly)) {
64 return func1.emitError()
65 << "external definition has mismatching consumption "
66 "annotations for argument #"
67 << i;
68 }
69 }
70
71
72 assert(func1.isExternal());
73 func1->erase();
74
76 }
77
82 "requires target to implement the 'SymbolTable' trait");
84 "requires target to implement the 'SymbolTable' trait");
85
88
89
90
91
92
93 LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
94
95 for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
98 &targetSymbolTable})) {
99 Operation *symbolTableOp = symbolTable->getOp();
101 auto symbolOp = dyn_cast(op);
102 if (!symbolOp)
103 continue;
104 StringAttr name = symbolOp.getNameAttr();
105 LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");
106
107
108 auto collidingOp =
109 cast_or_null(otherSymbolTable->lookup(name));
110 if (!collidingOp)
111 continue;
112
113 LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());
114
115
116 if (auto funcOp = dyn_cast(op),
117 collidingFuncOp =
118 dyn_cast(collidingOp.getOperation());
119 funcOp && collidingFuncOp) {
122 LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
123 "will be merged\n");
124 continue;
125 }
126
127
128 LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
129 }
130
131
132 auto renameToUnique =
133 [&](SymbolOpInterface op, SymbolOpInterface otherOp,
136 LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
137 FailureOr maybeNewName =
138 symbolTable.renameToUnique(op, {&otherSymbolTable});
139 if (failed(maybeNewName)) {
141 diag.attachNote(otherOp->getLoc())
142 << "attempted renaming due to collision with this op";
144 }
145 LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
146 << "\n");
148 };
149
150 if (symbolOp.isPrivate()) {
152 symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
153 if (failed(diag))
155 continue;
156 }
157 if (collidingOp.isPrivate()) {
159 collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
160 if (failed(diag))
162 continue;
163 }
164 LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
166 << "doubly defined symbol @" << name.getValue();
167 diag.attachNote(collidingOp->getLoc()) << "previously defined here";
169 }
170 }
171
172
173
176 return op->emitError() << "failed to verify input op after renaming";
177 }
178
179
180
181
182 LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
183 {
185 for (Operation &op : other->getRegion(0).front()) {
186 if (auto symbol = dyn_cast(op))
187 opsToMove.push_back(symbol);
188 }
189
190 for (SymbolOpInterface op : opsToMove) {
191
192 auto collidingOp = cast_or_null(
193 targetSymbolTable.lookup(op.getNameAttr()));
194
195
196 LLVM_DEBUG(DBGS() << " moving @" << op.getName());
199
200
201 if (!collidingOp) {
202 LLVM_DEBUG(llvm::dbgs() << " without collision\n");
203 continue;
204 }
205
206
207
208 auto funcOp = cast(op.getOperation());
209 auto collidingFuncOp =
210 cast(collidingOp.getOperation());
211
212
213
214
215 if ((funcOp, collidingFuncOp)) {
216 std::swap(funcOp, collidingFuncOp);
217 }
218 assert(canMergeInto(funcOp, collidingFuncOp));
219
220 LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
221 << collidingFuncOp.getLoc() << ":\n"
222 << collidingFuncOp << "\n");
223
224
225 targetSymbolTable.remove(funcOp);
226 targetSymbolTable.insert(collidingFuncOp);
227 assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
228
229
230 {
232 if (failed(diag))
234 }
235 }
236 }
237
240 << "failed to verify target op after merging symbols";
241
242 LLVM_DEBUG(DBGS() << "done merging ops\n");
244 }
static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2)
Return whether func1 can be merged into func2.
static InFlightDiagnostic mergeInto(FunctionOpInterface func1, FunctionOpInterface func2)
Merge func1 into func2.
static std::string diag(const llvm::Value &value)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
InFlightDiagnostic mergeSymbolsInto(Operation *target, OwningOpRef< Operation * > other)
Merge all symbols from other into target.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...