MLIR: lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

12

43#include "llvm/ADT/STLExtras.h"

44#include "llvm/ADT/ScopeExit.h"

45#include "llvm/ADT/SmallPtrSet.h"

46#include "llvm/ADT/TypeSwitch.h"

47#include "llvm/Support/DebugLog.h"

48#include "llvm/Support/LogicalResult.h"

49#include <type_traits>

50

51using namespace mlir;

54

55#define DEBUG_TYPE "linalg-transforms"

56

57

58

59

60

61

62template <typename PatternTy, typename... Args>

63static FailureOr tryApply(Operation *operation, Args &&...args) {

64

65 using OpTy = typename llvm::function_traits<

66 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;

67 auto op = dyn_cast(operation);

68 if (!op)

69 return failure();

70

71

72 PatternTy pattern(operation->getContext(), std::forward(args)...);

73

74

77 auto result = pattern.returningMatchAndRewrite(op, rewriter);

79 return failure();

80 return cast(result->getOperation());

81}

82

83

84

85

90 if (auto attr = dyn_cast(ofr)) {

91 if (!isa(attr))

92 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";

93 result.push_back(ofr);

94 continue;

95 }

96

97 Value transformValue = cast(ofr);

98 if (isa(transformValue.getType())) {

100 if (params.size() != 1)

101 return transformOp.emitDefiniteFailure()

102 << "requires exactly one parameter associated";

103 result.push_back(params[0]);

104 continue;

105 }

106

107 auto payloadOps = state.getPayloadOps(transformValue);

108 if (!llvm::hasSingleElement(payloadOps)) {

110 transformOp.emitSilenceableError()

111 << "handle must be mapped to exactly one payload op";

112 diag.attachNote(transformValue.getLoc())

113 << "mapped to " << llvm::range_size(payloadOps) << " payload ops";

115 }

116

117 Operation *op = *payloadOps.begin();

120 transformOp.emitSilenceableError()

121 << "payload op must have exactly 1 index result";

125 }

127 }

128

130}

131

132

133

134

135

136

137

141 if (isa(packedHandle.getType())) {

143 for (auto param : params) {

144 if (!isa(param))

145 return transformOp.emitDefiniteFailure()

146 << "expected the parameter to be associated with an integer "

147 "attribute";

148 result.push_back(param);

149 }

151 }

152

154 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {

156 transformOp.emitSilenceableError()

157 << "payload op must have exactly 1 index result";

158 diag.attachNote(op->getLoc())

159 << "has " << op->getNumResults() << " results";

161 }

162 result.push_back(op->getResult(0));

163 }

164

166}

167

168

169

170

171

173 TransformState &state, TransformOpInterface &transformOp,

175 for (OpFoldResult paramOrHandle : mixedResults) {

176 if (auto attr = dyn_cast(paramOrHandle)) {

177 reified.push_back(cast(attr).getInt());

178 continue;

179 }

180 if (isa(cast(paramOrHandle).getType())) {

182 if (params.size() != 1)

183 return transformOp.emitSilenceableError() << "expected a single param";

184 reified.push_back(

185 cast(params.front()).getValue().getSExtValue());

186 continue;

187 }

188

189 Value handle = cast(paramOrHandle);

190 if (!isa(handle.getType()))

191 return transformOp.emitSilenceableError() << "unexpected value handle";

193 if (!llvm::hasSingleElement(payload))

194 return transformOp.emitSilenceableError()

195 << "requires param or handle that is mapped to 1 payload op";

196

197 Operation *paramOrHandlePayloadOp = *payload.begin();

198 if (paramOrHandlePayloadOp->getNumResults() != 1 ||

200 return transformOp.emitSilenceableError()

201 << "requires param or handle to be result of op with 1 index "

202 "result";

203 }

204

205 IntegerAttr attr;

207 return transformOp.emitSilenceableError()

208 << "requires param or handle to be the result of a constant like "

209 "op";

210

211 reified.push_back(attr.getInt());

212 }

214}

215

216

217

218

219

220void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(

223}

224

225void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(

228}

229

230void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(

233}

234

235void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(

239}

240

241void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(

244 options.rankReductionStrategy =

247}

248

249void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(

252}

253

254void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(

257}

258

259void transform::ApplyPadVectorizationPatternsOp::populatePatterns(

262}

263

264void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(

267}

268

269void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(

272}

273

274

275

276

277

278namespace {

280public:

282

283 SmallVector<Operation *> getNewOps() const {

284 return SmallVector<Operation *>(newOps.begin(), newOps.end());

285 }

286

287private:

288 void notifyOperationInserted(Operation *op,

289 OpBuilder::InsertPoint previous) override {

290 ForwardingListener::notifyOperationInserted(op, previous);

291

292 if (previous.isSet())

293 return;

294 auto inserted = newOps.insert(op);

296 assert(inserted.second && "expected newly created op");

297 }

298

299 void notifyOperationErased(Operation *op) override {

300 ForwardingListener::notifyOperationErased(op);

301 op->walk([&](Operation *op) { newOps.erase(op); });

302 }

303

305};

306}

307

311

313 auto resetListener =

314 llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });

315 NewOpsListener newOpsListener(previousListener);

317

319 if (getMemcpyOp() == "bufferization.materialize_in_destination") {

320 options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::

321 MaterializeInDestination;

322 } else if (getMemcpyOp() == "memref.copy") {

325 } else if (getMemcpyOp() == "linalg.copy") {

328 } else {

329 llvm_unreachable("invalid memcpy op");

330 }

331 if (getAllocOp() == "memref.alloc") {

334 } else if (getAllocOp() == "memref.alloca") {

337 } else {

338 llvm_unreachable("invalid alloc op");

339 }

340 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();

341 options.emitDealloc = getEmitDealloc();

342

343

345 getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();

350 if (!buffer) {

352 << "failed to bufferize operation";

353 diag.attachNote(op->getLoc()) << "target payload op";

355 }

356 allocatedBuffers.push_back(buffer);

357 }

358

359

360 results.setValues(cast(getAllocatedBuffer()), allocatedBuffers);

361 results.set(cast(getNewOps()), newOpsListener.getNewOps());

363}

364

365void transform::BufferizeToAllocationOp::getEffects(

367 if (getBufferizeDestinationOnly()) {

368

369

371 } else {

373 }

374 producesHandle(getOperation()->getOpResults(), effects);

376}

377

378LogicalResult transform::BufferizeToAllocationOp::verify() {

379 if (getMemcpyOp() != "bufferization.materialize_in_destination" &&

380 getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")

381 return emitOpError() << "unsupported memcpy op";

382 if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")

383 return emitOpError() << "unsupported alloc op";

385}

386

387

388

389

390

391

392

393

395 auto linalgOp = dyn_castlinalg::LinalgOp(operand.getOwner());

396

397

398 if (!linalgOp)

399 return true;

400

401

402 Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);

403 return !blockArgument.use_empty();

404}

405

406

408

409

410 if (!isa<TensorType, FloatType, IntegerType>(value.getType()))

411 return true;

412 return llvm::any_of(value.getUses(),

414}

415

422 auto type = dyn_cast(tensor.getType());

423 if (!type) {

424 return emitSilenceableError() << "non-tensor type: " << tensor;

425 }

426

428 if (definingOp)

430 else

432

433

435

438 for (auto [pos, dim] : llvm::enumerate(type.getShape())) {

439 if (!ShapedType::isDynamic(dim))

440 continue;

443 auto dimOp =

444 tensor::DimOp::create(rewriter, tensor.getLoc(), tensor, cst);

445 preservedOps.insert(dimOp);

446 dynamicDims.push_back(dimOp);

447 }

448 auto allocation = bufferization::AllocTensorOp::create(

449 rewriter, tensor.getLoc(), type, dynamicDims);

450

451 if (getMemorySpaceAttr())

452 allocation.setMemorySpaceAttr(getMemorySpaceAttr());

453 Value allocated = allocation;

454

455

456

457 if (needsMaterialization) {

458 auto copy = bufferization::MaterializeInDestinationOp::create(

459 rewriter, tensor.getLoc(), tensor, allocated);

460 preservedOps.insert(copy);

461 promoted.push_back(copy.getResult());

462 } else {

463 promoted.push_back(allocated);

464 }

466 }

467 results.setValues(cast(getPromoted()), promoted);

469}

470

471void transform::PromoteTensorOp::getEffects(

476}

477

478

479

480

481

487#define DOWNSCALE(trans) \

488 { \

489 FailureOr res = tryApply(target); \

490 if (succeeded(res)) { \

491 results.push_back(*res); \

492 return DiagnosedSilenceableFailure::success(); \

493 } \

494 }

495

496#define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>

497#define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))

498

504 DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)

506 DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)

510#undef DOWNSCALE_NORMAL

511#undef DOWNSCALE_CALL

512#undef DOWNSCALE

513 return emitDefaultSilenceableFailure(target);

514}

515

516

517

518

519

520

521

522

527 auto decomposableOp = dyn_cast(target);

528 if (!decomposableOp) {

530 "payload is not a decomposable op"));

531 return emitDefaultSilenceableFailure(target);

532 }

533

534 FailureOr<SmallVector> maybeNewResults =

535 decomposableOp.decomposeOperation(rewriter);

536 if (failed(maybeNewResults))

537 return emitDefaultSilenceableFailure(target);

538

539 rewriter.replaceOp(decomposableOp, *maybeNewResults);

540 for (Value val : *maybeNewResults) {

541 Operation *definition = val.getDefiningOp();

542 if (definition)

544 }

546}

547

548

549

550

551

552void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(

556}

557

559transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(

563 options.allowReturnAllocsFromLoops = true;

564

569 << "failed to analyze op";

571 rewriter, target, state)))

573 << "failed to eliminate LinalgOp anchored tensor.empty ops";

574 }

576}

577

578

579

580

581

586 bool applyCleanup, bool useForall) {

587 return build(

588 builder, result, loopTypes,

590

592

594 applyCleanup, useForall);

595}

596

600 bool applyCleanup, bool useForall) {

601 return build(

604

606

608 applyCleanup, useForall);

609}

610

615 bool applyCleanup, bool useForall) {

616

617

619 build(builder, result, loopTypes, target, mixedTileSizes,

620 mixedTileInterchange, applyCleanup, useForall);

621}

622

627 bool applyCleanup, bool useForall) {

634 staticTileInterchange);

635

636

637

639 auto staticTileInterchangeAttr =

641 unsigned numExpectedLoops =

642 useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);

644 resultTypes.reserve(numExpectedLoops);

645 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&

646 "expected one loop type or as many as loops");

647 if (loopTypes.size() == 1)

648 resultTypes.append(numExpectedLoops, loopTypes[0]);

649 else

650 llvm::append_range(resultTypes, loopTypes);

651 build(builder, result, target.getType(),

652 resultTypes,

654 dynamicTileSizes,

655 dynamicTileInterchange,

656 staticTileSizesAttr,

657 staticTileInterchangeAttr,

658 applyCleanup,

659 useForall);

660}

661

662

663

664template

668 function_ref<FailureOrscf::SCFTileAndFuseResult(TilingInterface)>

669 applyFn) {

672

674 auto tilingInterfaceOp = dyn_cast(target);

675 if (!tilingInterfaceOp)

676 return transformOp->emitError("only TilingInterface ops are supported");

677

679 FailureOrscf::SCFTileAndFuseResult tiledResults =

680 applyFn(tilingInterfaceOp);

681 if (failed(tiledResults))

682 return failure();

683

684

686 llvm::append_range(opsToReplace, tiledResults->fusedProducers);

687 for (Operation *toReplace : opsToReplace) {

688 for (OpResult res : toReplace->getResults())

689 if (auto replacement = tiledResults->replacements.lookup(res))

691 if (toReplace->use_empty()) {

692 rewriter.eraseOp(toReplace);

693 }

694 }

695

696

697 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());

698 assert(tiledResults->loops.size() == numLoops &&

699 "Mismatched number of loops, tile and fuse transform should have "

700 "failed");

701 for (unsigned int i = 0; i < numLoops; ++i)

702 loopOps[i].push_back(tiledResults->loops[i]);

703 }

704

705 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);

706 for (unsigned int i = 0; i < numLoops; ++i)

707 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);

708

710}

711

716 auto transformOp = cast(getOperation());

717

720 state, transformOp, getMixedTileSizes(), tileSizes);

722 return status;

725 state, transformOp, getMixedTileInterchange(), tileInterchange);

727 return status;

728

729 scf::SCFTilingOptions tilingOptions;

730 tilingOptions.interchangeVector = tileInterchange;

731 bool useForall = getUseForall();

732 tilingOptions.setLoopType(useForall

733 ? scf::SCFTilingOptions::LoopType::ForallOp

734 : scf::SCFTilingOptions::LoopType::ForOp);

737 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);

738 scf::SCFTileAndFuseOptions tileAndFuseOptions;

739 tileAndFuseOptions.tilingOptions = tilingOptions;

740

741 if (getApplyCleanup()) {

744 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);

747 tileAndFuseOptions.cleanupPatterns = std::move(patterns);

748 }

749

750 size_t numLoops =

751 useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);

753 rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,

754 transformResults,

755 [&](TilingInterface tilingInterfaceOp)

756 -> FailureOrscf::SCFTileAndFuseResult {

757 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,

758 tileAndFuseOptions);

759 });

762}

763

764LogicalResult transform::FuseOp::verify() {

765 auto iterspace_rank = getStaticTileSizes().size();

767 if (permutation.size() > iterspace_rank)

769 << "interchange length exceeds iteration space dimensions ("

770 << iterspace_rank << "), found " << getTileInterchange();

772 for (int64_t v : permutation) {

773 if (!ShapedType::isDynamic(v)) {

774 if (v < 0 || v >= static_cast<int64_t>(iterspace_rank))

775 return emitOpError() << "expects interchange values to be in range [0, "

776 << iterspace_rank << "), found: " << v;

777 if (seen[v])

778 return emitOpError() << "found duplicate interchange value: " << v;

779 seen[v] = true;

780 }

781 }

782

784 size_t numExpectedLoops =

785 getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);

786 if (numExpectedLoops != getNumResults() - 1)

787 return emitOpError() << "expects " << numExpectedLoops << " loop results";

788

790}

791

794}

795

797 return getMixedValues(getStaticTileInterchange(), getTileInterchange(),

799}

800

801void transform::FuseOp::getEffects(

806 producesHandle(getOperation()->getOpResults(), effects);

808}

809

810

811

812

813

814void transform::FuseIntoContainingOp::build(OpBuilder &builder,

816 Value producerOp,

817 Value containingOp) {

818 result.addOperands({producerOp, containingOp});

819 auto resultType = transform::AnyOpType::get(builder.getContext());

820 result.addTypes({resultType, resultType});

821}

822

823

824

830

831

835 if (!containingOp->isAncestor(user) &&

836 (domInfo.dominates(containingOp, user))) {

837 dominatedUsers.insert(user);

838 }

839 }

840 if (dominatedUsers.empty())

841 return nullptr;

842

843

844 auto forallOp = castscf::ForallOp(containingOp);

847

848

849 Location loc = forallOp.getLoc();

850 auto genericOp = dyn_castlinalg::GenericOp(producerOp);

851 if (!genericOp)

852 return nullptr;

855 newOuts.push_back(outputs[resultNumber]);

856

857

858 auto newforallOp = scf::ForallOp::create(

859 rewriter, loc, forallOp.getMixedLowerBound(),

860 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,

861 forallOp.getMapping());

