MLIR: lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

12

44 #include "llvm/ADT/STLExtras.h"

45 #include "llvm/ADT/ScopeExit.h"

46 #include "llvm/ADT/TypeSwitch.h"

47 #include "llvm/Support/Debug.h"

48 #include <type_traits>

49

50 using namespace mlir;

53

54 #define DEBUG_TYPE "linalg-transforms"

55 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")

56 #define DBGSNL() (llvm::dbgs() << "\n")

57 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")

58

59

60

61

62

63

64 template <typename PatternTy, typename... Args>

65 static FailureOr tryApply(Operation *operation, Args &&...args) {

66

67 using OpTy = typename llvm::function_traits<

68 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;

69 auto op = dyn_cast(operation);

70 if (!op)

71 return failure();

72

73

74 PatternTy pattern(operation->getContext(), std::forward(args)...);

75

76

78 public:

79 explicit TrivialPatternRewriter(MLIRContext *context)

81 };

82 TrivialPatternRewriter rewriter(operation->getContext());

83 rewriter.setInsertionPoint(operation);

84 auto result = pattern.returningMatchAndRewrite(op, rewriter);

85 if (failed(result))

86 return failure();

87 return cast(result->getOperation());

88 }

89

90

91

92

97 if (auto attr = dyn_cast(ofr)) {

98 if (!isa(attr))

99 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";

100 result.push_back(ofr);

101 continue;

102 }

103

104 Value transformValue = cast(ofr);

105 if (isa(transformValue.getType())) {

107 if (params.size() != 1)

108 return transformOp.emitDefiniteFailure()

109 << "requires exactly one parameter associated";

110 result.push_back(params[0]);

111 continue;

112 }

113

114 auto payloadOps = state.getPayloadOps(transformValue);

115 if (!llvm::hasSingleElement(payloadOps)) {

117 transformOp.emitSilenceableError()

118 << "handle must be mapped to exactly one payload op";

119 diag.attachNote(transformValue.getLoc())

120 << "mapped to " << llvm::range_size(payloadOps) << " payload ops";

122 }

123

124 Operation *op = *payloadOps.begin();

127 transformOp.emitSilenceableError()

128 << "payload op must have exactly 1 index result";

132 }

133 result.push_back(op->getResult(0));

134 }

135

137 }

138

139

140

141

142

143

144

148 if (isa(packedHandle.getType())) {

150 for (auto param : params) {

151 if (!isa(param))

152 return transformOp.emitDefiniteFailure()

153 << "expected the parameter to be associated with an integer "

154 "attribute";

155 result.push_back(param);

156 }

158 }

159

160 for (Operation *op : state.getPayloadOps(packedHandle)) {

161 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {

163 transformOp.emitSilenceableError()

164 << "payload op must have exactly 1 index result";

165 diag.attachNote(op->getLoc())

166 << "has " << op->getNumResults() << " results";

168 }

169 result.push_back(op->getResult(0));

170 }

171

173 }

174

175

176

177

178

180 TransformState &state, TransformOpInterface &transformOp,

182 for (OpFoldResult paramOrHandle : mixedResults) {

183 if (auto attr = dyn_cast(paramOrHandle)) {

184 reified.push_back(cast(attr).getInt());

185 continue;

186 } else if (isa(cast(paramOrHandle).getType())) {

187 ArrayRef params = state.getParams(cast(paramOrHandle));

188 if (params.size() != 1)

189 return transformOp.emitSilenceableError() << "expected a single param";

190 reified.push_back(

191 cast(params.front()).getValue().getSExtValue());

192 continue;

193 }

194

195 Value handle = cast(paramOrHandle);

196 if (!isa(handle.getType()))

197 return transformOp.emitSilenceableError() << "unexpected value handle";

198 auto payload = state.getPayloadOps(handle);

199 if (!llvm::hasSingleElement(payload))

200 return transformOp.emitSilenceableError()

201 << "requires param or handle that is mapped to 1 payload op";

202

203 Operation *paramOrHandlePayloadOp = *payload.begin();

204 if (paramOrHandlePayloadOp->getNumResults() != 1 ||

206 return transformOp.emitSilenceableError()

207 << "requires param or handle to be result of op with 1 index "

208 "result";

209 }

210

211 IntegerAttr attr;

213 return transformOp.emitSilenceableError()

214 << "requires param or handle to be the result of a constant like "

215 "op";

216

217 reified.push_back(attr.getInt());

218 }

220 }

221

222

223

224

225

226 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(

229 }

230

231 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(

234 }

235

236 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(

239 }

240

241 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(

245 }

246

247 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(

250 options.rankReductionStrategy =

253 }

254

255 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(

258 }

259

260 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(

263 }

264

265 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(

268 }

269

270 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(

273 }

274

275 void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(

278 }

279

280

281

282

283

284 void transform::BufferizeToAllocationOp::build(OpBuilder &b,

289 resultTypes.push_back(b.getTypetransform::AnyValueType());

290 resultTypes.push_back(b.getTypetransform::AnyOpType());

291 return build(b, result,

292 resultTypes,

293 target,

294 memorySpace);

295 }

296

297 void transform::BufferizeToAllocationOp::build(OpBuilder &b,

300 int64_t memorySpace) {

302 resultTypes.push_back(b.getTypetransform::AnyValueType());

303 resultTypes.push_back(b.getTypetransform::AnyOpType());

304 return build(b, result,

305 resultTypes,

306 target,

308 }

309

310 namespace {

312 public:

314

317 }

318

319 private:

320 void notifyOperationInserted(Operation *op,

322 ForwardingListener::notifyOperationInserted(op, previous);

323

324 if (previous.isSet())

325 return;

326 auto inserted = newOps.insert(op);

327 (void)inserted;

328 assert(inserted.second && "expected newly created op");

329 }

330

331 void notifyOperationErased(Operation *op) override {

332 ForwardingListener::notifyOperationErased(op);

333 op->walk([&](Operation *op) { newOps.erase(op); });

334 }

335

337 };

338 }

339

343

345 auto resetListener =

346 llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });

347 NewOpsListener newOpsListener(previousListener);

349

351 if (getMemcpyOp() == "bufferization.materialize_in_destination") {

354 } else if (getMemcpyOp() == "memref.copy") {

357 } else if (getMemcpyOp() == "linalg.copy") {

360 } else {

361 llvm_unreachable("invalid memcpy op");

362 }

363 if (getAllocOp() == "memref.alloc") {

366 } else if (getAllocOp() == "memref.alloca") {

369 } else {

370 llvm_unreachable("invalid alloc op");

371 }

372 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();

373 options.emitDealloc = getEmitDealloc();

374

375

377 getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();

379 for (Operation *op : state.getPayloadOps(getTarget())) {

382 if (!buffer) {

384 << "failed to bufferize operation";

385 diag.attachNote(op->getLoc()) << "target payload op";

387 }

388 allocatedBuffers.push_back(buffer);

389 }

390

391

392 results.setValues(cast(getAllocatedBuffer()), allocatedBuffers);

393 results.set(cast(getNewOps()), newOpsListener.getNewOps());

395 }

396

397 void transform::BufferizeToAllocationOp::getEffects(

399 if (getBufferizeDestinationOnly()) {

400

401

403 } else {

405 }

406 producesHandle(getOperation()->getOpResults(), effects);

408 }

409

411 if (getMemcpyOp() != "bufferization.materialize_in_destination" &&

412 getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")

413 return emitOpError() << "unsupported memcpy op";

414 if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")

415 return emitOpError() << "unsupported alloc op";

416 return success();

417 }

418

419

420

421

422

425 LinalgOp target,

428 #define DOWNSCALE(trans) \

429 { \

430 FailureOr res = tryApply(target); \

431 if (succeeded(res)) { \

432 results.push_back(*res); \

433 return DiagnosedSilenceableFailure::success(); \

434 } \

435 }

436

437 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>

438 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))

439

445 DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)

447 DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)

451 #undef DOWNSCALE_NORMAL

452 #undef DOWNSCALE_CALL

453 #undef DOWNSCALE

454 return emitDefaultSilenceableFailure(target);

455 }

456

457

458

459

460

461

462

463

468 auto decomposableOp = dyn_cast(target);

469 if (!decomposableOp) {

471 "payload is not a decomposable op"));

472 return emitDefaultSilenceableFailure(target);

473 }

474

475 FailureOr<SmallVector> maybeNewResults =

476 decomposableOp.decomposeOperation(rewriter);

477 if (failed(maybeNewResults))

478 return emitDefaultSilenceableFailure(target);

479

480 rewriter.replaceOp(decomposableOp, *maybeNewResults);

481 for (Value val : *maybeNewResults) {

482 Operation *definition = val.getDefiningOp();

483 if (definition)

485 }

487 }

488

489

490

491

492

493 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(

497 }

498

500 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(

504 options.allowReturnAllocsFromLoops = true;

505

506 for (Operation *target : state.getPayloadOps(getTarget())) {

508 if (failed(analyzeOp(target, state)))

510 << "failed to analyze op";

512 rewriter, target, state)))

514 << "failed to eliminate LinalgOp anchored tensor.empty ops";

515 }

517 }

518

519

520

521

522

523

524

525 template

529 function_ref<FailureOrscf::SCFTileAndFuseResult(TilingInterface)>

530 applyFn) {

533

534 for (Operation *target : payloadOps) {

535 auto tilingInterfaceOp = dyn_cast(target);

536 if (!tilingInterfaceOp)

537 return transformOp->emitError("only TilingInterface ops are supported");

538

540 FailureOrscf::SCFTileAndFuseResult tiledResults =

541 applyFn(tilingInterfaceOp);

542 if (failed(tiledResults))

543 return failure();

544

545

547 llvm::append_range(opsToReplace, tiledResults->fusedProducers);

548 for (Operation *toReplace : opsToReplace) {

549 for (OpResult res : toReplace->getResults())

550 if (auto replacement = tiledResults->replacements.lookup(res))

552 if (toReplace->use_empty()) {

553 rewriter.eraseOp(toReplace);

554 }

555 }

556

557

558 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());

559 assert(tiledResults->loops.size() == numLoops &&

560 "Mismatched number of loops, tile and fuse transform should have "

561 "failed");

562 for (unsigned int i = 0; i < numLoops; ++i)

563 loopOps[i].push_back(tiledResults->loops[i]);

564 }

565

566 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);

567 for (unsigned int i = 0; i < numLoops; ++i)

568 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);

569

570 return success();

571 }

572

578 extractFromIntegerArrayAttr<int64_t>(getTileSizes());

580 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());

581

586 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);

588 tileAndFuseOptions.tilingOptions = tilingOptions;

589

590 if (getApplyCleanup()) {

593 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);

597 }

598

600 rewriter, getOperation(), state.getPayloadOps(getTarget()),

601 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,

602 [&](TilingInterface tilingInterfaceOp)

603 -> FailureOrscf::SCFTileAndFuseResult {

604 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,

605 tileAndFuseOptions);

606 });

609 }

610

613 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());

614 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));

615 if (!std::is_permutation(sequence.begin(), sequence.end(),

616 permutation.begin(), permutation.end())) {

617 return emitOpError() << "expects interchange to be a permutation, found "

618 << getTileInterchange();

619 }

620

622 extractFromIntegerArrayAttr<int64_t>(getTileSizes());

623 size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);

624 if (numExpectedLoops != getNumResults() - 1)

625 return emitOpError() << "expects " << numExpectedLoops << " loop results";

626

627 return success();

628 }

629

630

631

632

633

634 void transform::FuseIntoContainingOp::build(OpBuilder &builder,

636 Value producerOp,

637 Value containingOp) {

638 result.addOperands({producerOp, containingOp});

640 result.addTypes({resultType, resultType});

641 }

642

643

644

650

651

655 if (!containingOp->isAncestor(user) &&

656 (domInfo.dominates(containingOp, user))) {

657 dominatedUsers.insert(user);

658 }

659 }

660 if (dominatedUsers.empty())

661 return nullptr;

662

663

664 auto forallOp = castscf::ForallOp(containingOp);

667

668

669 Location loc = forallOp.getLoc();

670 auto genericOp = dyn_castlinalg::GenericOp(producerOp);

671 if (!genericOp)

672 return nullptr;

675 newOuts.push_back(outputs[resultNumber]);

676

677

678 auto newforallOp = rewriter.createscf::ForallOp(

679 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),

680 forallOp.getMixedStep(), newOuts, forallOp.getMapping());

