MLIR: lib/Dialect/SCF/IR/SCF.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

26 #include "llvm/ADT/MapVector.h"

27 #include "llvm/ADT/SmallPtrSet.h"

28 #include "llvm/ADT/TypeSwitch.h"

29

30 using namespace mlir;

32

33 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"

34

35

36

37

38

39 namespace {

42

43

45 IRMapping &valueMapping) const final {

46 return true;

47 }

48

49

51 return true;

52 }

53

54

55 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {

56 auto retValOp = dyn_castscf::YieldOp(op);

57 if (!retValOp)

58 return;

59

60 for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {

61 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));

62 }

63 }

64 };

65 }

66

67

68

69

70

71 void SCFDialect::initialize() {

72 addOperations<

73 #define GET_OP_LIST

74 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"

75 >();

76 addInterfaces();

77 declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();

78 declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,

79 InParallelOp, ReduceReturnOp>();

80 declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,

81 ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,

82 ForallOp, InParallelOp, WhileOp, YieldOp>();

83 declarePromisedInterface<ValueBoundsOpInterface, ForOp>();

84 }

85

86

88 builder.createscf::YieldOp(loc);

89 }

90

91

92

93 template

95 StringRef errorMessage) {

96 Operation *terminatorOperation = nullptr;

98 terminatorOperation = &region.front().back();

99 if (auto yield = dyn_cast_or_null(terminatorOperation))

100 return yield;

101 }

103 if (terminatorOperation)

104 diag.attachNote(terminatorOperation->getLoc()) << "terminator here";

105 return nullptr;

106 }

107

108

109

110

111

112

113

116 assert(llvm::hasSingleElement(region) && "expected single-region block");

122 rewriter.eraseOp(terminator);

123 }

124

125

126

127

128

129

130

131

132

133

134

135

139 return failure();

140

141

143 if (parser.parseRegion(*body, {}, {}) ||

145 return failure();

146

147 return success();

148 }

149

152

153 p << ' ';

155 false,

156 true);

157

159 }

160

162 if (getRegion().empty())

163 return emitOpError("region needs to have at least one block");

164 if (getRegion().front().getNumArguments() > 0)

165 return emitOpError("region cannot have any arguments");

166 return success();

167 }

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

185

188 if (!llvm::hasSingleElement(op.getRegion()))

189 return failure();

191 return success();

192 }

193 };

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

234

237 if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))

238 return failure();

239

240 Block *prevBlock = op->getBlock();

241 Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());

243

244 rewriter.createcf::BranchOp(op.getLoc(), &op.getRegion().front());

245

246 for (Block &blk : op.getRegion()) {

247 if (YieldOp yieldOp = dyn_cast(blk.getTerminator())) {

249 rewriter.createcf::BranchOp(yieldOp.getLoc(), postBlock,

250 yieldOp.getResults());

251 rewriter.eraseOp(yieldOp);

252 }

253 }

254

257

258 for (auto res : op.getResults())

259 blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));

260

261 rewriter.replaceOp(op, blockArgs);

262 return success();

263 }

264 };

265

266 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,

269 }

270

271 void ExecuteRegionOp::getSuccessorRegions(

273

276 return;

277 }

278

279

281 }

282

283

284

285

286

289 assert((point.isParent() || point == getParentOp().getAfter()) &&

290 "condition op can only exit the loop or branch to the after"

291 "region");

292

293 return getArgsMutable();

294 }

295

296 void ConditionOp::getSuccessorRegions(

298 FoldAdaptor adaptor(operands, *this);

299

300 WhileOp whileOp = getParentOp();

301

302

303

304 auto boolAttr = dyn_cast_or_null(adaptor.getCondition());

305 if (!boolAttr || boolAttr.getValue())

306 regions.emplace_back(&whileOp.getAfter(),

307 whileOp.getAfter().getArguments());

308 if (!boolAttr || !boolAttr.getValue())

309 regions.emplace_back(whileOp.getResults());

310 }

311

312

313

314

315

318 BodyBuilderFn bodyBuilder) {

320

323 for (Value v : initArgs)

324 result.addTypes(v.getType());

329 for (Value v : initArgs)

330 bodyBlock->addArgument(v.getType(), v.getLoc());

331

332

333

334

335 if (initArgs.empty() && !bodyBuilder) {

336 ForOp::ensureTerminator(*bodyRegion, builder, result.location);

337 } else if (bodyBuilder) {

342 }

343 }

344

346

347 if (getInitArgs().size() != getNumResults())

348 return emitOpError(

349 "mismatch in number of loop-carried values and defined values");

350

351 return success();

352 }

353

354 LogicalResult ForOp::verifyRegions() {

355

356

358 return emitOpError(

359 "expected induction variable to be same type as bounds and step");

360

361 if (getNumRegionIterArgs() != getNumResults())

362 return emitOpError(

363 "mismatch in number of basic block args and defined values");

364

365 auto initArgs = getInitArgs();

366 auto iterArgs = getRegionIterArgs();

367 auto opResults = getResults();

368 unsigned i = 0;

369 for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {

370 if (std::get<0>(e).getType() != std::get<2>(e).getType())

371 return emitOpError() << "types mismatch between " << i

372 << "th iter operand and defined value";

373 if (std::get<1>(e).getType() != std::get<2>(e).getType())

374 return emitOpError() << "types mismatch between " << i

375 << "th iter region arg and defined value";

376

377 ++i;

378 }

379 return success();

380 }

381

382 std::optional<SmallVector> ForOp::getLoopInductionVars() {

384 }

385

386 std::optional<SmallVector> ForOp::getLoopLowerBounds() {

388 }

389

390 std::optional<SmallVector> ForOp::getLoopSteps() {

392 }

393

394 std::optional<SmallVector> ForOp::getLoopUpperBounds() {

396 }

397

398 std::optional ForOp::getLoopResults() { return getResults(); }

399

400

401

403 std::optional<int64_t> tripCount =

405 if (!tripCount.has_value() || tripCount != 1)

406 return failure();

407

408

409 auto yieldOp = castscf::YieldOp(getBody()->getTerminator());

411

412

413

416 llvm::append_range(bbArgReplacements, getInitArgs());

417

418

419 rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),

420 getOperation()->getIterator(), bbArgReplacements);

421

422

423 rewriter.eraseOp(yieldOp);

425

426 return success();

427 }

428

429

430

431

432

436 StringRef prefix = "") {

437 assert(blocksArgs.size() == initializers.size() &&

438 "expected same length of arguments and initializers");

439 if (initializers.empty())

440 return;

441

442 p << prefix << '(';

443 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {

444 p << std::get<0>(it) << " = " << std::get<1>(it);

445 });

446 p << ")";

447 }

448

450 p << " " << getInductionVar() << " = " << getLowerBound() << " to "

452

454 if (!getInitArgs().empty())

455 p << " -> (" << getInitArgs().getTypes() << ')';

456 p << ' ';

458 p << " : " << t << ' ';

460 false,

461 !getInitArgs().empty());

463 }

464

468

471

472

474

478 return failure();

479

480

483 regionArgs.push_back(inductionVariable);

484

486 if (hasIterArgs) {

487

490 return failure();

491 }

492

493 if (regionArgs.size() != result.types.size() + 1)

496 "mismatch in number of loop-carried values and defined values");

497

498

502 return failure();

503

504

505 regionArgs.front().type = type;

506 for (auto [iterArg, type] :

507 llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))

508 iterArg.type = type;

509

510

512 if (parser.parseRegion(*body, regionArgs))

513 return failure();

514 ForOp::ensureTerminator(*body, builder, result.location);

515

516

517

521 return failure();

522 if (hasIterArgs) {

523 for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),

524 operands, result.types)) {

525 Type type = std::get<2>(argOperandType);

526 std::get<0>(argOperandType).type = type;

527 if (parser.resolveOperand(std::get<1>(argOperandType), type,

529 return failure();

530 }

531 }

532

533

535 return failure();

536

537 return success();

538 }

539

541

543 return getBody()->getArguments().drop_front(getNumInductionVars());

544 }

545

547 return getInitArgsMutable();

548 }

549

550 FailureOr

551 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,

553 bool replaceInitOperandUsesInLoop,

555

558 auto inits = llvm::to_vector(getInitArgs());

559 inits.append(newInitOperands.begin(), newInitOperands.end());

560 scf::ForOp newLoop = rewriter.createscf::ForOp(

564

565

566 auto yieldOp = castscf::YieldOp(getBody()->getTerminator());

568 newLoop.getBody()->getArguments().take_back(newInitOperands.size());

569 {

573 newYieldValuesFn(rewriter, getLoc(), newIterArgs);

574 assert(newInitOperands.size() == newYieldedValues.size() &&

575 "expected as many new yield values as new iter operands");

577 yieldOp.getResultsMutable().append(newYieldedValues);

578 });

579 }

580

581

582 rewriter.mergeBlocks(getBody(), newLoop.getBody(),

583 newLoop.getBody()->getArguments().take_front(

584 getBody()->getNumArguments()));

585

586 if (replaceInitOperandUsesInLoop) {

587

588

589 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {

594 });

595 }

596 }

597

598

599 rewriter.replaceOp(getOperation(),

600 newLoop->getResults().take_front(getNumResults()));

601 return cast(newLoop.getOperation());

602 }

603

605 auto ivArg = llvm::dyn_cast(val);

606 if (!ivArg)

607 return ForOp();

608 assert(ivArg.getOwner() && "unlinked block argument");

609 auto *containingOp = ivArg.getOwner()->getParentOp();

610 return dyn_cast_or_null(containingOp);

611 }

612

614 return getInitArgs();

615 }

616

619

620

621

622 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));

624 }

625

627

628

629

631 for (auto [lb, ub, step] :

632 llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {

634 if (!tripCount.has_value() || *tripCount != 1)

635 return failure();

636 }

637

638 promote(rewriter, *this);

639 return success();

640 }

641

643 return getBody()->getArguments().drop_front(getRank());

644 }

645

647 return getOutputsMutable();

648 }

649

650

653 scf::InParallelOp terminator = forallOp.getTerminator();

654

655

656

657 SmallVector bbArgReplacements = forallOp.getLowerBound(rewriter);

658 bbArgReplacements.append(forallOp.getOutputs().begin(),

659 forallOp.getOutputs().end());

660

661

662 rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),

663 forallOp->getIterator(), bbArgReplacements);

664

665

668 results.reserve(forallOp.getResults().size());

669 for (auto &yieldingOp : terminator.getYieldingOps()) {

670 auto parallelInsertSliceOp =

671 casttensor::ParallelInsertSliceOp(yieldingOp);

672

673 Value dst = parallelInsertSliceOp.getDest();

674 Value src = parallelInsertSliceOp.getSource();

675 if (llvm::isa(src.getType())) {

676 results.push_back(rewriter.createtensor::InsertSliceOp(

677 forallOp.getLoc(), dst.getType(), src, dst,

678 parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),

679 parallelInsertSliceOp.getStrides(),

680 parallelInsertSliceOp.getStaticOffsets(),

681 parallelInsertSliceOp.getStaticSizes(),

682 parallelInsertSliceOp.getStaticStrides()));

683 } else {

684 llvm_unreachable("unsupported terminator");

685 }

686 }

688

689

690 rewriter.eraseOp(terminator);

691 rewriter.eraseOp(forallOp);

692 }

693

698 bodyBuilder) {

699 assert(lbs.size() == ubs.size() &&

700 "expected the same number of lower and upper bounds");

701 assert(lbs.size() == steps.size() &&

702 "expected the same number of lower bounds and steps");

703

704

705 if (lbs.empty()) {

707 bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)

709 assert(results.size() == iterArgs.size() &&

710 "loop nest body must return as many values as loop has iteration "

711 "arguments");

712 return LoopNest{{}, std::move(results)};

713 }

714

715

716

720 loops.reserve(lbs.size());

721 ivs.reserve(lbs.size());

722 ValueRange currentIterArgs = iterArgs;

724 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {

725 auto loop = builder.createscf::ForOp(

726 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,

729 ivs.push_back(iv);

730

731

732 currentIterArgs = args;

733 currentLoc = nestedLoc;

734 });

735

736

737

739 loops.push_back(loop);

740 }

741

742

743 for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {

745 builder.createscf::YieldOp(loc, loops[i + 1].getResults());

746 }

747

748

749

752 ? bodyBuilder(builder, currentLoc, ivs,

753 loops.back().getRegionIterArgs())

755 assert(results.size() == iterArgs.size() &&

756 "loop nest body must return as many values as loop has iteration "

757 "arguments");

759 builder.createscf::YieldOp(loc, results);

760

761

