MLIR: lib/AsmParser/DialectSymbolParser.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

26 #include "llvm/Support/MemoryBuffer.h"

27 #include "llvm/Support/SourceMgr.h"

28 #include

29 #include

30 #include

31

32 using namespace mlir;

34 using llvm::MemoryBuffer;

35 using llvm::SourceMgr;

36

37 namespace {

38

39

40

41 class CustomDialectAsmParser : public AsmParserImpl {

42 public:

43 CustomDialectAsmParser(StringRef fullSpec, Parser &parser)

45 fullSpec(fullSpec) {}

46 ~CustomDialectAsmParser() override = default;

47

48

49

50 StringRef getFullSymbolSpec() const override { return fullSpec; }

51

52 private:

53

54 StringRef fullSpec;

55 };

56 }

57

58

59

60

61

62

63

64

65

67 bool &isCodeCompletion) {

68

69

70

72

73

74

75

76 assert(*curPtr == '<');

79

80

81 auto emitPunctError = [&] {

82 return emitError() << "unbalanced '" << nestedPunctuation.back()

83 << "' character in pretty dialect name";

84 };

85

86 auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult {

87 if (nestedPunctuation.back() != expectedToken)

88 return emitPunctError();

89 nestedPunctuation.pop_back();

90 return success();

91 };

92 do {

93

94

95 if (curPtr == codeCompleteLoc) {

96 isCodeCompletion = true;

97 nestedPunctuation.clear();

98 break;

99 }

100

101 char c = *curPtr++;

102 switch (c) {

103 case '\0':

104

105 if (!nestedPunctuation.empty())

106 return emitPunctError();

107 return emitError("unexpected nul or EOF in pretty dialect name");

108 case '<':

109 case '[':

110 case '(':

111 case '{':

112 nestedPunctuation.push_back(c);

113 continue;

114

115 case '-':

116

117 if (*curPtr == '>')

118 ++curPtr;

119 continue;

120

121 case '>':

122 if (failed(checkNestedPunctuation('<')))

123 return failure();

124 break;

125 case ']':

126 if (failed(checkNestedPunctuation('[')))

127 return failure();

128 break;

129 case ')':

130 if (failed(checkNestedPunctuation('(')))

131 return failure();

132 break;

133 case '}':

134 if (failed(checkNestedPunctuation('{')))

135 return failure();

136 break;

137 case '"': {

138

141

142

143

145 isCodeCompletion = true;

146 nestedPunctuation.clear();

147 break;

148 }

149

150

152 return failure();

153 break;

154 }

155

156 default:

157 continue;

158 }

159 } while (!nestedPunctuation.empty());

160

161

162

164

165 unsigned length = curPtr - body.begin();

166 body = StringRef(body.data(), length);

167 return success();

168 }

169

170

171 template <typename Symbol, typename SymbolAliasMap, typename CreateFn>

173 SymbolAliasMap &aliases,

174 CreateFn &&createSymbol) {

176

177

178 StringRef identifier = tok.getSpelling().drop_front();

181

182

186

187

188 auto [dialectName, symbolData] = identifier.split('.');

189 bool isPrettyName = !symbolData.empty() || identifier.back() == '.';

190

191

192

193 bool hasTrailingData =

195 identifier.bytes_end() == p.getTokenSpelling().bytes_begin();

196

197

198

199 if (!hasTrailingData && !isPrettyName) {

200

201 auto aliasIt = aliases.find(identifier);

202 if (aliasIt == aliases.end())

203 return (p.emitWrongTokenError("undefined symbol alias id '" + identifier +

204 "'"),

205 nullptr);

206 if (asmState) {

207 if constexpr (std::is_same_v<Symbol, Type>)

209 else

211 }

212 return aliasIt->second;

213 }

214

215

216

217

218 if (!isPrettyName) {

219

220 symbolData = StringRef(dialectName.end(), 0);

221

222

223 bool isCodeCompletion = false;

225 return nullptr;

226 symbolData = symbolData.drop_front();

227

228

229

230 if (!isCodeCompletion)

231 symbolData = symbolData.drop_back();

232 } else {

233 loc = SMLoc::getFromPointer(symbolData.data());

234

235

236

238 return nullptr;

239 }

240

241 return createSymbol(dialectName, symbolData, loc);

242 }

243

244

245

246

247

248

249

250

251

254 Attribute attr = parseExtendedSymbol(

256 [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {

257

258 Type attrType = type;

259 if (consumeIf(Token::colon) && !(attrType = parseType()))

260 return Attribute();

261

262

263 if (Dialect *dialect =

264 builder.getContext()->getOrLoadDialect(dialectName)) {

265

266 const char *curLexerPos = getToken().getLoc().getPointer();

267 resetToken(symbolData.data());

268

269

270 CustomDialectAsmParser customParser(symbolData, *this);

271 Attribute attr = dialect->parseAttribute(customParser, attrType);

272 resetToken(curLexerPos);

273 return attr;

274 }

275

276

277 return OpaqueAttr::getChecked(

278 [&] { return emitError(loc); }, StringAttr::get(ctx, dialectName),

279 symbolData, attrType ? attrType : NoneType::get(ctx));

280 });

281

282

283 auto typedAttr = dyn_cast_or_null(attr);

284 if (type && typedAttr && typedAttr.getType() != type) {

285 emitError("attribute type different than expected: expected ")

286 << type << ", but got " << typedAttr.getType();

287 return nullptr;

288 }

289 return attr;

290 }

