MLIR: lib/Dialect/IRDL/IRDLLoading.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

22#include "llvm/ADT/STLExtras.h"

23

24using namespace mlir;

26

27

28

29

30static LogicalResult

33 ArrayRef<std::unique_ptr> constraints,

35 if (params.size() != paramConstraints.size()) {

36 emitError() << "expected " << paramConstraints.size()

37 << " type arguments, but had " << params.size();

38 return failure();

39 }

40

42

43

44 for (auto [i, param] : enumerate(params))

45 if (failed(verifier.verify(emitError, param, paramConstraints[i])))

46 return failure();

47

49}

50

51

53 StringRef attrName, unsigned numElements,

56

58 if (!segmentSizesAttr) {

59 return op->emitError() << "'" << attrName

60 << "' attribute is expected but not provided";

61 }

62

63 auto denseSegmentSizes = dyn_cast(segmentSizesAttr);

64 if (!denseSegmentSizes) {

65 return op->emitError() << "'" << attrName

66 << "' attribute is expected to be a dense i32 array";

67 }

68

69 if (denseSegmentSizes.size() != (int64_t)variadicities.size()) {

70 return op->emitError() << "'" << attrName << "' attribute for specifying "

71 << elemName << " segments must have "

72 << variadicities.size() << " elements, but got "

73 << denseSegmentSizes.size();

74 }

75

76

77 for (auto [i, segmentSize, variadicity] :

78 enumerate(denseSegmentSizes.asArrayRef(), variadicities)) {

79 if (segmentSize < 0)

81 << "'" << attrName << "' attribute for specifying " << elemName

82 << " segments must have non-negative values";

83 if (variadicity == Variadicity::single && segmentSize != 1)

84 return op->emitError() << "element " << i << " in '" << attrName

85 << "' attribute must be equal to 1";

86

87 if (variadicity == Variadicity::optional && segmentSize > 1)

88 return op->emitError() << "element " << i << " in '" << attrName

89 << "' attribute must be equal to 0 or 1";

90

91 segmentSizes.push_back(segmentSize);

92 }

93

94

95 int32_t sum = 0;

96 for (int32_t segmentSize : denseSegmentSizes.asArrayRef())

97 sum += segmentSize;

98 if (sum != static_cast<int32_t>(numElements))

99 return op->emitError() << "sum of elements in '" << attrName

100 << "' attribute must be equal to the number of "

101 << elemName << "s";

102

104}

105

106

107

108

109

110

112 StringRef attrName, unsigned numElements,

115

116

117 int numberNonSingle = count_if(

118 variadicities, [](Variadicity v) { return v != Variadicity::single; });

119 if (numberNonSingle > 1)

121 variadicities, segmentSizes);

122

123

124 if (numberNonSingle == 0) {

125 if (numElements != variadicities.size()) {

126 return op->emitError() << "op expects exactly " << variadicities.size()

127 << " " << elemName << "s, but got " << numElements;

128 }

129 for (size_t i = 0, e = variadicities.size(); i < e; ++i)

130 segmentSizes.push_back(1);

132 }

133

134 assert(numberNonSingle == 1);

135

136

137

138 int nonSingleSegmentSize = static_cast<int>(numElements) -

139 static_cast<int>(variadicities.size()) + 1;

140

141 if (nonSingleSegmentSize < 0) {

142 return op->emitError() << "op expects at least " << variadicities.size() - 1

143 << " " << elemName << "s, but got " << numElements;

144 }

145

146

147 for (Variadicity variadicity : variadicities) {

148 if (variadicity == Variadicity::single) {

149 segmentSizes.push_back(1);

150 continue;

151 }

152

153

154

155 if (nonSingleSegmentSize > 1 && variadicity == Variadicity::optional)

156 return op->emitError() << "op expects at most " << variadicities.size()

157 << " " << elemName << "s, but got " << numElements;

158

159 segmentSizes.push_back(nonSingleSegmentSize);

160 }

161

163}

164

165

166

167

168

175

176

177

178

179

186

187

188

189

195

196