681 rewriter.eraseBlock(newforallOp.getBody());

682 newforallOp.getRegion().takeBody(forallOp.getRegion());

683

684

685

686

687 newforallOp.getBody()->addArgument(newOuts.back().getType(),

688 newOuts.back().getLoc());

689 auto bbArgs = newforallOp.getBody()->getArguments();

694 });

695

696

697 scf::InParallelOp terminatorOp = newforallOp.getTerminator();

699 terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));

700 Operation *firstYieldOp = yieldingOps.front();

703 Value dst = newforallOp.getRegionIterArgs().back();

705 rewriter.createtensor::ParallelInsertSliceOp(firstYieldOp->getLoc(), src,

706 dst, offsets, sizes, strides);

707

708 for (auto result : llvm::enumerate(forallOp.getResults())) {

710 newforallOp->getResult(result.index()));

711 }

713 newforallOp->getResults().back(),

715 Operation *user = use.getOwner();

716 return dominatedUsers.contains(user);

717 });

718 return newforallOp;

719 }

720

721

722

723

724

725

727

728

730 destWorklist.push_back(dst);

731

732 while (!destWorklist.empty()) {

733 Value currentDst = destWorklist.pop_back_val();

734

735

736

737 if (src == currentDst)

738 return true;

739

740

741

742 auto bbArg = dyn_cast(currentDst);

743 if (!bbArg)

744 continue;

745

746 Block *parentBlock = bbArg.getOwner();

747 assert(parentBlock && "unlinked block argument");

748

750 assert(parentOp && "expected block argument with parent operation");

751

752

753 auto parentLoop = dyn_cast(parentOp);

754 if (!parentLoop)

755 continue;

756

757 for (auto innerIterArg : parentLoop.getRegionIterArgs()) {

758

759 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);

760 Value loopBlockArgument =

762 destWorklist.push_back(loopBlockArgument);

763 }

764 }

765

766 return false;

767 }

768

769

770

771

772

773

774

775 static std::tuple<SmallVector<Operation *>, Operation *>

778 LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");

779 auto tileableProducer = dyn_cast(producerOp);

780 if (!tileableProducer) {

781 diag.attachNote(producerOp->getLoc())

782 << "producer is not a TileableInterface: " << *producerOp;

783 return {};

784 }

785

786

787

788

789 auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {

790 auto sliceOp = dyn_casttensor::ExtractSliceOp(user);

791 return sliceOp && containingOp->isProperAncestor(sliceOp);

792 });

793

794

795 if (it == tileableProducer->getUsers().end()) {

796 diag.attachNote(tileableProducer->getLoc())

797 << "could not find fusion opportunity for: " << *tileableProducer;

798 return {};

799 }

800 auto sliceOpToTile = casttensor::ExtractSliceOp(*it);

801

802

805

806

807

808

809

810

811

812

813 if (LoopLikeOpInterface containerLoop =

814 dyn_cast(sliceOpToTile->getParentOp())) {

817

818

819

821 cast(clone).getDpsInitsMutable()) {

822 Value producerOperand =

825 containerLoop.getRegionIterArgs()) {

826 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);

827 Value consumerOperand =

828 containerLoop->getOperand(bbArg->getOperandNumber());

829

830 if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {

831 initOperandPtr.set(containerIterArg);

832 }

833 }

834 }

835 });

836

837 tileableProducer = dyn_cast(clone);

838 }

839

840

841 int64_t resultNumber =

842 cast(sliceOpToTile.getSource()).getResultNumber();

843 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");

844

847

848 FailureOr tileAndFuseResult =

849 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,

850 sizes);

851

852 if (failed(tileAndFuseResult)) {

853 diag.attachNote(tileableProducer->getLoc())

854 << "failed to tile producer op: " << *tileableProducer;

855 return {};

856 }

857

858 #ifndef NDEBUG

859 for (auto *tiledOp : tileAndFuseResult->tiledOps) {

860 LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");

861 }

862 #endif

863

864

865 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(

866 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],

867 cast(sliceOpToTile->getResult(0).getType()).getShape());

868 if (failed(maybeRankReduced)) {

869 diag.attachNote(producerOp->getLoc())

870 << "shape types don't match (missing canonicalization?):\nTiledOp: "

871 << tileAndFuseResult->tiledValues[0]

872 << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';

873 return {};

874 }

875 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);

876

877

879 rewriter, diag, producerOp, containingOp, *tileAndFuseResult,

880 resultNumber, offsets, sizes);

881

882

883 if (dyn_cast(containingOp))

884 rewriter.eraseOp(tileableProducer);

885

886 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);

887 }

888

889

890

891

892

893

894

899 LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");

900

901 auto tileableProducer = dyn_cast(producerOp);

902 if (!tileableProducer) {

903 diag.attachNote(producerOp->getLoc())

904 << "producer is not a TileableInterface: " << *producerOp;

905 return {};

906 }

907

908

909 scf::ForallOp forallOp;

910 auto itProducerUses =

911 llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {

912 forallOp = dyn_castscf::ForallOp(use.getOwner());

913 return forallOp;

914 });

915

916 if (!forallOp || forallOp != containingOp) {

917 diag.attachNote(tileableProducer->getLoc())

918 << "could not find a use by the containing op: " << *tileableProducer;

919 return {};

920 }

921

922

923

924

925

926 OpOperand *pUse = &(*itProducerUses);

927 BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);

928

929

930

931

932 auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {

933 auto sliceOp = dyn_casttensor::ExtractSliceOp(user);

934 return sliceOp && containingOp->isProperAncestor(sliceOp);

935 });

936

937

938 if (itBBArgUsers == bbArg.getUsers().end()) {

939 diag.attachNote(containingOp->getLoc())

940 << "could not find fusion opportunity for bbArg: " << bbArg;

941 return {};

942 }

943 auto sliceOpToTile = casttensor::ExtractSliceOp(*itBBArgUsers);

944

945

948

949

950

951 int64_t resultNumber = cast(pUse->get()).getResultNumber();

952 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");

953

954

957 rewriter, tileableProducer->getLoc(), tileableProducer,

958 destinationTensors))) {

959 diag.attachNote(tileableProducer->getLoc())

960 << "failed to get destination tensors for: " << *tileableProducer;

961 return {};

962 }

963

965 bvm.map(destinationTensors[resultNumber], bbArg);

966 auto tileableProducerClone =

967 cast(rewriter.clone(*tileableProducer, bvm));

968 auto scopeGuard =

969 llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });

970

971

972 FailureOr tileAndFuseResult =

973 tileableProducerClone.generateResultTileValue(

974 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),

975 sliceOpToTile.getMixedSizes());

976 if (failed(tileAndFuseResult)) {

977 diag.attachNote(tileableProducer->getLoc())

978 << "failed to tile producer op: " << *tileableProducer;

979 return {};

980 }

981

982

983 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(

984 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],

985 cast(sliceOpToTile->getResult(0).getType()).getShape());

986 assert(succeeded(maybeRankReduced) && "unexpected shape");

987 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);

988

989

992 destinationTensors.front());

993 });

994

995 return tileAndFuseResult->tiledOps;

996 }

997

1001 LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");

1002

1003

1006 for (OpOperand &use : result.getUses()) {

1008 uses.push_back(&use);

1009 continue;

1010 }

1011

1012

1013 if (containingOp == use.getOwner()) {

1014 diag.attachNote(producerOp->getLoc())

1015 << "producer op use by containing op cannot be fused by cloning";

1016 return nullptr;

1017 }

1018 }

1019 }

1020

1021

1022 if (uses.empty()) {

1023 diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";

1024 return nullptr;

1025 }

1026

1027

1030

1031

1032 assert(!isatensor::ParallelInsertSliceOp(use->getOwner()) &&

1033 "Parallel insert slice is not a valid clone destination");

1034 unsigned resultNumber = cast(use->get()).getResultNumber();

1035 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");

1036

1039 fusedOp = rewriter.clone(*producerOp);

1041 use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });

1042

1043 return fusedOp;

1044 }

1045

1046 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {

1047

1048 return true;

1049 }

1050

1056 auto producerOps = state.getPayloadOps(getProducerOp());

1057 auto containingOps = state.getPayloadOps(getContainingOp());

1058 if (!llvm::hasSingleElement(containingOps)) {

1060 << "requires exactly one containing_op handle (got "

1061 << llvm::range_size(containingOps) << ")";

1062 }

1063 Operation *containingOp = *containingOps.begin();

1064

1065

1066 if (std::empty(producerOps)) {

1068 results.set(cast(getNewContainingOp()), {containingOp});

1070 }

1071

1072

1073

1075 auto getNextProducer = [&]() -> FailureOr<Operation *> {

1076 for (const auto &it : enumerate(remainingProducers)) {

1077 Operation *producerOp = it.value();

1078

1079 int64_t numUsesInContainingOp =

1081 return containingOp->isAncestor(op);

1082 });

1083

1084

1085

1086 if (numUsesInContainingOp > 0) {

1087 if (numUsesInContainingOp == 1)

1088 remainingProducers.erase(remainingProducers.begin() + it.index());

1089 return producerOp;

1090 }

1091 }

1092 return failure();

1093 };

1094

1095 while (!remainingProducers.empty()) {

1096 auto nextProducer = getNextProducer();

1097 if (failed(nextProducer)) {

1099 << "could not find next producer to fuse into container";

1100 diag.attachNote(containingOp->getLoc()) << "containing op";

1101 return diag;

1102 }

1103

1104 Operation *producerOp = *nextProducer;

1105

1106

1108 diag << "could not fuse " << *producerOp << " into " << *containingOp;

1109

1110

1111

1112

1113

1114

1115 auto [tiledOps, newContainingOp] =

1117 if (!tiledOps.empty()) {

1118 LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);

1119 fusedOps.append(tiledOps);

1120 if (newContainingOp) {

1121

1122

1123

1124

1125

1126

1127

1128 LogicalResult replacementStatus =

1130 newContainingOp);

1131 (void)replacementStatus;

1132 assert(succeeded(replacementStatus) &&

1133 "unable to update transform state mapping");

1134 rewriter.eraseOp(containingOp);

1135 containingOp = newContainingOp;

1136 }

1137 continue;

1138 }

1139

1142 rewriter, diag, producerOp, containingOp);

1143 if (!tiledContainingOpOperand.empty()) {

1144 LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"

1145 << *containingOp);

1146 fusedOps.append(tiledContainingOpOperand);

1147 continue;

1148 }

1149

1152 if (cloned) {

1153 LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);

1154 fusedOps.push_back(cloned);

1155 continue;

1156 }

1158 }

1159

1160 results.set(cast(getFusedOp()), fusedOps);

1161 results.set(cast(getNewContainingOp()), {containingOp});

1163 }

1164

1165 void transform::FuseIntoContainingOp::getEffects(

1169 producesHandle(getOperation()->getOpResults(), effects);

1171 }

1172

1173

1174

1175

1176

1179 LinalgOp target,

1182

1183 if (isa(target)) {

1186 }

1188 FailureOr generic = generalizeNamedOp(rewriter, target);

1189 if (succeeded(generic)) {

1190 results.push_back(generic->getOperation());

1192 }

1193 return emitDefaultSilenceableFailure(target);

1194 }

1195

1196

1197

1198

1199

1202 LinalgOp target,

1205

1206 if (!isa(target)) {

1209 }

1211 FailureOr named =

1213 if (succeeded(named)) {

1214 results.push_back(named->getOperation());

1216 }

1217 return emitDefaultSilenceableFailure(target);

1218 }

1219

1220

1221

1222

1223

1226 GenericOp target,

1230

1231 if (interchangeVector.empty()) {

1234 }

1235

1236 unsigned numLoops = cast(target.getOperation()).getNumLoops();

1237 if (interchangeVector.size() != numLoops) {

1238 return emitSilenceableError()

1239 << getIteratorInterchangeAttrName() << " has length ("

1240 << interchangeVector.size()

1241 << ") different from the number of loops in the target operation ("

1242 << numLoops << ")";

1243 }

1246 if (failed(res))

1248 results.push_back(res->getOperation());

1250 }

1251

1254 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));

1255 if (!std::is_permutation(sequence.begin(), sequence.end(),

1256 permutation.begin(), permutation.end())) {

1257 return emitOpError()

1258 << "expects iterator_interchange to be a permutation, found "

1259 << getIteratorInterchange();

1260 }

1261 return success();

1262 }

1263

1264

1265

1266