763 llvm::append_range(nestResults, loops.front().getResults());

764 return LoopNest{std::move(loops), std::move(nestResults)};

765 }

766

771

772 return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,

773 [&bodyBuilder](OpBuilder &nestedBuilder,

776 if (bodyBuilder)

777 bodyBuilder(nestedBuilder, nestedLoc, ivs);

778 return {};

779 });

780 }

781

786 assert(operand.getOwner() == forOp);

788

789

790 assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&

791 "expected an iter OpOperand");

793 "Expected a different type");

795 for (OpOperand &opOperand : forOp.getInitArgsMutable()) {

796 if (opOperand.getOperandNumber() == operand.getOperandNumber()) {

797 newIterOperands.push_back(replacement);

798 continue;

799 }

800 newIterOperands.push_back(opOperand.get());

801 }

802

803

804 scf::ForOp newForOp = rewriter.createscf::ForOp(

805 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),

806 forOp.getStep(), newIterOperands);

807 newForOp->setAttrs(forOp->getAttrs());

808 Block &newBlock = newForOp.getRegion().front();

811

812

813

816 BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(

818 Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);

819 newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;

820

821

822 Block &oldBlock = forOp.getRegion().front();

823 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);

824

825

826 auto clonedYieldOp = castscf::YieldOp(newBlock.getTerminator());

828 unsigned yieldIdx =

829 newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();

830 Value castOut = castFn(rewriter, newForOp.getLoc(), newType,

831 clonedYieldOp.getOperand(yieldIdx));

833 newYieldOperands[yieldIdx] = castOut;

834 rewriter.createscf::YieldOp(newForOp.getLoc(), newYieldOperands);

835 rewriter.eraseOp(clonedYieldOp);

836

837

840 newResults[yieldIdx] =

841 castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);

842

843 return newResults;

844 }

845

846 namespace {

847

848

849

850

851

852

853

854

855

856

857

858 struct ForOpIterArgsFolder : public OpRewritePatternscf::ForOp {

860

861 LogicalResult matchAndRewrite(scf::ForOp forOp,

863 bool canonicalize = false;

864

865

866

867

868

869

870 int64_t numResults = forOp.getNumResults();

872 keepMask.reserve(numResults);

874 newResultValues;

875 newBlockTransferArgs.reserve(1 + numResults);

876 newBlockTransferArgs.push_back(Value());

877 newIterArgs.reserve(forOp.getInitArgs().size());

878 newYieldValues.reserve(numResults);

879 newResultValues.reserve(numResults);

881 for (auto [init, arg, result, yielded] :

882 llvm::zip(forOp.getInitArgs(),

883 forOp.getRegionIterArgs(),

884 forOp.getResults(),

885 forOp.getYieldedValues()

886 )) {

887

888

889

890

891

892 bool forwarded = (arg == yielded) || (init == yielded) ||

893 (arg.use_empty() && result.use_empty());

894 if (forwarded) {

895 canonicalize = true;

896 keepMask.push_back(false);

897 newBlockTransferArgs.push_back(init);

898 newResultValues.push_back(init);

899 continue;

900 }

901

902

903

904 if (auto it = initYieldToArg.find({init, yielded});

905 it != initYieldToArg.end()) {

906 canonicalize = true;

907 keepMask.push_back(false);

908 auto [sameArg, sameResult] = it->second;

911

912 newBlockTransferArgs.push_back(init);

913 newResultValues.push_back(init);

914 continue;

915 }

916

917

918 initYieldToArg.insert({{init, yielded}, {arg, result}});

919 keepMask.push_back(true);

920 newIterArgs.push_back(init);

921 newYieldValues.push_back(yielded);

922 newBlockTransferArgs.push_back(Value());

923 newResultValues.push_back(Value());

924 }

925

926 if (!canonicalize)

927 return failure();

928

929 scf::ForOp newForOp = rewriter.createscf::ForOp(

930 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),

931 forOp.getStep(), newIterArgs);

932 newForOp->setAttrs(forOp->getAttrs());

933 Block &newBlock = newForOp.getRegion().front();

934

935

936 newBlockTransferArgs[0] = newBlock.getArgument(0);

937 for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();

938 idx != e; ++idx) {

939 Value &blockTransferArg = newBlockTransferArgs[1 + idx];

940 Value &newResultVal = newResultValues[idx];

941 assert((blockTransferArg && newResultVal) ||

942 (!blockTransferArg && !newResultVal));

943 if (!blockTransferArg) {

944 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];

945 newResultVal = newForOp.getResult(collapsedIdx++);

946 }

947 }

948

949 Block &oldBlock = forOp.getRegion().front();

950 assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&

951 "unexpected argument size mismatch");

952

953

954

955

956 if (newIterArgs.empty()) {

957 auto newYieldOp = castscf::YieldOp(newBlock.getTerminator());

958 rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);

960 rewriter.replaceOp(forOp, newResultValues);

961 return success();

962 }

963

964

965 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {

969 filteredOperands.reserve(newResultValues.size());

970 for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)

971 if (keepMask[idx])

972 filteredOperands.push_back(mergedTerminator.getOperand(idx));

973 rewriter.createscf::YieldOp(mergedTerminator.getLoc(),

974 filteredOperands);

975 };

976

977 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);

978 auto mergedYieldOp = castscf::YieldOp(newBlock.getTerminator());

979 cloneFilteredTerminator(mergedYieldOp);

980 rewriter.eraseOp(mergedYieldOp);

981 rewriter.replaceOp(forOp, newResultValues);

982 return success();

983 }

984 };

985

986

987

988

989 static std::optional<int64_t> computeConstDiff(Value l, Value u) {

990 IntegerAttr clb, cub;

992 llvm::APInt lbValue = clb.getValue();

993 llvm::APInt ubValue = cub.getValue();

994 return (ubValue - lbValue).getSExtValue();

995 }

996

997

998 llvm::APInt diff;

1003 return diff.getSExtValue();

1004 return std::nullopt;

1005 }

1006

1007

1008

1009

1010 struct SimplifyTrivialLoops : public OpRewritePattern {

1012

1013 LogicalResult matchAndRewrite(ForOp op,

1015

1016

1017 if (op.getLowerBound() == op.getUpperBound()) {

1018 rewriter.replaceOp(op, op.getInitArgs());

1019 return success();

1020 }

1021

1022 std::optional<int64_t> diff =

1023 computeConstDiff(op.getLowerBound(), op.getUpperBound());

1024 if (!diff)

1025 return failure();

1026

1027

1028 if (*diff <= 0) {

1029 rewriter.replaceOp(op, op.getInitArgs());

1030 return success();

1031 }

1032

1033 std::optionalllvm::APInt maybeStepValue = op.getConstantStep();

1034 if (!maybeStepValue)

1035 return failure();

1036

1037

1038

1039 llvm::APInt stepValue = *maybeStepValue;

1040 if (stepValue.sge(*diff)) {

1042 blockArgs.reserve(op.getInitArgs().size() + 1);

1043 blockArgs.push_back(op.getLowerBound());

1044 llvm::append_range(blockArgs, op.getInitArgs());

1046 return success();

1047 }

1048

1049

1050 Block &block = op.getRegion().front();

1051 if (!llvm::hasSingleElement(block))

1052 return failure();

1053

1054

1055 if (llvm::any_of(op.getYieldedValues(),

1056 [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))

1057 return failure();

1058 rewriter.replaceOp(op, op.getYieldedValues());

1059 return success();

1060 }

1061 };

1062

1063

1064

1065

1066

1067

1068

1069

1070

1071

1072

1073

1074

1075

1076

1077

1078

1079

1080

1081

1082

1083

1084

1085

1086

1087

1088

1089 struct ForOpTensorCastFolder : public OpRewritePattern {

1091

1092 LogicalResult matchAndRewrite(ForOp op,

1094 for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {

1095 OpOperand &iterOpOperand = std::get<0>(it);

1096 auto incomingCast = iterOpOperand.get().getDefiningOptensor::CastOp();

1097 if (!incomingCast ||

1098 incomingCast.getSource().getType() == incomingCast.getType())

1099 continue;

1100

1101

1103 incomingCast.getDest().getType(),

1104 incomingCast.getSource().getType()))

1105 continue;

1106 if (!std::get<1>(it).hasOneUse())

1107 continue;

1108

1109

1112 rewriter, op, iterOpOperand, incomingCast.getSource(),

1114 return b.createtensor::CastOp(loc, type, source);

1115 }));

1116 return success();

1117 }

1118 return failure();

1119 }

1120 };

1121

1122 }

1123

1124 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,

1126 results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(

1127 context);

1128 }

1129

1130 std::optional ForOp::getConstantStep() {

1131 IntegerAttr step;

1133 return step.getValue();

1134 return {};

1135 }

1136

1137 std::optional<MutableArrayRef> ForOp::getYieldedValuesMutable() {

1138 return castscf::YieldOp(getBody()->getTerminator()).getResultsMutable();

1139 }

1140

1142

1143

1144 if (auto constantStep = getConstantStep())

1145 if (*constantStep == 1)

1147

1148

1149

1151 }

1152

1153

1154

1155

1156

1158 unsigned numLoops = getRank();

1159

1160 if (getNumResults() != getOutputs().size())

1161 return emitOpError("produces ")

1162 << getNumResults() << " results, but has only "

1163 << getOutputs().size() << " outputs";

1164

1165

1166 auto *body = getBody();

1167 if (body->getNumArguments() != numLoops + getOutputs().size())

1168 return emitOpError("region expects ") << numLoops << " arguments";

1169 for (int64_t i = 0; i < numLoops; ++i)

1171 return emitOpError("expects ")

1172 << i << "-th block argument to be an index";

1173 for (unsigned i = 0; i < getOutputs().size(); ++i)

1174 if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())

1175 return emitOpError("type mismatch between ")

1176 << i << "-th output and corresponding block argument";

1177 if (getMapping().has_value() && !getMapping()->empty()) {

1178 if (static_cast<int64_t>(getMapping()->size()) != numLoops)

1179 return emitOpError() << "mapping attribute size must match op rank";

1180 for (auto map : getMapping()->getValue()) {

1181 if (!isa(map))

1182 return emitOpError()

1184 }

1185 }

1186

1187

1190 getStaticLowerBound(),

1191 getDynamicLowerBound())))

1192 return failure();

1194 getStaticUpperBound(),

1195 getDynamicUpperBound())))

1196 return failure();

1198 getStaticStep(), getDynamicStep())))

1199 return failure();

1200

1201 return success();

1202 }

1203

1206 p << " (" << getInductionVars();

1207 if (isNormalized()) {

1208 p << ") in ";

1210 {}, {},

1212 } else {

1213 p << ") = ";

1215 {}, {},

1217 p << " to ";

1219 {}, {},

1221 p << " step ";

1223 {}, {},

1225 }

1227 p << " ";

1228 if (!getRegionOutArgs().empty())

1229 p << "-> (" << getResultTypes() << ") ";

1230 p.printRegion(getRegion(),

1231 false,

1232 getNumResults() > 0);

1233 p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),

1234 getStaticLowerBoundAttrName(),

1235 getStaticUpperBoundAttrName(),

1236 getStaticStepAttrName()});

1237 }

1238

1241 auto indexType = b.getIndexType();

1242

1243

1244

1245

1248 return failure();

1249

1252 dynamicSteps;

1254

1256 nullptr,

1259 return failure();

1260

1261 unsigned numLoops = ivs.size();

1264 } else {

1265

1268 nullptr,

1270

1272 return failure();

1273

1274

1277 nullptr,

1280 return failure();

1281

1282

1285 nullptr,

1288 return failure();

1289 }

1290

1291

1296 if (outOperands.size() != result.types.size())

1297 return parser.emitError(outOperandsLoc,

1298 "mismatch between out operands and types");

1303 return failure();

1304 }

1305

1306

1308 std::unique_ptr region = std::make_unique();

1309 for (auto &iv : ivs) {

1310 iv.type = b.getIndexType();

1311 regionArgs.push_back(iv);

1312 }

1314 auto &out = it.value();

1315 out.type = result.types[it.index()];

1316 regionArgs.push_back(out);

1317 }

1318 if (parser.parseRegion(*region, regionArgs))

1319 return failure();

1320

1321

1322 ForallOp::ensureTerminator(*region, b, result.location);

1323 result.addRegion(std::move(region));

1324

1325

1327 return failure();

1328

1329 result.addAttribute("staticLowerBound", staticLbs);

1330 result.addAttribute("staticUpperBound", staticUbs);

1331 result.addAttribute("staticStep", staticSteps);

