MLIR: lib/Tools/PDLL/CodeGen/MLIRGen.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

22 #include "llvm/ADT/ScopedHashTable.h"

23 #include "llvm/ADT/StringExtras.h"

24 #include "llvm/ADT/TypeSwitch.h"

25 #include

26

27 using namespace mlir;

29

30

31

32

33

34 namespace {

35 class CodeGen {

36 public:

38 const llvm::SourceMgr &sourceMgr)

39 : builder(mlirContext), odsContext(context.getODSContext()),

40 sourceMgr(sourceMgr) {

41

42 mlirContext->loadDialectpdl::PDLDialect();

43 }

44

46

47 private:

48

49 Location genLoc(llvm::SMLoc loc);

50 Location genLoc(llvm::SMRange loc) { return genLoc(loc.Start); }

51

52

54

55

57

58

59

60

61

68

69

70

71

72

76

77

78

80

81

82

83

85

86

87

89

90

91

92

93

104

107 bool isNegated = false);

110 template <typename PDLOpT, typename T>

113 bool isNegated = false);

114

115

116

117

118

119

121

122

123 using VariableMapTy =

124 llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector>;

125 VariableMapTy variables;

126

127

129

130

131 const llvm::SourceMgr &sourceMgr;

132 };

133 }

134

137 builder.create(genLoc(module.getLoc()));

138 builder.setInsertionPointToStart(mlirModule->getBody());

139

140

142 gen(decl);

143

144 return mlirModule;

145 }

146

147 Location CodeGen::genLoc(llvm::SMLoc loc) {

148 unsigned fileID = sourceMgr.FindBufferContainingLoc(loc);

149

150

151

152 auto &bufferInfo = sourceMgr.getBufferInfo(fileID);

153 unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer());

154 unsigned column =

155 (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1;

156 auto *buffer = sourceMgr.getMemoryBuffer(fileID);

157

159 buffer->getBufferIdentifier(), lineNo, column);

160 }

161

165 return builder.getTypepdl::AttributeType();

166 })

168 return builder.getTypepdl::OperationType();

169 })

171 return builder.getTypepdl::TypeType();

172 })

174 return builder.getTypepdl::ValueType();

175 })

178 });

179 }

180

181 void CodeGen::gen(const ast::Node *node) {

187 [&](auto derivedNode) { this->genImpl(derivedNode); })

188 .Case([&](const ast::Expr *expr) { genExpr(expr); });

189 }

190

191

192

193

194

196 VariableMapTy::ScopeTy varScope(variables);

198 gen(childStmt);

199 }

200

201

202

203

208 builder.createpdl::RewriteOp(loc, rootExpr, StringAttr(),

211 }

212 }

213

216 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());

217 Location loc = genLoc(stmt->getLoc());

218

219

222 builder.createpdl::EraseOp(loc, rootExpr);

223 }

224

226

229 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());

230 Location loc = genLoc(stmt->getLoc());

231

232

235

238 replValues.push_back(genSingleExpr(replExpr));

239

240

241

242 bool usesReplOperation =

243 replValues.size() == 1 &&

244 isapdl::OperationType(replValues.front().getType());

245 builder.createpdl::ReplaceOp(

246 loc, rootExpr, usesReplOperation ? replValues[0] : Value(),

248 }

249

252 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());

253

254

258 }

259

261

262

263 }

264

265

266

267

268

270

271

272

273 }

274

276

277

278

279 }

280

282 const ast::Name *name = decl->getName();

283

284

285

286 pdl::PatternOp pattern = builder.createpdl::PatternOp(

287 genLoc(decl->getLoc()), decl->getBenefit(),

288 name ? std::optional(name->getName())

289 : std::optional());

290

292 builder.setInsertionPointToStart(pattern.getBody());

294 }

295

297 auto it = variables.begin(varDecl);

298 if (it != variables.end())

299 return *it;

300

301

302

305 values = genExpr(initExpr);

306 else

307 values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->getLoc())));

308

309

310 applyVarConstraints(varDecl, values);

311

312 variables.insert(varDecl, values);

313 return values;

314 }

315

318

319 auto getTypeConstraint = [&]() -> Value {

321 Value typeValue =

325 [&, this](auto *cst) -> Value {

326 if (auto *typeConstraintExpr = cst->getTypeExpr())

327 return this->genSingleExpr(typeConstraintExpr);

329 })

330 .Default(Value());

331 if (typeValue)

