MLIR: lib/Conversion/PDLToPDLInterp/PredicateTree.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

11

17 #include "llvm/ADT/MapVector.h"

18 #include "llvm/ADT/SmallPtrSet.h"

19 #include "llvm/ADT/TypeSwitch.h"

20 #include "llvm/Support/Debug.h"

21 #include

22

23 #define DEBUG_TYPE "pdl-predicate-tree"

24

25 using namespace mlir;

27

28

29

30

31

32 static void getTreePredicates(std::vector &predList,

36

37

40 }

41

42

44 return llvm::count_if(values.getTypes(),

45 [](Type type) { return !isapdl::RangeType(type); });

46 }

47

52 assert(isapdl::AttributeType(val.getType()) && "expected attribute type");

53 predList.emplace_back(pos, builder.getIsNotNull());

54

55 if (auto attr = dyn_castpdl::AttributeOp(val.getDefiningOp())) {

56

57 if (Value type = attr.getValueType())

59 else if (Attribute value = attr.getValueAttr())

61 }

62 }

63

64

70 bool isVariadic = isapdl::RangeType(valueType);

71

72

74 .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {

75

76

77 if (std::is_same<pdl::OperandOp, decltype(op)>::value ||

78 cast(pos)->getOperandGroupNumber())

79 predList.emplace_back(pos, builder.getIsNotNull());

80

81 if (Value type = op.getValueType())

84 })

85 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {

86 std::optional index = op.getIndex();

87

88

89 if (index)

90 predList.emplace_back(pos, builder.getIsNotNull());

91

92

94 predList.emplace_back(parentPos, builder.getIsNotNull());

95

96

97

98 Position *resultPos = nullptr;

99 if (std::is_same<pdl::ResultOp, decltype(op)>::value)

100 resultPos = builder.getResult(parentPos, *index);

101 else

102 resultPos = builder.getResultGroup(parentPos, index, isVariadic);

103 predList.emplace_back(resultPos, builder.getEqualTo(pos));

104

105

108 });

109 }

110

111 static void

115 std::optional ignoreOperand = std::nullopt) {

116 assert(isapdl::OperationType(val.getType()) && "expected operation");

117 pdl::OperationOp op = castpdl::OperationOp(val.getDefiningOp());

119

120

121 if (!opPos->isRoot())

122 predList.emplace_back(pos, builder.getIsNotNull());

123

124

125 if (std::optional opName = op.getOpName())

126 predList.emplace_back(pos, builder.getOperationName(*opName));

127

128

129

130 OperandRange operands = op.getOperandValues();

132 if (minOperands != operands.size()) {

133 if (minOperands)

135 } else {

136 predList.emplace_back(pos, builder.getOperandCount(minOperands));

137 }

138

139

140

143 if (minResults == types.size())

144 predList.emplace_back(pos, builder.getResultCount(types.size()));

145 else if (minResults)

147

148

149 for (auto [attrName, attr] :

150 llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {

152 predList, attr, builder, inputs,

154 }

155

156

157

158

159

160

161

162 if (operands.size() == 1 && isapdl::RangeType(operands[0].getType())) {

163

164

168 } else {

169 bool foundVariableLength = false;

170 for (const auto &operandIt : llvm::enumerate(operands)) {

171 bool isVariadic = isapdl::RangeType(operandIt.value().getType());

172 foundVariableLength |= isVariadic;

173

174

175

176 if (ignoreOperand == operandIt.index())

177 continue;

178

180 foundVariableLength

181 ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)

182 : builder.getOperand(opPos, operandIt.index());

183 getTreePredicates(predList, operandIt.value(), builder, inputs, pos);

184 }

185 }

186

187 if (types.size() == 1 && isapdl::RangeType(types[0].getType())) {

190 return;

191 }

192

193 bool foundVariableLength = false;

195 bool isVariadic = isapdl::RangeType(typeValue.getType());

196 foundVariableLength |= isVariadic;

197

198 auto *resultPos = foundVariableLength

201 predList.emplace_back(resultPos, builder.getIsNotNull());

203 builder.getType(resultPos));

204 }

205 }

206

211