1334 {static_cast<int32_t>(dynamicLbs.size()),

1335 static_cast<int32_t>(dynamicUbs.size()),

1336 static_cast<int32_t>(dynamicSteps.size()),

1337 static_cast<int32_t>(outOperands.size())}));

1338 return success();

1339 }

1340

1341

1342 void ForallOp::build(

1346 std::optional mapping,

1353

1359

1360 result.addAttribute(getStaticLowerBoundAttrName(result.name),

1362 result.addAttribute(getStaticUpperBoundAttrName(result.name),

1367 "operandSegmentSizes",

1369 static_cast<int32_t>(dynamicUbs.size()),

1370 static_cast<int32_t>(dynamicSteps.size()),

1371 static_cast<int32_t>(outputs.size())}));

1372 if (mapping.has_value()) {

1374 mapping.value());

1375 }

1376

1380 Block &bodyBlock = bodyRegion->front();

1381

1382

1389

1391 if (!bodyBuilderFn) {

1392 ForallOp::ensureTerminator(*bodyRegion, b, result.location);

1393 return;

1394 }

1396 }

1397

1398

1399 void ForallOp::build(

1402 std::optional mapping,

1404 unsigned numLoops = ubs.size();

1407 build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);

1408 }

1409

1410

1411 bool ForallOp::isNormalized() {

1413 return llvm::all_of(results, [&](OpFoldResult ofr) {

1415 return intValue.has_value() && intValue == val;

1416 });

1417 };

1418 return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);

1419 }

1420

1421 InParallelOp ForallOp::getTerminator() {

1422 return cast(getBody()->getTerminator());

1423 }

1424

1427 InParallelOp inParallelOp = getTerminator();

1428 for (Operation &yieldOp : inParallelOp.getYieldingOps()) {

1429 if (auto parallelInsertSliceOp =

1430 dyn_casttensor::ParallelInsertSliceOp(yieldOp);

1431 parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {

1432 storeOps.push_back(parallelInsertSliceOp);

1433 }

1434 }

1435 return storeOps;

1436 }

1437

1438 std::optional<SmallVector> ForallOp::getLoopInductionVars() {

1439 return SmallVector{getBody()->getArguments().take_front(getRank())};

1440 }

1441

1442

1443 std::optional<SmallVector> ForallOp::getLoopLowerBounds() {

1445 return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);

1446 }

1447

1448

1449 std::optional<SmallVector> ForallOp::getLoopUpperBounds() {

1451 return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);

1452 }

1453

1454

1455 std::optional<SmallVector> ForallOp::getLoopSteps() {

1457 return getMixedValues(getStaticStep(), getDynamicStep(), b);

1458 }

1459

1461 auto tidxArg = llvm::dyn_cast(val);

1462 if (!tidxArg)

1463 return ForallOp();

1464 assert(tidxArg.getOwner() && "unlinked block argument");

1465 auto *containingOp = tidxArg.getOwner()->getParentOp();

1466 return dyn_cast(containingOp);

1467 }

1468

1469 namespace {

1470

1471 struct DimOfForallOp : public OpRewritePatterntensor::DimOp {

1473

1474 LogicalResult matchAndRewrite(tensor::DimOp dimOp,

1476 auto forallOp = dimOp.getSource().getDefiningOp();

1477 if (!forallOp)

1478 return failure();

1479 Value sharedOut =

1480 forallOp.getTiedOpOperand(llvm::cast(dimOp.getSource()))

1481 ->get();

1483 dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });

1484 return success();

1485 }

1486 };

1487

1488 class ForallOpControlOperandsFolder : public OpRewritePattern {

1489 public:

1491

1492 LogicalResult matchAndRewrite(ForallOp op,

1500 return failure();

1501

1503 SmallVector dynamicLowerBound, dynamicUpperBound, dynamicStep;

1506 staticLowerBound);

1507 op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);

1508 op.setStaticLowerBound(staticLowerBound);

1509

1511 staticUpperBound);

1512 op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);

1513 op.setStaticUpperBound(staticUpperBound);

1514

1516 op.getDynamicStepMutable().assign(dynamicStep);

1517 op.setStaticStep(staticStep);

1518

1519 op->setAttr(ForallOp::getOperandSegmentSizeAttr(),

1521 {static_cast<int32_t>(dynamicLowerBound.size()),

1522 static_cast<int32_t>(dynamicUpperBound.size()),

1523 static_cast<int32_t>(dynamicStep.size()),

1524 static_cast<int32_t>(op.getNumResults())}));

1525 });

1526 return success();

1527 }

1528 };

1529

1530

1531

1532

1533

1534

1535

1536

1537

1538

1539

1540

1541

1542

1543

1544

1545

1546

1547

1548

1549

1550

1551

1552

1553

1554

1555

1556

1557

1558

1559

1560

1561

1562

1563

1564

1565

1566

1567

1568

1569

1570

1571

1572

1573

1574

1575

1576

1577

1578

1579

1580

1581

1582

1583

1584

1585

1586

1587

1588

1589

1590

1591

1592

1593

1594

1595

1596

1597

1598

1599

1600

1601

1602

1603 struct ForallOpIterArgsFolder : public OpRewritePattern {

1605

1606 LogicalResult matchAndRewrite(ForallOp forallOp,

1608

1609

1610

1611

1612

1613

1614

1615

1616

1617

1618

1619

1620

1621

1625 for (OpResult result : forallOp.getResults()) {

1626 OpOperand *opOperand = forallOp.getTiedOpOperand(result);

1627 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);

1628 if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {

1629 resultToDelete.insert(result);

1630 } else {

1631 resultToReplace.push_back(result);

1632 newOuts.push_back(opOperand->get());

1633 }

1634 }

1635

1636

1637

1638 if (resultToDelete.empty())

1639 return failure();

1640

1641

1642

1643

1644

1645

1646 for (OpResult result : resultToDelete) {

1647 OpOperand *opOperand = forallOp.getTiedOpOperand(result);

1648 BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);

1650 forallOp.getCombiningOps(blockArg);

1651 for (Operation *combiningOp : combiningOps)

1652 rewriter.eraseOp(combiningOp);

1653 }

1654

1655

1656

1657 auto newForallOp = rewriter.createscf::ForallOp(

1658 forallOp.getLoc(), forallOp.getMixedLowerBound(),

1659 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,

1660 forallOp.getMapping(),

1662

1663

1664

1665 Block *loopBody = forallOp.getBody();

1666 Block *newLoopBody = newForallOp.getBody();

1668

1669

1671 llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),

1674 unsigned index = 0;

1675

1676

1677

1678 for (OpResult result : forallOp.getResults()) {

1679 if (resultToDelete.count(result)) {

1680 newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());

1681 } else {

1682 newBlockArgs.push_back(newSharedOutsArgs[index++]);

1683 }

1684 }

1685 rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);

1686

1687

1688

1689 for (auto &&[oldResult, newResult] :

1690 llvm::zip(resultToReplace, newForallOp->getResults()))

1692

1693

1694

1695

1696 for (OpResult oldResult : resultToDelete)

1698 forallOp.getTiedOpOperand(oldResult)->get());

1699 return success();

1700 }

1701 };

1702

1703 struct ForallOpSingleOrZeroIterationDimsFolder

1706

1707 LogicalResult matchAndRewrite(ForallOp op,

1709

1710 if (op.getMapping().has_value() && !op.getMapping()->empty())

1711 return failure();

1713

1714

1716 newMixedSteps;

1718 for (auto [lb, ub, step, iv] :

1719 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),

1720 op.getMixedStep(), op.getInductionVars())) {

1722 if (numIterations.has_value()) {

1723

1724 if (*numIterations == 0) {

1725 rewriter.replaceOp(op, op.getOutputs());

1726 return success();

1727 }

1728

1729

1730 if (*numIterations == 1) {

1732 continue;

1733 }

1734 }

1735 newMixedLowerBounds.push_back(lb);

1736 newMixedUpperBounds.push_back(ub);

1737 newMixedSteps.push_back(step);

1738 }

1739

1740

1741 if (newMixedLowerBounds.empty()) {

1743 return success();

1744 }

1745

1746

1747 if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {

1749 op, "no dimensions have 0 or 1 iterations");

1750 }

1751

1752

1753 ForallOp newOp;

1754 newOp = rewriter.create(loc, newMixedLowerBounds,

1755 newMixedUpperBounds, newMixedSteps,

1756 op.getOutputs(), std::nullopt, nullptr);

1757 newOp.getBodyRegion().getBlocks().clear();

1758

1759

1760

1762 newOp.getStaticLowerBoundAttrName(),

1763 newOp.getStaticUpperBoundAttrName(),

1764 newOp.getStaticStepAttrName()};

1765 for (const auto &namedAttr : op->getAttrs()) {

1766 if (llvm::is_contained(elidedAttrs, namedAttr.getName()))

1767 continue;

1769 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());

1770 });

1771 }

1773 newOp.getRegion().begin(), mapping);

1774 rewriter.replaceOp(op, newOp.getResults());

1775 return success();

1776 }

1777 };

1778

1779

1780 struct ForallOpReplaceConstantInductionVar : public OpRewritePattern {

1782

1783 LogicalResult matchAndRewrite(ForallOp op,

1787 for (auto [lb, ub, step, iv] :

1788 llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),

1789 op.getMixedStep(), op.getInductionVars())) {

1790 if (iv.hasNUses(0))

1791 continue;

1793 if (!numIterations.has_value() || numIterations.value() != 1) {

1794 continue;

1795 }

1799 }

1800 return success(changed);

1801 }

1802 };

1803

1804 struct FoldTensorCastOfOutputIntoForallOp

1807

1808 struct TypeCast {

1809 Type srcType;

1810 Type dstType;

1811 };

1812

1813 LogicalResult matchAndRewrite(scf::ForallOp forallOp,

1815 llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;

1818 auto castOp = en.value().getDefiningOptensor::CastOp();

1819 if (!castOp)

1820 continue;

1821

1822

1823

1825 castOp.getSource().getType())) {

1826 continue;

1827 }

1828

1829 tensorCastProducers[en.index()] =

1830 TypeCast{castOp.getSource().getType(), castOp.getType()};

1831 newOutputTensors[en.index()] = castOp.getSource();

1832 }

1833

1834 if (tensorCastProducers.empty())

1835 return failure();

1836

1837

1838 Location loc = forallOp.getLoc();

1839 auto newForallOp = rewriter.create(

1840 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),

1841 forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),

1843 auto castBlockArgs =

1844 llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));

1845 for (auto [index, cast] : tensorCastProducers) {

1846 Value &oldTypeBBArg = castBlockArgs[index];

1847 oldTypeBBArg = nestedBuilder.createtensor::CastOp(

1848 nestedLoc, cast.dstType, oldTypeBBArg);

1849 }

1850

1851

1853 llvm::to_vector(bbArgs.take_front(forallOp.getRank()));

1854 ivsBlockArgs.append(castBlockArgs);

1855 rewriter.mergeBlocks(forallOp.getBody(),

1856 bbArgs.front().getParentBlock(), ivsBlockArgs);

1857 });

1858

1859

1860

1861

1862 auto terminator = newForallOp.getTerminator();

1863 for (auto [yieldingOp, outputBlockArg] : llvm::zip(

1864 terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {

1865 auto insertSliceOp = casttensor::ParallelInsertSliceOp(yieldingOp);

1866 insertSliceOp.getDestMutable().assign(outputBlockArg);

1867 }

1868

1869

1872 for (auto &item : tensorCastProducers) {

1873 Value &oldTypeResult = castResults[item.first];

1874 oldTypeResult = rewriter.createtensor::CastOp(loc, item.second.dstType,

1875 oldTypeResult);

1876 }

1877 rewriter.replaceOp(forallOp, castResults);

1878 return success();

1879 }

1880 };

1881

1882 }

1883

1884 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,

1886 results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,

1887 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,

1888 ForallOpSingleOrZeroIterationDimsFolder,

1889 ForallOpReplaceConstantInductionVar>(context);

1890 }

1891

1892

1893

1894

1895

1896

1899

1900

1901

1904 }

1905

1906

1907

1908

1909

1910

1915 }

1916

1918 scf::ForallOp forallOp =

1919 dyn_castscf::ForallOp(getOperation()->getParentOp());

1920 if (!forallOp)

1921 return this->emitOpError("expected forall op parent");

1922

1923

1924 for (Operation &op : getRegion().front().getOperations()) {

1925 if (!isatensor::ParallelInsertSliceOp(op)) {

1926 return this->emitOpError("expected only ")

1927 << tensor::ParallelInsertSliceOp::getOperationName() << " ops";

1928 }

1929

1930

1931 Value dest = casttensor::ParallelInsertSliceOp(op).getDest();

1933 if (!llvm::is_contained(regionOutArgs, dest))

1934 return op.emitOpError("may only insert into an output block argument");

1935 }