862 rewriter.eraseBlock(newforallOp.getBody());

863 newforallOp.getRegion().takeBody(forallOp.getRegion());

864

865

866

867

868 newforallOp.getBody()->addArgument(newOuts.back().getType(),

869 newOuts.back().getLoc());

870 auto bbArgs = newforallOp.getBody()->getArguments();

873 Operation *op = use.getOwner();

874 return newforallOp->isProperAncestor(op);

875 });

876

877

878 scf::InParallelOp terminatorOp = newforallOp.getTerminator();

880 terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));

881 Operation *firstYieldOp = yieldingOps.front();

884 Value dst = newforallOp.getRegionIterArgs().back();

886 tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->getLoc(), src,

887 dst, offsets, sizes, strides);

888

889 for (auto result : llvm::enumerate(forallOp.getResults())) {

891 newforallOp->getResult(result.index()));

892 }

894 newforallOp->getResults().back(),

896 Operation *user = use.getOwner();

897 return dominatedUsers.contains(user);

898 });

899 return newforallOp;

900}

901

902

903

904

905

906

908

909

911 destWorklist.push_back(dst);

912

913 while (!destWorklist.empty()) {

914 Value currentDst = destWorklist.pop_back_val();

915

916

917

918 if (src == currentDst)

919 return true;

920

921

922

923 auto bbArg = dyn_cast(currentDst);

924 if (!bbArg)

925 continue;

926

927 Block *parentBlock = bbArg.getOwner();

928 assert(parentBlock && "unlinked block argument");

929

931 assert(parentOp && "expected block argument with parent operation");

932

933

934 auto parentLoop = dyn_cast(parentOp);

935 if (!parentLoop)

936 continue;

937

938 for (auto innerIterArg : parentLoop.getRegionIterArgs()) {

939

940 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);

941 Value loopBlockArgument =

943 destWorklist.push_back(loopBlockArgument);

944 }

945 }

946

947 return false;

948}

949

950

951

952

953

954

955

956static std::tuple<SmallVector<Operation *>, Operation *>

959 LDBG() << "Try to fuse a direct extract use";

960 auto tileableProducer = dyn_cast(producerOp);

961 if (!tileableProducer) {

962 diag.attachNote(producerOp->getLoc())

963 << "producer is not a TileableInterface: " << *producerOp;

964 return {};

965 }

966

967

968

969

970 auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {

971 auto sliceOp = dyn_casttensor::ExtractSliceOp(user);

972 return sliceOp && containingOp->isProperAncestor(sliceOp);

973 });

974

975

976 if (it == tileableProducer->getUsers().end()) {

977 diag.attachNote(tileableProducer->getLoc())

978 << "could not find fusion opportunity for: " << *tileableProducer;

979 return {};

980 }

981 auto sliceOpToTile = casttensor::ExtractSliceOp(*it);

982

983

986

987

988

989

990

991

992

993

994 if (LoopLikeOpInterface containerLoop =

995 dyn_cast(sliceOpToTile->getParentOp())) {

998

999

1000

1001 auto dpsInterface = dyn_cast(clone);

1002 if (!dpsInterface)

1003 return;

1004

1005 for (OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {

1006 Value producerOperand =

1007 clone->getOperand(initOperandPtr.getOperandNumber());

1009 containerLoop.getRegionIterArgs()) {

1010 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);

1011 Value consumerOperand =

1013

1015 initOperandPtr.set(containerIterArg);

1016 }

1017 }

1018 }

1019 });

1020

1021 tileableProducer = dyn_cast(clone);

1022 }

1023

1024

1026 cast(sliceOpToTile.getSource()).getResultNumber();

1027 LDBG() << "resultNumber: " << resultNumber;

1028

1031

1032 FailureOr tileAndFuseResult =

1033 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,

1034 sizes);

1035

1036 if (failed(tileAndFuseResult)) {

1037 diag.attachNote(tileableProducer->getLoc())

1038 << "failed to tile producer op: " << *tileableProducer;

1039 return {};

1040 }

1041

1042#ifndef NDEBUG

1043 for (auto *tiledOp : tileAndFuseResult->tiledOps) {

1044 LDBG() << "tiledProducer: " << *tiledOp;

1045 }

1046#endif

1047

1048

1049 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(

1050 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],

1051 cast(sliceOpToTile->getResult(0).getType()).getShape());

1052 if (failed(maybeRankReduced)) {

1053 diag.attachNote(producerOp->getLoc())

1054 << "shape types don't match (missing canonicalization?):\nTiledOp: "

1055 << tileAndFuseResult->tiledValues[0]

1056 << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';

1057 return {};

1058 }

1059 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);

1060

1061

1063 rewriter, diag, producerOp, containingOp, *tileAndFuseResult,

1064 resultNumber, offsets, sizes);

1065

1066

1067 if (isa(containingOp))

1068 rewriter.eraseOp(tileableProducer);

1069

1070 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);

1071}

1072

1073

1074

1075

1076

1077

1078

1083 LDBG() << "Try to fuse an extract use through block argument";

1084

1085 auto tileableProducer = dyn_cast(producerOp);

1086 if (!tileableProducer) {

1087 diag.attachNote(producerOp->getLoc())

1088 << "producer is not a TileableInterface: " << *producerOp;

1089 return {};

1090 }

1091

1092

1093 scf::ForallOp forallOp;

1094 auto itProducerUses =

1095 llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {

1096 forallOp = dyn_castscf::ForallOp(use.getOwner());

1097 return forallOp;

1098 });

1099

1100 if (!forallOp || forallOp != containingOp) {

1101 diag.attachNote(tileableProducer->getLoc())

1102 << "could not find a use by the containing op: " << *tileableProducer;

1103 return {};

1104 }

1105

1106

1107

1108

1109

1110 OpOperand *pUse = &(*itProducerUses);

1111 BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);

1112

1113

1114

1115

1116 auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {

1117 auto sliceOp = dyn_casttensor::ExtractSliceOp(user);

1118 return sliceOp && containingOp->isProperAncestor(sliceOp);

1119 });

1120

1121

1122 if (itBBArgUsers == bbArg.getUsers().end()) {

1123 diag.attachNote(containingOp->getLoc())

1124 << "could not find fusion opportunity for bbArg: " << bbArg;

1125 return {};

1126 }

1127 auto sliceOpToTile = casttensor::ExtractSliceOp(*itBBArgUsers);

1128

1129

1132

1133

1134

1135 int64_t resultNumber = cast(pUse->get()).getResultNumber();

1136 LDBG() << "resultNumber: " << resultNumber;

1137

1138

1141 rewriter, tileableProducer->getLoc(), tileableProducer,

1142 destinationTensors))) {

1143 diag.attachNote(tileableProducer->getLoc())

1144 << "failed to get destination tensors for: " << *tileableProducer;

1145 return {};

1146 }

1147

1149 bvm.map(destinationTensors[resultNumber], bbArg);

1150 auto tileableProducerClone =

1151 cast(rewriter.clone(*tileableProducer, bvm));

1152 auto scopeGuard =

1153 llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });

1154

1155

1156 FailureOr tileAndFuseResult =

1157 tileableProducerClone.generateResultTileValue(

1158 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),

1159 sliceOpToTile.getMixedSizes());

1160 if (failed(tileAndFuseResult)) {

1161 diag.attachNote(tileableProducer->getLoc())

1162 << "failed to tile producer op: " << *tileableProducer;

1163 return {};

1164 }

1165

1166

1167 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(

1168 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],

1169 cast(sliceOpToTile->getResult(0).getType()).getShape());

1170 assert(succeeded(maybeRankReduced) && "unexpected shape");

1171 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);

1172

1173

1176 destinationTensors.front());

1177 });

1178

1179 return tileAndFuseResult->tiledOps;

1180}

1181

1185 LDBG() << "Try to fuse an use by cloning";

1186

1187

1192 uses.push_back(&use);

1193 continue;

1194 }

1195

1196

1197 if (containingOp == use.getOwner()) {

1198 diag.attachNote(producerOp->getLoc())

1199 << "producer op use by containing op cannot be fused by cloning";

1200 return nullptr;

1201 }

1202 }

1203 }

1204

1205

1206 if (uses.empty()) {

1207 diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";

1208 return nullptr;

1209 }

1210

1211

1214

1215

1216 assert(!isatensor::ParallelInsertSliceOp(use->getOwner()) &&

1217 "Parallel insert slice is not a valid clone destination");

1218 unsigned resultNumber = cast(use->get()).getResultNumber();

1219 LDBG() << "resultNumber: " << resultNumber;

1220

1223 fusedOp = rewriter.clone(*producerOp);

1225 use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });

1226

1227 return fusedOp;

1228}

1229

1230bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {

1231

1232 return true;

1233}

1234

1240 auto producerOps = state.getPayloadOps(getProducerOp());

1241 auto containingOps = state.getPayloadOps(getContainingOp());

1242 if (!llvm::hasSingleElement(containingOps)) {

1244 << "requires exactly one containing_op handle (got "

1245 << llvm::range_size(containingOps) << ")";

1246 }

1247 Operation *containingOp = *containingOps.begin();

1248

1249

1250 if (std::empty(producerOps)) {

1252 results.set(cast(getNewContainingOp()), {containingOp});

1254 }

1255

1256

1257

1259 auto getNextProducer = [&]() -> FailureOr<Operation *> {

1260 for (const auto &it : enumerate(remainingProducers)) {

1261 Operation *producerOp = it.value();

1262

1263 int64_t numUsesInContainingOp =

1265 return containingOp->isAncestor(op);

1266 });

1267

1268

1269

1270 if (numUsesInContainingOp > 0) {

1271 if (numUsesInContainingOp == 1)

1272 remainingProducers.erase(remainingProducers.begin() + it.index());

1273 return producerOp;

1274 }

1275 }

1276 return failure();

1277 };

1278

1279 while (!remainingProducers.empty()) {

1280 auto nextProducer = getNextProducer();

1281 if (failed(nextProducer)) {

1283 << "could not find next producer to fuse into container";

1284 diag.attachNote(containingOp->getLoc()) << "containing op";

1285 return diag;

1286 }

1287

1288 Operation *producerOp = *nextProducer;

1289

1290

1292 diag << "could not fuse " << *producerOp << " into " << *containingOp;

1293

1294

1295

1296

1297

1298

1299 auto [tiledOps, newContainingOp] =

1301 if (!tiledOps.empty()) {

1302 LDBG() << "\nFused a direct extract use\n" << *containingOp;

1303 fusedOps.append(tiledOps);

1304 if (newContainingOp) {

1305

1306

1307

1308

1309

1310

1311

1312 LogicalResult replacementStatus =

1314 newContainingOp);

1315 (void)replacementStatus;

1316 assert(succeeded(replacementStatus) &&

1317 "unable to update transform state mapping");

1318 rewriter.eraseOp(containingOp);

1319 containingOp = newContainingOp;

1320 }

1321 continue;

1322 }

1323

1326 rewriter, diag, producerOp, containingOp);

1327 if (!tiledContainingOpOperand.empty()) {

1328 LDBG() << "\nFused an extract use through block argument\n"

1329 << *containingOp;

1330 fusedOps.append(tiledContainingOpOperand);

1331 continue;

1332 }

1333

1336 if (cloned) {

1337 LDBG() << "\nFused an use by cloning\n" << *containingOp;

1338 fusedOps.push_back(cloned);

1339 continue;

1340 }

1342 }

1343

1344 results.set(cast(getFusedOp()), fusedOps);

1345 results.set(cast(getNewContainingOp()), {containingOp});

1347}

1348

1349void transform::FuseIntoContainingOp::getEffects(

1353 producesHandle(getOperation()->getOpResults(), effects);

1355}

1356

1357

1358

1359

1360

1366

1367 if (isa(target)) {

1370 }

1373 if (succeeded(generic)) {

1374 results.push_back(generic->getOperation());

1376 }

1377 return emitDefaultSilenceableFailure(target);

1378}

1379

1380

1381

1382

1383

1389

1390 if (!isa(target)) {

1393 }

1395 FailureOr named =

1397 if (succeeded(named)) {

1398 results.push_back(named->getOperation());

1400 }

1401 return emitDefaultSilenceableFailure(target);

1402}

1403

1404

1405

1406

1407

1414

1415 if (interchangeVector.empty()) {

1418 }

1419

1420 unsigned numLoops = cast(target.getOperation()).getNumLoops();

1421 if (interchangeVector.size() != numLoops) {

1422 return emitSilenceableError()

1423 << getIteratorInterchangeAttrName() << " has length ("

1424 << interchangeVector.size()

1425 << ") different from the number of loops in the target operation ("

1426 << numLoops << ")";

1427 }

1432 results.push_back(res->getOperation());

1434}

1435

1436LogicalResult transform::InterchangeOp::verify() {

1438 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));

1439 if (!std::is_permutation(sequence.begin(), sequence.end(),

1440 permutation.begin(), permutation.end())) {

1442 << "expects iterator_interchange to be a permutation, found "

1443 << getIteratorInterchange();

1444 }

1446}

1447

1448

1449

1450

1451

1456

1457

1458 if (!isalinalg::CopyOp(targetOp)) {

1460 emitSilenceableError() << "only linalg.copy target ops are supported";

1461 diag.attachNote(targetOp->getLoc()) << "target op";

1462 return diag;

1463 }

1464

1465 auto copyOp = dyn_castlinalg::CopyOp(targetOp);

1466 if (!copyOp.hasPureBufferSemantics()) {

1468 emitSilenceableError()

1469 << "cannot transform a linalg.copy on tensors into a memref.copy";

1470 diag.attachNote(targetOp->getLoc()) << "target op";

1471 return diag;

1472 }

1473

1476 assert(inputs.size() == 1 && "expected linalg copy op with one input");

1477 assert(outputs.size() == 1 && "expected memref copy op with one output");

1478 Value input = inputs.front();

1479 Value output = outputs.front();

1480

1481

1482

1483

1484 if (!isa(input.getType())) {

1486 emitSilenceableError()

1487 << "cannot transform a linalg.copy which input has no shape";

1488 diag.attachNote(targetOp->getLoc()) << "target op";

1489 return diag;

1490 }

1491

1492

1493 assert(isa(output.getType()));

1494

1495 if (cast(input.getType()).getElementType() !=

1496 cast(output.getType()).getElementType()) {

1498 emitSilenceableError()

1499 << "cannot transform a linalg.copy with different source and "

1500 "destination element types ";

1501 diag.attachNote(targetOp->getLoc()) << "target op";

1502 return diag;

1503 }

1504

1505

1506 auto memrefCopyOp =

1507 rewriter.replaceOpWithNewOpmemref::CopyOp(targetOp, input, output);

1508

1509 results.push_back(memrefCopyOp);

1511}

1512

1513

1514

1515

1516

1522 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();

1523 FailureOr res =

1524 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);

1527 << "cannot lower to pad + expand + transpose";

1528 }

1529 transformResults.push_back(res->padOp);

1530 transformResults.push_back(res->expandShapeOp);

1531 transformResults.push_back(res->transposeOp);

1533}

1534

1535

1536

1537

1538

1544 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();

1545 FailureOr res =

1549 emitSilenceableError()

1550 << "cannot lower to transpose + collapse + extract";

1551 diag.attachNote(target->getLoc()) << "target payload op";

1552 return diag;

1553 }

