MLIR: lib/Dialect/SCF/Transforms/TileUsingInterface.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

14

31 #include "llvm/ADT/ScopeExit.h"

32 #include "llvm/ADT/TypeSwitch.h"

33 #include "llvm/Support/Debug.h"

34 #include

35

36 #define DEBUG_TYPE "tile-using-interface"

37

38 using namespace mlir;

39

43 auto tileSizes = llvm::to_vector(ts);

45 return tileSizes;

46 };

47 return *this;

48 }

49

52 assert(!numThreadsComputationFunction && "num tiles already set");

53 auto numThreads = llvm::to_vector(nt);

54 numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {

55 return numThreads;

56 };

57 return *this;

58 }

59

60

61

64 size_t iterationDomainSize) {

66 if (filledVector.size() < iterationDomainSize) {

67 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);

68 filledVector.append(range.begin(), range.end());

69 }

70 if (filledVector.size() > iterationDomainSize)

71 filledVector.resize(iterationDomainSize);

72 return filledVector;

73 }

74

75

76

77

78

79

80 static LogicalResult

83

84 if (options.numThreadsComputationFunction &&

87 loc, "number of threads can only by specified when loop type is "

88 "set to use `scf.forall`");

89 }

90

91

92 if (options.interchangeVector.empty()) {

95 loc, "invalid interchange vector, not a permutation of the entire "

96 "iteration space");

97 }

98 }

99 return success();

100 }

101

102

103

110 size_t numLoops = iterationDomain.size();

111

112

113 if (options.numThreadsComputationFunction) {

114 numThreads = options.numThreadsComputationFunction(rewriter, op);

115 numThreads.resize(numLoops, zero);

116

117

118 if (options.tileSizeComputationFunction) {

119 tileSizes = options.tileSizeComputationFunction(rewriter, op);

120 tileSizes.resize(numLoops, zero);

121 return {tileSizes, numThreads};

122 }

123

124

125

126

127

130

133 tileSizes.resize(numLoops, zero);

134 for (auto [index, range, nt] :

137 continue;

138

140 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});

141 }

142 tileSizes.resize(numLoops, zero);

143 return {tileSizes, numThreads};

144 }

145

146

147

148

149

150 assert(options.tileSizeComputationFunction &&

151 "expected tile sizes to be specified");

152 tileSizes = options.tileSizeComputationFunction(rewriter, op);

153 tileSizes.resize(numLoops, zero);

154

155 return {tileSizes, numThreads};

156 }

157

158

162 auto iterators = op.getLoopIteratorTypes();

163 assert(iterators.size() == tileSizes.size() &&

164 "expected as many tile size values as number of loops");

165 assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&

166 "when specified, expected number of threads to use for each loop");

167

168 for (auto [index, iterator, tileSize] :

170

171

172 if (!numThreads.empty()) {

173 if (std::optional<int64_t> constNumThreads =

175 if (constNumThreads.value() > 1 &&

176 iterator != utils::IteratorType::parallel) {

177 op.emitWarning() << "tiling is not thread safe at axis #" << index;

178 }

179 }

180 continue;

181 }

182

183 if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {

184 if (constTileSize.value() > 0 &&

185 iterator != utils::IteratorType::parallel) {

186 op.emitWarning() << "tiling is not thread safe at axis #" << index;

187 }

188 }

189 }

190 }

191

192

195 if (!offsetAsInt)

196 return false;

198 if (!sizeAsInt)

199 return false;

201 if (!strideAsInt)

202 return false;

203 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);

204 }

205

206

207

212 if (ts && ts.value() == 1)

213 return tileSize;

214

217 return tileSize;

218

219

220

221

229 }

230

231

232

237 std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);

238 std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);

239 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)

240 return false;

241 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;

242 }

243

244

245

246

253 int materializedLoopNum = 0;

254

255 if (!numThreads.empty()) {

257 AffineExpr offsetExpr, residualTileSizeExpr;

260 offsetExpr = d0 + d1 * s0;

261 residualTileSizeExpr = s1 - (d0 + d1 * s0);

262

263 for (auto [nt, tileSize, loopRange] :

264 llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {

265

266

267

269 offsets.push_back(loopRange.offset);

270 sizes.push_back(loopRange.size);

271 continue;

272 }

273

274 Value iv = ivs[materializedLoopNum++];

276 rewriter, loc, offsetExpr,

279 rewriter, loc, residualTileSizeExpr,

280 {loopRange.offset, nt, tileSize, loopRange.size});

281

286 {offset, loopRange.size});

288 rewriter, loc,

290 {sizeMinusOffsetPerThread, tileSize});

291 }

292

293

294

295

296

297

298

299

300

301

306 rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});

307 }

308

309 offsets.push_back(offset);

310 sizes.push_back(size);

311 }

312 return {offsets, sizes};

313 } else {

314 for (auto [tileSize, loopRange] :

315 llvm::zip_equal(tileSizes, iterationDomain)) {

316

317

318

320 offsets.push_back(loopRange.offset);

321 sizes.push_back(loopRange.size);

322 continue;

323 }

324

325 Value iv = ivs[materializedLoopNum++];

327 offsets.push_back(offset);

330 sizes.push_back(size);

331 }

332 return {offsets, sizes};

333 }

334 }

335

336

342 for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {

343

345 continue;

346 lbs.push_back(loopRange.offset);

347 ubs.push_back(loopRange.size);

348 steps.push_back(tileSize);

349 }

350 return {lbs, ubs, steps};

351 }

352

353

354

355

356

357

358

359

360

361

362

363

364

365

366

367

373

374

375

380 if (newDestArgs.empty())

381 return clonedOp;

382 if (auto destinationStyleOp = dyn_cast(clonedOp))

383 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);

384 return clonedOp;

385 }

386

387

388

389

390

391

392

393

394

395

401 assert(!loopRanges.empty() && "unexpected empty loop ranges");

402 assert(loopRanges.size() == tileSizes.size() &&

403 "expected as many tile sizes as loop ranges");

405

407 std::tie(lbs, ubs, steps) =

408 getLoopBounds(rewriter, loc, loopRanges, tileSizes);

415

417 for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {

418 auto loop =

419 rewriter.createscf::ForOp(loc, lb, ub, step, destinationTensors,

422 loops.push_back(loop);

423 ivs.push_back(loop.getInductionVar());

425 destinationTensors = loop.getRegionIterArgs();

426 }

427

430 if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,

431 tiledResults, resultOffsets, resultSizes))) {

433 loc, "failed to generate inner tile loop body");