332 return typeValue;

333 }

335 };

336

337

339 Type mlirType = genType(type);

340 if (isaast::ValueType(type))

341 return builder.createpdl::OperandOp(loc, mlirType, getTypeConstraint());

342 if (isaast::TypeType(type))

343 return builder.createpdl::TypeOp(loc, mlirType, TypeAttr());

344 if (isaast::AttributeType(type))

345 return builder.createpdl::AttributeOp(loc, getTypeConstraint());

346 if (ast::OperationType opType = dyn_castast::OperationType(type)) {

347 Value operands = builder.createpdl::OperandsOp(

349 Value());

350 Value results = builder.createpdl::TypesOp(

352 ArrayAttr());

353 return builder.createpdl::OperationOp(

354 loc, opType.getName(), operands, std::nullopt, ValueRange(), results);

355 }

356

357 if (ast::RangeType rangeTy = dyn_castast::RangeType(type)) {

358 ast::Type eleTy = rangeTy.getElementType();

359 if (isaast::ValueType(eleTy))

360 return builder.createpdl::OperandsOp(loc, mlirType,

361 getTypeConstraint());

362 if (isaast::TypeType(eleTy))

363 return builder.createpdl::TypesOp(loc, mlirType, ArrayAttr());

364 }

365

366 llvm_unreachable("invalid non-initialized variable type");

367 }

368

369 void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl,

371

372

374 if (const auto *userCst = dyn_castast::UserConstraintDecl(ref.constraint))

375 genConstraintCall(userCst, genLoc(ref.referenceLoc), values);

376 }

377

378

379

380

381

382 Value CodeGen::genSingleExpr(const ast::Expr *expr) {

387 [&](auto derivedNode) { return this->genExprImpl(derivedNode); })

388 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(

389 [&](auto derivedNode) {

390 return llvm::getSingleElement(this->genExprImpl(derivedNode));

391 });

392 }

393

396 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(

397 [&](auto derivedNode) { return this->genExprImpl(derivedNode); })

399 return {genSingleExpr(expr)};

400 });

401 }

402

405 assert(attr && "invalid MLIR attribute data");

406 return builder.createpdl::AttributeOp(genLoc(expr->getLoc()), attr);

407 }

408

410 Location loc = genLoc(expr->getLoc());

413 arguments.push_back(genSingleExpr(arg));

414

415

416 auto *callableExpr = dyn_castast::DeclRefExpr(expr->getCallableExpr());

417 assert(callableExpr && "unhandled CallExpr callable");

418

419

420 const ast::Decl *callable = callableExpr->getDecl();

421 if (const auto *decl = dyn_castast::UserConstraintDecl(callable))

422 return genConstraintCall(decl, loc, arguments, expr->getIsNegated());

423 if (const auto *decl = dyn_castast::UserRewriteDecl(callable))

424 return genRewriteCall(decl, loc, arguments);

425 llvm_unreachable("unhandled CallExpr callable");

426 }

427

429 if (const auto *varDecl = dyn_castast::VariableDecl(expr->getDecl()))

430 return genVar(varDecl);

431 llvm_unreachable("unknown decl reference expression");

432 }

433

435 Location loc = genLoc(expr->getLoc());

439

440

441 if (ast::OperationType opType = dyn_castast::OperationType(parentType)) {

442 if (isaast::AllResultsMemberAccessExpr(expr)) {

443 Type mlirType = genType(expr->getType());

444 if (isapdl::ValueType(mlirType))

445 return builder.createpdl::ResultOp(loc, mlirType, parentExprs[0],

446 builder.getI32IntegerAttr(0));

447 return builder.createpdl::ResultsOp(loc, mlirType, parentExprs[0]);

448 }

449

450 const ods::Operation *odsOp = opType.getODSOperation();

451 if (!odsOp) {

452 assert(llvm::isDigit(name[0]) &&

453 "unregistered op only allows numeric indexing");

454 unsigned resultIndex;

455 name.getAsInteger(10, resultIndex);

456 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);

457 return builder.createpdl::ResultOp(loc, genType(expr->getType()),

458 parentExprs[0], index);

459 }

460

461

463 unsigned resultIndex = results.size();

464 if (llvm::isDigit(name[0])) {

465 name.getAsInteger(10, resultIndex);

466 } else {

468 return result.getName() == name;

469 };

470 resultIndex = llvm::find_if(results, findFn) - results.begin();

471 }