1936 return success();

1937 }

1938

1940 p << " ";

1942 false,

1943 false);

1945 }

1946

1948 auto &builder = parser.getBuilder();

1949

1951 std::unique_ptr region = std::make_unique();

1952 if (parser.parseRegion(*region, regionOperands))

1953 return failure();

1954

1955 if (region->empty())

1957 result.addRegion(std::move(region));

1958

1959

1961 return failure();

1962 return success();

1963 }

1964

1965 OpResult InParallelOp::getParentResult(int64_t idx) {

1966 return getOperation()->getParentOp()->getResult(idx);

1967 }

1968

1970 return llvm::to_vector<4>(

1971 llvm::map_range(getYieldingOps(), [](Operation &op) {

1972

1973 auto insertSliceOp = casttensor::ParallelInsertSliceOp(&op);

1974 return llvm::cast(insertSliceOp.getDest());

1975 }));

1976 }

1977

1979 return getRegion().front().getOperations();

1980 }

1981

1982

1983

1984

1985

1987 assert(a && "expected non-empty operation");

1988 assert(b && "expected non-empty operation");

1989

1991 while (ifOp) {

1992

1993 if (ifOp->isProperAncestor(b))

1994

1995

1996 return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=

1997 static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));

1998

1999 ifOp = ifOp->getParentOfType();

2000 }

2001

2002

2003 return false;

2004 }

2005

2006 LogicalResult

2007 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional loc,

2008 IfOp::Adaptor adaptor,

2010 if (adaptor.getRegions().empty())

2011 return failure();

2012 Region *r = &adaptor.getThenRegion();

2013 if (r->empty())

2014 return failure();

2015 Block &b = r->front();

2017 return failure();

2018 auto yieldOp = llvm::dyn_cast(b.back());

2019 if (!yieldOp)

2020 return failure();

2021 TypeRange types = yieldOp.getOperandTypes();

2022 llvm::append_range(inferredReturnTypes, types);

2023 return success();

2024 }

2025

2028 return build(builder, result, resultTypes, cond, false,

2029 false);

2030 }

2031

2033 TypeRange resultTypes, Value cond, bool addThenBlock,

2034 bool addElseBlock) {

2035 assert((!addElseBlock || addThenBlock) &&

2036 "must not create else block w/o then block");

2037 result.addTypes(resultTypes);

2039

2040

2043 if (addThenBlock)

2046 if (addElseBlock)

2048 }

2049

2051 bool withElseRegion) {

2052 build(builder, result, TypeRange{}, cond, withElseRegion);

2053 }

2054

2056 TypeRange resultTypes, Value cond, bool withElseRegion) {

2057 result.addTypes(resultTypes);

2059

2060

2064 if (resultTypes.empty())

2065 IfOp::ensureTerminator(*thenRegion, builder, result.location);

2066

2067

2069 if (withElseRegion) {

2071 if (resultTypes.empty())

2072 IfOp::ensureTerminator(*elseRegion, builder, result.location);

2073 }

2074 }

2075

2079 assert(thenBuilder && "the builder callback for 'then' must be present");

2081

2082

2086 thenBuilder(builder, result.location);

2087

2088

2090 if (elseBuilder) {

2092 elseBuilder(builder, result.location);

2093 }

2094

2095

2099 if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,

2100 nullptr, result.regions,

2101 inferredReturnTypes))) {

2102 result.addTypes(inferredReturnTypes);

2103 }

2104 }

2105

2107 if (getNumResults() != 0 && getElseRegion().empty())

2108 return emitOpError("must have an else block if defining values");

2109 return success();

2110 }

2111

2113

2114 result.regions.reserve(2);

2117

2118 auto &builder = parser.getBuilder();

2123 return failure();

2124

2126 return failure();

2127

2128 if (parser.parseRegion(*thenRegion, {}, {}))

2129 return failure();

2130 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);

2131

2132

2134 if (parser.parseRegion(*elseRegion, {}, {}))

2135 return failure();

2136 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);

2137 }

2138

2139

2141 return failure();

2142 return success();

2143 }

2144

2146 bool printBlockTerminators = false;

2147

2148 p << " " << getCondition();

2149 if (!getResults().empty()) {

2150 p << " -> (" << getResultTypes() << ")";

2151

2152 printBlockTerminators = true;

2153 }

2154 p << ' ';

2156 false,

2157 printBlockTerminators);

2158

2159

2160 auto &elseRegion = getElseRegion();

2161 if (!elseRegion.empty()) {

2162 p << " else ";

2164 false,

2165 printBlockTerminators);

2166 }

2167

2169 }

2170

2173

2176 return;

2177 }

2178

2180

2181

2182 Region *elseRegion = &this->getElseRegion();

2183 if (elseRegion->empty())

2185 else

2187 }

2188

2191 FoldAdaptor adaptor(operands, *this);

2192 auto boolAttr = dyn_cast_or_null(adaptor.getCondition());

2193 if (!boolAttr || boolAttr.getValue())

2194 regions.emplace_back(&getThenRegion());

2195

2196

2197 if (!boolAttr || !boolAttr.getValue()) {

2198 if (!getElseRegion().empty())

2199 regions.emplace_back(&getElseRegion());

2200 else

2201 regions.emplace_back(getResults());

2202 }

2203 }

2204

2205 LogicalResult IfOp::fold(FoldAdaptor adaptor,

2207

2208 if (getElseRegion().empty())

2209 return failure();

2210

2211 arith::XOrIOp xorStmt = getCondition().getDefiningOparith::XOrIOp();

2212 if (!xorStmt)

2213 return failure();

2214

2216 return failure();

2217

2218 getConditionMutable().assign(xorStmt.getLhs());

2219 Block *thenBlock = &getThenRegion().front();

2220

2221

2222 getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),

2223 getElseRegion().getBlocks());

2224 getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),

2225 getThenRegion().getBlocks(), thenBlock);

2226 return success();

2227 }

2228

2229 void IfOp::getRegionInvocationBounds(

2232 if (auto cond = llvm::dyn_cast_or_null(operands[0])) {

2233

2234

2235 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);

2236 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);

2237 } else {

2238

2239 invocationBounds.assign(2, {0, 1});

2240 }

2241 }

2242

2243 namespace {

2244

2245 struct RemoveUnusedResults : public OpRewritePattern {

2247

2250

2252

2253 auto yieldOp = castscf::YieldOp(dest->getTerminator());

2255 llvm::transform(usedResults, std::back_inserter(usedOperands),

2258 });

2260 [&]() { yieldOp->setOperands(usedOperands); });

2261 }

2262

2263 LogicalResult matchAndRewrite(IfOp op,

2265

2267 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),

2268 [](OpResult result) { return !result.use_empty(); });

2269

2270

2271 if (usedResults.size() == op.getNumResults())

2272 return failure();

2273

2274

2276 llvm::transform(usedResults, std::back_inserter(newTypes),

2278

2279

2280 auto newOp =

2281 rewriter.create(op.getLoc(), newTypes, op.getCondition());

2282 rewriter.createBlock(&newOp.getThenRegion());

2283 rewriter.createBlock(&newOp.getElseRegion());

2284

2285

2286

2287 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);

2288 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);

2289

2290

2293 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());

2294 rewriter.replaceOp(op, repResults);

2295 return success();

2296 }

2297 };

2298

2299 struct RemoveStaticCondition : public OpRewritePattern {

2301

2302 LogicalResult matchAndRewrite(IfOp op,

2306 return failure();

2307

2310 else if (!op.getElseRegion().empty())

2312 else

2314

2315 return success();

2316 }

2317 };

2318

2319

2320

2321 struct ConvertTrivialIfToSelect : public OpRewritePattern {

2323

2324 LogicalResult matchAndRewrite(IfOp op,

2326 if (op->getNumResults() == 0)

2327 return failure();

2328

2329 auto cond = op.getCondition();

2330 auto thenYieldArgs = op.thenYield().getOperands();

2331 auto elseYieldArgs = op.elseYield().getOperands();

2332

2334 for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {

2335 if (&op.getThenRegion() == trueVal.getParentRegion() ||

2336 &op.getElseRegion() == falseVal.getParentRegion())

2337 nonHoistable.push_back(trueVal.getType());

2338 }

2339

2340

2341 if (nonHoistable.size() == op->getNumResults())

2342 return failure();

2343

2344 IfOp replacement = rewriter.create(op.getLoc(), nonHoistable, cond,

2345 false);

2346 if (replacement.thenBlock())

2347 rewriter.eraseBlock(replacement.thenBlock());

2348 replacement.getThenRegion().takeBody(op.getThenRegion());

2349 replacement.getElseRegion().takeBody(op.getElseRegion());

2350

2352 assert(thenYieldArgs.size() == results.size());

2353 assert(elseYieldArgs.size() == results.size());

2354

2358 for (const auto &it :

2359 llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {

2360 Value trueVal = std::get<0>(it.value());

2361 Value falseVal = std::get<1>(it.value());

2362 if (&replacement.getThenRegion() == trueVal.getParentRegion() ||

2363 &replacement.getElseRegion() == falseVal.getParentRegion()) {

2364 results[it.index()] = replacement.getResult(trueYields.size());

2365 trueYields.push_back(trueVal);

2366 falseYields.push_back(falseVal);

2367 } else if (trueVal == falseVal)

2368 results[it.index()] = trueVal;

2369 else

2370 results[it.index()] = rewriter.createarith::SelectOp(

2371 op.getLoc(), cond, trueVal, falseVal);

2372 }

2373

2375 rewriter.replaceOpWithNewOp(replacement.thenYield(), trueYields);

2376

2378 rewriter.replaceOpWithNewOp(replacement.elseYield(), falseYields);

2379

2380 rewriter.replaceOp(op, results);

2381 return success();

2382 }

2383 };

2384

2385

2386

2387

2388

2389

2390

2391

2392

2393

2394

2395

2396

2397

2398 struct ConditionPropagation : public OpRewritePattern {

2400

2401 LogicalResult matchAndRewrite(IfOp op,

2403

2404

2406 return failure();

2407

2410

2411

2412

2413 Value constantTrue = nullptr;

2414 Value constantFalse = nullptr;

2415

2417 llvm::make_early_inc_range(op.getCondition().getUses())) {

2420

2421 if (!constantTrue)

2422 constantTrue = rewriter.createarith::ConstantOp(

2423 op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));

2424

2426 [&]() { use.set(constantTrue); });

2427 } else if (op.getElseRegion().isAncestor(

2430

2431 if (!constantFalse)

2432 constantFalse = rewriter.createarith::ConstantOp(

2433 op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));

2434

2436 [&]() { use.set(constantFalse); });

2437 }

2438 }

2439

2440 return success(changed);

2441 }

2442 };

2443

2444

2445

2446

2447

2448

2449

2450

2451

2452

2453

2454

2455

2456

2457

2458

2459

2460

2461

2462

2463

2464

2465

2466

2467

2468

2469

2470

2471

2472

2473

2474

2475

2476

2477

2478

2479

2480 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern {

2482

2483 LogicalResult matchAndRewrite(IfOp op,

2485

2486 if (op.getNumResults() == 0)

2487 return failure();

2488

2489 auto trueYield =

2490 castscf::YieldOp(op.getThenRegion().back().getTerminator());

2491 auto falseYield =

2492 castscf::YieldOp(op.getElseRegion().back().getTerminator());

2493

2495 op.getOperation()->getIterator());

2498 for (auto [trueResult, falseResult, opResult] :

2499 llvm::zip(trueYield.getResults(), falseYield.getResults(),

2500 op.getResults())) {

2501 if (trueResult == falseResult) {

2502 if (!opResult.use_empty()) {

2503 opResult.replaceAllUsesWith(trueResult);

2505 }

2506 continue;

2507 }

2508

2509 BoolAttr trueYield, falseYield;

2512 continue;

2513

2514 bool trueVal = trueYield.getValue();

2515 bool falseVal = falseYield.getValue();

2516 if (!trueVal && falseVal) {

2517 if (!opResult.use_empty()) {

2518 Dialect *constDialect = trueResult.getDefiningOp()->getDialect();

2519 Value notCond = rewriter.createarith::XOrIOp(

2520 op.getLoc(), op.getCondition(),

2521 constDialect

2524 op.getLoc())

2528 }

2529 }

2530 if (trueVal && !falseVal) {

2531 if (!opResult.use_empty()) {

2532 opResult.replaceAllUsesWith(op.getCondition());

2534 }

2535 }

2536 }

2537 return success(changed);

2538 }

2539 };