434 }

435 if (loops.empty())

436 return success();

437

438 assert(tiledResults.size() == destinationTensors.size() &&

439 "Number of results of body should be equal to number of iter args");

440

441

443 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :

444 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,

445 resultSizes)) {

448 auto insertSlice = rewriter.createtensor::InsertSliceOp(

449 loc, tiledValue, destinationTensor, resultOffset, resultSize,

450 resultStride);

451 yieldedValues.push_back(insertSlice);

452 }

453 rewriter.createscf::YieldOp(loc, yieldedValues);

454

455

456 for (auto [outerLoop, innerLoop] :

460 castscf::ForOp(outerLoop.getOperation()).getBody());

461 rewriter.createscf::YieldOp(outerLoop.getLoc(), innerLoop->getResults());

462 }

463 return success();

464 }

465

466

467

468

469

470

471

472

473

474

475

476

482 assert(!loopRanges.empty() && "unexpected empty loop ranges");

483 assert(loopRanges.size() == tileSizes.size() &&

484 "expected as many tile sizes as loop ranges");

486

487 std::optional mappingAttr;

488 if (!mappingVector.empty())

489 mappingAttr = rewriter.getArrayAttr(mappingVector);

490

491 scf::ForallOp forallOp;

492 bool useNumThreads = !numThreads.empty();

493

494 if (useNumThreads) {

495

497 for (auto nt : numThreads) {

499 continue;

500 nonZeroNumThreads.push_back(nt);

501 }

502 forallOp = rewriter.createscf::ForallOp(loc, nonZeroNumThreads,

503 destinationTensors, mappingAttr);

504 } else {

506 std::tie(lbs, ubs, steps) =

507 getLoopBounds(rewriter, loc, loopRanges, tileSizes);

508 forallOp = rewriter.createscf::ForallOp(loc, lbs, ubs, steps,

509 destinationTensors, mappingAttr);

510 }

511 loops.push_back(forallOp);

512

514 destinationTensors = forallOp.getRegionOutArgs();

515

518 if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),

519 destinationTensors, tiledResults, resultOffsets,

520 resultSizes)))

521 return rewriter.notifyMatchFailure(loc, "failed to generate loop body");

522

524 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :

525 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,

526 resultSizes)) {

529

530 rewriter.createtensor::ParallelInsertSliceOp(

531 loc, tiledValue, destinationTensor, resultOffset, resultSize,

532 resultStride);

533 }

534 return success();

535 }

536

537

538

539

540

541

542

543

544

545

546

552

553

557 return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,

558 tiledResults, resultOffsets, resultSizes);

559 }

562 destinationTensors, tiledBodyFn, loops);

563 }

566 rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,

567 destinationTensors, tiledBodyFn, loops);

568 }

570 }

571

572 static FailureOr<SmallVector>

578 switch (options.reductionStrategy) {

581 return failure();

582 return initTensors;

585 auto redOp = dyn_cast(op.getOperation());

586 if (!redOp) {

588 op, "PartialReductionOuterReduction tiling strategy is only supported"

589 "for operations implementing PartialReductionOpInterface");

590 }

591

592

593

595 for (auto [idx, iteratorType] :

597 if (iteratorType == utils::IteratorType::reduction)

598 reductionDims.push_back(idx);

599 }

600 return redOp.generateInitialTensorForPartialReduction(

601 rewriter, loc, tileSizes, reductionDims);

602 }

603 default:

605 "unhandled reduction tiling strategy");

606 }

607 }

608

609 static FailureOr

614 switch (options.reductionStrategy) {

616 return op.getTiledImplementation(rewriter, offsets, sizes);

619 auto redOp = dyn_cast(op.getOperation());

620 if (!redOp) {

622 op, "PartialReductionOuterReduction tiling strategy is only "

623 "supported for operations "

624 "implementing PartialReductionOpInterface");

625 }

626

627

628

630 for (auto [idx, iteratorType] :

632 if (iteratorType == utils::IteratorType::reduction)

633 reductionDims.push_back(idx);

634 }

635 return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,

636 offsets, sizes, reductionDims);

637 }

638 default:

640 "unhandled reduction tiling strategy");

641 }

642 }

643

644 static LogicalResult

651

652 switch (options.reductionStrategy) {

654 return op.getResultTilePosition(rewriter, index, offsets, sizes,

655 resultOffset, resultSize);

658 auto redOp = dyn_cast(op.getOperation());

659 if (!redOp) {

661 op, "PartialReductionOuterReduction tiling strategy is only supported"

662 "for operations implementing PartialReductionOpInterface");

663 }

664

665

666

668 for (auto [idx, iteratorType] :

670 if (iteratorType == utils::IteratorType::reduction)

671 reductionDims.push_back(idx);

672 }

673 return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,

674 resultOffset, resultSize,

675 reductionDims);

676 }

677 default:

679 "unhandled reduction tiling strategy");

680 }

681 }

682

683 static FailureOr

687 switch (options.reductionStrategy) {

689

693 auto redOp = dyn_cast(op.getOperation());

694 if (!redOp) {

696 op, "PartialReductionOuterReduction tiling strategy is only "

697 "supported for operations "

698 "implementing PartialReductionOpInterface");

699 }

700

701

702

704 for (auto [idx, iteratorType] :

706 if (iteratorType == utils::IteratorType::reduction)

707 reductionDims.push_back(idx);

708 }

709 return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,

710 reductionDims);

711 }

712 default:

714 "unhandled reduction tiling strategy");

715 }

716 }

717

718

719

720

721

722

723

724

725 template

726 FailureOr

731 }

732

733

734 template <>

735 FailureOr yieldTiledValuesAndReplaceLoopscf::ForOp(

739 Location loc = loopOp.getLoc();

741

742 auto inits = llvm::to_vector(loopOp.getInitArgs());

743 inits.append(newInitOperands.begin(), newInitOperands.end());

744 auto newLoop = rewriter.createscf::ForOp(

745 loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),

747

748

749 Block *loopBody = loopOp.getBody();

750 Block *newLoopBody = newLoop.getBody();

752 loopBody, newLoopBody,

753 newLoopBody->getArguments().take_front(loopBody->getNumArguments()));

754

755 auto yieldOp = castscf::YieldOp(newLoopBody->getTerminator());