212 if (pdl::TypeOp typeOp = val.getDefiningOppdl::TypeOp()) {

213 if (Attribute type = typeOp.getConstantTypeAttr())

215 } else if (pdl::TypesOp typeOp = val.getDefiningOppdl::TypesOp()) {

216 if (Attribute typeAttr = typeOp.getConstantTypesAttr())

218 }

219 }

220

221

226

227 auto it = inputs.try_emplace(val, pos);

228 if (!it.second) {

229

230

231 if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,

233 auto minMaxPositions =

235 predList.emplace_back(minMaxPositions.second,

236 builder.getEqualTo(minMaxPositions.first));

237 }

238 return;

239 }

240

244 })

245 .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {

247 })

248 .Default([](auto *) { llvm_unreachable("unexpected position kind"); });

249 }

250

252 std::vector &predList,

255 Position *&attrPos = inputs[op];

256 if (attrPos)

257 return;

258 Attribute value = op.getValueAttr();

259 assert(value && "expected non-tree `pdl.attribute` to contain a value");

261 }

262

264 std::vector &predList,

268

269 std::vector<Position *> allPositions;

270 allPositions.reserve(arguments.size());

271 for (Value arg : arguments)

272 allPositions.push_back(inputs.lookup(arg));

273

274

279 op.getIsNegated());

280

281

285 auto [it, inserted] = inputs.try_emplace(result, pos);

286

287

288 if (!inserted) {

290 Position *second = it->second;

292 std::tie(second, first) = std::make_pair(first, second);

293

294 predList.emplace_back(second, builder.getEqualTo(first));

295 }

296 }

297 predList.emplace_back(pos, pred);

298 }

299

301 std::vector &predList,

304 Position *&resultPos = inputs[op];

305 if (resultPos)

306 return;

307

308

309 auto *parentPos = cast(inputs.lookup(op.getParent()));

310 resultPos = builder.getResult(parentPos, op.getIndex());

311 predList.emplace_back(resultPos, builder.getIsNotNull());

312 }

313

315 std::vector &predList,

318 Position *&resultPos = inputs[op];

319 if (resultPos)

320 return;

321

322

323 auto *parentPos = cast(inputs.lookup(op.getParent()));

324 bool isVariadic = isapdl::RangeType(op.getType());

325 std::optional index = op.getIndex();

326 resultPos = builder.getResultGroup(parentPos, index, isVariadic);

327 if (index)

328 predList.emplace_back(resultPos, builder.getIsNotNull());

329 }

330

335 Position *&typePos = inputs[typeValue];

336 if (typePos)

337 return;

338 Attribute typeAttr = typeAttrFn();

339 assert(typeAttr &&

340 "expected non-tree `pdl.type`/`pdl.types` to contain a value");

342 }

343

344

345

347 std::vector &predList,

350 for (Operation &op : pattern.getBodyRegion().getOps()) {

352 .Case([&](pdl::AttributeOp attrOp) {

354 })

355 .Casepdl::ApplyNativeConstraintOp([&](auto constraintOp) {

357 })

358 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {

360 })

361 .Case([&](pdl::TypeOp typeOp) {

363 typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder,

364 inputs);

365 })

366 .Case([&](pdl::TypesOp typeOp) {

368 typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder,

369 inputs);

370 });

371 }

372 }

373

374 namespace {

375

376

377 struct OpIndex {

379 std::optional index;

380 };

381

382

383

385

386 }

387

388

389

391

392

394 for (auto operationOp : pattern.getBodyRegion().getOpspdl::OperationOp()) {

395 for (Value operand : operationOp.getOperandValues())

397 .Case<pdl::ResultOp, pdl::ResultsOp>(

398 [&used](auto resultOp) { used.insert(resultOp.getParent()); });

399 }

400

401

402

403 if (Value root = pattern.getRewriter().getRoot())

404 used.erase(root);

405

406

408 for (Value operationOp : pattern.getBodyRegion().getOpspdl::OperationOp())

409 if (!used.contains(operationOp))

410 roots.push_back(operationOp);

411

412 return roots;

413 }

414

415

416

417

418

419

420

421