472 assert(resultIndex < results.size() && "invalid result index");

473

474

475 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);

476 return builder.createpdl::ResultsOp(loc, genType(expr->getType()),

477 parentExprs[0], index);

478 }

479

480

481 if (auto tupleType = dyn_castast::TupleType(parentType)) {

482 auto elementNames = tupleType.getElementNames();

483

484

485 unsigned index = 0;

486 if (llvm::isDigit(name[0]))

487 name.getAsInteger(10, index);

488 else

489 index = llvm::find(elementNames, name) - elementNames.begin();

490

491 assert(index < parentExprs.size() && "invalid result index");

492 return parentExprs[index];

493 }

494

495 llvm_unreachable("unhandled member access expression");

496 }

497

499 Location loc = genLoc(expr->getLoc());

500 std::optional opName = expr->getName();

501

502

505 operands.push_back(genSingleExpr(operand));

506

507

511 attrNames.push_back(attr->getName().getName());

512 attrValues.push_back(genSingleExpr(attr->getValue()));

513 }

514

515

518 results.push_back(genSingleExpr(result));

519

520 return builder.createpdl::OperationOp(loc, opName, operands, attrNames,

521 attrValues, results);

522 }

523

527 llvm::append_range(elements, genExpr(element));

528

529 return builder.createpdl::RangeOp(genLoc(expr->getLoc()),

530 genType(expr->getType()), elements);

531 }

532

536 elements.push_back(genSingleExpr(element));

537 return elements;

538 }

539

542 assert(type && "invalid MLIR type data");

543 return builder.createpdl::TypeOp(genLoc(expr->getLoc()),

544 builder.getTypepdl::TypeType(),

546 }

547

551

552 for (auto it : llvm::zip(decl->getInputs(), inputs))

553 applyVarConstraints(std::get<0>(it), std::get<1>(it));

554

555

557 genConstraintOrRewriteCallpdl::ApplyNativeConstraintOp(

558 decl, loc, inputs, isNegated);

559

560

561 for (auto it : llvm::zip(decl->getResults(), results))

562 applyVarConstraints(std::get<0>(it), std::get<1>(it));

563 return results;

564 }

565

568 return genConstraintOrRewriteCallpdl::ApplyNativeRewriteOp(decl, loc,

569 inputs);

570 }

571

572 template <typename PDLOpT, typename T>

574 CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc,

577

578

579 if (!cstBody) {

580 ast::Type declResultType = decl->getResultType();

582 if (ast::TupleType tupleType = dyn_castast::TupleType(declResultType)) {

583 for (ast::Type type : tupleType.getElementTypes())

584 resultTypes.push_back(genType(type));

585 } else {

586 resultTypes.push_back(genType(declResultType));

587 }

588 PDLOpT pdlOp = builder.create(loc, resultTypes,

589 decl->getName().getName(), inputs);

590 if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>)

591 castpdl::ApplyNativeConstraintOp(pdlOp).setIsNegated(true);

592 return pdlOp->getResults();

593 }

594

595

596 VariableMapTy::ScopeTy varScope(variables);

597

598

599

600

601 for (auto it : llvm::zip(inputs, decl->getInputs()))

602 variables.insert(std::get<1>(it), {std::get<0>(it)});

603

604

605 gen(cstBody);

606

607

610 auto *returnStmt = dyn_castast::ReturnStmt(cstBody->getChildren().back());

611 if (!returnStmt)

613

614

615 return genExpr(returnStmt->getResultExpr());

616 }

617

618

619

620

621

624 const llvm::SourceMgr &sourceMgr, const ast::Module &module) {

625 CodeGen codegen(mlirContext, context, sourceMgr);

627 if (failed(verify(*mlirModule)))

628 return nullptr;

629 return mlirModule;

630 }

static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, Location loc)

If the given builder is nested under a PDL PatternOp, build a rewrite operation and update the builde...

static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)

Rewrite the given regions using the computing analysis.

Attributes are known-constant values of operations.

Operation * getParentOp()

Returns the closest surrounding operation that contains this block.

static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

void loadDialect()

Load a dialect in the context.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

Block * getInsertionBlock() const

Return the block the current insertion point belongs to.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

The class represents an Attribute constraint, and constrains a variable to be an Attribute.

This expression represents a literal MLIR Attribute, and contains the textual assembly format of that...

StringRef getValue() const

Get the raw value of this expression.

This class represents a PDLL type that corresponds to an mlir::Attribute.