2540

2541

2542

2543

2544

2545

2546

2547

2548

2549

2550

2551

2552

2553

2554

2555

2556

2557

2558

2559

2560

2561

2564

2565 LogicalResult matchAndRewrite(IfOp nextIf,

2567 Block *parent = nextIf->getBlock();

2568 if (nextIf == &parent->front())

2569 return failure();

2570

2571 auto prevIf = dyn_cast(nextIf->getPrevNode());

2572 if (!prevIf)

2573 return failure();

2574

2575

2576

2577

2578

2579 Block *nextThen = nullptr;

2580 Block *nextElse = nullptr;

2581 if (nextIf.getCondition() == prevIf.getCondition()) {

2582 nextThen = nextIf.thenBlock();

2583 if (!nextIf.getElseRegion().empty())

2584 nextElse = nextIf.elseBlock();

2585 }

2586 if (arith::XOrIOp notv =

2587 nextIf.getCondition().getDefiningOparith::XOrIOp()) {

2588 if (notv.getLhs() == prevIf.getCondition() &&

2590 nextElse = nextIf.thenBlock();

2591 if (!nextIf.getElseRegion().empty())

2592 nextThen = nextIf.elseBlock();

2593 }

2594 }

2595 if (arith::XOrIOp notv =

2596 prevIf.getCondition().getDefiningOparith::XOrIOp()) {

2597 if (notv.getLhs() == nextIf.getCondition() &&

2599 nextElse = nextIf.thenBlock();

2600 if (!nextIf.getElseRegion().empty())

2601 nextThen = nextIf.elseBlock();

2602 }

2603 }

2604

2605 if (!nextThen && !nextElse)

2606 return failure();

2607

2609 if (!prevIf.getElseRegion().empty())

2610 prevElseYielded = prevIf.elseYield().getOperands();

2611

2612

2613 for (auto it : llvm::zip(prevIf.getResults(),

2614 prevIf.thenYield().getOperands(), prevElseYielded))

2616 llvm::make_early_inc_range(std::get<0>(it).getUses())) {

2620 use.set(std::get<1>(it));

2625 use.set(std::get<2>(it));

2627 }

2628 }

2629

2631 llvm::append_range(mergedTypes, nextIf.getResultTypes());

2632

2633 IfOp combinedIf = rewriter.create(

2634 nextIf.getLoc(), mergedTypes, prevIf.getCondition(), false);

2635 rewriter.eraseBlock(&combinedIf.getThenRegion().back());

2636

2638 combinedIf.getThenRegion(),

2639 combinedIf.getThenRegion().begin());

2640

2641 if (nextThen) {

2642 YieldOp thenYield = combinedIf.thenYield();

2643 YieldOp thenYield2 = cast(nextThen->getTerminator());

2644 rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());

2646

2648 llvm::append_range(mergedYields, thenYield2.getOperands());

2649 rewriter.create(thenYield2.getLoc(), mergedYields);

2650 rewriter.eraseOp(thenYield);

2651 rewriter.eraseOp(thenYield2);

2652 }

2653

2655 combinedIf.getElseRegion(),

2656 combinedIf.getElseRegion().begin());

2657

2658 if (nextElse) {

2659 if (combinedIf.getElseRegion().empty()) {

2661 combinedIf.getElseRegion(),

2662 combinedIf.getElseRegion().begin());

2663 } else {

2664 YieldOp elseYield = combinedIf.elseYield();

2665 YieldOp elseYield2 = cast(nextElse->getTerminator());

2666 rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());

2667

2669

2671 llvm::append_range(mergedElseYields, elseYield2.getOperands());

2672

2673 rewriter.create(elseYield2.getLoc(), mergedElseYields);

2674 rewriter.eraseOp(elseYield);

2675 rewriter.eraseOp(elseYield2);

2676 }

2677 }

2678

2681 for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {

2682 if (pair.index() < prevIf.getNumResults())

2683 prevValues.push_back(pair.value());

2684 else

2685 nextValues.push_back(pair.value());

2686 }

2687 rewriter.replaceOp(prevIf, prevValues);

2688 rewriter.replaceOp(nextIf, nextValues);

2689 return success();

2690 }

2691 };

2692

2693

2694 struct RemoveEmptyElseBranch : public OpRewritePattern {

2696

2697 LogicalResult matchAndRewrite(IfOp ifOp,

2699

2700 if (ifOp.getNumResults())

2701 return failure();

2702 Block *elseBlock = ifOp.elseBlock();

2703 if (!elseBlock || !llvm::hasSingleElement(*elseBlock))

2704 return failure();

2706 rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),

2707 newIfOp.getThenRegion().begin());

2709 return success();

2710 }

2711 };

2712

2713

2714

2715

2716

2717

2718

2719

2720

2721

2722

2723

2724

2725

2726

2727

2728

2731

2732 LogicalResult matchAndRewrite(IfOp op,

2734 auto nestedOps = op.thenBlock()->without_terminator();

2735

2736 if (!llvm::hasSingleElement(nestedOps))

2737 return failure();

2738

2739

2740 if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))

2741 return failure();

2742

2743 auto nestedIf = dyn_cast(*nestedOps.begin());

2744 if (!nestedIf)

2745 return failure();

2746

2747 if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))

2748 return failure();

2749

2752 if (op.elseBlock())

2753 llvm::append_range(elseYield, op.elseYield().getOperands());

2754

2755

2756

2758

2759

2760

2761

2762

2763

2764

2765

2767 if (tup.value().getDefiningOp() == nestedIf) {

2768 auto nestedIdx = llvm::cast(tup.value()).getResultNumber();

2769 if (nestedIf.elseYield().getOperand(nestedIdx) !=

2770 elseYield[tup.index()]) {

2771 return failure();

2772 }

2773

2774

2775 thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);

2776 continue;

2777 }

2778

2779

2780

2781

2782

2783

2784

2785

2786

2787

2788 if (tup.value().getParentRegion() == &op.getThenRegion()) {

2789 return failure();

2790 }

2791 elseYieldsToUpgradeToSelect.push_back(tup.index());

2792 }

2793

2795 Value newCondition = rewriter.createarith::AndIOp(

2796 loc, op.getCondition(), nestedIf.getCondition());

2797 auto newIf = rewriter.create(loc, op.getResultTypes(), newCondition);

2798 Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());

2799

2801 llvm::append_range(results, newIf.getResults());

2803

2804 for (auto idx : elseYieldsToUpgradeToSelect)

2805 results[idx] = rewriter.createarith::SelectOp(

2806 op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);

2807

2808 rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);

2811 if (!elseYield.empty()) {

2812 rewriter.createBlock(&newIf.getElseRegion());

2814 rewriter.create(loc, elseYield);

2815 }

2816 rewriter.replaceOp(op, results);

2817 return success();

2818 }

2819 };

2820

2821 }

2822

2823 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,

2825 results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,

2826 ConvertTrivialIfToSelect, RemoveEmptyElseBranch,

2827 RemoveStaticCondition, RemoveUnusedResults,

2828 ReplaceIfYieldWithConditionOrValue>(context);

2829 }

2830

2831 Block *IfOp::thenBlock() { return &getThenRegion().back(); }

2832 YieldOp IfOp::thenYield() { return cast(&thenBlock()->back()); }

2833 Block *IfOp::elseBlock() {

2834 Region &r = getElseRegion();

2835 if (r.empty())

2836 return nullptr;

2837 return &r.back();

2838 }

2839 YieldOp IfOp::elseYield() { return cast(&elseBlock()->back()); }

2840

2841

2842

2843

2844

2845 void ParallelOp::build(

2849 bodyBuilderFn) {

2855 ParallelOp::getOperandSegmentSizeAttr(),

2857 static_cast<int32_t>(upperBounds.size()),

2858 static_cast<int32_t>(steps.size()),

2859 static_cast<int32_t>(initVals.size())}));

2861

2863 unsigned numIVs = steps.size();

2867 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);

2868

2869 if (bodyBuilderFn) {

2871 bodyBuilderFn(builder, result.location,

2872 bodyBlock->getArguments().take_front(numIVs),

2873 bodyBlock->getArguments().drop_front(numIVs));

2874 }

2875

2876 if (initVals.empty())

2877 ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);

2878 }

2879

2880 void ParallelOp::build(

2884

2885

2886

2887 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,

2890 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);

2891 };

2893 if (bodyBuilderFn)

2894 wrapper = wrappedBuilderFn;

2895

2896 build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),

2897 wrapper);

2898 }

2899

2901

2902

2903

2905 if (stepValues.empty())

2906 return emitOpError(

2907 "needs at least one tuple element for lowerBound, upperBound and step");

2908

2909

2910 for (Value stepValue : stepValues)

2912 if (*cst <= 0)

2913 return emitOpError("constant step operand must be positive");

2914

2915

2916

2917 Block *body = getBody();

2919 return emitOpError() << "expects the same number of induction variables: "

2921 << " as bound and step values: " << stepValues.size();

2923 if (!arg.getType().isIndex())

2924 return emitOpError(

2925 "expects arguments for the induction variable to be of index type");

2926

2927

2928 auto reduceOp = verifyAndGetTerminatorscf::ReduceOp(

2929 *this, getRegion(), "expects body to terminate with 'scf.reduce'");

2930 if (!reduceOp)

2931 return failure();

2932

2933

2934 auto resultsSize = getResults().size();

2935 auto reductionsSize = reduceOp.getReductions().size();

2936 auto initValsSize = getInitVals().size();

2937 if (resultsSize != reductionsSize)

2938 return emitOpError() << "expects number of results: " << resultsSize

2939 << " to be the same as number of reductions: "

2940 << reductionsSize;

2941 if (resultsSize != initValsSize)

2942 return emitOpError() << "expects number of results: " << resultsSize

2943 << " to be the same as number of initial values: "

2944 << initValsSize;

2945

2946

2947 for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {

2948 auto resultType = getOperation()->getResult(i).getType();

2949 auto reductionOperandType = reduceOp.getOperands()[i].getType();

2950 if (resultType != reductionOperandType)

2951 return reduceOp.emitOpError()

2952 << "expects type of " << i

2953 << "-th reduction operand: " << reductionOperandType

2954 << " to be the same as the " << i

2955 << "-th result type: " << resultType;

2956 }

2957 return success();

2958 }

2959

2961 auto &builder = parser.getBuilder();

2962

2965 return failure();

2966

2967

2973 return failure();

2974

2980 return failure();

2981

2982

2988 return failure();

2989

2990

2994 return failure();

2995 }

2996

2997

2999 return failure();

3000

3001

3003 for (auto &iv : ivs)

3006 return failure();

3007

3008

3010 ParallelOp::getOperandSegmentSizeAttr(),

3012 static_cast<int32_t>(upper.size()),

3013 static_cast<int32_t>(steps.size()),

3014 static_cast<int32_t>(initVals.size())}));

3015

3016

3020 return failure();

3021

3022

3023 ParallelOp::ensureTerminator(*body, builder, result.location);

3024 return success();

3025 }

3026

3028 p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()

3029 << ") to (" << getUpperBound() << ") step (" << getStep() << ")";

3030 if (!getInitVals().empty())

3031 p << " init (" << getInitVals() << ")";

3033 p << ' ';

3034 p.printRegion(getRegion(), false);

3036 (*this)->getAttrs(),

3037 ParallelOp::getOperandSegmentSizeAttr());

3038 }

3039

3041

3042 std::optional<SmallVector> ParallelOp::getLoopInductionVars() {

3044 }

3045

3046 std::optional<SmallVector> ParallelOp::getLoopLowerBounds() {

3048 }

3049

3050 std::optional<SmallVector> ParallelOp::getLoopUpperBounds() {

3052 }

3053

3054 std::optional<SmallVector> ParallelOp::getLoopSteps() {

3055 return getStep();

3056 }

3057

3059 auto ivArg = llvm::dyn_cast(val);

3060 if (!ivArg)

3061 return ParallelOp();

3062 assert(ivArg.getOwner() && "unlinked block argument");

3063 auto *containingOp = ivArg.getOwner()->getParentOp();

3064 return dyn_cast(containingOp);

3065 }

3066

3067 namespace {

3068

3069 struct ParallelOpSingleOrZeroIterationDimsFolder

3072

3073 LogicalResult matchAndRewrite(ParallelOp op,

3076

3077

3080 for (auto [lb, ub, step, iv] :

3081 llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),

3082 op.getInductionVars())) {

3084 if (numIterations.has_value()) {

3085

3086 if (*numIterations == 0) {

3087 rewriter.replaceOp(op, op.getInitVals());

3088 return success();

3089 }

3090

3091

3092 if (*numIterations == 1) {

3094 continue;

3095 }

3096 }

3097 newLowerBounds.push_back(lb);

3098 newUpperBounds.push_back(ub);

3099 newSteps.push_back(step);

3100 }