1267

1272

1273

1274 if (!isalinalg::CopyOp(targetOp)) {

1276 emitSilenceableError() << "only linalg.copy target ops are supported";

1277 diag.attachNote(targetOp->getLoc()) << "target op";

1278 return diag;

1279 }

1280

1281 auto copyOp = dyn_castlinalg::CopyOp(targetOp);

1282 if (!copyOp.hasPureBufferSemantics()) {

1284 emitSilenceableError()

1285 << "cannot transform a linalg.copy on tensors into a memref.copy";

1286 diag.attachNote(targetOp->getLoc()) << "target op";

1287 return diag;

1288 }

1289

1292 assert(inputs.size() == 1 && "expected linalg copy op with one input");

1293 assert(outputs.size() == 1 && "expected memref copy op with one output");

1294 Value input = inputs.front();

1295 Value output = outputs.front();

1296

1297

1298

1299

1300 if (!isa(input.getType())) {

1302 emitSilenceableError()

1303 << "cannot transform a linalg.copy which input has no shape";

1304 diag.attachNote(targetOp->getLoc()) << "target op";

1305 return diag;

1306 }

1307

1308

1309 assert(isa(output.getType()));

1310

1311 if (cast(input.getType()).getElementType() !=

1312 cast(output.getType()).getElementType()) {

1314 emitSilenceableError()

1315 << "cannot transform a linalg.copy with different source and "

1316 "destination element types ";

1317 diag.attachNote(targetOp->getLoc()) << "target op";

1318 return diag;

1319 }

1320

1321

1322 auto memrefCopyOp =

1323 rewriter.replaceOpWithNewOpmemref::CopyOp(targetOp, input, output);

1324

1325 results.push_back(memrefCopyOp);

1327 }

1328

1329

1330

1331

1332

1338 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();

1339 FailureOr res =

1340 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);

1341 if (failed(res)) {

1343 << "cannot lower to pad + expand + transpose";

1344 }

1345 transformResults.push_back(res->padOp);

1346 transformResults.push_back(res->expandShapeOp);

1347 transformResults.push_back(res->transposeOp);

1349 }

1350

1351

1352

1353

1354

1360 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();

1361 FailureOr res =

1362 lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);

1363 if (failed(res)) {

1365 emitSilenceableError()

1366 << "cannot lower to transpose + collapse + extract";

1367 diag.attachNote(target->getLoc()) << "target payload op";

1368 return diag;

1369 }

1370 transformResults.push_back(res->emptyOp);

1371 transformResults.push_back(res->transposeOp);

1372 transformResults.push_back(res->collapseShapeOp);

1373 transformResults.push_back(res->extractSliceOp);

1375 }

1376

1377

1378

1379

1380

1387 }

1388

1395 result.addTypes(resultTypes);

1396 }

1397

1403 if (getOps().has_value())

1404 strs.insert_range(getOps()->getAsValueRange());

1405

1406 auto payloadOps = state.getPayloadOps(getTarget());

1407 if (!llvm::hasSingleElement(payloadOps)) {

1409 }

1410

1412 bool incorrectNumOperandTypes = false;

1413 auto matchFun = [&](Operation *op) {

1414 if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))

1415 return;

1416

1417

1418

1419 if (getInterface().has_value()) {

1420 auto iface = getInterface().value();

1421 if (iface == transform::MatchInterfaceEnum::LinalgOp &&

1422 !isa(op))

1423 return;

1424 if (iface == transform::MatchInterfaceEnum::TilingInterface &&

1425 !isa(op))

1426 return;

1427 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&

1428 !isa(op))

1429 return;

1430 }

1431

1432

1433 if (getOpAttrs().has_value()) {

1434 DictionaryAttr opAttrs = getOpAttrs().value();

1436 if (attr.getName() == getInterfaceAttrName() ||

1437 attr.getName() == getOpsAttrName())

1438 continue;

1439 if (!op->hasAttr(attr.getName()))

1440 return;

1441 if (op->getAttr(attr.getName()) != attr.getValue())

1442 return;

1443 }

1444 }

1445

1446 if (getFilterResultType().has_value()) {

1447 Type t = getFilterResultType().value();

1449 return;

1450 }

1451

1452 if (getFilterOperandTypes().has_value()) {

1453 mlir::ArrayAttr types = getFilterOperandTypes().value();

1455

1456 if (types.size() == 1) {

1457

1458 auto typeattr =

1459 dyn_castmlir::TypeAttr(getFilterOperandTypes().value()[0]);

1460 Type t = cast<::mlir::Type>(typeattr.getValue());

1462 [&](Type operandType) { return operandType == t; }))

1463 return;

1464 } else {

1465

1466

1467 if (types.size() != operandTypes.size()) {

1468 incorrectNumOperandTypes = true;

1469 return;

1470 }

1471

1472 for (auto [attr, operandType] :

1473 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {

1474 auto typeattr = castmlir::TypeAttr(attr);

1475 Type type = cast<::mlir::Type>(typeattr.getValue());

1476

1477 if (type != operandType)

1478 return;

1479 }

1480 }

1481 }

1482

1483

1484 res.push_back(op);

1485 return;

1486 };

1487

1488 (*payloadOps.begin())->walk(matchFun);

1489 if (incorrectNumOperandTypes)

1490 return emitDefiniteFailure("If filter_operand_types contains more than a "

1491 "type, then it must contain as much types as "

1492 "the number of operands in the target ops");

1493 results.set(cast(getResult()), res);

1495 }

1496

1497

1498

1499

1500

1505 }

1506

1508 Type &targetType, Type &lowSizeType,

1509 Type &highSizeType,

1510 Type &splitPointType) {

1511 FunctionType funcType;

1513 if (failed(parser.parseType(funcType)))

1514 return failure();

1515

1516 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {

1517 parser.emitError(typeLoc) << "expects a trailing functional type with one "

1518 "argument and one result";

1519 }

1520 targetType = funcType.getInput(0);

1521 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);

1522

1523 return success();

1524 }

1525

1529 if (isa(getLowSize().getType())) {

1530 if (target.hasDynamicShape()) {

1531 auto diag = emitSilenceableError()

1532 << "cannot compute parametric tile sizes for dynamically "

1533 "shaped payload op";

1534 diag.attachNote(target->getLoc()) << "payload op";

1535 return diag;

1536 }

1537

1539 target, getDimension(), getTargetSize(), getDivisor());

1540 if (failed(spec)) {

1541 return emitSilenceableError()

1542 << "failed to compute multi-size tiling sizes";

1543 }

1544

1545 Builder builder(target.getContext());

1546 results.assign(llvm::map_range(

1548 spec->lowTileSize * spec->lowTripCount}),

1549 [&builder, this](int64_t value) {

1551 cast(getLowSize().getType()).getType(), value);

1552 }));

1554 }

1555

1556 OpBuilder builder(target.getContext());

1561 builder, target, getDimension(), targetSize, divisor);

1562 if (failed(spec)) {

1563 return emitSilenceableError() << "could not generate tile size computation";

1564 }

1565

1570 {spec->lowTileSize, spec->lowTripCount});

1571 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();

1572 Operation *highTileSize = spec->highTileSize.getDefiningOp();

1573 assert(lowTileSize && highTileSize && splitPoint &&

1574 "tile sizes are not produced by operations");

1577 results.push_back(highTileSize);

1580 }

1581

1582 void transform::MultiTileSizesOp::getEffects(

1585 producesHandle(getOperation()->getOpResults(), effects);

1586 if (isa(getLowSize().getType()))

1588 else

1590 }

1591

1593 if (getLowSize().getType() != getHighSize().getType() ||

1594 getLowSize().getType() != getSplitPoint().getType()) {

1595 return emitOpError() << "expects all results type to be the same";

1596 }

1597 return success();

1598 }

1599

1600

1601

1602

1603

1610 staticPackedSizes);

1611

1612

1613

1615 builder.getContext(), GenericOp::getOperationName());

1616 build(builder, result,

1617 linalgOpHType,

1618 target,

1619 dynamicPackedSizes,

1621 }

1622

1625 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);

1626 }

1627

1632 auto targetOps = state.getPayloadOps(getTarget());

1633

1634 if (std::empty(targetOps)) {

1635 transformResults.set(cast(getPackedOp()),

1638 }

1639

1640 auto linalgOp = dyn_cast(*targetOps.begin());

1641 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {

1642 return emitSilenceableError()

1643 << "requires target to map to exactly 1 LinalgOp (got "

1644 << llvm::range_size(targetOps) << ")";

1645 }

1646

1647 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {

1648 return emitSilenceableError()

1649 << "requires number of packed sizes match the number of loops ("

1650 << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()

1651 << ")";

1652 }

1653

1654

1657 state, *this, packedSizes, getMixedPackedSizes());

1658

1660 FailureOr maybeResult = pack(rewriter, linalgOp, packedSizes);

1661 if (failed(maybeResult))

1663

1664 transformResults.set(cast(getPackedOp()),

1665 {maybeResult->packedLinalgOp.getOperation()});

1667 }

1668

1669 void transform::PackOp::getEffects(

1675 }

1676

1677

1678

1679

1680

1683 return emitOpError() << getMatmulInnerDimsOrderAttrName()

1684 << " is not a valid permutation";

1685 }

1686

1687 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {

1688 for (auto [s, nmo] :

1689 llvm::zip_equal(getMixedMatmulPackedSizes(),

1690 getMatmulPaddedSizesNextMultipleOf())) {

1692 if (nmo != 0 &&

1693 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {

1694 return emitOpError() << "at most one of the packed_size and the "

1695 "padded_sizes_next_multiple_of can be nonzero "

1696 "for the matmul strategy";

1697 }

1698 }

1699 }

1700 return success();

1701 }

1702

1708 for (Operation *op : state.getPayloadOps(getTarget())) {

1709 auto linalgOp = dyn_cast(op);

1710 if (!linalgOp)

1711 continue;

1712

1713

1715

1716

1718 rewriter,

1719 linalgOp,

1720 getMixedMatmulPackedSizes(),

1721

1722 getMatmulPaddedSizesNextMultipleOf(),

1723 getMatmulInnerDimsOrder());

1724 if (succeeded(packResult)) {

1725 results.push_back(packResult->packedLinalgOp);

1726 continue;

1727 }

1728 results.push_back(linalgOp);

1729 }

1730 transformResults.set(cast(getPackedOp()), results);

1732 }

1733

1736 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),

1737 b);

1738 }

1739

1740 void transform::PackGreedilyOp::getEffects(

1746 }

1747

1748

1749

1750

1751

1754 return emitOpError() << getInnerPermAttrName()

1755 << " is not a valid permutation";

1756 }

1758 return emitOpError() << getOuterPermAttrName()

1759 << " is not a valid permutation";

1760 }

1761 if (getInnerPerm().empty() && getOuterPerm().empty()) {

1762 return emitOpError() << " at least one of " << getInnerPermAttrName()

1763 << " or " << getOuterPermAttrName()

1764 << " must be specified";

1765 }

1766 return success();

1767 }

1768

1769 namespace {

1770 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };

1771 }

1772

1773

1774

1775

1776

1777

1778

1779

1780 template

1783 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {

1784 static_assert(

1785 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,

1786 "applies to only pack or unpack operations");

1787 if (!op || permutation.empty())

1788 return true;

1789 size_t innerRank = op.getInnerDimsPos().size();

1790 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)

1791 return permutation.size() == innerRank && isPermutationVector(permutation);

1792

1793

1794 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {

1795 return permutation.size() == op.getSourceRank() &&

1797 }

1798 return permutation.size() == op.getDestRank() &&

1800 }

1801

1806 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());

1807 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());

1808

1809 if (std::empty(packOrUnpackOps)) {

1810 transformResults.set(cast(getPackedOp()), {});

1811 transformResults.set(cast(getPackOp()), {});

1812 transformResults.set(cast(getUnPackOp()), {});

1814 }

1815

1816

1817

1818 if (!llvm::hasSingleElement(packOrUnpackOps) ||

1819 !llvm::hasSingleElement(linalgOps)) {

1820 return emitSilenceableError()

1821 << "requires target to map to exactly 1 "

1822 "packing op and 1 packed op ("

1823 << "got " << llvm::range_size(packOrUnpackOps) << " and "

1824 << llvm::range_size(linalgOps) << ")";

1825 }

1826

1827

1828 auto packOp = dyn_castlinalg::PackOp(*packOrUnpackOps.begin());