1554 transformResults.push_back(res->emptyOp);

1555 transformResults.push_back(res->transposeOp);

1556 transformResults.push_back(res->collapseShapeOp);

1557 transformResults.push_back(res->extractSliceOp);

1559}

1560

1561

1562

1563

1564

1568 result.addAttribute(MatchOp::getOpsAttrName(result.name),

1570 result.addTypes(transform::AnyOpType::get(builder.getContext()));

1571}

1572

1577 result.addAttribute(MatchOp::getOpsAttrName(result.name),

1579 result.addTypes(resultTypes);

1580}

1581

1587 if (getOps().has_value())

1588 strs.insert_range(getOps()->getAsValueRange());

1589

1590 auto payloadOps = state.getPayloadOps(getTarget());

1591 if (!llvm::hasSingleElement(payloadOps)) {

1593 }

1594

1596 bool incorrectNumOperandTypes = false;

1597 auto matchFun = [&](Operation *op) {

1598 if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))

1599 return;

1600

1601

1602

1603 if (getInterface().has_value()) {

1604 auto iface = getInterface().value();

1605 if (iface == transform::MatchInterfaceEnum::LinalgOp &&

1606 !isa(op))

1607 return;

1608 if (iface == transform::MatchInterfaceEnum::TilingInterface &&

1609 !isa(op))

1610 return;

1611 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&

1612 !isa(op))

1613 return;

1614 }

1615

1616

1617 if (getOpAttrs().has_value()) {

1618 DictionaryAttr opAttrs = getOpAttrs().value();

1620 if (attr.getName() == getInterfaceAttrName() ||

1621 attr.getName() == getOpsAttrName())

1622 continue;

1623 if (!op->hasAttr(attr.getName()))

1624 return;

1625 if (op->getAttr(attr.getName()) != attr.getValue())

1626 return;

1627 }

1628 }

1629

1630 if (getFilterResultType().has_value()) {

1631 Type t = getFilterResultType().value();

1633 return;

1634 }

1635

1636 if (getFilterOperandTypes().has_value()) {

1637 mlir::ArrayAttr types = getFilterOperandTypes().value();

1639

1640 if (types.size() == 1) {

1641

1642 auto typeattr =

1643 dyn_castmlir::TypeAttr(getFilterOperandTypes().value()[0]);

1644 Type t = cast<::mlir::Type>(typeattr.getValue());

1646 [&](Type operandType) { return operandType == t; }))

1647 return;

1648 } else {

1649

1650

1651 if (types.size() != operandTypes.size()) {

1652 incorrectNumOperandTypes = true;

1653 return;

1654 }

1655

1656 for (auto [attr, operandType] :

1657 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {

1658 auto typeattr = castmlir::TypeAttr(attr);

1659 Type type = cast<::mlir::Type>(typeattr.getValue());

1660

1661 if (type != operandType)

1662 return;

1663 }

1664 }

1665 }

1666

1667

1668 res.push_back(op);

1669 return;

1670 };

1671

1672 (*payloadOps.begin())->walk(matchFun);

1673 if (incorrectNumOperandTypes)

1674 return emitDefiniteFailure("If filter_operand_types contains more than a "

1675 "type, then it must contain as much types as "

1676 "the number of operands in the target ops");

1677 results.set(cast(getResult()), res);

1679}

1680

1681

1682

1683

1684

1690

1692 Type &targetType, Type &lowSizeType,

1693 Type &highSizeType,

1694 Type &splitPointType) {

1695 FunctionType funcType;

1697 if (failed(parser.parseType(funcType)))

1698 return failure();

1699

1700 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {

1701 parser.emitError(typeLoc) << "expects a trailing functional type with one "

1702 "argument and one result";

1703 }

1704 targetType = funcType.getInput(0);

1705 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);

1706

1708}

1709

1713 if (isa(getLowSize().getType())) {

1714 if (target.hasDynamicShape()) {

1715 auto diag = emitSilenceableError()

1716 << "cannot compute parametric tile sizes for dynamically "

1717 "shaped payload op";

1718 diag.attachNote(target->getLoc()) << "payload op";

1719 return diag;

1720 }

1721

1723 target, getDimension(), getTargetSize(), getDivisor());

1725 return emitSilenceableError()

1726 << "failed to compute multi-size tiling sizes";

1727 }

1728

1730 results.assign(llvm::map_range(

1732 spec->lowTileSize * spec->lowTripCount}),

1733 [&builder, this](int64_t value) {

1735 cast(getLowSize().getType()).getType(), value);

1736 }));

1738 }

1739

1745 builder, target, getDimension(), targetSize, divisor);

1747 return emitSilenceableError() << "could not generate tile size computation";

1748 }

1749

1754 {spec->lowTileSize, spec->lowTripCount});

1755 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();

1756 Operation *highTileSize = spec->highTileSize.getDefiningOp();

1757 assert(lowTileSize && highTileSize && splitPoint &&

1758 "tile sizes are not produced by operations");

1761 results.push_back(highTileSize);

1764}

1765

1766void transform::MultiTileSizesOp::getEffects(

1769 producesHandle(getOperation()->getOpResults(), effects);

1770 if (isa(getLowSize().getType()))

1772 else

1774}

1775

1776LogicalResult transform::MultiTileSizesOp::verify() {

1777 if (getLowSize().getType() != getHighSize().getType() ||

1778 getLowSize().getType() != getSplitPoint().getType()) {

1779 return emitOpError() << "expects all results type to be the same";

1780 }

1782}

1783

1784

1785

1786

1787

1794 staticPackedSizes);

1795

1796

1797

1798 Type linalgOpHType = transform::OperationType::get(

1799 builder.getContext(), GenericOp::getOperationName());

1800 build(builder, result,

1801 linalgOpHType,

1803 dynamicPackedSizes,

1805}

1806

1809 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);

1810}

1811

1816 auto targetOps = state.getPayloadOps(getTarget());

1817

1818 if (std::empty(targetOps)) {

1819 transformResults.set(cast(getPackedOp()),

1822 }

1823

1824 auto linalgOp = dyn_cast(*targetOps.begin());

1825 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {

1826 return emitSilenceableError()

1827 << "requires target to map to exactly 1 LinalgOp (got "

1828 << llvm::range_size(targetOps) << ")";

1829 }

1830

1831 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {

1832 return emitSilenceableError()

1833 << "requires number of packed sizes match the number of loops ("

1834 << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()

1835 << ")";

1836 }

1837

1838

1841 state, *this, packedSizes, getMixedPackedSizes());

1842

1844 FailureOr maybeResult = pack(rewriter, linalgOp, packedSizes);

1845 if (failed(maybeResult))

1847

1848 transformResults.set(cast(getPackedOp()),

1849 {maybeResult->packedLinalgOp.getOperation()});

1851}

1852

1853void transform::PackOp::getEffects(

1859}

1860

1861

1862

1863

1864

1865LogicalResult transform::PackGreedilyOp::verify() {

1867 return emitOpError() << getMatmulInnerDimsOrderAttrName()

1868 << " is not a valid permutation";

1869 }

1870

1871 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {

1872 for (auto [s, nmo] :

1873 llvm::zip_equal(getMixedMatmulPackedSizes(),

1874 getMatmulPaddedSizesNextMultipleOf())) {

1876 if (nmo != 0 &&

1877 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {

1878 return emitOpError() << "at most one of the packed_size and the "

1879 "padded_sizes_next_multiple_of can be nonzero "

1880 "for the matmul strategy";

1881 }

1882 }

1883 }

1885}

1886

1893 auto linalgOp = dyn_cast(op);

1894 if (!linalgOp)

1895 continue;

1896

1897

1899

1900

1902 rewriter,

1903 linalgOp,

1904 getMixedMatmulPackedSizes(),

1905

1906 getMatmulPaddedSizesNextMultipleOf(),

1907 getMatmulInnerDimsOrder());

1908 if (succeeded(packResult)) {

1909 results.push_back(packResult->packedLinalgOp);

1910 continue;

1911 }

1912 results.push_back(linalgOp);

1913 }

1914 transformResults.set(cast(getPackedOp()), results);

1916}

1917

1920 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),

1921 b);

1922}

1923

1924void transform::PackGreedilyOp::getEffects(

1930}

1931

1932

1933

1934

1935

1936LogicalResult transform::PackTransposeOp::verify() {

1938 return emitOpError() << getInnerPermAttrName()

1939 << " is not a valid permutation";

1940 }

1942 return emitOpError() << getOuterPermAttrName()

1943 << " is not a valid permutation";

1944 }

1945 if (getInnerPerm().empty() && getOuterPerm().empty()) {

1946 return emitOpError() << " at least one of " << getInnerPermAttrName()

1947 << " or " << getOuterPermAttrName()

1948 << " must be specified";

1949 }

1951}

1952

1953namespace {

1954enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };

1955}

1956

1957

1958

1959

1960

1961

1962

1963

1964template

1965static bool isValidPackingPermutation(

1967 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {

1968 static_assert(

1969 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,

1970 "applies to only pack or unpack operations");

1971 if (!op || permutation.empty())

1972 return true;

1973 size_t innerRank = op.getInnerDimsPos().size();

1974 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)

1975 return permutation.size() == innerRank && isPermutationVector(permutation);

1976

1977

1978 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {

1979 return permutation.size() == op.getSourceRank() &&

1981 }

1982 return permutation.size() == op.getDestRank() &&

1984}

1985

1990 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());

1991 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());

1992

1993 if (std::empty(packOrUnpackOps)) {

1994 transformResults.set(cast(getPackedOp()), {});

1995 transformResults.set(cast(getPackOp()), {});

1996 transformResults.set(cast(getUnPackOp()), {});

1998 }

1999

2000

2001

2002 if (!llvm::hasSingleElement(packOrUnpackOps) ||

2003 !llvm::hasSingleElement(linalgOps)) {

2004 return emitSilenceableError()

2005 << "requires target to map to exactly 1 "

2006 "packing op and 1 packed op ("

2007 << "got " << llvm::range_size(packOrUnpackOps) << " and "

2008 << llvm::range_size(linalgOps) << ")";

2009 }

2010

2011

2012 auto packOp = dyn_castlinalg::PackOp(*packOrUnpackOps.begin());

2013 auto unPackOp = dyn_castlinalg::UnPackOp(*packOrUnpackOps.begin());

2014 if ((!packOp && !unPackOp)) {

2015 return emitSilenceableError() << "requires target to map to a "

2016 "linalg.pack or linalg.unpack";

2017 }

2018 LinalgOp linalgOpTarget = dyn_cast(*linalgOps.begin());

2019 if (!linalgOpTarget)

2020 return emitSilenceableError() << "requires a LinalgOp target";

2021

2022

2023 LinalgOp linalgOp;

2024 if (packOp && packOp.getResult().hasOneUse())

2025 linalgOp = dyn_cast(*(packOp.getResult().getUsers().begin()));

2026 else if (unPackOp)

2027 linalgOp = unPackOp.getSource().getDefiningOp();

2028 if (linalgOp != linalgOpTarget) {

2029 auto errorMsg =

2030 packOp ? StringLiteral{"not a single use by the LinalgOp target"}

2031 : StringLiteral{"not produced by the LinalgOp target"};

2032 return emitSilenceableError() << errorMsg;

2033 }

2034

2035

2036

2037 if (unPackOp) {

2038 assert(!packOp && "packOp must be null on entry when unPackOp is not null");

2039 OpOperand *packUse = linalgOp.getDpsInitOperand(

2040 cast(unPackOp.getSource()).getResultNumber());

2042 if (!packOp || !packOp.getResult().hasOneUse())

2043 return emitSilenceableError() << "could not find matching pack op";

2044 }

2045

2046

2047 for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {

2049 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();

2050 auto errorMsg = (permType == OuterOrInnerPerm::Outer)

2051 ? StringLiteral{"invalid outer_perm"}

2052 : StringLiteral{"invalid inner_perm"};

2053 if (!isValidPackingPermutation(packOp, perm, permType) ||

2054 !isValidPackingPermutation(unPackOp, perm, permType)) {

2056 unPackOp ? unPackOp.getOperation() : packOp.getOperation();

2057 return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;

2058 }

2059 }

2060

2061

2062

2063 assert(packOp && linalgOp && "unexpected null op");

2064

2065

2066 FailureOr res = packTranspose(

2067 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());

2068

2069 assert(succeeded(res) && "unexpected packTranspose failure");

2070

2071

2072 transformResults.set(cast(getPackOp()), {res->transposedPackOp});

2073 transformResults.set(cast(getPackedOp()),

2074 {res->transposedLinalgOp});

2075 if (unPackOp) {

2076 transformResults.set(cast(getUnPackOp()),

2077 {res->transposedUnPackOp});

2078 } else {

2079 transformResults.set(cast(getUnPackOp()), {});

2080 }

2081

2083}

2084

2085

2086

2087

2088

2094 StringRef copyBackOp,

2095 bool usePrescribedTensorShapes) {

2096 auto resultType = transform::AnyOpType::get(b.getContext());

2097 return build(b,

2101 ArrayAttr(),

2102 b.getI64ArrayAttr(paddingDimensions),

2104

2105 (padToMultipleOf.empty()

2107 : b.getDenseI64ArrayAttr(padToMultipleOf)),

2108 b.getI64ArrayAttr(nofoldFlags),

2109 b.getArrayAttr(transposePaddings),

2110 b.getStringAttr(copyBackOp),

2111

2112 usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);

2113}

2114

2120 StringRef copyBackOp,

2121 bool usePrescribedTensorShapes) {

2122 auto resultType = transform::AnyOpType::get(b.getContext());

2126 staticPadToMultipleOf);

2127 return build(b,

2129 TypeRange{resultType, resultType},

2131 ArrayAttr(),

2132 b.getI64ArrayAttr(paddingDimensions),

2133 dynamicPadToMultipleOf,

2134 staticPadToMultipleOf,

2135 b.getI64ArrayAttr(nofoldFlags),

2136 b.getArrayAttr(transposePaddings),

2137 copyBackOp,

2138 usePrescribedTensorShapes);

2139}

2140

2141void PadOp::getEffects(

2145 producesHandle(getOperation()->getOpResults(), effects);

2147}

2148

2149SmallVector PadOp::getMixedPadToMultipleOf() {

2151 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);

2152}

2153

2154DiagnosedSilenceableFailure

