MLIR: lib/Dialect/IRDL/IRDLLoading.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
22#include "llvm/ADT/STLExtras.h"
23
24using namespace mlir;
26
27
28
29
30static LogicalResult
33 ArrayRef<std::unique_ptr> constraints,
35 if (params.size() != paramConstraints.size()) {
36 emitError() << "expected " << paramConstraints.size()
37 << " type arguments, but had " << params.size();
38 return failure();
39 }
40
42
43
44 for (auto [i, param] : enumerate(params))
45 if (failed(verifier.verify(emitError, param, paramConstraints[i])))
46 return failure();
47
49}
50
51
53 StringRef attrName, unsigned numElements,
56
58 if (!segmentSizesAttr) {
59 return op->emitError() << "'" << attrName
60 << "' attribute is expected but not provided";
61 }
62
63 auto denseSegmentSizes = dyn_cast(segmentSizesAttr);
64 if (!denseSegmentSizes) {
65 return op->emitError() << "'" << attrName
66 << "' attribute is expected to be a dense i32 array";
67 }
68
69 if (denseSegmentSizes.size() != (int64_t)variadicities.size()) {
70 return op->emitError() << "'" << attrName << "' attribute for specifying "
71 << elemName << " segments must have "
72 << variadicities.size() << " elements, but got "
73 << denseSegmentSizes.size();
74 }
75
76
77 for (auto [i, segmentSize, variadicity] :
78 enumerate(denseSegmentSizes.asArrayRef(), variadicities)) {
79 if (segmentSize < 0)
81 << "'" << attrName << "' attribute for specifying " << elemName
82 << " segments must have non-negative values";
83 if (variadicity == Variadicity::single && segmentSize != 1)
84 return op->emitError() << "element " << i << " in '" << attrName
85 << "' attribute must be equal to 1";
86
87 if (variadicity == Variadicity::optional && segmentSize > 1)
88 return op->emitError() << "element " << i << " in '" << attrName
89 << "' attribute must be equal to 0 or 1";
90
91 segmentSizes.push_back(segmentSize);
92 }
93
94
95 int32_t sum = 0;
96 for (int32_t segmentSize : denseSegmentSizes.asArrayRef())
97 sum += segmentSize;
98 if (sum != static_cast<int32_t>(numElements))
99 return op->emitError() << "sum of elements in '" << attrName
100 << "' attribute must be equal to the number of "
101 << elemName << "s";
102
104}
105
106
107
108
109
110
112 StringRef attrName, unsigned numElements,
115
116
117 int numberNonSingle = count_if(
118 variadicities, [](Variadicity v) { return v != Variadicity::single; });
119 if (numberNonSingle > 1)
121 variadicities, segmentSizes);
122
123
124 if (numberNonSingle == 0) {
125 if (numElements != variadicities.size()) {
126 return op->emitError() << "op expects exactly " << variadicities.size()
127 << " " << elemName << "s, but got " << numElements;
128 }
129 for (size_t i = 0, e = variadicities.size(); i < e; ++i)
130 segmentSizes.push_back(1);
132 }
133
134 assert(numberNonSingle == 1);
135
136
137
138 int nonSingleSegmentSize = static_cast<int>(numElements) -
139 static_cast<int>(variadicities.size()) + 1;
140
141 if (nonSingleSegmentSize < 0) {
142 return op->emitError() << "op expects at least " << variadicities.size() - 1
143 << " " << elemName << "s, but got " << numElements;
144 }
145
146
147 for (Variadicity variadicity : variadicities) {
148 if (variadicity == Variadicity::single) {
149 segmentSizes.push_back(1);
150 continue;
151 }
152
153
154
155 if (nonSingleSegmentSize > 1 && variadicity == Variadicity::optional)
156 return op->emitError() << "op expects at most " << variadicities.size()
157 << " " << elemName << "s, but got " << numElements;
158
159 segmentSizes.push_back(nonSingleSegmentSize);
160 }
161
163}
164
165
166
167
168
175
176
177
178
179
186
187
188
189
195
196
198 if (failed(
200 return failure();
201
202
203
206 return failure();
207
209
210
211
213
214 for (auto [name, constraint] : attributeConstrs) {
215
216 std::optional actual = actualAttrs.getNamed(name);
217 if (!actual.has_value())
219 << "attribute " << name << " is expected but not provided";
220
221
222 if (failed(verifier.verify({emitError}, actual->getValue(), constraint)))
223 return failure();
224 }
225
226
227 int operandIdx = 0;
228 for (auto [defIndex, segmentSize] : enumerate(operandSegmentSizes)) {
229 for (int i = 0; i < segmentSize; i++) {
230 if (failed(verifier.verify(
231 {emitError}, TypeAttr::get(op->getOperandTypes()[operandIdx]),
232 operandConstrs[defIndex])))
233 return failure();
234 ++operandIdx;
235 }
236 }
237
238
239 int resultIdx = 0;
240 for (auto [defIndex, segmentSize] : enumerate(resultSegmentSizes)) {
241 for (int i = 0; i < segmentSize; i++) {
242 if (failed(verifier.verify({emitError},
244 resultConstrs[defIndex])))
245 return failure();
246 ++resultIdx;
247 }
248 }
249
251}
252
255 ArrayRef<std::unique_ptr> regionsConstraints) {
256 if (op->getNumRegions() != regionsConstraints.size()) {
258 << "unexpected number of regions: expected "
259 << regionsConstraints.size() << " but got " << op->getNumRegions();
260 }
261
262 for (auto [constraint, region] :
263 llvm::zip(regionsConstraints, op->getRegions()))
264 if (failed(constraint->verify(region, verifier)))
265 return failure();
266
268}
269
270llvm::unique_function<LogicalResult(Operation *) const>
272 OperationOp op,
273 const DenseMap<irdl::TypeOp, std::unique_ptr> &types,
274 const DenseMap<irdl::AttributeOp, std::unique_ptr>
275 &attrs) {
276
280 if (isa(op)) {
281 if (op.getNumResults() != 1) {
282 op.emitError()
283 << "IRDL constraint operations must have exactly one result";
284 return nullptr;
285 }
286 constrToValue.push_back(op.getResult(0));
287 }
288 if (isa(op)) {
289 if (op.getNumResults() != 1) {
290 op.emitError()
291 << "IRDL constraint operations must have exactly one result";
292 return nullptr;
293 }
294 regionToValue.push_back(op.getResult(0));
295 }
296 }
297
298
300 for (Value v : constrToValue) {
301 VerifyConstraintInterface op =
302 cast(v.getDefiningOp());
303 std::unique_ptr verifier =
304 op.getVerifier(constrToValue, types, attrs);
305 if (!verifier)
306 return nullptr;
307 constraints.push_back(std::move(verifier));
308 }
309
310
312 for (Value v : regionToValue) {
313 VerifyRegionInterface op = cast(v.getDefiningOp());
314 std::unique_ptr verifier =
315 op.getVerifier(constrToValue, types, attrs);
316 regionConstraints.push_back(std::move(verifier));
317 }
318
321
322
323 auto operandsOp = op.getOp();
324 if (operandsOp.has_value()) {
325 operandConstraints.reserve(operandsOp->getArgs().size());
326 for (Value operand : operandsOp->getArgs()) {
327 for (auto [i, constr] : enumerate(constrToValue)) {
328 if (constr == operand) {
329 operandConstraints.push_back(i);
330 break;
331 }
332 }
333 }
334
335
336 for (VariadicityAttr attr : operandsOp->getVariadicity())
337 operandVariadicity.push_back(attr.getValue());
338 }
339
342
343
344 auto resultsOp = op.getOp();
345 if (resultsOp.has_value()) {
346 resultConstraints.reserve(resultsOp->getArgs().size());
347 for (Value result : resultsOp->getArgs()) {
348 for (auto [i, constr] : enumerate(constrToValue)) {
349 if (constr == result) {
350 resultConstraints.push_back(i);
351 break;
352 }
353 }
354 }
355
356
357 for (Attribute attr : resultsOp->getVariadicity())
358 resultVariadicity.push_back(cast(attr).getValue());
359 }
360
361
363 auto attributesOp = op.getOp();
364 if (attributesOp.has_value()) {
366 const ArrayAttr names = attributesOp->getAttributeValueNames();
367
368 for (const auto &[name, value] : llvm::zip(names, values)) {
369 for (auto [i, constr] : enumerate(constrToValue)) {
370 if (constr == value) {
371 attributeConstraints[cast(name)] = i;
372 break;
373 }
374 }
375 }
376 }
377
378 return
379 [constraints{std::move(constraints)},
380 regionConstraints{std::move(regionConstraints)},
381 operandConstraints{std::move(operandConstraints)},
382 operandVariadicity{std::move(operandVariadicity)},
383 resultConstraints{std::move(resultConstraints)},
384 resultVariadicity{std::move(resultVariadicity)},
385 attributeConstraints{std::move(attributeConstraints)}](Operation *op) {
387 const LogicalResult opVerifierResult = irdlOpVerifier(
388 op, verifier, operandConstraints, operandVariadicity,
389 resultConstraints, resultVariadicity, attributeConstraints);
390 const LogicalResult opRegionVerifierResult =
392 return LogicalResult::success(opVerifierResult.succeeded() &&
393 opRegionVerifierResult.succeeded());
394 };
395}
396
397
398
401 const DenseMap<TypeOp, std::unique_ptr> &types,
402 const DenseMap<AttributeOp, std::unique_ptr>
403 &attrs) {
404
405
407 return failure();
408 };
410 printer.printGenericOp(op);
411 };
412
414 if (!verifier)
416
417
418
419 auto regionVerifier = [](Operation *op) { return LogicalResult::success(); };
420
422 op.getName(), dialect, std::move(verifier), std::move(regionVerifier),
423 std::move(parser), std::move(printer));
425
427}
428
429
430
433 DenseMap<TypeOp, std::unique_ptr> &types,
434 DenseMap<AttributeOp, std::unique_ptr> &attrs) {
435 assert((isa(attrOrTypeDef) || isa(attrOrTypeDef)) &&
436 "Expected an attribute or type definition");
437
438
441 if (isa(op)) {
442 assert(op.getNumResults() == 1 &&
443 "IRDL constraint operations must have exactly one result");
444 constrToValue.push_back(op.getResult(0));
445 }
446 }
447
448
450 for (Value v : constrToValue) {
451 VerifyConstraintInterface op =
452 cast(v.getDefiningOp());
453 std::unique_ptr verifier =
454 op.getVerifier(constrToValue, types, attrs);
455 if (!verifier)
456 return {};
457 constraints.push_back(std::move(verifier));
458 }
459
460
461 std::optional params;
462 if (auto attr = dyn_cast(attrOrTypeDef))
463 params = attr.getOp();
464 else if (auto type = dyn_cast(attrOrTypeDef))
465 params = type.getOp();
466
467
469 if (params.has_value()) {
470 paramConstraints.reserve(params->getArgs().size());
471 for (Value param : params->getArgs()) {
472 for (auto [i, constr] : enumerate(constrToValue)) {
473 if (constr == param) {
474 paramConstraints.push_back(i);
475 break;
476 }
477 }
478 }
479 }
480
481 auto verifier = [paramConstraints{std::move(paramConstraints)},
482 constraints{std::move(constraints)}](
486 paramConstraints);
487 };
488
489
490
491 return std::move(verifier);
492}
493
494
495
496
497
498
499
500
501
502
506
507 if (auto anyOf = dyn_cast(op)) {
508 bool hasAny = false;
510 hasAny &= getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds);
511 return hasAny;
512 }
513
514
515
516 if (auto allOf = dyn_cast(op))
517 return getBases(allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps,
518 isIds);
519
520
521 if (auto params = dyn_cast(op)) {
522 SymbolRefAttr symRef = params.getBaseType();
524 assert(defOp && "symbol reference should refer to an existing operation");
525 paramIrdlOps.insert(defOp);
526 return false;
527 }
528
529
530 if (auto is = dyn_cast(op)) {
531 Attribute expected = is.getExpected();
532 isIds.insert(expected.getTypeID());
533 return false;
534 }
535
536
537
538 if (auto isA = dyn_cast(op))
539 return true;
540
541 llvm_unreachable("unknown IRDL constraint");
542}
543
544
545
546
547
548
549
550
551
552
553
554
559
561 Operation *argOp = arg.getDefiningOp();
565
566
567
568 if (getBases(argOp, argParamIds, argParamIrdlOps, argIsIds))
569 return failure();
570
571
572
573
574 for (TypeID id : argParamIds) {
575 if (isIds.count(id))
576 return failure();
577 bool inserted = paramIds.insert(id).second;
579 return failure();
580 }
581
582
583
584 for (TypeID id : isIds) {
585 if (paramIds.count(id))
586 return failure();
587 isIds.insert(id);
588 }
589
590
591
592
593
594 for (Operation *op : argParamIrdlOps) {
595 bool inserted = paramIrdlOps.insert(op).second;
597 return failure();
598 }
599 }
600
602}
603
604
605
608 op.walk([&](DialectOp dialectOp) {
609 MLIRContext *ctx = dialectOp.getContext();
610 StringRef dialectName = dialectOp.getName();
611
614
615 dialects.insert({dialectOp, dialect});
616 });
617 return dialects;
618}
619
620
621
626 op.walk([&](TypeOp typeOp) {
629 typeOp.getName(), dialect,
632 });
633 typeDefs.try_emplace(typeOp, std::move(typeDef));
634 });
635 return typeDefs;
636}
637
638
639
644 op.walk([&](AttributeOp attrOp) {
647 attrOp.getName(), dialect,
650 });
651 attrDefs.try_emplace(attrOp, std::move(attrDef));
652 });
653 return attrDefs;
654}
655
657
658
662 return op.emitError("any_of constraints are not in the correct form");
663
664
665
666
672
673
674 WalkResult res = op.walk([&](TypeOp typeOp) {
676 typeOp, dialects[typeOp.getParentOp()], types, attrs);
677 if (!verifier)
679 types[typeOp]->setVerifyFn(std::move(verifier));
681 });
683 return failure();
684
685
686 res = op.walk([&](AttributeOp attrOp) {
688 attrOp, dialects[attrOp.getParentOp()], types, attrs);
689 if (!verifier)
691 attrs[attrOp]->setVerifyFn(std::move(verifier));
693 });
695 return failure();
696
697
698 res = op.walk([&](OperationOp opOp) {
699 return loadOperation(opOp, dialects[opOp.getParentOp()], types, attrs);
700 });
702 return failure();
703
704
705 for (auto &pair : types) {
708 }
709
710
711 for (auto &pair : attrs) {
714 }
715
717}
static bool getBases(Operation *op, SmallPtrSet< TypeID, 4 > ¶mIds, SmallPtrSet< Operation *, 4 > ¶mIrdlOps, SmallPtrSet< TypeID, 4 > &isIds)
Get the possible bases of a constraint.
Definition IRDLLoading.cpp:503
static LogicalResult checkCorrectAnyOf(AnyOfOp anyOf)
Check that an any_of is in the subset IRDL can handle.
Definition IRDLLoading.cpp:555
static DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition > > preallocateTypeDefs(ModuleOp op, DenseMap< DialectOp, ExtensibleDialect * > dialects)
Preallocate type definitions objects with empty verifiers.
Definition IRDLLoading.cpp:623
LogicalResult getSegmentSizesFromAttr(Operation *op, StringRef elemName, StringRef attrName, unsigned numElements, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Get the operand segment sizes from the attribute dictionary.
Definition IRDLLoading.cpp:52
static LogicalResult irdlOpVerifier(Operation *op, ConstraintVerifier &verifier, ArrayRef< size_t > operandConstrs, ArrayRef< Variadicity > operandVariadicity, ArrayRef< size_t > resultConstrs, ArrayRef< Variadicity > resultVariadicity, const DenseMap< StringAttr, size_t > &attributeConstrs)
Verify that the given operation satisfies the given constraints.
Definition IRDLLoading.cpp:190
static LogicalResult irdlAttrOrTypeVerifier(function_ref< InFlightDiagnostic()> emitError, ArrayRef< Attribute > params, ArrayRef< std::unique_ptr< Constraint > > constraints, ArrayRef< size_t > paramConstraints)
Verify that the given list of parameters satisfy the given constraints.
Definition IRDLLoading.cpp:31
static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier(Operation *attrOrTypeDef, ExtensibleDialect *dialect, DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition > > &types, DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition > > &attrs)
Get the verifier of a type or attribute definition.
Definition IRDLLoading.cpp:431
static DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition > > preallocateAttrDefs(ModuleOp op, DenseMap< DialectOp, ExtensibleDialect * > dialects)
Preallocate attribute definitions objects with empty verifiers.
Definition IRDLLoading.cpp:641
LogicalResult getOperandSegmentSizes(Operation *op, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given operands.
Definition IRDLLoading.cpp:169
static LogicalResult irdlRegionVerifier(Operation *op, ConstraintVerifier &verifier, ArrayRef< std::unique_ptr< RegionConstraint > > regionsConstraints)
Definition IRDLLoading.cpp:253
static DenseMap< DialectOp, ExtensibleDialect * > loadEmptyDialects(ModuleOp op)
Load all dialects in the given module, without loading any operation, type or attribute definitions.
Definition IRDLLoading.cpp:606
LogicalResult getSegmentSizes(Operation *op, StringRef elemName, StringRef attrName, unsigned numElements, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given element (operands, results).
Definition IRDLLoading.cpp:111
static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect, const DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition > > &types, const DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition > > &attrs)
Define and load an operation represented by a irdl.operation operation.
Definition IRDLLoading.cpp:399
LogicalResult getResultSegmentSizes(Operation *op, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given results.
Definition IRDLLoading.cpp:180
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Attributes are known-constant values of operations.
TypeID getTypeID()
Return a unique identifier for the concrete attribute type.
static std::unique_ptr< DynamicAttrDefinition > get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier)
Create a new attribute definition at runtime.
llvm::unique_function< LogicalResult( function_ref< InFlightDiagnostic()>, ArrayRef< Attribute >) const > VerifierFn
A dialect that can be defined at runtime.
static std::unique_ptr< DynamicOpDefinition > get(StringRef name, ExtensibleDialect *dialect, OperationName::VerifyInvariantsFn &&verifyFn, OperationName::VerifyRegionInvariantsFn &&verifyRegionFn)
Create a new op at runtime.
static std::unique_ptr< DynamicTypeDefinition > get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier)
Create a new dynamic type definition.
A dialect that can be extended with new operations/types/attributes at runtime.
void registerDynamicOp(std::unique_ptr< DynamicOpDefinition > &&type)
Add a new operation defined at runtime to the dialect.
void registerDynamicType(std::unique_ptr< DynamicTypeDefinition > &&type)
Add a new type defined at runtime to the dialect.
void registerDynamicAttr(std::unique_ptr< DynamicAttrDefinition > &&attr)
Add a new attribute defined at runtime to the dialect.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
DynamicDialect * getOrLoadDynamicDialect(StringRef dialectNamespace, function_ref< void(DynamicDialect *)> ctor)
Get (or create) a dynamic dialect for the given name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Operation is the basic unit of execution within MLIR.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
iterator_range< OpIterator > getOps()
This class provides an efficient unique identifier for a specific C++ type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Provides context to the verification of constraints.
LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Attribute attr, unsigned variable)
Check that a constraint is satisfied by an attribute.
llvm::LogicalResult loadDialects(ModuleOp op)
Load all the dialects defined in the module.
Definition IRDLLoading.cpp:656
Operation * lookupSymbolNearDialect(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Looks up a symbol from the symbol table containing the source operation's dialect definition operatio...
llvm::unique_function< LogicalResult(Operation *) const > createVerifier(OperationOp operation, const DenseMap< irdl::TypeOp, std::unique_ptr< DynamicTypeDefinition > > &typeDefs, const DenseMap< irdl::AttributeOp, std::unique_ptr< DynamicAttrDefinition > > &attrDefs)
Generate an op verifier function from the given IRDL operation definition.
Definition IRDLLoading.cpp:271
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
This represents an operation in an abstracted form, suitable for use with the builder APIs.