757

761 newLoop.getRegionIterArgs().take_back(newInitOperands.size());

762 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),

763 newRegionIterArgs, tiledValues, resultOffsets,

764 resultSizes))) {

765 rewriter.eraseOp(newLoop);

766 return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");

767 }

768

769 SmallVector newYieldValues = llvm::to_vector(yieldOp.getOperands());

770 for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :

771 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,

772 resultSizes)) {

775 Value insert = rewriter.createtensor::InsertSliceOp(

776 yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,

777 resultStride);

778 newYieldValues.push_back(insert);

779 }

780

783 newLoop->getResults().take_front(loopOp.getNumResults()));

784 return cast(newLoop.getOperation());

785 }

786

787

788 template <>

789 FailureOr yieldTiledValuesAndReplaceLoopscf::ForallOp(

793 Location loc = loopOp.getLoc();

795 auto inits = llvm::to_vector(loopOp.getOutputs());

796 inits.append(newInitOperands.begin(), newInitOperands.end());

797 auto newLoop = rewriter.createscf::ForallOp(

798 loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),

799 loopOp.getMixedStep(), inits, loopOp.getMapping(),

801

802

803 Block *loopBody = loopOp.getBody();

804 Block *newLoopBody = newLoop.getBody();

806 loopBody, newLoopBody,

807 newLoopBody->getArguments().take_front(loopBody->getNumArguments()));

808

809 auto terminator = castscf::InParallelOp(newLoopBody->getTerminator());

814 newLoop.getRegionIterArgs().take_back(newInitOperands.size());

815 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),

816 regionIterArgs, tiledValues, resultOffsets,

817 resultSizes))) {

818 rewriter.eraseOp(newLoop);

820 "failed to get yielded tiled values");

821 }

822

823

825

826 for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(

827 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {

830 rewriter.createtensor::ParallelInsertSliceOp(

831 terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,

832 resultStride);

833 }

834

836 newLoop->getResults().take_front(loopOp.getNumResults()));

837 return cast(newLoop.getOperation());

838 }

839

840

841

842

844 LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,

847 loopLikeOp.getOperation())

848 .Case<scf::ForOp, scf::ForallOp>(

849 [&](auto loopOp) -> FailureOr {

851 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);

852 })

853 .Default([&](auto loopOp) -> FailureOr {

855 });

856 }

857

858

859

860

861

862

866 if (loops.empty())

867 return success();

870

872 for (auto &loop : loops.drop_back()) {

874

875

876 auto forLoop = castscf::ForOp(loop.getOperation());

877

878

879 SmallVector newInits = llvm::to_vector(forLoop.getInitArgs());

880 newInits.append(newInitValues.begin(), newInitValues.end());

881 auto newLoop = rewriter.createscf::ForOp(

882 forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),

883 forLoop.getStep(), newInits,

885

886

888 sourceBlockArgs.push_back(newLoop.getInductionVar());

889 auto newRegionIterArgs = newLoop.getRegionIterArgs();

890 sourceBlockArgs.append(

891 newRegionIterArgs.begin(),

892 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));

893 rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);

895 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));

896 loop = newLoop;

897 ivs.push_back(newLoop.getInductionVar());

898 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());

899 }

900

901

902 LoopLikeOpInterface innerMostLoop = loops.back();

903 FailureOr newInnerMostLoop =

905 getNewTiledYieldsFn);

906

907 if (failed(newInnerMostLoop))

908 return innerMostLoop.emitOpError("failed to return additional yields");

909 loops.back() = newInnerMostLoop.value();

910

911

912

913 for (auto [outerLoop, innerLoop] :

914 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {

915

916 auto outerForLoop = castscf::ForOp(outerLoop);

917 auto outerLoopYield =

918 castscf::YieldOp(outerForLoop.getBody()->getTerminator());

920 llvm::to_vector(outerLoopYield.getOperands());

922 innerLoop->getResults().take_back(newInitValues.size());

923 newYields.append(additionalYields.begin(), additionalYields.end());

925 rewriter.replaceOpWithNewOpscf::YieldOp(outerLoopYield, newYields);

926 }

927 return success();

928 }

929

930

931

932 FailureOrscf::SCFTilingResult

936 return failure();

937 }

938

941

942

943 SmallVector iterationDomain = op.getIterationDomain(rewriter);

944

945

947 std::tie(tileSizes, numThreads) =

949

950

951

954 }

955

956

957

959 if (options.interchangeVector.empty()) {

961 iterationDomain.size());

963 "expected interchange vector to be a permutation");

964

967 if (!numThreads.empty())

969 }

970

971 FailureOr tilingResult;

972

973

979 -> LogicalResult {

980

983 rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);

984

985

986

987 if (!interchangeVector.empty()) {

991 }

992

993

994

995

996 auto clonedOp = cast(

998

999

1000

1001

1003 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());

1004 tilingResult =

1005 TilingResult{{clonedOp}, clonedOp->getResults(),

1006 {}};

1007 return success();

1008 }

1009

1010

1012 offsets, sizes, options);

1013 if (failed(tilingResult)) {

1014 rewriter.eraseOp(clonedOp);

1015 return op.emitOpError("faild to tile operation");

1016 }

1017

1018

1019 rewriter.eraseOp(clonedOp);

1020

1021

1022

1023 for (auto [index, tiledValue] :

1025 tiledResults.push_back(tiledValue);

1028 sizes, resultOffset, resultSize,

1030 for (auto op : tilingResult->tiledOps) {

1032 }

1034 op, "failed to get slice of result produced");

1035 }

1036 resultOffsets.emplace_back(std::move(resultOffset));

1037 resultSizes.emplace_back(std::move(resultSize));

1038 }

1039

1040 return success();

1041 };

1042

1043

1044 FailureOr<SmallVector> maybeInits =

1046 if (failed(maybeInits)) {

1048 op, "unable to create initial tensors for tiling");

1049 }

1051

1052

1055 tileSizes, numThreads, initTensors,

1056 innerYieldTiledValuesFn, loops)))

1057 return op.emitOpError("failed to generate tiling loops");

1058 assert(succeeded(tilingResult) &&

1059 "expected tiling result to be computed after loop generation");

1060

1061 if (loops.empty()) {

1062

1063

1065 initTensors,

1066 loops,

1067 tilingResult->tiledValues,

1068 tilingResult->generatedSlices,

1069 {}};