2155transform::PadOp::apply(transform::TransformRewriter &rewriter,

2156 transform::TransformResults &results,

2157 transform::TransformState &state) {

2158 auto transformOp = cast(getOperation());

2159 SmallVector<Operation *> paddedOps, padOps, copyBackOps;

2160

2162 auto linalgTarget = dyn_cast(target);

2163 if (!linalgTarget) {

2164 auto diag = emitSilenceableError() << "expected LinalgOp target";

2165 diag.attachNote(target->getLoc()) << "target op";

2166 return diag;

2167 }

2168

2169

2170 SmallVector nofoldFlags;

2171 for (int64_t packPadding :

2173 nofoldFlags.push_back(static_cast<bool>(packPadding));

2174

2175

2176 SmallVector paddingValues;

2177 for (auto const &[untypedAttr, elementOrTensorType] :

2178 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {

2179

2180 if (isaub::PoisonAttr(untypedAttr)) {

2181 paddingValues.push_back(untypedAttr);

2182 continue;

2183 }

2184 auto attr = dyn_cast(untypedAttr);

2185 if (!attr) {

2186 emitOpError("expects padding values to be typed attributes or poison");

2188 }

2190

2191 if (auto stringAttr = dyn_cast(attr)) {

2192 auto parsedAttr = dyn_cast_if_present(parseAttribute(

2193 stringAttr, getContext(), elementType,

2194 nullptr, true));

2195 if (!parsedAttr || parsedAttr.getType() != elementType) {

2196 auto diag = this->emitOpError("expects a padding that parses to ")

2197 << elementType << ", got " << untypedAttr;

2198 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";

2200 }

2201 paddingValues.push_back(parsedAttr);

2202 continue;

2203 }

2204

2205 if (attr.getType() != elementType) {

2206 auto diag = this->emitOpError("expects a padding value of type ")

2207 << elementType << ", got " << attr;

2208 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";

2210 }

2211 paddingValues.push_back(attr);

2212 }

2213

2214

2215 SmallVector<SmallVector<int64_t>> transposePaddings;

2216 for (Attribute transposeVector : cast(getTransposePaddings()))

2218 cast(transposeVector)));

2219

2220 LinalgOp paddedOp;

2221 LinalgPaddingOptions options;

2222 options.paddingDimensions =

2224

2225 SmallVector<int64_t> padToMultipleOf;

2227 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);

2229 return status;

2230 if (padToMultipleOf.empty())

2231 padToMultipleOf =

2232 SmallVector<int64_t>(options.paddingDimensions.size(), 1);

2233

2234 options.padToMultipleOf = padToMultipleOf;

2235 options.paddingValues = paddingValues;

2236 options.nofoldFlags = nofoldFlags;

2237 if (getCopyBackOp() ==

2238 bufferization::MaterializeInDestinationOp::getOperationName()) {

2239 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::

2240 BufferizationMaterializeInDestination;

2241 } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {

2242 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy;

2243 } else if (getCopyBackOp() == kCopyOpNone) {

2244 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None;

2245 } else {

2246 llvm_unreachable("unsupported copy_back op");

2247 }

2248

2249 bool irChanged = false;

2250 if (getUsePrescribedTensorShapes() &&

2251 linalgTarget.hasPureTensorSemantics()) {

2252 OpBuilder::InsertionGuard g(rewriter);

2254 for (OpOperand &operand : linalgTarget->getOpOperands()) {

2255 for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {

2256 if (ShapedType::isStatic(dim))

2257 continue;

2258 options.setSizeToPadTo(operand.getOperandNumber(), i,

2260 operand.get().getLoc(),

2261 operand.get(), i));

2262 irChanged = true;

2263 }

2264 }

2265 }

2266

2267 SmallVector replacements;

2268 SmallVectortensor::PadOp newPadOps;

2270 replacements, newPadOps))) {

2271 if (irChanged) {

2274 return diag;

2275 }

2276 auto diag = emitSilenceableError() << "failed to pad op";

2277 diag.attachNote(target->getLoc()) << "target op";

2278 return diag;

2279 }

2280

2281

2282

2283

2284

2285

2286 rewriter.replaceOp(linalgTarget, replacements);

2287 paddedOps.push_back(paddedOp);

2288 padOps.append(newPadOps.begin(), newPadOps.end());

2289 if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {

2290 for (Value v : replacements) {

2291 Operation *copyBackOp = v.getDefiningOp();

2292 if (!llvm::is_contained(copyBackOps, copyBackOp))

2293 copyBackOps.push_back(copyBackOp);

2294 }

2295 }

2296 }

2297

2298 results.set(cast(getPadded()), paddedOps);

2299 results.set(cast(getPad()), padOps);

2300 results.set(cast(getCopy()), copyBackOps);

2302}

2303

2304LogicalResult transform::PadOp::verify() {

2305 SmallVector<int64_t> nofoldFlags =

2307 if (any_of(nofoldFlags, [](int64_t packPadding) {

2308 return packPadding != 0 && packPadding != 1;

2309 })) {

2311 << "expects nofold_flags to contain booleans (0/1), found "

2312 << getNofoldFlags();

2313 }

2314

2315 SmallVector<int64_t> paddingDimensions =

2317 if (any_of(paddingDimensions,

2318 [](int64_t paddingDimension) { return paddingDimension < 0; })) {

2319 return emitOpError() << "expects padding_dimensions to contain positive "

2320 "integers, found "

2321 << getPaddingDimensions();

2322 }

2323 if (!getMixedPadToMultipleOf().empty()) {

2324 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {

2325 return emitOpError() << "expects as many multiples as padding_dimensions";

2326 }

2327 }

2328 ArrayAttr transposes = getTransposePaddings();

2329 for (Attribute attr : transposes) {

2331 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));

2332 if (!std::is_permutation(sequence.begin(), sequence.end(),

2333 transpose.begin(), transpose.end())) {

2335 << "expects transpose_paddings to be a permutation, found "

2336 << attr;

2337 }

2338 }

2339 if (getCopyBackOp() !=

2340 bufferization::MaterializeInDestinationOp::getOperationName() &&

2341 getCopyBackOp() != linalg::CopyOp::getOperationName() &&

2342 getCopyBackOp() != kCopyOpNone)

2343 return emitOpError() << "invalid copy_back_op";

2345}

2346

2347

2348

2349

2350

2351void transform::PadTilingInterfaceOp::build(OpBuilder &b,

2352 OperationState &result,

2354 ArrayRef<int64_t> paddingSizes,

2355 bool padToMultipleOf) {

2356 auto resultType = transform::AnyOpType::get(b.getContext());

2357 return build(b,

2359 TypeRange{resultType, resultType},

2361 ArrayAttr(),

2363

2365 : b.getDenseI64ArrayAttr(paddingSizes)),

2366

2367 padToMultipleOf ? b.getUnitAttr() : nullptr);

2368}

2369

2370void transform::PadTilingInterfaceOp::build(

2371 OpBuilder &b, OperationState &result, Value target,

2372 ArrayRef mixedPaddingSizes, bool padToMultipleOf) {

2373 auto resultType = transform::AnyOpType::get(b.getContext());

2374 SmallVector<int64_t> staticPaddingSizes;

2375 SmallVector dynamicPaddingSizes;

2377 staticPaddingSizes);

2378 return build(b,

2380 TypeRange{resultType, resultType},

2382 ArrayAttr(),

2383 dynamicPaddingSizes,

2384 staticPaddingSizes,

2385 padToMultipleOf);

2386}

2387

2388void transform::PadTilingInterfaceOp::getEffects(

2389 SmallVectorImplMemoryEffects::EffectInstance &effects) {

2392 producesHandle(getOperation()->getOpResults(), effects);

2394}

2395

2396SmallVector

2397transform::PadTilingInterfaceOp::getMixedPaddingSizes() {

2399 return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);

2400}

2401

2402DiagnosedSilenceableFailure

2403transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,

2404 transform::TransformResults &results,

2405 transform::TransformState &state) {

2406 SmallVector<Operation *> paddedOps, padOps;

2407

2409 auto targetOp = dyn_cast(target);

2410 if (!targetOp) {

2411 auto diag = emitSilenceableError() << "expected TilingInterface target";

2412 diag.attachNote(target->getLoc()) << "target op";

2413 return diag;

2414 }

2415

2416

2417

2418

2419 if (!isa(targetOp.getOperation())) {

2420 auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "

2421 "supported atm";

2422 diag.attachNote(target->getLoc()) << "target op";

2423 return diag;

2424 }

2425

2426

2427 SmallVector paddingValues;

2428 for (auto const &[untypedAttr, elementOrTensorType] :

2429 llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {

2430 auto attr = dyn_cast(untypedAttr);

2432

2433 if (isaub::PoisonAttr(untypedAttr)) {

2434 paddingValues.push_back(untypedAttr);

2435 continue;

2436 }

2437 if (!attr) {

2438 emitOpError("expects padding values to be typed attributes or poison");

2440 }

2441

2442 if (auto stringAttr = dyn_cast(attr)) {

2443 auto parsedAttr = dyn_cast_if_present(parseAttribute(

2444 stringAttr, getContext(), elementType,

2445 nullptr, true));

2446 if (!parsedAttr || parsedAttr.getType() != elementType) {

2447 auto diag = this->emitOpError("expects a padding that parses to ")

2448 << elementType << ", got " << attr;

2449 diag.attachNote(targetOp.getLoc()) << "when applied to this op";

2451 }

2452 paddingValues.push_back(parsedAttr);

2453 continue;

2454 }

2455

2456 if (attr.getType() != elementType) {

2457 auto diag = this->emitOpError("expects a padding value of type ")

2458 << elementType << ", got " << attr;

2459 diag.attachNote(targetOp.getLoc()) << "when applied to this op";

2461 }

2462 paddingValues.push_back(attr);

2463 }

2464

2465

2466 PadTilingInterfaceOptions options;

2467 options.setPaddingValues(paddingValues)

2468 .setPaddingSizes(getMixedPaddingSizes())

2469 .setPadToMultipleOf(getPadToMultipleOf());

2470

2471 OpBuilder::InsertionGuard g(rewriter);

2474 rewriter, cast(targetOp.getOperation()), options);

2475 if (failed(maybePadOps)) {

2476 auto diag = emitSilenceableError() << "failed to pad op";

2477 diag.attachNote(target->getLoc()) << "target op";

2478 return diag;

2479 }

2480 const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();

2481

2482

2483 paddedOps.push_back(paddedOp);

2484 padOps.append(paddedOperands.begin(), paddedOperands.end());

2485 rewriter.replaceOp(targetOp.getOperation(), slicedResults);

2486 }

2487

2488 results.set(cast(getPadded()), paddedOps);

2489 results.set(cast(getPad()), padOps);

2491}

2492

2493LogicalResult transform::PadTilingInterfaceOp::verify() { return success(); }

2494

2495

2496

2497

2498

2499DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(

2500 transform::TransformRewriter &rewriter,

2501 transform::TransformResults &transformResults,

2502 transform::TransformState &state) {

2503 auto targetOps = state.getPayloadOps(getTarget());

2505 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {

2507 << "requires exactly one target and one loop handle (got "

2508 << llvm::range_size(targetOps) << " and "

2509 << llvm::range_size(loopOps) << ")";

2510 }

2511

2512 auto padOp = dyn_cast_or_nulltensor::PadOp(*targetOps.begin());

2513 auto loopOp = dyn_cast_or_nullscf::ForOp(*loopOps.begin());

2514 if (!padOp || !loopOp)

2516

2517 FailureOrlinalg::detail::PackingResult result =

2519 getTranspose());

2522

2523 if (result->clonedLoopIvs.empty()) {

2524 transformResults.set(cast(getPackingLoop()),

2525 {result->hoistedPadOp.getOperation()});

2527 }

2528 auto outerPackedLoop =

2530 transformResults.set(cast(getPackingLoop()),

2531 {outerPackedLoop.getOperation()});

2533}

2534

2535LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() {

2536 ArrayRef<int64_t> transpose = getTranspose();

2537 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));

2538 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),

2539 transpose.end())) {

2540 return emitOpError() << "expects transpose to be a permutation, found "

2541 << getTranspose();

2542 }

2544}

2545

2546void transform::HoistPadBuildPackingLoopNestOp::getEffects(

2547 SmallVectorImplMemoryEffects::EffectInstance &effects) {

2552}

2553

2554DiagnosedSilenceableFailure

2555transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,

2556 tensor::PadOp target,

2557 transform::ApplyToEachResultList &results,

2558 transform::TransformState &state) {

2559 tensor::PadOp hoistedPadOp;

2560 SmallVector transposeOps;

2561 FailureOr result =

2563 hoistedPadOp, transposeOps);

2564 if (succeeded(result)) {

2565

2566

2567

2568

2569

2571 results.push_back(hoistedPadOp);

2573 }

2574 return emitDefaultSilenceableFailure(target);

2575}

2576

2577LogicalResult transform::HoistPadOp::verify() {

2578 ArrayRef<int64_t> transpose = getTranspose();

2579 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));

2580 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),

2581 transpose.end())) {

2582 return emitOpError() << "expects transpose to be a permutation, found "

2583 << getTranspose();

2584 }

2586}

2587

2588

2589

2590

2591

2592DiagnosedSilenceableFailure

2593transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,

2595 transform::ApplyToEachResultList &results,

2596 transform::TransformState &state) {

2597 LinalgPromotionOptions promotionOptions;

2598 if (!getOperandsToPromote().empty())

2601 if (getUseFullTilesByDefault())

2603 getUseFullTilesByDefault());

2604 if (getUseOriginalSubviewSize())

2605 promotionOptions =

2607 if (getUseAlloca())

2608 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());

2609 if (!getUseFullTileBuffers().empty())

2611 llvm::to_vector(getUseFullTileBuffers().getAsValueRange()));

2612 if (getAlignment().has_value())

2613 promotionOptions = promotionOptions.setAlignment(*getAlignment());

2614 if (getMemorySpace().has_value())

2615 promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());

2616

2617 if (getMapping().has_value()) {

2618

2619 auto mapping = *getMapping();

2620 if (mapping.size() > 1)

2621 return emitDefaultDefiniteFailure(target);

2622

2623 auto addressSpace = castmlir::gpu::GPUMemorySpaceMappingAttr(mapping[0]);

2624

2625 if (addressSpace.getAddressSpace() ==

2626 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {

2627 promotionOptions =

2628 promotionOptions

2633 } else if (addressSpace.getAddressSpace() ==

2634 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {

2635 promotionOptions =

2636 promotionOptions

2641 } else {

2642 return emitDefaultDefiniteFailure(target);

2643 }

2644 }

2645

2647 return emitDefaultDefiniteFailure(target);

2648

2652 return emitDefaultDefiniteFailure(target);

2655}

2656

2657

2658

2659

2660

2661DiagnosedSilenceableFailure

2662transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,

2663 TransformResults &transformResults,

2664 TransformState &state) {

2665 auto payload = state.getPayloadOps(getTarget());

2666

2667

2668 for (Operation *target : payload) {

2669 if (target->getNumOperands() > 0)

2671 if (target->hasTraitOpTrait::IsIsolatedFromAbove() &&

2672 target->getNumRegions() > 0)

2674 << "expected target that is isolated from above";

2675 }

2676

2677

2678 Operation *pattern = &getBodyRegion().front().front();

2679 SmallVector<Operation *> replacements;

2680 for (Operation *target : payload) {

2681 if (getOperation()->isAncestor(target))

2682 continue;

2687 }

2688 transformResults.set(cast(getReplacement()), replacements);

2690}

2691

2692void transform::ReplaceOp::getEffects(

2693 SmallVectorImplMemoryEffects::EffectInstance &effects) {

2695 producesHandle(getOperation()->getOpResults(), effects);

2697}

2698

2699LogicalResult transform::ReplaceOp::verify() {

2700 if (!getBodyRegion().hasOneBlock())

2701 return emitOpError() << "expected one block";

2702 if (std::distance(getBodyRegion().front().begin(),

2703 getBodyRegion().front().end()) != 1)

2704 return emitOpError() << "expected one operation in block";

2705 Operation *replacement = &getBodyRegion().front().front();

2708 << "expected replacement without operands";

2709 if (replacement->hasTraitOpTrait::IsIsolatedFromAbove() &&

2712 << "expect op that is isolated from above";

2714}

2715

2716

2717

2718

2719

2720DiagnosedSilenceableFailure

2721transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,

2723 transform::ApplyToEachResultList &results,

2724 transform::TransformState &state) {

2725 scf::SCFTilingOptions tilingOptions;

2726 tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {

2727 SmallVector tileSizes;

2728 Location loc = target.getLoc();

2729 SmallVector allShapeSizes =

2730 target.createFlatListOfOperandDims(b, loc);

2731 AffineMap map = target.getShapesToLoopsMap();

2732 if (!map)

2733 return tileSizes;

2734 SmallVector shapeSizes =

2736 allShapeSizes);

2737

2738

2739 for (OpFoldResult shapeSize : shapeSizes) {

2741 : b.getIndexAttr(1));

2742 }