1829 auto unPackOp = dyn_castlinalg::UnPackOp(*packOrUnpackOps.begin());

1830 if ((!packOp && !unPackOp)) {

1831 return emitSilenceableError() << "requires target to map to a "

1832 "linalg.pack or linalg.unpack";

1833 }

1834 LinalgOp linalgOpTarget = dyn_cast(*linalgOps.begin());

1835 if (!linalgOpTarget)

1836 return emitSilenceableError() << "requires a LinalgOp target";

1837

1838

1839 LinalgOp linalgOp;

1840 if (packOp && packOp.getResult().hasOneUse())

1841 linalgOp = dyn_cast(*(packOp.getResult().getUsers().begin()));

1842 else if (unPackOp)

1843 linalgOp = unPackOp.getSource().getDefiningOp();

1844 if (linalgOp != linalgOpTarget) {

1845 auto errorMsg =

1846 packOp ? StringLiteral{"not a single use by the LinalgOp target"}

1847 : StringLiteral{"not produced by the LinalgOp target"};

1848 return emitSilenceableError() << errorMsg;

1849 }

1850

1851

1852

1853 if (unPackOp) {

1854 assert(!packOp && "packOp must be null on entry when unPackOp is not null");

1855 OpOperand *packUse = linalgOp.getDpsInitOperand(

1856 cast(unPackOp.getSource()).getResultNumber());

1857 packOp = dyn_cast_or_nulllinalg::PackOp(packUse->get().getDefiningOp());

1858 if (!packOp || !packOp.getResult().hasOneUse())

1859 return emitSilenceableError() << "could not find matching pack op";

1860 }

1861

1862

1863 for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {

1865 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();

1866 auto errorMsg = (permType == OuterOrInnerPerm::Outer)

1867 ? StringLiteral{"invalid outer_perm"}

1868 : StringLiteral{"invalid inner_perm"};

1872 unPackOp ? unPackOp.getOperation() : packOp.getOperation();

1873 return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;

1874 }

1875 }

1876

1877

1878

1879 assert(packOp && linalgOp && "unexpected null op");

1880

1881

1882 FailureOr res = packTranspose(

1883 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());

1884

1885 assert(succeeded(res) && "unexpected packTranspose failure");

1886

1887

1888 transformResults.set(cast(getPackOp()), {res->transposedPackOp});

1889 transformResults.set(cast(getPackedOp()),

1890 {res->transposedLinalgOp});

1891 if (unPackOp) {

1892 transformResults.set(cast(getUnPackOp()),

1893 {res->transposedUnPackOp});

1894 } else {

1895 transformResults.set(cast(getUnPackOp()), {});

1896 }

1897

1899 }

1900

1901

1902

1903

1904

1910 StringRef copyBackOp) {

1912 return build(b,

1913 result,

1914 TypeRange{resultType, resultType},

1915 target,

1916 ArrayAttr(),

1917 b.getI64ArrayAttr(paddingDimensions),

1919

1920 (padToMultipleOf.empty()

1922 : b.getDenseI64ArrayAttr(padToMultipleOf)),

1923 b.getI64ArrayAttr(nofoldFlags),

1924 b.getArrayAttr(transposePaddings),

1925 b.getStringAttr(copyBackOp));

1926 }

1927

1933 StringRef copyBackOp) {

1938 staticPadToMultipleOf);

1939 return build(b,

1940 result,

1941 TypeRange{resultType, resultType},

1942 target,

1943 ArrayAttr(),

1944 b.getI64ArrayAttr(paddingDimensions),

1945 dynamicPadToMultipleOf,

1946 staticPadToMultipleOf,

1948 b.getArrayAttr(transposePaddings),

1950 }

1951

1952 void PadOp::getEffects(

1956 producesHandle(getOperation()->getOpResults(), effects);

1958 }

1959

1962 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);

1963 }

1964

1969 auto transformOp = cast(getOperation());

1971

1972 for (Operation *target : state.getPayloadOps(getTarget())) {

1973 auto linalgTarget = dyn_cast(target);

1974 if (!linalgTarget) {

1975 auto diag = emitSilenceableError() << "expected LinalgOp target";

1976 diag.attachNote(target->getLoc()) << "target op";

1977 return diag;

1978 }

1979

1980

1982 for (int64_t packPadding :

1983 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))

1984 nofoldFlags.push_back(static_cast<bool>(packPadding));

1985

1986

1988 for (auto const &it :

1989 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {

1990 auto attr = dyn_cast(std::get<0>(it));

1991 if (!attr) {

1992 emitOpError("expects padding values to be typed attributes");

1994 }

1996

1997 if (auto stringAttr = dyn_cast(attr)) {

1998 auto parsedAttr = dyn_cast_if_present(parseAttribute(

1999 stringAttr, getContext(), elementType,

2000 nullptr, true));

2001 if (!parsedAttr || parsedAttr.getType() != elementType) {

2002 auto diag = this->emitOpError("expects a padding that parses to ")

2003 << elementType << ", got " << std::get<0>(it);

2004 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";

2006 }

2007 paddingValues.push_back(parsedAttr);

2008 continue;

2009 }

2010

2011 if (attr.getType() != elementType) {

2012 auto diag = this->emitOpError("expects a padding value of type ")

2013 << elementType << ", got " << attr;

2014 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";

2016 }

2017 paddingValues.push_back(attr);

2018 }

2019

2020

2022 for (Attribute transposeVector : cast(getTransposePaddings()))

2023 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(

2024 cast(transposeVector)));

2025

2026 LinalgOp paddedOp;

2028 options.paddingDimensions =

2029 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());

2030

2033 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);

2035 return status;

2036 if (padToMultipleOf.empty())

2037 padToMultipleOf =

2039

2040 options.padToMultipleOf = padToMultipleOf;

2041 options.paddingValues = paddingValues;

2042 options.nofoldFlags = nofoldFlags;

2043 if (getCopyBackOp() ==

2044 bufferization::MaterializeInDestinationOp::getOperationName()) {

2047 } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {

2049 } else if (getCopyBackOp() == kCopyOpNone) {

2051 } else {

2052 llvm_unreachable("unsupported copy_back op");

2053 }

2054

2058 replacements, newPadOps))) {

2059 auto diag = emitSilenceableError() << "failed to pad op";

2060 diag.attachNote(target->getLoc()) << "target op";

2061 return diag;

2062 }

2063

2064

2065

2066

2067

2068

2069 rewriter.replaceOp(linalgTarget, replacements);

2070 paddedOps.push_back(paddedOp);

2071 padOps.append(newPadOps.begin(), newPadOps.end());

2073 for (Value v : replacements) {

2074 Operation *copyBackOp = v.getDefiningOp();

2075 if (!llvm::is_contained(copyBackOps, copyBackOp))

2076 copyBackOps.push_back(copyBackOp);

2077 }

2078 }

2079 }

2080

2081 results.set(cast(getPadded()), paddedOps);

2082 results.set(cast(getPad()), padOps);

2083 results.set(cast(getCopy()), copyBackOps);

2085 }

2086

2089 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());

2090 if (any_of(nofoldFlags, [](int64_t packPadding) {

2091 return packPadding != 0 && packPadding != 1;

2092 })) {

2093 return emitOpError()

2094 << "expects nofold_flags to contain booleans (0/1), found "

2095 << getNofoldFlags();

2096 }

2097

2099 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());

2100 if (any_of(paddingDimensions,

2101 [](int64_t paddingDimension) { return paddingDimension < 0; })) {

2102 return emitOpError() << "expects padding_dimensions to contain positive "

2103 "integers, found "

2104 << getPaddingDimensions();

2105 }

2106 if (!getMixedPadToMultipleOf().empty()) {

2107 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {

2108 return emitOpError() << "expects as many multiples as padding_dimensions";

2109 }

2110 }

2111 ArrayAttr transposes = getTransposePaddings();

2112 for (Attribute attr : transposes) {

2114 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));

2115 if (!std::is_permutation(sequence.begin(), sequence.end(),

2117 return emitOpError()

2118 << "expects transpose_paddings to be a permutation, found "

2119 << attr;

2120 }

2121 }

2122 if (getCopyBackOp() !=

2123 bufferization::MaterializeInDestinationOp::getOperationName() &&

2124 getCopyBackOp() != linalg::CopyOp::getOperationName() &&

2125 getCopyBackOp() != kCopyOpNone)

2126 return emitOpError() << "invalid copy_back_op";

2127 return success();

2128 }

2129

2130

2131

2132

2133

2138 auto targetOps = state.getPayloadOps(getTarget());

2139 auto loopOps = state.getPayloadOps(getLoop());

2140 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {

2142 << "requires exactly one target and one loop handle (got "

2143 << llvm::range_size(targetOps) << " and "

2144 << llvm::range_size(loopOps) << ")";

2145 }

2146

2147 auto padOp = dyn_cast_or_nulltensor::PadOp(*targetOps.begin());

2148 auto loopOp = dyn_cast_or_nullscf::ForOp(*loopOps.begin());

2149 if (!padOp || !loopOp)

2151

2152 FailureOrlinalg::detail::PackingResult result =

2154 getTranspose());

2155 if (failed(result))

2157

2158 if (result->clonedLoopIvs.empty()) {

2159 transformResults.set(cast(getPackingLoop()),

2160 {result->hoistedPadOp.getOperation()});

2162 }

2163 auto outerPackedLoop =

2165 transformResults.set(cast(getPackingLoop()),

2166 {outerPackedLoop.getOperation()});

2168 }

2169

2172 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));

2173 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),

2175 return emitOpError() << "expects transpose to be a permutation, found "

2176 << getTranspose();

2177 }

2178 return success();

2179 }

2180

2181 void transform::HoistPadBuildPackingLoopNestOp::getEffects(

2187 }

2188

2191 tensor::PadOp target,

2194 tensor::PadOp hoistedPadOp;

2196 FailureOr result =

2198 hoistedPadOp, transposeOps);

2199 if (succeeded(result)) {

2200

2201

2202

2203

2204

2205 rewriter.replaceOp(target, *result);

2206 results.push_back(hoistedPadOp);

2208 }

2209 return emitDefaultSilenceableFailure(target);

2210 }

2211

2214 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));

2215 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),

2217 return emitOpError() << "expects transpose to be a permutation, found "

2218 << getTranspose();

2219 }

2220 return success();

2221 }

2222

2223

2224

2225

2226

2229 LinalgOp target,

2233 if (!getOperandsToPromote().empty())

2235 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));

2236 if (getUseFullTilesByDefault())

2238 getUseFullTilesByDefault());

2239 if (getUseAlloca())

2240 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());

2241 if (!getUseFullTileBuffers().empty())

2243 llvm::to_vector(getUseFullTileBuffers().getAsValueRange()));

2244 if (getAlignment().has_value())

2245 promotionOptions = promotionOptions.setAlignment(*getAlignment());

2246 if (getMemorySpace().has_value())

2247 promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());

2248

2249 if (getMapping().has_value()) {

2250

2251 auto mapping = *getMapping();

2252 if (mapping.size() > 1)

2253 return emitDefaultDefiniteFailure(target);

2254

2255 auto addressSpace = castmlir::gpu::GPUMemorySpaceMappingAttr(mapping[0]);

2256

2257 if (addressSpace.getAddressSpace() ==

2258 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {

2259 promotionOptions =

2260 promotionOptions

2265 } else if (addressSpace.getAddressSpace() ==

2266 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {

2267 promotionOptions =

2268 promotionOptions

2273 } else {

2274 return emitDefaultDefiniteFailure(target);

2275 }

2276 }

2277

2279 return emitDefaultDefiniteFailure(target);

2280

2282 FailureOr res = promoteSubViews(rewriter, target, promotionOptions);

2283 if (failed(res))

2284 return emitDefaultDefiniteFailure(target);

2287 }

2288

2289

2290

2291

2292

2297 auto payload = state.getPayloadOps(getTarget());

2298

2299

2300 for (Operation *target : payload) {

2301 if (target->getNumOperands() > 0)

2304 target->getNumRegions() > 0)

2306 << "expected target that is isolated from above";

2307 }

2308

2309

2310 Operation *pattern = &getBodyRegion().front().front();

2312 for (Operation *target : payload) {

2313 if (getOperation()->isAncestor(target))

2314 continue;

2318 replacements.push_back(replacement);

2319 }

2320 transformResults.set(cast(getReplacement()), replacements);

2322 }

2323