1070 }

1071

1072 auto loopResults = llvm::map_to_vector(loops.front()->getResults(),

1074

1075

1076 if (options.reductionStrategy ==

1079 tilingResult->tiledOps, initTensors, loops, loopResults,

1080 tilingResult->generatedSlices, {}};

1081 }

1082

1083

1084 FailureOr mergeResult =

1086 if (failed(mergeResult)) {

1088 op, "Failed to merge partial results from tiling");

1089 }

1091 initTensors,

1092 loops,

1093 mergeResult->replacements,

1094 tilingResult->generatedSlices,

1095 mergeResult->mergeOps};

1096 }

1097

1098 FailureOrscf::SCFTilingResult

1100 PartialReductionOpInterface op,

1104 options.setReductionTilingStrategy(

1106 PartialReductionOuterReduction);

1107 options.setTileSizes(tileSize);

1109 }

1110

1111

1112

1113

1114

1115

1116

1117

1118

1119

1120

1121 static std::tuple<OpResult, std::optional<OpOperand *>>

1124 std::optional<OpOperand *> destinationIterArg;

1125 assert(!loops.empty() && "expected non empty loops container");

1126 auto loopIt = loops.rbegin();

1127 while (loopIt != loops.rend() && isa(source->get())) {

1128 auto iterArg = cast(source->get());

1129 auto loop = *loopIt;

1130 if (iterArg.getOwner()->getParentOp() != loop)

1131 break;

1132 source = loop.getTiedLoopInit(iterArg);

1133 loopIt++;

1134 }

1135 if (loopIt == loops.rend())

1136 destinationIterArg = source;

1137 return {dyn_cast(source->get()), destinationIterArg};

1138 }

1139

1140

1141

1142 std::optionalscf::SCFFuseProducerOfSliceResult

1144 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,

1146

1147

1148 auto [fusableProducer, destinationInitArg] =

1150 loops);

1151 if (!fusableProducer)

1152 return std::nullopt;

1153 unsigned resultNumber = fusableProducer.getResultNumber();

1154

1157

1158

1159

1160 SmallVector origDestinationTensors, clonedOpDestinationTensors;

1161 Operation *fusableProducerOp = fusableProducer.getOwner();

1162 if (isa(fusableProducerOp) &&

1164 rewriter, fusableProducerOp->getLoc(), fusableProducerOp,

1165 origDestinationTensors)))

1166 return std::nullopt;

1167

1168 clonedOpDestinationTensors = origDestinationTensors;

1169 if (destinationInitArg &&

1170 isa(fusableProducerOp)) {

1171

1172

1173

1174 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();

1175 }

1176

1178 rewriter, fusableProducerOp, clonedOpDestinationTensors);

1179

1180

1181

1183 llvm::to_vector(candidateSliceOp->getOperands());

1184 candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);

1185 tensor::ExtractSliceOp clonedCandidateSliceOp =

1187 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);

1188

1189

1190 FailureOr tileAndFuseResult =

1192 rewriter, clonedCandidateSliceOp,

1193 clonedProducerOp->getResult(resultNumber));

1194 if (failed(tileAndFuseResult))

1195 return std::nullopt;

1196

1197

1199 tileAndFuseResult->tiledValues[0]);

1200 rewriter.eraseOp(clonedCandidateSliceOp);

1201 rewriter.eraseOp(clonedProducerOp);

1202

1203

1204

1205

1206

1207

1208

1209

1210

1211

1212

1213

1214

1215

1216

1217

1218

1219

1220

1221

1222

1223

1224

1225

1226

1227

1228

1229

1230

1231

1232

1233

1234

1235

1236

1237

1238

1239

1240

1241

1242

1243

1244

1245

1246 if (destinationInitArg &&

1247 isa(fusableProducerOp) && !loops.empty()) {

1248 loops.front()

1249 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]

1250 .set(origDestinationTensors[resultNumber]);

1251 }

1253 fusableProducer, tileAndFuseResult->tiledValues[0],

1254 tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};

1255 }

1256

1257

1259 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,

1263 if (loops.empty())

1264 return success();

1265

1267 *tiledOwner = fusedProducerInfo.tiledOps[0];

1268

1270

1272 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq(

1274 : llvm::to_vector(yieldResultNumber);

1276 for (const auto &resultNumber : initNumberList) {

1278 rewriter, loc, originalOwner->getResult(resultNumber));

1279 if (succeeded(initValue)) {

1280 initValueList.push_back(initValue.value());

1281 } else {

1282 return failure();

1283 }

1284 }

1285

1293

1294

1296 sliceSizes = sliceOp.getMixedSizes();

1297

1298

1299 if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))

1300 return failure();

1301

1302 unsigned sliceResultNumber =

1304

1305 auto tilableOp = cast(originalOwner);

1306

1308

1309 if (tilableOp->getNumResults() > 1 &&

1310 failed(tilableOp.getIterationDomainTileFromResultTile(

1311 rewriter, sliceResultNumber, sliceOffset, sliceSizes,

1312 iterDomainOffset, iterDomainSizes))) {

1313

1314

1315

1316

1317

1318

1319

1320

1321 return failure();

1322 }

1323

1324

1325

1327 for (const auto &resultNumber : initNumberList) {

1328 if (resultNumber == sliceResultNumber) {

1329 offsetList.push_back(sliceOffset);

1330 sizesList.push_back(sliceSizes);

1331 } else {

1332 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());

1333

1335 if (failed(tilableOp.getResultTilePosition(

1336 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,

1337 offset, sizes))) {

1338 return failure();

1339 }

1340 offsetList.push_back(offset);

1341 sizesList.push_back(sizes);

1342 }

1343 }

1344

1345

1346

1347 if (auto tiledDestStyleOp =

1348 dyn_cast(tiledOwner)) {

1350 for (const auto &&[index, newRegionArg] :

1352 auto destSlice = rewriter.createtensor::ExtractSliceOp(

1353 loc, newRegionArg, offsetList[index], sizesList[index],

1356 generatedSlices.push_back(destSlice);

1357 unsigned resultNumber = initNumberList[index];

1359 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);

1360 });

1361 }

1362 }

1363

1364

1365

1368 for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {

1369 tiledResult.push_back(tiledOwner->getResult(resultNumber));

1370 tiledOffset.emplace_back(offsetList[index]);

1371 tiledSizes.emplace_back(sizesList[index]);

1372 }