423 ParentMaps &parentMaps) {

424

425

426

427

428

429

430 struct Entry {

431 Entry(Value value, Value parent, std::optional index,

432 unsigned depth)

433 : value(value), parent(parent), index(index), depth(depth) {}

434

437 std::optional index;

438 unsigned depth;

439 };

440

441

442 struct RootDepth {

444 unsigned depth = 0;

445 };

446

447

448

449 llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;

450

451

452 for (Value root : roots) {

453

454

455

456 std::queue toVisit;

457 toVisit.emplace(root, Value(), 0, 0);

458

459

461

462 while (!toVisit.empty()) {

463 Entry entry = toVisit.front();

464 toVisit.pop();

465

466 if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)

467 continue;

468

469

470 connectorsRootsDepths[entry.value].push_back({root, entry.depth});

471

472

473

474

476 .Casepdl::OperationOp([&](auto operationOp) {

477 OperandRange operands = operationOp.getOperandValues();

478

479

480 if (operands.size() == 1 &&

481 isapdl::RangeType(operands[0].getType())) {

482 toVisit.emplace(operands[0], entry.value, std::nullopt,

483 entry.depth + 1);

484 return;

485 }

486

487

488 for (const auto &p :

490 toVisit.emplace(p.value(), entry.value, p.index(),

491 entry.depth + 1);

492 })

493 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {

494 toVisit.emplace(resultOp.getParent(), entry.value,

495 resultOp.getIndex(), entry.depth);

496 });

497 }

498 }

499

500

501

502 unsigned nextID = 0;

503 for (const auto &connectorRootsDepths : connectorsRootsDepths) {

504 Value value = connectorRootsDepths.first;

506

507

508 if (rootsDepths.size() == 1)

509 continue;

510

511 for (const RootDepth &p : rootsDepths) {

512 for (const RootDepth &q : rootsDepths) {

513 if (&p == &q)

514 continue;

515

517 if (!entry.connector || entry.cost.first > q.depth) {

519 entry.cost.second = nextID++;

520 entry.cost.first = q.depth;

522 }

523 }

524 }

525 }

526

527 assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&

528 "the pattern contains a candidate root disconnected from the others");

529 }

530

531

532

534 OperandRange operands = op.getOperandValues();

535 assert(index < operands.size() && "operand index out of range");

536 for (unsigned i = 0; i <= index; ++i)

537 if (isapdl::RangeType(operands[i].getType()))

538 return true;

539 return false;

540 }

541

542

543 static void visitUpward(std::vector &predList,

546 Position *&pos, unsigned rootID) {

547 Value value = opIndex.parent;

549 .Casepdl::OperationOp([&](auto operationOp) {

550 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");

551

552

553 Position *usersPos = builder.getUsers(pos, true);

556

557

559 if (!opIndex.index) {

560

562 } else if (useOperandGroup(operationOp, *opIndex.index)) {

563

564 Type type = operationOp.getOperandValues()[*opIndex.index].getType();

565 bool variadic = isapdl::RangeType(type);

566 operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);

567 } else {

568

569 operandPos = builder.getOperand(opPos, *opIndex.index);

570 }

571 predList.emplace_back(operandPos, builder.getEqualTo(pos));

572

573

574

575

576

577 bool inserted = valueToPosition.try_emplace(value, opPos).second;

578 (void)inserted;

579 assert(inserted && "duplicate upward visit");

580

581

582 getTreePredicates(predList, value, builder, valueToPosition, opPos,

583 opIndex.index);

584

585

586 pos = opPos;

587 })

588 .Casepdl::ResultOp([&](auto resultOp) {

589

590 auto *opPos = dyn_cast(pos);

591 assert(opPos && "operations and results must be interleaved");

592 pos = builder.getResult(opPos, *opIndex.index);

593

594

595 valueToPosition.try_emplace(value, pos);

596 })

597 .Casepdl::ResultsOp([&](auto resultOp) {

598

599 auto *opPos = dyn_cast(pos);

600 assert(opPos && "operations and results must be interleaved");

601 bool isVariadic = isapdl::RangeType(value.getType());

602 if (opIndex.index)

603 pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);

604 else

606

607

608 valueToPosition.try_emplace(value, pos);

609 });

610 }

611

612

613

616 std::vector &predList,

619

620

622 ParentMaps parentMaps;