3101

3102 if (newLowerBounds.size() == op.getLowerBound().size())

3103 return failure();

3104

3105 if (newLowerBounds.empty()) {

3106

3107

3109 results.reserve(op.getInitVals().size());

3110 for (auto &bodyOp : op.getBody()->without_terminator())

3111 rewriter.clone(bodyOp, mapping);

3112 auto reduceOp = cast(op.getBody()->getTerminator());

3113 for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {

3114 Block &reduceBlock = reduceOp.getReductions()[i].front();

3115 auto initValIndex = results.size();

3116 mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);

3120 rewriter.clone(reduceBodyOp, mapping);

3121

3123 cast(reduceBlock.getTerminator()).getResult());

3124 results.push_back(result);

3125 }

3126

3127 rewriter.replaceOp(op, results);

3128 return success();

3129 }

3130

3131 auto newOp =

3132 rewriter.create(op.getLoc(), newLowerBounds, newUpperBounds,

3133 newSteps, op.getInitVals(), nullptr);

3134

3135 rewriter.eraseBlock(newOp.getBody());

3136

3137

3139 newOp.getRegion().begin(), mapping);

3140 rewriter.replaceOp(op, newOp.getResults());

3141 return success();

3142 }

3143 };

3144

3145 struct MergeNestedParallelLoops : public OpRewritePattern {

3147

3148 LogicalResult matchAndRewrite(ParallelOp op,

3150 Block &outerBody = *op.getBody();

3152 return failure();

3153

3154 auto innerOp = dyn_cast(outerBody.front());

3155 if (!innerOp)

3156 return failure();

3157

3159 if (llvm::is_contained(innerOp.getLowerBound(), val) ||

3160 llvm::is_contained(innerOp.getUpperBound(), val) ||

3161 llvm::is_contained(innerOp.getStep(), val))

3162 return failure();

3163

3164

3165 if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())

3166 return failure();

3167

3170 Block &innerBody = *innerOp.getBody();

3171 assert(iterVals.size() ==

3179 builder.clone(op, mapping);

3180 };

3181

3182 auto concatValues = [](const auto &first, const auto &second) {

3184 ret.reserve(first.size() + second.size());

3185 ret.assign(first.begin(), first.end());

3186 ret.append(second.begin(), second.end());

3187 return ret;

3188 };

3189

3190 auto newLowerBounds =

3191 concatValues(op.getLowerBound(), innerOp.getLowerBound());

3192 auto newUpperBounds =

3193 concatValues(op.getUpperBound(), innerOp.getUpperBound());

3194 auto newSteps = concatValues(op.getStep(), innerOp.getStep());

3195

3196 rewriter.replaceOpWithNewOp(op, newLowerBounds, newUpperBounds,

3197 newSteps, std::nullopt,

3198 bodyBuilder);

3199 return success();

3200 }

3201 };

3202

3203 }

3204

3205 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,

3207 results

3208 .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(

3209 context);

3210 }

3211

3212

3213

3214

3215

3216

3217 void ParallelOp::getSuccessorRegions(

3219

3220

3221

3224 }

3225

3226

3227

3228

3229

3231

3235 for (Value v : operands) {

3241 }

3242 }

3243

3244 LogicalResult ReduceOp::verifyRegions() {

3245

3246

3247 for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {

3248 auto type = getOperands()[i].getType();

3249 Block &block = getReductions()[i].front();

3250 if (block.empty())

3251 return emitOpError() << i << "-th reduction has an empty body";

3254 return arg.getType() != type;

3255 }))

3256 return emitOpError() << "expected two block arguments with type " << type

3257 << " in the " << i << "-th reduction region";

3258

3259

3260 if (!isa(block.getTerminator()))

3261 return emitOpError("reduction bodies must be terminated with an "

3262 "'scf.reduce.return' op");

3263 }

3264

3265 return success();

3266 }

3267

3270

3272 }

3273

3274

3275

3276

3277

3279

3280

3281 Block *reductionBody = getOperation()->getBlock();

3282

3283 assert(isa(reductionBody->getParentOp()) && "expected scf.reduce");

3285 if (expectedResultType != getResult().getType())

3286 return emitOpError() << "must have type " << expectedResultType

3287 << " (the type of the reduction inputs)";

3288 return success();

3289 }

3290

3291

3292

3293

3294

3297 ValueRange inits, BodyBuilderFn beforeBuilder,

3298 BodyBuilderFn afterBuilder) {

3300 odsState.addTypes(resultTypes);

3301

3303

3304

3306 beforeArgLocs.reserve(inits.size());

3307 for (Value operand : inits) {

3308 beforeArgLocs.push_back(operand.getLoc());

3309 }

3310

3312 Block *beforeBlock = odsBuilder.createBlock(beforeRegion, {},

3313 inits.getTypes(), beforeArgLocs);

3314 if (beforeBuilder)

3316

3317

3319

3321 Block *afterBlock = odsBuilder.createBlock(afterRegion, {},

3322 resultTypes, afterArgLocs);

3323

3324 if (afterBuilder)

3326 }

3327

3328 ConditionOp WhileOp::getConditionOp() {

3329 return cast(getBeforeBody()->getTerminator());

3330 }

3331

3332 YieldOp WhileOp::getYieldOp() {

3333 return cast(getAfterBody()->getTerminator());

3334 }

3335

3336 std::optional<MutableArrayRef> WhileOp::getYieldedValuesMutable() {

3337 return getYieldOp().getResultsMutable();

3338 }

3339

3341 return getBeforeBody()->getArguments();

3342 }

3343

3345 return getAfterBody()->getArguments();

3346 }

3347

3349 return getBeforeArguments();

3350 }

3351

3353 assert(point == getBefore() &&

3354 "WhileOp is expected to branch only to the first region");

3355 return getInits();

3356 }

3357

3360

3362 regions.emplace_back(&getBefore(), getBefore().getArguments());

3363 return;

3364 }

3365

3366 assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&

3367 "there are only two regions in a WhileOp");

3368

3369 if (point == getAfter()) {

3370 regions.emplace_back(&getBefore(), getBefore().getArguments());

3371 return;

3372 }

3373

3374 regions.emplace_back(getResults());

3375 regions.emplace_back(&getAfter(), getAfter().getArguments());

3376 }

3377

3379 return {&getBefore(), &getAfter()};

3380 }

3381

3382

3383

3384

3385

3386

3387

3388

3394

3397 if (listResult.has_value() && failed(listResult.value()))

3398 return failure();

3399

3400 FunctionType functionType;

3403 return failure();

3404

3405 result.addTypes(functionType.getResults());

3406

3407 if (functionType.getNumInputs() != operands.size()) {

3408 return parser.emitError(typeLoc)

3409 << "expected as many input types as operands "

3410 << "(expected " << operands.size() << " got "

3411 << functionType.getNumInputs() << ")";

3412 }

3413

3414

3415 if (failed(parser.resolveOperands(operands, functionType.getInputs(),

3418 return failure();

3419

3420

3421 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)

3422 regionArgs[i].type = functionType.getInput(i);

3423

3424 return failure(parser.parseRegion(*before, regionArgs) ||

3427 }

3428

3429

3432 p << " : ";

3434 p << ' ';

3435 p.printRegion(getBefore(), false);

3436 p << " do ";

3439 }

3440

3441

3442

3443

3444 template

3446 TypeRange right, StringRef message) {

3447 if (left.size() != right.size())

3448 return op.emitOpError("expects the same number of ") << message;

3449

3450 for (unsigned i = 0, e = left.size(); i < e; ++i) {

3451 if (left[i] != right[i]) {

3453 << message;

3454 diag.attachNote() << "for argument " << i << ", found " << left[i]

3455 << " and " << right[i];

3456 return diag;

3457 }

3458 }

3459

3460 return success();

3461 }

3462

3464 auto beforeTerminator = verifyAndGetTerminatorscf::ConditionOp(

3465 *this, getBefore(),

3466 "expects the 'before' region to terminate with 'scf.condition'");

3467 if (!beforeTerminator)

3468 return failure();

3469

3470 auto afterTerminator = verifyAndGetTerminatorscf::YieldOp(

3471 *this, getAfter(),

3472 "expects the 'after' region to terminate with 'scf.yield'");

3473 return success(afterTerminator != nullptr);

3474 }

3475

3476 namespace {

3477

3478

3479

3480

3481

3482

3483

3484

3485

3486

3487

3488

3489

3490

3491

3492

3493

3494

3495

3496 struct WhileConditionTruth : public OpRewritePattern {

3498

3499 LogicalResult matchAndRewrite(WhileOp op,

3501 auto term = op.getConditionOp();

3502

3503

3504

3505 Value constantTrue = nullptr;

3506

3507 bool replaced = false;

3508 for (auto yieldedAndBlockArgs :

3509 llvm::zip(term.getArgs(), op.getAfterArguments())) {

3510 if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {

3511 if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {

3512 if (!constantTrue)

3513 constantTrue = rewriter.createarith::ConstantOp(

3514 op.getLoc(), term.getCondition().getType(),

3516

3518 constantTrue);

3519 replaced = true;

3520 }

3521 }

3522 }

3523 return success(replaced);

3524 }

3525 };

3526

3527

3528

3529

3530

3531

3532

3533

3534

3535

3536

3537

3538

3539

3540

3541

3542

3543

3544

3545

3546

3547

3548

3549

3550

3551

3552

3553

3554

3555

3556

3557

3558

3559

3560

3561

3562

3563

3564

3565

3566

3567

3568

3569

3570

3571

3572

3573

3574

3575 struct RemoveLoopInvariantArgsFromBeforeBlock

3578

3579 LogicalResult matchAndRewrite(WhileOp op,

3581 Block &afterBlock = *op.getAfterBody();

3583 ConditionOp condOp = op.getConditionOp();

3587

3588 bool canSimplify = false;

3589 for (const auto &it :

3590 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {

3591 auto index = static_cast<unsigned>(it.index());

3592 auto [initVal, yieldOpArg] = it.value();

3593

3594

3595 if (yieldOpArg == initVal) {

3596 canSimplify = true;

3597 break;

3598 }

3599

3600

3601

3602

3603

3604 auto yieldOpBlockArg = llvm::dyn_cast(yieldOpArg);

3605 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {

3606 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];

3607 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {

3608 canSimplify = true;

3609 break;

3610 }

3611 }

3612 }

3613

3614 if (!canSimplify)

3615 return failure();

3616

3620 for (const auto &it :

3621 llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {

3622 auto index = static_cast<unsigned>(it.index());

3623 auto [initVal, yieldOpArg] = it.value();

3624

3625

3626

3627 if (yieldOpArg == initVal) {

3628 beforeBlockInitValMap.insert({index, initVal});

3629 continue;

3630 } else {

3631

3632

3633

3634

3635

3636 auto yieldOpBlockArg = llvm::dyn_cast(yieldOpArg);

3637 if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {

3638 Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];

3639 if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {

3640 beforeBlockInitValMap.insert({index, initVal});

3641 continue;

3642 }

3643 }

3644 }

3645 newInitArgs.emplace_back(initVal);

3646 newYieldOpArgs.emplace_back(yieldOpArg);

3647 newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());

3648 }

3649

3650 {

3654 }

3655

3656 auto newWhile =

3657 rewriter.create(op.getLoc(), op.getResultTypes(), newInitArgs);

3658

3660 &newWhile.getBefore(), {},

3662

3663 Block &beforeBlock = *op.getBeforeBody();

3665

3666

3667

3668

3669

3670 for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {

3671

3672

3673 if (beforeBlockInitValMap.count(i) != 0)

3674 newBeforeBlockArgs[i] = beforeBlockInitValMap[i];

3675 else

3676 newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);

3677 }

3678

3679 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);

3681 newWhile.getAfter().begin());

3682

3683 rewriter.replaceOp(op, newWhile.getResults());

3684 return success();

3685 }

3686 };

3687

3688

3689

3690

3691

3692

3693

3694

3695

3696

3697

3698

3699

3700

3701

3702

3703

3704

3705

3706

3707

3708

3709

3710

3711

3712

3713

3714

3715

3716

3717

3718

3719

3720

3721

3722

3723

3724

3725

3726

3727