291

292

293

294

295

296

297

298

299 Type Parser::parseExtendedType() {

301 return parseExtendedSymbol(

302 *this, state.asmState, state.symbols.typeAliasDefinitions,

303 [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {

304

305 if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {

306

307 const char *curLexerPos = getToken().getLoc().getPointer();

308 resetToken(symbolData.data());

309

310

311 CustomDialectAsmParser customParser(symbolData, *this);

312 Type type = dialect->parseType(customParser);

313 resetToken(curLexerPos);

314 return type;

315 }

316

317

318 return OpaqueType::getChecked([&] { return emitError(loc); },

320 symbolData);

321 });

322 }

323

324

325

326

327

328

329

330 template <typename T, typename ParserFn>

332 size_t *numReadOut, bool isKnownNullTerminated,

333 ParserFn &&parserFn) {

334

335

336 auto memBuffer =

337 isKnownNullTerminated

338 ? MemoryBuffer::getMemBuffer(inputStr,

339 inputStr)

340 : MemoryBuffer::getMemBufferCopy(inputStr, inputStr);

341 SourceMgr sourceMgr;

342 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());

345 ParserState state(sourceMgr, config, aliasState, nullptr,

346 nullptr);

347 Parser parser(state);

348

350 T symbol = parserFn(parser);

351 if (!symbol)

352 return T();

353

354

356 size_t numRead =

357 endTok.getLoc().getPointer() - startTok.getLoc().getPointer();

358 if (numReadOut) {

359 *numReadOut = numRead;

360 } else if (numRead != inputStr.size()) {

361 parser.emitError(endTok.getLoc()) << "found trailing characters: '"

362 << inputStr.drop_front(numRead) << "'";

363 return T();

364 }

365 return symbol;

366 }

367

369 Type type, size_t *numRead,

370 bool isKnownNullTerminated) {

371 return parseSymbol(

372 attrStr, context, numRead, isKnownNullTerminated,

374 }

376 bool isKnownNullTerminated) {

377 return parseSymbol(typeStr, context, numRead, isKnownNullTerminated,

379 }

static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState, SymbolAliasMap &aliases, CreateFn &&createSymbol)

Parse an extended dialect symbol.

static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t *numReadOut, bool isKnownNullTerminated, ParserFn &&parserFn)

Parses a symbol, of type 'T', and returns it if parsing was successful.

static MLIRContext * getContext(OpFoldResult val)

This class represents state from a parsed MLIR textual format string.

void addTypeAliasUses(StringRef name, SMRange locations)

void addAttrAliasUses(StringRef name, SMRange locations)

Attributes are known-constant values of operations.

The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...

const char * getCodeCompleteLoc() const

Return the code completion location of the lexer, or nullptr if there is none.

MLIRContext is the top-level object for a collection of MLIR operations.

This class represents a configuration for the MLIR assembly parser.

This represents a token in the MLIR syntax.

SMRange getLocRange() const

bool isCodeCompletion() const

Returns true if the current token represents a code completion.

StringRef getSpelling() const

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides the implementation of the generic parser methods within AsmParser.

This class implement support for parsing global entities like attributes and types.

Type parseType()

Parse an arbitrary type.

InFlightDiagnostic emitError(const Twine &message={})

Emit an error and return failure.

ParserState & state

The Parser is subclassed and reinstantiated.

Attribute parseAttribute(Type type={})

Parse an arbitrary attribute with an optional type.

StringRef getTokenSpelling() const

void consumeToken()

Advance the current lexer onto the next token.

ParseResult parseDialectSymbolBody(StringRef &body, bool &isCodeCompletion)

Parse the body of a dialect symbol, which starts and ends with <>'s, and may be recursive.

MLIRContext * getContext() const

InFlightDiagnostic emitWrongTokenError(const Twine &message={})

Emit an error about a "wrong token".

void resetToken(const char *tokPos)

Reset the parser to the given lexer position.

Attribute parseExtendedAttr(Type type)

Parse an extended attribute.

const Token & getToken() const

Return the current token the parser is inspecting.

Attribute codeCompleteDialectSymbol(const llvm::StringMap< Attribute > &aliases)

Include the generated interface declarations.

const FrozenRewritePatternSet GreedyRewriteConfig config

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)

This parses a single MLIR attribute to an MLIR context if it was valid.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)

This parses a single MLIR type to an MLIR context if it was valid.

This class refers to all of the state maintained globally by the parser, such as the current lexer po...

SymbolState & symbols

The current state for symbol parsing.

Lexer lex

The lexer for the source file we're parsing.

Token curToken

This is the next token that hasn't been consumed yet.

AsmParserState * asmState

An optional pointer to a struct containing high level parser state to be populated during parsing.

This class contains record of any parsed top-level symbols.

llvm::StringMap< Attribute > attributeAliasDefinitions

A map from attribute alias identifier to Attribute.