624 LLVM_DEBUG({

625 llvm::dbgs() << "Graph:\n";

626 for (auto &target : graph) {

627 llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first

628 << "\n";

629 for (auto &source : target.second) {

631 llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first

632 << ":" << entry.cost.second << " via "

634 }

635 }

636 });

637

638

639

640 Value bestRoot = pattern.getRewriter().getRoot();

642 if (!bestRoot) {

643 unsigned bestCost = 0;

644 LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");

645 for (Value root : roots) {

647 unsigned cost = solver.solve();

648 LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n");

649 if (!bestRoot || bestCost > cost) {

650 bestCost = cost;

651 bestRoot = root;

653 }

654 }

655 } else {

659 }

660

661

662 LLVM_DEBUG({

663 llvm::dbgs() << "Best tree:\n";

664 for (const std::pair<Value, Value> &edge : bestEdges) {

665 llvm::dbgs() << " * " << edge.first;

666 if (edge.second)

667 llvm::dbgs() << " <- " << edge.second;

668 llvm::dbgs() << "\n";

669 }

670 });

671

672 LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");

673 LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n");

674

675

676

679

680

681

682

684 Value target = it.value().first;

685 Value source = it.value().second;

686

687

688

689

690

691 if (valueToPosition.count(target))

692 continue;

693

694

695 Value connector = graph[target][source].connector;

696 assert(connector && "invalid edge");

697 LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n");

699 Position *pos = valueToPosition.lookup(connector);

700 assert(pos && "connector has not been traversed yet");

701

702

703 for (Value value = connector; value != target;) {

704 OpIndex opIndex = parentMap.lookup(value);

705 assert(opIndex.parent && "missing parent");

706 visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());

707 value = opIndex.parent;

708 }

709 }

710

712

713 return bestRoot;

714 }

715

716

717

718

719

720 namespace {

721

722

723

724

725 struct OrderedPredicate {

726 OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)

727 : position(ip.first), question(ip.second) {}

729 : position(ip.position), question(ip.question) {}

730

731

733

734

736

737

738

739

740 unsigned primary = 0;

741

742

743

744

745 unsigned secondary = 0;

746

747

748

749

750 unsigned id = 0;

751

752

753

755

756

757

758 bool operator<(const OrderedPredicate &rhs) const {

759

760

761

762

763

764

765 auto *rhsPos = rhs.position;

766 return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),

767 rhsPos->getKind(), rhs.question->getKind(), rhs.id) >

768 std::make_tuple(rhs.primary, rhs.secondary,

770 question->getKind(), id);

771 }

772 };

773

774

775

776 struct OrderedPredicateDenseInfo {

778

779 static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }

780 static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }

781 static bool isEqual(const OrderedPredicate &lhs,

782 const OrderedPredicate &rhs) {

783 return lhs.position == rhs.position && lhs.question == rhs.question;

784 }

785 static unsigned getHashValue(const OrderedPredicate &p) {

786 return llvm::hash_combine(p.position, p.question);

787 }

788 };

789

790

791

792 struct OrderedPredicateList {

793 OrderedPredicateList(pdl::PatternOp pattern, Value root)

794 : pattern(pattern), root(root) {}

795

796 pdl::PatternOp pattern;

799 };

800 }

801

802

803

804

806 return node->getPosition() == predicate->position &&

807 node->getQuestion() == predicate->question;

808 }

809

810

811

813 OrderedPredicate *predicate,

814 pdl::PatternOp pattern) {

816 "expected matcher to equal the given predicate");

817

818 auto it = predicate->patternToAnswer.find(pattern);

819 assert(it != predicate->patternToAnswer.end() &&

820 "expected pattern to exist in predicate");

822 }

823

824

825

826

827

829 OrderedPredicateList &list,

830 std::vector<OrderedPredicate *>::iterator current,

831 std::vector<OrderedPredicate *>::iterator end) {

832 if (current == end) {

833

834 node =

835 std::make_unique(list.pattern, list.root, std::move(node));

836

837

838 } else if (!list.predicates.contains(*current)) {

840

841

842

843 } else if (!node) {

844

845 node = std::make_unique((*current)->position,

846 (*current)->question);

848 getOrCreateChild(cast(&*node), *current, list.pattern),

849 list, std::next(current), end);

850

851

852

855 getOrCreateChild(cast(&*node), *current, list.pattern),

856 list, std::next(current), end);