This expression represents a call to a decl, such as a UserConstraintDecl/UserRewriteDecl.

Expr * getCallableExpr() const

Return the callable of this call.

MutableArrayRef< Expr * > getArguments()

Return the arguments of this call.

bool getIsNegated() const

Returns whether the result of this call is to be negated.

This statement represents a compound statement, which contains a collection of other statements.

MutableArrayRef< Stmt * > getChildren()

Return the children of this compound statement.

This class represents the main context of the PDLL AST.

This expression represents a reference to a Decl node.

Decl * getDecl() const

Get the decl referenced by this expression.

This class represents the base Decl node.

This statement represents the erase statement in PDLL.

This class represents a base AST Expression node.

Type getType() const

Return the type of this expression.

This statement represents a let statement in PDLL.

VariableDecl * getVarDecl() const

Return the variable defined by this statement.

This expression represents a named member or field access of a given parent expression.

const Expr * getParentExpr() const

Get the parent expression of this access.

StringRef getMemberName() const

Return the name of the member being accessed.

This class represents a top-level AST module.

MutableArrayRef< Decl * > getChildren()

Return the children of this module.

This Decl represents a NamedAttribute, and contains a string name and attribute value.

This class represents a base AST node.

This expression represents the structural form of an MLIR Operation.

MutableArrayRef< Expr * > getResultTypes()

Return the result types of this operation.

MutableArrayRef< NamedAttributeDecl * > getAttributes()

Return the attributes of this operation.

MutableArrayRef< Expr * > getOperands()

Return the operands of this operation.

std::optional< StringRef > getName() const

Return the name of the operation, or std::nullopt if there isn't one.

This class represents a PDLL type that corresponds to an mlir::Operation.

This Decl represents a single Pattern.

const CompoundStmt * getBody() const

Return the body of this pattern.

static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)

std::optional< uint16_t > getBenefit() const

Return the benefit of this pattern if specified, or std::nullopt.

This expression builds a range from a set of element values (which may be ranges themselves).

MutableArrayRef< Expr * > getElements()

Return the element expressions of this range.

RangeType getType() const

Return the range result type of this expression.

This class represents a PDLL type that corresponds to a range of elements with a given element type.

Type getElementType() const

Return the element type of this range.

This statement represents the replace statement in PDLL.

MutableArrayRef< Expr * > getReplExprs()

Return the replacement values of this statement.

This statement represents a return from a "callable" like decl, e.g.

This statement represents an operation rewrite that contains a block of nested rewrite commands.

CompoundStmt * getRewriteBody() const

Return the compound rewrite body.

This class represents a base AST Statement node.

This expression builds a tuple from a set of element values.

MutableArrayRef< Expr * > getElements()

Return the element expressions of this tuple.

This class represents a PDLL tuple type, i.e.

This expression represents a literal MLIR Type, and contains the textual assembly format of that type...

StringRef getValue() const

Get the raw value of this expression.

This class represents a PDLL type that corresponds to an mlir::Type.

This decl represents a user defined constraint.

MutableArrayRef< VariableDecl * > getInputs()

Return the input arguments of this constraint.

MutableArrayRef< VariableDecl * > getResults()

Return the explicit results of the constraint declaration.

This decl represents a user defined rewrite.

The class represents a Value constraint, and constrains a variable to be a Value.

The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.

This class represents a PDLL type that corresponds to an mlir::Value.

This Decl represents the definition of a PDLL variable.

Expr * getInitExpr() const

Return the initializer expression of this statement, or nullptr if there was no initializer.

MutableArrayRef< ConstraintRef > getConstraints()

Return the constraints of this variable.

Type getType() const

Return the type of the decl.

This class contains all of the registered ODS operation classes.

This class provides an ODS representation of a specific operation operand or result.

This class provides an ODS representation of a specific operation.

ArrayRef< OperandOrResult > getResults() const

Returns the results of this operation.

OwningOpRef< ModuleOp > codegenPDLLToMLIR(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr, const ast::Module &module)

Given a PDLL module, generate an MLIR PDL pattern module within the given MLIR context.

Include the generated interface declarations.

Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)

This parses a single MLIR attribute to an MLIR context if it was valid.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)

This parses a single MLIR type to an MLIR context if it was valid.

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

This class represents a reference to a constraint, and contains a constraint and the location of the ...

This class provides a convenient API for interacting with source names.

StringRef getName() const

Return the raw string name.