2743 return tileSizes;

2744 });

2746 FailureOrscf::SCFTilingResult maybeTilingResult = tileUsingSCF(

2747 rewriter, cast(target.getOperation()), tilingOptions);

2748 if (failed(maybeTilingResult))

2749 return emitDefaultDefiniteFailure(target);

2750

2751 if (target->getNumResults())

2752 rewriter.replaceOp(target, maybeTilingResult->replacements);

2753 else

2755

2756 results.reserve(maybeTilingResult->tiledOps.size());

2757 for (Operation *tiled : maybeTilingResult->tiledOps)

2760}

2761

2762

2763

2764

2765

2766DiagnosedSilenceableFailure

2767transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,

2768 transform::TransformResults &results,

2769 transform::TransformState &state) {

2770 SmallVector<Operation *> loops;

2772 auto tilingOp = dyn_cast(*target);

2773 if (!tilingOp) {

2774 DiagnosedSilenceableFailure diag =

2775 emitSilenceableError()

2776 << "expected the payload to implement TilingInterface";

2777 diag.attachNote(target->getLoc()) << "payload op";

2778 return diag;

2779 }

2781 FailureOr<SmallVectorscf::ForOp> generatedLoops =

2782 scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);

2783 if (failed(generatedLoops))

2784 return emitDefaultDefiniteFailure(target);

2785 for (scf::ForOp &loop : *generatedLoops) {

2786 loops.push_back(loop.getOperation());

2787 }

2789 }

2790 results.set(cast(getResult()), loops);

2792}

2793

2794

2795

2796

2797

2798DiagnosedSilenceableFailure

2799transform::RewriteInDestinationPassingStyleOp::applyToOne(

2800 transform::TransformRewriter &rewriter, Operation *target,

2801 transform::ApplyToEachResultList &results,

2802 transform::TransformState &state) {

2804 FailureOr<Operation *> maybeResult =

2806 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(

2807 [&rewriter](auto op) {

2809 });

2810 if (failed(maybeResult))

2811 return emitDefaultSilenceableFailure(target);

2812 results.push_back(*maybeResult);

2814}

2815

2816

2817

2818

2819

2820DiagnosedSilenceableFailure

2821SplitOp::apply(transform::TransformRewriter &rewriter,

2822 TransformResults &results, TransformState &state) {

2823

2824 SmallVector<Operation *> payload =

2825 llvm::to_vector(state.getPayloadOps(getTarget()));

2826

2827 bool isMultiwaySplit = getMultiway();

2828

2829 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {

2831 << "requires exactly one target when "

2832 "multiway split is enabled (got "

2833 << llvm::range_size(payload) << ")";

2834 }

2835

2836 SmallVector chunkSizes;

2837

2838 if (!isMultiwaySplit)

2839 chunkSizes.reserve(payload.size());

2840

2841 if (getDynamicChunkSizes()) {

2843 if (isa(getDynamicChunkSizes().getType())) {

2844 chunkSizes = llvm::to_vector(llvm::map_range(

2845 state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {

2848 diag = emitSilenceableError()

2849 << "expected dynamic split point handle to point to a "

2850 "single-result index-typed op";

2851 diag.attachNote(op->getLoc()) << "dynamic split point";

2852 }

2853 return OpFoldResult(op->getResult(0));

2854 }));

2855 } else {

2856 chunkSizes = llvm::to_vector(

2857 llvm::map_range(state.getParams(getDynamicChunkSizes()),

2858 [](Attribute attr) { return OpFoldResult(attr); }));

2859 }

2860 if (diag.isSilenceableFailure())

2861 return diag;

2862

2863

2864

2865 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {

2867 << "expected the dynamic split point handle to point to as "

2868 "many operations ("

2869 << chunkSizes.size() << ") as the target handle ("

2870 << payload.size() << ")";

2871 }

2872 } else {

2873 chunkSizes.resize(payload.size(),

2874 rewriter.getIndexAttr(getStaticChunkSizes()));

2875 }

2876

2877 auto checkStructuredOpAndDimensions =

2878 [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {

2879 if (!linalgOp) {

2880 auto diag = emitSilenceableError() << "only applies to structured ops";

2881 diag.attachNote(loc) << "target op";

2882 return diag;

2883 }

2884

2885 if (getDimension() >= linalgOp.getNumLoops()) {

2886 auto diag = emitSilenceableError() << "dimension " << getDimension()

2887 << " does not exist in target op";

2888 diag.attachNote(loc) << "target op";

2889 return diag;

2890 }

2892 };

2893

2894 auto checkFailureInSplitting =

2895 [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {

2896 if (hasFailed) {

2899 return diag;

2900 }

2902 };

2903

2904 SmallVector<Operation *> opList;

2905 if (isMultiwaySplit) {

2906

2907

2908 TilingInterface head, tail;

2909 Operation *target = payload.front();

2910

2911 LinalgOp linalgOp = dyn_cast(target);

2912

2913

2914 DiagnosedSilenceableFailure diag =

2915 checkStructuredOpAndDimensions(linalgOp, target->getLoc());

2916 if (diag.isSilenceableFailure())

2917 return diag;

2918

2919 for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {

2920

2921 if (idx > 0)

2922 target = tail.getOperation();

2923

2925 break;

2926

2927 linalgOp = cast(target);

2928 Location loc = target->getLoc();

2929

2932 rewriter, cast(linalgOp.getOperation()),

2933 getDimension(), chunkSize);

2934

2935

2936 DiagnosedSilenceableFailure diag =

2937 checkFailureInSplitting(!head && !tail, loc);

2938 if (diag.isDefiniteFailure())

2939 return diag;

2940

2941 opList.push_back(head.getOperation());

2942 }

2943

2944

2945 if (tail)

2946 opList.push_back(tail.getOperation());

2947

2948 } else {

2949

2950 SmallVector<Operation *> first, second;

2951 Operation *noSecondPart = nullptr;

2952 for (const auto &pair : llvm::zip(payload, chunkSizes)) {

2953 Operation *target = std::get<0>(pair);

2954 Location loc = target->getLoc();

2955 LinalgOp linalgOp = dyn_cast(target);

2956 DiagnosedSilenceableFailure diag =

2957 checkStructuredOpAndDimensions(linalgOp, target->getLoc());

2958

2959 if (diag.isSilenceableFailure())

2960 return diag;

2961

2963 std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(

2964 rewriter, cast(linalgOp.getOperation()),

2965 getDimension(), std::get<1>(pair));

2966

2967

2968 DiagnosedSilenceableFailure diagSplit =

2969 checkFailureInSplitting(!first.back() && !second.back(), loc);

2971 return diag;

2972

2973

2974 if (!second.back()) {

2975 noSecondPart = target;

2976 second.pop_back();

2977 }

2978 }

2979

2980 if (second.size() != first.size() && !second.empty()) {

2981 auto diag = emitSilenceableError()

2982 << "splitting does not produce the second part for a subset "

2983 "of targets";

2984 diag.attachNote()

2985 << "expected splitting to produce the second part of all "

2986 "or none of the targets";

2987 diag.attachNote(noSecondPart->getLoc())

2988 << "first target with no second part";

2989 return diag;

2990 }

2991

2992 opList.append(first);

2993 if (!second.empty())

2994 opList.append(second);

2995 }

2996 results.set(cast(getSplitList()), opList);

2998}

2999

3000void SplitOp::getEffects(

3001 SmallVectorImplMemoryEffects::EffectInstance &effects) {

3003 if (getDynamicChunkSizes())

3004 onlyReadsHandle(getDynamicChunkSizesMutable(), effects);

3005 producesHandle(getOperation()->getOpResults(), effects);

3007}

3008

3009ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {

3010 OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;

3011 IntegerAttr staticChunkSizes;

3013 return failure();

3014

3015 OptionalParseResult dynamicPointParseResult =

3017 if (!dynamicPointParseResult.has_value()) {

3018 int64_t staticChunkSizesValue;

3020 return failure();

3021

3022 staticChunkSizes =

3024 }

3025

3026 Type targetType;

3030 return failure();

3031 }

3032 if (dynamicPointParseResult.has_value()) {

3033 Type chunkSizesType;

3034 if (failed(*dynamicPointParseResult) || parser.parseComma() ||

3035 parser.parseType(chunkSizesType) ||

3036 parser.resolveOperand(dynamicChunkSizes, chunkSizesType,

3038 return failure();

3039 }

3040

3041 staticChunkSizes =

3043 }

3044

3045 result.addAttribute(

3046 SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),

3047 staticChunkSizes);

3048 result.addTypes(targetType);

3050}

3051

3052void SplitOp::print(OpAsmPrinter &printer) {

3053 printer << " " << getTarget() << " after ";

3054 int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());

3055 if (staticChunkSize != ShapedType::kDynamic)

3056 printer << staticChunkSize;

3057 else

3058 printer << getDynamicChunkSizes();

3059 printer << " ";

3061 {getStaticChunkSizesAttrName()});

3062 printer << " : " << getTarget().getType();

3063 if (staticChunkSize == ShapedType::kDynamic)

3064 printer << ", " << getDynamicChunkSizes().getType();

3065}

3066

3067LogicalResult SplitOp::verify() {

3068 if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^

3069 (getDynamicChunkSizes() == nullptr)) {

3070 return emitOpError() << "expects either a dynamic or a static split "

3071 "point to be provided";

3072 }

3074}

3075

3076

3077

3078

3079

3080void transform::SplitReductionOp::build(

3081 OpBuilder &builder, OperationState &result, Value target,

3082 int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,

3083 bool useScalingAlgorithm, bool useAlloc) {

3084 MLIRContext *ctx = builder.getContext();

3086 result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),

3088 result.addAttribute(

3089 SplitReductionOp::getInsertSplitDimensionAttrName(result.name),

3091 if (innerParallel) {

3092 result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),

3094 }

3095 if (useScalingAlgorithm) {

3096 result.addAttribute(

3097 SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),

3099 }

3100 if (useAlloc) {

3101 result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),

3103 }

3104 auto resultType = transform::AnyOpType::get(ctx);

3105 result.addTypes({resultType, resultType, resultType, resultType});

3106}

3107

3108DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(

3109 transform::TransformRewriter &rewriter, LinalgOp target,

3110 transform::ApplyToEachResultList &results,

3111 transform::TransformState &state) {

3113 return linalg::SplitReductionOptions{int64_t(getSplitFactor()),

3114 unsigned(getInsertSplitDimension()),

3115 bool(getInnerParallel())};

3116 };

3118 FailureOr splitResult =

3119 (getUseScalingAlgorithm())

3122 if (failed(splitResult))

3123 return emitDefaultDefiniteFailure(target);

3124

3125 results.push_back(splitResult->initOrAlloc);

3126 results.push_back(splitResult->fillOp);

3127 results.push_back(splitResult->splitLinalgOp);

3128 results.push_back(splitResult->resultCombiningLinalgOp);

3130}

3131

3132

3133

3134

3135

3136void transform::TileReductionUsingForOp::build(

3137 OpBuilder &builder, OperationState &result, Value target,

3138 ArrayRef<int64_t> staticTileSizes) {

3139

3140

3141

3142

3143

3144 MLIRContext *ctx = builder.getContext();

3145 auto opTy = transform::AnyOpType::get(ctx);

3146 auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);

3147 build(builder, result,

3148 TypeRange{opTy, opTy, opTy, opTy},

3150 nullptr,

3151 staticTileSizesAttr);

3152}

3153

3154DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(

3155 transform::TransformRewriter &rewriter, Operation *target,

3156 transform::ApplyToEachResultList &results,

3157 transform::TransformState &state) {

3159

3160 auto partialReductionOp = dyn_cast(target);

3161 if (!partialReductionOp) {

3164 "Operation should implement PartialReductionOpInterface");

3165 }

3166

3167 SmallVector reductionDims =

3169 if (reductionDims.empty()) {

3170 for (auto [idx, iteratorType] :

3171 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {

3172 if (iteratorType == utils::IteratorType::reduction)

3173 reductionDims.push_back(idx);

3174 }

3175 }

3176

3177 scf::SCFTilingOptions options;

3178 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);

3179 options.setReductionTilingStrategy(

3182 options.setReductionDims(reductionDims);

3183 FailureOrscf::SCFTilingResult result =

3184 scf::tileUsingSCF(rewriter, partialReductionOp, options);

3185

3188 "failed to tile using partial reduction");

3189 }

3191 for (Value initValue : result->initialValues)

3193 for (auto parallelTiledOp : result->tiledOps)

3194 results.push_back(parallelTiledOp);

3195 for (auto mergeOp : result->mergeOps)

3199}

3200

3201

3202

3203

3204

3205void transform::TileReductionUsingForallOp::build(

3206 OpBuilder &builder, OperationState &result, Value target,

3207 ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,

3209

3210

3211

3212

3213

3214 MLIRContext *ctx = builder.getContext();

3215 auto opTy = transform::AnyOpType::get(ctx);

3218 build(builder, result,

3219 TypeRange{opTy, opTy, opTy, opTy},

3221 {},

3222 staticNumThreadsAttr,

3223 staticTileSizesAttr,

3224 mapping);

3225}

3226

3227DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(

3228 transform::TransformRewriter &rewriter, Operation *target,

3229 transform::ApplyToEachResultList &results,

3230 transform::TransformState &state) {

3232

3233 auto partialReductionOp = dyn_cast(target);

3234 if (!partialReductionOp) {

3237 "Operation should implement PartialReductionOpInterface");

3238 }

3239 SmallVector numThreads =

3241 SmallVector tileSizes =

3243

3244 scf::SCFTilingOptions options;

3245 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);

3246 options.setReductionTilingStrategy(

3248 if (!getNumThreads().empty()) {

3249 options.setNumThreads(numThreads);

3250 } else {

3251 options.setTileSizes(tileSizes);

3252 }

3253 if (auto mapping = getMapping()) {

3254 options.setMapping(mapping.value().getValue());

3255 }

3256 SmallVector reductionDims =

3258 if (reductionDims.empty()) {

3259 for (auto [idx, iteratorType] :

3260 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {

3261 if (iteratorType == utils::IteratorType::reduction)

3262 reductionDims.push_back(idx);

3263 }

3264 }

3265 options.setReductionDims(reductionDims);

3266 FailureOrscf::SCFTilingResult result =

3267 scf::tileUsingSCF(rewriter, partialReductionOp, options);

3268

3270 auto diag = emitSilenceableError() << "could not tile reduction";

3271 return diag;

3272 }

3274

3275 for (Value initValue : result->initialValues)

3277 for (auto parallelTiledOp : result->tiledOps)

3278 results.push_back(parallelTiledOp);

3279 for (auto mergeOp : result->mergeOps)

3283}

3284

3285

3286

3287

3288

3289DiagnosedSilenceableFailure

3290transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,

3291 TransformResults &transformResults,