1373 return success();

1374 };

1375

1377 newYieldValuesFn))) {

1378 return failure();

1379 }

1380 return generatedSlices;

1381 }

1382

1383 namespace {

1384

1385

1386

1387

1388

1389

1390

1391

1393 public:

1394 explicit SliceTrackingListener(

1395 std::optional patterns);

1396 SliceTrackingListener() = default;

1397

1398

1399

1400

1401

1403

1404

1405 void notifyOperationInserted(Operation *op,

1407

1408

1410

1411

1412 void notifyOperationErased(Operation *op) override;

1413

1414

1415 void notifyOperationReplaced(Operation *op, ValueRange replacement) override;

1416

1417

1418

1419 std::dequetensor::ExtractSliceOp worklist;

1420

1421 private:

1422

1423

1424 std::optional patterns = std::nullopt;

1425 };

1426

1427 SliceTrackingListener::SliceTrackingListener(

1428 std::optional p) {

1430 }

1431

1432 LogicalResult

1435 if (auto slice = dyn_casttensor::ExtractSliceOp(op))

1436 worklist.push_back(slice);

1437 }

1438

1440 return success();

1441

1446 }

1447

1448 void SliceTrackingListener::notifyOperationInserted(

1450 auto slice = dyn_casttensor::ExtractSliceOp(op);

1451 if (!slice)

1452 return;

1453 worklist.push_back(slice);

1454 }

1455

1456

1457

1458

1459 void SliceTrackingListener::removeOp(Operation *op) {

1460 if (!isatensor::ExtractSliceOp(op))

1461 return;

1462 auto iter = worklist.begin();

1463 while (iter != worklist.end()) {

1464 if (*iter == op)

1465 break;

1466 iter++;

1467 }

1468 if (iter == worklist.end())

1469 return;

1470

1471 worklist.erase(iter);

1472 }

1473

1474 void SliceTrackingListener::notifyOperationErased(Operation *op) {

1475 removeOp(op);

1476 }

1477

1478 void SliceTrackingListener::notifyOperationReplaced(Operation *op,

1480 removeOp(op);

1481 }

1482

1483

1484

1485

1486

1487

1488

1489

1491 public:

1494 : ForwardingListener(listener), replacements(replacements) {}

1495

1496 void updateReplacementValues(ValueRange origValues,

1498

1499

1500 for (auto &[key, val] : replacements) {

1501 for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {

1502 if (val == orig) {

1503 val = replace;

1504 }

1505 }

1506 }

1507 }

1508

1509 void notifyOperationReplaced(Operation *op, Operation *newOp) override {

1510 ForwardingListener::notifyOperationReplaced(op, newOp);

1512 }

1513

1514 void notifyOperationReplaced(Operation *op, ValueRange values) override {

1515 ForwardingListener::notifyOperationReplaced(op, values);

1516 updateReplacementValues(op->getResults(), values);

1517 }

1518

1519 private:

1521 };

1522

1523 }

1524

1525

1526 FailureOrscf::SCFTileAndFuseResult

1528 RewriterBase &rewriter, TilingInterface consumer,

1530

1531

1532 if (!consumer->getNumResults()) {

1534 consumer, "invalid pattern for op with no results");

1535 }

1536

1537

1539

1540 FailureOrscf::SCFTilingResult tilingResult =

1542

1543 if (failed(tilingResult))

1544 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");

1545 tiledAndFusedOps.insert_range(tilingResult->tiledOps);

1546

1548 for (auto [origVal, replacement] :

1549 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {

1550 replacements[origVal] = replacement;

1551 }

1552

1553

1554 auto &loops = tilingResult->loops;

1555 if (loops.empty()) {

1557 replacements};

1558 }

1559

1560

1561

1562

1564 auto resetListener =

1565 llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });

1566 ReplacementListener replaceListener(replacements, previousListener);

1568

1569

1570

1571

1572

1573

1574

1575

1576 struct WorklistItem {

1577 tensor::ExtractSliceOp candidateSlice;

1579 };

1580

1581 SliceTrackingListener sliceTracker =

1582 SliceTrackingListener(options.cleanupPatterns);

1583

1584 if (failed(

1585 sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {

1586 return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");

1587 }

1589 while (!sliceTracker.worklist.empty()) {

1590 auto candidateSlice = sliceTracker.worklist.front();

1591 sliceTracker.worklist.pop_front();

1592

1593 auto [fusableProducer, destinationInitArg] =

1595 loops);

1596 if (!fusableProducer)

1597 continue;

1598

1599 std::optionalSCFTileAndFuseOptions::ControlFnResult controlFnResult =

1600 options.fusionControlFn(candidateSlice, fusableProducer,

1601 destinationInitArg.has_value());

1602 if (!controlFnResult)

1603 continue;

1604

1605 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};

1606

1607

1608

1609

1610 std::optionalscf::SCFFuseProducerOfSliceResult fusedResult =

1612 loops);

1613 if (!fusedResult)

1614 continue;

1615

1617

1618 if (worklistItem.controlFnResult.yieldProducerReplacement) {

1619

1620

1621

1622

1623 Operation *fusableProducerOp = fusedResult->origProducer.getOwner();

1624 FailureOr<SmallVector<Operation *>> newSlices =

1626 worklistItem.candidateSlice,

1627 fusedResult.value(), loops);

1628 if (failed(newSlices)) {

1630 fusableProducerOp, "failed to replacement value for this "

1631 "operation from within the tiled loop");

1632 }

1633 worklistCandidates.append(newSlices.value());

1634 for (auto [index, result] :

1636 replacements[result] = loops.front()->getResult(

1637 loops.front()->getNumResults() -

1639 }

1640 }

1641 if (Operation *tiledAndFusedOp =

1642 fusedResult->tiledAndFusedProducer.getDefiningOp()) {

1643 fusedProducers.insert(fusedResult->origProducer.getDefiningOp());

1644 tiledAndFusedOps.insert(tiledAndFusedOp);

1645 }

1646

1647 if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {

1648 return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");

1649 }

1650 }

1651

1653 replacements};

1654 }

1655

1656

1657

1658

1659

1660

1661

1662 static LogicalResult

1664 Value result = candidateSliceOp.getResult();

1666 if (!llvm::hasSingleElement(uses)) {

1667 LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");

1668 return failure();

1669 }

1670 OpOperand &operandUse = (*uses.begin());

