MLIR: lib/TableGen/Pattern.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14#include
15
17#include "llvm/ADT/StringExtras.h"
18#include "llvm/ADT/Twine.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "llvm/Support/Path.h"
22#include "llvm/TableGen/Error.h"
23#include "llvm/TableGen/Record.h"
24
25#define DEBUG_TYPE "mlir-tblgen-pattern"
26
27using namespace mlir;
28using namespace tblgen;
29
30using llvm::DagInit;
31using llvm::dbgs;
32using llvm::DefInit;
33using llvm::formatv;
34using llvm::IntInit;
35using llvm::Record;
36
37
38
39
40
42 return isa_and_nonnullllvm::UnsetInit(def);
43}
44
46
47 return isSubClassOf("TypeConstraint");
48}
49
51
52 return isSubClassOf("AttrConstraint");
53}
54
56
57 return isSubClassOf("PropConstraint");
58}
59
61
62 return isSubClassOf("Property");
63}
64
66 return isSubClassOf("NativeCodeCall");
67}
68
70
72
74
76
79 "the DAG leaf must be operand, attribute, or property");
80 return Constraint(cast(def)->getDef());
81}
82
84 assert(isPropMatcher() && "the DAG leaf must be a property matcher");
86}
87
89 assert(isPropDefinition() && "the DAG leaf must be a property definition");
90 return Property(cast(def)->getDef());
91}
92
94 assert(isConstantAttr() && "the DAG leaf must be constant attribute");
96}
97
99 assert(isEnumCase() && "the DAG leaf must be an enum attribute case");
100 return EnumCase(cast(def));
101}
102
104 assert(isConstantProp() && "the DAG leaf must be a constant property value");
106}
107
111
113 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
114 return cast(def)->getDef()->getValueAsString("expression");
115}
116
118 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
119 return cast(def)->getDef()->getValueAsInt("numReturns");
120}
121
123 assert(isStringAttr() && "the DAG leaf must be string attribute");
124 return def->getAsUnquotedString();
125}
126bool DagLeaf::isSubClassOf(StringRef superclass) const {
127 if (auto *defInit = dyn_cast_or_null(def))
128 return defInit->getDef()->isSubClassOf(superclass);
129 return false;
130}
131
133 if (def)
134 def->print(os);
135}
136
137
138
139
140
142 if (auto *defInit = dyn_cast_or_null(node->getOperator()))
143 return defInit->getDef()->isSubClassOf("NativeCodeCall");
144 return false;
145}
146
152
154 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
155 return cast(node->getOperator())
156 ->getDef()
157 ->getValueAsString("expression");
158}
159
161 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
162 return cast(node->getOperator())
163 ->getDef()
164 ->getValueAsInt("numReturns");
165}
166
168
170 const Record *opDef = cast(node->getOperator())->getDef();
171 auto [it, inserted] = mapper->try_emplace(opDef);
173 it->second = std::make_unique(opDef);
174 return *it->second;
175}
176
178
179
181 for (int i = 0, e = getNumArgs(); i != e; ++i) {
183 count += child.getNumOps();
184 }
185 return count;
186}
187
189
191 return isa(node->getArg(index));
192}
193
195 return DagNode(dyn_cast_or_null(node->getArg(index)));
196}
197
202
204 return node->getArgNameStr(index);
205}
206
208 auto *dagOpDef = cast(node->getOperator())->getDef();
209 return dagOpDef->getName() == "replaceWithValue";
210}
211
213 auto *dagOpDef = cast(node->getOperator())->getDef();
214 return dagOpDef->getName() == "location";
215}
216
218 auto *dagOpDef = cast(node->getOperator())->getDef();
219 return dagOpDef->getName() == "returnType";
220}
221
223 auto *dagOpDef = cast(node->getOperator())->getDef();
224 return dagOpDef->getName() == "either";
225}
226
228 auto *dagOpDef = cast(node->getOperator())->getDef();
229 return dagOpDef->getName() == "variadic";
230}
231
233 if (node)
234 node->print(os);
235}
236
237
238
239
240
242 int idx = -1;
243 auto [name, indexStr] = symbol.rsplit("__");
244
245 if (indexStr.consumeInteger(10, idx)) {
246
247 return symbol;
248 }
251 }
252 return name;
253}
254
255SymbolInfoMap::SymbolInfo::SymbolInfo(
256 const Operator *op, SymbolInfo::Kind kind,
257 std::optional dagAndConstant)
258 : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
259
260int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
261 switch (kind) {
262 case Kind::Attr:
263 case Kind::Prop:
264 case Kind::Operand:
265 case Kind::Value:
266 return 1;
267 case Kind::Result:
269 case Kind::MultipleValues:
270 return getSize();
271 }
272 llvm_unreachable("unknown kind");
273}
274
276 return alternativeName ? *alternativeName : name.str();
277}
278
280 LLVM_DEBUG(dbgs() << "getVarTypeStr for '" << name << "': ");
281 switch (kind) {
282 case Kind::Attr: {
283 if (op)
284 return cast<NamedAttribute *>(op->getArg(getArgIndex()))
285 ->attr.getStorageType()
286 .str();
287
288 return "::mlir::Attribute";
289 }
290 case Kind::Prop: {
291 if (op)
292 return cast<NamedProperty *>(op->getArg(getArgIndex()))
293 ->prop.getInterfaceType()
294 .str();
295 assert(dagAndConstant && dagAndConstant->dag &&
296 "generic properties must carry their constraint");
297 return reinterpret_cast<const DagLeaf *>(dagAndConstant->dag)
298 ->getAsPropConstraint()
299 .getInterfaceType()
300 .str();
301 }
302 case Kind::Operand: {
303
304
305 return "::mlir::Operation::operand_range";
306 }
307 case Kind::Value: {
308 return "::mlir::Value";
309 }
310 case Kind::MultipleValues: {
311 return "::mlir::ValueRange";
312 }
313 case Kind::Result: {
314
315 return op->getQualCppClassName();
316 }
317 }
318 llvm_unreachable("unknown kind");
319}
320
322 LLVM_DEBUG(dbgs() << "getVarDecl for '" << name << "': ");
323 std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
324 return std::string(
326}
327
329 LLVM_DEBUG(dbgs() << "getArgDecl for '" << name << "': ");
330 return std::string(
332}
333
334std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
335 StringRef name, int index, const char *fmt, const char *separator) const {
336 LLVM_DEBUG(dbgs() << "getValueAndRangeUse for '" << name << "': ");
337 switch (kind) {
338 case Kind::Attr: {
339 assert(index < 0);
340 auto repl = formatv(fmt, name);
341 LLVM_DEBUG(dbgs() << repl << " (Attr)\n");
342 return std::string(repl);
343 }
344 case Kind::Prop: {
345 assert(index < 0);
346 auto repl = formatv(fmt, name);
347 LLVM_DEBUG(dbgs() << repl << " (Prop)\n");
348 return std::string(repl);
349 }
350 case Kind::Operand: {
351 assert(index < 0);
352 auto *operand = cast<NamedTypeConstraint *>(op->getArg(getArgIndex()));
353 if (operand->isOptional()) {
354 auto repl = formatv(
355 fmt, formatv("({0}.empty() ? ::mlir::Value() : *{0}.begin())", name));
356 LLVM_DEBUG(dbgs() << repl << " (OptionalOperand)\n");
357 return std::string(repl);
358 }
359
360
361
362 if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
363 auto repl = formatv(fmt, name);
364 LLVM_DEBUG(dbgs() << repl << " (VariadicOperand)\n");
365 return std::string(repl);
366 }
367 auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
368 LLVM_DEBUG(dbgs() << repl << " (SingleOperand)\n");
369 return std::string(repl);
370 }
371 case Kind::Result: {
372
373
374 if (index >= 0) {
375 std::string v =
376 std::string(formatv("{0}.getODSResults({1})", name, index));
377 if (!op->getResult(index).isVariadic())
378 v = std::string(formatv("(*{0}.begin())", v));
379 auto repl = formatv(fmt, v);
380 LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
381 return std::string(repl);
382 }
383
384
385
386 if (op->getNumResults() == 0) {
387 LLVM_DEBUG(dbgs() << name << " (Op)\n");
388 return formatv(fmt, name);
389 }
390
391
392
393 SmallVector<std::string, 4> values;
394 values.reserve(op->getNumResults());
395
396 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
397 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
398 if (!op->getResult(i).isVariadic()) {
399 v = std::string(formatv("(*{0}.begin())", v));
400 }
401 values.push_back(std::string(formatv(fmt, v)));
402 }
403 auto repl = llvm::join(values, separator);
404 LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
405 return repl;
406 }
407 case Kind::Value: {
408 assert(index < 0);
409 assert(op == nullptr);
410 auto repl = formatv(fmt, name);
411 LLVM_DEBUG(dbgs() << repl << " (Value)\n");
412 return std::string(repl);
413 }
414 case Kind::MultipleValues: {
415 assert(op == nullptr);
416 assert(index < getSize());
417 if (index >= 0) {
418 std::string repl =
419 formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
420 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
421 return repl;
422 }
423
424 auto repl =
425 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
426 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
427 return std::string(repl);
428 }
429 }
430 llvm_unreachable("unknown kind");
431}
432
433std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
434 StringRef name, int index, const char *fmt, const char *separator) const {
435 LLVM_DEBUG(dbgs() << "getAllRangeUse for '" << name << "': ");
436 switch (kind) {
437 case Kind::Attr:
438 case Kind::Prop:
439 case Kind::Operand: {
440 assert(index < 0 && "only allowed for symbol bound to result");
441 auto repl = formatv(fmt, name);
442 LLVM_DEBUG(dbgs() << repl << " (Operand/Attr/Prop)\n");
443 return std::string(repl);
444 }
445 case Kind::Result: {
446 if (index >= 0) {
447 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
448 LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
449 return std::string(repl);
450 }
451
452
453
454 SmallVector<std::string, 4> values;
455 values.reserve(op->getNumResults());
456
457 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
458 values.push_back(std::string(
459 formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
460 }
461 auto repl = llvm::join(values, separator);
462 LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n");
463 return repl;
464 }
465 case Kind::Value: {
466 assert(index < 0 && "only allowed for symbol bound to result");
467 assert(op == nullptr);
468 auto repl = formatv(fmt, formatv("{{{0}}", name));
469 LLVM_DEBUG(dbgs() << repl << " (Value)\n");
470 return std::string(repl);
471 }
472 case Kind::MultipleValues: {
473 assert(op == nullptr);
474 assert(index < getSize());
475 if (index >= 0) {
476 std::string repl =
477 formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
478 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
479 return repl;
480 }
481 auto repl =
482 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
483 LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n");
484 return std::string(repl);
485 }
486 }
487 llvm_unreachable("unknown kind");
488}
489
491 const Operator &op, int argIndex,
492 std::optional variadicSubIndex) {
494 if (name != symbol) {
495 auto error = formatv(
496 "symbol '{0}' with trailing index cannot bind to op argument", symbol);
497 PrintFatalError(loc, error);
498 }
499
502 isa<NamedAttribute *>(arg) ? SymbolInfo::getAttr(&op, argIndex)
503 : isa<NamedProperty *>(arg)
504 ? SymbolInfo::getProp(&op, argIndex)
505 : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
506
507 std::string key = symbol.str();
508 if (symbolInfoMap.count(key)) {
509
510 if (symInfo.kind != SymbolInfo::Kind::Operand) {
511 return false;
512 }
513
514
515
516 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
517 return false;
518 }
519 }
520
521 symbolInfoMap.emplace(key, symInfo);
522 return true;
523}
524
527 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
528
529 return symbolInfoMap.count(inserted->first) == 1;
530}
531
534 if (numValues > 1)
537}
538
540 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
541 return symbolInfoMap.count(inserted->first) == 1;
542}
543
547 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
548 return symbolInfoMap.count(inserted->first) == 1;
549}
550
552 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
553 return symbolInfoMap.count(inserted->first) == 1;
554}
555
559 symbolInfoMap.emplace(symbol.str(), SymbolInfo::getProp(&constraint));
560 return symbolInfoMap.count(inserted->first) == 1;
561}
562
564 return find(symbol) != symbolInfoMap.end();
565}
566
569
570 return symbolInfoMap.find(name);
571}
572
575 int argIndex,
576 std::optional variadicSubIndex) const {
578 key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
579}
580
583 const SymbolInfo &symbolInfo) const {
585 auto range = symbolInfoMap.equal_range(name);
586
587 for (auto it = range.first; it != range.second; ++it)
588 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
589 return it;
590
591 return symbolInfoMap.end();
592}
593
594std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
597
598 return symbolInfoMap.equal_range(name);
599}
600
603 return symbolInfoMap.count(name);
604}
605
608 if (name != symbol) {
609
610
611 return 1;
612 }
613
614 return find(name)->second.getStaticValueCount();
615}
616
618 const char *fmt,
619 const char *separator) const {
622
623 auto it = symbolInfoMap.find(name.str());
624 if (it == symbolInfoMap.end()) {
625 auto error = formatv("referencing unbound symbol '{0}'", symbol);
626 PrintFatalError(loc, error);
627 }
628
629 return it->second.getValueAndRangeUse(name, index, fmt, separator);
630}
631
633 const char *separator) const {
636
637 auto it = symbolInfoMap.find(name.str());
638 if (it == symbolInfoMap.end()) {
639 auto error = formatv("referencing unbound symbol '{0}'", symbol);
640 PrintFatalError(loc, error);
641 }
642
643 return it->second.getAllRangeUse(name, index, fmt, separator);
644}
645
648
649 for (auto symbolInfoIt = symbolInfoMap.begin();
650 symbolInfoIt != symbolInfoMap.end();) {
651 auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
652 auto startRange = range.first;
653 auto endRange = range.second;
654
655 auto operandName = symbolInfoIt->first;
656 int startSearchIndex = 0;
657 for (++startRange; startRange != endRange; ++startRange) {
658
659
660 for (int i = startSearchIndex;; ++i) {
661 std::string alternativeName = operandName + std::to_string(i);
662 if (!usedNames.contains(alternativeName) &&
663 symbolInfoMap.count(alternativeName) == 0) {
664 usedNames.insert(alternativeName);
665 startRange->second.alternativeName = alternativeName;
666 startSearchIndex = i + 1;
667
668 break;
669 }
670 }
671 }
672
673 symbolInfoIt = endRange;
674 }
675}
676
677
678
679
680
682 : def(*def), recordOpMap(mapper) {}
683
685 return DagNode(def.getValueAsDag("sourcePattern"));
686}
687
689 auto *results = def.getValueAsListInit("resultPatterns");
690 return results->size();
691}
692
694 auto *results = def.getValueAsListInit("resultPatterns");
695 return DagNode(cast(results->getElement(index)));
696}
697
699 LLVM_DEBUG(dbgs() << "start collecting source pattern bound symbols\n");
701 LLVM_DEBUG(dbgs() << "done collecting source pattern bound symbols\n");
702
703 LLVM_DEBUG(dbgs() << "start assigning alternative names for symbols\n");
705 LLVM_DEBUG(dbgs() << "done assigning alternative names for symbols\n");
706}
707
709 LLVM_DEBUG(dbgs() << "start collecting result pattern bound symbols\n");
713 }
714 LLVM_DEBUG(dbgs() << "done collecting result pattern bound symbols\n");
715}
716
720
724
726 auto *listInit = def.getValueAsListInit("constraints");
727 std::vector ret;
728 ret.reserve(listInit->size());
729
730 for (auto *it : *listInit) {
731 auto *dagInit = dyn_cast(it);
732 if (!dagInit)
733 PrintFatalError(&def, "all elements in Pattern multi-entity "
734 "constraints should be DAG nodes");
735
736 std::vectorstd::string entities;
737 entities.reserve(dagInit->arg_size());
738 for (auto *argName : dagInit->getArgNames()) {
739 if (!argName) {
740 PrintFatalError(
741 &def,
742 "operands to additional constraints can only be symbol references");
743 }
744 entities.emplace_back(argName->getValue());
745 }
746
747 ret.emplace_back(cast(dagInit->getOperator())->getDef(),
748 dagInit->getNameStr(), std::move(entities));
749 }
750 return ret;
751}
752
754 auto *results = def.getValueAsListInit("supplementalPatterns");
755 return results->size();
756}
757
759 auto *results = def.getValueAsListInit("supplementalPatterns");
760 return DagNode(cast(results->getElement(index)));
761}
762
764
765
767 const DagInit *delta = def.getValueAsDag("benefitDelta");
768 if (delta->getNumArgs() != 1 || !isa(delta->getArg(0))) {
769 PrintFatalError(&def,
770 "The 'addBenefit' takes and only takes one integer value");
771 }
772 return initBenefit + dyn_cast(delta->getArg(0))->getValue();
773}
774
775std::vectorPattern::IdentifierLine
777 std::vector<std::pair<StringRef, unsigned>> result;
778 result.reserve(def.getLoc().size());
779 for (auto loc : def.getLoc()) {
780 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
781 assert(buf && "invalid source location");
782
783 StringRef bufferName =
784 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier();
785
786
787
788
789
790
791 if (forSourceOutput && llvm::sys::path::is_absolute(bufferName))
792 bufferName = llvm::sys::path::filename(bufferName);
793
794 result.emplace_back(bufferName,
795 llvm::SrcMgr.getLineAndColumn(loc, buf).first);
796 }
798}
799
800void Pattern::verifyBind(bool result, StringRef symbolName) {
802 auto err = formatv("symbol '{0}' bound more than once", symbolName);
803 PrintFatalError(&def, err);
804 }
805}
806
808 bool isSrcPattern) {
809 auto treeName = tree.getSymbol();
810 auto numTreeArgs = tree.getNumArgs();
811
813 if (!treeName.empty()) {
814 if (!isSrcPattern) {
815 LLVM_DEBUG(dbgs() << "found symbol bound to NativeCodeCall: "
816 << treeName << '\n');
817 verifyBind(
819 treeName);
820 } else {
821 PrintFatalError(&def,
822 formatv("binding symbol '{0}' to NativecodeCall in "
823 "MatchPattern is not supported",
824 treeName));
825 }
826 }
827
828 for (int i = 0; i != numTreeArgs; ++i) {
830
832 continue;
833 }
834
835 if (!isSrcPattern)
836 continue;
837
838
839
840 auto treeArgName = tree.getArgName(i);
841
842
843 if (!treeArgName.empty() && treeArgName != "_") {
845
846
847
849
850 verifyBind(infoMap.bindValue(treeArgName), treeArgName);
852
854 if (propConstraint.getInterfaceType().empty()) {
855 PrintFatalError(&def,
856 formatv("binding symbol '{0}' in NativeCodeCall to "
857 "a property constraint without specifying "
858 "that constraint's type is unsupported",
859 treeArgName));
860 }
861 verifyBind(infoMap.bindProp(treeArgName, propConstraint),
862 treeArgName);
863 } else {
868
869 if (isAttr) {
870
871 verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
872 continue;
873 }
874
875
876 verifyBind(infoMap.bindValue(treeArgName), treeArgName);
877 }
878 }
879 }
880
881 return;
882 }
883
886 auto numOpArgs = op.getNumArgs();
887 int numEither = 0;
888
889
890
891 int numDirectives = 0;
892 for (int i = numTreeArgs - 1; i >= 0; --i) {
894 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
895 ++numDirectives;
896 else if (dagArg.isEither())
897 ++numEither;
898 }
899 }
900
901 if (numOpArgs != numTreeArgs - numDirectives + numEither) {
902 auto err =
903 formatv("op '{0}' argument number mismatch: "
904 "{1} in pattern vs. {2} in definition",
905 op.getOperationName(), numTreeArgs + numEither, numOpArgs);
906 PrintFatalError(&def, err);
907 }
908
909
910
911 if (!treeName.empty()) {
912 LLVM_DEBUG(dbgs() << "found symbol bound to op result: " << treeName
913 << '\n');
914 verifyBind(infoMap.bindOpResult(treeName, op), treeName);
915 }
916
917
918
919 auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
920 int opArgIdx) {
921 for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
924 } else {
925 auto argName = tree.getArgName(i);
926 if (!argName.empty() && argName != "_") {
927 verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
928 argName);
929 }
930 }
931 }
932 };
933
934
935
936
937 auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
938 int opArgIdx) {
939 auto treeName = tree.getSymbol();
940 if (!treeName.empty()) {
941
942 verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx,
943 std::nullopt),
944 treeName);
945 }
946
947 for (int i = 0; i < tree.getNumArgs(); ++i) {
950 } else {
951 auto argName = tree.getArgName(i);
952 if (!argName.empty() && argName != "_") {
953 verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx,
954 i),
955 argName);
956 }
957 }
958 }
959 };
960
961 for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
962 if (auto treeArg = tree.getArgAsNestedDag(i)) {
963 if (treeArg.isEither()) {
964 collectSymbolInEither(tree, treeArg, opArgIdx);
965
966
967
968
969
970
971
972 ++opArgIdx;
973 } else if (treeArg.isVariadic()) {
974 collectSymbolInVariadic(tree, treeArg, opArgIdx);
975 } else {
976
978 }
979 continue;
980 }
981
982 if (isSrcPattern) {
983
984
985 auto treeArgName = tree.getArgName(i);
986
987 if (!treeArgName.empty() && treeArgName != "_") {
988 LLVM_DEBUG(dbgs() << "found symbol bound to op argument: "
989 << treeArgName << '\n');
990 verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
991 treeArgName);
992 }
993 }
994 }
995 return;
996 }
997
998 if (!treeName.empty()) {
999 PrintFatalError(
1000 &def, formatv("binding symbol '{0}' to non-operation/native code call "
1001 "unsupported right now",
1002 treeName));
1003 }
1004}
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
std::string getConditionTemplate() const
Constraint getAsConstraint() const
Definition Pattern.cpp:77
bool isNativeCodeCall() const
Definition Pattern.cpp:65
bool isPropMatcher() const
Definition Pattern.cpp:55
int getNumReturnsOfNativeCode() const
Definition Pattern.cpp:117
ConstantAttr getAsConstantAttr() const
Definition Pattern.cpp:93
void print(raw_ostream &os) const
Definition Pattern.cpp:132
std::string getStringAttr() const
Definition Pattern.cpp:122
Property getAsProperty() const
Definition Pattern.cpp:88
bool isEnumCase() const
Definition Pattern.cpp:71
StringRef getNativeCodeTemplate() const
Definition Pattern.cpp:112
std::string getConditionTemplate() const
Definition Pattern.cpp:108
bool isConstantProp() const
Definition Pattern.cpp:73
PropConstraint getAsPropConstraint() const
Definition Pattern.cpp:83
ConstantProp getAsConstantProp() const
Definition Pattern.cpp:103
bool isUnspecified() const
Definition Pattern.cpp:41
EnumCase getAsEnumCase() const
Definition Pattern.cpp:98
bool isAttrMatcher() const
Definition Pattern.cpp:50
bool isOperandMatcher() const
Definition Pattern.cpp:45
bool isPropDefinition() const
Definition Pattern.cpp:60
bool isConstantAttr() const
Definition Pattern.cpp:69
bool isStringAttr() const
Definition Pattern.cpp:75
bool isReturnTypeDirective() const
Definition Pattern.cpp:217
bool isLocationDirective() const
Definition Pattern.cpp:212
bool isReplaceWithValue() const
Definition Pattern.cpp:207
DagNode getArgAsNestedDag(unsigned index) const
Definition Pattern.cpp:194
bool isOperation() const
Definition Pattern.cpp:147
DagLeaf getArgAsLeaf(unsigned index) const
Definition Pattern.cpp:198
int getNumReturnsOfNativeCode() const
Definition Pattern.cpp:160
StringRef getNativeCodeTemplate() const
Definition Pattern.cpp:153
void print(raw_ostream &os) const
Definition Pattern.cpp:232
int getNumOps() const
Definition Pattern.cpp:177
Operator & getDialectOp(RecordOperatorMap *mapper) const
Definition Pattern.cpp:169
bool isVariadic() const
Definition Pattern.cpp:227
bool isNativeCodeCall() const
Definition Pattern.cpp:141
bool isEither() const
Definition Pattern.cpp:222
bool isNestedDagArg(unsigned index) const
Definition Pattern.cpp:190
StringRef getSymbol() const
Definition Pattern.cpp:167
int getNumArgs() const
Definition Pattern.cpp:188
DagNode(const llvm::DagInit *node)
StringRef getArgName(unsigned index) const
Definition Pattern.cpp:203
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
int getNumResults() const
Returns the number of results this op produces.
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
int getNumResultPatterns() const
Definition Pattern.cpp:688
DagNode getSourcePattern() const
Definition Pattern.cpp:684
const Operator & getSourceRootOp()
Definition Pattern.cpp:717
std::vector< AppliedConstraint > getConstraints() const
Definition Pattern.cpp:725
DagNode getResultPattern(unsigned index) const
Definition Pattern.cpp:693
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Definition Pattern.cpp:807
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
Operator & getDialectOp(DagNode node)
Definition Pattern.cpp:721
DagNode getSupplementalPattern(unsigned index) const
Definition Pattern.cpp:758
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
Definition Pattern.cpp:698
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
Definition Pattern.cpp:708
int getNumSupplementalPatterns() const
Definition Pattern.cpp:753
std::vector< IdentifierLine > getLocation(bool forSourceOutput=false) const
Definition Pattern.cpp:776
std::string getArgDecl(StringRef name) const
Definition Pattern.cpp:328
std::string getVarName(StringRef name) const
Definition Pattern.cpp:275
std::string getVarTypeStr(StringRef name) const
Definition Pattern.cpp:279
std::string getVarDecl(StringRef name) const
Definition Pattern.cpp:321
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
Definition Pattern.cpp:241
int count(StringRef key) const
Definition Pattern.cpp:601
const_iterator find(StringRef key) const
Definition Pattern.cpp:567
void assignUniqueAlternativeNames()
Definition Pattern.cpp:646
bool bindMultipleValues(StringRef symbol, int numValues)
Definition Pattern.cpp:544
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
Definition Pattern.cpp:490
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition Pattern.cpp:632
bool bindValues(StringRef symbol, int numValues=1)
Definition Pattern.cpp:532
bool bindAttr(StringRef symbol)
Definition Pattern.cpp:551
bool bindProp(StringRef symbol, const PropConstraint &constraint)
Definition Pattern.cpp:556
bool bindValue(StringRef symbol)
Definition Pattern.cpp:539
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
Definition Pattern.cpp:574
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
Definition Pattern.cpp:595
int getStaticValueCount(StringRef symbol) const
Definition Pattern.cpp:606
bool contains(StringRef symbol) const
Definition Pattern.cpp:563
BaseT::const_iterator const_iterator
bool bindOpResult(StringRef symbol, const Operator &op)
Definition Pattern.cpp:525
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition Pattern.cpp:617
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
DenseMap< const llvm::Record *, std::unique_ptr< Operator > > RecordOperatorMap
Include the generated interface declarations.