3292 TransformState &state) {

3293

3294 SmallVector<Operation *> targetOps =

3295 llvm::to_vector(state.getPayloadOps(getTarget()));

3296

3297 if (!llvm::hasSingleElement(targetOps)) {

3299 << "requires exactly one target (got " << llvm::range_size(targetOps)

3300 << ")";

3301 }

3302

3303 Operation *target = *targetOps.begin();

3304 auto linalgOp = dyn_cast(target);

3305 auto tileableOp = dyn_cast(target);

3306

3307 if (!linalgOp)

3309

3310 OpBuilder builder(linalgOp.getContext());

3311

3312 if (isa(getChunkSizes().getType())) {

3313 if (linalgOp.hasDynamicShape()) {

3314 auto diag = emitSilenceableError()

3315 << "cannot compute parametric tile sizes for dynamically "

3316 "shaped payload op";

3317 diag.attachNote(linalgOp->getLoc()) << "payload op";

3318 return diag;

3319 }

3320

3321 FailureOr spec =

3323 getTargetSize());

3325 return emitSilenceableError()

3326 << "failed to compute multi-size tiling sizes";

3327 }

3328

3329 SmallVector<int64_t> chunkSizes;

3330

3331 for (auto &&[tileSize, tripCount] :

3332 llvm::zip_equal(spec->tileSizes, spec->tripCounts))

3333 chunkSizes.push_back(tileSize * tripCount);

3334

3335 auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {

3336 return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {

3338 });

3339 };

3341 getI64AttrsFromI64(spec->tileSizes));

3342 transformResults.setParams(cast(getChunkSizes()),

3343 getI64AttrsFromI64(chunkSizes));

3344

3346 }

3347

3349

3350 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());

3351 unsigned dimension = getDimension();

3352

3354 builder, tileableOp, dimension, targetSize, true);

3356 return emitSilenceableError() << "could not generate tile size computation";

3357 }

3358

3361 auto apply = [&](AffineExpr expr, ArrayRef ofrs) -> Value {

3363 ofrs);

3364 };

3365

3366 SmallVector chunkSizes;

3367 Value splitPoint;

3368 for (auto &&[tileSize, tripCount] :

3369 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {

3370 splitPoint = apply(s0 * s1, {tileSize, tripCount});

3371 chunkSizes.push_back(splitPoint);

3372 }

3373

3374 auto getDefiningOps = [&](ArrayRef values) {

3375 return llvm::map_to_vector(values, [&](Value value) -> Operation * {

3377 });

3378 };

3379

3381 getDefiningOps(spec->tileSizes));

3382 transformResults.set(cast(getChunkSizes()),

3383 getDefiningOps(chunkSizes));

3384

3386}

3387

3388LogicalResult transform::ContinuousTileSizesOp::verify() {

3389

3391 return emitOpError() << "expects all results type to be the same";

3392 }

3393

3395}

3396

3397void transform::ContinuousTileSizesOp::getEffects(

3398 SmallVectorImplMemoryEffects::EffectInstance &effects) {

3401 else

3404 producesHandle(getOperation()->getOpResults(), effects);

3405}

3406

3408 Type targetType, Type tileSizes,

3411}

3412

3414 Type &targetType,

3415 Type &tileSizesType,

3416 Type &chunkSizesType) {

3417 FunctionType funcType;

3419 if (failed(parser.parseType(funcType)))

3420 return failure();

3421

3422 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {

3423 parser.emitError(typeLoc) << "expects a trailing functional type with one "

3424 "argument and one result";

3425 }

3426 targetType = funcType.getInput(0);

3427 tileSizesType = chunkSizesType = funcType.getResult(0);

3428

3430}

3431

3432

3433

3434

3435

3436void transform::TileUsingForOp::build(

3437 OpBuilder &builder, OperationState &result, TypeRange loopTypes,

3438 Value target, ArrayRef<int64_t> staticTileSizes,

3439 ArrayRef<int64_t> interchange,

3440 std::optional<ArrayRef> scalableSizes) {

3441 return build(builder, result, loopTypes,

3443

3445 interchange, scalableSizes);

3446}

3447

3448void transform::TileUsingForOp::build(

3449 OpBuilder &builder, OperationState &result, Value target,

3450 ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,

3451 std::optional<ArrayRef> scalableSizes) {

3454 interchange, scalableSizes);

3455}

3456

3457void transform::TileUsingForOp::build(

3458 OpBuilder &builder, OperationState &result, Value target,

3459 ArrayRef mixedTileSizes, ArrayRef<int64_t> interchange,

3460 std::optional<ArrayRef> scalableSizes) {

3461

3462

3463 SmallVector loopTypes(1, builder.getTypetransform::AnyOpType());

3464 build(builder, result, loopTypes, target, mixedTileSizes, interchange,

3465 scalableSizes);

3466}

3467

3468void transform::TileUsingForOp::build(

3469 OpBuilder &builder, OperationState &result, TypeRange loopTypes,

3470 Value target, ArrayRef mixedTileSizes,

3471 ArrayRef<int64_t> interchange,

3472 std::optional<ArrayRef> scalableSizes) {

3473 SmallVector<int64_t> staticTileSizes;

3474 SmallVector dynamicTileSizes;

3476

3477

3478

3480 unsigned numExpectedLoops =

3481 staticTileSizes.size() - llvm::count(staticTileSizes, 0);

3482 SmallVector resultTypes;

3483 resultTypes.reserve(numExpectedLoops);

3484 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&

3485 "expected one loop type or as many as loops");

3486 if (loopTypes.size() == 1)

3487 resultTypes.append(numExpectedLoops, loopTypes[0]);

3488 else

3489 llvm::append_range(resultTypes, loopTypes);

3490 SmallVector expandedScalableSizes(mixedTileSizes.size(), false);

3491 if (scalableSizes.has_value())

3492 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());

3493 build(builder, result, target.getType(),

3494 resultTypes,

3496 dynamicTileSizes,

3497 staticTileSizesAttr,

3499 expandedScalableSizes);

3500}

3501

3502LogicalResult transform::TileUsingForOp::verify() {

3503 if (getMixedSizes().size() != getScalableSizes().size())

3504 return emitOpError("expected same number of sizes (")

3505 << getMixedSizes().size() << ") and scalable sizes ("

3506 << getScalableSizes().size() << ")";

3507 ArrayRef<int64_t> staticSizes = getStaticSizes();

3508 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);

3509 if (getLoops().size() != numExpectedLoops)

3510 return emitOpError("expected number of loops to tile (")

3511 << numExpectedLoops << ") to match number of `loops` results ("

3512 << getLoops().size() << ")";

3514}

3515

3516DiagnosedSilenceableFailure

3517transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,

3518 TransformResults &transformResults,

3519 TransformState &state) {

3520 ArrayRef<int64_t> tileSizes = getStaticSizes();

3521

3522 SmallVector<Operation *> targets =

3523 llvm::to_vector(state.getPayloadOps(getTarget()));

3524 SmallVector<SmallVector<Operation *>> dynamicSizeProducers;

3525 SmallVector<SmallVector<int64_t>> paramSizes;

3529 if (isa(transformValue.getType())) {

3530 dynamicSizeProducers.push_back({});

3531 ArrayRef params = state.getParams(transformValue);

3532 paramSizes.push_back(

3533 llvm::to_vector(llvm::map_range(params, [](Attribute attr) {

3534 return cast(attr).getValue().getSExtValue();

3535 })));

3536

3537 if (paramSizes.back().size() != targets.size()) {

3538 DiagnosedSilenceableFailure diag =

3539 emitSilenceableError()

3540 << "expected as many parameter values ("

3541 << dynamicSizeProducers.back().size() << ") as target ops ("

3542 << targets.size() << ")";

3543 diag.attachNote(transformValue.getLoc()) << "for this parameter";

3544 return diag;

3545 }

3546

3547 continue;

3548 }

3549 paramSizes.push_back({});

3550 dynamicSizeProducers.push_back(

3551 llvm::to_vector(state.getPayloadOps(transformValue)));

3552

3553 if (dynamicSizeProducers.back().size() != targets.size()) {

3554 DiagnosedSilenceableFailure diag =

3555 emitSilenceableError()

3556 << "expected as many dynamic size-producing operations ("

3557 << dynamicSizeProducers.back().size() << ") as target ops ("

3558 << targets.size() << ")";

3559 diag.attachNote(transformValue.getLoc()) << "for this handle";

3560 return diag;

3561 }

3562

3563 for (Operation *op : dynamicSizeProducers.back()) {

3566 continue;

3567 }

3568

3569 DiagnosedSilenceableFailure diag =

3570 emitSilenceableError() << "expected sizes to be produced by ops "

3571 "with a single index-type result";

3572 diag.attachNote(op->getLoc()) << "size producer op";

3573 diag.attachNote(transformValue.getLoc()) << "for this handle";

3574 return diag;

3575 }

3576 }

3577

3578 SmallVector<Operation *> tiled;

3579 SmallVector<SmallVector<Operation *, 4>, 4> loops;

3580 loops.resize(getLoops().size());

3581 auto scalableSizes = getScalableSizes();

3582 for (auto [i, op] : llvm::enumerate(targets)) {

3583 auto tilingInterface = dyn_cast(op);

3584 if (!tilingInterface) {

3585 DiagnosedSilenceableFailure diag =

3586 emitSilenceableError()

3587 << "only ops implementing TilingInterface are supported";

3588 diag.attachNote(op->getLoc()) << "target op";

3589 return diag;

3590 }

3591 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {

3592 DiagnosedSilenceableFailure diag =

3593 emitSilenceableError()

3594 << "too many tiles provided, expected at most "

3595 << tilingInterface.getLoopIteratorTypes().size() << " found "

3596 << tileSizes.size();

3597 diag.attachNote(op->getLoc()) << "target op";

3598 return diag;

3599 }

3600

3601 scf::SCFTilingOptions tilingOptions;

3602 if (tileSizes.empty()) {

3603 tilingOptions.setTileSizeComputationFunction(

3604 [](OpBuilder &, Operation *) -> SmallVector {

3605 return {};

3606 });

3607 } else {

3608 tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,

3609 Operation *) {

3610 SmallVector sizes;

3611 sizes.reserve(tileSizes.size());

3612 unsigned dynamicIdx = 0;

3613

3614 for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {

3615 if (auto attr = llvm::dyn_cast_if_present(ofr)) {

3616 if (scalableSizes[ofrIdx]) {

3618 b, getLoc(), cast(attr).getInt());

3619 Value vscale =

3620 vector::VectorScaleOp::create(b, getLoc(), b.getIndexType());

3621 sizes.push_back(

3622 arith::MulIOp::create(b, getLoc(), val, vscale).getResult());

3623 } else {

3624 sizes.push_back(attr);

3625 }

3626 continue;

3627 }

3628 ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];

3629 ArrayRef<int64_t> params = paramSizes[dynamicIdx];

3630 ++dynamicIdx;

3631 assert((dynamicSizes.empty() ^ params.empty()) &&

3632 "expected either dynamic sizes or parameters");

3633 if (!params.empty()) {

3634 sizes.push_back(b.getIndexAttr(params[index]));

3635 } else {

3636 sizes.push_back(dynamicSizes[index]->getResult(0));

3637 }

3638 }

3639 return sizes;

3640 });

3641 }

3642

3643 tilingOptions.setInterchange(getInterchange());

3644 FailureOrscf::SCFTilingResult maybeTilingResult =

3645 tileUsingSCF(rewriter, tilingInterface, tilingOptions);

3646 if (failed(maybeTilingResult))

3648

3649 rewriter.replaceOp(op, maybeTilingResult->replacements);

3650

3651 tiled.append(maybeTilingResult->tiledOps);

3652 for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))

3653 loops[en2.index()].push_back(en2.value());

3654 }

3655

3656 transformResults.set(cast(getTiledLinalgOp()), tiled);

3657 for (const auto &en : llvm::enumerate(loops))

3658 transformResults.set(cast(getLoops()[en.index()]), en.value());

3659

3661}

3662

3663SmallVector transform::TileUsingForOp::getMixedSizes() {

3665 ArrayRef<int64_t> tileSizes = getStaticSizes();

3666 SmallVector results;

3667 results.reserve(tileSizes.size());

3668 unsigned dynamicPos = 0;

3670 for (int64_t size : tileSizes) {

3671 if (size == ShapedType::kDynamic) {

3672 results.push_back(dynamic[dynamicPos++]);

3673 } else {

3674 results.push_back(builder.getIndexAttr(size));

3675 }

3676 }

3677 return results;

3678}

3679

3680void transform::TileUsingForOp::getEffects(

3681 SmallVectorImplMemoryEffects::EffectInstance &effects) {

3684 producesHandle(getOperation()->getOpResults(), effects);

3686}

3687

3688

3689

3690

3691

3692void transform::TileUsingForallOp::build(OpBuilder &builder,

3694 ArrayRef<int64_t> staticTileSizes,

3695 transform::TileSizesSpec,

3697 return build(builder, result,

3699

3701 TileSizesSpec(),

3702 mapping);

3703}

3704

3705void transform::TileUsingForallOp::build(OpBuilder &builder,

3707 ArrayRef mixedTileSizes,

3708 transform::TileSizesSpec,

3710 SmallVector<int64_t> staticTileSizes;

3711 SmallVector dynamicTileSizes;

3713

3714

3715

3716 MLIRContext *ctx = builder.getContext();

3717 auto operationType = transform::AnyOpType::get(ctx);

3719 build(builder, result,

3720 TypeRange{operationType, operationType},

3723 dynamicTileSizes,

3724 Value(),

3725 Value(),

3727 staticTileSizesAttr,

3728 mapping);

3729}

3730

3731void transform::TileUsingForallOp::build(OpBuilder &builder,

3733 ArrayRef<int64_t> staticNumThreads,

3734 transform::NumThreadsSpec,

3738 NumThreadsSpec(), mapping);

3739}

3740

3741void transform::TileUsingForallOp::build(OpBuilder &builder,

3743 ArrayRef mixedNumThreads,

3744 transform::NumThreadsSpec,

3746 SmallVector<int64_t> staticNumThreads;

3747 SmallVector dynamicNumThreads;

3749 staticNumThreads);

3750

3751

3752

3753 MLIRContext *ctx = builder.getContext();

3754 auto operationType = transform::AnyOpType::get(ctx);

3756 build(builder, result,

3757 TypeRange{operationType, operationType},

3759 dynamicNumThreads,

3761 Value(),

3762 Value(),

3763 staticNumThreadsAttr,

3765 mapping);

3766}

3767

3768

3769

3770static SmallVector

3776 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);

3778 for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {

3780 rewriter, loc, normalizedUbExpr, {lb, ub, step});

3781 normalizedUbs.push_back(normalizedUb);

3782 }

3783 return normalizedUbs;

3784}

3785

3786

3787

3796 AffineExpr denormExpr = s0 + d0 * s1;

3798

3799 for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {

3802 denormalizedIvs.push_back(

3804 }

3805 return denormalizedIvs;

3806}

3807

3808

3809

3810

3811

3812

3813

3814

3816 scf::ForallOp loop) {

3820

3822 return loop;

3823 }

3824

3825 Location loc = loop.getLoc();

3832

3833 auto normalizedForallOp = scf::ForallOp::create(

3834 rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,

3835 loop.getOutputs(), loop.getMapping(),

3837

3838 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();

3840 Block *normalizedLoopBlock = normalizedForallOp.getBody();

3842

3844 denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);