1672 if (!isascf::YieldOp(userOp)) {

1673 LLVM_DEBUG(llvm::dbgs()

1674 << "Expected scf.yield to be the only user, but got -> "

1675 << (*userOp));

1676 return failure();

1677 }

1679 LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "

1680 "be in the same block\n");

1681 return failure();

1682 }

1683 return success();

1684 }

1685

1686

1687

1689 if (!isa(loopOp))

1690 return failure();

1691 Operation *firstUserOfLoop = nullptr;

1693

1694

1695

1696

1697

1698

1699

1700

1701

1702

1703

1704

1705

1706

1707

1708 if (isatensor::ParallelInsertSliceOp(userOp))

1710

1711 if (loopOp->getBlock() != userOp->getBlock())

1712 return failure();

1713

1714 if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))

1715 firstUserOfLoop = userOp;

1716 }

1717 return firstUserOfLoop;

1718 }

1719

1720

1721

1722

1723

1724

1725

1726

1727

1728

1729

1730

1731

1732

1733

1734

1735

1736

1737

1738

1739

1740

1741

1742

1743

1744

1745

1746

1747

1748

1749

1750

1751

1752

1753

1754

1755

1756

1757

1758 static FailureOr<llvm::SetVector<Operation *>>

1760 bool reorderOperations) {

1761 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);

1762 if (failed(firstUserOfLoop))

1763 return failure();

1764

1767 options.inclusive = true;

1768 options.omitBlockArguments = true;

1769 bool includeLoopOp = false;

1771 if (op == loopOp) {

1772 includeLoopOp = true;

1773 return false;

1774 }

1775

1776

1778 };

1780 for (auto operand : consumerOp->getOperands()) {

1782 assert(result.succeeded() && "expected a backward slice");

1783 (void)result;

1784 }

1785

1786 if (!slice.empty()) {

1787

1788

1789

1790

1791

1792

1793

1794

1795

1796 if (includeLoopOp || !reorderOperations)

1797 return failure();

1798 }

1799

1800 return slice;

1801 }

1802

1803

1804

1805

1808 unsigned resultNumber) {

1809 if (!isa(loopOp))

1810 return failure();

1814 Operation *consumerOp = opOperand.getOwner();

1815

1816 if (!isa(consumerOp) ||

1817 !isa(consumerOp)) {

1818

1819

1820

1821 continue;

1822 }

1823

1824 if (loopBlock != consumerOp->getBlock())

1825 continue;

1826

1827

1829 continue;

1830

1831 FailureOr<llvm::SetVector<Operation *>> slice =

1833 if (failed(slice))

1834 continue;

1835

1836

1837 if (!slice->empty()) {

1839 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);

1840 assert(succeeded(firstUserOfLoop) && "First user of loop is not found");

1841 for (auto op : *slice) {

1842 rewriter.moveOpBefore(op, *firstUserOfLoop);

1843 }

1844 }

1845 return &opOperand;

1846 }

1847 return failure();

1848 }

1849

1850

1851

1852

1853

1854

1855

1856

1857

1858

1859

1860

1861

1862

1863 static bool

1865 assert(!loops.empty() && "unexpected empty loop nest");

1866 if (loops.size() == 1) {

1867 return isa_and_nonnullscf::ForOp(loops.front().getOperation());

1868 }

1869 for (auto [outerLoop, innerLoop] :

1870 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {

1871 auto outerFor = dyn_cast_or_nullscf::ForOp(outerLoop.getOperation());

1872 auto innerFor = dyn_cast_or_nullscf::ForOp(innerLoop.getOperation());

1873 if (!outerFor || !innerFor) {

1874 return false;

1875 }

1876 auto outerBBArgs = outerFor.getRegionIterArgs();

1877 auto innerIterArgs = innerFor.getInitArgs();

1878 if (outerBBArgs.size() != innerIterArgs.size()) {

1879 return false;

1880 }

1881

1882 for (auto [outerBBArg, innerIterArg] :

1883 llvm::zip_equal(outerBBArgs, innerIterArgs)) {

1884 if (!llvm::hasSingleElement(outerBBArg.getUses()) ||

1885 innerIterArg != outerBBArg) {

1886 return false;

1887 }

1888 }

1889

1891 castscf::YieldOp(outerFor.getBody()->getTerminator())->getOperands();

1892 ValueRange innerResults = innerFor.getResults();

1893 if (outerYields.size() != innerResults.size()) {

1894 return false;

1895 }

1896 for (auto [outerYield, innerResult] :

1897 llvm::zip_equal(outerYields, innerResults)) {

1898 if (!llvm::hasSingleElement(innerResult.getUses()) ||

1899 outerYield != innerResult) {

1900 return false;

1901 }

1902 }

1903 }

1904 return true;

1905 }

1906

1907

1908

1909

1910

1911

1912

1913 static FailureOr<OpOperand *>

1915 tensor::InsertSliceOp candidateSliceOp,

1917 assert(!loops.empty() && "unexpected loops to be empty");

1918

1920 if (containingOp != loops.back()) {

1922 candidateSliceOp,

1923 "expected slice to be within body of inner-most loop");

1924 }

1925

1926

1929 candidateSliceOp, "expected passed loops to be perfectly nested.");

1930 }

1931

1933 return failure();

1934 Value sliceResult = candidateSliceOp.getResult();

1935

1936

1937 OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());

1939

1940 scf::ForOp topLevelForOp = castscf::ForOp(loops.front().getOperation());

1941

1943 }

1944

1945

1946

1947 static FailureOr<OpOperand *>

1949 tensor::ParallelInsertSliceOp candidateSliceOp,

1951 assert(!loops.empty() && "unexpected loops to be empty");

1952

1953 if (loops.size() != 1) {

1955 candidateSliceOp, "expected single surrounding scf.forall");

1956 }

1957 auto forallOp = dyn_castscf::ForallOp(loops.front().getOperation());

1958 if (!forallOp) {

1960 candidateSliceOp, "expected single surrounding scf.forall");

1961 }

1962

1963

1964 Value sliceDest = candidateSliceOp.getDest();

1965 auto iterArg = dyn_cast(sliceDest);

1966 if (!iterArg)

1967 return failure();

1968 if (iterArg.getOwner()->getParentOp() != forallOp)

1969 return failure();

1970

1971 unsigned resultNumber =

1972 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))

1973 .getResultNumber();

1974

1976 }

1977

1978

1979