2324 void transform::ReplaceOp::getEffects(

2327 producesHandle(getOperation()->getOpResults(), effects);

2329 }

2330

2332 if (!getBodyRegion().hasOneBlock())

2333 return emitOpError() << "expected one block";

2334 if (std::distance(getBodyRegion().front().begin(),

2335 getBodyRegion().front().end()) != 1)

2336 return emitOpError() << "expected one operation in block";

2337 Operation *replacement = &getBodyRegion().front().front();

2340 << "expected replacement without operands";

2344 << "expect op that is isolated from above";

2345 return success();

2346 }

2347

2348

2349

2350

2351

2354 LinalgOp target,

2360 Location loc = target.getLoc();

2362 target.createFlatListOfOperandDims(b, loc);

2363 AffineMap map = target.getShapesToLoopsMap();

2364 if (!map)

2365 return tileSizes;

2368 allShapeSizes);

2369

2370

2374 }

2375 return tileSizes;

2376 });

2378 FailureOrscf::SCFTilingResult maybeTilingResult = tileUsingSCF(

2379 rewriter, cast(target.getOperation()), tilingOptions);

2380 if (failed(maybeTilingResult))

2381 return emitDefaultDefiniteFailure(target);

2382

2383 if (target->getNumResults())

2384 rewriter.replaceOp(target, maybeTilingResult->replacements);

2385 else

2386 rewriter.eraseOp(target);

2387

2388 results.reserve(maybeTilingResult->tiledOps.size());

2389 for (Operation *tiled : maybeTilingResult->tiledOps)

2392 }

2393

2394

2395

2396

2397

2403 for (Operation *target : state.getPayloadOps(getTarget())) {

2404 auto tilingOp = dyn_cast(*target);

2405 if (!tilingOp) {

2407 emitSilenceableError()

2408 << "expected the payload to implement TilingInterface";

2409 diag.attachNote(target->getLoc()) << "payload op";

2410 return diag;

2411 }

2413 FailureOr<SmallVectorscf::ForOp> generatedLoops =

2415 if (failed(generatedLoops))

2416 return emitDefaultDefiniteFailure(target);

2417 for (scf::ForOp &loop : *generatedLoops) {

2418 loops.push_back(loop.getOperation());

2419 }

2420 rewriter.eraseOp(target);

2421 }

2422 results.set(cast(getResult()), loops);

2424 }

2425

2426

2427

2428

2429

2431 transform::RewriteInDestinationPassingStyleOp::applyToOne(

2436 FailureOr<Operation *> maybeResult =

2438 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(

2439 [&rewriter](auto op) {

2441 });

2442 if (failed(maybeResult))

2443 return emitDefaultSilenceableFailure(target);

2444 results.push_back(*maybeResult);

2446 }

2447

2448

2449

2450

2451

2455

2457 llvm::to_vector(state.getPayloadOps(getTarget()));

2458

2459 bool isMultiwaySplit = getMultiway();

2460

2461 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {

2463 << "requires exactly one target when "

2464 "multiway split is enabled (got "

2465 << llvm::range_size(payload) << ")";

2466 }

2467

2469

2470 if (!isMultiwaySplit)

2471 chunkSizes.reserve(payload.size());

2472

2473 if (getDynamicChunkSizes()) {

2475 if (isa(getDynamicChunkSizes().getType())) {

2476 chunkSizes = llvm::to_vector(llvm::map_range(

2477 state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {

2480 diag = emitSilenceableError()

2481 << "expected dynamic split point handle to point to a "

2482 "single-result index-typed op";

2483 diag.attachNote(op->getLoc()) << "dynamic split point";

2484 }

2486 }));

2487 } else {

2488 chunkSizes = llvm::to_vector(

2489 llvm::map_range(state.getParams(getDynamicChunkSizes()),

2491 }

2492 if (diag.isSilenceableFailure())

2493 return diag;

2494

2495

2496

2497 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {

2499 << "expected the dynamic split point handle to point to as "

2500 "many operations ("

2501 << chunkSizes.size() << ") as the target handle ("

2502 << payload.size() << ")";

2503 }

2504 } else {

2505 chunkSizes.resize(payload.size(),

2506 rewriter.getIndexAttr(getStaticChunkSizes()));

2507 }

2508

2509 auto checkStructuredOpAndDimensions =

2511 if (!linalgOp) {

2512 auto diag = emitSilenceableError() << "only applies to structured ops";

2513 diag.attachNote(loc) << "target op";

2514 return diag;

2515 }

2516

2517 if (getDimension() >= linalgOp.getNumLoops()) {

2518 auto diag = emitSilenceableError() << "dimension " << getDimension()

2519 << " does not exist in target op";

2520 diag.attachNote(loc) << "target op";

2521 return diag;

2522 }

2524 };

2525

2526 auto checkFailureInSplitting =

2528 if (hasFailed) {

2530 diag.attachNote(loc) << "target op";

2531 return diag;

2532 }

2534 };

2535

2537 if (isMultiwaySplit) {

2538

2539

2540 TilingInterface head, tail;

2541 Operation *target = payload.front();

2542

2543 LinalgOp linalgOp = dyn_cast(target);

2544

2545

2547 checkStructuredOpAndDimensions(linalgOp, target->getLoc());

2548 if (diag.isSilenceableFailure())

2549 return diag;

2550

2551 for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {

2552

2553 if (idx > 0)

2554 target = tail.getOperation();

2555

2556 if (!target)

2557 break;

2558

2559 linalgOp = cast(target);

2561

2564 rewriter, cast(linalgOp.getOperation()),

2565 getDimension(), chunkSize);

2566

2567

2569 checkFailureInSplitting(!head && !tail, loc);

2570 if (diag.isDefiniteFailure())

2571 return diag;

2572

2573 opList.push_back(head.getOperation());

2574 }

2575

2576

2577 if (tail)

2578 opList.push_back(tail.getOperation());

2579

2580 } else {

2581

2583 Operation *noSecondPart = nullptr;

2584 for (const auto &pair : llvm::zip(payload, chunkSizes)) {

2585 Operation *target = std::get<0>(pair);

2587 LinalgOp linalgOp = dyn_cast(target);

2589 checkStructuredOpAndDimensions(linalgOp, target->getLoc());

2590

2591 if (diag.isSilenceableFailure())

2592 return diag;

2593

2595 std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(

2596 rewriter, cast(linalgOp.getOperation()),

2597 getDimension(), std::get<1>(pair));

2598

2599

2601 checkFailureInSplitting(!first.back() && !second.back(), loc);

2603 return diag;

2604

2605

2606 if (!second.back()) {

2607 noSecondPart = target;

2608 second.pop_back();

2609 }

2610 }

2611

2612 if (second.size() != first.size() && !second.empty()) {

2613 auto diag = emitSilenceableError()

2614 << "splitting does not produce the second part for a subset "

2615 "of targets";

2616 diag.attachNote()

2617 << "expected splitting to produce the second part of all "

2618 "or none of the targets";

2619 diag.attachNote(noSecondPart->getLoc())

2620 << "first target with no second part";

2621 return diag;

2622 }

2623

2624 opList.append(first);

2625 if (second.size())

2626 opList.append(second);

2627 }

2628 results.set(cast(getSplitList()), opList);

2630 }

2631

2632 void SplitOp::getEffects(

2635 if (getDynamicChunkSizes())

2636 onlyReadsHandle(getDynamicChunkSizesMutable(), effects);

2637 producesHandle(getOperation()->getOpResults(), effects);

2639 }

2640

2643 IntegerAttr staticChunkSizes;

2645 return failure();

2646

2649 if (!dynamicPointParseResult.has_value()) {

2650 int64_t staticChunkSizesValue;

2651 if (failed(parser.parseInteger(staticChunkSizesValue)))

2652 return failure();

2653

2654 staticChunkSizes =

2656 }

2657

2658 Type targetType;

2662 return failure();

2663 }

2664 if (dynamicPointParseResult.has_value()) {

2665 Type ChunkSizesType;

2666 if (failed(*dynamicPointParseResult) || parser.parseComma() ||

2667 parser.parseType(ChunkSizesType) ||

2668 parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,

2670 return failure();

2671 }

2672

2673 staticChunkSizes =

2675 }

2676

2678 SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),

2679 staticChunkSizes);

2680 result.addTypes(targetType);

2681 return success();

2682 }

2683

2685 printer << " " << getTarget() << " after ";

2686 int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());

2687 if (staticChunkSize != ShapedType::kDynamic)

2688 printer << staticChunkSize;

2689 else

2690 printer << getDynamicChunkSizes();

2691 printer << " ";

2693 {getStaticChunkSizesAttrName()});

2694 printer << " : " << getTarget().getType();

2695 if (staticChunkSize == ShapedType::kDynamic)

2696 printer << ", " << getDynamicChunkSizes().getType();

2697 }

2698

2700 if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^

2701 (getDynamicChunkSizes() == nullptr)) {

2702 return emitOpError() << "expects either a dynamic or a static split "

2703 "point to be provided";

2704 }

2705 return success();

2706 }

2707

2708

2709

2710

2711

2712 void transform::SplitReductionOp::build(

2714 int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,

2715 bool useScalingAlgorithm, bool useAlloc) {

2718 result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),

2721 SplitReductionOp::getInsertSplitDimensionAttrName(result.name),

2723 if (innerParallel) {

2724 result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),

2726 }

2727 if (useScalingAlgorithm) {

2729 SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),

2731 }

2732 if (useAlloc) {

2733 result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),

2735 }

2737 result.addTypes({resultType, resultType, resultType, resultType});

2738 }

2739

2746 unsigned(getInsertSplitDimension()),

2747 bool(getInnerParallel())};

2748 };

2750 FailureOr splitResult =

2751 (getUseScalingAlgorithm())

2753 : splitReduction(rewriter, target, splitFn, getUseAlloc());

2754 if (failed(splitResult))

2755 return emitDefaultDefiniteFailure(target);

2756

2757 results.push_back(splitResult->initOrAlloc);

2758 results.push_back(splitResult->fillOp);

2759 results.push_back(splitResult->splitLinalgOp);

2760 results.push_back(splitResult->resultCombiningLinalgOp);

2762 }

2763

2764

2765

2766

2767

2768 void transform::TileReductionUsingForOp::build(

2771

2772

2773

2774

2775

2779 build(builder, result,

2780 TypeRange{opTy, opTy, opTy, opTy},

2781 target,

2782 staticTileSizesAttr);

2783 }

2784

2790

2791 auto partialReductionOp = dyn_cast(target);

2792 if (!partialReductionOp) {

2795 "Operation should implement PartialReductionOpInterface");

2796 }

2798 rewriter, partialReductionOp,

2800

2801 if (failed(result))

2802 return emitDefaultSilenceableFailure(target);

2803 rewriter.replaceOp(target, result->replacements);

2804 for (Value initValue : result->initialValues)

2806 for (auto parallelTiledOp : result->tiledOps)

2807 results.push_back(parallelTiledOp);

2808 for (auto mergeOp : result->mergeOps)

2810 results.push_back(result->loops.front());

2812 }

2813

2814

2815

2816

2817

2818 void transform::TileReductionUsingForallOp::build(

2821 ArrayAttr mapping) {

2822

2823

2824

2825

2826

2831 build(builder, result,

2832 TypeRange{opTy, opTy, opTy, opTy},

2833 target,

2834 staticNumThreadsAttr,

2835 staticTileSizesAttr,

2836 mapping);

2837 }

2838

2848 FailureOrlinalg::ForallReductionTilingResult result =

2850 rewriter, cast(target.getOperation()),

2851 numThreads, tileSizes, getMapping());

2852

2853 if (failed(result)) {

2854 auto diag = emitSilenceableError() << "could not tile reduction";

2855 diag.attachNote(target.getLoc()) << "target operation";

2856 return diag;

2857 }

2858 for (Value initValue : result->initialValues)

2860 for (auto parallelTiledOp : result->parallelTiledOps)

2861 results.push_back(parallelTiledOp);

2862 for (auto mergeOp : result->mergeOps)

2864 results.push_back(result->loops);

2866 }

2867

2868

2869

2870

2871

2876

2878 llvm::to_vector(state.getPayloadOps(getTarget()));

2879

2880 if (!llvm::hasSingleElement(targetOps)) {

2882 << "requires exactly one target (got " << llvm::range_size(targetOps)

2883 << ")";

2884 }

2885

2886 Operation *target = *targetOps.begin();

2887 auto linalgOp = dyn_cast(target);