3845 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),

3846 normalizedForallOp.getRegionIterArgs().end());

3847 Block *origLoopBlock = loop.getBody();

3848 rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);

3849

3850 rewriter.replaceOp(loop, normalizedForallOp);

3851 return normalizedForallOp;

3852}

3853

3859 scf::SCFTilingResult &tilingResult) {

3860

3861 auto tileableOp = dyn_cast(target);

3862 if (!tileableOp) {

3864 transformOp.emitSilenceableError()

3865 << "only TilingInterface ops are supported";

3866 diag.attachNote(target->getLoc()) << "target op";

3867 return diag;

3868 }

3870 scf::SCFTilingOptions options;

3871 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);

3872 if (!mixedNumThreads.empty()) {

3873 options.setNumThreads(mixedNumThreads);

3874 } else {

3875 options.setTileSizes(mixedTileSizes);

3876 }

3877 if (mapping) {

3878 options.setMapping(mapping.value().getValue());

3879 }

3880 FailureOrscf::SCFTilingResult maybeTilingResult =

3881 scf::tileUsingSCF(rewriter, tileableOp, options);

3882

3883 if (failed(maybeTilingResult))

3884 return transformOp.emitDefaultSilenceableFailure(tileableOp);

3885

3886 rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);

3887

3888 tilingResult = *maybeTilingResult;

3889

3890 if (mixedNumThreads.empty()) {

3891 auto generatedForallOp = castscf::ForallOp(tilingResult.loops.front());

3894 scf::ForallOp normalizedForallOp =

3896 tilingResult.loops.front() = normalizedForallOp;

3897 }

3898

3900}

3901

3906 auto transformOp = cast(getOperation());

3907

3908

3911

3912

3915 getPackedNumThreads()

3917 state, transformOp, mixedNumThreads, getPackedNumThreads())

3919 state, transformOp, mixedNumThreads, getMixedNumThreads());

3921 return status;

3923 status = getPackedTileSizes()

3925 state, transformOp, mixedTileSizes, getPackedTileSizes())

3927 state, transformOp, mixedTileSizes, getMixedTileSizes());

3929 return status;

3930

3932 scf::SCFTilingResult tilingResult;

3934 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,

3935 getMapping(), tilingResult);

3936 if (diag.succeeded())

3937 return diag;

3938 tileOps.push_back(tilingResult.loops.front());

3939 tiledOps.append(tilingResult.tiledOps);

3940 }

3941

3942 transformResults.set(cast(getForallOp()), tileOps);

3943 transformResults.set(cast(getTiledOp()), tiledOps);

3944

3946}

3947

3948void transform::TileUsingForallOp::getEffects(

3949 SmallVectorImplMemoryEffects::EffectInstance &effects) {

3955 producesHandle(getOperation()->getOpResults(), effects);

3957}

3958

3959SmallVector TileUsingForallOp::getMixedNumThreads() {

3961 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);

3962}

3963

3964SmallVector TileUsingForallOp::getMixedTileSizes() {

3967}

3968

3969LogicalResult TileUsingForallOp::verify() {

3970 int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +

3971 static_cast<int>(getPackedNumThreads() != Value());

3972 if (numThreadsSpec > 1)

3974 "num_threads and packed_num_threads are mutually exclusive");

3975 int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +

3976 static_cast<int>(getPackedTileSizes() != Value());

3977 if (tileSizesSpec > 1)

3979 "tile_sizes and packed_tile_sizes are mutually exclusive");

3980 if (numThreadsSpec == 0 && tileSizesSpec == 0)

3981 return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "

3982 "must be specified");

3984}

3985

3986

3987

3988

3989

3990void transform::VectorizeChildrenAndApplyPatternsOp::build(

3991 OpBuilder &builder, OperationState &result, Value target,

3992 bool foldTypeExtensionsIntoContract, bool vectorizePadding,

3993 bool vectorizeExtract, bool flatten1DDepthwiseConv) {

3995 if (foldTypeExtensionsIntoContract) {

3996 result.addAttribute(

3997 VectorizeChildrenAndApplyPatternsOp::

3998 getFoldTypeExtensionsIntoContractAttrName(result.name),

4000 }

4001 if (vectorizePadding) {

4002 result.addAttribute(

4003 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(

4006 }

4007 if (vectorizeExtract) {

4008 result.addAttribute(

4009 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(

4012 }

4013 if (flatten1DDepthwiseConv) {

4014 result.addAttribute(

4015 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(

4018 }

4019 result.addTypes(transform::AnyOpType::get(builder.getContext()));

4020}

4021

4022namespace {

4023

4024

4025struct VectorizationPattern : public RewritePattern {

4026 explicit VectorizationPattern(MLIRContext *context,

4027 bool vectorizeExtract = false,

4028 bool flattenConv = false)

4029 : RewritePattern(MatchAnyOpTypeTag(), 1, context),

4030 vectorizeNDExtract(vectorizeExtract),

4031 flatten1DDepthwiseConv(flattenConv) {}

4032 LogicalResult matchAndRewrite(Operation *op,

4033 PatternRewriter &rewriter) const override {

4036 "Unsupported Op, cannot vectorize");

4037 FailureOr vectorResults =

4038 vectorize(rewriter, op, {},

4039 {}, vectorizeNDExtract,

4040 flatten1DDepthwiseConv);

4041 if (failed(vectorResults))

4042 return failure();

4043 rewriter.replaceOp(op, vectorResults->replacements);

4045 }

4046

4047private:

4048

4049

4050 bool vectorizeNDExtract = false;

4051

4052

4053

4054 bool flatten1DDepthwiseConv = false;

4055};

4056}

4057

4058DiagnosedSilenceableFailure

4059transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(

4060 transform::TransformRewriter &rewriter, Operation *target,

4061 transform::ApplyToEachResultList &results,

4062 transform::TransformState &state) {

4063 if (target->hasTraitOpTrait::IsIsolatedFromAbove()) {

4064 auto diag = this->emitOpError("requires isolated-from-above targets");

4065 diag.attachNote(target->getLoc()) << "non-isolated target";

4067 }

4068

4070 RewritePatternSet patterns(ctx);

4071 patterns.add(ctx, getVectorizeNdExtract(),

4072 getFlatten_1dDepthwiseConv());

4073

4074 if (!getDisableTransferPermutationMapLoweringPatterns())

4076

4077 if (!getDisableMultiReductionToContractPatterns())

4079

4081

4082 patterns.add<linalg::LinalgCopyVTRForwardingPattern,

4083 linalg::LinalgCopyVTWForwardingPattern>(ctx,

4084 2);

4085 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);

4086 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);

4088

4089 patterns.add(ctx);

4090

4091 if (getFoldTypeExtensionsIntoContract())

4093

4094 if (getVectorizePadding()) {

4096

4097

4099 }

4101

4102 TrackingListener listener(state, *this);

4105 GreedyRewriteConfig().setListener(&listener))))

4106 return emitDefaultDefiniteFailure(target);

4107

4110}

4111

4112

4113

4114

4115

4116DiagnosedSilenceableFailure transform::VectorizeOp::apply(

4117 transform::TransformRewriter &rewriter,

4118 mlir::transform::TransformResults &transformResults,

4119 mlir::transform::TransformState &state) {

4120 auto targets = state.getPayloadOps(getTarget());

4121 if (std::empty(targets))

4123 auto transformOp = cast(getOperation());

4124 SmallVector<int64_t> vectorSizes;

4126 state, transformOp, getMixedVectorSizes(), vectorSizes);

4128 return status;

4129

4130

4131 for (Operation *target : targets) {

4134 << "Unsupported Op, cannot vectorize";

4135 }

4136 FailureOr vectorResults =

4138 getVectorizeNdExtract().value_or(false),

4139 false,

4140 getAssumeDynamicDimsMatchVecSizes().value_or(false),

4141 getCreateNamedContraction().value_or(false));

4142 if (failed(vectorResults)) {

4144 << "Attempted to vectorize, but failed";

4145 }

4146 rewriter.replaceOp(target, vectorResults->replacements);

4147 }

4148

4150}

4151

4152void transform::VectorizeOp::getEffects(

4153 SmallVectorImplMemoryEffects::EffectInstance &effects) {

4157}

4158

4159SmallVector VectorizeOp::getMixedVectorSizes() {

4161 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);

4162}

4163

4164LogicalResult transform::VectorizeOp::verify() {

4165 if (getStaticVectorSizes().size() != getScalableSizes().size())

4166 return emitOpError("expected same number of vector sizes (")

4167 << getStaticVectorSizes().size() << ") and scalable sizes ("

4168 << getScalableSizes().size() << ")";

4170}

4171

4172

4173

4174

4175

4176DiagnosedSilenceableFailure

4177transform::HoistRedundantVectorTransfersOp::applyToOne(

4178 transform::TransformRewriter &rewriter, func::FuncOp target,

4179 transform::ApplyToEachResultList &results,

4180 transform::TransformState &state) {

4181

4182

4183

4187}

4188

4189

4190

4191

4192

4193DiagnosedSilenceableFailure

4194transform::HoistRedundantVectorBroadcastsOp::applyToOne(

4195 transform::TransformRewriter &rewriter, mlir::Operation *target,

4196 transform::ApplyToEachResultList &results,

4197 transform::TransformState &state) {

4202}

4203

4204

4205

4206

4207

4208DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(

4209 transform::TransformRewriter &rewriter, linalg::LinalgOp target,

4210 transform::ApplyToEachResultList &results,

4211 transform::TransformState &state) {

4213 auto maybeTransformed =

4216 .Case([&](linalg::Conv2DNhwcHwcfOp op) {

4218 })

4219 .Case([&](linalg::Conv2DNhwcFhwcOp op) {

4221 })

4222 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {

4224 })

4225 .Case([&](linalg::Conv2DNchwFchwOp op) {

4227 })

4228 .Default([&](Operation *op) {

4230 });

4231 if (failed(maybeTransformed))

4232 return emitDefaultSilenceableFailure(target);

4233

4234 results.push_back(maybeTransformed->first);

4235

4236 results.push_back(maybeTransformed->second);

4238}

4239

4240

4241

4242

4243

4244DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(

4245 transform::TransformRewriter &rewriter, linalg::LinalgOp target,

4246 transform::ApplyToEachResultList &results,

4247 transform::TransformState &state) {

4251 << "only elementwise flattening is supported";

4252

4253

4254 if (target.getNumLoops() <= 1) {

4257 }

4258

4259

4261 std::iota(reassociation.begin(), reassociation.end(), 0);

4262 auto maybeFlattened =

4264 if (failed(maybeFlattened))

4266 << "attempted to flatten, but failed";

4267 results.push_back(maybeFlattened->collapsedOp);

4270}

4271

4272

4273

4274

4275

4276DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(

4277 transform::TransformRewriter &rewriter, linalg::LinalgOp target,

4278 transform::ApplyToEachResultList &results,

4279 transform::TransformState &state) {

4281 auto maybeTransformed =

4283 .Case([&](linalg::Conv2DNhwcFhwcOp op) {

4285 })

4286 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {

4288 })

4289 .Default([&](Operation *op) {

4291 });

4292 if (failed(maybeTransformed))

4293 return emitDefaultSilenceableFailure(target);

4294

4295 results.push_back(*maybeTransformed);

4297}

4298

4299

4300

4301

4302

4303DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(

4304 transform::TransformRewriter &rewriter, linalg::LinalgOp target,

4305 transform::ApplyToEachResultList &results,

4306 transform::TransformState &state) {

4308 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;

4309 auto maybeTransformed =

4311 .Case([&](linalg::MatmulOp op) {

4313 })

4314 .Case([&](linalg::BatchMatmulOp op) {

4316 })

4317 .Default([&](Operation *op) { return failure(); });

4318 if (failed(maybeTransformed))

4320

4321 results.push_back(*maybeTransformed);

4323}

4324

4325

4326

4327

4328template

4329static DiagnosedSilenceableFailure

4333 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,

4334 tensor::ParallelInsertSliceOp>() &&

4335 "wrong op type");

4336

4337 if (auto copySource =

4338 target.getSource().template getDefiningOplinalg::CopyOp()) {

4341 }

4342

4343

4344

4345

4346 if (isamlir::ParallelCombiningOpInterface(target.getOperation()))

4348

4349 Value extracted = tensor::ExtractSliceOp::create(

4350 rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),

4351 target.getMixedSizes(), target.getMixedStrides());

4352 Value copied = linalg::CopyOp::create(rewriter, target.getLoc(),

4353 target.getSource(), extracted)

4354 .getResult(0);

4355

4359 target.getMixedSizes(), target.getMixedStrides());

4360

4363}

4364

4365DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(

4366 transform::TransformRewriter &rewriter, Operation *targetOp,

4367 transform::ApplyToEachResultList &results,

4368 transform::TransformState &state) {

4369

4371 if (auto target = dyn_casttensor::InsertSliceOp(targetOp))

4372 return doit(rewriter, target, results, state);

4373 if (auto target = dyn_casttensor::ParallelInsertSliceOp(targetOp))

4374 return doit(rewriter, target, results, state);

4375

4376 DiagnosedSilenceableFailure diag =

4377 emitSilenceableError()

4378 << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";

4379 diag.attachNote(targetOp->getLoc()) << "target op";

4380 return diag;

4381}

4382

4383

4384

4385

4386

4387DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(

4388 transform::TransformRewriter &rewriter, Operation *target,

4389 transform::ApplyToEachResultList &results,

4390 transform::TransformState &state) {

4391

4392 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {

4393 DiagnosedSilenceableFailure diag =

4394 emitSilenceableError()

4395 << "only linalg.copy and tensor.pad target ops are supported";

4396 diag.attachNote(target->getLoc()) << "target op";

4397 return diag;

4398 }

4399 assert(target->getNumResults() == 1 && "expected single result");

4400 auto resultShapedType = cast(target->getResult(0).getType());

4401 if (!resultShapedType.hasStaticShape()) {

4402 DiagnosedSilenceableFailure diag =

4403 emitSilenceableError()

4404 << "only statically sized ops of rank <= 3 are supported";

4405 diag.attachNote(target->getLoc()) << "target op";

4406 return diag;

4407 }

4408

4409

4410 int64_t desiredBitAlignment = getDesiredBitAlignment();

4411 int64_t eltBitwidth =

4412 resultShapedType.getElementType().getIntOrFloatBitWidth();

4413 if (desiredBitAlignment % eltBitwidth != 0) {

4414 desiredBitAlignment = eltBitwidth;

4415 }

4416

4417 gpu::CopyMappingInfo mapping(

4419 getTotalNumThreads(),

4420 desiredBitAlignment,

4421 resultShapedType.getShape(),

4422 false,

4423

4424 resultShapedType.getElementType().getIntOrFloatBitWidth());

4425 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {

4426 DiagnosedSilenceableFailure diag =

4427 emitSilenceableError()

4428 << "too few threads to map copy op to threads on the most minor "

4429 "dimension, given alignment and vector size constraints, try "

4430 "smaller tile size of mapping to more threads";

4431 diag.attachNote(target->getLoc()) << "target op";

4432 return diag;

4433 }

4434

4435

4437 scf::SCFTilingResult tilingResult;

4439 rewriter,

4440 state,

4441 *this,

4443 getMixedValues(mapping.numThreads, {}, b),

4444 ArrayRef{},

4445 b.getArrayAttr(mapping.threadMapping),

4446 tilingResult);