1980 static FailureOr<OpOperand *>

1983 assert(!loops.empty() && "unexpected empty loops");

1984 if (auto insertSlice = dyn_casttensor::InsertSliceOp(sliceOp)) {

1986 } else if (auto parallelInsertSlice =

1987 dyn_casttensor::ParallelInsertSliceOp(sliceOp)) {

1989 } else {

1990 return failure();

1991 }

1992 }

1993

1994

1995

1996 FailureOrscf::SCFFuseConsumerOfSliceResult

2000

2001

2002 if (loops.empty()) {

2004 "cannot call tile and fuse consumer with an empty loop nest");

2005 }

2006 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(

2007 candidateSliceOp))

2008 return failure();

2009

2010

2011

2012 FailureOr<OpOperand *> maybeConsumerOpOperand =

2014 if (failed(maybeConsumerOpOperand)) {

2016 "could not fetch consumer to fuse");

2017 }

2018 OpOperand *consumerOpOperand = *maybeConsumerOpOperand;

2020 unsigned operandNumber = consumerOpOperand->getOperandNumber();

2021 unsigned resultNumber = 0;

2022 if (auto producerResult = dyn_cast(consumerOpOperand->get())) {

2023 resultNumber = producerResult.getResultNumber();

2024 } else {

2026 consumerOp, "consumer op's operand doesn't seem to be an OpResult");

2027 }

2028

2029 LoopLikeOpInterface outerMostLoop = loops.front();

2030 LoopLikeOpInterface innerMostLoop = loops.back();

2031

2032

2035 outerMostLoop, "the first user of loop should not dominate any define "

2036 "of consumer operand(s)");

2037 }

2038

2040

2041

2042 auto dstOp = dyn_cast(consumerOp);

2043 if (!dstOp)

2045 "consumer op is not DPS operation");

2047 llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });

2048 if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {

2050 consumerOp,

2051 "consumer op taking the result of scf.for as init is not supported");

2052 }

2054

2055 Location loc = outerMostLoop->getLoc();

2056

2057

2058

2059 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);

2060 if (failed(firstUserOfLoop)) {

2062 outerMostLoop, "could not find the first user of outer most loop");

2063 }

2064 rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);

2065

2066

2067

2068

2069

2070 tensor::InsertSliceOp clonedInsertSliceOp;

2071 if (auto sliceOp =

2072 dyn_casttensor::ParallelInsertSliceOp(candidateSliceOp)) {

2073 auto newForallOp = castscf::ForallOp(innerMostLoop.getOperation());

2075 clonedInsertSliceOp = rewriter.createtensor::InsertSliceOp(

2076 loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),

2077 sliceOp.getMixedSizes(), sliceOp.getMixedStrides());

2078 } else {

2080 clonedInsertSliceOp =

2081 casttensor::InsertSliceOp(rewriter.clone(*candidateSliceOp));

2082 }

2083

2084

2085 auto clonedConsumerOp = cast(rewriter.clone(*consumerOp));

2086

2087

2088

2089 OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);

2091 operandToReplace.set(clonedInsertSliceOp.getResult());

2092 });

2093

2094

2095

2096 auto ossSliceOp =

2097 cast(clonedInsertSliceOp.getOperation());

2098 FailureOr tileAndFuseResult =

2100 rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));

2101 if (failed(tileAndFuseResult)) {

2102 return failure();

2103 }

2104 auto tiledConsumerOp = cast(tileAndFuseResult->tiledOps[0]);

2105 rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),

2106 clonedInsertSliceOp.getSource());

2107

2108

2115

2117

2121

2122

2125 candidateSliceOp, "containingOp's result yield with stride");

2126 }

2127

2128

2129

2130

2131

2132

2133

2135 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(

2136 rewriter, operandNumber, offsets, sizes, iterDomainOffsets,

2137 iterDomainSizes))) {

2139 clonedConsumerOp,

2140 "can't get iter domain position from input position");

2141 }

2142

2143

2144

2145

2146 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();

2148 totalNumResultsOfConsumer);

2150 totalNumResultsOfConsumer);

2151 for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {

2152 if (failed(tiledConsumerOp.getResultTilePosition(

2153 rewriter, idx, iterDomainOffsets, iterDomainSizes,

2154 resultOffsets[idx], resultSizes[idx]))) {

2156 tiledConsumerOp,

2157 "can't get result domain position from iter domain position");

2158 }

2159 }

2160

2161

2162

2163 if (auto tiledDestStyleOp = dyn_cast(

2164 tiledConsumerOp.getOperation())) {

2166 for (const auto &&[index, newRegionArg] :

2168 auto destSlice = rewriter.createtensor::ExtractSliceOp(

2169 loc, newRegionArg, resultOffsets[index], resultSizes[index],

2172

2173

2174 auto dstNumber = index;

2176 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);

2177 });

2178 }

2179 }

2180

2181

2182

2185 for (const auto &&[index, result] :

2187 tiledResult.push_back(result);

2188 tiledOffset.emplace_back(resultOffsets[index]);

2189 tiledSizes.emplace_back(resultSizes[index]);

2190 }

2191 return success();

2192 };

2193

2195 newYieldValuesFn))) {

2197 "unable to add new inits to nest loop");

2198 }

2199

2200

2201

2202

2203 for (auto &&[oldResult, newResult] :

2205 loops.front()->getResults().take_back(newInits.size()))) {

2207 }

2208

2209

2210 rewriter.eraseOp(clonedConsumerOp);

2211

2213 consumerOpOperand,

2214 &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),

2215 tileAndFuseResult->tiledOps};

2216 }

2217

2218

2219

2220

2221

2222 FailureOr<SmallVectorscf::ForOp>

2224 TilingInterface op) {

2225

2226 if (op->getNumResults() > 0) {

2228 op, "unable to lower to loops operations with return values");

2229 }

2230

2235 for (auto loopRange : domain) {

2236 Value offsetVal =

2240 Value strideVal =

2242 auto loop = rewriter.createscf::ForOp(op.getLoc(), offsetVal, sizeVal,

2244 loops.push_back(loop);

2245 ivs.push_back(loop.getInductionVar());

2247 }

2248 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {

2249 return failure();

2250 }

2251 return loops;

2252 }

static llvm::ManagedStatic< PassManagerOptions > options

static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)

Helper method to adjust the interchange vector to match the iteration domain.

static LogicalResult verifyTileSizeOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)