3728 struct RemoveLoopInvariantValueYielded : public OpRewritePattern {

3730

3731 LogicalResult matchAndRewrite(WhileOp op,

3733 Block &beforeBlock = *op.getBeforeBody();

3734 ConditionOp condOp = op.getConditionOp();

3736

3737 bool canSimplify = false;

3738 for (Value condOpArg : condOpArgs) {

3739

3740

3741

3743 canSimplify = true;

3744 break;

3745 }

3746 }

3747

3748 if (!canSimplify)

3749 return failure();

3750

3752

3758 auto index = static_cast<unsigned>(it.index());

3759 Value condOpArg = it.value();

3760

3761

3762

3764 condOpInitValMap.insert({index, condOpArg});

3765 } else {

3766 newCondOpArgs.emplace_back(condOpArg);

3767 newAfterBlockType.emplace_back(condOpArg.getType());

3768 newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());

3769 }

3770 }

3771

3772 {

3775 rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(),

3776 newCondOpArgs);

3777 }

3778

3779 auto newWhile = rewriter.create(op.getLoc(), newAfterBlockType,

3780 op.getOperands());

3781

3782 Block &newAfterBlock =

3783 *rewriter.createBlock(&newWhile.getAfter(), {},

3784 newAfterBlockType, newAfterBlockArgLocs);

3785

3786 Block &afterBlock = *op.getAfterBody();

3787

3788

3789

3790

3793 for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {

3794 Value afterBlockArg, result;

3795

3796

3797 if (condOpInitValMap.count(i) != 0) {

3798 afterBlockArg = condOpInitValMap[i];

3799 result = afterBlockArg;

3800 } else {

3801 afterBlockArg = newAfterBlock.getArgument(j);

3802 result = newWhile.getResult(j);

3803 j++;

3804 }

3805 newAfterBlockArgs[i] = afterBlockArg;

3806 newWhileResults[i] = result;

3807 }

3808

3809 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);

3811 newWhile.getBefore().begin());

3812

3813 rewriter.replaceOp(op, newWhileResults);

3814 return success();

3815 }

3816 };

3817

3818

3819

3820

3821

3822

3823

3824

3825

3826

3827

3828

3829

3830

3831

3832

3833

3834

3835

3836

3837

3838

3839

3840

3841

3842

3843

3844 struct WhileUnusedResult : public OpRewritePattern {

3846

3847 LogicalResult matchAndRewrite(WhileOp op,

3849 auto term = op.getConditionOp();

3850 auto afterArgs = op.getAfterArguments();

3851 auto termArgs = term.getArgs();

3852

3853

3858 bool needUpdate = false;

3859 for (const auto &it :

3860 llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {

3861 auto i = static_cast<unsigned>(it.index());

3862 Value result = std::get<0>(it.value());

3863 Value afterArg = std::get<1>(it.value());

3864 Value termArg = std::get<2>(it.value());

3866 needUpdate = true;

3867 } else {

3868 newResultsIndices.emplace_back(i);

3869 newTermArgs.emplace_back(termArg);

3870 newResultTypes.emplace_back(result.getType());

3871 newArgLocs.emplace_back(result.getLoc());

3872 }

3873 }

3874

3875 if (!needUpdate)

3876 return failure();

3877

3878 {

3882 newTermArgs);

3883 }

3884

3885 auto newWhile =

3886 rewriter.create(op.getLoc(), newResultTypes, op.getInits());

3887

3889 &newWhile.getAfter(), {}, newResultTypes, newArgLocs);

3890

3891

3892

3895 for (const auto &it : llvm::enumerate(newResultsIndices)) {

3896 newResults[it.value()] = newWhile.getResult(it.index());

3897 newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());

3898 }

3899

3901 newWhile.getBefore().begin());

3902

3903 Block &afterBlock = *op.getAfterBody();

3904 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);

3905

3906 rewriter.replaceOp(op, newResults);

3907 return success();

3908 }

3909 };

3910

3911

3912

3913

3914

3915

3916

3917

3918

3919

3920

3921

3922

3923

3924

3925

3926

3927

3928

3929

3930

3931

3932

3933 struct WhileCmpCond : public OpRewritePatternscf::WhileOp {

3935

3936 LogicalResult matchAndRewrite(scf::WhileOp op,

3938 using namespace scf;

3939 auto cond = op.getConditionOp();

3940 auto cmp = cond.getCondition().getDefiningOparith::CmpIOp();

3941 if (!cmp)

3942 return failure();

3944 for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {

3945 for (size_t opIdx = 0; opIdx < 2; opIdx++) {

3946 if (std::get<0>(tup) != cmp.getOperand(opIdx))

3947 continue;

3949 llvm::make_early_inc_range(std::get<1>(tup).getUses())) {

3950 auto cmp2 = dyn_castarith::CmpIOp(u.getOwner());

3951 if (!cmp2)

3952 continue;

3953

3954 if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))

3955 continue;

3956 bool samePredicate;

3957 if (cmp2.getPredicate() == cmp.getPredicate())

3958 samePredicate = true;

3959 else if (cmp2.getPredicate() ==

3961 samePredicate = false;

3962 else

3963 continue;

3964

3965 rewriter.replaceOpWithNewOparith::ConstantIntOp(cmp2, samePredicate,

3966 1);

3968 }

3969 }

3970 }

3971 return success(changed);

3972 }

3973 };

3974

3975

3976 struct WhileRemoveUnusedArgs : public OpRewritePattern {

3978

3979 LogicalResult matchAndRewrite(WhileOp op,

3981

3982 if (!llvm::any_of(op.getBeforeArguments(),

3983 [](Value arg) { return arg.use_empty(); }))

3985

3986 YieldOp yield = op.getYieldOp();

3987

3988

3991 llvm::BitVector argsToErase;

3992

3993 size_t argsCount = op.getBeforeArguments().size();

3994 newYields.reserve(argsCount);

3995 newInits.reserve(argsCount);

3996 argsToErase.reserve(argsCount);

3997 for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(

3998 op.getBeforeArguments(), yield.getOperands(), op.getInits())) {

3999 if (beforeArg.use_empty()) {

4000 argsToErase.push_back(true);

4001 } else {

4002 argsToErase.push_back(false);

4003 newYields.emplace_back(yieldValue);

4004 newInits.emplace_back(initValue);

4005 }

4006 }

4007

4008 Block &beforeBlock = *op.getBeforeBody();

4009 Block &afterBlock = *op.getAfterBody();

4010

4012

4014 auto newWhileOp =

4015 rewriter.create(loc, op.getResultTypes(), newInits,

4016 nullptr, nullptr);

4017 Block &newBeforeBlock = *newWhileOp.getBeforeBody();

4018 Block &newAfterBlock = *newWhileOp.getAfterBody();

4019

4023

4024 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,

4025 newBeforeBlock.getArguments());

4026 rewriter.mergeBlocks(&afterBlock, &newAfterBlock,

4028

4029 rewriter.replaceOp(op, newWhileOp.getResults());

4030 return success();

4031 }

4032 };

4033

4034

4035 struct WhileRemoveDuplicatedResults : public OpRewritePattern {

4037

4038 LogicalResult matchAndRewrite(WhileOp op,

4040 ConditionOp condOp = op.getConditionOp();

4041 ValueRange condOpArgs = condOp.getArgs();

4042

4044

4045 if (argsSet.size() == condOpArgs.size())

4047

4048 llvm::SmallDenseMap<Value, unsigned> argsMap;

4050 argsMap.reserve(condOpArgs.size());

4051 newArgs.reserve(condOpArgs.size());

4052 for (Value arg : condOpArgs) {

4053 if (!argsMap.count(arg)) {

4054 auto pos = static_cast<unsigned>(argsMap.size());

4055 argsMap.insert({arg, pos});

4056 newArgs.emplace_back(arg);

4057 }

4058 }

4059

4061

4063 auto newWhileOp = rewriter.createscf::WhileOp(

4064 loc, argsRange.getTypes(), op.getInits(), nullptr,

4065 nullptr);

4066 Block &newBeforeBlock = *newWhileOp.getBeforeBody();

4067 Block &newAfterBlock = *newWhileOp.getAfterBody();

4068

4072 auto it = argsMap.find(arg);

4073 assert(it != argsMap.end());

4074 auto pos = it->second;

4075 afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));

4076 resultsMapping.emplace_back(newWhileOp->getResult(pos));

4077 }

4078

4081 rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(),

4082 argsRange);

4083

4084 Block &beforeBlock = *op.getBeforeBody();

4085 Block &afterBlock = *op.getAfterBody();

4086

4087 rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,

4088 newBeforeBlock.getArguments());

4089 rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);

4090 rewriter.replaceOp(op, resultsMapping);

4091 return success();

4092 }

4093 };

4094

4095

4096

4097 static std::optional<SmallVector> getArgsMapping(ValueRange args1,

4099 if (args1.size() != args2.size())

4100 return std::nullopt;

4101

4104 auto it = llvm::find(args2, arg1);

4105 if (it == args2.end())

4106 return std::nullopt;

4107

4108 ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);

4109 }

4110

4111 return ret;

4112 }

4113

4114 static bool hasDuplicates(ValueRange args) {

4115 llvm::SmallDenseSet set;

4116 for (Value arg : args) {

4117 if (!set.insert(arg).second)

4118 return true;

4119 }

4120 return false;

4121 }

4122

4123

4124

4125

4126

4127 struct WhileOpAlignBeforeArgs : public OpRewritePattern {

4129

4130 LogicalResult matchAndRewrite(WhileOp loop,

4132 auto oldBefore = loop.getBeforeBody();

4133 ConditionOp oldTerm = loop.getConditionOp();

4134 ValueRange beforeArgs = oldBefore->getArguments();

4135 ValueRange termArgs = oldTerm.getArgs();

4136 if (beforeArgs == termArgs)

4137 return failure();

4138

4139 if (hasDuplicates(termArgs))

4140 return failure();

4141

4142 auto mapping = getArgsMapping(beforeArgs, termArgs);

4143 if (!mapping)

4144 return failure();

4145

4146 {

4149 rewriter.replaceOpWithNewOp(oldTerm, oldTerm.getCondition(),

4150 beforeArgs);

4151 }

4152

4153 auto oldAfter = loop.getAfterBody();

4154

4157 newResultTypes[j] = loop.getResult(i).getType();

4158

4159 auto newLoop = rewriter.create(

4160 loop.getLoc(), newResultTypes, loop.getInits(),

4161 nullptr, nullptr);

4162 auto newBefore = newLoop.getBeforeBody();

4163 auto newAfter = newLoop.getAfterBody();

4164

4168 newResults[i] = newLoop.getResult(j);

4169 newAfterArgs[i] = newAfter->getArgument(j);

4170 }

4171

4172 rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),

4173 newBefore->getArguments());

4174 rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),

4175 newAfterArgs);

4176

4177 rewriter.replaceOp(loop, newResults);

4178 return success();

4179 }

4180 };

4181 }

4182

4183 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,

4185 results.add<RemoveLoopInvariantArgsFromBeforeBlock,

4186 RemoveLoopInvariantValueYielded, WhileConditionTruth,

4187 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,

4188 WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);

4189 }

4190

4191

4192

4193

4194

4195

4196 static ParseResult

4198 SmallVectorImpl<std::unique_ptr> &caseRegions) {

4201 int64_t value;

4202 Region &region = *caseRegions.emplace_back(std::make_unique());

4204 return failure();

4205 caseValues.push_back(value);

4206 }

4208 return success();

4209 }

4210

4211

4214 for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {

4216 p << "case " << value << ' ';

4217 p.printRegion(*region, false);

4218 }

4219 }

4220

4222 if (getCases().size() != getCaseRegions().size()) {

4223 return emitOpError("has ")

4224 << getCaseRegions().size() << " case regions but "

4225 << getCases().size() << " case values";

4226 }

4227

4229 for (int64_t value : getCases())

4230 if (!valueSet.insert(value).second)

4231 return emitOpError("has duplicate case value: ") << value;

4232 auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {

4233 auto yield = dyn_cast(region.front().back());

4234 if (!yield)

4235 return emitOpError("expected region to end with scf.yield, but got ")

4237

4238 if (yield.getNumOperands() != getNumResults()) {

4239 return (emitOpError("expected each region to return ")

4240 << getNumResults() << " values, but " << name << " returns "

4241 << yield.getNumOperands())

4242 .attachNote(yield.getLoc())

4243 << "see yield operation here";

4244 }

4245 for (auto [idx, result, operand] :

4246 llvm::zip(llvm::seq(0, getNumResults()), getResultTypes(),

4247 yield.getOperandTypes())) {

4248 if (result == operand)

4249 continue;

4250 return (emitOpError("expected result #")

4251 << idx << " of each region to be " << result)

4252 .attachNote(yield.getLoc())

4253 << name << " returns " << operand << " here";

4254 }

4255 return success();

4256 };

4257

4258 if (failed(verifyRegion(getDefaultRegion(), "default region")))