857

858

859

860 } else {

861 propagatePattern(node->getFailureNode(), list, current, end);

862 }

863 }

864

865

866

868 if (!node)

869 return;

870

871 if (SwitchNode *switchNode = dyn_cast(&*node)) {

873 for (auto &it : children)

875

876

877

878 if (children.size() == 1) {

879 auto *childIt = children.begin();

880 node = std::make_unique(

881 node->getPosition(), node->getQuestion(), childIt->first,

882 std::move(childIt->second), std::move(node->getFailureNode()));

883 }

884 } else if (BoolNode *boolNode = dyn_cast(&*node)) {

886 }

887

889 }

890

891

893 while (*root)

894 root = &(*root)->getFailureNode();

895 *root = std::make_unique();

896 }

897

898

899 template <typename Iterator, typename Compare>

901 while (begin != end) {

902

903

904

906 for (auto i = begin; i != end; ++i) {

907 if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))

908 sortBeforeOthers.insert(*i);

909 }

910

911 auto const next = std::stable_partition(begin, end, [&](auto const &a) {

912 return sortBeforeOthers.contains(a);

913 });

914 assert(next != begin && "not a partial ordering");

915 begin = next;

916 }

917 }

918

919

920 static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {

921 auto *cqa = dyn_cast(a->question);

922 if (!cqa)

923 return false;

924

925 auto positionDependsOnA = [&](Position *p) {

926 auto *cp = dyn_cast(p);

927 return cp && cp->getQuestion() == cqa;

928 };

929

930 if (auto *cqb = dyn_cast(b->question)) {

931

932 return llvm::any_of(cqb->getArgs(), positionDependsOnA);

933 }

934 if (auto *equalTo = dyn_cast(b->question)) {

935 return positionDependsOnA(b->position) ||

936 positionDependsOnA(equalTo->getValue());

937 }

938 return positionDependsOnA(b->position);

939 }

940

941

942

943 std::unique_ptr

944 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,

946

947

948 struct PatternPredicates {

949 PatternPredicates(pdl::PatternOp pattern, Value root,

950 std::vector predicates)

951 : pattern(pattern), root(root), predicates(std::move(predicates)) {}

952

953

954 pdl::PatternOp pattern;

955

956

958

959

960 std::vector predicates;

961 };

962

964 for (pdl::PatternOp pattern : module.getOpspdl::PatternOp()) {

965 std::vector predicateList;

967 buildPredicateList(pattern, builder, predicateList, valueToPosition);

968 patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));

969 }

970

971

973 for (auto &patternAndPredList : patternsAndPredicates) {

974 for (auto &predicate : patternAndPredList.predicates) {

975 auto it = uniqued.insert(predicate);

976 it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,

977 predicate.answer);

978

979 if (it.second)

980 it.first->id = uniqued.size() - 1;

981 }

982 }

983

984

985 std::vector lists;

986 lists.reserve(patternsAndPredicates.size());

987 for (auto &patternAndPredList : patternsAndPredicates) {

988 OrderedPredicateList list(patternAndPredList.pattern,

989 patternAndPredList.root);

990 for (auto &predicate : patternAndPredList.predicates) {

991 OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);

992 list.predicates.insert(orderedPredicate);

993

994

995 ++orderedPredicate->primary;

996 }

997 lists.push_back(std::move(list));

998 }

999

1000

1001

1002

1003 for (auto &list : lists) {

1004 unsigned total = 0;

1005 for (auto *predicate : list.predicates)

1006 total += predicate->primary * predicate->primary;

1007 for (auto *predicate : list.predicates)

1008 predicate->secondary += total;

1009 }

1010

1011

1012

1013 std::vector<OrderedPredicate *> ordered;

1014 ordered.reserve(uniqued.size());

1015 for (auto &ip : uniqued)

1016 ordered.push_back(&ip);

1017 llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {

1018 return *lhs < *rhs;

1019 });

1020

1021

1022

1024

1025

1026 std::unique_ptr root;