2888 auto tileableOp = dyn_cast(target);

2889

2890 if (!linalgOp)

2892

2893 OpBuilder builder(linalgOp.getContext());

2894

2895 if (isa(getChunkSizes().getType())) {

2896 if (linalgOp.hasDynamicShape()) {

2897 auto diag = emitSilenceableError()

2898 << "cannot compute parametric tile sizes for dynamically "

2899 "shaped payload op";

2900 diag.attachNote(linalgOp->getLoc()) << "payload op";

2901 return diag;

2902 }

2903

2904 FailureOr spec =

2906 getTargetSize());

2907 if (failed(spec)) {

2908 return emitSilenceableError()

2909 << "failed to compute multi-size tiling sizes";

2910 }

2911

2913

2914 for (auto &&[tileSize, tripCount] :

2915 llvm::zip_equal(spec->tileSizes, spec->tripCounts))

2916 chunkSizes.push_back(tileSize * tripCount);

2917

2919 return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {

2921 });

2922 };

2924 getI64AttrsFromI64(spec->tileSizes));

2925 transformResults.setParams(cast(getChunkSizes()),

2926 getI64AttrsFromI64(chunkSizes));

2927

2929 }

2930

2932

2934 unsigned dimension = getDimension();

2935

2937 builder, tileableOp, dimension, targetSize, true);

2938 if (failed(spec)) {

2939 return emitSilenceableError() << "could not generate tile size computation";

2940 }

2941

2946 ofrs);

2947 };

2948

2950 Value splitPoint;

2951 for (auto &&[tileSize, tripCount] :

2952 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {

2953 splitPoint = apply(s0 * s1, {tileSize, tripCount});

2954 chunkSizes.push_back(splitPoint);

2955 }

2956

2958 return llvm::map_to_vector(values, [&](Value value) -> Operation * {

2960 });

2961 };

2962

2964 getDefiningOps(spec->tileSizes));

2965 transformResults.set(cast(getChunkSizes()),

2966 getDefiningOps(chunkSizes));

2967

2969 }

2970

2972

2974 return emitOpError() << "expects all results type to be the same";

2975 }

2976

2977 return success();

2978 }

2979

2980 void transform::ContinuousTileSizesOp::getEffects(

2984 else

2987 producesHandle(getOperation()->getOpResults(), effects);

2988 }

2989

2991 Type targetType, Type tile_sizes,

2994 }

2995

2997 Type &targetType,

2998 Type &tileSizesType,

2999 Type &chunkSizesType) {

3000 FunctionType funcType;

3002 if (failed(parser.parseType(funcType)))

3003 return failure();

3004

3005 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {

3006 parser.emitError(typeLoc) << "expects a trailing functional type with one "

3007 "argument and one result";

3008 }

3009 targetType = funcType.getInput(0);

3010 tileSizesType = chunkSizesType = funcType.getResult(0);

3011

3012 return success();

3013 }

3014

3015

3016

3017

3018

3019 void transform::TileUsingForOp::build(

3024 return build(builder, result, loopTypes,

3025 target,

3026

3028 interchange, scalableSizes);

3029 }

3030

3031 void transform::TileUsingForOp::build(

3035 build(builder, result, target,

3037 interchange, scalableSizes);

3038 }

3039

3040 void transform::TileUsingForOp::build(

3044

3045

3047 build(builder, result, loopTypes, target, mixedTileSizes, interchange,

3048 scalableSizes);

3049 }

3050

3051 void transform::TileUsingForOp::build(

3059

3060

3061

3063 unsigned numExpectedLoops =

3064 staticTileSizes.size() - llvm::count(staticTileSizes, 0);

3066 resultTypes.reserve(numExpectedLoops);

3067 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&

3068 "expected one loop type or as many as loops");

3069 if (loopTypes.size() == 1)

3070 resultTypes.append(numExpectedLoops, loopTypes[0]);

3071 else

3072 llvm::append_range(resultTypes, loopTypes);

3073 SmallVector expandedScalableSizes(mixedTileSizes.size(), false);

3074 if (scalableSizes.has_value())

3075 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());

3076 build(builder, result, target.getType(),

3077 resultTypes,

3078 target,

3079 dynamicTileSizes,

3080 staticTileSizesAttr,

3082 expandedScalableSizes);

3083 }

3084

3086 if (getMixedSizes().size() != getScalableSizes().size())

3087 return emitOpError("expected same number of sizes (")

3088 << getMixedSizes().size() << ") and scalable sizes ("

3089 << getScalableSizes().size() << ")";

3091 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);

3092 if (getLoops().size() != numExpectedLoops)

3093 return emitOpError("expected number of loops to tile (")

3094 << numExpectedLoops << ") to match number of `loops` results ("

3095 << getLoops().size() << ")";

3096 return success();

3097 }

3098

3104

3106 llvm::to_vector(state.getPayloadOps(getTarget()));

3112 if (isa(transformValue.getType())) {

3113 dynamicSizeProducers.push_back({});

3115 paramSizes.push_back(

3116 llvm::to_vector(llvm::map_range(params, [](Attribute attr) {

3117 return cast(attr).getValue().getSExtValue();

3118 })));

3119

3120 if (paramSizes.back().size() != targets.size()) {

3122 emitSilenceableError()

3123 << "expected as many parameter values ("

3124 << dynamicSizeProducers.back().size() << ") as target ops ("

3125 << targets.size() << ")";

3126 diag.attachNote(transformValue.getLoc()) << "for this parameter";

3127 return diag;

3128 }

3129

3130 continue;

3131 }

3132 paramSizes.push_back({});

3133 dynamicSizeProducers.push_back(

3134 llvm::to_vector(state.getPayloadOps(transformValue)));

3135

3136 if (dynamicSizeProducers.back().size() != targets.size()) {

3138 emitSilenceableError()

3139 << "expected as many dynamic size-producing operations ("

3140 << dynamicSizeProducers.back().size() << ") as target ops ("

3141 << targets.size() << ")";

3142 diag.attachNote(transformValue.getLoc()) << "for this handle";

3143 return diag;

3144 }

3145

3146 for (Operation *op : dynamicSizeProducers.back()) {

3149 continue;

3150 }

3151

3153 emitSilenceableError() << "expected sizes to be produced by ops "

3154 "with a single index-type result";

3155 diag.attachNote(op->getLoc()) << "size producer op";

3156 diag.attachNote(transformValue.getLoc()) << "for this handle";

3157 return diag;

3158 }

3159 }

3160

3163 loops.resize(getLoops().size());

3164 auto scalableSizes = getScalableSizes();

3166 auto tilingInterface = dyn_cast(op);

3167 if (!tilingInterface) {

3169 emitSilenceableError()

3170 << "only ops implementing TilingInterface are supported";

3171 diag.attachNote(op->getLoc()) << "target op";

3172 return diag;

3173 }

3174 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {

3176 emitSilenceableError()

3177 << "too many tiles provided, expected at most "

3178 << tilingInterface.getLoopIteratorTypes().size() << " found "

3179 << tileSizes.size();

3180 diag.attachNote(op->getLoc()) << "target op";

3181 return diag;

3182 }

3183

3185 if (tileSizes.empty()) {

3188 return {};

3189 });

3190 } else {

3194 sizes.reserve(tileSizes.size());

3195 unsigned dynamicIdx = 0;

3196

3198 if (auto attr = llvm::dyn_cast_if_present(ofr)) {

3199 if (scalableSizes[ofrIdx]) {

3200 auto val = b.createarith::ConstantIndexOp(

3201 getLoc(), cast(attr).getInt());

3204 sizes.push_back(

3205 b.createarith::MulIOp(getLoc(), val, vscale).getResult());

3206 } else {

3207 sizes.push_back(attr);

3208 }

3209 continue;

3210 }

3213 ++dynamicIdx;

3214 assert((dynamicSizes.empty() ^ params.empty()) &&

3215 "expected either dynamic sizes or parameters");

3216 if (!params.empty()) {

3217 sizes.push_back(b.getIndexAttr(params[index]));

3218 } else {

3219 sizes.push_back(dynamicSizes[index]->getResult(0));

3220 }

3221 }

3222 return sizes;

3223 });

3224 }

3225

3227 FailureOrscf::SCFTilingResult maybeTilingResult =

3228 tileUsingSCF(rewriter, tilingInterface, tilingOptions);

3229 if (failed(maybeTilingResult))

3231

3232 rewriter.replaceOp(op, maybeTilingResult->replacements);

3233

3234 tiled.append(maybeTilingResult->tiledOps);

3235 for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))

3236 loops[en2.index()].push_back(en2.value());

3237 }

3238

3239 transformResults.set(cast(getTiledLinalgOp()), tiled);

3241 transformResults.set(cast(getLoops()[en.index()]), en.value());

3242

3244 }

3245

3250 results.reserve(tileSizes.size());

3251 unsigned dynamicPos = 0;

3253 for (int64_t size : tileSizes) {

3254 if (size == ShapedType::kDynamic) {

3255 results.push_back(dynamic[dynamicPos++]);

3256 } else {

3257 results.push_back(builder.getIndexAttr(size));

3258 }

3259 }

3260 return results;

3261 }

3262

3263 void transform::TileUsingForOp::getEffects(

3267 producesHandle(getOperation()->getOpResults(), effects);

3269 }

3270

3271

3272

3273

3274

3275 void transform::TileUsingForallOp::build(OpBuilder &builder,

3279 ArrayAttr mapping) {

3280 return build(builder, result,

3281 target,

3282

3285 mapping);

3286 }

3287

3288 void transform::TileUsingForallOp::build(OpBuilder &builder,

3292 ArrayAttr mapping) {

3296

3297

3298

3302 build(builder, result,

3303 TypeRange{operationType, operationType},

3304 target,

3306 dynamicTileSizes,

3307 Value(),

3308 Value(),

3310 staticTileSizesAttr,

3311 mapping);

3312 }

3313

3314 void transform::TileUsingForallOp::build(OpBuilder &builder,

3318 ArrayAttr mapping) {

3319 return build(builder, result, target,

3322 }

3323

3324 void transform::TileUsingForallOp::build(OpBuilder &builder,

3328 ArrayAttr mapping) {

3332 staticNumThreads);

3333

3334

3335

3339 build(builder, result,

3340 TypeRange{operationType, operationType},

3341 target,

3342 dynamicNumThreads,

3344 Value(),

3345 Value(),

3346 staticNumThreadsAttr,

3348 mapping);

3349 }

3350

3351

3352

3359 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);

3361 for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {

3363 rewriter, loc, normalizedUbExpr, {lb, ub, step});

3364 normalizedUbs.push_back(normalizedUb);

3365 }

3366 return normalizedUbs;

3367 }

3368

3369

3370

3379 AffineExpr denormExpr = s0 + d0 * s1;

3381

3382 for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {

3385 denormalizedIvs.push_back(

3387 }

3388 return denormalizedIvs;

3389 }

3390

3391

3392

3393

3394

3395

3396

3397

3399 scf::ForallOp loop) {

3403

3405 return loop;

3406 }

3407

3408 Location loc = loop.getLoc();

3415

3416 auto normalizedForallOp = rewriter.createscf::ForallOp(

3417 loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),

3419

3420 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();

3422 Block *normalizedLoopBlock = normalizedForallOp.getBody();

3424

3426 denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);

3427 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),

3428 normalizedForallOp.getRegionIterArgs().end());

3429 Block *origLoopBlock = loop.getBody();

3430 rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);

3431

3432 rewriter.replaceOp(loop, normalizedForallOp);

3433 return normalizedForallOp;

3434 }

3435

3438 TransformOpInterface transformOp, Operation *target,

3442

3443 auto tileableOp = dyn_cast(target);

3444 if (!tileableOp) {

3446 transformOp.emitSilenceableError()

3447 << "only TilingInterface ops are supported";

3448 diag.attachNote(target->getLoc()) << "target op";

3449 return diag;

3450 }

3454 if (!mixedNumThreads.empty()) {

3455 options.setNumThreads(mixedNumThreads);

3456 } else {

3457 options.setTileSizes(mixedTileSizes);

3458 }

3459 if (mapping) {

3460 options.setMapping(mapping.value().getValue());

3461 }

3462 FailureOrscf::SCFTilingResult maybeTilingResult =

3464

3465 if (failed(maybeTilingResult))

3466 return transformOp.emitDefaultSilenceableFailure(tileableOp);

3467

3468 rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);

3469

3470 tilingResult = *maybeTilingResult;