198 if (failed(

200 return failure();

201

202

203

206 return failure();

207

209

210

211

213

214 for (auto [name, constraint] : attributeConstrs) {

215

216 std::optional actual = actualAttrs.getNamed(name);

217 if (!actual.has_value())

219 << "attribute " << name << " is expected but not provided";

220

221

222 if (failed(verifier.verify({emitError}, actual->getValue(), constraint)))

223 return failure();

224 }

225

226

227 int operandIdx = 0;

228 for (auto [defIndex, segmentSize] : enumerate(operandSegmentSizes)) {

229 for (int i = 0; i < segmentSize; i++) {

230 if (failed(verifier.verify(

231 {emitError}, TypeAttr::get(op->getOperandTypes()[operandIdx]),

232 operandConstrs[defIndex])))

233 return failure();

234 ++operandIdx;

235 }

236 }

237

238

239 int resultIdx = 0;

240 for (auto [defIndex, segmentSize] : enumerate(resultSegmentSizes)) {

241 for (int i = 0; i < segmentSize; i++) {

242 if (failed(verifier.verify({emitError},

244 resultConstrs[defIndex])))

245 return failure();

246 ++resultIdx;

247 }

248 }

249

251}

252

255 ArrayRef<std::unique_ptr> regionsConstraints) {

256 if (op->getNumRegions() != regionsConstraints.size()) {

258 << "unexpected number of regions: expected "

259 << regionsConstraints.size() << " but got " << op->getNumRegions();

260 }

261

262 for (auto [constraint, region] :

263 llvm::zip(regionsConstraints, op->getRegions()))

264 if (failed(constraint->verify(region, verifier)))

265 return failure();

266

268}

269

270llvm::unique_function<LogicalResult(Operation *) const>

272 OperationOp op,

273 const DenseMap<irdl::TypeOp, std::unique_ptr> &types,

274 const DenseMap<irdl::AttributeOp, std::unique_ptr>

275 &attrs) {

276

280 if (isa(op)) {

281 if (op.getNumResults() != 1) {

282 op.emitError()

283 << "IRDL constraint operations must have exactly one result";

284 return nullptr;

285 }

286 constrToValue.push_back(op.getResult(0));

287 }

288 if (isa(op)) {

289 if (op.getNumResults() != 1) {

290 op.emitError()

291 << "IRDL constraint operations must have exactly one result";

292 return nullptr;

293 }

294 regionToValue.push_back(op.getResult(0));

295 }

296 }

297

298

300 for (Value v : constrToValue) {

301 VerifyConstraintInterface op =

302 cast(v.getDefiningOp());

303 std::unique_ptr verifier =

304 op.getVerifier(constrToValue, types, attrs);

305 if (!verifier)

306 return nullptr;

307 constraints.push_back(std::move(verifier));

308 }

309

310

312 for (Value v : regionToValue) {

313 VerifyRegionInterface op = cast(v.getDefiningOp());

314 std::unique_ptr verifier =

315 op.getVerifier(constrToValue, types, attrs);

316 regionConstraints.push_back(std::move(verifier));

317 }

318

321

322

323 auto operandsOp = op.getOp();

324 if (operandsOp.has_value()) {

325 operandConstraints.reserve(operandsOp->getArgs().size());

326 for (Value operand : operandsOp->getArgs()) {

327 for (auto [i, constr] : enumerate(constrToValue)) {

328 if (constr == operand) {

329 operandConstraints.push_back(i);

330 break;

331 }

332 }

333 }

334

335

336 for (VariadicityAttr attr : operandsOp->getVariadicity())

337 operandVariadicity.push_back(attr.getValue());

338 }

339

342

343

344 auto resultsOp = op.getOp();

345 if (resultsOp.has_value()) {

346 resultConstraints.reserve(resultsOp->getArgs().size());

347 for (Value result : resultsOp->getArgs()) {

348 for (auto [i, constr] : enumerate(constrToValue)) {

349 if (constr == result) {

350 resultConstraints.push_back(i);

351 break;

352 }

353 }

354 }

355

356

357 for (Attribute attr : resultsOp->getVariadicity())

358 resultVariadicity.push_back(cast(attr).getValue());

359 }

360

361

363 auto attributesOp = op.getOp();

364 if (attributesOp.has_value()) {

366 const ArrayAttr names = attributesOp->getAttributeValueNames();

367

368 for (const auto &[name, value] : llvm::zip(names, values)) {

369 for (auto [i, constr] : enumerate(constrToValue)) {

370 if (constr == value) {

371 attributeConstraints[cast(name)] = i;

372 break;

373 }

374 }

375 }

376 }

377

378 return