1027 for (OrderedPredicateList &list : lists)

1028 propagatePattern(root, list, ordered.begin(), ordered.end());

1029

1030

1033 return root;

1034 }

1035

1036

1037

1038

1039

1041 std::unique_ptr failureNode)

1042 : position(p), question(q), failureNode(std::move(failureNode)),

1043 matcherTypeID(matcherTypeID) {}

1044

1045

1046

1047

1048

1050 std::unique_ptr successNode,

1051 std::unique_ptr failureNode)

1053 std::move(failureNode)),

1054 answer(answer), successNode(std::move(successNode)) {}

1055

1056

1057

1058

1059

1061 std::unique_ptr failureNode)

1063 nullptr, std::move(failureNode)),

1064 pattern(pattern), root(root) {}

1065

1066

1067

1068

1069

static Value buildPredicateList(pdl::PatternOp pattern, PredicateBuilder &builder, std::vector< PositionalPredicate > &predList, DenseMap< Value, Position * > &valueToPosition)

Given a pattern operation, build the set of matcher predicates necessary to match this pattern.

static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)

static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b)

Returns true if 'b' depends on a result of 'a'.

static void getTypePredicates(Value typeValue, function_ref< Attribute()> typeAttrFn, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)

static void getNonTreePredicates(pdl::PatternOp pattern, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)

Collect all of the predicates that cannot be determined via walking the tree.

static bool useOperandGroup(pdl::OperationOp op, unsigned index)

Returns true if the operand at the given index needs to be queried using an operand group,...

static void getTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)

Collect the tree predicates anchored at the given value.

static void getResultPredicates(pdl::ResultOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)

static void getOperandTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)

Collect all of the predicates for the given operand position.

std::unique_ptr< MatcherNode > & getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate, pdl::PatternOp pattern)

Get or insert a child matcher for the given parent switch node, given a predicate and parent pattern.

static bool comparePosDepth(Position *lhs, Position *rhs)

Compares the depths of two positions.

static void visitUpward(std::vector< PositionalPredicate > &predList, OpIndex opIndex, PredicateBuilder &builder, DenseMap< Value, Position * > &valueToPosition, Position *&pos, unsigned rootID)

Visit a node during upward traversal.

static unsigned getNumNonRangeValues(ValueRange values)

Returns the number of non-range elements within values.

static SmallVector< Value > detectRoots(pdl::PatternOp pattern)

Given a pattern, determines the set of roots present in this pattern.

static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp)

Sorts the range begin/end with the partial order given by cmp.

static void insertExitNode(std::unique_ptr< MatcherNode > *root)

Insert an exit node at the end of the failure path of the root.

static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate)

Returns true if the given matcher refers to the same predicate as the given ordered predicate.

static void foldSwitchToBool(std::unique_ptr< MatcherNode > &node)

Fold any switch nodes nested under node to boolean nodes when possible.

static void buildCostGraph(ArrayRef< Value > roots, RootOrderingGraph &graph, ParentMaps &parentMaps)

Given a list of candidate roots, builds the cost graph for connecting them.

static void getAttributePredicates(pdl::AttributeOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)

static void propagatePattern(std::unique_ptr< MatcherNode > &node, OrderedPredicateList &list, std::vector< OrderedPredicate * >::iterator current, std::vector< OrderedPredicate * >::iterator end)

Build the matcher CFG by "pushing" patterns through by sorted predicate order.

Attributes are known-constant values of operations.

This class implements the operand iterators for the Operation class.

type_range getType() const

Operation is the basic unit of execution within MLIR.

This class implements the result iterators for the Operation class.

type_range getTypes() const

This class provides an efficient unique identifier for a specific C++ type.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

type_range getTypes() const

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Location getLoc() const

Return the location of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

This class represents the base of a predicate matcher node.

Position * getPosition() const

Returns the position on which the question predicate should be checked.

Qualifier * getQuestion() const

Returns the predicate checked on this node.

The optimal branching algorithm solver.

unsigned solve()

Runs the Edmonds' algorithm for the current graph, returning the total cost of the minimum-weight spa...

std::vector< std::pair< Value, Value > > EdgeList

A list of edges (child, parent).

EdgeList preOrderTraversal(ArrayRef< Value > nodes) const