4259 return failure();

4260 for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))

4261 if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))

4262 return failure();

4263

4264 return success();

4265 }

4266

4267 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }

4268

4269 Block &scf::IndexSwitchOp::getDefaultBlock() {

4270 return getDefaultRegion().front();

4271 }

4272

4273 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {

4274 assert(idx < getNumCases() && "case index out-of-bounds");

4275 return getCaseRegions()[idx].front();

4276 }

4277

4278 void IndexSwitchOp::getSuccessorRegions(

4280

4282 successors.emplace_back(getResults());

4283 return;

4284 }

4285

4286 llvm::append_range(successors, getRegions());

4287 }

4288

4289 void IndexSwitchOp::getEntrySuccessorRegions(

4292 FoldAdaptor adaptor(operands, *this);

4293

4294

4295 auto arg = dyn_cast_or_null(adaptor.getArg());

4296 if (!arg) {

4297 llvm::append_range(successors, getRegions());

4298 return;

4299 }

4300

4301

4302

4303 for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {

4304 if (caseValue == arg.getInt()) {

4305 successors.emplace_back(&caseRegion);

4306 return;

4307 }

4308 }

4309 successors.emplace_back(&getDefaultRegion());

4310 }

4311

4312 void IndexSwitchOp::getRegionInvocationBounds(

4314 auto operandValue = llvm::dyn_cast_or_null(operands.front());

4315 if (!operandValue) {

4316

4317 bounds.append(getNumRegions(), InvocationBounds(0, 1));

4318 return;

4319 }

4320

4321 unsigned liveIndex = getNumRegions() - 1;

4322 const auto *it = llvm::find(getCases(), operandValue.getInt());

4323 if (it != getCases().end())

4324 liveIndex = std::distance(getCases().begin(), it);

4325 for (unsigned i = 0, e = getNumRegions(); i < e; ++i)

4326 bounds.emplace_back(0, i == liveIndex);

4327 }

4328

4331

4334

4335

4337 if (!maybeCst.has_value())

4338 return failure();

4339 int64_t cst = *maybeCst;

4340 int64_t caseIdx, e = op.getNumCases();

4341 for (caseIdx = 0; caseIdx < e; ++caseIdx) {

4342 if (cst == op.getCases()[caseIdx])

4343 break;

4344 }

4345

4346 Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]

4347 : op.getDefaultRegion();

4348 Block &source = r.front();

4351

4353 rewriter.eraseOp(terminator);

4354

4355

4356 rewriter.replaceOp(op, results);

4357

4358 return success();

4359 }

4360 };

4361

4362 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,

4365 }

4366

4367

4368

4369

4370

4371 #define GET_OP_CLASSES

4372 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"

static std::optional< int64_t > getUpperBound(Value iv)

Gets the constant upper bound on an affine.for iv.

static std::optional< int64_t > getLowerBound(Value iv)

Gets the constant lower bound on an iv.

static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)

Returns the mutable operand range used to transfer operands from block to its successor with the give...

static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)

static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})

Replaces the given op with the contents of the given single-block region, using the operands of the b...

static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)

Parse the case regions and values.

static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)

Verifies that two ranges of types match, i.e.

static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")

Prints the initialization list in the form of (inner = outer, inner2 = outer2,...

static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)

Verifies that the first block of the given region is terminated by a TerminatorTy.

static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)

Print the case regions and values.

static MLIRContext * getContext(OpFoldResult val)

static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)

Utility to check that all of the operations within 'src' can be inlined.

static std::string diag(const llvm::Value &value)

static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)

@ Paren

Parens surrounding zero or more operands.

virtual Builder & getBuilder() const =0

Return a builder which provides useful access to MLIRContext, global objects like types and attribute...

virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0

Parse a named dictionary into 'result' if it is present.

virtual ParseResult parseOptionalKeyword(StringRef keyword)=0

Parse the given keyword if present.

MLIRContext * getContext() const

virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0

Emit a diagnostic at the specified location and return failure.

virtual ParseResult parseOptionalColon()=0

Parse a : token if present.

ParseResult parseInteger(IntT &result)

Parse an integer value from the stream.

virtual ParseResult parseEqual()=0

Parse a = token.

virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0

Parse a named dictionary into 'result' if the attributes keyword is present.

virtual ParseResult parseColonType(Type &result)=0

Parse a colon followed by a type.

virtual SMLoc getCurrentLocation()=0

Get the location of the next token and store it into the argument.

virtual SMLoc getNameLoc() const =0

Return the location of the original name token.

virtual ParseResult parseType(Type &result)=0

Parse a type.

virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0

Parse an optional arrow followed by a type list.

virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0

Parse an arrow followed by a type list.

ParseResult parseKeyword(StringRef keyword)

Parse a given keyword.

void printOptionalArrowTypeList(TypeRange &&types)

Print an optional arrow followed by a type list.

This class represents an argument of a Block.

unsigned getArgNumber() const

Returns the number of this argument.

Block represents an ordered list of Operations.

MutableArrayRef< BlockArgument > BlockArgListType

BlockArgument getArgument(unsigned i)

unsigned getNumArguments()

iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)

Add one argument to the argument list for each type specified in the list.

Region * getParent() const

Provide a 'getParent' method for ilist_node_with_parent methods.

Operation * getTerminator()

Get the terminator operation of this block.

BlockArgument addArgument(Type type, Location loc)

Add one value to the argument list.

void eraseArguments(unsigned start, unsigned num)

Erases 'num' arguments from the index 'start'.

BlockArgListType getArguments()

iterator_range< iterator > without_terminator()

Return an iterator range over the operation within this block excluding the terminator operation at t...

Operation * getParentOp()

Returns the closest surrounding operation that contains this block.

Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.

bool getValue() const

Return the boolean value of this attribute.

This class is a general helper class for creating context-global objects like types,...

IntegerAttr getIndexAttr(int64_t value)

DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)

IntegerAttr getIntegerAttr(Type type, int64_t value)

DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)

IntegerType getIntegerType(unsigned width)

BoolAttr getBoolAttr(bool value)

MLIRContext * getContext() const

This is the interface that must be implemented by the dialects of operations to be inlined.

DialectInlinerInterface(Dialect *dialect)

Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...

virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)

Registered hook to materialize a single constant operation from a given attribute value with the desi...

This is a utility class for mapping one set of IR entities to another.

auto lookupOrDefault(T from) const

Lookup a mapped value within the map.

void map(Value from, Value to)

Inserts a new mapping for 'from' to 'to'.

IRValueT get() const

Return the current value being used by this operand.

void set(IRValueT newValue)

Set the current value being used by this operand.

This class represents a diagnostic that is inflight and set to be reported.

This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

This class provides a mutable adaptor for a range of operands.

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0

virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0

Parses a region.

virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0

Parse zero or more arguments with a specified surrounding delimiter.

ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)

Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)

virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0

Resolve an operand to an SSA value, emitting an error on failure.

ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)

Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...

virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0

Parse a single SSA value operand name along with a result number if allowResultNumber is true.

virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0

Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

virtual void printNewline()=0

Print a newline and indent the printer to the start of the current operation.

virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0

If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...

virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0

If the specified operation has attributes, print out an attribute dictionary with their values.

void printFunctionalType(Operation *op)

Print the complete type of an operation in functional form.

virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0

Prints a region.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)

Clone the blocks that belong to "region" before the given position in another region "parent".

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)

Creates a deep copy of this operation but keep the operation regions empty.

This class represents a single result from folding an operation.

This class represents an operand of an operation.

unsigned getOperandNumber()

Return which operand this is in the OpOperand list of the Operation.

This is a value defined by a result of an operation.

unsigned getResultNumber() const

Returns the number of this result.

This class implements the operand iterators for the Operation class.

Operation is the basic unit of execution within MLIR.

void setAttrs(DictionaryAttr newAttrs)

Set the attributes from a dictionary on this operation.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

Location getLoc()

The source location the operation was defined or derived from.

ArrayRef< NamedAttribute > getAttrs()

Return all of the attributes on this operation.

OpTy getParentOfType()

Return the closest surrounding parent operation that is of type 'OpTy'.

OperationName getName()

The name of an operation is the key identifier for it.

operand_range getOperands()

Returns an iterator on the underlying Value's.

void replaceAllUsesWith(ValuesT &&values)

Replace all uses of results of this operation with the provided 'values'.

Region * getParentRegion()

Returns the region to which the instruction belongs.

bool isProperAncestor(Operation *other)

Return true if this operation is a proper ancestor of the other operation.

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

This class implements Optional functionality for ParseResult.

ParseResult value() const

Access the internal ParseResult value.

bool has_value() const

Returns true if we contain a valid ParseResult value.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class represents a point being branched from in the methods of the RegionBranchOpInterface.

bool isParent() const

Returns true if branching from the parent op.

This class provides an abstraction over the different types of ranges over Regions.

This class represents a successor of a region.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

bool isAncestor(Region *other)

Return true if this region is ancestor of the other region.

unsigned getNumArguments()

BlockArgument getArgument(unsigned i)

RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)

Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void eraseBlock(Block *block)

This method erases all operations in a block.

Block * splitBlock(Block *block, Block::iterator before)

Split the operations starting at "before" (inclusive) out of the given block into a new block,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

void replaceAllUsesWith(Value from, Value to)

Find uses of from and replace them with to.

void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)

Inline the operations of block 'source' into the end of block 'dest'.

virtual void finalizeOpModification(Operation *op)

This method is used to signal the end of an in-place modification of the given operation.

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)

Find uses of from and replace them with to if the functor returns true.

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)

Move the blocks that belong to "region" before the given position in another region "parent".

virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)

Inline the operations of block 'source' into block 'dest' before the given position.

virtual void startOpModification(Operation *op)

This method is used to notify the rewriter that an in-place operation modification is about to happen...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

type_range getTypes() const

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

bool use_empty() const

Returns true if this value has no uses.

Type getType() const

Return the type of this value.

Block * getParentBlock()

Return the Block in which this Value is defined.

Location getLoc() const

Return the location of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

Region * getParentRegion()

Return the Region in which this Value is defined.

Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...

ArrayRef< T > asArrayRef() const

Operation * getOwner() const

Return the owner of this operand.

constexpr auto RecursivelySpeculatable

Speculatability

This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...

constexpr auto NotSpeculatable

LogicalResult promoteIfSingleIteration(AffineForOp forOp)

Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...

arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)

Invert an integer comparison predicate.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

StringRef getMappingAttrName()

Name of the mapping attribute produced by loop mappers.

QueryRef parse(llvm::StringRef line, const QuerySession &qs)

ParallelOp getParallelForInductionVarOwner(Value val)

Returns the parallel loop parent of an induction variable.

void buildTerminatedBody(OpBuilder &builder, Location loc)

Default callback for IfOp builders. Inserts a yield without arguments.

LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)

Creates a perfect nest of "for" loops, i.e.

bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)

Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.

void promote(RewriterBase &rewriter, scf::ForallOp forallOp)

Promotes the loop body of a scf::ForallOp to its containing block.

ForOp getForInductionVarOwner(Value val)

Returns the loop parent of an induction variable.

ForallOp getForallOpThreadIndexOwner(Value val)

Returns the ForallOp parent of an thread index variable.

SmallVector< Value > ValueVector

An owning vector of values, handy to return from functions.

SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)

bool preservesStaticInformation(Type source, Type target)

Returns true if target is a ranked tensor type that preserves static information available in the sou...

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)

Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...

const FrozenRewritePatternSet GreedyRewriteConfig bool * changed

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)

Parser hooks for custom directive in assemblyFormat.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)

Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.

std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn

A function that returns the additional yielded values during replaceWithAdditionalYields.

detail::constant_int_predicate_matcher m_One()

Matches a constant scalar / vector splat / tensor splat integer one.

void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)

Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)

Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...

LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)

Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)

void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)

Printer hooks for custom directive in assemblyFormat.

LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)

Returns "success" when any of the elements in ofrs is a constant value.

LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override

LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override

LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override

UnresolvedOperand ssaName

This is the representation of an operand reference.

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...

This represents an operation in an abstracted form, suitable for use with the builder APIs.

SmallVector< Value, 4 > operands

void addOperands(ValueRange newOperands)

void addAttribute(StringRef name, Attribute attr)

Add an attribute with the specified name.

void addTypes(ArrayRef< Type > newTypes)

SmallVector< std::unique_ptr< Region >, 1 > regions

Regions that the op will hold.

SmallVector< Type, 4 > types

Types of the results of this operation.

Region * addRegion()

Create a region that should be attached to the operation.

Eliminates variable at the specified position using Fourier-Motzkin variable elimination.