3471

3472 if (mixedNumThreads.empty()) {

3473 auto generatedForallOp = castscf::ForallOp(tilingResult.loops.front());

3476 scf::ForallOp normalizedForallOp =

3478 tilingResult.loops.front() = normalizedForallOp;

3479 }

3480

3482 }

3483

3488 auto transformOp = cast(getOperation());

3489

3490

3493

3494

3497 getPackedNumThreads()

3499 state, transformOp, mixedNumThreads, getPackedNumThreads())

3501 state, transformOp, mixedNumThreads, getMixedNumThreads());

3503 return status;

3505 status = getPackedTileSizes()

3507 state, transformOp, mixedTileSizes, getPackedTileSizes())

3509 state, transformOp, mixedTileSizes, getMixedTileSizes());

3511 return status;

3512

3513 for (Operation *target : state.getPayloadOps(getTarget())) {

3516 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,

3517 getMapping(), tilingResult);

3518 if (diag.succeeded())

3519 return diag;

3520 tileOps.push_back(tilingResult.loops.front());

3521 tiledOps.append(tilingResult.tiledOps);

3522 }

3523

3524 transformResults.set(cast(getForallOp()), tileOps);

3525 transformResults.set(cast(getTiledOp()), tiledOps);

3526

3528 }

3529

3530 void transform::TileUsingForallOp::getEffects(

3537 producesHandle(getOperation()->getOpResults(), effects);

3539 }

3540

3543 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);

3544 }

3545

3549 }

3550

3552 int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +

3553 static_cast<int>(getPackedNumThreads() != Value());

3554 if (numThreadsSpec > 1)

3555 return emitOpError(

3556 "num_threads and packed_num_threads are mutually exclusive");

3557 int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +

3558 static_cast<int>(getPackedTileSizes() != Value());

3559 if (tileSizesSpec > 1)

3560 return emitOpError(

3561 "tile_sizes and packed_tile_sizes are mutually exclusive");

3562 if (numThreadsSpec == 0 && tileSizesSpec == 0)

3563 return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "

3564 "must be specified");

3565 return success();

3566 }

3567

3568

3569

3570

3571

3572 void transform::VectorizeChildrenAndApplyPatternsOp::build(

3574 bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {

3576 if (vectorizePadding) {

3578 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(

3579 result.name),

3581 }

3582 if (vectorizeExtract) {

3584 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(

3585 result.name),

3587 }

3588 if (flatten1DDepthwiseConv) {

3590 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(

3591 result.name),

3593 }

3595 }

3596

3597 namespace {

3598

3599

3600 struct VectorizationPattern : public RewritePattern {

3601 explicit VectorizationPattern(MLIRContext *context,

3602 bool vectorizeExtract = false,

3603 bool flattenConv = false)

3604 : RewritePattern(MatchAnyOpTypeTag(), 1, context),

3605 vectorizeNDExtract(vectorizeExtract),

3606 flatten1DDepthwiseConv(flattenConv) {}

3607 LogicalResult matchAndRewrite(Operation *op,

3611 "Unsupported Op, cannot vectorize");

3612 return vectorize(rewriter, op, {},

3613 {}, vectorizeNDExtract,

3614 flatten1DDepthwiseConv);

3615 }

3616

3617 private:

3618

3619

3620 bool vectorizeNDExtract = false;

3621

3622

3623

3624 bool flatten1DDepthwiseConv = false;

3625 };

3626 }

3627

3629 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(

3634 auto diag = this->emitOpError("requires isolated-from-above targets");

3635 diag.attachNote(target->getLoc()) << "non-isolated target";

3637 }

3638

3641 patterns.add(ctx, getVectorizeNdExtract(),

3642 getFlatten_1dDepthwiseConv());

3643

3644 if (!getDisableTransferPermutationMapLoweringPatterns())

3646

3647 if (!getDisableMultiReductionToContractPatterns())

3649

3651

3654 2);

3655 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);

3656 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);

3658

3660

3661 if (getVectorizePadding()) {

3663

3664

3666 }

3668

3670 if (failed(

3673 return emitDefaultDefiniteFailure(target);

3674

3677 }

3678

3679

3680

3681

3682

3687 auto targets = state.getPayloadOps(getTarget());

3688 if (std::empty(targets))

3690 auto transformOp = cast(getOperation());

3693 state, transformOp, getMixedVectorSizes(), vectorSizes);

3695 return status;

3696

3697

3698 for (Operation *target : targets) {

3701 << "Unsupported Op, cannot vectorize";

3702 }

3703

3705 getScalableSizes(),

3706 getVectorizeNdExtract().value_or(false)))) {

3708 << "Attempted to vectorize, but failed";

3709 }

3710 }

3711

3713 }

3714

3715 void transform::VectorizeOp::getEffects(

3720 }

3721

3724 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);

3725 }

3726

3728 if (getStaticVectorSizes().size() != getScalableSizes().size())

3729 return emitOpError("expected same number of vector sizes (")

3730 << getStaticVectorSizes().size() << ") and scalable sizes ("

3731 << getScalableSizes().size() << ")";

3732 return success();

3733 }

3734

3735

3736

3737

3738

3740 transform::HoistRedundantVectorTransfersOp::applyToOne(

3744

3745

3746

3750 }

3751

3752

3753

3754

3755

3757 transform::HoistRedundantVectorBroadcastsOp::applyToOne(

3765 }

3766

3767

3768

3769

3770

3776 auto maybeTransformed =

3778 target)

3779 .Case([&](linalg::Conv2DNhwcHwcfOp op) {

3781 })

3782 .Case([&](linalg::Conv2DNhwcFhwcOp op) {

3784 })

3785 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {

3787 })

3788 .Case([&](linalg::Conv2DNchwFchwOp op) {

3790 })

3793 });

3794 if (failed(maybeTransformed))

3795 return emitDefaultSilenceableFailure(target);

3796

3797 results.push_back(maybeTransformed->first);

3798

3799 results.push_back(maybeTransformed->second);

3801 }

3802

3803

3804

3805

3806

3814 << "only elementwise flattening is supported";

3815

3816

3817 if (target.getNumLoops() <= 1) {

3820 }

3821

3822

3824 std::iota(reassociation.begin(), reassociation.end(), 0);

3825 auto maybeFlattened =

3827 if (failed(maybeFlattened))

3829 << "attempted to flatten, but failed";

3830 results.push_back(maybeFlattened->collapsedOp);

3831 rewriter.replaceOp(target, maybeFlattened->results);

3833 }

3834

3835

3836

3837

3838

3844 auto maybeTransformed =

3846 .Case([&](linalg::Conv2DNhwcFhwcOp op) {

3848 })

3849 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {

3851 })

3854 });

3855 if (failed(maybeTransformed))

3856 return emitDefaultSilenceableFailure(target);

3857

3858 results.push_back(*maybeTransformed);

3860 }

3861

3862

3863

3864

3865

3871 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;

3872 auto maybeTransformed =

3874 .Case([&](linalg::MatmulOp op) {

3876 })

3877 .Case([&](linalg::BatchMatmulOp op) {

3879 })

3880 .Default([&](Operation *op) { return failure(); });

3881 if (failed(maybeTransformed))

3883

3884 results.push_back(*maybeTransformed);

3886 }

3887

3888

3889

3890

3891 template

3895 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,

3896 tensor::ParallelInsertSliceOp>() &&

3897 "wrong op type");

3898

3899 if (auto copySource =

3900 target.getSource().template getDefiningOplinalg::CopyOp()) {

3903 }

3904

3905

3906

3907 if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {

3909 target->template getParentOfTypescf::InParallelOp());

3910 }

3911

3912 Value extracted = rewriter.createtensor::ExtractSliceOp(

3913 target.getLoc(), target.getDest(), target.getMixedOffsets(),

3914 target.getMixedSizes(), target.getMixedStrides());

3915 Value copied = rewriter

3916 .createlinalg::CopyOp(target.getLoc(),

3917 target.getSource(), extracted)

3918 .getResult(0);

3919

3922 target, copied, target.getDest(), target.getMixedOffsets(),

3923 target.getMixedSizes(), target.getMixedStrides());

3924

3925 results.push_back(copied.getDefiningOp());

3927 }

3928

3933

3935 if (auto target = dyn_casttensor::InsertSliceOp(targetOp))

3936 return doit(rewriter, target, results, state);

3937 if (auto target = dyn_casttensor::ParallelInsertSliceOp(targetOp))

3938 return doit(rewriter, target, results, state);

3939

3941 emitSilenceableError()

3942 << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";

3943 diag.attachNote(targetOp->getLoc()) << "target op";

3944 return diag;

3945 }

3946

3947

3948

3949

3950

3955

3956 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {

3958 emitSilenceableError()

3959 << "only linalg.copy and tensor.pad target ops are supported";

3960 diag.attachNote(target->getLoc()) << "target op";

3961 return diag;

3962 }

3963 assert(target->getNumResults() == 1 && "expected single result");

3964 auto resultShapedType = cast(target->getResult(0).getType());

3965 if (!resultShapedType.hasStaticShape()) {

3967 emitSilenceableError()

3968 << "only statically sized ops of rank <= 3 are supported";

3969 diag.attachNote(target->getLoc()) << "target op";

3970 return diag;

3971 }

3972

3973

3974 int64_t desiredBitAlignment = getDesiredBitAlignment();

3975 int64_t eltBitwidth =

3976 resultShapedType.getElementType().getIntOrFloatBitWidth();

3977 if (desiredBitAlignment % eltBitwidth != 0) {

3978 desiredBitAlignment = eltBitwidth;

3979 }

3980

3983 getTotalNumThreads(),

3984 desiredBitAlignment,

3985 resultShapedType.getShape(),

3986 false,

3987

3988 resultShapedType.getElementType().getIntOrFloatBitWidth());

3989 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {

3991 emitSilenceableError()

3992 << "too few threads to map copy op to threads on the most minor "

3993 "dimension, given alignment and vector size constraints, try "

3994 "smaller tile size of mapping to more threads";

3995 diag.attachNote(target->getLoc()) << "target op";

3996 return diag;

3997 }

3998

3999

4003 rewriter,

4004 state,

4005 *this,

4006 target,

4007 getMixedValues(mapping.numThreads, {}, b),

4009 b.getArrayAttr(mapping.threadMapping),

4010 tilingResult);

4011 if (diag.succeeded())

4012 return diag;

4013

4015 for (auto op : tilingResult.tiledOps)

4018 }

4019

4020

4021

4022

4023

4029 FailureOr<Operation *> maybeTransformed = failure();

4031 .Case([&](linalg::Conv2DNhwcFhwcOp op) {

4032 maybeTransformed =

4034 return true;

4035 })

4036 .Default([&](Operation *op) { return false; });

4037

4038 if (!supported) {

4039 return emitSilenceableError()

4040 << "this operation is not supported to convert to Winograd Conv2D";

4041 }

4042

4043 if (failed(maybeTransformed)) {

4044 return emitSilenceableError() << "apply Winograd Conv2D failed";

4045 }

4046

4047 results.push_back(*maybeTransformed);

4049 }

4050

4056 FailureOr<Operation *> maybeTransformed = failure();

4057 bool supported =

4059 .Case([&](linalg::WinogradFilterTransformOp op) {

4061 return true;

4062 })

4063 .Case([&](linalg::WinogradInputTransformOp op) {

4065 return true;

4066 })

4067 .Case([&](linalg::WinogradOutputTransformOp op) {

4069 return true;

4070 })

4071 .Default([&](Operation *op) { return false; });

4072

4073 if (!supported) {

4075 emitSilenceableError()

4076 << "this operation is not supported to decompose into other operations";

4077 diag.attachNote(target->getLoc()) << "target op";

4078 return diag;

4079 }

4080

4081 if (failed(maybeTransformed)) {

4083 emitSilenceableError() << "decompose Winograd operations failed";

4084 diag.attachNote(target->getLoc()) << "target op";

4085 return diag;

4086 }

4087

4088 results.push_back(*maybeTransformed);

4090 }

4091

4092 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"

4093

4094 #define GET_OP_CLASSES

4095 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"

static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)

Maps the 2-dim vector shape to the two 16-bit tile sizes.

static MLIRContext * getContext(OpFoldResult val)

DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)

static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)

bool isValidPackingPermutation(RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)

Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...

static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)

When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...

static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)

Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...

static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type)

static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, scf::ForallOp loop)

Given a scf.forall loop return a loop op with the loop bounds normalized.