Returns the computed edges as visited in the preorder traversal.

A position describes a value on the input IR on which a predicate may be applied, such as an operatio...

unsigned getOperationDepth() const

Returns the depth of the first ancestor operation position.

Predicates::Kind getKind() const

Returns the kind of this position.

const KeyTy & getValue() const

Return the key value of this predicate.

This class provides utilities for constructing predicates.

ConstraintPosition * getConstraintPosition(ConstraintQuestion *q, unsigned index)

Position * getTypeLiteral(Attribute attr)

Returns a type position for the given type value.

Predicate getOperandCount(unsigned count)

Create a predicate comparing the number of operands of an operation to a known value.

OperationPosition * getPassthroughOp(Position *p)

Returns the operation position equivalent to the given position.

Predicate getIsNotNull()

Create a predicate comparing a value with null.

Predicate getOperandCountAtLeast(unsigned count)

Predicate getResultCountAtLeast(unsigned count)

Position * getType(Position *p)

Returns a type position for the given entity.

Position * getAttribute(OperationPosition *p, StringRef name)

Returns an attribute position for an attribute of the given operation.

Position * getOperandGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)

Returns a position for a group of operands of the given operation.

Position * getForEach(Position *p, unsigned id)

Position * getOperand(OperationPosition *p, unsigned operand)

Returns an operand position for an operand of the given operation.

Position * getResult(OperationPosition *p, unsigned result)

Returns a result position for a result of the given operation.

Position * getRoot()

Returns the root operation position.

Predicate getAttributeConstraint(Attribute attr)

Create a predicate comparing an attribute to a known value.

Position * getResultGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)

Returns a position for a group of results of the given operation.

Position * getAllResults(OperationPosition *p)

UsersPosition * getUsers(Position *p, bool useRepresentative)

Returns the users of a position using the value at the given operand.

Predicate getTypeConstraint(Attribute type)

Create a predicate comparing the type of an attribute or value to a known type.

OperationPosition * getOperandDefiningOp(Position *p)

Returns the parent position defining the value held by the given operand.

Predicate getResultCount(unsigned count)

Create a predicate comparing the number of results of an operation to a known value.

std::pair< Qualifier *, Qualifier * > Predicate

An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...

Predicate getEqualTo(Position *pos)

Create a predicate checking if two values are equal.

Position * getAllOperands(OperationPosition *p)

Position * getAttributeLiteral(Attribute attr)

Returns an attribute position for the given attribute.

Predicate getConstraint(StringRef name, ArrayRef< Position * > args, ArrayRef< Type > resultTypes, bool isNegated)

Create a predicate that applies a generic constraint.

Predicate getOperationName(StringRef name)

Create a predicate comparing the name of an operation to a known value.

An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...

Predicates::Kind getKind() const

Returns the kind of this qualifier.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

bool operator<(const Fraction &x, const Fraction &y)

Include the generated interface declarations.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

A position describing an attribute of an operation.

A BoolNode denotes a question with a boolean-like result.

BoolNode(Position *position, Qualifier *question, Qualifier *answer, std::unique_ptr< MatcherNode > successNode, std::unique_ptr< MatcherNode > failureNode=nullptr)

A position describing the result of a native constraint.

Apply a parameterized constraint to multiple position values and possibly produce results.

An operation position describes an operation node in the IR.

bool isRoot() const

Returns if this operation position corresponds to the root.

bool isOperandDefiningOp() const

Returns if this operation represents an operand defining op.

A PositionalPredicate is a predicate that is associated with a specific positional value.

The information associated with an edge in the cost graph.

Value connector

The connector value in the intersection of the two subtrees rooted at the source and target root that...

std::pair< unsigned, unsigned > cost

The depth of the connector Value w.r.t.

A SuccessNode denotes that a given high level pattern has successfully been matched.

SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr< MatcherNode > failureNode)

A SwitchNode denotes a question with multiple potential results.

llvm::MapVector< Qualifier *, std::unique_ptr< MatcherNode > > ChildMapT

Returns the children of this switch node.

ChildMapT & getChildren()

SwitchNode(Position *position, Qualifier *question)

A position describing the result type of an entity, i.e.