MLIR: lib/Conversion/PDLToPDLInterp/PredicateTree.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
11
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/Debug.h"
21 #include
22
23 #define DEBUG_TYPE "pdl-predicate-tree"
24
25 using namespace mlir;
27
28
29
30
31
32 static void getTreePredicates(std::vector &predList,
36
37
40 }
41
42
44 return llvm::count_if(values.getTypes(),
45 [](Type type) { return !isapdl::RangeType(type); });
46 }
47
52 assert(isapdl::AttributeType(val.getType()) && "expected attribute type");
53 predList.emplace_back(pos, builder.getIsNotNull());
54
55 if (auto attr = dyn_castpdl::AttributeOp(val.getDefiningOp())) {
56
57 if (Value type = attr.getValueType())
59 else if (Attribute value = attr.getValueAttr())
61 }
62 }
63
64
70 bool isVariadic = isapdl::RangeType(valueType);
71
72
74 .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
75
76
77 if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
78 cast(pos)->getOperandGroupNumber())
79 predList.emplace_back(pos, builder.getIsNotNull());
80
81 if (Value type = op.getValueType())
84 })
85 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
86 std::optional index = op.getIndex();
87
88
89 if (index)
90 predList.emplace_back(pos, builder.getIsNotNull());
91
92
94 predList.emplace_back(parentPos, builder.getIsNotNull());
95
96
97
98 Position *resultPos = nullptr;
99 if (std::is_same<pdl::ResultOp, decltype(op)>::value)
100 resultPos = builder.getResult(parentPos, *index);
101 else
102 resultPos = builder.getResultGroup(parentPos, index, isVariadic);
103 predList.emplace_back(resultPos, builder.getEqualTo(pos));
104
105
108 });
109 }
110
111 static void
115 std::optional ignoreOperand = std::nullopt) {
116 assert(isapdl::OperationType(val.getType()) && "expected operation");
117 pdl::OperationOp op = castpdl::OperationOp(val.getDefiningOp());
119
120
121 if (!opPos->isRoot())
122 predList.emplace_back(pos, builder.getIsNotNull());
123
124
125 if (std::optional opName = op.getOpName())
126 predList.emplace_back(pos, builder.getOperationName(*opName));
127
128
129
130 OperandRange operands = op.getOperandValues();
132 if (minOperands != operands.size()) {
133 if (minOperands)
135 } else {
136 predList.emplace_back(pos, builder.getOperandCount(minOperands));
137 }
138
139
140
143 if (minResults == types.size())
144 predList.emplace_back(pos, builder.getResultCount(types.size()));
145 else if (minResults)
147
148
149 for (auto [attrName, attr] :
150 llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
152 predList, attr, builder, inputs,
154 }
155
156
157
158
159
160
161
162 if (operands.size() == 1 && isapdl::RangeType(operands[0].getType())) {
163
164
168 } else {
169 bool foundVariableLength = false;
170 for (const auto &operandIt : llvm::enumerate(operands)) {
171 bool isVariadic = isapdl::RangeType(operandIt.value().getType());
172 foundVariableLength |= isVariadic;
173
174
175
176 if (ignoreOperand == operandIt.index())
177 continue;
178
180 foundVariableLength
181 ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
182 : builder.getOperand(opPos, operandIt.index());
183 getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
184 }
185 }
186
187 if (types.size() == 1 && isapdl::RangeType(types[0].getType())) {
190 return;
191 }
192
193 bool foundVariableLength = false;
195 bool isVariadic = isapdl::RangeType(typeValue.getType());
196 foundVariableLength |= isVariadic;
197
198 auto *resultPos = foundVariableLength
201 predList.emplace_back(resultPos, builder.getIsNotNull());
203 builder.getType(resultPos));
204 }
205 }
206
211
212 if (pdl::TypeOp typeOp = val.getDefiningOppdl::TypeOp()) {
213 if (Attribute type = typeOp.getConstantTypeAttr())
215 } else if (pdl::TypesOp typeOp = val.getDefiningOppdl::TypesOp()) {
216 if (Attribute typeAttr = typeOp.getConstantTypesAttr())
218 }
219 }
220
221
226
227 auto it = inputs.try_emplace(val, pos);
228 if (!it.second) {
229
230
231 if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
233 auto minMaxPositions =
235 predList.emplace_back(minMaxPositions.second,
236 builder.getEqualTo(minMaxPositions.first));
237 }
238 return;
239 }
240
244 })
245 .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
247 })
248 .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
249 }
250
252 std::vector &predList,
255 Position *&attrPos = inputs[op];
256 if (attrPos)
257 return;
258 Attribute value = op.getValueAttr();
259 assert(value && "expected non-tree `pdl.attribute` to contain a value");
261 }
262
264 std::vector &predList,
268
269 std::vector<Position *> allPositions;
270 allPositions.reserve(arguments.size());
271 for (Value arg : arguments)
272 allPositions.push_back(inputs.lookup(arg));
273
274
279 op.getIsNegated());
280
281
285 auto [it, inserted] = inputs.try_emplace(result, pos);
286
287
288 if (!inserted) {
290 Position *second = it->second;
292 std::tie(second, first) = std::make_pair(first, second);
293
294 predList.emplace_back(second, builder.getEqualTo(first));
295 }
296 }
297 predList.emplace_back(pos, pred);
298 }
299
301 std::vector &predList,
304 Position *&resultPos = inputs[op];
305 if (resultPos)
306 return;
307
308
309 auto *parentPos = cast(inputs.lookup(op.getParent()));
310 resultPos = builder.getResult(parentPos, op.getIndex());
311 predList.emplace_back(resultPos, builder.getIsNotNull());
312 }
313
315 std::vector &predList,
318 Position *&resultPos = inputs[op];
319 if (resultPos)
320 return;
321
322
323 auto *parentPos = cast(inputs.lookup(op.getParent()));
324 bool isVariadic = isapdl::RangeType(op.getType());
325 std::optional index = op.getIndex();
326 resultPos = builder.getResultGroup(parentPos, index, isVariadic);
327 if (index)
328 predList.emplace_back(resultPos, builder.getIsNotNull());
329 }
330
335 Position *&typePos = inputs[typeValue];
336 if (typePos)
337 return;
338 Attribute typeAttr = typeAttrFn();
339 assert(typeAttr &&
340 "expected non-tree `pdl.type`/`pdl.types` to contain a value");
342 }
343
344
345
347 std::vector &predList,
350 for (Operation &op : pattern.getBodyRegion().getOps()) {
352 .Case([&](pdl::AttributeOp attrOp) {
354 })
355 .Casepdl::ApplyNativeConstraintOp([&](auto constraintOp) {
357 })
358 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
360 })
361 .Case([&](pdl::TypeOp typeOp) {
363 typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder,
364 inputs);
365 })
366 .Case([&](pdl::TypesOp typeOp) {
368 typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder,
369 inputs);
370 });
371 }
372 }
373
374 namespace {
375
376
377 struct OpIndex {
379 std::optional index;
380 };
381
382
383
385
386 }
387
388
389
391
392
394 for (auto operationOp : pattern.getBodyRegion().getOpspdl::OperationOp()) {
395 for (Value operand : operationOp.getOperandValues())
397 .Case<pdl::ResultOp, pdl::ResultsOp>(
398 [&used](auto resultOp) { used.insert(resultOp.getParent()); });
399 }
400
401
402
403 if (Value root = pattern.getRewriter().getRoot())
404 used.erase(root);
405
406
408 for (Value operationOp : pattern.getBodyRegion().getOpspdl::OperationOp())
409 if (!used.contains(operationOp))
410 roots.push_back(operationOp);
411
412 return roots;
413 }
414
415
416
417
418
419
420
421
423 ParentMaps &parentMaps) {
424
425
426
427
428
429
430 struct Entry {
431 Entry(Value value, Value parent, std::optional index,
432 unsigned depth)
433 : value(value), parent(parent), index(index), depth(depth) {}
434
437 std::optional index;
438 unsigned depth;
439 };
440
441
442 struct RootDepth {
444 unsigned depth = 0;
445 };
446
447
448
449 llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
450
451
452 for (Value root : roots) {
453
454
455
456 std::queue toVisit;
457 toVisit.emplace(root, Value(), 0, 0);
458
459
461
462 while (!toVisit.empty()) {
463 Entry entry = toVisit.front();
464 toVisit.pop();
465
466 if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
467 continue;
468
469
470 connectorsRootsDepths[entry.value].push_back({root, entry.depth});
471
472
473
474
476 .Casepdl::OperationOp([&](auto operationOp) {
477 OperandRange operands = operationOp.getOperandValues();
478
479
480 if (operands.size() == 1 &&
481 isapdl::RangeType(operands[0].getType())) {
482 toVisit.emplace(operands[0], entry.value, std::nullopt,
483 entry.depth + 1);
484 return;
485 }
486
487
488 for (const auto &p :
490 toVisit.emplace(p.value(), entry.value, p.index(),
491 entry.depth + 1);
492 })
493 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
494 toVisit.emplace(resultOp.getParent(), entry.value,
495 resultOp.getIndex(), entry.depth);
496 });
497 }
498 }
499
500
501
502 unsigned nextID = 0;
503 for (const auto &connectorRootsDepths : connectorsRootsDepths) {
504 Value value = connectorRootsDepths.first;
506
507
508 if (rootsDepths.size() == 1)
509 continue;
510
511 for (const RootDepth &p : rootsDepths) {
512 for (const RootDepth &q : rootsDepths) {
513 if (&p == &q)
514 continue;
515
517 if (!entry.connector || entry.cost.first > q.depth) {
519 entry.cost.second = nextID++;
520 entry.cost.first = q.depth;
522 }
523 }
524 }
525 }
526
527 assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
528 "the pattern contains a candidate root disconnected from the others");
529 }
530
531
532
534 OperandRange operands = op.getOperandValues();
535 assert(index < operands.size() && "operand index out of range");
536 for (unsigned i = 0; i <= index; ++i)
537 if (isapdl::RangeType(operands[i].getType()))
538 return true;
539 return false;
540 }
541
542
543 static void visitUpward(std::vector &predList,
546 Position *&pos, unsigned rootID) {
547 Value value = opIndex.parent;
549 .Casepdl::OperationOp([&](auto operationOp) {
550 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
551
552
553 Position *usersPos = builder.getUsers(pos, true);
556
557
559 if (!opIndex.index) {
560
562 } else if (useOperandGroup(operationOp, *opIndex.index)) {
563
564 Type type = operationOp.getOperandValues()[*opIndex.index].getType();
565 bool variadic = isapdl::RangeType(type);
566 operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
567 } else {
568
569 operandPos = builder.getOperand(opPos, *opIndex.index);
570 }
571 predList.emplace_back(operandPos, builder.getEqualTo(pos));
572
573
574
575
576
577 bool inserted = valueToPosition.try_emplace(value, opPos).second;
578 (void)inserted;
579 assert(inserted && "duplicate upward visit");
580
581
582 getTreePredicates(predList, value, builder, valueToPosition, opPos,
583 opIndex.index);
584
585
586 pos = opPos;
587 })
588 .Casepdl::ResultOp([&](auto resultOp) {
589
590 auto *opPos = dyn_cast(pos);
591 assert(opPos && "operations and results must be interleaved");
592 pos = builder.getResult(opPos, *opIndex.index);
593
594
595 valueToPosition.try_emplace(value, pos);
596 })
597 .Casepdl::ResultsOp([&](auto resultOp) {
598
599 auto *opPos = dyn_cast(pos);
600 assert(opPos && "operations and results must be interleaved");
601 bool isVariadic = isapdl::RangeType(value.getType());
602 if (opIndex.index)
603 pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
604 else
606
607
608 valueToPosition.try_emplace(value, pos);
609 });
610 }
611
612
613
616 std::vector &predList,
619
620
622 ParentMaps parentMaps;
624 LLVM_DEBUG({
625 llvm::dbgs() << "Graph:\n";
626 for (auto &target : graph) {
627 llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first
628 << "\n";
629 for (auto &source : target.second) {
631 llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first
632 << ":" << entry.cost.second << " via "
634 }
635 }
636 });
637
638
639
640 Value bestRoot = pattern.getRewriter().getRoot();
642 if (!bestRoot) {
643 unsigned bestCost = 0;
644 LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
645 for (Value root : roots) {
647 unsigned cost = solver.solve();
648 LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n");
649 if (!bestRoot || bestCost > cost) {
650 bestCost = cost;
651 bestRoot = root;
653 }
654 }
655 } else {
659 }
660
661
662 LLVM_DEBUG({
663 llvm::dbgs() << "Best tree:\n";
664 for (const std::pair<Value, Value> &edge : bestEdges) {
665 llvm::dbgs() << " * " << edge.first;
666 if (edge.second)
667 llvm::dbgs() << " <- " << edge.second;
668 llvm::dbgs() << "\n";
669 }
670 });
671
672 LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
673 LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n");
674
675
676
679
680
681
682
684 Value target = it.value().first;
685 Value source = it.value().second;
686
687
688
689
690
691 if (valueToPosition.count(target))
692 continue;
693
694
695 Value connector = graph[target][source].connector;
696 assert(connector && "invalid edge");
697 LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n");
699 Position *pos = valueToPosition.lookup(connector);
700 assert(pos && "connector has not been traversed yet");
701
702
703 for (Value value = connector; value != target;) {
704 OpIndex opIndex = parentMap.lookup(value);
705 assert(opIndex.parent && "missing parent");
706 visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
707 value = opIndex.parent;
708 }
709 }
710
712
713 return bestRoot;
714 }
715
716
717
718
719
720 namespace {
721
722
723
724
725 struct OrderedPredicate {
726 OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
727 : position(ip.first), question(ip.second) {}
729 : position(ip.position), question(ip.question) {}
730
731
733
734
736
737
738
739
740 unsigned primary = 0;
741
742
743
744
745 unsigned secondary = 0;
746
747
748
749
750 unsigned id = 0;
751
752
753
755
756
757
758 bool operator<(const OrderedPredicate &rhs) const {
759
760
761
762
763
764
765 auto *rhsPos = rhs.position;
766 return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
767 rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
768 std::make_tuple(rhs.primary, rhs.secondary,
770 question->getKind(), id);
771 }
772 };
773
774
775
776 struct OrderedPredicateDenseInfo {
778
779 static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
780 static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
781 static bool isEqual(const OrderedPredicate &lhs,
782 const OrderedPredicate &rhs) {
783 return lhs.position == rhs.position && lhs.question == rhs.question;
784 }
785 static unsigned getHashValue(const OrderedPredicate &p) {
786 return llvm::hash_combine(p.position, p.question);
787 }
788 };
789
790
791
792 struct OrderedPredicateList {
793 OrderedPredicateList(pdl::PatternOp pattern, Value root)
794 : pattern(pattern), root(root) {}
795
796 pdl::PatternOp pattern;
799 };
800 }
801
802
803
804
806 return node->getPosition() == predicate->position &&
807 node->getQuestion() == predicate->question;
808 }
809
810
811
813 OrderedPredicate *predicate,
814 pdl::PatternOp pattern) {
816 "expected matcher to equal the given predicate");
817
818 auto it = predicate->patternToAnswer.find(pattern);
819 assert(it != predicate->patternToAnswer.end() &&
820 "expected pattern to exist in predicate");
822 }
823
824
825
826
827
829 OrderedPredicateList &list,
830 std::vector<OrderedPredicate *>::iterator current,
831 std::vector<OrderedPredicate *>::iterator end) {
832 if (current == end) {
833
834 node =
835 std::make_unique(list.pattern, list.root, std::move(node));
836
837
838 } else if (!list.predicates.contains(*current)) {
840
841
842
843 } else if (!node) {
844
845 node = std::make_unique((*current)->position,
846 (*current)->question);
848 getOrCreateChild(cast(&*node), *current, list.pattern),
849 list, std::next(current), end);
850
851
852
855 getOrCreateChild(cast(&*node), *current, list.pattern),
856 list, std::next(current), end);
857
858
859
860 } else {
861 propagatePattern(node->getFailureNode(), list, current, end);
862 }
863 }
864
865
866
868 if (!node)
869 return;
870
871 if (SwitchNode *switchNode = dyn_cast(&*node)) {
873 for (auto &it : children)
875
876
877
878 if (children.size() == 1) {
879 auto *childIt = children.begin();
880 node = std::make_unique(
881 node->getPosition(), node->getQuestion(), childIt->first,
882 std::move(childIt->second), std::move(node->getFailureNode()));
883 }
884 } else if (BoolNode *boolNode = dyn_cast(&*node)) {
886 }
887
889 }
890
891
893 while (*root)
894 root = &(*root)->getFailureNode();
895 *root = std::make_unique();
896 }
897
898
899 template <typename Iterator, typename Compare>
901 while (begin != end) {
902
903
904
906 for (auto i = begin; i != end; ++i) {
907 if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
908 sortBeforeOthers.insert(*i);
909 }
910
911 auto const next = std::stable_partition(begin, end, [&](auto const &a) {
912 return sortBeforeOthers.contains(a);
913 });
914 assert(next != begin && "not a partial ordering");
915 begin = next;
916 }
917 }
918
919
920 static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
921 auto *cqa = dyn_cast(a->question);
922 if (!cqa)
923 return false;
924
925 auto positionDependsOnA = [&](Position *p) {
926 auto *cp = dyn_cast(p);
927 return cp && cp->getQuestion() == cqa;
928 };
929
930 if (auto *cqb = dyn_cast(b->question)) {
931
932 return llvm::any_of(cqb->getArgs(), positionDependsOnA);
933 }
934 if (auto *equalTo = dyn_cast(b->question)) {
935 return positionDependsOnA(b->position) ||
936 positionDependsOnA(equalTo->getValue());
937 }
938 return positionDependsOnA(b->position);
939 }
940
941
942
943 std::unique_ptr
944 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
946
947
948 struct PatternPredicates {
949 PatternPredicates(pdl::PatternOp pattern, Value root,
950 std::vector predicates)
951 : pattern(pattern), root(root), predicates(std::move(predicates)) {}
952
953
954 pdl::PatternOp pattern;
955
956
958
959
960 std::vector predicates;
961 };
962
964 for (pdl::PatternOp pattern : module.getOpspdl::PatternOp()) {
965 std::vector predicateList;
967 buildPredicateList(pattern, builder, predicateList, valueToPosition);
968 patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
969 }
970
971
973 for (auto &patternAndPredList : patternsAndPredicates) {
974 for (auto &predicate : patternAndPredList.predicates) {
975 auto it = uniqued.insert(predicate);
976 it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
977 predicate.answer);
978
979 if (it.second)
980 it.first->id = uniqued.size() - 1;
981 }
982 }
983
984
985 std::vector lists;
986 lists.reserve(patternsAndPredicates.size());
987 for (auto &patternAndPredList : patternsAndPredicates) {
988 OrderedPredicateList list(patternAndPredList.pattern,
989 patternAndPredList.root);
990 for (auto &predicate : patternAndPredList.predicates) {
991 OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
992 list.predicates.insert(orderedPredicate);
993
994
995 ++orderedPredicate->primary;
996 }
997 lists.push_back(std::move(list));
998 }
999
1000
1001
1002
1003 for (auto &list : lists) {
1004 unsigned total = 0;
1005 for (auto *predicate : list.predicates)
1006 total += predicate->primary * predicate->primary;
1007 for (auto *predicate : list.predicates)
1008 predicate->secondary += total;
1009 }
1010
1011
1012
1013 std::vector<OrderedPredicate *> ordered;
1014 ordered.reserve(uniqued.size());
1015 for (auto &ip : uniqued)
1016 ordered.push_back(&ip);
1017 llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
1018 return *lhs < *rhs;
1019 });
1020
1021
1022
1024
1025
1026 std::unique_ptr root;
1027 for (OrderedPredicateList &list : lists)
1028 propagatePattern(root, list, ordered.begin(), ordered.end());
1029
1030
1033 return root;
1034 }
1035
1036
1037
1038
1039
1041 std::unique_ptr failureNode)
1042 : position(p), question(q), failureNode(std::move(failureNode)),
1043 matcherTypeID(matcherTypeID) {}
1044
1045
1046
1047
1048
1050 std::unique_ptr successNode,
1051 std::unique_ptr failureNode)
1053 std::move(failureNode)),
1054 answer(answer), successNode(std::move(successNode)) {}
1055
1056
1057
1058
1059
1061 std::unique_ptr failureNode)
1063 nullptr, std::move(failureNode)),
1064 pattern(pattern), root(root) {}
1065
1066
1067
1068
1069
static Value buildPredicateList(pdl::PatternOp pattern, PredicateBuilder &builder, std::vector< PositionalPredicate > &predList, DenseMap< Value, Position * > &valueToPosition)
Given a pattern operation, build the set of matcher predicates necessary to match this pattern.
static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b)
Returns true if 'b' depends on a result of 'a'.
static void getTypePredicates(Value typeValue, function_ref< Attribute()> typeAttrFn, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getNonTreePredicates(pdl::PatternOp pattern, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
Collect all of the predicates that cannot be determined via walking the tree.
static bool useOperandGroup(pdl::OperationOp op, unsigned index)
Returns true if the operand at the given index needs to be queried using an operand group,...
static void getTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect the tree predicates anchored at the given value.
static void getResultPredicates(pdl::ResultOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getOperandTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect all of the predicates for the given operand position.
std::unique_ptr< MatcherNode > & getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate, pdl::PatternOp pattern)
Get or insert a child matcher for the given parent switch node, given a predicate and parent pattern.
static bool comparePosDepth(Position *lhs, Position *rhs)
Compares the depths of two positions.
static void visitUpward(std::vector< PositionalPredicate > &predList, OpIndex opIndex, PredicateBuilder &builder, DenseMap< Value, Position * > &valueToPosition, Position *&pos, unsigned rootID)
Visit a node during upward traversal.
static unsigned getNumNonRangeValues(ValueRange values)
Returns the number of non-range elements within values.
static SmallVector< Value > detectRoots(pdl::PatternOp pattern)
Given a pattern, determines the set of roots present in this pattern.
static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp)
Sorts the range begin/end with the partial order given by cmp.
static void insertExitNode(std::unique_ptr< MatcherNode > *root)
Insert an exit node at the end of the failure path of the root.
static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate)
Returns true if the given matcher refers to the same predicate as the given ordered predicate.
static void foldSwitchToBool(std::unique_ptr< MatcherNode > &node)
Fold any switch nodes nested under node to boolean nodes when possible.
static void buildCostGraph(ArrayRef< Value > roots, RootOrderingGraph &graph, ParentMaps &parentMaps)
Given a list of candidate roots, builds the cost graph for connecting them.
static void getAttributePredicates(pdl::AttributeOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void propagatePattern(std::unique_ptr< MatcherNode > &node, OrderedPredicateList &list, std::vector< OrderedPredicate * >::iterator current, std::vector< OrderedPredicate * >::iterator end)
Build the matcher CFG by "pushing" patterns through by sorted predicate order.
Attributes are known-constant values of operations.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
This class implements the result iterators for the Operation class.
type_range getTypes() const
This class provides an efficient unique identifier for a specific C++ type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This class represents the base of a predicate matcher node.
Position * getPosition() const
Returns the position on which the question predicate should be checked.
Qualifier * getQuestion() const
Returns the predicate checked on this node.
The optimal branching algorithm solver.
unsigned solve()
Runs the Edmonds' algorithm for the current graph, returning the total cost of the minimum-weight spa...
std::vector< std::pair< Value, Value > > EdgeList
A list of edges (child, parent).
EdgeList preOrderTraversal(ArrayRef< Value > nodes) const
Returns the computed edges as visited in the preorder traversal.
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
unsigned getOperationDepth() const
Returns the depth of the first ancestor operation position.
Predicates::Kind getKind() const
Returns the kind of this position.
const KeyTy & getValue() const
Return the key value of this predicate.
This class provides utilities for constructing predicates.
ConstraintPosition * getConstraintPosition(ConstraintQuestion *q, unsigned index)
Position * getTypeLiteral(Attribute attr)
Returns a type position for the given type value.
Predicate getOperandCount(unsigned count)
Create a predicate comparing the number of operands of an operation to a known value.
OperationPosition * getPassthroughOp(Position *p)
Returns the operation position equivalent to the given position.
Predicate getIsNotNull()
Create a predicate comparing a value with null.
Predicate getOperandCountAtLeast(unsigned count)
Predicate getResultCountAtLeast(unsigned count)
Position * getType(Position *p)
Returns a type position for the given entity.
Position * getAttribute(OperationPosition *p, StringRef name)
Returns an attribute position for an attribute of the given operation.
Position * getOperandGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of operands of the given operation.
Position * getForEach(Position *p, unsigned id)
Position * getOperand(OperationPosition *p, unsigned operand)
Returns an operand position for an operand of the given operation.
Position * getResult(OperationPosition *p, unsigned result)
Returns a result position for a result of the given operation.
Position * getRoot()
Returns the root operation position.
Predicate getAttributeConstraint(Attribute attr)
Create a predicate comparing an attribute to a known value.
Position * getResultGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of results of the given operation.
Position * getAllResults(OperationPosition *p)
UsersPosition * getUsers(Position *p, bool useRepresentative)
Returns the users of a position using the value at the given operand.
Predicate getTypeConstraint(Attribute type)
Create a predicate comparing the type of an attribute or value to a known type.
OperationPosition * getOperandDefiningOp(Position *p)
Returns the parent position defining the value held by the given operand.
Predicate getResultCount(unsigned count)
Create a predicate comparing the number of results of an operation to a known value.
std::pair< Qualifier *, Qualifier * > Predicate
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Predicate getEqualTo(Position *pos)
Create a predicate checking if two values are equal.
Position * getAllOperands(OperationPosition *p)
Position * getAttributeLiteral(Attribute attr)
Returns an attribute position for the given attribute.
Predicate getConstraint(StringRef name, ArrayRef< Position * > args, ArrayRef< Type > resultTypes, bool isNegated)
Create a predicate that applies a generic constraint.
Predicate getOperationName(StringRef name)
Create a predicate comparing the name of an operation to a known value.
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Predicates::Kind getKind() const
Returns the kind of this qualifier.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool operator<(const Fraction &x, const Fraction &y)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A position describing an attribute of an operation.
A BoolNode denotes a question with a boolean-like result.
BoolNode(Position *position, Qualifier *question, Qualifier *answer, std::unique_ptr< MatcherNode > successNode, std::unique_ptr< MatcherNode > failureNode=nullptr)
A position describing the result of a native constraint.
Apply a parameterized constraint to multiple position values and possibly produce results.
An operation position describes an operation node in the IR.
bool isRoot() const
Returns if this operation position corresponds to the root.
bool isOperandDefiningOp() const
Returns if this operation represents an operand defining op.
A PositionalPredicate is a predicate that is associated with a specific positional value.
The information associated with an edge in the cost graph.
Value connector
The connector value in the intersection of the two subtrees rooted at the source and target root that...
std::pair< unsigned, unsigned > cost
The depth of the connector Value w.r.t.
A SuccessNode denotes that a given high level pattern has successfully been matched.
SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr< MatcherNode > failureNode)
A SwitchNode denotes a question with multiple potential results.
llvm::MapVector< Qualifier *, std::unique_ptr< MatcherNode > > ChildMapT
Returns the children of this switch node.
ChildMapT & getChildren()
SwitchNode(Position *position, Qualifier *question)
A position describing the result type of an entity, i.e.