static SmallVector< Value > denormalizeIndVar(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)

When a loop is normalized, the uses of the induction variable within the loop need to replaced with o...

#define DOWNSCALE_NORMAL(a, b)

static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)

Attempts to apply the pattern specified as template argument to the given operation.

static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)

static bool sameOrEquivalentIterArg(Value src, Value dst)

Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is ...

static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)

Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...

static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)

static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)

First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...

static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)

static SmallVector< OpFoldResult > normalizeUpperBounds(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)

Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.

static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)

Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...

static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)

Find the first "extract" user of producerOp and tile it right before its use.

static std::string diag(const llvm::Value &value)

static llvm::ManagedStatic< PassManagerOptions > options

static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)

static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)

Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...

Base type for affine expression.

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

virtual Builder & getBuilder() const =0

Return a builder which provides useful access to MLIRContext, global objects like types and attribute...

virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0

Parse a named dictionary into 'result' if it is present.

virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0

Emit a diagnostic at the specified location and return failure.

ParseResult parseInteger(IntT &result)

Parse an integer value from the stream.

virtual ParseResult parseColonType(Type &result)=0

Parse a colon followed by a type.

virtual SMLoc getCurrentLocation()=0

Get the location of the next token and store it into the argument.

virtual ParseResult parseType(Type &result)=0

Parse a type.

virtual ParseResult parseComma()=0

Parse a , token.

ParseResult parseKeyword(StringRef keyword)

Parse a given keyword.

Attributes are known-constant values of operations.

This class represents an argument of a Block.

Block represents an ordered list of Operations.

Operation * getParentOp()

Returns the closest surrounding operation that contains this block.

This class is a general helper class for creating context-global objects like types,...

IntegerAttr getIndexAttr(int64_t value)

IntegerAttr getIntegerAttr(Type type, int64_t value)

DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)

AffineExpr getAffineSymbolExpr(unsigned position)

IntegerAttr getI64IntegerAttr(int64_t value)

Ty getType(Args &&...args)

Get or construct an instance of the type Ty with provided arguments.

StringAttr getStringAttr(const Twine &bytes)

MLIRContext * getContext() const

ArrayAttr getArrayAttr(ArrayRef< Attribute > value)

ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)

ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)

The result of a transform IR operation application.

static DiagnosedSilenceableFailure success()

Constructs a DiagnosedSilenceableFailure in the success state.

bool isDefiniteFailure() const

Returns true if this is a definite failure.

static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)

Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...

bool succeeded() const

Returns true if this is a success.

static DiagnosedSilenceableFailure definiteFailure()

Constructs a DiagnosedSilenceableFailure in the failure state.

This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.

A class for computing basic dominance information.

bool dominates(Operation *a, Operation *b) const

Return true if operation A dominates operation B, i.e.

This class allows control over how the GreedyPatternRewriteDriver works.

This is a utility class for mapping one set of IR entities to another.

void map(Value from, Value to)

Inserts a new mapping for 'from' to 'to'.

IRValueT get() const

Return the current value being used by this operand.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

NamedAttribute represents a combination of a name and an Attribute value.

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0

Resolve an operand to an SSA value, emitting an error on failure.

virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0

Parse a single operand if present.

virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0

Parse a single SSA value operand name along with a result number if allowResultNumber is true.

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0

If the specified operation has attributes, print out an attribute dictionary with their values.

void printFunctionalType(Operation *op)

Print the complete type of an operation in functional form.

This class represents a saved insertion point.

bool isSet() const

Returns true if this insert point is set.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

void setListener(Listener *newListener)

Sets the listener of this builder to the one provided.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

Listener * getListener() const

Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

This class represents a single result from folding an operation.

This class represents an operand of an operation.

unsigned getOperandNumber()

Return which operand this is in the OpOperand list of the Operation.

This is a value defined by a result of an operation.

This class provides the API for ops that are known to be isolated from above.

StringRef getStringRef() const

Return the name of this operation. This always succeeds.

Operation is the basic unit of execution within MLIR.

Value getOperand(unsigned idx)

OpResult getOpResult(unsigned idx)

bool hasTrait()

Returns true if the operation was registered with a particular trait, e.g.

Attribute getAttr(StringAttr name)

Return the specified attribute if present, null otherwise.

void setOperand(unsigned idx, Value value)

bool hasAttr(StringAttr name)

Return true if the operation has an attribute with the provided name, false otherwise.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)

Walk the operation by calling the callback for each nested operation (including this one),...

MLIRContext * getContext()

Return the context this operation is associated with.

unsigned getNumRegions()

Returns the number of regions held by this operation.

Location getLoc()

The source location the operation was defined or derived from.

unsigned getNumOperands()

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

OperationName getName()

The name of an operation is the key identifier for it.

operand_type_range getOperandTypes()

result_type_range getResultTypes()

bool isAncestor(Operation *other)

Return true if this operation is an ancestor of the other operation.

user_range getUsers()

Returns a range of all users.

result_range getOpResults()

result_range getResults()

bool isProperAncestor(Operation *other)

Return true if this operation is a proper ancestor of the other operation.

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

unsigned getNumResults()

Return the number of results held by this operation.

This class implements Optional functionality for ParseResult.

bool has_value() const

Returns true if we contain a valid ParseResult value.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

RewritePattern is the common base class for all DAG to DAG replacements.

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void eraseBlock(Block *block)

This method erases all operations in a block.

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

void replaceAllUsesWith(Value from, Value to)

Find uses of from and replace them with to.

void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)

Inline the operations of block 'source' into the end of block 'dest'.

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)

Find uses of from and replace them with to if the functor returns true.

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

Type front()

Return first type in the range.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

user_range getUsers() const

Location getLoc() const

Return the location of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

State for analysis-enabled bufferization.

Operation * getOwner() const

Return the owner of this operand.

A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...

void assign(unsigned size, std::nullptr_t)

Sets the list of results to size null pointers.

void reserve(unsigned size)

Reserves space for size elements in the list.

size_t size() const

Returns the number of elements in the list.

void push_back(Operation *op)

Appends an element to the list.

A listener that updates a TransformState based on IR modifications.

Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...

void setValues(OpResult handle, Range &&values)

Indicates that the result of the transform IR op at the given position corresponds to the given range...

void setParams(OpResult value, ArrayRef< TransformState::Param > params)

Indicates that the result of the transform IR op at the given position corresponds to the given list ...

void set(OpResult value, Range &&ops)

Indicates that the result of the transform IR op at the given position corresponds to the given list ...

This is a special rewriter to be used in transform op implementations, providing additional helper fu...

LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)

Notify the transform dialect interpreter that the given op has been replaced with another op and that...

The state maintained across applications of various ops implementing the TransformOpInterface.

SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Variant of makeComposedFoldedAffineApply suitable for multi-result maps.

AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...

OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...

LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)

Analyze op and its nested ops.

void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)

Walk all of the regions, blocks, or operations nested under (and including) the given operation.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)

Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.

LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)

Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.

FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)

Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....

bool hasVectorizationImpl(Operation *)

Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...

FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)

Rewrite linalg.winograd_filter_transform.

std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)

Allocate the subview in the GPU workgroup memory.

FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)

Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...

Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)

Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....

FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)

Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.

FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)

Create a namedOp from the given GenericOp and replace the GenericOp.

FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)

Rewrite pack as empty + transpose + reshape + extract_slice.

void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)

Populates patterns with patterns that vectorize tensor.pad.

void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)

LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)

In case of GPU private memory there is no need to deallocate since the memory is freed when going out...

FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)

Rewrite linalg.winograd_output_transform.

std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)

Allocate the subview in the GPU private memory.

FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)

Rewrite tensor.from_elements to linalg.generic.

FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)

Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).

FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)

Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.

void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)

Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...

LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)

Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...

LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)

Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.

LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)

Emit a suitable vector form for an operation.

FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)

Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.

FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)

Pattern to replace.

LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)

Promote memref.subviews feeding linalg-on-buffers operations.

LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)

Normal copy to between src and dst.

bool isElementwise(LinalgOp op)

Check if a LinalgOp is an element-wise operation.

FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)

Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.

FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)

void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)

Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.

FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)

FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)

FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)

void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)

Populates patterns with patterns that fold operations like linalg.pack and linalg....

void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)

Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...

FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)

Method to tile a reduction to parallel iterations computing partial reductions.

FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)

Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...

FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)

Implement packing of a single LinalgOp by packedSizes.

void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)

Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.

FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)

Promote the subViews into a new buffer allocated at the insertion point b.

std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn

Function signature to control reduction splitting.

LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)

In case of GPU group memory there is no need to deallocate.

FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)

Convert Linalg matmul ops to transposed variants.

FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)

Collapses dimensions of linalg.generic/linalg.copy operation.

void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)

Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...

FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)

Rewrite linalg.winograd_input_transform.

void populateDecomposePadPatterns(RewritePatternSet &patterns)

Populates patterns to decompose tensor.pad into e.g.

void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)

Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...

std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)

Split the given op into two parts along the given iteration space dimension at the specified splitPoi...

void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)

Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...

FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)

Scaling-based implementation of the split reduction transformation.

FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)

Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...

FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)

Rewrite pack as pad + reshape + transpose.

SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)

Return the dimensions of the given memref value.

QueryRef parse(llvm::StringRef line, const QuerySession &qs)

FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSizes)

Method to tile a reduction and generate a parallel op within a serial loop.

FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)

Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.

ForOp getForInductionVarOwner(Value val)

Returns the loop parent of an induction variable.

FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)

Method to lower an op that implements the TilingInterface to loops/scalars.

uint64_t getM(LevelType lt)

void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)

Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.

void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)

Appends patterns that are used to bubble up tensor.extract slice op above its producer.

LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)

This is a helper function for DestinationStyleOpInterface.

void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)

Appends patterns for folding tensor subset ops into vector transfer ops.

void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, scf::SCFTilingResult &tilingResult)

Implementation of tiling operations using scf.forall.

void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

Populates effects with the memory effects indicating the operation on the given handle value:

void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

Populates effects with the memory effects indicating the access to payload IR resource.

void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)

Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....

void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)

Populate the pattern set with the following patterns:

void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)

Patterns that remove redundant Vector Ops by re-ordering them with e.g.

void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)

Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...

static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)

Convert int64_t to integer attributes of index type and return them as OpFoldResult.

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to DimExpr at positions: [0 .

LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)

Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...

DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})

Emits a silenceable failure with the given message.

detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr

Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)

This parses a single MLIR attribute to an MLIR context if it was valid.

DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})

Emits a definite failure with the given message.

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

bool isZeroInteger(OpFoldResult v)

Return true if v is an IntegerAttr with value 0.

void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to SymbolExpr at positions: [0 .

void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)

Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

OpFoldResult getAsOpFoldResult(Value val)

Given a value, try to extract a constant Attribute.

SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)

Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

bool isPermutationVector(ArrayRef< int64_t > interchange)

Method to check if an interchange vector is a permutation.

bool isOneInteger(OpFoldResult v)

Return true if v is an IntegerAttr with value 1.

This is the representation of an operand reference.

This class represents a listener that may be used to hook into various actions within an OpBuilder.

This represents an operation in an abstracted form, suitable for use with the builder APIs.

SmallVector< Value, 4 > operands

void addOperands(ValueRange newOperands)

void addAttribute(StringRef name, Attribute attr)

Add an attribute with the specified name.

void addTypes(ArrayRef< Type > newTypes)

Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...

A listener that forwards all notifications to another listener.

ForwardingListener(OpBuilder::Listener *listener)

Container for result values of tiling.

SmallVector< Value > tiledValues

Options for analysis-enabled bufferization.

@ MaterializeInDestination

Transformation to drop unit-extent dimensions from linalg.generic operations.

Vectorization pattern for memref::CopyOp.

Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...

Match and rewrite for the pattern:

Match and rewrite for the pattern:

@ BufferizationMaterializeInDestination

Options used to control tile + fuse.

SCFTilingOptions tilingOptions

The tiling options used to control the tiling of the consumer.

std::optional< FrozenRewritePatternSet > cleanupPatterns

An optional set of rewrite patterns to apply to the results of tiling before fusion.

Options to use to control tiling.

SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)

SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)

SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)

Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...

SmallVector< int64_t > interchangeVector

The interchange vector to reorder the tiled loops.

Transformation information returned after tiling.

SmallVector< Operation * > tiledOps

Tiled operations that are generated during tiling.

SmallVector< LoopLikeOpInterface > loops

The scf.for operations that iterate over the tiles.