379 [constraints{std::move(constraints)},

380 regionConstraints{std::move(regionConstraints)},

381 operandConstraints{std::move(operandConstraints)},

382 operandVariadicity{std::move(operandVariadicity)},

383 resultConstraints{std::move(resultConstraints)},

384 resultVariadicity{std::move(resultVariadicity)},

385 attributeConstraints{std::move(attributeConstraints)}](Operation *op) {

387 const LogicalResult opVerifierResult = irdlOpVerifier(

388 op, verifier, operandConstraints, operandVariadicity,

389 resultConstraints, resultVariadicity, attributeConstraints);

390 const LogicalResult opRegionVerifierResult =

392 return LogicalResult::success(opVerifierResult.succeeded() &&

393 opRegionVerifierResult.succeeded());

394 };

395}

396

397

398

401 const DenseMap<TypeOp, std::unique_ptr> &types,

402 const DenseMap<AttributeOp, std::unique_ptr>

403 &attrs) {

404

405

407 return failure();

408 };

410 printer.printGenericOp(op);

411 };

412

414 if (!verifier)

416

417

418

419 auto regionVerifier = [](Operation *op) { return LogicalResult::success(); };

420

422 op.getName(), dialect, std::move(verifier), std::move(regionVerifier),

423 std::move(parser), std::move(printer));

425

427}

428

429

430

433 DenseMap<TypeOp, std::unique_ptr> &types,

434 DenseMap<AttributeOp, std::unique_ptr> &attrs) {

435 assert((isa(attrOrTypeDef) || isa(attrOrTypeDef)) &&

436 "Expected an attribute or type definition");

437

438

441 if (isa(op)) {

442 assert(op.getNumResults() == 1 &&

443 "IRDL constraint operations must have exactly one result");

444 constrToValue.push_back(op.getResult(0));

445 }

446 }

447

448

450 for (Value v : constrToValue) {

451 VerifyConstraintInterface op =

452 cast(v.getDefiningOp());

453 std::unique_ptr verifier =

454 op.getVerifier(constrToValue, types, attrs);

455 if (!verifier)

456 return {};

457 constraints.push_back(std::move(verifier));

458 }

459

460

461 std::optional params;

462 if (auto attr = dyn_cast(attrOrTypeDef))

463 params = attr.getOp();

464 else if (auto type = dyn_cast(attrOrTypeDef))

465 params = type.getOp();

466

467

469 if (params.has_value()) {

470 paramConstraints.reserve(params->getArgs().size());

471 for (Value param : params->getArgs()) {

472 for (auto [i, constr] : enumerate(constrToValue)) {

473 if (constr == param) {

474 paramConstraints.push_back(i);

475 break;

476 }

477 }

478 }

479 }

480

481 auto verifier = [paramConstraints{std::move(paramConstraints)},

482 constraints{std::move(constraints)}](

486 paramConstraints);

487 };

488

489

490

491 return std::move(verifier);

492}

493

494

495

496

497

498

499

500

501

502

506

507 if (auto anyOf = dyn_cast(op)) {

508 bool hasAny = false;

510 hasAny &= getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds);

511 return hasAny;

512 }

513

514

515

516 if (auto allOf = dyn_cast(op))

517 return getBases(allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps,

518 isIds);

519

520

521 if (auto params = dyn_cast(op)) {

522 SymbolRefAttr symRef = params.getBaseType();

524 assert(defOp && "symbol reference should refer to an existing operation");

525 paramIrdlOps.insert(defOp);

526 return false;

527 }

528

529

530 if (auto is = dyn_cast(op)) {

531 Attribute expected = is.getExpected();

532 isIds.insert(expected.getTypeID());

533 return false;

534 }

535

536

537

538 if (auto isA = dyn_cast(op))

539 return true;

540

541 llvm_unreachable("unknown IRDL constraint");

542}

543

544

545

546

547

548

549

550

551

552

553

554

559

561 Operation *argOp = arg.getDefiningOp();

565

566

567

568 if (getBases(argOp, argParamIds, argParamIrdlOps, argIsIds))

569 return failure();

570

571

572

573

574 for (TypeID id : argParamIds) {

575 if (isIds.count(id))

576 return failure();

577 bool inserted = paramIds.insert(id).second;

579 return failure();

580 }

581

582

583

584 for (TypeID id : isIds) {

585 if (paramIds.count(id))

586 return failure();

587 isIds.insert(id);

588 }

589

