MLIR: lib/Dialect/Transform/IR/Utils.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

13 #include "llvm/Support/Debug.h"

14

15 using namespace mlir;

16

17 #define DEBUG_TYPE "transform-dialect-utils"

18 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")

19

20

21

22

23

24 static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {

25 return func1.isExternal() && (func2.isPublic() || func2.isExternal());

26 }

27

28

29

30

32 FunctionOpInterface func2) {

34 assert(func1->getParentOp() == func2->getParentOp() &&

35 "expected func1 and func2 to be in the same parent op");

36

37

38 if (func1.getFunctionType() != func2.getFunctionType()) {

39 return func1.emitError()

40 << "external definition has a mismatching signature ("

41 << func2.getFunctionType() << ")";

42 }

43

44

45 MLIRContext *context = func1->getContext();

46 auto *td = context->getLoadedDialecttransform::TransformDialect();

47 StringAttr consumedName = td->getConsumedAttrName();

48 StringAttr readOnlyName = td->getReadOnlyAttrName();

49 for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {

50 bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;

51 bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;

52 bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;

53 bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;

54 if (!isExternalConsumed && !isExternalReadonly) {

55 if (isConsumed)

56 func2.setArgAttr(i, consumedName, UnitAttr::get(context));

57 else if (isReadonly)

58 func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));

59 continue;

60 }

61

62 if ((isExternalConsumed && !isConsumed) ||

63 (isExternalReadonly && !isReadonly)) {

64 return func1.emitError()

65 << "external definition has mismatching consumption "

66 "annotations for argument #"

67 << i;

68 }

69 }

70

71

72 assert(func1.isExternal());

73 func1->erase();

74

76 }

77

82 "requires target to implement the 'SymbolTable' trait");

84 "requires target to implement the 'SymbolTable' trait");

85

88

89

90

91

92

93 LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");

94

95 for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(

98 &targetSymbolTable})) {

99 Operation *symbolTableOp = symbolTable->getOp();

101 auto symbolOp = dyn_cast(op);

102 if (!symbolOp)

103 continue;

104 StringAttr name = symbolOp.getNameAttr();

105 LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");

106

107

108 auto collidingOp =

109 cast_or_null(otherSymbolTable->lookup(name));

110 if (!collidingOp)

111 continue;

112

113 LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());

114

115

116 if (auto funcOp = dyn_cast(op),

117 collidingFuncOp =

118 dyn_cast(collidingOp.getOperation());

119 funcOp && collidingFuncOp) {

122 LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "

123 "will be merged\n");

124 continue;

125 }

126

127

128 LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");

129 }

130

131

132 auto renameToUnique =

133 [&](SymbolOpInterface op, SymbolOpInterface otherOp,

136 LLVM_DEBUG(llvm::dbgs() << ", renaming\n");

137 FailureOr maybeNewName =

138 symbolTable.renameToUnique(op, {&otherSymbolTable});

139 if (failed(maybeNewName)) {

141 diag.attachNote(otherOp->getLoc())

142 << "attempted renaming due to collision with this op";

144 }

145 LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()

146 << "\n");

148 };

149

150 if (symbolOp.isPrivate()) {

152 symbolOp, collidingOp, *symbolTable, *otherSymbolTable);

153 if (failed(diag))

155 continue;

156 }

157 if (collidingOp.isPrivate()) {

159 collidingOp, symbolOp, *otherSymbolTable, *symbolTable);

160 if (failed(diag))

162 continue;

163 }

164 LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");

166 << "doubly defined symbol @" << name.getValue();

167 diag.attachNote(collidingOp->getLoc()) << "previously defined here";

169 }

170 }

171

172

173

176 return op->emitError() << "failed to verify input op after renaming";

177 }

178

179

180

181

182 LLVM_DEBUG(DBGS() << "moving all symbols into target\n");

183 {

185 for (Operation &op : other->getRegion(0).front()) {

186 if (auto symbol = dyn_cast(op))

187 opsToMove.push_back(symbol);

188 }

189

190 for (SymbolOpInterface op : opsToMove) {

191

192 auto collidingOp = cast_or_null(

193 targetSymbolTable.lookup(op.getNameAttr()));

194

195

196 LLVM_DEBUG(DBGS() << " moving @" << op.getName());

199

200

201 if (!collidingOp) {

202 LLVM_DEBUG(llvm::dbgs() << " without collision\n");

203 continue;

204 }

205

206

207

208 auto funcOp = cast(op.getOperation());

209 auto collidingFuncOp =

210 cast(collidingOp.getOperation());

211

212

213

214

215 if (canMergeInto(funcOp, collidingFuncOp)) {

216 std::swap(funcOp, collidingFuncOp);

217 }

218 assert(canMergeInto(funcOp, collidingFuncOp));

219

220 LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "

221 << collidingFuncOp.getLoc() << ":\n"

222 << collidingFuncOp << "\n");

223

224

225 targetSymbolTable.remove(funcOp);

226 targetSymbolTable.insert(collidingFuncOp);

227 assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);

228

229

230 {

232 if (failed(diag))

234 }

235 }

236 }

237

240 << "failed to verify target op after merging symbols";

241

242 LLVM_DEBUG(DBGS() << "done merging ops\n");

244 }

static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2)

Return whether func1 can be merged into func2.

static InFlightDiagnostic mergeInto(FunctionOpInterface func1, FunctionOpInterface func2)

Merge func1 into func2.

static std::string diag(const llvm::Value &value)

This class represents a diagnostic that is inflight and set to be reported.

MLIRContext is the top-level object for a collection of MLIR operations.

Dialect * getLoadedDialect(StringRef name)

Get a registered IR dialect with the given namespace.

A trait used to provide symbol table functionalities to a region operation.

Operation is the basic unit of execution within MLIR.

bool hasTrait()

Returns true if the operation was registered with a particular trait, e.g.

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

Region & getRegion(unsigned index)

Returns the region held by this operation at position 'index'.

This class acts as an owning reference to an op, and will automatically destroy the held op on destru...

This class allows for representing and managing the symbol table used by operations with the 'SymbolT...

Operation * lookup(StringRef name) const

Look up a symbol with the specified name, returning null if no such name exists.

StringAttr insert(Operation *symbol, Block::iterator insertPt={})

Insert a new symbol into the table, and rename it as necessary to avoid collisions.

void remove(Operation *op)

Remove the given symbol from the table, without deleting it.

InFlightDiagnostic mergeSymbolsInto(Operation *target, OwningOpRef< Operation * > other)

Merge all symbols from other into target.

Include the generated interface declarations.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...