MLIR: lib/Dialect/SCF/IR/SCF.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
26 #include "llvm/ADT/MapVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
28 #include "llvm/ADT/TypeSwitch.h"
29
30 using namespace mlir;
32
33 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
34
35
36
37
38
39 namespace {
42
43
45 IRMapping &valueMapping) const final {
46 return true;
47 }
48
49
51 return true;
52 }
53
54
55 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
56 auto retValOp = dyn_castscf::YieldOp(op);
57 if (!retValOp)
58 return;
59
60 for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
61 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
62 }
63 }
64 };
65 }
66
67
68
69
70
71 void SCFDialect::initialize() {
72 addOperations<
73 #define GET_OP_LIST
74 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
75 >();
76 addInterfaces();
77 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
78 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
79 InParallelOp, ReduceReturnOp>();
80 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
81 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
82 ForallOp, InParallelOp, WhileOp, YieldOp>();
83 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
84 }
85
86
88 builder.createscf::YieldOp(loc);
89 }
90
91
92
93 template
95 StringRef errorMessage) {
96 Operation *terminatorOperation = nullptr;
98 terminatorOperation = ®ion.front().back();
99 if (auto yield = dyn_cast_or_null(terminatorOperation))
100 return yield;
101 }
103 if (terminatorOperation)
104 diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
105 return nullptr;
106 }
107
108
109
110
111
112
113
116 assert(llvm::hasSingleElement(region) && "expected single-region block");
122 rewriter.eraseOp(terminator);
123 }
124
125
126
127
128
129
130
131
132
133
134
135
139 return failure();
140
141
143 if (parser.parseRegion(*body, {}, {}) ||
145 return failure();
146
147 return success();
148 }
149
152
153 p << ' ';
155 false,
156 true);
157
159 }
160
162 if (getRegion().empty())
163 return emitOpError("region needs to have at least one block");
164 if (getRegion().front().getNumArguments() > 0)
165 return emitOpError("region cannot have any arguments");
166 return success();
167 }
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
185
188 if (!llvm::hasSingleElement(op.getRegion()))
189 return failure();
191 return success();
192 }
193 };
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
234
237 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
238 return failure();
239
240 Block *prevBlock = op->getBlock();
241 Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
243
244 rewriter.createcf::BranchOp(op.getLoc(), &op.getRegion().front());
245
246 for (Block &blk : op.getRegion()) {
247 if (YieldOp yieldOp = dyn_cast(blk.getTerminator())) {
249 rewriter.createcf::BranchOp(yieldOp.getLoc(), postBlock,
250 yieldOp.getResults());
251 rewriter.eraseOp(yieldOp);
252 }
253 }
254
257
258 for (auto res : op.getResults())
259 blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
260
261 rewriter.replaceOp(op, blockArgs);
262 return success();
263 }
264 };
265
266 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
269 }
270
271 void ExecuteRegionOp::getSuccessorRegions(
273
276 return;
277 }
278
279
281 }
282
283
284
285
286
289 assert((point.isParent() || point == getParentOp().getAfter()) &&
290 "condition op can only exit the loop or branch to the after"
291 "region");
292
293 return getArgsMutable();
294 }
295
296 void ConditionOp::getSuccessorRegions(
298 FoldAdaptor adaptor(operands, *this);
299
300 WhileOp whileOp = getParentOp();
301
302
303
304 auto boolAttr = dyn_cast_or_null(adaptor.getCondition());
305 if (!boolAttr || boolAttr.getValue())
306 regions.emplace_back(&whileOp.getAfter(),
307 whileOp.getAfter().getArguments());
308 if (!boolAttr || !boolAttr.getValue())
309 regions.emplace_back(whileOp.getResults());
310 }
311
312
313
314
315
318 BodyBuilderFn bodyBuilder) {
320
323 for (Value v : initArgs)
324 result.addTypes(v.getType());
329 for (Value v : initArgs)
330 bodyBlock->addArgument(v.getType(), v.getLoc());
331
332
333
334
335 if (initArgs.empty() && !bodyBuilder) {
336 ForOp::ensureTerminator(*bodyRegion, builder, result.location);
337 } else if (bodyBuilder) {
342 }
343 }
344
346
347 if (getInitArgs().size() != getNumResults())
348 return emitOpError(
349 "mismatch in number of loop-carried values and defined values");
350
351 return success();
352 }
353
354 LogicalResult ForOp::verifyRegions() {
355
356
358 return emitOpError(
359 "expected induction variable to be same type as bounds and step");
360
361 if (getNumRegionIterArgs() != getNumResults())
362 return emitOpError(
363 "mismatch in number of basic block args and defined values");
364
365 auto initArgs = getInitArgs();
366 auto iterArgs = getRegionIterArgs();
367 auto opResults = getResults();
368 unsigned i = 0;
369 for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
370 if (std::get<0>(e).getType() != std::get<2>(e).getType())
371 return emitOpError() << "types mismatch between " << i
372 << "th iter operand and defined value";
373 if (std::get<1>(e).getType() != std::get<2>(e).getType())
374 return emitOpError() << "types mismatch between " << i
375 << "th iter region arg and defined value";
376
377 ++i;
378 }
379 return success();
380 }
381
382 std::optional<SmallVector> ForOp::getLoopInductionVars() {
384 }
385
386 std::optional<SmallVector> ForOp::getLoopLowerBounds() {
388 }
389
390 std::optional<SmallVector> ForOp::getLoopSteps() {
392 }
393
394 std::optional<SmallVector> ForOp::getLoopUpperBounds() {
396 }
397
398 std::optional ForOp::getLoopResults() { return getResults(); }
399
400
401
403 std::optional<int64_t> tripCount =
405 if (!tripCount.has_value() || tripCount != 1)
406 return failure();
407
408
409 auto yieldOp = castscf::YieldOp(getBody()->getTerminator());
411
412
413
416 llvm::append_range(bbArgReplacements, getInitArgs());
417
418
419 rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
420 getOperation()->getIterator(), bbArgReplacements);
421
422
423 rewriter.eraseOp(yieldOp);
425
426 return success();
427 }
428
429
430
431
432
436 StringRef prefix = "") {
437 assert(blocksArgs.size() == initializers.size() &&
438 "expected same length of arguments and initializers");
439 if (initializers.empty())
440 return;
441
442 p << prefix << '(';
443 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
444 p << std::get<0>(it) << " = " << std::get<1>(it);
445 });
446 p << ")";
447 }
448
450 p << " " << getInductionVar() << " = " << getLowerBound() << " to "
452
454 if (!getInitArgs().empty())
455 p << " -> (" << getInitArgs().getTypes() << ')';
456 p << ' ';
458 p << " : " << t << ' ';
460 false,
461 !getInitArgs().empty());
463 }
464
468
471
472
474
478 return failure();
479
480
483 regionArgs.push_back(inductionVariable);
484
486 if (hasIterArgs) {
487
490 return failure();
491 }
492
493 if (regionArgs.size() != result.types.size() + 1)
496 "mismatch in number of loop-carried values and defined values");
497
498
502 return failure();
503
504
505 regionArgs.front().type = type;
506 for (auto [iterArg, type] :
507 llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
508 iterArg.type = type;
509
510
512 if (parser.parseRegion(*body, regionArgs))
513 return failure();
514 ForOp::ensureTerminator(*body, builder, result.location);
515
516
517
521 return failure();
522 if (hasIterArgs) {
523 for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
524 operands, result.types)) {
525 Type type = std::get<2>(argOperandType);
526 std::get<0>(argOperandType).type = type;
527 if (parser.resolveOperand(std::get<1>(argOperandType), type,
529 return failure();
530 }
531 }
532
533
535 return failure();
536
537 return success();
538 }
539
541
543 return getBody()->getArguments().drop_front(getNumInductionVars());
544 }
545
547 return getInitArgsMutable();
548 }
549
550 FailureOr
551 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
553 bool replaceInitOperandUsesInLoop,
555
558 auto inits = llvm::to_vector(getInitArgs());
559 inits.append(newInitOperands.begin(), newInitOperands.end());
560 scf::ForOp newLoop = rewriter.createscf::ForOp(
564
565
566 auto yieldOp = castscf::YieldOp(getBody()->getTerminator());
568 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
569 {
573 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
574 assert(newInitOperands.size() == newYieldedValues.size() &&
575 "expected as many new yield values as new iter operands");
577 yieldOp.getResultsMutable().append(newYieldedValues);
578 });
579 }
580
581
582 rewriter.mergeBlocks(getBody(), newLoop.getBody(),
583 newLoop.getBody()->getArguments().take_front(
584 getBody()->getNumArguments()));
585
586 if (replaceInitOperandUsesInLoop) {
587
588
589 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
594 });
595 }
596 }
597
598
599 rewriter.replaceOp(getOperation(),
600 newLoop->getResults().take_front(getNumResults()));
601 return cast(newLoop.getOperation());
602 }
603
605 auto ivArg = llvm::dyn_cast(val);
606 if (!ivArg)
607 return ForOp();
608 assert(ivArg.getOwner() && "unlinked block argument");
609 auto *containingOp = ivArg.getOwner()->getParentOp();
610 return dyn_cast_or_null(containingOp);
611 }
612
614 return getInitArgs();
615 }
616
619
620
621
622 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
624 }
625
627
628
629
631 for (auto [lb, ub, step] :
632 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
634 if (!tripCount.has_value() || *tripCount != 1)
635 return failure();
636 }
637
638 promote(rewriter, *this);
639 return success();
640 }
641
643 return getBody()->getArguments().drop_front(getRank());
644 }
645
647 return getOutputsMutable();
648 }
649
650
653 scf::InParallelOp terminator = forallOp.getTerminator();
654
655
656
657 SmallVector bbArgReplacements = forallOp.getLowerBound(rewriter);
658 bbArgReplacements.append(forallOp.getOutputs().begin(),
659 forallOp.getOutputs().end());
660
661
662 rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
663 forallOp->getIterator(), bbArgReplacements);
664
665
668 results.reserve(forallOp.getResults().size());
669 for (auto &yieldingOp : terminator.getYieldingOps()) {
670 auto parallelInsertSliceOp =
671 casttensor::ParallelInsertSliceOp(yieldingOp);
672
673 Value dst = parallelInsertSliceOp.getDest();
674 Value src = parallelInsertSliceOp.getSource();
675 if (llvm::isa(src.getType())) {
676 results.push_back(rewriter.createtensor::InsertSliceOp(
677 forallOp.getLoc(), dst.getType(), src, dst,
678 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
679 parallelInsertSliceOp.getStrides(),
680 parallelInsertSliceOp.getStaticOffsets(),
681 parallelInsertSliceOp.getStaticSizes(),
682 parallelInsertSliceOp.getStaticStrides()));
683 } else {
684 llvm_unreachable("unsupported terminator");
685 }
686 }
688
689
690 rewriter.eraseOp(terminator);
691 rewriter.eraseOp(forallOp);
692 }
693
698 bodyBuilder) {
699 assert(lbs.size() == ubs.size() &&
700 "expected the same number of lower and upper bounds");
701 assert(lbs.size() == steps.size() &&
702 "expected the same number of lower bounds and steps");
703
704
705 if (lbs.empty()) {
707 bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
709 assert(results.size() == iterArgs.size() &&
710 "loop nest body must return as many values as loop has iteration "
711 "arguments");
712 return LoopNest{{}, std::move(results)};
713 }
714
715
716
720 loops.reserve(lbs.size());
721 ivs.reserve(lbs.size());
722 ValueRange currentIterArgs = iterArgs;
724 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
725 auto loop = builder.createscf::ForOp(
726 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
729 ivs.push_back(iv);
730
731
732 currentIterArgs = args;
733 currentLoc = nestedLoc;
734 });
735
736
737
739 loops.push_back(loop);
740 }
741
742
743 for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
745 builder.createscf::YieldOp(loc, loops[i + 1].getResults());
746 }
747
748
749
752 ? bodyBuilder(builder, currentLoc, ivs,
753 loops.back().getRegionIterArgs())
755 assert(results.size() == iterArgs.size() &&
756 "loop nest body must return as many values as loop has iteration "
757 "arguments");
759 builder.createscf::YieldOp(loc, results);
760
761
763 llvm::append_range(nestResults, loops.front().getResults());
764 return LoopNest{std::move(loops), std::move(nestResults)};
765 }
766
771
772 return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
773 [&bodyBuilder](OpBuilder &nestedBuilder,
776 if (bodyBuilder)
777 bodyBuilder(nestedBuilder, nestedLoc, ivs);
778 return {};
779 });
780 }
781
786 assert(operand.getOwner() == forOp);
788
789
790 assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
791 "expected an iter OpOperand");
793 "Expected a different type");
795 for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
796 if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
797 newIterOperands.push_back(replacement);
798 continue;
799 }
800 newIterOperands.push_back(opOperand.get());
801 }
802
803
804 scf::ForOp newForOp = rewriter.createscf::ForOp(
805 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
806 forOp.getStep(), newIterOperands);
807 newForOp->setAttrs(forOp->getAttrs());
808 Block &newBlock = newForOp.getRegion().front();
811
812
813
816 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
818 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
819 newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
820
821
822 Block &oldBlock = forOp.getRegion().front();
823 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
824
825
826 auto clonedYieldOp = castscf::YieldOp(newBlock.getTerminator());
828 unsigned yieldIdx =
829 newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
830 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
831 clonedYieldOp.getOperand(yieldIdx));
833 newYieldOperands[yieldIdx] = castOut;
834 rewriter.createscf::YieldOp(newForOp.getLoc(), newYieldOperands);
835 rewriter.eraseOp(clonedYieldOp);
836
837
840 newResults[yieldIdx] =
841 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
842
843 return newResults;
844 }
845
846 namespace {
847
848
849
850
851
852
853
854
855
856
857
858 struct ForOpIterArgsFolder : public OpRewritePatternscf::ForOp {
860
861 LogicalResult matchAndRewrite(scf::ForOp forOp,
863 bool canonicalize = false;
864
865
866
867
868
869
870 int64_t numResults = forOp.getNumResults();
872 keepMask.reserve(numResults);
874 newResultValues;
875 newBlockTransferArgs.reserve(1 + numResults);
876 newBlockTransferArgs.push_back(Value());
877 newIterArgs.reserve(forOp.getInitArgs().size());
878 newYieldValues.reserve(numResults);
879 newResultValues.reserve(numResults);
881 for (auto [init, arg, result, yielded] :
882 llvm::zip(forOp.getInitArgs(),
883 forOp.getRegionIterArgs(),
884 forOp.getResults(),
885 forOp.getYieldedValues()
886 )) {
887
888
889
890
891
892 bool forwarded = (arg == yielded) || (init == yielded) ||
893 (arg.use_empty() && result.use_empty());
894 if (forwarded) {
895 canonicalize = true;
896 keepMask.push_back(false);
897 newBlockTransferArgs.push_back(init);
898 newResultValues.push_back(init);
899 continue;
900 }
901
902
903
904 if (auto it = initYieldToArg.find({init, yielded});
905 it != initYieldToArg.end()) {
906 canonicalize = true;
907 keepMask.push_back(false);
908 auto [sameArg, sameResult] = it->second;
911
912 newBlockTransferArgs.push_back(init);
913 newResultValues.push_back(init);
914 continue;
915 }
916
917
918 initYieldToArg.insert({{init, yielded}, {arg, result}});
919 keepMask.push_back(true);
920 newIterArgs.push_back(init);
921 newYieldValues.push_back(yielded);
922 newBlockTransferArgs.push_back(Value());
923 newResultValues.push_back(Value());
924 }
925
926 if (!canonicalize)
927 return failure();
928
929 scf::ForOp newForOp = rewriter.createscf::ForOp(
930 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
931 forOp.getStep(), newIterArgs);
932 newForOp->setAttrs(forOp->getAttrs());
933 Block &newBlock = newForOp.getRegion().front();
934
935
936 newBlockTransferArgs[0] = newBlock.getArgument(0);
937 for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
938 idx != e; ++idx) {
939 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
940 Value &newResultVal = newResultValues[idx];
941 assert((blockTransferArg && newResultVal) ||
942 (!blockTransferArg && !newResultVal));
943 if (!blockTransferArg) {
944 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
945 newResultVal = newForOp.getResult(collapsedIdx++);
946 }
947 }
948
949 Block &oldBlock = forOp.getRegion().front();
950 assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
951 "unexpected argument size mismatch");
952
953
954
955
956 if (newIterArgs.empty()) {
957 auto newYieldOp = castscf::YieldOp(newBlock.getTerminator());
958 rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
960 rewriter.replaceOp(forOp, newResultValues);
961 return success();
962 }
963
964
965 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
969 filteredOperands.reserve(newResultValues.size());
970 for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
971 if (keepMask[idx])
972 filteredOperands.push_back(mergedTerminator.getOperand(idx));
973 rewriter.createscf::YieldOp(mergedTerminator.getLoc(),
974 filteredOperands);
975 };
976
977 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
978 auto mergedYieldOp = castscf::YieldOp(newBlock.getTerminator());
979 cloneFilteredTerminator(mergedYieldOp);
980 rewriter.eraseOp(mergedYieldOp);
981 rewriter.replaceOp(forOp, newResultValues);
982 return success();
983 }
984 };
985
986
987
988
989 static std::optional<int64_t> computeConstDiff(Value l, Value u) {
990 IntegerAttr clb, cub;
992 llvm::APInt lbValue = clb.getValue();
993 llvm::APInt ubValue = cub.getValue();
994 return (ubValue - lbValue).getSExtValue();
995 }
996
997
998 llvm::APInt diff;
1003 return diff.getSExtValue();
1004 return std::nullopt;
1005 }
1006
1007
1008
1009
1010 struct SimplifyTrivialLoops : public OpRewritePattern {
1012
1013 LogicalResult matchAndRewrite(ForOp op,
1015
1016
1017 if (op.getLowerBound() == op.getUpperBound()) {
1018 rewriter.replaceOp(op, op.getInitArgs());
1019 return success();
1020 }
1021
1022 std::optional<int64_t> diff =
1023 computeConstDiff(op.getLowerBound(), op.getUpperBound());
1024 if (!diff)
1025 return failure();
1026
1027
1028 if (*diff <= 0) {
1029 rewriter.replaceOp(op, op.getInitArgs());
1030 return success();
1031 }
1032
1033 std::optionalllvm::APInt maybeStepValue = op.getConstantStep();
1034 if (!maybeStepValue)
1035 return failure();
1036
1037
1038
1039 llvm::APInt stepValue = *maybeStepValue;
1040 if (stepValue.sge(*diff)) {
1042 blockArgs.reserve(op.getInitArgs().size() + 1);
1043 blockArgs.push_back(op.getLowerBound());
1044 llvm::append_range(blockArgs, op.getInitArgs());
1046 return success();
1047 }
1048
1049
1050 Block &block = op.getRegion().front();
1051 if (!llvm::hasSingleElement(block))
1052 return failure();
1053
1054
1055 if (llvm::any_of(op.getYieldedValues(),
1056 [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1057 return failure();
1058 rewriter.replaceOp(op, op.getYieldedValues());
1059 return success();
1060 }
1061 };
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089 struct ForOpTensorCastFolder : public OpRewritePattern {
1091
1092 LogicalResult matchAndRewrite(ForOp op,
1094 for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1095 OpOperand &iterOpOperand = std::get<0>(it);
1096 auto incomingCast = iterOpOperand.get().getDefiningOptensor::CastOp();
1097 if (!incomingCast ||
1098 incomingCast.getSource().getType() == incomingCast.getType())
1099 continue;
1100
1101
1103 incomingCast.getDest().getType(),
1104 incomingCast.getSource().getType()))
1105 continue;
1106 if (!std::get<1>(it).hasOneUse())
1107 continue;
1108
1109
1112 rewriter, op, iterOpOperand, incomingCast.getSource(),
1114 return b.createtensor::CastOp(loc, type, source);
1115 }));
1116 return success();
1117 }
1118 return failure();
1119 }
1120 };
1121
1122 }
1123
1124 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1126 results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1127 context);
1128 }
1129
1130 std::optional ForOp::getConstantStep() {
1131 IntegerAttr step;
1133 return step.getValue();
1134 return {};
1135 }
1136
1137 std::optional<MutableArrayRef> ForOp::getYieldedValuesMutable() {
1138 return castscf::YieldOp(getBody()->getTerminator()).getResultsMutable();
1139 }
1140
1142
1143
1144 if (auto constantStep = getConstantStep())
1145 if (*constantStep == 1)
1147
1148
1149
1151 }
1152
1153
1154
1155
1156
1158 unsigned numLoops = getRank();
1159
1160 if (getNumResults() != getOutputs().size())
1161 return emitOpError("produces ")
1162 << getNumResults() << " results, but has only "
1163 << getOutputs().size() << " outputs";
1164
1165
1166 auto *body = getBody();
1167 if (body->getNumArguments() != numLoops + getOutputs().size())
1168 return emitOpError("region expects ") << numLoops << " arguments";
1169 for (int64_t i = 0; i < numLoops; ++i)
1171 return emitOpError("expects ")
1172 << i << "-th block argument to be an index";
1173 for (unsigned i = 0; i < getOutputs().size(); ++i)
1174 if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1175 return emitOpError("type mismatch between ")
1176 << i << "-th output and corresponding block argument";
1177 if (getMapping().has_value() && !getMapping()->empty()) {
1178 if (static_cast<int64_t>(getMapping()->size()) != numLoops)
1179 return emitOpError() << "mapping attribute size must match op rank";
1180 for (auto map : getMapping()->getValue()) {
1181 if (!isa(map))
1182 return emitOpError()
1184 }
1185 }
1186
1187
1190 getStaticLowerBound(),
1191 getDynamicLowerBound())))
1192 return failure();
1194 getStaticUpperBound(),
1195 getDynamicUpperBound())))
1196 return failure();
1198 getStaticStep(), getDynamicStep())))
1199 return failure();
1200
1201 return success();
1202 }
1203
1206 p << " (" << getInductionVars();
1207 if (isNormalized()) {
1208 p << ") in ";
1210 {}, {},
1212 } else {
1213 p << ") = ";
1215 {}, {},
1217 p << " to ";
1219 {}, {},
1221 p << " step ";
1223 {}, {},
1225 }
1227 p << " ";
1228 if (!getRegionOutArgs().empty())
1229 p << "-> (" << getResultTypes() << ") ";
1230 p.printRegion(getRegion(),
1231 false,
1232 getNumResults() > 0);
1233 p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1234 getStaticLowerBoundAttrName(),
1235 getStaticUpperBoundAttrName(),
1236 getStaticStepAttrName()});
1237 }
1238
1241 auto indexType = b.getIndexType();
1242
1243
1244
1245
1248 return failure();
1249
1252 dynamicSteps;
1254
1256 nullptr,
1259 return failure();
1260
1261 unsigned numLoops = ivs.size();
1264 } else {
1265
1268 nullptr,
1270
1272 return failure();
1273
1274
1277 nullptr,
1280 return failure();
1281
1282
1285 nullptr,
1288 return failure();
1289 }
1290
1291
1296 if (outOperands.size() != result.types.size())
1297 return parser.emitError(outOperandsLoc,
1298 "mismatch between out operands and types");
1303 return failure();
1304 }
1305
1306
1308 std::unique_ptr region = std::make_unique();
1309 for (auto &iv : ivs) {
1310 iv.type = b.getIndexType();
1311 regionArgs.push_back(iv);
1312 }
1314 auto &out = it.value();
1315 out.type = result.types[it.index()];
1316 regionArgs.push_back(out);
1317 }
1318 if (parser.parseRegion(*region, regionArgs))
1319 return failure();
1320
1321
1322 ForallOp::ensureTerminator(*region, b, result.location);
1323 result.addRegion(std::move(region));
1324
1325
1327 return failure();
1328
1329 result.addAttribute("staticLowerBound", staticLbs);
1330 result.addAttribute("staticUpperBound", staticUbs);
1331 result.addAttribute("staticStep", staticSteps);
1334 {static_cast<int32_t>(dynamicLbs.size()),
1335 static_cast<int32_t>(dynamicUbs.size()),
1336 static_cast<int32_t>(dynamicSteps.size()),
1337 static_cast<int32_t>(outOperands.size())}));
1338 return success();
1339 }
1340
1341
1342 void ForallOp::build(
1346 std::optional mapping,
1353
1359
1360 result.addAttribute(getStaticLowerBoundAttrName(result.name),
1362 result.addAttribute(getStaticUpperBoundAttrName(result.name),
1367 "operandSegmentSizes",
1369 static_cast<int32_t>(dynamicUbs.size()),
1370 static_cast<int32_t>(dynamicSteps.size()),
1371 static_cast<int32_t>(outputs.size())}));
1372 if (mapping.has_value()) {
1374 mapping.value());
1375 }
1376
1380 Block &bodyBlock = bodyRegion->front();
1381
1382
1389
1391 if (!bodyBuilderFn) {
1392 ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1393 return;
1394 }
1396 }
1397
1398
1399 void ForallOp::build(
1402 std::optional mapping,
1404 unsigned numLoops = ubs.size();
1407 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1408 }
1409
1410
1411 bool ForallOp::isNormalized() {
1413 return llvm::all_of(results, [&](OpFoldResult ofr) {
1415 return intValue.has_value() && intValue == val;
1416 });
1417 };
1418 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1419 }
1420
1421 InParallelOp ForallOp::getTerminator() {
1422 return cast(getBody()->getTerminator());
1423 }
1424
1427 InParallelOp inParallelOp = getTerminator();
1428 for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1429 if (auto parallelInsertSliceOp =
1430 dyn_casttensor::ParallelInsertSliceOp(yieldOp);
1431 parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1432 storeOps.push_back(parallelInsertSliceOp);
1433 }
1434 }
1435 return storeOps;
1436 }
1437
1438 std::optional<SmallVector> ForallOp::getLoopInductionVars() {
1439 return SmallVector{getBody()->getArguments().take_front(getRank())};
1440 }
1441
1442
1443 std::optional<SmallVector> ForallOp::getLoopLowerBounds() {
1445 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1446 }
1447
1448
1449 std::optional<SmallVector> ForallOp::getLoopUpperBounds() {
1451 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1452 }
1453
1454
1455 std::optional<SmallVector> ForallOp::getLoopSteps() {
1457 return getMixedValues(getStaticStep(), getDynamicStep(), b);
1458 }
1459
1461 auto tidxArg = llvm::dyn_cast(val);
1462 if (!tidxArg)
1463 return ForallOp();
1464 assert(tidxArg.getOwner() && "unlinked block argument");
1465 auto *containingOp = tidxArg.getOwner()->getParentOp();
1466 return dyn_cast(containingOp);
1467 }
1468
1469 namespace {
1470
1471 struct DimOfForallOp : public OpRewritePatterntensor::DimOp {
1473
1474 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1476 auto forallOp = dimOp.getSource().getDefiningOp();
1477 if (!forallOp)
1478 return failure();
1479 Value sharedOut =
1480 forallOp.getTiedOpOperand(llvm::cast(dimOp.getSource()))
1481 ->get();
1483 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1484 return success();
1485 }
1486 };
1487
1488 class ForallOpControlOperandsFolder : public OpRewritePattern {
1489 public:
1491
1492 LogicalResult matchAndRewrite(ForallOp op,
1500 return failure();
1501
1503 SmallVector dynamicLowerBound, dynamicUpperBound, dynamicStep;
1506 staticLowerBound);
1507 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1508 op.setStaticLowerBound(staticLowerBound);
1509
1511 staticUpperBound);
1512 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1513 op.setStaticUpperBound(staticUpperBound);
1514
1516 op.getDynamicStepMutable().assign(dynamicStep);
1517 op.setStaticStep(staticStep);
1518
1519 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1521 {static_cast<int32_t>(dynamicLowerBound.size()),
1522 static_cast<int32_t>(dynamicUpperBound.size()),
1523 static_cast<int32_t>(dynamicStep.size()),
1524 static_cast<int32_t>(op.getNumResults())}));
1525 });
1526 return success();
1527 }
1528 };
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603 struct ForallOpIterArgsFolder : public OpRewritePattern {
1605
1606 LogicalResult matchAndRewrite(ForallOp forallOp,
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1625 for (OpResult result : forallOp.getResults()) {
1626 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1627 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1628 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1629 resultToDelete.insert(result);
1630 } else {
1631 resultToReplace.push_back(result);
1632 newOuts.push_back(opOperand->get());
1633 }
1634 }
1635
1636
1637
1638 if (resultToDelete.empty())
1639 return failure();
1640
1641
1642
1643
1644
1645
1646 for (OpResult result : resultToDelete) {
1647 OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1648 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1650 forallOp.getCombiningOps(blockArg);
1651 for (Operation *combiningOp : combiningOps)
1652 rewriter.eraseOp(combiningOp);
1653 }
1654
1655
1656
1657 auto newForallOp = rewriter.createscf::ForallOp(
1658 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1659 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1660 forallOp.getMapping(),
1662
1663
1664
1665 Block *loopBody = forallOp.getBody();
1666 Block *newLoopBody = newForallOp.getBody();
1668
1669
1671 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1674 unsigned index = 0;
1675
1676
1677
1678 for (OpResult result : forallOp.getResults()) {
1679 if (resultToDelete.count(result)) {
1680 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1681 } else {
1682 newBlockArgs.push_back(newSharedOutsArgs[index++]);
1683 }
1684 }
1685 rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1686
1687
1688
1689 for (auto &&[oldResult, newResult] :
1690 llvm::zip(resultToReplace, newForallOp->getResults()))
1692
1693
1694
1695
1696 for (OpResult oldResult : resultToDelete)
1698 forallOp.getTiedOpOperand(oldResult)->get());
1699 return success();
1700 }
1701 };
1702
1703 struct ForallOpSingleOrZeroIterationDimsFolder
1706
1707 LogicalResult matchAndRewrite(ForallOp op,
1709
1710 if (op.getMapping().has_value() && !op.getMapping()->empty())
1711 return failure();
1713
1714
1716 newMixedSteps;
1718 for (auto [lb, ub, step, iv] :
1719 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1720 op.getMixedStep(), op.getInductionVars())) {
1722 if (numIterations.has_value()) {
1723
1724 if (*numIterations == 0) {
1725 rewriter.replaceOp(op, op.getOutputs());
1726 return success();
1727 }
1728
1729
1730 if (*numIterations == 1) {
1732 continue;
1733 }
1734 }
1735 newMixedLowerBounds.push_back(lb);
1736 newMixedUpperBounds.push_back(ub);
1737 newMixedSteps.push_back(step);
1738 }
1739
1740
1741 if (newMixedLowerBounds.empty()) {
1743 return success();
1744 }
1745
1746
1747 if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1749 op, "no dimensions have 0 or 1 iterations");
1750 }
1751
1752
1753 ForallOp newOp;
1754 newOp = rewriter.create(loc, newMixedLowerBounds,
1755 newMixedUpperBounds, newMixedSteps,
1756 op.getOutputs(), std::nullopt, nullptr);
1757 newOp.getBodyRegion().getBlocks().clear();
1758
1759
1760
1762 newOp.getStaticLowerBoundAttrName(),
1763 newOp.getStaticUpperBoundAttrName(),
1764 newOp.getStaticStepAttrName()};
1765 for (const auto &namedAttr : op->getAttrs()) {
1766 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1767 continue;
1769 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1770 });
1771 }
1773 newOp.getRegion().begin(), mapping);
1774 rewriter.replaceOp(op, newOp.getResults());
1775 return success();
1776 }
1777 };
1778
1779
1780 struct ForallOpReplaceConstantInductionVar : public OpRewritePattern {
1782
1783 LogicalResult matchAndRewrite(ForallOp op,
1787 for (auto [lb, ub, step, iv] :
1788 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1789 op.getMixedStep(), op.getInductionVars())) {
1790 if (iv.hasNUses(0))
1791 continue;
1793 if (!numIterations.has_value() || numIterations.value() != 1) {
1794 continue;
1795 }
1799 }
1800 return success(changed);
1801 }
1802 };
1803
1804 struct FoldTensorCastOfOutputIntoForallOp
1807
1808 struct TypeCast {
1809 Type srcType;
1810 Type dstType;
1811 };
1812
1813 LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1815 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1818 auto castOp = en.value().getDefiningOptensor::CastOp();
1819 if (!castOp)
1820 continue;
1821
1822
1823
1825 castOp.getSource().getType())) {
1826 continue;
1827 }
1828
1829 tensorCastProducers[en.index()] =
1830 TypeCast{castOp.getSource().getType(), castOp.getType()};
1831 newOutputTensors[en.index()] = castOp.getSource();
1832 }
1833
1834 if (tensorCastProducers.empty())
1835 return failure();
1836
1837
1838 Location loc = forallOp.getLoc();
1839 auto newForallOp = rewriter.create(
1840 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1841 forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1843 auto castBlockArgs =
1844 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1845 for (auto [index, cast] : tensorCastProducers) {
1846 Value &oldTypeBBArg = castBlockArgs[index];
1847 oldTypeBBArg = nestedBuilder.createtensor::CastOp(
1848 nestedLoc, cast.dstType, oldTypeBBArg);
1849 }
1850
1851
1853 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1854 ivsBlockArgs.append(castBlockArgs);
1855 rewriter.mergeBlocks(forallOp.getBody(),
1856 bbArgs.front().getParentBlock(), ivsBlockArgs);
1857 });
1858
1859
1860
1861
1862 auto terminator = newForallOp.getTerminator();
1863 for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1864 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1865 auto insertSliceOp = casttensor::ParallelInsertSliceOp(yieldingOp);
1866 insertSliceOp.getDestMutable().assign(outputBlockArg);
1867 }
1868
1869
1872 for (auto &item : tensorCastProducers) {
1873 Value &oldTypeResult = castResults[item.first];
1874 oldTypeResult = rewriter.createtensor::CastOp(loc, item.second.dstType,
1875 oldTypeResult);
1876 }
1877 rewriter.replaceOp(forallOp, castResults);
1878 return success();
1879 }
1880 };
1881
1882 }
1883
1884 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1886 results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1887 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1888 ForallOpSingleOrZeroIterationDimsFolder,
1889 ForallOpReplaceConstantInductionVar>(context);
1890 }
1891
1892
1893
1894
1895
1896
1899
1900
1901
1904 }
1905
1906
1907
1908
1909
1910
1915 }
1916
1918 scf::ForallOp forallOp =
1919 dyn_castscf::ForallOp(getOperation()->getParentOp());
1920 if (!forallOp)
1921 return this->emitOpError("expected forall op parent");
1922
1923
1924 for (Operation &op : getRegion().front().getOperations()) {
1925 if (!isatensor::ParallelInsertSliceOp(op)) {
1926 return this->emitOpError("expected only ")
1927 << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1928 }
1929
1930
1931 Value dest = casttensor::ParallelInsertSliceOp(op).getDest();
1933 if (!llvm::is_contained(regionOutArgs, dest))
1934 return op.emitOpError("may only insert into an output block argument");
1935 }
1936 return success();
1937 }
1938
1940 p << " ";
1942 false,
1943 false);
1945 }
1946
1948 auto &builder = parser.getBuilder();
1949
1951 std::unique_ptr region = std::make_unique();
1952 if (parser.parseRegion(*region, regionOperands))
1953 return failure();
1954
1955 if (region->empty())
1957 result.addRegion(std::move(region));
1958
1959
1961 return failure();
1962 return success();
1963 }
1964
1965 OpResult InParallelOp::getParentResult(int64_t idx) {
1966 return getOperation()->getParentOp()->getResult(idx);
1967 }
1968
1970 return llvm::to_vector<4>(
1971 llvm::map_range(getYieldingOps(), [](Operation &op) {
1972
1973 auto insertSliceOp = casttensor::ParallelInsertSliceOp(&op);
1974 return llvm::cast(insertSliceOp.getDest());
1975 }));
1976 }
1977
1979 return getRegion().front().getOperations();
1980 }
1981
1982
1983
1984
1985
1987 assert(a && "expected non-empty operation");
1988 assert(b && "expected non-empty operation");
1989
1991 while (ifOp) {
1992
1993 if (ifOp->isProperAncestor(b))
1994
1995
1996 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1997 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1998
1999 ifOp = ifOp->getParentOfType();
2000 }
2001
2002
2003 return false;
2004 }
2005
2006 LogicalResult
2007 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional loc,
2008 IfOp::Adaptor adaptor,
2010 if (adaptor.getRegions().empty())
2011 return failure();
2012 Region *r = &adaptor.getThenRegion();
2013 if (r->empty())
2014 return failure();
2015 Block &b = r->front();
2017 return failure();
2018 auto yieldOp = llvm::dyn_cast(b.back());
2019 if (!yieldOp)
2020 return failure();
2021 TypeRange types = yieldOp.getOperandTypes();
2022 llvm::append_range(inferredReturnTypes, types);
2023 return success();
2024 }
2025
2028 return build(builder, result, resultTypes, cond, false,
2029 false);
2030 }
2031
2033 TypeRange resultTypes, Value cond, bool addThenBlock,
2034 bool addElseBlock) {
2035 assert((!addElseBlock || addThenBlock) &&
2036 "must not create else block w/o then block");
2037 result.addTypes(resultTypes);
2039
2040
2043 if (addThenBlock)
2046 if (addElseBlock)
2048 }
2049
2051 bool withElseRegion) {
2052 build(builder, result, TypeRange{}, cond, withElseRegion);
2053 }
2054
2056 TypeRange resultTypes, Value cond, bool withElseRegion) {
2057 result.addTypes(resultTypes);
2059
2060
2064 if (resultTypes.empty())
2065 IfOp::ensureTerminator(*thenRegion, builder, result.location);
2066
2067
2069 if (withElseRegion) {
2071 if (resultTypes.empty())
2072 IfOp::ensureTerminator(*elseRegion, builder, result.location);
2073 }
2074 }
2075
2079 assert(thenBuilder && "the builder callback for 'then' must be present");
2081
2082
2086 thenBuilder(builder, result.location);
2087
2088
2090 if (elseBuilder) {
2092 elseBuilder(builder, result.location);
2093 }
2094
2095
2099 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2100 nullptr, result.regions,
2101 inferredReturnTypes))) {
2102 result.addTypes(inferredReturnTypes);
2103 }
2104 }
2105
2107 if (getNumResults() != 0 && getElseRegion().empty())
2108 return emitOpError("must have an else block if defining values");
2109 return success();
2110 }
2111
2113
2114 result.regions.reserve(2);
2117
2118 auto &builder = parser.getBuilder();
2123 return failure();
2124
2126 return failure();
2127
2128 if (parser.parseRegion(*thenRegion, {}, {}))
2129 return failure();
2130 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2131
2132
2134 if (parser.parseRegion(*elseRegion, {}, {}))
2135 return failure();
2136 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2137 }
2138
2139
2141 return failure();
2142 return success();
2143 }
2144
2146 bool printBlockTerminators = false;
2147
2148 p << " " << getCondition();
2149 if (!getResults().empty()) {
2150 p << " -> (" << getResultTypes() << ")";
2151
2152 printBlockTerminators = true;
2153 }
2154 p << ' ';
2156 false,
2157 printBlockTerminators);
2158
2159
2160 auto &elseRegion = getElseRegion();
2161 if (!elseRegion.empty()) {
2162 p << " else ";
2164 false,
2165 printBlockTerminators);
2166 }
2167
2169 }
2170
2173
2176 return;
2177 }
2178
2180
2181
2182 Region *elseRegion = &this->getElseRegion();
2183 if (elseRegion->empty())
2185 else
2187 }
2188
2191 FoldAdaptor adaptor(operands, *this);
2192 auto boolAttr = dyn_cast_or_null(adaptor.getCondition());
2193 if (!boolAttr || boolAttr.getValue())
2194 regions.emplace_back(&getThenRegion());
2195
2196
2197 if (!boolAttr || !boolAttr.getValue()) {
2198 if (!getElseRegion().empty())
2199 regions.emplace_back(&getElseRegion());
2200 else
2201 regions.emplace_back(getResults());
2202 }
2203 }
2204
2205 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2207
2208 if (getElseRegion().empty())
2209 return failure();
2210
2211 arith::XOrIOp xorStmt = getCondition().getDefiningOparith::XOrIOp();
2212 if (!xorStmt)
2213 return failure();
2214
2216 return failure();
2217
2218 getConditionMutable().assign(xorStmt.getLhs());
2219 Block *thenBlock = &getThenRegion().front();
2220
2221
2222 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2223 getElseRegion().getBlocks());
2224 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2225 getThenRegion().getBlocks(), thenBlock);
2226 return success();
2227 }
2228
2229 void IfOp::getRegionInvocationBounds(
2232 if (auto cond = llvm::dyn_cast_or_null(operands[0])) {
2233
2234
2235 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2236 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2237 } else {
2238
2239 invocationBounds.assign(2, {0, 1});
2240 }
2241 }
2242
2243 namespace {
2244
2245 struct RemoveUnusedResults : public OpRewritePattern {
2247
2250
2252
2253 auto yieldOp = castscf::YieldOp(dest->getTerminator());
2255 llvm::transform(usedResults, std::back_inserter(usedOperands),
2258 });
2260 [&]() { yieldOp->setOperands(usedOperands); });
2261 }
2262
2263 LogicalResult matchAndRewrite(IfOp op,
2265
2267 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2268 [](OpResult result) { return !result.use_empty(); });
2269
2270
2271 if (usedResults.size() == op.getNumResults())
2272 return failure();
2273
2274
2276 llvm::transform(usedResults, std::back_inserter(newTypes),
2278
2279
2280 auto newOp =
2281 rewriter.create(op.getLoc(), newTypes, op.getCondition());
2282 rewriter.createBlock(&newOp.getThenRegion());
2283 rewriter.createBlock(&newOp.getElseRegion());
2284
2285
2286
2287 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2288 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2289
2290
2293 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2294 rewriter.replaceOp(op, repResults);
2295 return success();
2296 }
2297 };
2298
2299 struct RemoveStaticCondition : public OpRewritePattern {
2301
2302 LogicalResult matchAndRewrite(IfOp op,
2306 return failure();
2307
2310 else if (!op.getElseRegion().empty())
2312 else
2314
2315 return success();
2316 }
2317 };
2318
2319
2320
2321 struct ConvertTrivialIfToSelect : public OpRewritePattern {
2323
2324 LogicalResult matchAndRewrite(IfOp op,
2326 if (op->getNumResults() == 0)
2327 return failure();
2328
2329 auto cond = op.getCondition();
2330 auto thenYieldArgs = op.thenYield().getOperands();
2331 auto elseYieldArgs = op.elseYield().getOperands();
2332
2334 for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2335 if (&op.getThenRegion() == trueVal.getParentRegion() ||
2336 &op.getElseRegion() == falseVal.getParentRegion())
2337 nonHoistable.push_back(trueVal.getType());
2338 }
2339
2340
2341 if (nonHoistable.size() == op->getNumResults())
2342 return failure();
2343
2344 IfOp replacement = rewriter.create(op.getLoc(), nonHoistable, cond,
2345 false);
2346 if (replacement.thenBlock())
2347 rewriter.eraseBlock(replacement.thenBlock());
2348 replacement.getThenRegion().takeBody(op.getThenRegion());
2349 replacement.getElseRegion().takeBody(op.getElseRegion());
2350
2352 assert(thenYieldArgs.size() == results.size());
2353 assert(elseYieldArgs.size() == results.size());
2354
2358 for (const auto &it :
2359 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2360 Value trueVal = std::get<0>(it.value());
2361 Value falseVal = std::get<1>(it.value());
2362 if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2363 &replacement.getElseRegion() == falseVal.getParentRegion()) {
2364 results[it.index()] = replacement.getResult(trueYields.size());
2365 trueYields.push_back(trueVal);
2366 falseYields.push_back(falseVal);
2367 } else if (trueVal == falseVal)
2368 results[it.index()] = trueVal;
2369 else
2370 results[it.index()] = rewriter.createarith::SelectOp(
2371 op.getLoc(), cond, trueVal, falseVal);
2372 }
2373
2375 rewriter.replaceOpWithNewOp(replacement.thenYield(), trueYields);
2376
2378 rewriter.replaceOpWithNewOp(replacement.elseYield(), falseYields);
2379
2380 rewriter.replaceOp(op, results);
2381 return success();
2382 }
2383 };
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398 struct ConditionPropagation : public OpRewritePattern {
2400
2401 LogicalResult matchAndRewrite(IfOp op,
2403
2404
2406 return failure();
2407
2410
2411
2412
2413 Value constantTrue = nullptr;
2414 Value constantFalse = nullptr;
2415
2417 llvm::make_early_inc_range(op.getCondition().getUses())) {
2420
2421 if (!constantTrue)
2422 constantTrue = rewriter.createarith::ConstantOp(
2423 op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2424
2426 [&]() { use.set(constantTrue); });
2427 } else if (op.getElseRegion().isAncestor(
2430
2431 if (!constantFalse)
2432 constantFalse = rewriter.createarith::ConstantOp(
2433 op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2434
2436 [&]() { use.set(constantFalse); });
2437 }
2438 }
2439
2440 return success(changed);
2441 }
2442 };
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern {
2482
2483 LogicalResult matchAndRewrite(IfOp op,
2485
2486 if (op.getNumResults() == 0)
2487 return failure();
2488
2489 auto trueYield =
2490 castscf::YieldOp(op.getThenRegion().back().getTerminator());
2491 auto falseYield =
2492 castscf::YieldOp(op.getElseRegion().back().getTerminator());
2493
2495 op.getOperation()->getIterator());
2498 for (auto [trueResult, falseResult, opResult] :
2499 llvm::zip(trueYield.getResults(), falseYield.getResults(),
2500 op.getResults())) {
2501 if (trueResult == falseResult) {
2502 if (!opResult.use_empty()) {
2503 opResult.replaceAllUsesWith(trueResult);
2505 }
2506 continue;
2507 }
2508
2509 BoolAttr trueYield, falseYield;
2512 continue;
2513
2514 bool trueVal = trueYield.getValue();
2515 bool falseVal = falseYield.getValue();
2516 if (!trueVal && falseVal) {
2517 if (!opResult.use_empty()) {
2518 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2519 Value notCond = rewriter.createarith::XOrIOp(
2520 op.getLoc(), op.getCondition(),
2521 constDialect
2524 op.getLoc())
2528 }
2529 }
2530 if (trueVal && !falseVal) {
2531 if (!opResult.use_empty()) {
2532 opResult.replaceAllUsesWith(op.getCondition());
2534 }
2535 }
2536 }
2537 return success(changed);
2538 }
2539 };
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2564
2565 LogicalResult matchAndRewrite(IfOp nextIf,
2567 Block *parent = nextIf->getBlock();
2568 if (nextIf == &parent->front())
2569 return failure();
2570
2571 auto prevIf = dyn_cast(nextIf->getPrevNode());
2572 if (!prevIf)
2573 return failure();
2574
2575
2576
2577
2578
2579 Block *nextThen = nullptr;
2580 Block *nextElse = nullptr;
2581 if (nextIf.getCondition() == prevIf.getCondition()) {
2582 nextThen = nextIf.thenBlock();
2583 if (!nextIf.getElseRegion().empty())
2584 nextElse = nextIf.elseBlock();
2585 }
2586 if (arith::XOrIOp notv =
2587 nextIf.getCondition().getDefiningOparith::XOrIOp()) {
2588 if (notv.getLhs() == prevIf.getCondition() &&
2590 nextElse = nextIf.thenBlock();
2591 if (!nextIf.getElseRegion().empty())
2592 nextThen = nextIf.elseBlock();
2593 }
2594 }
2595 if (arith::XOrIOp notv =
2596 prevIf.getCondition().getDefiningOparith::XOrIOp()) {
2597 if (notv.getLhs() == nextIf.getCondition() &&
2599 nextElse = nextIf.thenBlock();
2600 if (!nextIf.getElseRegion().empty())
2601 nextThen = nextIf.elseBlock();
2602 }
2603 }
2604
2605 if (!nextThen && !nextElse)
2606 return failure();
2607
2609 if (!prevIf.getElseRegion().empty())
2610 prevElseYielded = prevIf.elseYield().getOperands();
2611
2612
2613 for (auto it : llvm::zip(prevIf.getResults(),
2614 prevIf.thenYield().getOperands(), prevElseYielded))
2616 llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2620 use.set(std::get<1>(it));
2625 use.set(std::get<2>(it));
2627 }
2628 }
2629
2631 llvm::append_range(mergedTypes, nextIf.getResultTypes());
2632
2633 IfOp combinedIf = rewriter.create(
2634 nextIf.getLoc(), mergedTypes, prevIf.getCondition(), false);
2635 rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2636
2638 combinedIf.getThenRegion(),
2639 combinedIf.getThenRegion().begin());
2640
2641 if (nextThen) {
2642 YieldOp thenYield = combinedIf.thenYield();
2643 YieldOp thenYield2 = cast(nextThen->getTerminator());
2644 rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2646
2648 llvm::append_range(mergedYields, thenYield2.getOperands());
2649 rewriter.create(thenYield2.getLoc(), mergedYields);
2650 rewriter.eraseOp(thenYield);
2651 rewriter.eraseOp(thenYield2);
2652 }
2653
2655 combinedIf.getElseRegion(),
2656 combinedIf.getElseRegion().begin());
2657
2658 if (nextElse) {
2659 if (combinedIf.getElseRegion().empty()) {
2661 combinedIf.getElseRegion(),
2662 combinedIf.getElseRegion().begin());
2663 } else {
2664 YieldOp elseYield = combinedIf.elseYield();
2665 YieldOp elseYield2 = cast(nextElse->getTerminator());
2666 rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2667
2669
2671 llvm::append_range(mergedElseYields, elseYield2.getOperands());
2672
2673 rewriter.create(elseYield2.getLoc(), mergedElseYields);
2674 rewriter.eraseOp(elseYield);
2675 rewriter.eraseOp(elseYield2);
2676 }
2677 }
2678
2681 for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2682 if (pair.index() < prevIf.getNumResults())
2683 prevValues.push_back(pair.value());
2684 else
2685 nextValues.push_back(pair.value());
2686 }
2687 rewriter.replaceOp(prevIf, prevValues);
2688 rewriter.replaceOp(nextIf, nextValues);
2689 return success();
2690 }
2691 };
2692
2693
2694 struct RemoveEmptyElseBranch : public OpRewritePattern {
2696
2697 LogicalResult matchAndRewrite(IfOp ifOp,
2699
2700 if (ifOp.getNumResults())
2701 return failure();
2702 Block *elseBlock = ifOp.elseBlock();
2703 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2704 return failure();
2706 rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2707 newIfOp.getThenRegion().begin());
2709 return success();
2710 }
2711 };
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2731
2732 LogicalResult matchAndRewrite(IfOp op,
2734 auto nestedOps = op.thenBlock()->without_terminator();
2735
2736 if (!llvm::hasSingleElement(nestedOps))
2737 return failure();
2738
2739
2740 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2741 return failure();
2742
2743 auto nestedIf = dyn_cast(*nestedOps.begin());
2744 if (!nestedIf)
2745 return failure();
2746
2747 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2748 return failure();
2749
2752 if (op.elseBlock())
2753 llvm::append_range(elseYield, op.elseYield().getOperands());
2754
2755
2756
2758
2759
2760
2761
2762
2763
2764
2765
2767 if (tup.value().getDefiningOp() == nestedIf) {
2768 auto nestedIdx = llvm::cast(tup.value()).getResultNumber();
2769 if (nestedIf.elseYield().getOperand(nestedIdx) !=
2770 elseYield[tup.index()]) {
2771 return failure();
2772 }
2773
2774
2775 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2776 continue;
2777 }
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788 if (tup.value().getParentRegion() == &op.getThenRegion()) {
2789 return failure();
2790 }
2791 elseYieldsToUpgradeToSelect.push_back(tup.index());
2792 }
2793
2795 Value newCondition = rewriter.createarith::AndIOp(
2796 loc, op.getCondition(), nestedIf.getCondition());
2797 auto newIf = rewriter.create(loc, op.getResultTypes(), newCondition);
2798 Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2799
2801 llvm::append_range(results, newIf.getResults());
2803
2804 for (auto idx : elseYieldsToUpgradeToSelect)
2805 results[idx] = rewriter.createarith::SelectOp(
2806 op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2807
2808 rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2811 if (!elseYield.empty()) {
2812 rewriter.createBlock(&newIf.getElseRegion());
2814 rewriter.create(loc, elseYield);
2815 }
2816 rewriter.replaceOp(op, results);
2817 return success();
2818 }
2819 };
2820
2821 }
2822
2823 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2825 results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2826 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2827 RemoveStaticCondition, RemoveUnusedResults,
2828 ReplaceIfYieldWithConditionOrValue>(context);
2829 }
2830
2831 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2832 YieldOp IfOp::thenYield() { return cast(&thenBlock()->back()); }
2833 Block *IfOp::elseBlock() {
2834 Region &r = getElseRegion();
2835 if (r.empty())
2836 return nullptr;
2837 return &r.back();
2838 }
2839 YieldOp IfOp::elseYield() { return cast(&elseBlock()->back()); }
2840
2841
2842
2843
2844
2845 void ParallelOp::build(
2849 bodyBuilderFn) {
2855 ParallelOp::getOperandSegmentSizeAttr(),
2857 static_cast<int32_t>(upperBounds.size()),
2858 static_cast<int32_t>(steps.size()),
2859 static_cast<int32_t>(initVals.size())}));
2861
2863 unsigned numIVs = steps.size();
2867 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2868
2869 if (bodyBuilderFn) {
2871 bodyBuilderFn(builder, result.location,
2872 bodyBlock->getArguments().take_front(numIVs),
2873 bodyBlock->getArguments().drop_front(numIVs));
2874 }
2875
2876 if (initVals.empty())
2877 ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2878 }
2879
2880 void ParallelOp::build(
2884
2885
2886
2887 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2890 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2891 };
2893 if (bodyBuilderFn)
2894 wrapper = wrappedBuilderFn;
2895
2896 build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2897 wrapper);
2898 }
2899
2901
2902
2903
2905 if (stepValues.empty())
2906 return emitOpError(
2907 "needs at least one tuple element for lowerBound, upperBound and step");
2908
2909
2910 for (Value stepValue : stepValues)
2912 if (*cst <= 0)
2913 return emitOpError("constant step operand must be positive");
2914
2915
2916
2917 Block *body = getBody();
2919 return emitOpError() << "expects the same number of induction variables: "
2921 << " as bound and step values: " << stepValues.size();
2923 if (!arg.getType().isIndex())
2924 return emitOpError(
2925 "expects arguments for the induction variable to be of index type");
2926
2927
2928 auto reduceOp = verifyAndGetTerminatorscf::ReduceOp(
2929 *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2930 if (!reduceOp)
2931 return failure();
2932
2933
2934 auto resultsSize = getResults().size();
2935 auto reductionsSize = reduceOp.getReductions().size();
2936 auto initValsSize = getInitVals().size();
2937 if (resultsSize != reductionsSize)
2938 return emitOpError() << "expects number of results: " << resultsSize
2939 << " to be the same as number of reductions: "
2940 << reductionsSize;
2941 if (resultsSize != initValsSize)
2942 return emitOpError() << "expects number of results: " << resultsSize
2943 << " to be the same as number of initial values: "
2944 << initValsSize;
2945
2946
2947 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2948 auto resultType = getOperation()->getResult(i).getType();
2949 auto reductionOperandType = reduceOp.getOperands()[i].getType();
2950 if (resultType != reductionOperandType)
2951 return reduceOp.emitOpError()
2952 << "expects type of " << i
2953 << "-th reduction operand: " << reductionOperandType
2954 << " to be the same as the " << i
2955 << "-th result type: " << resultType;
2956 }
2957 return success();
2958 }
2959
2961 auto &builder = parser.getBuilder();
2962
2965 return failure();
2966
2967
2973 return failure();
2974
2980 return failure();
2981
2982
2988 return failure();
2989
2990
2994 return failure();
2995 }
2996
2997
2999 return failure();
3000
3001
3003 for (auto &iv : ivs)
3006 return failure();
3007
3008
3010 ParallelOp::getOperandSegmentSizeAttr(),
3012 static_cast<int32_t>(upper.size()),
3013 static_cast<int32_t>(steps.size()),
3014 static_cast<int32_t>(initVals.size())}));
3015
3016
3020 return failure();
3021
3022
3023 ParallelOp::ensureTerminator(*body, builder, result.location);
3024 return success();
3025 }
3026
3028 p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3029 << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3030 if (!getInitVals().empty())
3031 p << " init (" << getInitVals() << ")";
3033 p << ' ';
3034 p.printRegion(getRegion(), false);
3036 (*this)->getAttrs(),
3037 ParallelOp::getOperandSegmentSizeAttr());
3038 }
3039
3041
3042 std::optional<SmallVector> ParallelOp::getLoopInductionVars() {
3044 }
3045
3046 std::optional<SmallVector> ParallelOp::getLoopLowerBounds() {
3048 }
3049
3050 std::optional<SmallVector> ParallelOp::getLoopUpperBounds() {
3052 }
3053
3054 std::optional<SmallVector> ParallelOp::getLoopSteps() {
3055 return getStep();
3056 }
3057
3059 auto ivArg = llvm::dyn_cast(val);
3060 if (!ivArg)
3061 return ParallelOp();
3062 assert(ivArg.getOwner() && "unlinked block argument");
3063 auto *containingOp = ivArg.getOwner()->getParentOp();
3064 return dyn_cast(containingOp);
3065 }
3066
3067 namespace {
3068
3069 struct ParallelOpSingleOrZeroIterationDimsFolder
3072
3073 LogicalResult matchAndRewrite(ParallelOp op,
3076
3077
3080 for (auto [lb, ub, step, iv] :
3081 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3082 op.getInductionVars())) {
3084 if (numIterations.has_value()) {
3085
3086 if (*numIterations == 0) {
3087 rewriter.replaceOp(op, op.getInitVals());
3088 return success();
3089 }
3090
3091
3092 if (*numIterations == 1) {
3094 continue;
3095 }
3096 }
3097 newLowerBounds.push_back(lb);
3098 newUpperBounds.push_back(ub);
3099 newSteps.push_back(step);
3100 }
3101
3102 if (newLowerBounds.size() == op.getLowerBound().size())
3103 return failure();
3104
3105 if (newLowerBounds.empty()) {
3106
3107
3109 results.reserve(op.getInitVals().size());
3110 for (auto &bodyOp : op.getBody()->without_terminator())
3111 rewriter.clone(bodyOp, mapping);
3112 auto reduceOp = cast(op.getBody()->getTerminator());
3113 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3114 Block &reduceBlock = reduceOp.getReductions()[i].front();
3115 auto initValIndex = results.size();
3116 mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3120 rewriter.clone(reduceBodyOp, mapping);
3121
3123 cast(reduceBlock.getTerminator()).getResult());
3124 results.push_back(result);
3125 }
3126
3127 rewriter.replaceOp(op, results);
3128 return success();
3129 }
3130
3131 auto newOp =
3132 rewriter.create(op.getLoc(), newLowerBounds, newUpperBounds,
3133 newSteps, op.getInitVals(), nullptr);
3134
3135 rewriter.eraseBlock(newOp.getBody());
3136
3137
3139 newOp.getRegion().begin(), mapping);
3140 rewriter.replaceOp(op, newOp.getResults());
3141 return success();
3142 }
3143 };
3144
3145 struct MergeNestedParallelLoops : public OpRewritePattern {
3147
3148 LogicalResult matchAndRewrite(ParallelOp op,
3150 Block &outerBody = *op.getBody();
3152 return failure();
3153
3154 auto innerOp = dyn_cast(outerBody.front());
3155 if (!innerOp)
3156 return failure();
3157
3159 if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3160 llvm::is_contained(innerOp.getUpperBound(), val) ||
3161 llvm::is_contained(innerOp.getStep(), val))
3162 return failure();
3163
3164
3165 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3166 return failure();
3167
3170 Block &innerBody = *innerOp.getBody();
3171 assert(iterVals.size() ==
3179 builder.clone(op, mapping);
3180 };
3181
3182 auto concatValues = [](const auto &first, const auto &second) {
3184 ret.reserve(first.size() + second.size());
3185 ret.assign(first.begin(), first.end());
3186 ret.append(second.begin(), second.end());
3187 return ret;
3188 };
3189
3190 auto newLowerBounds =
3191 concatValues(op.getLowerBound(), innerOp.getLowerBound());
3192 auto newUpperBounds =
3193 concatValues(op.getUpperBound(), innerOp.getUpperBound());
3194 auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3195
3196 rewriter.replaceOpWithNewOp(op, newLowerBounds, newUpperBounds,
3197 newSteps, std::nullopt,
3198 bodyBuilder);
3199 return success();
3200 }
3201 };
3202
3203 }
3204
3205 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3207 results
3208 .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3209 context);
3210 }
3211
3212
3213
3214
3215
3216
3217 void ParallelOp::getSuccessorRegions(
3219
3220
3221
3224 }
3225
3226
3227
3228
3229
3231
3235 for (Value v : operands) {
3241 }
3242 }
3243
3244 LogicalResult ReduceOp::verifyRegions() {
3245
3246
3247 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3248 auto type = getOperands()[i].getType();
3249 Block &block = getReductions()[i].front();
3250 if (block.empty())
3251 return emitOpError() << i << "-th reduction has an empty body";
3254 return arg.getType() != type;
3255 }))
3256 return emitOpError() << "expected two block arguments with type " << type
3257 << " in the " << i << "-th reduction region";
3258
3259
3260 if (!isa(block.getTerminator()))
3261 return emitOpError("reduction bodies must be terminated with an "
3262 "'scf.reduce.return' op");
3263 }
3264
3265 return success();
3266 }
3267
3270
3272 }
3273
3274
3275
3276
3277
3279
3280
3281 Block *reductionBody = getOperation()->getBlock();
3282
3283 assert(isa(reductionBody->getParentOp()) && "expected scf.reduce");
3285 if (expectedResultType != getResult().getType())
3286 return emitOpError() << "must have type " << expectedResultType
3287 << " (the type of the reduction inputs)";
3288 return success();
3289 }
3290
3291
3292
3293
3294
3297 ValueRange inits, BodyBuilderFn beforeBuilder,
3298 BodyBuilderFn afterBuilder) {
3300 odsState.addTypes(resultTypes);
3301
3303
3304
3306 beforeArgLocs.reserve(inits.size());
3307 for (Value operand : inits) {
3308 beforeArgLocs.push_back(operand.getLoc());
3309 }
3310
3312 Block *beforeBlock = odsBuilder.createBlock(beforeRegion, {},
3313 inits.getTypes(), beforeArgLocs);
3314 if (beforeBuilder)
3316
3317
3319
3321 Block *afterBlock = odsBuilder.createBlock(afterRegion, {},
3322 resultTypes, afterArgLocs);
3323
3324 if (afterBuilder)
3326 }
3327
3328 ConditionOp WhileOp::getConditionOp() {
3329 return cast(getBeforeBody()->getTerminator());
3330 }
3331
3332 YieldOp WhileOp::getYieldOp() {
3333 return cast(getAfterBody()->getTerminator());
3334 }
3335
3336 std::optional<MutableArrayRef> WhileOp::getYieldedValuesMutable() {
3337 return getYieldOp().getResultsMutable();
3338 }
3339
3341 return getBeforeBody()->getArguments();
3342 }
3343
3345 return getAfterBody()->getArguments();
3346 }
3347
3349 return getBeforeArguments();
3350 }
3351
3353 assert(point == getBefore() &&
3354 "WhileOp is expected to branch only to the first region");
3355 return getInits();
3356 }
3357
3360
3362 regions.emplace_back(&getBefore(), getBefore().getArguments());
3363 return;
3364 }
3365
3366 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3367 "there are only two regions in a WhileOp");
3368
3369 if (point == getAfter()) {
3370 regions.emplace_back(&getBefore(), getBefore().getArguments());
3371 return;
3372 }
3373
3374 regions.emplace_back(getResults());
3375 regions.emplace_back(&getAfter(), getAfter().getArguments());
3376 }
3377
3379 return {&getBefore(), &getAfter()};
3380 }
3381
3382
3383
3384
3385
3386
3387
3388
3394
3397 if (listResult.has_value() && failed(listResult.value()))
3398 return failure();
3399
3400 FunctionType functionType;
3403 return failure();
3404
3405 result.addTypes(functionType.getResults());
3406
3407 if (functionType.getNumInputs() != operands.size()) {
3408 return parser.emitError(typeLoc)
3409 << "expected as many input types as operands "
3410 << "(expected " << operands.size() << " got "
3411 << functionType.getNumInputs() << ")";
3412 }
3413
3414
3415 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3418 return failure();
3419
3420
3421 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3422 regionArgs[i].type = functionType.getInput(i);
3423
3424 return failure(parser.parseRegion(*before, regionArgs) ||
3427 }
3428
3429
3432 p << " : ";
3434 p << ' ';
3435 p.printRegion(getBefore(), false);
3436 p << " do ";
3439 }
3440
3441
3442
3443
3444 template
3446 TypeRange right, StringRef message) {
3447 if (left.size() != right.size())
3448 return op.emitOpError("expects the same number of ") << message;
3449
3450 for (unsigned i = 0, e = left.size(); i < e; ++i) {
3451 if (left[i] != right[i]) {
3453 << message;
3454 diag.attachNote() << "for argument " << i << ", found " << left[i]
3455 << " and " << right[i];
3456 return diag;
3457 }
3458 }
3459
3460 return success();
3461 }
3462
3464 auto beforeTerminator = verifyAndGetTerminatorscf::ConditionOp(
3465 *this, getBefore(),
3466 "expects the 'before' region to terminate with 'scf.condition'");
3467 if (!beforeTerminator)
3468 return failure();
3469
3470 auto afterTerminator = verifyAndGetTerminatorscf::YieldOp(
3471 *this, getAfter(),
3472 "expects the 'after' region to terminate with 'scf.yield'");
3473 return success(afterTerminator != nullptr);
3474 }
3475
3476 namespace {
3477
3478
3479
3480
3481
3482
3483
3484
3485
3486
3487
3488
3489
3490
3491
3492
3493
3494
3495
3496 struct WhileConditionTruth : public OpRewritePattern {
3498
3499 LogicalResult matchAndRewrite(WhileOp op,
3501 auto term = op.getConditionOp();
3502
3503
3504
3505 Value constantTrue = nullptr;
3506
3507 bool replaced = false;
3508 for (auto yieldedAndBlockArgs :
3509 llvm::zip(term.getArgs(), op.getAfterArguments())) {
3510 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3511 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3512 if (!constantTrue)
3513 constantTrue = rewriter.createarith::ConstantOp(
3514 op.getLoc(), term.getCondition().getType(),
3516
3518 constantTrue);
3519 replaced = true;
3520 }
3521 }
3522 }
3523 return success(replaced);
3524 }
3525 };
3526
3527
3528
3529
3530
3531
3532
3533
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544
3545
3546
3547
3548
3549
3550
3551
3552
3553
3554
3555
3556
3557
3558
3559
3560
3561
3562
3563
3564
3565
3566
3567
3568
3569
3570
3571
3572
3573
3574
3575 struct RemoveLoopInvariantArgsFromBeforeBlock
3578
3579 LogicalResult matchAndRewrite(WhileOp op,
3581 Block &afterBlock = *op.getAfterBody();
3583 ConditionOp condOp = op.getConditionOp();
3587
3588 bool canSimplify = false;
3589 for (const auto &it :
3590 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3591 auto index = static_cast<unsigned>(it.index());
3592 auto [initVal, yieldOpArg] = it.value();
3593
3594
3595 if (yieldOpArg == initVal) {
3596 canSimplify = true;
3597 break;
3598 }
3599
3600
3601
3602
3603
3604 auto yieldOpBlockArg = llvm::dyn_cast(yieldOpArg);
3605 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3606 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3607 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3608 canSimplify = true;
3609 break;
3610 }
3611 }
3612 }
3613
3614 if (!canSimplify)
3615 return failure();
3616
3620 for (const auto &it :
3621 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3622 auto index = static_cast<unsigned>(it.index());
3623 auto [initVal, yieldOpArg] = it.value();
3624
3625
3626
3627 if (yieldOpArg == initVal) {
3628 beforeBlockInitValMap.insert({index, initVal});
3629 continue;
3630 } else {
3631
3632
3633
3634
3635
3636 auto yieldOpBlockArg = llvm::dyn_cast(yieldOpArg);
3637 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3638 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3639 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3640 beforeBlockInitValMap.insert({index, initVal});
3641 continue;
3642 }
3643 }
3644 }
3645 newInitArgs.emplace_back(initVal);
3646 newYieldOpArgs.emplace_back(yieldOpArg);
3647 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3648 }
3649
3650 {
3654 }
3655
3656 auto newWhile =
3657 rewriter.create(op.getLoc(), op.getResultTypes(), newInitArgs);
3658
3660 &newWhile.getBefore(), {},
3662
3663 Block &beforeBlock = *op.getBeforeBody();
3665
3666
3667
3668
3669
3670 for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3671
3672
3673 if (beforeBlockInitValMap.count(i) != 0)
3674 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3675 else
3676 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3677 }
3678
3679 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3681 newWhile.getAfter().begin());
3682
3683 rewriter.replaceOp(op, newWhile.getResults());
3684 return success();
3685 }
3686 };
3687
3688
3689
3690
3691
3692
3693
3694
3695
3696
3697
3698
3699
3700
3701
3702
3703
3704
3705
3706
3707
3708
3709
3710
3711
3712
3713
3714
3715
3716
3717
3718
3719
3720
3721
3722
3723
3724
3725
3726
3727
3728 struct RemoveLoopInvariantValueYielded : public OpRewritePattern {
3730
3731 LogicalResult matchAndRewrite(WhileOp op,
3733 Block &beforeBlock = *op.getBeforeBody();
3734 ConditionOp condOp = op.getConditionOp();
3736
3737 bool canSimplify = false;
3738 for (Value condOpArg : condOpArgs) {
3739
3740
3741
3743 canSimplify = true;
3744 break;
3745 }
3746 }
3747
3748 if (!canSimplify)
3749 return failure();
3750
3752
3758 auto index = static_cast<unsigned>(it.index());
3759 Value condOpArg = it.value();
3760
3761
3762
3764 condOpInitValMap.insert({index, condOpArg});
3765 } else {
3766 newCondOpArgs.emplace_back(condOpArg);
3767 newAfterBlockType.emplace_back(condOpArg.getType());
3768 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3769 }
3770 }
3771
3772 {
3775 rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(),
3776 newCondOpArgs);
3777 }
3778
3779 auto newWhile = rewriter.create(op.getLoc(), newAfterBlockType,
3780 op.getOperands());
3781
3782 Block &newAfterBlock =
3783 *rewriter.createBlock(&newWhile.getAfter(), {},
3784 newAfterBlockType, newAfterBlockArgLocs);
3785
3786 Block &afterBlock = *op.getAfterBody();
3787
3788
3789
3790
3793 for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3794 Value afterBlockArg, result;
3795
3796
3797 if (condOpInitValMap.count(i) != 0) {
3798 afterBlockArg = condOpInitValMap[i];
3799 result = afterBlockArg;
3800 } else {
3801 afterBlockArg = newAfterBlock.getArgument(j);
3802 result = newWhile.getResult(j);
3803 j++;
3804 }
3805 newAfterBlockArgs[i] = afterBlockArg;
3806 newWhileResults[i] = result;
3807 }
3808
3809 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3811 newWhile.getBefore().begin());
3812
3813 rewriter.replaceOp(op, newWhileResults);
3814 return success();
3815 }
3816 };
3817
3818
3819
3820
3821
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831
3832
3833
3834
3835
3836
3837
3838
3839
3840
3841
3842
3843
3844 struct WhileUnusedResult : public OpRewritePattern {
3846
3847 LogicalResult matchAndRewrite(WhileOp op,
3849 auto term = op.getConditionOp();
3850 auto afterArgs = op.getAfterArguments();
3851 auto termArgs = term.getArgs();
3852
3853
3858 bool needUpdate = false;
3859 for (const auto &it :
3860 llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3861 auto i = static_cast<unsigned>(it.index());
3862 Value result = std::get<0>(it.value());
3863 Value afterArg = std::get<1>(it.value());
3864 Value termArg = std::get<2>(it.value());
3866 needUpdate = true;
3867 } else {
3868 newResultsIndices.emplace_back(i);
3869 newTermArgs.emplace_back(termArg);
3870 newResultTypes.emplace_back(result.getType());
3871 newArgLocs.emplace_back(result.getLoc());
3872 }
3873 }
3874
3875 if (!needUpdate)
3876 return failure();
3877
3878 {
3882 newTermArgs);
3883 }
3884
3885 auto newWhile =
3886 rewriter.create(op.getLoc(), newResultTypes, op.getInits());
3887
3889 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);
3890
3891
3892
3895 for (const auto &it : llvm::enumerate(newResultsIndices)) {
3896 newResults[it.value()] = newWhile.getResult(it.index());
3897 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3898 }
3899
3901 newWhile.getBefore().begin());
3902
3903 Block &afterBlock = *op.getAfterBody();
3904 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3905
3906 rewriter.replaceOp(op, newResults);
3907 return success();
3908 }
3909 };
3910
3911
3912
3913
3914
3915
3916
3917
3918
3919
3920
3921
3922
3923
3924
3925
3926
3927
3928
3929
3930
3931
3932
3933 struct WhileCmpCond : public OpRewritePatternscf::WhileOp {
3935
3936 LogicalResult matchAndRewrite(scf::WhileOp op,
3938 using namespace scf;
3939 auto cond = op.getConditionOp();
3940 auto cmp = cond.getCondition().getDefiningOparith::CmpIOp();
3941 if (!cmp)
3942 return failure();
3944 for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3945 for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3946 if (std::get<0>(tup) != cmp.getOperand(opIdx))
3947 continue;
3949 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3950 auto cmp2 = dyn_castarith::CmpIOp(u.getOwner());
3951 if (!cmp2)
3952 continue;
3953
3954 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3955 continue;
3956 bool samePredicate;
3957 if (cmp2.getPredicate() == cmp.getPredicate())
3958 samePredicate = true;
3959 else if (cmp2.getPredicate() ==
3961 samePredicate = false;
3962 else
3963 continue;
3964
3965 rewriter.replaceOpWithNewOparith::ConstantIntOp(cmp2, samePredicate,
3966 1);
3968 }
3969 }
3970 }
3971 return success(changed);
3972 }
3973 };
3974
3975
3976 struct WhileRemoveUnusedArgs : public OpRewritePattern {
3978
3979 LogicalResult matchAndRewrite(WhileOp op,
3981
3982 if (!llvm::any_of(op.getBeforeArguments(),
3983 [](Value arg) { return arg.use_empty(); }))
3985
3986 YieldOp yield = op.getYieldOp();
3987
3988
3991 llvm::BitVector argsToErase;
3992
3993 size_t argsCount = op.getBeforeArguments().size();
3994 newYields.reserve(argsCount);
3995 newInits.reserve(argsCount);
3996 argsToErase.reserve(argsCount);
3997 for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3998 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3999 if (beforeArg.use_empty()) {
4000 argsToErase.push_back(true);
4001 } else {
4002 argsToErase.push_back(false);
4003 newYields.emplace_back(yieldValue);
4004 newInits.emplace_back(initValue);
4005 }
4006 }
4007
4008 Block &beforeBlock = *op.getBeforeBody();
4009 Block &afterBlock = *op.getAfterBody();
4010
4012
4014 auto newWhileOp =
4015 rewriter.create(loc, op.getResultTypes(), newInits,
4016 nullptr, nullptr);
4017 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4018 Block &newAfterBlock = *newWhileOp.getAfterBody();
4019
4023
4024 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4025 newBeforeBlock.getArguments());
4026 rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
4028
4029 rewriter.replaceOp(op, newWhileOp.getResults());
4030 return success();
4031 }
4032 };
4033
4034
4035 struct WhileRemoveDuplicatedResults : public OpRewritePattern {
4037
4038 LogicalResult matchAndRewrite(WhileOp op,
4040 ConditionOp condOp = op.getConditionOp();
4041 ValueRange condOpArgs = condOp.getArgs();
4042
4044
4045 if (argsSet.size() == condOpArgs.size())
4047
4048 llvm::SmallDenseMap<Value, unsigned> argsMap;
4050 argsMap.reserve(condOpArgs.size());
4051 newArgs.reserve(condOpArgs.size());
4052 for (Value arg : condOpArgs) {
4053 if (!argsMap.count(arg)) {
4054 auto pos = static_cast<unsigned>(argsMap.size());
4055 argsMap.insert({arg, pos});
4056 newArgs.emplace_back(arg);
4057 }
4058 }
4059
4061
4063 auto newWhileOp = rewriter.createscf::WhileOp(
4064 loc, argsRange.getTypes(), op.getInits(), nullptr,
4065 nullptr);
4066 Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4067 Block &newAfterBlock = *newWhileOp.getAfterBody();
4068
4072 auto it = argsMap.find(arg);
4073 assert(it != argsMap.end());
4074 auto pos = it->second;
4075 afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4076 resultsMapping.emplace_back(newWhileOp->getResult(pos));
4077 }
4078
4081 rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(),
4082 argsRange);
4083
4084 Block &beforeBlock = *op.getBeforeBody();
4085 Block &afterBlock = *op.getAfterBody();
4086
4087 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4088 newBeforeBlock.getArguments());
4089 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4090 rewriter.replaceOp(op, resultsMapping);
4091 return success();
4092 }
4093 };
4094
4095
4096
4097 static std::optional<SmallVector> getArgsMapping(ValueRange args1,
4099 if (args1.size() != args2.size())
4100 return std::nullopt;
4101
4104 auto it = llvm::find(args2, arg1);
4105 if (it == args2.end())
4106 return std::nullopt;
4107
4108 ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4109 }
4110
4111 return ret;
4112 }
4113
4114 static bool hasDuplicates(ValueRange args) {
4115 llvm::SmallDenseSet set;
4116 for (Value arg : args) {
4117 if (!set.insert(arg).second)
4118 return true;
4119 }
4120 return false;
4121 }
4122
4123
4124
4125
4126
4127 struct WhileOpAlignBeforeArgs : public OpRewritePattern {
4129
4130 LogicalResult matchAndRewrite(WhileOp loop,
4132 auto oldBefore = loop.getBeforeBody();
4133 ConditionOp oldTerm = loop.getConditionOp();
4134 ValueRange beforeArgs = oldBefore->getArguments();
4135 ValueRange termArgs = oldTerm.getArgs();
4136 if (beforeArgs == termArgs)
4137 return failure();
4138
4139 if (hasDuplicates(termArgs))
4140 return failure();
4141
4142 auto mapping = getArgsMapping(beforeArgs, termArgs);
4143 if (!mapping)
4144 return failure();
4145
4146 {
4149 rewriter.replaceOpWithNewOp(oldTerm, oldTerm.getCondition(),
4150 beforeArgs);
4151 }
4152
4153 auto oldAfter = loop.getAfterBody();
4154
4157 newResultTypes[j] = loop.getResult(i).getType();
4158
4159 auto newLoop = rewriter.create(
4160 loop.getLoc(), newResultTypes, loop.getInits(),
4161 nullptr, nullptr);
4162 auto newBefore = newLoop.getBeforeBody();
4163 auto newAfter = newLoop.getAfterBody();
4164
4168 newResults[i] = newLoop.getResult(j);
4169 newAfterArgs[i] = newAfter->getArgument(j);
4170 }
4171
4172 rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4173 newBefore->getArguments());
4174 rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4175 newAfterArgs);
4176
4177 rewriter.replaceOp(loop, newResults);
4178 return success();
4179 }
4180 };
4181 }
4182
4183 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4185 results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4186 RemoveLoopInvariantValueYielded, WhileConditionTruth,
4187 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4188 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4189 }
4190
4191
4192
4193
4194
4195
4196 static ParseResult
4198 SmallVectorImpl<std::unique_ptr> &caseRegions) {
4201 int64_t value;
4202 Region ®ion = *caseRegions.emplace_back(std::make_unique());
4204 return failure();
4205 caseValues.push_back(value);
4206 }
4208 return success();
4209 }
4210
4211
4214 for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4216 p << "case " << value << ' ';
4217 p.printRegion(*region, false);
4218 }
4219 }
4220
4222 if (getCases().size() != getCaseRegions().size()) {
4223 return emitOpError("has ")
4224 << getCaseRegions().size() << " case regions but "
4225 << getCases().size() << " case values";
4226 }
4227
4229 for (int64_t value : getCases())
4230 if (!valueSet.insert(value).second)
4231 return emitOpError("has duplicate case value: ") << value;
4232 auto verifyRegion = [&](Region ®ion, const Twine &name) -> LogicalResult {
4233 auto yield = dyn_cast(region.front().back());
4234 if (!yield)
4235 return emitOpError("expected region to end with scf.yield, but got ")
4237
4238 if (yield.getNumOperands() != getNumResults()) {
4239 return (emitOpError("expected each region to return ")
4240 << getNumResults() << " values, but " << name << " returns "
4241 << yield.getNumOperands())
4242 .attachNote(yield.getLoc())
4243 << "see yield operation here";
4244 }
4245 for (auto [idx, result, operand] :
4246 llvm::zip(llvm::seq(0, getNumResults()), getResultTypes(),
4247 yield.getOperandTypes())) {
4248 if (result == operand)
4249 continue;
4250 return (emitOpError("expected result #")
4251 << idx << " of each region to be " << result)
4252 .attachNote(yield.getLoc())
4253 << name << " returns " << operand << " here";
4254 }
4255 return success();
4256 };
4257
4258 if (failed(verifyRegion(getDefaultRegion(), "default region")))
4259 return failure();
4260 for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4261 if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4262 return failure();
4263
4264 return success();
4265 }
4266
4267 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4268
4269 Block &scf::IndexSwitchOp::getDefaultBlock() {
4270 return getDefaultRegion().front();
4271 }
4272
4273 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4274 assert(idx < getNumCases() && "case index out-of-bounds");
4275 return getCaseRegions()[idx].front();
4276 }
4277
4278 void IndexSwitchOp::getSuccessorRegions(
4280
4282 successors.emplace_back(getResults());
4283 return;
4284 }
4285
4286 llvm::append_range(successors, getRegions());
4287 }
4288
4289 void IndexSwitchOp::getEntrySuccessorRegions(
4292 FoldAdaptor adaptor(operands, *this);
4293
4294
4295 auto arg = dyn_cast_or_null(adaptor.getArg());
4296 if (!arg) {
4297 llvm::append_range(successors, getRegions());
4298 return;
4299 }
4300
4301
4302
4303 for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4304 if (caseValue == arg.getInt()) {
4305 successors.emplace_back(&caseRegion);
4306 return;
4307 }
4308 }
4309 successors.emplace_back(&getDefaultRegion());
4310 }
4311
4312 void IndexSwitchOp::getRegionInvocationBounds(
4314 auto operandValue = llvm::dyn_cast_or_null(operands.front());
4315 if (!operandValue) {
4316
4317 bounds.append(getNumRegions(), InvocationBounds(0, 1));
4318 return;
4319 }
4320
4321 unsigned liveIndex = getNumRegions() - 1;
4322 const auto *it = llvm::find(getCases(), operandValue.getInt());
4323 if (it != getCases().end())
4324 liveIndex = std::distance(getCases().begin(), it);
4325 for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4326 bounds.emplace_back(0, i == liveIndex);
4327 }
4328
4331
4334
4335
4337 if (!maybeCst.has_value())
4338 return failure();
4339 int64_t cst = *maybeCst;
4340 int64_t caseIdx, e = op.getNumCases();
4341 for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4342 if (cst == op.getCases()[caseIdx])
4343 break;
4344 }
4345
4346 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4347 : op.getDefaultRegion();
4348 Block &source = r.front();
4351
4353 rewriter.eraseOp(terminator);
4354
4355
4356 rewriter.replaceOp(op, results);
4357
4358 return success();
4359 }
4360 };
4361
4362 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4365 }
4366
4367
4368
4369
4370
4371 #define GET_OP_CLASSES
4372 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of (inner = outer, inner2 = outer2,...
static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.