590

591

592

593

594 for (Operation *op : argParamIrdlOps) {

595 bool inserted = paramIrdlOps.insert(op).second;

597 return failure();

598 }

599 }

600

602}

603

604

605

608 op.walk([&](DialectOp dialectOp) {

609 MLIRContext *ctx = dialectOp.getContext();

610 StringRef dialectName = dialectOp.getName();

611

614

615 dialects.insert({dialectOp, dialect});

616 });

617 return dialects;

618}

619

620

621

626 op.walk([&](TypeOp typeOp) {

629 typeOp.getName(), dialect,

632 });

633 typeDefs.try_emplace(typeOp, std::move(typeDef));

634 });

635 return typeDefs;

636}

637

638

639

644 op.walk([&](AttributeOp attrOp) {

647 attrOp.getName(), dialect,

650 });

651 attrDefs.try_emplace(attrOp, std::move(attrDef));

652 });

653 return attrDefs;

654}

655

657

658

662 return op.emitError("any_of constraints are not in the correct form");

663

664

665

666

672

673

674 WalkResult res = op.walk([&](TypeOp typeOp) {

676 typeOp, dialects[typeOp.getParentOp()], types, attrs);

677 if (!verifier)

679 types[typeOp]->setVerifyFn(std::move(verifier));

681 });

683 return failure();

684

685

686 res = op.walk([&](AttributeOp attrOp) {

688 attrOp, dialects[attrOp.getParentOp()], types, attrs);

689 if (!verifier)

691 attrs[attrOp]->setVerifyFn(std::move(verifier));

693 });

695 return failure();

696

697

698 res = op.walk([&](OperationOp opOp) {

699 return loadOperation(opOp, dialects[opOp.getParentOp()], types, attrs);

700 });

702 return failure();

703

704

705 for (auto &pair : types) {

708 }

709

710

711 for (auto &pair : attrs) {

714 }

715

717}

static bool getBases(Operation *op, SmallPtrSet< TypeID, 4 > &paramIds, SmallPtrSet< Operation *, 4 > &paramIrdlOps, SmallPtrSet< TypeID, 4 > &isIds)

Get the possible bases of a constraint.

Definition IRDLLoading.cpp:503

static LogicalResult checkCorrectAnyOf(AnyOfOp anyOf)

Check that an any_of is in the subset IRDL can handle.

Definition IRDLLoading.cpp:555

static DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition > > preallocateTypeDefs(ModuleOp op, DenseMap< DialectOp, ExtensibleDialect * > dialects)

Preallocate type definitions objects with empty verifiers.

Definition IRDLLoading.cpp:623

LogicalResult getSegmentSizesFromAttr(Operation *op, StringRef elemName, StringRef attrName, unsigned numElements, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)

Get the operand segment sizes from the attribute dictionary.

Definition IRDLLoading.cpp:52

static LogicalResult irdlOpVerifier(Operation *op, ConstraintVerifier &verifier, ArrayRef< size_t > operandConstrs, ArrayRef< Variadicity > operandVariadicity, ArrayRef< size_t > resultConstrs, ArrayRef< Variadicity > resultVariadicity, const DenseMap< StringAttr, size_t > &attributeConstrs)

Verify that the given operation satisfies the given constraints.

Definition IRDLLoading.cpp:190

static LogicalResult irdlAttrOrTypeVerifier(function_ref< InFlightDiagnostic()> emitError, ArrayRef< Attribute > params, ArrayRef< std::unique_ptr< Constraint > > constraints, ArrayRef< size_t > paramConstraints)

Verify that the given list of parameters satisfy the given constraints.

Definition IRDLLoading.cpp:31

static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier(Operation *attrOrTypeDef, ExtensibleDialect *dialect, DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition > > &types, DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition > > &attrs)

Get the verifier of a type or attribute definition.

Definition IRDLLoading.cpp:431

static DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition > > preallocateAttrDefs(ModuleOp op, DenseMap< DialectOp, ExtensibleDialect * > dialects)

Preallocate attribute definitions objects with empty verifiers.

Definition IRDLLoading.cpp:641

LogicalResult getOperandSegmentSizes(Operation *op, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)

Compute the segment sizes of the given operands.

Definition IRDLLoading.cpp:169

static LogicalResult irdlRegionVerifier(Operation *op, ConstraintVerifier &verifier, ArrayRef< std::unique_ptr< RegionConstraint > > regionsConstraints)