4447 if (diag.succeeded())

4448 return diag;

4449

4450 results.push_back(tilingResult.loops.front());

4451 for (auto op : tilingResult.tiledOps)

4454}

4455

4456

4457

4458

4459

4460DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(

4461 transform::TransformRewriter &rewriter, linalg::LinalgOp target,

4462 transform::ApplyToEachResultList &results,

4463 transform::TransformState &state) {

4465 FailureOr<Operation *> maybeTransformed = failure();

4467 .Case([&](linalg::Conv2DNhwcFhwcOp op) {

4468 maybeTransformed =

4470 return true;

4471 })

4472 .Default([&](Operation *op) { return false; });

4473

4474 if (!supported) {

4475 return emitSilenceableError()

4476 << "this operation is not supported to convert to Winograd Conv2D";

4477 }

4478

4479 if (failed(maybeTransformed)) {

4480 return emitSilenceableError() << "apply Winograd Conv2D failed";

4481 }

4482

4483 results.push_back(*maybeTransformed);

4485}

4486

4487DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(

4488 transform::TransformRewriter &rewriter, Operation *target,

4489 transform::ApplyToEachResultList &results,

4490 transform::TransformState &state) {

4492 FailureOr<Operation *> maybeTransformed = failure();

4493 bool supported =

4495 .Case([&](linalg::WinogradFilterTransformOp op) {

4497 return true;

4498 })

4499 .Case([&](linalg::WinogradInputTransformOp op) {

4501 return true;

4502 })

4503 .Case([&](linalg::WinogradOutputTransformOp op) {

4505 return true;

4506 })

4507 .Default(false);

4508

4509 if (!supported) {

4510 DiagnosedSilenceableFailure diag =

4511 emitSilenceableError()

4512 << "this operation is not supported to decompose into other operations";

4513 diag.attachNote(target->getLoc()) << "target op";

4514 return diag;

4515 }

4516

4517 if (failed(maybeTransformed)) {

4518 DiagnosedSilenceableFailure diag =

4519 emitSilenceableError() << "decompose Winograd operations failed";

4520 diag.attachNote(target->getLoc()) << "target op";

4521 return diag;

4522 }

4523

4524 results.push_back(*maybeTransformed);

4526}

4527

4528#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"

4529

4530#define GET_OP_CLASSES

4531#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"

static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)

Maps the 2-dim vector shape to the two 16-bit tile sizes.

p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")

Given a list of lists of parsed operands, populates uniqueOperands with unique operands.

static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)

Copies the given number of bytes from src to dst pointers.

static SmallVector< Value > denormalizeIndVar(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)

When a loop is normalized, the uses of the induction variable within the loop need to replaced with o...

Definition LinalgTransformOps.cpp:3788

b

Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...

Definition LinalgTransformOps.cpp:2097

static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)

Definition LinalgTransformOps.cpp:1182

static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)

When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...

Definition LinalgTransformOps.cpp:172

static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)

Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...

Definition LinalgTransformOps.cpp:86

TypeRange

Definition LinalgTransformOps.cpp:2099

b ValueRange

Definition LinalgTransformOps.cpp:2103

static SmallVector< OpFoldResult > normalizeUpperBounds(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)

Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.

Definition LinalgTransformOps.cpp:3771

static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)

Find the first "extract" user of producerOp and tile it right before its use.

Definition LinalgTransformOps.cpp:957

static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, scf::ForallOp loop)

Given a scf.forall loop return a loop op with the loop bounds normalized. TODO: Replace this with a g...

Definition LinalgTransformOps.cpp:3815

target

Definition LinalgTransformOps.cpp:2100

#define DOWNSCALE_NORMAL(a, b)

static bool mayBeRead(OpOperand &operand)

Return true if the operand may be read from by its owner.

Definition LinalgTransformOps.cpp:394

static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)

Definition LinalgTransformOps.cpp:1685

result

Definition LinalgTransformOps.cpp:2098

static bool sameOrEquivalentIterArg(Value src, Value dst)

Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is ...

Definition LinalgTransformOps.cpp:907

static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)

First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...

Definition LinalgTransformOps.cpp:1080

static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)

Definition LinalgTransformOps.cpp:3413

static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)

Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...

Definition LinalgTransformOps.cpp:825

static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)

Definition LinalgTransformOps.cpp:1691

static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)

Attempts to apply the pattern specified as template argument to the given operation.

Definition LinalgTransformOps.cpp:63

static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tileSizes, Type)

Definition LinalgTransformOps.cpp:3407

static DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)

Definition LinalgTransformOps.cpp:4330

static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)

Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...

Definition LinalgTransformOps.cpp:665

*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method

*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`

static std::string diag(const llvm::Value &value)

memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))

static llvm::ManagedStatic< PassManagerOptions > options

static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)

Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...

Base type for affine expression.

virtual Builder & getBuilder() const =0

Return a builder which provides useful access to MLIRContext, global objects like types and attribute...

virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0

Parse a named dictionary into 'result' if it is present.

virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0

Emit a diagnostic at the specified location and return failure.

ParseResult parseInteger(IntT &result)

Parse an integer value from the stream.

virtual ParseResult parseColonType(Type &result)=0

Parse a colon followed by a type.

virtual SMLoc getCurrentLocation()=0

Get the location of the next token and store it into the argument.

virtual ParseResult parseType(Type &result)=0

Parse a type.

virtual ParseResult parseComma()=0

Parse a , token.

ParseResult parseKeyword(StringRef keyword)

Parse a given keyword.

Attributes are known-constant values of operations.

This class represents an argument of a Block.

Block represents an ordered list of Operations.

Operation * getParentOp()

Returns the closest surrounding operation that contains this block.

This class is a general helper class for creating context-global objects like types,...

IntegerAttr getIndexAttr(int64_t value)

IntegerAttr getIntegerAttr(Type type, int64_t value)

DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)

AffineExpr getAffineSymbolExpr(unsigned position)

IntegerAttr getI64IntegerAttr(int64_t value)

Ty getType(Args &&...args)

Get or construct an instance of the type Ty with provided arguments.

MLIRContext * getContext() const

ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)

ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)

Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)

Attaches a note to the error.

The result of a transform IR operation application.

static DiagnosedSilenceableFailure success()

Constructs a DiagnosedSilenceableFailure in the success state.

bool isDefiniteFailure() const

Returns true if this is a definite failure.

static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)

Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...

bool succeeded() const

Returns true if this is a success.

static DiagnosedSilenceableFailure definiteFailure()

Constructs a DiagnosedSilenceableFailure in the failure state.

This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.

A class for computing basic dominance information.

bool dominates(Operation *a, Operation *b) const

Return true if operation A dominates operation B, i.e.

This is a utility class for mapping one set of IR entities to another.

void map(Value from, Value to)

Inserts a new mapping for 'from' to 'to'.

IRValueT get() const

Return the current value being used by this operand.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

NamedAttribute represents a combination of a name and an Attribute value.

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0

Resolve an operand to an SSA value, emitting an error on failure.

virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0

Parse a single operand if present.

virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0

Parse a single SSA value operand name along with a result number if allowResultNumber is true.

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0

If the specified operation has attributes, print out an attribute dictionary with their values.

void printFunctionalType(Operation *op)

Print the complete type of an operation in functional form.

bool isSet() const

Returns true if this insert point is set.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

void setListener(Listener *newListener)

Sets the listener of this builder to the one provided.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

Listener * getListener() const

Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

This class represents a single result from folding an operation.

This class represents an operand of an operation.

unsigned getOperandNumber()

Return which operand this is in the OpOperand list of the Operation.

This is a value defined by a result of an operation.

StringRef getStringRef() const

Return the name of this operation. This always succeeds.

Operation is the basic unit of execution within MLIR.

OpResult getOpResult(unsigned idx)

Attribute getAttr(StringAttr name)

Return the specified attribute if present, null otherwise.

void setOperand(unsigned idx, Value value)

bool hasAttr(StringAttr name)

Return true if the operation has an attribute with the provided name, false otherwise.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

Location getLoc()

The source location the operation was defined or derived from.

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

OperationName getName()

The name of an operation is the key identifier for it.

operand_type_range getOperandTypes()

result_type_range getResultTypes()

bool isAncestor(Operation *other)

Return true if this operation is an ancestor of the other operation.

std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)

Walk the operation by calling the callback for each nested operation (including this one),...

user_range getUsers()

Returns a range of all users.

result_range getOpResults()

bool isProperAncestor(Operation *other)

Return true if this operation is a proper ancestor of the other operation.

MLIRContext * getContext()

Return the context this operation is associated with.

unsigned getNumResults()

Return the number of results held by this operation.

bool has_value() const

Returns true if we contain a valid ParseResult value.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

virtual void eraseBlock(Block *block)

This method erases all operations in a block.

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)

Find uses of from and replace them with to if the functor returns true.

void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)

Find uses of from and replace them with to except if the user is exceptedUser.

void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})

Inline the operations of block 'source' into the end of block 'dest'.

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

virtual void replaceAllUsesWith(Value from, Value to)

Find uses of from and replace them with to.

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

Type front()

Return first type in the range.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

bool use_empty() const

Returns true if this value has no uses.

Type getType() const

Return the type of this value.

use_range getUses() const

Returns a range of all uses, which is useful for iterating over all uses.

user_range getUsers() const

Location getLoc() const

Return the location of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)

State for analysis-enabled bufferization.

Operation * getOwner() const

Return the owner of this operand.

A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...

void assign(unsigned size, std::nullptr_t)

Sets the list of results to size null pointers.

void reserve(unsigned size)

Reserves space for size elements in the list.

size_t size() const

Returns the number of elements in the list.

void push_back(Operation *op)

Appends an element to the list.

Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...

void setValues(OpResult handle, Range &&values)

Indicates that the result of the transform IR op at the given position corresponds to the given range...

void setParams(OpResult value, ArrayRef< TransformState::Param > params)

Indicates that the result of the transform IR op at the given position corresponds to the given list ...

void set(OpResult value, Range &&ops)

Indicates that the result of the transform IR op at the given position corresponds to the given list ...

This is a special rewriter to be used in transform op implementations, providing additional helper fu...

LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)

Notify the transform dialect interpreter that the given op has been replaced with another op and that...

The state maintained across applications of various ops implementing the TransformOpInterface.

auto getPayloadOps(Value value) const

Returns an iterator that enumerates all ops that the given transform IR value corresponds to.

auto getPayloadValues(Value handleValue) const

Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...

ArrayRef< Attribute > getParams(Value value) const

Returns the list of parameters that the given transform IR value corresponds to.

AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)

Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...

SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)

Variant of makeComposedFoldedAffineApply suitable for multi-result maps.

OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)

Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...

LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)

Analyze op and its nested ops.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)

Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.

LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)

Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...

FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)

Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....

bool hasVectorizationImpl(Operation *)

Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...

FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)

Rewrite linalg.winograd_filter_transform.

std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)

Allocate the subview in the GPU workgroup memory.

FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)

Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...

Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)

Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....

FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)

Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...

FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)

Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.

FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)

Create a namedOp from the given GenericOp and replace the GenericOp.

FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)

Rewrite pack as empty + transpose + reshape + extract_slice.

void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)

Populates patterns with patterns that vectorize tensor.pad.

void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)

LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)

In case of GPU private memory there is no need to deallocate since the memory is freed when going out...

FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)

Rewrite linalg.winograd_output_transform.

std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn

Function signature to control reduction splitting.

std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)

Allocate the subview in the GPU private memory.

FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)

Rewrite tensor.from_elements to linalg.generic.

FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)

Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).

FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)

Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.

void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)

Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...

LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)

Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...

LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)

Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.

FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)

Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.

FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)

Pattern to replace.

LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)

Promote memref.subviews feeding linalg-on-buffers operations.

LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)

Normal copy to between src and dst.

bool isElementwise(LinalgOp op)

Check if a LinalgOp is an element-wise operation.

FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)

Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.

FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)

void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)

Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.

FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)

FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)

FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)

void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)

Populates patterns with patterns that fold operations like linalg.pack and linalg....

void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)

Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...

void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)

Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...

FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)

Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...

FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)

Implement packing of a single LinalgOp by packedSizes.

void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)

Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.

FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)

Promote the subViews into a new buffer allocated at the insertion point b.

LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)

In case of GPU group memory there is no need to deallocate.

FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)

Convert Linalg matmul ops to transposed variants.

FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)

Collapses dimensions of linalg.generic/linalg.copy operation.

void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)

Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...

FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)

Rewrite linalg.winograd_input_transform.

void populateDecomposePadPatterns(RewritePatternSet &patterns)

Populates patterns to decompose tensor.pad into e.g.

void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)

Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...

std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)

Split the given op into two parts along the given iteration space dimension at the specified splitPoi...

FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)

Scaling-based implementation of the split reduction transformation.

FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)

Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...

FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)

Rewrite pack as pad + reshape + transpose.

ForOp getForInductionVarOwner(Value val)

Returns the loop parent of an induction variable.

void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)

Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.

void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)

Appends patterns that are used to bubble up tensor.extract slice op above its producer.

OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)

Return the dimension of the given tensor value.

SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)

Return the dimensions of the given tensor value.

LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)

This is a helper function for DestinationStyleOpInterface.

void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)

Appends patterns for folding tensor subset ops into vector transfer ops.

void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, scf::SCFTilingResult &tilingResult)

Implementation of tiling operations using scf.forall.

Definition LinalgTransformOps.cpp:3854

void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

Populates effects with the memory effects indicating the operation on the given handle value:

void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

Populates effects with the memory effects indicating the access to payload IR resource.

void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)

Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....

void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)

Populate the pattern set with the following patterns:

void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)

Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...

void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)

Patterns that remove redundant Vector Ops by re-ordering them with e.g.

void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)

Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)

Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...

@ PartialReductionOuterReduction

@ PartialReductionOuterParallel

OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)

Convert int64_t to integer attributes of index type and return them as OpFoldResult.

detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to DimExpr at positions: [0 .

LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)

Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...

DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})

Emits a silenceable failure with the given message.

llvm::DenseSet< ValueT, ValueInfoT > DenseSet

Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)

This parses a single MLIR attribute to an MLIR context if it was valid.

llvm::SetVector< T, Vector, Set, N > SetVector

DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})

Emits a definite failure with the given message.

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

bool isZeroInteger(OpFoldResult v)

Return true if v is an IntegerAttr with value 0.

void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to SymbolExpr at positions: [0 .

void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)

Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...

llvm::TypeSwitch< T, ResultT > TypeSwitch

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)

OpFoldResult getAsOpFoldResult(Value val)

Given a value, try to extract a constant Attribute.

SmallVector< int64_t, 2 > ReassociationIndices

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

SmallVector< IntTy > extractFromIntegerArrayAttr(Attribute attr)

Extract integer values from the assumed ArrayAttr of IntegerAttr.

llvm::function_ref< Fn > function_ref

bool isPermutationVector(ArrayRef< int64_t > interchange)

Method to check if an interchange vector is a permutation.

bool isOneInteger(OpFoldResult v)

Return true if v is an IntegerAttr with value 1.

This class represents a listener that may be used to hook into various actions within an OpBuilder.

This represents an operation in an abstracted form, suitable for use with the builder APIs.

Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...

A listener that forwards all notifications to another listener.

ForwardingListener(OpBuilder::Listener *listener)

Container for result values of tiling.

SmallVector< Value > tiledValues

Options for analysis-enabled bufferization.

Transformation to drop unit-extent dimensions from linalg.generic operations.

Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...