Verify the tile size options are set in a consistent manner.

static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)

A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...

static bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)

Check that the loop is perfectly nested.

std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn

A function that allows returning additional yielded values during yieldTiledValuesAndReplace.

static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)

Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.

static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)

static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)

This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...

static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)

Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...

static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult tileSize)

Returns the bounded tile size given the current offset, loopRange and tileSize, i....

FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)

Append the specified additional newInitOperands operands to the loops existing init operands (or simi...

static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)

Generate the tile-loop nest using scf.for operation.

static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)

Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...

static void checkSafeToTileToForall(TilingInterface op, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)

Checks if any of the tiled loops are not parallel.

static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)

Method to instantiate the tile sizes and/or number of threads specified by the user.

static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)

Return the untiled producer whose slice is used in a tiled consumer.

static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)

Generate the tile-loop nest using the loop construct specifed in options.

static bool tileDividesIterationDomain(Range loopRange)

Check if stride evenly divides the trip count size - offset.

static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes)

Function to return the bounds of the loops to be generated.

static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)

Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...

static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)

An utility to get the first user of the given loopOp.

static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)

static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)

Method to add new init values to a loop nest.

static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ValueRange partialResults, const scf::SCFTilingOptions &options)

static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)

Fetch the untiled consumer of the outermost scf.for's result which is yielded by a tensor....

static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, const scf::SCFTilingOptions &options)

static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)

Generate the tile-loop nest using scf.forall operation.

Base type for affine expression.

AffineExpr ceilDiv(uint64_t v) const

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)

Returns an AffineMap with 'numDims' identity result dim exprs.

static AffineMap get(MLIRContext *context)

Returns a zero result affine map with no dimensions or symbols: () -> ().

Block represents an ordered list of Operations.

Operation * getTerminator()

Get the terminator operation of this block.

BlockArgListType getArguments()

IntegerAttr getIndexAttr(int64_t value)

MLIRContext * getContext() const

ArrayAttr getArrayAttr(ArrayRef< Attribute > value)

A class for computing basic dominance information.

bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const

Return true if operation A properly dominates operation B, i.e.

This class allows control over how the GreedyPatternRewriteDriver works.

IRValueT get() const

Return the current value being used by this operand.

void set(IRValueT newValue)

Set the current value being used by this operand.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

This class represents a saved insertion point.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Block::iterator getInsertionPoint() const

Returns the current insertion point of the builder.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

void setListener(Listener *newListener)

Sets the listener of this builder to the one provided.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

Listener * getListener() const

Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

This class represents a single result from folding an operation.

This class represents an operand of an operation.

unsigned getOperandNumber()

Return which operand this is in the OpOperand list of the Operation.

This is a value defined by a result of an operation.

Operation * getOwner() const

Returns the operation that owns this result.

unsigned getResultNumber() const

Returns the number of this result.

Operation is the basic unit of execution within MLIR.

bool use_empty()

Returns true if this operation has no uses.

bool isBeforeInBlock(Operation *other)

Given an operation 'other' that is within the same parent block, return whether the current operation...

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

Location getLoc()

The source location the operation was defined or derived from.

Operation * getParentOp()

Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...

Block * getBlock()

Returns the operation block that contains this operation.

OpTy getParentOfType()

Return the closest surrounding parent operation that is of type 'OpTy'.

operand_range getOperands()

Returns an iterator on the underlying Value's.

user_range getUsers()

Returns a range of all users.

result_range getResults()

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

unsigned getNumResults()

Return the number of results held by this operation.

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

void replaceAllUsesWith(Value from, Value to)

Find uses of from and replace them with to.

void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)

Inline the operations of block 'source' into the end of block 'dest'.

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

void moveOpBefore(Operation *op, Operation *existingOp)

Unlink this operation from its current block and insert it right before existingOp which may be in th...

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

use_range getUses() const

Returns a range of all uses, which is useful for iterating over all uses.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

Operation * getOwner() const

Return the owner of this operand.

OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...

OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...

OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)

Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...

FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSizes)

Method to tile a reduction and generate a parallel op within a serial loop.

FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)

Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.

FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)

Method to lower an op that implements the TilingInterface to loops/scalars.

FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})

Reconstruct the fused producer from within the tiled-and-fused code.

FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)

Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.

std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)

Implementation of fusing producer of a single slice by computing the slice of the producer in-place.

FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)

Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...

FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)

This is a helper function for DestinationStyleOpInterface.

LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)

This is a helper function for DestinationStyleOpInterface.

FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)

Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...

Include the generated interface declarations.

LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})

Fills backwardSlice with the computed backward slice (i.e.

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to DimExpr at positions: [0 .

AffineMap inversePermutation(AffineMap map)

Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...

LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)

Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...

const FrozenRewritePatternSet & patterns

bool isZeroInteger(OpFoldResult v)

Return true if v is an IntegerAttr with value 0.

void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to SymbolExpr at positions: [0 .

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)

OpFoldResult getAsOpFoldResult(Value val)

Given a value, try to extract a constant Attribute.

@ ExistingAndNewOps

Only pre-existing and newly created ops are processed.

void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)

Apply the permutation defined by permutation to inVec.

bool isPermutationVector(ArrayRef< int64_t > interchange)

Method to check if an interchange vector is a permutation.

bool isOneInteger(OpFoldResult v)

Return true if v is an IntegerAttr with value 1.

SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)

Sorts all operations in toSort topologically while also considering region semantics.

SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)

Helper method to apply to inverse a permutation.

Container for the result of merge operation of tiling.

This class represents a listener that may be used to hook into various actions within an OpBuilder.

Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...

A listener that forwards all notifications to another listener.

Container for result values of tiling.

Fuse the consumer of the source of candidateSliceOp by computing the required slice of the consumer i...

Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...

SmallVector< Operation * > tiledOps

Control function to check if a slice needs to be fused or not, The control function receives 1) the s...

Options used to control tile + fuse.

Transformation information returned after tile and fuse.

Options to use to control tiling.

SCFTileSizeComputationFunction tileSizeComputationFunction

Computation function that returns the tile sizes to use for each loop.

SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)

Convenience function to set the numThreadsComputationFunction to a function that computes num threads...

SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)

Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...

ReductionTilingStrategy

Specify how reduction dimensions should be tiled.

@ PartialReductionOuterReduction

Transformation information returned after tiling.

SmallVector< Operation * > tiledOps

Tiled operations that are generated during tiling.