Definition IRDLLoading.cpp:253

static DenseMap< DialectOp, ExtensibleDialect * > loadEmptyDialects(ModuleOp op)

Load all dialects in the given module, without loading any operation, type or attribute definitions.

Definition IRDLLoading.cpp:606

LogicalResult getSegmentSizes(Operation *op, StringRef elemName, StringRef attrName, unsigned numElements, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)

Compute the segment sizes of the given element (operands, results).

Definition IRDLLoading.cpp:111

static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect, const DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition > > &types, const DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition > > &attrs)

Define and load an operation represented by a irdl.operation operation.

Definition IRDLLoading.cpp:399

LogicalResult getResultSegmentSizes(Operation *op, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)

Compute the segment sizes of the given results.

Definition IRDLLoading.cpp:180

*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method

Attributes are known-constant values of operations.

TypeID getTypeID()

Return a unique identifier for the concrete attribute type.

static std::unique_ptr< DynamicAttrDefinition > get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier)

Create a new attribute definition at runtime.

llvm::unique_function< LogicalResult( function_ref< InFlightDiagnostic()>, ArrayRef< Attribute >) const > VerifierFn

A dialect that can be defined at runtime.

static std::unique_ptr< DynamicOpDefinition > get(StringRef name, ExtensibleDialect *dialect, OperationName::VerifyInvariantsFn &&verifyFn, OperationName::VerifyRegionInvariantsFn &&verifyRegionFn)

Create a new op at runtime.

static std::unique_ptr< DynamicTypeDefinition > get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier)

Create a new dynamic type definition.

A dialect that can be extended with new operations/types/attributes at runtime.

void registerDynamicOp(std::unique_ptr< DynamicOpDefinition > &&type)

Add a new operation defined at runtime to the dialect.

void registerDynamicType(std::unique_ptr< DynamicTypeDefinition > &&type)

Add a new type defined at runtime to the dialect.

void registerDynamicAttr(std::unique_ptr< DynamicAttrDefinition > &&attr)

Add a new attribute defined at runtime to the dialect.

This class represents a diagnostic that is inflight and set to be reported.

MLIRContext is the top-level object for a collection of MLIR operations.

DynamicDialect * getOrLoadDynamicDialect(StringRef dialectNamespace, function_ref< void(DynamicDialect *)> ctor)

Get (or create) a dynamic dialect for the given name.

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

Operation is the basic unit of execution within MLIR.

DictionaryAttr getAttrDictionary()

Return all of the attributes on this operation as a DictionaryAttr.

Region & getRegion(unsigned index)

Returns the region held by this operation at position 'index'.

Attribute getAttr(StringAttr name)

Return the specified attribute if present, null otherwise.

unsigned getNumRegions()

Returns the number of regions held by this operation.

unsigned getNumOperands()

OperandRange operand_range

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

operand_type_range getOperandTypes()

MutableArrayRef< Region > getRegions()

Returns the regions held by this operation.

result_type_range getResultTypes()

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

unsigned getNumResults()

Return the number of results held by this operation.

iterator_range< OpIterator > getOps()

This class provides an efficient unique identifier for a specific C++ type.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

A utility result that is used to signal how to proceed with an ongoing walk:

static WalkResult advance()

bool wasInterrupted() const

Returns true if the walk was interrupted.

static WalkResult interrupt()

Provides context to the verification of constraints.

LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Attribute attr, unsigned variable)

Check that a constraint is satisfied by an attribute.

llvm::LogicalResult loadDialects(ModuleOp op)

Load all the dialects defined in the module.

Definition IRDLLoading.cpp:656

Operation * lookupSymbolNearDialect(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)

Looks up a symbol from the symbol table containing the source operation's dialect definition operatio...

llvm::unique_function< LogicalResult(Operation *) const > createVerifier(OperationOp operation, const DenseMap< irdl::TypeOp, std::unique_ptr< DynamicTypeDefinition > > &typeDefs, const DenseMap< irdl::AttributeOp, std::unique_ptr< DynamicAttrDefinition > > &attrDefs)

Generate an op verifier function from the given IRDL operation definition.

Definition IRDLLoading.cpp:271

Include the generated interface declarations.

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap

llvm::function_ref< Fn > function_ref

This represents an operation in an abstracted form, suitable for use with the builder APIs.