MLIR: lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
12
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/ADT/ScopeExit.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/Debug.h"
48 #include <type_traits>
49
50 using namespace mlir;
53
54 #define DEBUG_TYPE "linalg-transforms"
55 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
56 #define DBGSNL() (llvm::dbgs() << "\n")
57 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
58
59
60
61
62
63
64 template <typename PatternTy, typename... Args>
65 static FailureOr tryApply(Operation *operation, Args &&...args) {
66
67 using OpTy = typename llvm::function_traits<
68 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
69 auto op = dyn_cast(operation);
70 if (!op)
71 return failure();
72
73
74 PatternTy pattern(operation->getContext(), std::forward(args)...);
75
76
78 public:
79 explicit TrivialPatternRewriter(MLIRContext *context)
81 };
82 TrivialPatternRewriter rewriter(operation->getContext());
83 rewriter.setInsertionPoint(operation);
84 auto result = pattern.returningMatchAndRewrite(op, rewriter);
85 if (failed(result))
86 return failure();
87 return cast(result->getOperation());
88 }
89
90
91
92
97 if (auto attr = dyn_cast(ofr)) {
98 if (!isa(attr))
99 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
100 result.push_back(ofr);
101 continue;
102 }
103
104 Value transformValue = cast(ofr);
105 if (isa(transformValue.getType())) {
107 if (params.size() != 1)
108 return transformOp.emitDefiniteFailure()
109 << "requires exactly one parameter associated";
110 result.push_back(params[0]);
111 continue;
112 }
113
114 auto payloadOps = state.getPayloadOps(transformValue);
115 if (!llvm::hasSingleElement(payloadOps)) {
117 transformOp.emitSilenceableError()
118 << "handle must be mapped to exactly one payload op";
119 diag.attachNote(transformValue.getLoc())
120 << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
122 }
123
124 Operation *op = *payloadOps.begin();
127 transformOp.emitSilenceableError()
128 << "payload op must have exactly 1 index result";
132 }
133 result.push_back(op->getResult(0));
134 }
135
137 }
138
139
140
141
142
143
144
148 if (isa(packedHandle.getType())) {
150 for (auto param : params) {
151 if (!isa(param))
152 return transformOp.emitDefiniteFailure()
153 << "expected the parameter to be associated with an integer "
154 "attribute";
155 result.push_back(param);
156 }
158 }
159
160 for (Operation *op : state.getPayloadOps(packedHandle)) {
161 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
163 transformOp.emitSilenceableError()
164 << "payload op must have exactly 1 index result";
165 diag.attachNote(op->getLoc())
166 << "has " << op->getNumResults() << " results";
168 }
169 result.push_back(op->getResult(0));
170 }
171
173 }
174
175
176
177
178
180 TransformState &state, TransformOpInterface &transformOp,
182 for (OpFoldResult paramOrHandle : mixedResults) {
183 if (auto attr = dyn_cast(paramOrHandle)) {
184 reified.push_back(cast(attr).getInt());
185 continue;
186 } else if (isa(cast(paramOrHandle).getType())) {
187 ArrayRef params = state.getParams(cast(paramOrHandle));
188 if (params.size() != 1)
189 return transformOp.emitSilenceableError() << "expected a single param";
190 reified.push_back(
191 cast(params.front()).getValue().getSExtValue());
192 continue;
193 }
194
195 Value handle = cast(paramOrHandle);
196 if (!isa(handle.getType()))
197 return transformOp.emitSilenceableError() << "unexpected value handle";
198 auto payload = state.getPayloadOps(handle);
199 if (!llvm::hasSingleElement(payload))
200 return transformOp.emitSilenceableError()
201 << "requires param or handle that is mapped to 1 payload op";
202
203 Operation *paramOrHandlePayloadOp = *payload.begin();
204 if (paramOrHandlePayloadOp->getNumResults() != 1 ||
206 return transformOp.emitSilenceableError()
207 << "requires param or handle to be result of op with 1 index "
208 "result";
209 }
210
211 IntegerAttr attr;
213 return transformOp.emitSilenceableError()
214 << "requires param or handle to be the result of a constant like "
215 "op";
216
217 reified.push_back(attr.getInt());
218 }
220 }
221
222
223
224
225
226 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
229 }
230
231 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
234 }
235
236 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
239 }
240
241 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
245 }
246
247 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
250 options.rankReductionStrategy =
253 }
254
255 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
258 }
259
260 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
263 }
264
265 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
268 }
269
270 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
273 }
274
275 void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
278 }
279
280
281
282
283
284 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
289 resultTypes.push_back(b.getTypetransform::AnyValueType());
290 resultTypes.push_back(b.getTypetransform::AnyOpType());
291 return build(b, result,
292 resultTypes,
293 target,
294 memorySpace);
295 }
296
297 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
300 int64_t memorySpace) {
302 resultTypes.push_back(b.getTypetransform::AnyValueType());
303 resultTypes.push_back(b.getTypetransform::AnyOpType());
304 return build(b, result,
305 resultTypes,
306 target,
308 }
309
310 namespace {
312 public:
314
317 }
318
319 private:
320 void notifyOperationInserted(Operation *op,
322 ForwardingListener::notifyOperationInserted(op, previous);
323
324 if (previous.isSet())
325 return;
326 auto inserted = newOps.insert(op);
327 (void)inserted;
328 assert(inserted.second && "expected newly created op");
329 }
330
331 void notifyOperationErased(Operation *op) override {
332 ForwardingListener::notifyOperationErased(op);
333 op->walk([&](Operation *op) { newOps.erase(op); });
334 }
335
337 };
338 }
339
343
345 auto resetListener =
346 llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
347 NewOpsListener newOpsListener(previousListener);
349
351 if (getMemcpyOp() == "bufferization.materialize_in_destination") {
354 } else if (getMemcpyOp() == "memref.copy") {
357 } else if (getMemcpyOp() == "linalg.copy") {
360 } else {
361 llvm_unreachable("invalid memcpy op");
362 }
363 if (getAllocOp() == "memref.alloc") {
366 } else if (getAllocOp() == "memref.alloca") {
369 } else {
370 llvm_unreachable("invalid alloc op");
371 }
372 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
373 options.emitDealloc = getEmitDealloc();
374
375
377 getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
379 for (Operation *op : state.getPayloadOps(getTarget())) {
382 if (!buffer) {
384 << "failed to bufferize operation";
385 diag.attachNote(op->getLoc()) << "target payload op";
387 }
388 allocatedBuffers.push_back(buffer);
389 }
390
391
392 results.setValues(cast(getAllocatedBuffer()), allocatedBuffers);
393 results.set(cast(getNewOps()), newOpsListener.getNewOps());
395 }
396
397 void transform::BufferizeToAllocationOp::getEffects(
399 if (getBufferizeDestinationOnly()) {
400
401
403 } else {
405 }
406 producesHandle(getOperation()->getOpResults(), effects);
408 }
409
411 if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
412 getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
413 return emitOpError() << "unsupported memcpy op";
414 if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
415 return emitOpError() << "unsupported alloc op";
416 return success();
417 }
418
419
420
421
422
425 LinalgOp target,
428 #define DOWNSCALE(trans) \
429 { \
430 FailureOr res = tryApply(target); \
431 if (succeeded(res)) { \
432 results.push_back(*res); \
433 return DiagnosedSilenceableFailure::success(); \
434 } \
435 }
436
437 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
438 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
439
445 DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
447 DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
451 #undef DOWNSCALE_NORMAL
452 #undef DOWNSCALE_CALL
453 #undef DOWNSCALE
454 return emitDefaultSilenceableFailure(target);
455 }
456
457
458
459
460
461
462
463
468 auto decomposableOp = dyn_cast(target);
469 if (!decomposableOp) {
471 "payload is not a decomposable op"));
472 return emitDefaultSilenceableFailure(target);
473 }
474
475 FailureOr<SmallVector> maybeNewResults =
476 decomposableOp.decomposeOperation(rewriter);
477 if (failed(maybeNewResults))
478 return emitDefaultSilenceableFailure(target);
479
480 rewriter.replaceOp(decomposableOp, *maybeNewResults);
481 for (Value val : *maybeNewResults) {
482 Operation *definition = val.getDefiningOp();
483 if (definition)
485 }
487 }
488
489
490
491
492
493 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
497 }
498
500 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
504 options.allowReturnAllocsFromLoops = true;
505
506 for (Operation *target : state.getPayloadOps(getTarget())) {
508 if (failed(analyzeOp(target, state)))
510 << "failed to analyze op";
512 rewriter, target, state)))
514 << "failed to eliminate LinalgOp anchored tensor.empty ops";
515 }
517 }
518
519
520
521
522
523
524
525 template
529 function_ref<FailureOrscf::SCFTileAndFuseResult(TilingInterface)>
530 applyFn) {
533
534 for (Operation *target : payloadOps) {
535 auto tilingInterfaceOp = dyn_cast(target);
536 if (!tilingInterfaceOp)
537 return transformOp->emitError("only TilingInterface ops are supported");
538
540 FailureOrscf::SCFTileAndFuseResult tiledResults =
541 applyFn(tilingInterfaceOp);
542 if (failed(tiledResults))
543 return failure();
544
545
547 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
548 for (Operation *toReplace : opsToReplace) {
549 for (OpResult res : toReplace->getResults())
550 if (auto replacement = tiledResults->replacements.lookup(res))
552 if (toReplace->use_empty()) {
553 rewriter.eraseOp(toReplace);
554 }
555 }
556
557
558 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
559 assert(tiledResults->loops.size() == numLoops &&
560 "Mismatched number of loops, tile and fuse transform should have "
561 "failed");
562 for (unsigned int i = 0; i < numLoops; ++i)
563 loopOps[i].push_back(tiledResults->loops[i]);
564 }
565
566 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
567 for (unsigned int i = 0; i < numLoops; ++i)
568 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
569
570 return success();
571 }
572
578 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
580 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
581
586 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
588 tileAndFuseOptions.tilingOptions = tilingOptions;
589
590 if (getApplyCleanup()) {
593 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
597 }
598
600 rewriter, getOperation(), state.getPayloadOps(getTarget()),
601 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
602 [&](TilingInterface tilingInterfaceOp)
603 -> FailureOrscf::SCFTileAndFuseResult {
604 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
605 tileAndFuseOptions);
606 });
609 }
610
613 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
614 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
615 if (!std::is_permutation(sequence.begin(), sequence.end(),
616 permutation.begin(), permutation.end())) {
617 return emitOpError() << "expects interchange to be a permutation, found "
618 << getTileInterchange();
619 }
620
622 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
623 size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
624 if (numExpectedLoops != getNumResults() - 1)
625 return emitOpError() << "expects " << numExpectedLoops << " loop results";
626
627 return success();
628 }
629
630
631
632
633
634 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
636 Value producerOp,
637 Value containingOp) {
638 result.addOperands({producerOp, containingOp});
640 result.addTypes({resultType, resultType});
641 }
642
643
644
650
651
655 if (!containingOp->isAncestor(user) &&
656 (domInfo.dominates(containingOp, user))) {
657 dominatedUsers.insert(user);
658 }
659 }
660 if (dominatedUsers.empty())
661 return nullptr;
662
663
664 auto forallOp = castscf::ForallOp(containingOp);
667
668
669 Location loc = forallOp.getLoc();
670 auto genericOp = dyn_castlinalg::GenericOp(producerOp);
671 if (!genericOp)
672 return nullptr;
675 newOuts.push_back(outputs[resultNumber]);
676
677
678 auto newforallOp = rewriter.createscf::ForallOp(
679 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
680 forallOp.getMixedStep(), newOuts, forallOp.getMapping());
681 rewriter.eraseBlock(newforallOp.getBody());
682 newforallOp.getRegion().takeBody(forallOp.getRegion());
683
684
685
686
687 newforallOp.getBody()->addArgument(newOuts.back().getType(),
688 newOuts.back().getLoc());
689 auto bbArgs = newforallOp.getBody()->getArguments();
694 });
695
696
697 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
699 terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
700 Operation *firstYieldOp = yieldingOps.front();
703 Value dst = newforallOp.getRegionIterArgs().back();
705 rewriter.createtensor::ParallelInsertSliceOp(firstYieldOp->getLoc(), src,
706 dst, offsets, sizes, strides);
707
708 for (auto result : llvm::enumerate(forallOp.getResults())) {
710 newforallOp->getResult(result.index()));
711 }
713 newforallOp->getResults().back(),
715 Operation *user = use.getOwner();
716 return dominatedUsers.contains(user);
717 });
718 return newforallOp;
719 }
720
721
722
723
724
725
727
728
730 destWorklist.push_back(dst);
731
732 while (!destWorklist.empty()) {
733 Value currentDst = destWorklist.pop_back_val();
734
735
736
737 if (src == currentDst)
738 return true;
739
740
741
742 auto bbArg = dyn_cast(currentDst);
743 if (!bbArg)
744 continue;
745
746 Block *parentBlock = bbArg.getOwner();
747 assert(parentBlock && "unlinked block argument");
748
750 assert(parentOp && "expected block argument with parent operation");
751
752
753 auto parentLoop = dyn_cast(parentOp);
754 if (!parentLoop)
755 continue;
756
757 for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
758
759 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
760 Value loopBlockArgument =
762 destWorklist.push_back(loopBlockArgument);
763 }
764 }
765
766 return false;
767 }
768
769
770
771
772
773
774
775 static std::tuple<SmallVector<Operation *>, Operation *>
778 LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
779 auto tileableProducer = dyn_cast(producerOp);
780 if (!tileableProducer) {
781 diag.attachNote(producerOp->getLoc())
782 << "producer is not a TileableInterface: " << *producerOp;
783 return {};
784 }
785
786
787
788
789 auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
790 auto sliceOp = dyn_casttensor::ExtractSliceOp(user);
791 return sliceOp && containingOp->isProperAncestor(sliceOp);
792 });
793
794
795 if (it == tileableProducer->getUsers().end()) {
796 diag.attachNote(tileableProducer->getLoc())
797 << "could not find fusion opportunity for: " << *tileableProducer;
798 return {};
799 }
800 auto sliceOpToTile = casttensor::ExtractSliceOp(*it);
801
802
805
806
807
808
809
810
811
812
813 if (LoopLikeOpInterface containerLoop =
814 dyn_cast(sliceOpToTile->getParentOp())) {
817
818
819
821 cast(clone).getDpsInitsMutable()) {
822 Value producerOperand =
825 containerLoop.getRegionIterArgs()) {
826 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
827 Value consumerOperand =
828 containerLoop->getOperand(bbArg->getOperandNumber());
829
830 if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
831 initOperandPtr.set(containerIterArg);
832 }
833 }
834 }
835 });
836
837 tileableProducer = dyn_cast(clone);
838 }
839
840
841 int64_t resultNumber =
842 cast(sliceOpToTile.getSource()).getResultNumber();
843 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
844
847
848 FailureOr tileAndFuseResult =
849 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
850 sizes);
851
852 if (failed(tileAndFuseResult)) {
853 diag.attachNote(tileableProducer->getLoc())
854 << "failed to tile producer op: " << *tileableProducer;
855 return {};
856 }
857
858 #ifndef NDEBUG
859 for (auto *tiledOp : tileAndFuseResult->tiledOps) {
860 LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
861 }
862 #endif
863
864
865 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
866 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
867 cast(sliceOpToTile->getResult(0).getType()).getShape());
868 if (failed(maybeRankReduced)) {
869 diag.attachNote(producerOp->getLoc())
870 << "shape types don't match (missing canonicalization?):\nTiledOp: "
871 << tileAndFuseResult->tiledValues[0]
872 << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
873 return {};
874 }
875 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
876
877
879 rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
880 resultNumber, offsets, sizes);
881
882
883 if (dyn_cast(containingOp))
884 rewriter.eraseOp(tileableProducer);
885
886 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
887 }
888
889
890
891
892
893
894
899 LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
900
901 auto tileableProducer = dyn_cast(producerOp);
902 if (!tileableProducer) {
903 diag.attachNote(producerOp->getLoc())
904 << "producer is not a TileableInterface: " << *producerOp;
905 return {};
906 }
907
908
909 scf::ForallOp forallOp;
910 auto itProducerUses =
911 llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
912 forallOp = dyn_castscf::ForallOp(use.getOwner());
913 return forallOp;
914 });
915
916 if (!forallOp || forallOp != containingOp) {
917 diag.attachNote(tileableProducer->getLoc())
918 << "could not find a use by the containing op: " << *tileableProducer;
919 return {};
920 }
921
922
923
924
925
926 OpOperand *pUse = &(*itProducerUses);
927 BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
928
929
930
931
932 auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
933 auto sliceOp = dyn_casttensor::ExtractSliceOp(user);
934 return sliceOp && containingOp->isProperAncestor(sliceOp);
935 });
936
937
938 if (itBBArgUsers == bbArg.getUsers().end()) {
939 diag.attachNote(containingOp->getLoc())
940 << "could not find fusion opportunity for bbArg: " << bbArg;
941 return {};
942 }
943 auto sliceOpToTile = casttensor::ExtractSliceOp(*itBBArgUsers);
944
945
948
949
950
951 int64_t resultNumber = cast(pUse->get()).getResultNumber();
952 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
953
954
957 rewriter, tileableProducer->getLoc(), tileableProducer,
958 destinationTensors))) {
959 diag.attachNote(tileableProducer->getLoc())
960 << "failed to get destination tensors for: " << *tileableProducer;
961 return {};
962 }
963
965 bvm.map(destinationTensors[resultNumber], bbArg);
966 auto tileableProducerClone =
967 cast(rewriter.clone(*tileableProducer, bvm));
968 auto scopeGuard =
969 llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
970
971
972 FailureOr tileAndFuseResult =
973 tileableProducerClone.generateResultTileValue(
974 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
975 sliceOpToTile.getMixedSizes());
976 if (failed(tileAndFuseResult)) {
977 diag.attachNote(tileableProducer->getLoc())
978 << "failed to tile producer op: " << *tileableProducer;
979 return {};
980 }
981
982
983 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
984 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
985 cast(sliceOpToTile->getResult(0).getType()).getShape());
986 assert(succeeded(maybeRankReduced) && "unexpected shape");
987 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
988
989
992 destinationTensors.front());
993 });
994
995 return tileAndFuseResult->tiledOps;
996 }
997
1001 LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
1002
1003
1006 for (OpOperand &use : result.getUses()) {
1008 uses.push_back(&use);
1009 continue;
1010 }
1011
1012
1013 if (containingOp == use.getOwner()) {
1014 diag.attachNote(producerOp->getLoc())
1015 << "producer op use by containing op cannot be fused by cloning";
1016 return nullptr;
1017 }
1018 }
1019 }
1020
1021
1022 if (uses.empty()) {
1023 diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
1024 return nullptr;
1025 }
1026
1027
1030
1031
1032 assert(!isatensor::ParallelInsertSliceOp(use->getOwner()) &&
1033 "Parallel insert slice is not a valid clone destination");
1034 unsigned resultNumber = cast(use->get()).getResultNumber();
1035 LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
1036
1039 fusedOp = rewriter.clone(*producerOp);
1041 use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1042
1043 return fusedOp;
1044 }
1045
1046 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1047
1048 return true;
1049 }
1050
1056 auto producerOps = state.getPayloadOps(getProducerOp());
1057 auto containingOps = state.getPayloadOps(getContainingOp());
1058 if (!llvm::hasSingleElement(containingOps)) {
1060 << "requires exactly one containing_op handle (got "
1061 << llvm::range_size(containingOps) << ")";
1062 }
1063 Operation *containingOp = *containingOps.begin();
1064
1065
1066 if (std::empty(producerOps)) {
1068 results.set(cast(getNewContainingOp()), {containingOp});
1070 }
1071
1072
1073
1075 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1076 for (const auto &it : enumerate(remainingProducers)) {
1077 Operation *producerOp = it.value();
1078
1079 int64_t numUsesInContainingOp =
1081 return containingOp->isAncestor(op);
1082 });
1083
1084
1085
1086 if (numUsesInContainingOp > 0) {
1087 if (numUsesInContainingOp == 1)
1088 remainingProducers.erase(remainingProducers.begin() + it.index());
1089 return producerOp;
1090 }
1091 }
1092 return failure();
1093 };
1094
1095 while (!remainingProducers.empty()) {
1096 auto nextProducer = getNextProducer();
1097 if (failed(nextProducer)) {
1099 << "could not find next producer to fuse into container";
1100 diag.attachNote(containingOp->getLoc()) << "containing op";
1101 return diag;
1102 }
1103
1104 Operation *producerOp = *nextProducer;
1105
1106
1108 diag << "could not fuse " << *producerOp << " into " << *containingOp;
1109
1110
1111
1112
1113
1114
1115 auto [tiledOps, newContainingOp] =
1117 if (!tiledOps.empty()) {
1118 LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
1119 fusedOps.append(tiledOps);
1120 if (newContainingOp) {
1121
1122
1123
1124
1125
1126
1127
1128 LogicalResult replacementStatus =
1130 newContainingOp);
1131 (void)replacementStatus;
1132 assert(succeeded(replacementStatus) &&
1133 "unable to update transform state mapping");
1134 rewriter.eraseOp(containingOp);
1135 containingOp = newContainingOp;
1136 }
1137 continue;
1138 }
1139
1142 rewriter, diag, producerOp, containingOp);
1143 if (!tiledContainingOpOperand.empty()) {
1144 LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
1145 << *containingOp);
1146 fusedOps.append(tiledContainingOpOperand);
1147 continue;
1148 }
1149
1152 if (cloned) {
1153 LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
1154 fusedOps.push_back(cloned);
1155 continue;
1156 }
1158 }
1159
1160 results.set(cast(getFusedOp()), fusedOps);
1161 results.set(cast(getNewContainingOp()), {containingOp});
1163 }
1164
1165 void transform::FuseIntoContainingOp::getEffects(
1169 producesHandle(getOperation()->getOpResults(), effects);
1171 }
1172
1173
1174
1175
1176
1179 LinalgOp target,
1182
1183 if (isa(target)) {
1186 }
1188 FailureOr generic = generalizeNamedOp(rewriter, target);
1189 if (succeeded(generic)) {
1190 results.push_back(generic->getOperation());
1192 }
1193 return emitDefaultSilenceableFailure(target);
1194 }
1195
1196
1197
1198
1199
1202 LinalgOp target,
1205
1206 if (!isa(target)) {
1209 }
1211 FailureOr named =
1213 if (succeeded(named)) {
1214 results.push_back(named->getOperation());
1216 }
1217 return emitDefaultSilenceableFailure(target);
1218 }
1219
1220
1221
1222
1223
1226 GenericOp target,
1230
1231 if (interchangeVector.empty()) {
1234 }
1235
1236 unsigned numLoops = cast(target.getOperation()).getNumLoops();
1237 if (interchangeVector.size() != numLoops) {
1238 return emitSilenceableError()
1239 << getIteratorInterchangeAttrName() << " has length ("
1240 << interchangeVector.size()
1241 << ") different from the number of loops in the target operation ("
1242 << numLoops << ")";
1243 }
1246 if (failed(res))
1248 results.push_back(res->getOperation());
1250 }
1251
1254 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1255 if (!std::is_permutation(sequence.begin(), sequence.end(),
1256 permutation.begin(), permutation.end())) {
1257 return emitOpError()
1258 << "expects iterator_interchange to be a permutation, found "
1259 << getIteratorInterchange();
1260 }
1261 return success();
1262 }
1263
1264
1265
1266
1267
1272
1273
1274 if (!isalinalg::CopyOp(targetOp)) {
1276 emitSilenceableError() << "only linalg.copy target ops are supported";
1277 diag.attachNote(targetOp->getLoc()) << "target op";
1278 return diag;
1279 }
1280
1281 auto copyOp = dyn_castlinalg::CopyOp(targetOp);
1282 if (!copyOp.hasPureBufferSemantics()) {
1284 emitSilenceableError()
1285 << "cannot transform a linalg.copy on tensors into a memref.copy";
1286 diag.attachNote(targetOp->getLoc()) << "target op";
1287 return diag;
1288 }
1289
1292 assert(inputs.size() == 1 && "expected linalg copy op with one input");
1293 assert(outputs.size() == 1 && "expected memref copy op with one output");
1294 Value input = inputs.front();
1295 Value output = outputs.front();
1296
1297
1298
1299
1300 if (!isa(input.getType())) {
1302 emitSilenceableError()
1303 << "cannot transform a linalg.copy which input has no shape";
1304 diag.attachNote(targetOp->getLoc()) << "target op";
1305 return diag;
1306 }
1307
1308
1309 assert(isa(output.getType()));
1310
1311 if (cast(input.getType()).getElementType() !=
1312 cast(output.getType()).getElementType()) {
1314 emitSilenceableError()
1315 << "cannot transform a linalg.copy with different source and "
1316 "destination element types ";
1317 diag.attachNote(targetOp->getLoc()) << "target op";
1318 return diag;
1319 }
1320
1321
1322 auto memrefCopyOp =
1323 rewriter.replaceOpWithNewOpmemref::CopyOp(targetOp, input, output);
1324
1325 results.push_back(memrefCopyOp);
1327 }
1328
1329
1330
1331
1332
1338 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1339 FailureOr res =
1340 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1341 if (failed(res)) {
1343 << "cannot lower to pad + expand + transpose";
1344 }
1345 transformResults.push_back(res->padOp);
1346 transformResults.push_back(res->expandShapeOp);
1347 transformResults.push_back(res->transposeOp);
1349 }
1350
1351
1352
1353
1354
1360 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1361 FailureOr res =
1362 lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1363 if (failed(res)) {
1365 emitSilenceableError()
1366 << "cannot lower to transpose + collapse + extract";
1367 diag.attachNote(target->getLoc()) << "target payload op";
1368 return diag;
1369 }
1370 transformResults.push_back(res->emptyOp);
1371 transformResults.push_back(res->transposeOp);
1372 transformResults.push_back(res->collapseShapeOp);
1373 transformResults.push_back(res->extractSliceOp);
1375 }
1376
1377
1378
1379
1380
1387 }
1388
1395 result.addTypes(resultTypes);
1396 }
1397
1403 if (getOps().has_value())
1404 strs.insert_range(getOps()->getAsValueRange());
1405
1406 auto payloadOps = state.getPayloadOps(getTarget());
1407 if (!llvm::hasSingleElement(payloadOps)) {
1409 }
1410
1412 bool incorrectNumOperandTypes = false;
1413 auto matchFun = [&](Operation *op) {
1414 if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1415 return;
1416
1417
1418
1419 if (getInterface().has_value()) {
1420 auto iface = getInterface().value();
1421 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1422 !isa(op))
1423 return;
1424 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1425 !isa(op))
1426 return;
1427 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1428 !isa(op))
1429 return;
1430 }
1431
1432
1433 if (getOpAttrs().has_value()) {
1434 DictionaryAttr opAttrs = getOpAttrs().value();
1436 if (attr.getName() == getInterfaceAttrName() ||
1437 attr.getName() == getOpsAttrName())
1438 continue;
1439 if (!op->hasAttr(attr.getName()))
1440 return;
1441 if (op->getAttr(attr.getName()) != attr.getValue())
1442 return;
1443 }
1444 }
1445
1446 if (getFilterResultType().has_value()) {
1447 Type t = getFilterResultType().value();
1449 return;
1450 }
1451
1452 if (getFilterOperandTypes().has_value()) {
1453 mlir::ArrayAttr types = getFilterOperandTypes().value();
1455
1456 if (types.size() == 1) {
1457
1458 auto typeattr =
1459 dyn_castmlir::TypeAttr(getFilterOperandTypes().value()[0]);
1460 Type t = cast<::mlir::Type>(typeattr.getValue());
1462 [&](Type operandType) { return operandType == t; }))
1463 return;
1464 } else {
1465
1466
1467 if (types.size() != operandTypes.size()) {
1468 incorrectNumOperandTypes = true;
1469 return;
1470 }
1471
1472 for (auto [attr, operandType] :
1473 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1474 auto typeattr = castmlir::TypeAttr(attr);
1475 Type type = cast<::mlir::Type>(typeattr.getValue());
1476
1477 if (type != operandType)
1478 return;
1479 }
1480 }
1481 }
1482
1483
1484 res.push_back(op);
1485 return;
1486 };
1487
1488 (*payloadOps.begin())->walk(matchFun);
1489 if (incorrectNumOperandTypes)
1490 return emitDefiniteFailure("If filter_operand_types contains more than a "
1491 "type, then it must contain as much types as "
1492 "the number of operands in the target ops");
1493 results.set(cast(getResult()), res);
1495 }
1496
1497
1498
1499
1500
1505 }
1506
1508 Type &targetType, Type &lowSizeType,
1509 Type &highSizeType,
1510 Type &splitPointType) {
1511 FunctionType funcType;
1513 if (failed(parser.parseType(funcType)))
1514 return failure();
1515
1516 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1517 parser.emitError(typeLoc) << "expects a trailing functional type with one "
1518 "argument and one result";
1519 }
1520 targetType = funcType.getInput(0);
1521 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1522
1523 return success();
1524 }
1525
1529 if (isa(getLowSize().getType())) {
1530 if (target.hasDynamicShape()) {
1531 auto diag = emitSilenceableError()
1532 << "cannot compute parametric tile sizes for dynamically "
1533 "shaped payload op";
1534 diag.attachNote(target->getLoc()) << "payload op";
1535 return diag;
1536 }
1537
1539 target, getDimension(), getTargetSize(), getDivisor());
1540 if (failed(spec)) {
1541 return emitSilenceableError()
1542 << "failed to compute multi-size tiling sizes";
1543 }
1544
1545 Builder builder(target.getContext());
1546 results.assign(llvm::map_range(
1548 spec->lowTileSize * spec->lowTripCount}),
1549 [&builder, this](int64_t value) {
1551 cast(getLowSize().getType()).getType(), value);
1552 }));
1554 }
1555
1556 OpBuilder builder(target.getContext());
1561 builder, target, getDimension(), targetSize, divisor);
1562 if (failed(spec)) {
1563 return emitSilenceableError() << "could not generate tile size computation";
1564 }
1565
1570 {spec->lowTileSize, spec->lowTripCount});
1571 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1572 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1573 assert(lowTileSize && highTileSize && splitPoint &&
1574 "tile sizes are not produced by operations");
1577 results.push_back(highTileSize);
1580 }
1581
1582 void transform::MultiTileSizesOp::getEffects(
1585 producesHandle(getOperation()->getOpResults(), effects);
1586 if (isa(getLowSize().getType()))
1588 else
1590 }
1591
1593 if (getLowSize().getType() != getHighSize().getType() ||
1594 getLowSize().getType() != getSplitPoint().getType()) {
1595 return emitOpError() << "expects all results type to be the same";
1596 }
1597 return success();
1598 }
1599
1600
1601
1602
1603
1610 staticPackedSizes);
1611
1612
1613
1615 builder.getContext(), GenericOp::getOperationName());
1616 build(builder, result,
1617 linalgOpHType,
1618 target,
1619 dynamicPackedSizes,
1621 }
1622
1625 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1626 }
1627
1632 auto targetOps = state.getPayloadOps(getTarget());
1633
1634 if (std::empty(targetOps)) {
1635 transformResults.set(cast(getPackedOp()),
1638 }
1639
1640 auto linalgOp = dyn_cast(*targetOps.begin());
1641 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1642 return emitSilenceableError()
1643 << "requires target to map to exactly 1 LinalgOp (got "
1644 << llvm::range_size(targetOps) << ")";
1645 }
1646
1647 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1648 return emitSilenceableError()
1649 << "requires number of packed sizes match the number of loops ("
1650 << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1651 << ")";
1652 }
1653
1654
1657 state, *this, packedSizes, getMixedPackedSizes());
1658
1660 FailureOr maybeResult = pack(rewriter, linalgOp, packedSizes);
1661 if (failed(maybeResult))
1663
1664 transformResults.set(cast(getPackedOp()),
1665 {maybeResult->packedLinalgOp.getOperation()});
1667 }
1668
1669 void transform::PackOp::getEffects(
1675 }
1676
1677
1678
1679
1680
1683 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1684 << " is not a valid permutation";
1685 }
1686
1687 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1688 for (auto [s, nmo] :
1689 llvm::zip_equal(getMixedMatmulPackedSizes(),
1690 getMatmulPaddedSizesNextMultipleOf())) {
1692 if (nmo != 0 &&
1693 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1694 return emitOpError() << "at most one of the packed_size and the "
1695 "padded_sizes_next_multiple_of can be nonzero "
1696 "for the matmul strategy";
1697 }
1698 }
1699 }
1700 return success();
1701 }
1702
1708 for (Operation *op : state.getPayloadOps(getTarget())) {
1709 auto linalgOp = dyn_cast(op);
1710 if (!linalgOp)
1711 continue;
1712
1713
1715
1716
1718 rewriter,
1719 linalgOp,
1720 getMixedMatmulPackedSizes(),
1721
1722 getMatmulPaddedSizesNextMultipleOf(),
1723 getMatmulInnerDimsOrder());
1724 if (succeeded(packResult)) {
1725 results.push_back(packResult->packedLinalgOp);
1726 continue;
1727 }
1728 results.push_back(linalgOp);
1729 }
1730 transformResults.set(cast(getPackedOp()), results);
1732 }
1733
1736 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1737 b);
1738 }
1739
1740 void transform::PackGreedilyOp::getEffects(
1746 }
1747
1748
1749
1750
1751
1754 return emitOpError() << getInnerPermAttrName()
1755 << " is not a valid permutation";
1756 }
1758 return emitOpError() << getOuterPermAttrName()
1759 << " is not a valid permutation";
1760 }
1761 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1762 return emitOpError() << " at least one of " << getInnerPermAttrName()
1763 << " or " << getOuterPermAttrName()
1764 << " must be specified";
1765 }
1766 return success();
1767 }
1768
1769 namespace {
1770 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1771 }
1772
1773
1774
1775
1776
1777
1778
1779
1780 template
1783 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1784 static_assert(
1785 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1786 "applies to only pack or unpack operations");
1787 if (!op || permutation.empty())
1788 return true;
1789 size_t innerRank = op.getInnerDimsPos().size();
1790 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1791 return permutation.size() == innerRank && isPermutationVector(permutation);
1792
1793
1794 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1795 return permutation.size() == op.getSourceRank() &&
1797 }
1798 return permutation.size() == op.getDestRank() &&
1800 }
1801
1806 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1807 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1808
1809 if (std::empty(packOrUnpackOps)) {
1810 transformResults.set(cast(getPackedOp()), {});
1811 transformResults.set(cast(getPackOp()), {});
1812 transformResults.set(cast(getUnPackOp()), {});
1814 }
1815
1816
1817
1818 if (!llvm::hasSingleElement(packOrUnpackOps) ||
1819 !llvm::hasSingleElement(linalgOps)) {
1820 return emitSilenceableError()
1821 << "requires target to map to exactly 1 "
1822 "packing op and 1 packed op ("
1823 << "got " << llvm::range_size(packOrUnpackOps) << " and "
1824 << llvm::range_size(linalgOps) << ")";
1825 }
1826
1827
1828 auto packOp = dyn_castlinalg::PackOp(*packOrUnpackOps.begin());
1829 auto unPackOp = dyn_castlinalg::UnPackOp(*packOrUnpackOps.begin());
1830 if ((!packOp && !unPackOp)) {
1831 return emitSilenceableError() << "requires target to map to a "
1832 "linalg.pack or linalg.unpack";
1833 }
1834 LinalgOp linalgOpTarget = dyn_cast(*linalgOps.begin());
1835 if (!linalgOpTarget)
1836 return emitSilenceableError() << "requires a LinalgOp target";
1837
1838
1839 LinalgOp linalgOp;
1840 if (packOp && packOp.getResult().hasOneUse())
1841 linalgOp = dyn_cast(*(packOp.getResult().getUsers().begin()));
1842 else if (unPackOp)
1843 linalgOp = unPackOp.getSource().getDefiningOp();
1844 if (linalgOp != linalgOpTarget) {
1845 auto errorMsg =
1846 packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1847 : StringLiteral{"not produced by the LinalgOp target"};
1848 return emitSilenceableError() << errorMsg;
1849 }
1850
1851
1852
1853 if (unPackOp) {
1854 assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1855 OpOperand *packUse = linalgOp.getDpsInitOperand(
1856 cast(unPackOp.getSource()).getResultNumber());
1857 packOp = dyn_cast_or_nulllinalg::PackOp(packUse->get().getDefiningOp());
1858 if (!packOp || !packOp.getResult().hasOneUse())
1859 return emitSilenceableError() << "could not find matching pack op";
1860 }
1861
1862
1863 for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1865 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1866 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1867 ? StringLiteral{"invalid outer_perm"}
1868 : StringLiteral{"invalid inner_perm"};
1872 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1873 return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1874 }
1875 }
1876
1877
1878
1879 assert(packOp && linalgOp && "unexpected null op");
1880
1881
1882 FailureOr res = packTranspose(
1883 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1884
1885 assert(succeeded(res) && "unexpected packTranspose failure");
1886
1887
1888 transformResults.set(cast(getPackOp()), {res->transposedPackOp});
1889 transformResults.set(cast(getPackedOp()),
1890 {res->transposedLinalgOp});
1891 if (unPackOp) {
1892 transformResults.set(cast(getUnPackOp()),
1893 {res->transposedUnPackOp});
1894 } else {
1895 transformResults.set(cast(getUnPackOp()), {});
1896 }
1897
1899 }
1900
1901
1902
1903
1904
1910 StringRef copyBackOp) {
1912 return build(b,
1913 result,
1914 TypeRange{resultType, resultType},
1915 target,
1916 ArrayAttr(),
1917 b.getI64ArrayAttr(paddingDimensions),
1919
1920 (padToMultipleOf.empty()
1922 : b.getDenseI64ArrayAttr(padToMultipleOf)),
1923 b.getI64ArrayAttr(nofoldFlags),
1924 b.getArrayAttr(transposePaddings),
1925 b.getStringAttr(copyBackOp));
1926 }
1927
1933 StringRef copyBackOp) {
1938 staticPadToMultipleOf);
1939 return build(b,
1940 result,
1941 TypeRange{resultType, resultType},
1942 target,
1943 ArrayAttr(),
1944 b.getI64ArrayAttr(paddingDimensions),
1945 dynamicPadToMultipleOf,
1946 staticPadToMultipleOf,
1948 b.getArrayAttr(transposePaddings),
1950 }
1951
1952 void PadOp::getEffects(
1956 producesHandle(getOperation()->getOpResults(), effects);
1958 }
1959
1962 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1963 }
1964
1969 auto transformOp = cast(getOperation());
1971
1972 for (Operation *target : state.getPayloadOps(getTarget())) {
1973 auto linalgTarget = dyn_cast(target);
1974 if (!linalgTarget) {
1975 auto diag = emitSilenceableError() << "expected LinalgOp target";
1976 diag.attachNote(target->getLoc()) << "target op";
1977 return diag;
1978 }
1979
1980
1982 for (int64_t packPadding :
1983 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1984 nofoldFlags.push_back(static_cast<bool>(packPadding));
1985
1986
1988 for (auto const &it :
1989 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1990 auto attr = dyn_cast(std::get<0>(it));
1991 if (!attr) {
1992 emitOpError("expects padding values to be typed attributes");
1994 }
1996
1997 if (auto stringAttr = dyn_cast(attr)) {
1998 auto parsedAttr = dyn_cast_if_present(parseAttribute(
1999 stringAttr, getContext(), elementType,
2000 nullptr, true));
2001 if (!parsedAttr || parsedAttr.getType() != elementType) {
2002 auto diag = this->emitOpError("expects a padding that parses to ")
2003 << elementType << ", got " << std::get<0>(it);
2004 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2006 }
2007 paddingValues.push_back(parsedAttr);
2008 continue;
2009 }
2010
2011 if (attr.getType() != elementType) {
2012 auto diag = this->emitOpError("expects a padding value of type ")
2013 << elementType << ", got " << attr;
2014 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2016 }
2017 paddingValues.push_back(attr);
2018 }
2019
2020
2022 for (Attribute transposeVector : cast(getTransposePaddings()))
2023 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2024 cast(transposeVector)));
2025
2026 LinalgOp paddedOp;
2028 options.paddingDimensions =
2029 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2030
2033 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2035 return status;
2036 if (padToMultipleOf.empty())
2037 padToMultipleOf =
2039
2040 options.padToMultipleOf = padToMultipleOf;
2041 options.paddingValues = paddingValues;
2042 options.nofoldFlags = nofoldFlags;
2043 if (getCopyBackOp() ==
2044 bufferization::MaterializeInDestinationOp::getOperationName()) {
2047 } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2049 } else if (getCopyBackOp() == kCopyOpNone) {
2051 } else {
2052 llvm_unreachable("unsupported copy_back op");
2053 }
2054
2058 replacements, newPadOps))) {
2059 auto diag = emitSilenceableError() << "failed to pad op";
2060 diag.attachNote(target->getLoc()) << "target op";
2061 return diag;
2062 }
2063
2064
2065
2066
2067
2068
2069 rewriter.replaceOp(linalgTarget, replacements);
2070 paddedOps.push_back(paddedOp);
2071 padOps.append(newPadOps.begin(), newPadOps.end());
2073 for (Value v : replacements) {
2074 Operation *copyBackOp = v.getDefiningOp();
2075 if (!llvm::is_contained(copyBackOps, copyBackOp))
2076 copyBackOps.push_back(copyBackOp);
2077 }
2078 }
2079 }
2080
2081 results.set(cast(getPadded()), paddedOps);
2082 results.set(cast(getPad()), padOps);
2083 results.set(cast(getCopy()), copyBackOps);
2085 }
2086
2089 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2090 if (any_of(nofoldFlags, [](int64_t packPadding) {
2091 return packPadding != 0 && packPadding != 1;
2092 })) {
2093 return emitOpError()
2094 << "expects nofold_flags to contain booleans (0/1), found "
2095 << getNofoldFlags();
2096 }
2097
2099 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2100 if (any_of(paddingDimensions,
2101 [](int64_t paddingDimension) { return paddingDimension < 0; })) {
2102 return emitOpError() << "expects padding_dimensions to contain positive "
2103 "integers, found "
2104 << getPaddingDimensions();
2105 }
2106 if (!getMixedPadToMultipleOf().empty()) {
2107 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2108 return emitOpError() << "expects as many multiples as padding_dimensions";
2109 }
2110 }
2111 ArrayAttr transposes = getTransposePaddings();
2112 for (Attribute attr : transposes) {
2114 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2115 if (!std::is_permutation(sequence.begin(), sequence.end(),
2117 return emitOpError()
2118 << "expects transpose_paddings to be a permutation, found "
2119 << attr;
2120 }
2121 }
2122 if (getCopyBackOp() !=
2123 bufferization::MaterializeInDestinationOp::getOperationName() &&
2124 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2125 getCopyBackOp() != kCopyOpNone)
2126 return emitOpError() << "invalid copy_back_op";
2127 return success();
2128 }
2129
2130
2131
2132
2133
2138 auto targetOps = state.getPayloadOps(getTarget());
2139 auto loopOps = state.getPayloadOps(getLoop());
2140 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2142 << "requires exactly one target and one loop handle (got "
2143 << llvm::range_size(targetOps) << " and "
2144 << llvm::range_size(loopOps) << ")";
2145 }
2146
2147 auto padOp = dyn_cast_or_nulltensor::PadOp(*targetOps.begin());
2148 auto loopOp = dyn_cast_or_nullscf::ForOp(*loopOps.begin());
2149 if (!padOp || !loopOp)
2151
2152 FailureOrlinalg::detail::PackingResult result =
2154 getTranspose());
2155 if (failed(result))
2157
2158 if (result->clonedLoopIvs.empty()) {
2159 transformResults.set(cast(getPackingLoop()),
2160 {result->hoistedPadOp.getOperation()});
2162 }
2163 auto outerPackedLoop =
2165 transformResults.set(cast(getPackingLoop()),
2166 {outerPackedLoop.getOperation()});
2168 }
2169
2172 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2173 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2175 return emitOpError() << "expects transpose to be a permutation, found "
2176 << getTranspose();
2177 }
2178 return success();
2179 }
2180
2181 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2187 }
2188
2191 tensor::PadOp target,
2194 tensor::PadOp hoistedPadOp;
2196 FailureOr result =
2198 hoistedPadOp, transposeOps);
2199 if (succeeded(result)) {
2200
2201
2202
2203
2204
2205 rewriter.replaceOp(target, *result);
2206 results.push_back(hoistedPadOp);
2208 }
2209 return emitDefaultSilenceableFailure(target);
2210 }
2211
2214 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2215 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2217 return emitOpError() << "expects transpose to be a permutation, found "
2218 << getTranspose();
2219 }
2220 return success();
2221 }
2222
2223
2224
2225
2226
2229 LinalgOp target,
2233 if (!getOperandsToPromote().empty())
2235 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2236 if (getUseFullTilesByDefault())
2238 getUseFullTilesByDefault());
2239 if (getUseAlloca())
2240 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2241 if (!getUseFullTileBuffers().empty())
2243 llvm::to_vector(getUseFullTileBuffers().getAsValueRange()));
2244 if (getAlignment().has_value())
2245 promotionOptions = promotionOptions.setAlignment(*getAlignment());
2246 if (getMemorySpace().has_value())
2247 promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2248
2249 if (getMapping().has_value()) {
2250
2251 auto mapping = *getMapping();
2252 if (mapping.size() > 1)
2253 return emitDefaultDefiniteFailure(target);
2254
2255 auto addressSpace = castmlir::gpu::GPUMemorySpaceMappingAttr(mapping[0]);
2256
2257 if (addressSpace.getAddressSpace() ==
2258 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2259 promotionOptions =
2260 promotionOptions
2265 } else if (addressSpace.getAddressSpace() ==
2266 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2267 promotionOptions =
2268 promotionOptions
2273 } else {
2274 return emitDefaultDefiniteFailure(target);
2275 }
2276 }
2277
2279 return emitDefaultDefiniteFailure(target);
2280
2282 FailureOr res = promoteSubViews(rewriter, target, promotionOptions);
2283 if (failed(res))
2284 return emitDefaultDefiniteFailure(target);
2287 }
2288
2289
2290
2291
2292
2297 auto payload = state.getPayloadOps(getTarget());
2298
2299
2300 for (Operation *target : payload) {
2301 if (target->getNumOperands() > 0)
2304 target->getNumRegions() > 0)
2306 << "expected target that is isolated from above";
2307 }
2308
2309
2310 Operation *pattern = &getBodyRegion().front().front();
2312 for (Operation *target : payload) {
2313 if (getOperation()->isAncestor(target))
2314 continue;
2318 replacements.push_back(replacement);
2319 }
2320 transformResults.set(cast(getReplacement()), replacements);
2322 }
2323
2324 void transform::ReplaceOp::getEffects(
2327 producesHandle(getOperation()->getOpResults(), effects);
2329 }
2330
2332 if (!getBodyRegion().hasOneBlock())
2333 return emitOpError() << "expected one block";
2334 if (std::distance(getBodyRegion().front().begin(),
2335 getBodyRegion().front().end()) != 1)
2336 return emitOpError() << "expected one operation in block";
2337 Operation *replacement = &getBodyRegion().front().front();
2340 << "expected replacement without operands";
2344 << "expect op that is isolated from above";
2345 return success();
2346 }
2347
2348
2349
2350
2351
2354 LinalgOp target,
2360 Location loc = target.getLoc();
2362 target.createFlatListOfOperandDims(b, loc);
2363 AffineMap map = target.getShapesToLoopsMap();
2364 if (!map)
2365 return tileSizes;
2368 allShapeSizes);
2369
2370
2374 }
2375 return tileSizes;
2376 });
2378 FailureOrscf::SCFTilingResult maybeTilingResult = tileUsingSCF(
2379 rewriter, cast(target.getOperation()), tilingOptions);
2380 if (failed(maybeTilingResult))
2381 return emitDefaultDefiniteFailure(target);
2382
2383 if (target->getNumResults())
2384 rewriter.replaceOp(target, maybeTilingResult->replacements);
2385 else
2386 rewriter.eraseOp(target);
2387
2388 results.reserve(maybeTilingResult->tiledOps.size());
2389 for (Operation *tiled : maybeTilingResult->tiledOps)
2392 }
2393
2394
2395
2396
2397
2403 for (Operation *target : state.getPayloadOps(getTarget())) {
2404 auto tilingOp = dyn_cast(*target);
2405 if (!tilingOp) {
2407 emitSilenceableError()
2408 << "expected the payload to implement TilingInterface";
2409 diag.attachNote(target->getLoc()) << "payload op";
2410 return diag;
2411 }
2413 FailureOr<SmallVectorscf::ForOp> generatedLoops =
2415 if (failed(generatedLoops))
2416 return emitDefaultDefiniteFailure(target);
2417 for (scf::ForOp &loop : *generatedLoops) {
2418 loops.push_back(loop.getOperation());
2419 }
2420 rewriter.eraseOp(target);
2421 }
2422 results.set(cast(getResult()), loops);
2424 }
2425
2426
2427
2428
2429
2431 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2436 FailureOr<Operation *> maybeResult =
2438 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2439 [&rewriter](auto op) {
2441 });
2442 if (failed(maybeResult))
2443 return emitDefaultSilenceableFailure(target);
2444 results.push_back(*maybeResult);
2446 }
2447
2448
2449
2450
2451
2455
2457 llvm::to_vector(state.getPayloadOps(getTarget()));
2458
2459 bool isMultiwaySplit = getMultiway();
2460
2461 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2463 << "requires exactly one target when "
2464 "multiway split is enabled (got "
2465 << llvm::range_size(payload) << ")";
2466 }
2467
2469
2470 if (!isMultiwaySplit)
2471 chunkSizes.reserve(payload.size());
2472
2473 if (getDynamicChunkSizes()) {
2475 if (isa(getDynamicChunkSizes().getType())) {
2476 chunkSizes = llvm::to_vector(llvm::map_range(
2477 state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2480 diag = emitSilenceableError()
2481 << "expected dynamic split point handle to point to a "
2482 "single-result index-typed op";
2483 diag.attachNote(op->getLoc()) << "dynamic split point";
2484 }
2486 }));
2487 } else {
2488 chunkSizes = llvm::to_vector(
2489 llvm::map_range(state.getParams(getDynamicChunkSizes()),
2491 }
2492 if (diag.isSilenceableFailure())
2493 return diag;
2494
2495
2496
2497 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2499 << "expected the dynamic split point handle to point to as "
2500 "many operations ("
2501 << chunkSizes.size() << ") as the target handle ("
2502 << payload.size() << ")";
2503 }
2504 } else {
2505 chunkSizes.resize(payload.size(),
2506 rewriter.getIndexAttr(getStaticChunkSizes()));
2507 }
2508
2509 auto checkStructuredOpAndDimensions =
2511 if (!linalgOp) {
2512 auto diag = emitSilenceableError() << "only applies to structured ops";
2513 diag.attachNote(loc) << "target op";
2514 return diag;
2515 }
2516
2517 if (getDimension() >= linalgOp.getNumLoops()) {
2518 auto diag = emitSilenceableError() << "dimension " << getDimension()
2519 << " does not exist in target op";
2520 diag.attachNote(loc) << "target op";
2521 return diag;
2522 }
2524 };
2525
2526 auto checkFailureInSplitting =
2528 if (hasFailed) {
2530 diag.attachNote(loc) << "target op";
2531 return diag;
2532 }
2534 };
2535
2537 if (isMultiwaySplit) {
2538
2539
2540 TilingInterface head, tail;
2541 Operation *target = payload.front();
2542
2543 LinalgOp linalgOp = dyn_cast(target);
2544
2545
2547 checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2548 if (diag.isSilenceableFailure())
2549 return diag;
2550
2551 for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2552
2553 if (idx > 0)
2554 target = tail.getOperation();
2555
2556 if (!target)
2557 break;
2558
2559 linalgOp = cast(target);
2561
2564 rewriter, cast(linalgOp.getOperation()),
2565 getDimension(), chunkSize);
2566
2567
2569 checkFailureInSplitting(!head && !tail, loc);
2570 if (diag.isDefiniteFailure())
2571 return diag;
2572
2573 opList.push_back(head.getOperation());
2574 }
2575
2576
2577 if (tail)
2578 opList.push_back(tail.getOperation());
2579
2580 } else {
2581
2583 Operation *noSecondPart = nullptr;
2584 for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2585 Operation *target = std::get<0>(pair);
2587 LinalgOp linalgOp = dyn_cast(target);
2589 checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2590
2591 if (diag.isSilenceableFailure())
2592 return diag;
2593
2595 std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2596 rewriter, cast(linalgOp.getOperation()),
2597 getDimension(), std::get<1>(pair));
2598
2599
2601 checkFailureInSplitting(!first.back() && !second.back(), loc);
2603 return diag;
2604
2605
2606 if (!second.back()) {
2607 noSecondPart = target;
2608 second.pop_back();
2609 }
2610 }
2611
2612 if (second.size() != first.size() && !second.empty()) {
2613 auto diag = emitSilenceableError()
2614 << "splitting does not produce the second part for a subset "
2615 "of targets";
2616 diag.attachNote()
2617 << "expected splitting to produce the second part of all "
2618 "or none of the targets";
2619 diag.attachNote(noSecondPart->getLoc())
2620 << "first target with no second part";
2621 return diag;
2622 }
2623
2624 opList.append(first);
2625 if (second.size())
2626 opList.append(second);
2627 }
2628 results.set(cast(getSplitList()), opList);
2630 }
2631
2632 void SplitOp::getEffects(
2635 if (getDynamicChunkSizes())
2636 onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
2637 producesHandle(getOperation()->getOpResults(), effects);
2639 }
2640
2643 IntegerAttr staticChunkSizes;
2645 return failure();
2646
2649 if (!dynamicPointParseResult.has_value()) {
2650 int64_t staticChunkSizesValue;
2651 if (failed(parser.parseInteger(staticChunkSizesValue)))
2652 return failure();
2653
2654 staticChunkSizes =
2656 }
2657
2658 Type targetType;
2662 return failure();
2663 }
2664 if (dynamicPointParseResult.has_value()) {
2665 Type ChunkSizesType;
2666 if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2667 parser.parseType(ChunkSizesType) ||
2668 parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
2670 return failure();
2671 }
2672
2673 staticChunkSizes =
2675 }
2676
2678 SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
2679 staticChunkSizes);
2680 result.addTypes(targetType);
2681 return success();
2682 }
2683
2685 printer << " " << getTarget() << " after ";
2686 int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
2687 if (staticChunkSize != ShapedType::kDynamic)
2688 printer << staticChunkSize;
2689 else
2690 printer << getDynamicChunkSizes();
2691 printer << " ";
2693 {getStaticChunkSizesAttrName()});
2694 printer << " : " << getTarget().getType();
2695 if (staticChunkSize == ShapedType::kDynamic)
2696 printer << ", " << getDynamicChunkSizes().getType();
2697 }
2698
2700 if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2701 (getDynamicChunkSizes() == nullptr)) {
2702 return emitOpError() << "expects either a dynamic or a static split "
2703 "point to be provided";
2704 }
2705 return success();
2706 }
2707
2708
2709
2710
2711
2712 void transform::SplitReductionOp::build(
2714 int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2715 bool useScalingAlgorithm, bool useAlloc) {
2718 result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
2721 SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
2723 if (innerParallel) {
2724 result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
2726 }
2727 if (useScalingAlgorithm) {
2729 SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
2731 }
2732 if (useAlloc) {
2733 result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
2735 }
2737 result.addTypes({resultType, resultType, resultType, resultType});
2738 }
2739
2746 unsigned(getInsertSplitDimension()),
2747 bool(getInnerParallel())};
2748 };
2750 FailureOr splitResult =
2751 (getUseScalingAlgorithm())
2753 : splitReduction(rewriter, target, splitFn, getUseAlloc());
2754 if (failed(splitResult))
2755 return emitDefaultDefiniteFailure(target);
2756
2757 results.push_back(splitResult->initOrAlloc);
2758 results.push_back(splitResult->fillOp);
2759 results.push_back(splitResult->splitLinalgOp);
2760 results.push_back(splitResult->resultCombiningLinalgOp);
2762 }
2763
2764
2765
2766
2767
2768 void transform::TileReductionUsingForOp::build(
2771
2772
2773
2774
2775
2779 build(builder, result,
2780 TypeRange{opTy, opTy, opTy, opTy},
2781 target,
2782 staticTileSizesAttr);
2783 }
2784
2790
2791 auto partialReductionOp = dyn_cast(target);
2792 if (!partialReductionOp) {
2795 "Operation should implement PartialReductionOpInterface");
2796 }
2798 rewriter, partialReductionOp,
2800
2801 if (failed(result))
2802 return emitDefaultSilenceableFailure(target);
2803 rewriter.replaceOp(target, result->replacements);
2804 for (Value initValue : result->initialValues)
2806 for (auto parallelTiledOp : result->tiledOps)
2807 results.push_back(parallelTiledOp);
2808 for (auto mergeOp : result->mergeOps)
2810 results.push_back(result->loops.front());
2812 }
2813
2814
2815
2816
2817
2818 void transform::TileReductionUsingForallOp::build(
2821 ArrayAttr mapping) {
2822
2823
2824
2825
2826
2831 build(builder, result,
2832 TypeRange{opTy, opTy, opTy, opTy},
2833 target,
2834 staticNumThreadsAttr,
2835 staticTileSizesAttr,
2836 mapping);
2837 }
2838
2848 FailureOrlinalg::ForallReductionTilingResult result =
2850 rewriter, cast(target.getOperation()),
2851 numThreads, tileSizes, getMapping());
2852
2853 if (failed(result)) {
2854 auto diag = emitSilenceableError() << "could not tile reduction";
2855 diag.attachNote(target.getLoc()) << "target operation";
2856 return diag;
2857 }
2858 for (Value initValue : result->initialValues)
2860 for (auto parallelTiledOp : result->parallelTiledOps)
2861 results.push_back(parallelTiledOp);
2862 for (auto mergeOp : result->mergeOps)
2864 results.push_back(result->loops);
2866 }
2867
2868
2869
2870
2871
2876
2878 llvm::to_vector(state.getPayloadOps(getTarget()));
2879
2880 if (!llvm::hasSingleElement(targetOps)) {
2882 << "requires exactly one target (got " << llvm::range_size(targetOps)
2883 << ")";
2884 }
2885
2886 Operation *target = *targetOps.begin();
2887 auto linalgOp = dyn_cast(target);
2888 auto tileableOp = dyn_cast(target);
2889
2890 if (!linalgOp)
2892
2893 OpBuilder builder(linalgOp.getContext());
2894
2895 if (isa(getChunkSizes().getType())) {
2896 if (linalgOp.hasDynamicShape()) {
2897 auto diag = emitSilenceableError()
2898 << "cannot compute parametric tile sizes for dynamically "
2899 "shaped payload op";
2900 diag.attachNote(linalgOp->getLoc()) << "payload op";
2901 return diag;
2902 }
2903
2904 FailureOr spec =
2906 getTargetSize());
2907 if (failed(spec)) {
2908 return emitSilenceableError()
2909 << "failed to compute multi-size tiling sizes";
2910 }
2911
2913
2914 for (auto &&[tileSize, tripCount] :
2915 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
2916 chunkSizes.push_back(tileSize * tripCount);
2917
2919 return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
2921 });
2922 };
2924 getI64AttrsFromI64(spec->tileSizes));
2925 transformResults.setParams(cast(getChunkSizes()),
2926 getI64AttrsFromI64(chunkSizes));
2927
2929 }
2930
2932
2934 unsigned dimension = getDimension();
2935
2937 builder, tileableOp, dimension, targetSize, true);
2938 if (failed(spec)) {
2939 return emitSilenceableError() << "could not generate tile size computation";
2940 }
2941
2946 ofrs);
2947 };
2948
2950 Value splitPoint;
2951 for (auto &&[tileSize, tripCount] :
2952 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
2953 splitPoint = apply(s0 * s1, {tileSize, tripCount});
2954 chunkSizes.push_back(splitPoint);
2955 }
2956
2958 return llvm::map_to_vector(values, [&](Value value) -> Operation * {
2960 });
2961 };
2962
2964 getDefiningOps(spec->tileSizes));
2965 transformResults.set(cast(getChunkSizes()),
2966 getDefiningOps(chunkSizes));
2967
2969 }
2970
2972
2974 return emitOpError() << "expects all results type to be the same";
2975 }
2976
2977 return success();
2978 }
2979
2980 void transform::ContinuousTileSizesOp::getEffects(
2984 else
2987 producesHandle(getOperation()->getOpResults(), effects);
2988 }
2989
2991 Type targetType, Type tile_sizes,
2994 }
2995
2997 Type &targetType,
2998 Type &tileSizesType,
2999 Type &chunkSizesType) {
3000 FunctionType funcType;
3002 if (failed(parser.parseType(funcType)))
3003 return failure();
3004
3005 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3006 parser.emitError(typeLoc) << "expects a trailing functional type with one "
3007 "argument and one result";
3008 }
3009 targetType = funcType.getInput(0);
3010 tileSizesType = chunkSizesType = funcType.getResult(0);
3011
3012 return success();
3013 }
3014
3015
3016
3017
3018
3019 void transform::TileUsingForOp::build(
3024 return build(builder, result, loopTypes,
3025 target,
3026
3028 interchange, scalableSizes);
3029 }
3030
3031 void transform::TileUsingForOp::build(
3035 build(builder, result, target,
3037 interchange, scalableSizes);
3038 }
3039
3040 void transform::TileUsingForOp::build(
3044
3045
3047 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3048 scalableSizes);
3049 }
3050
3051 void transform::TileUsingForOp::build(
3059
3060
3061
3063 unsigned numExpectedLoops =
3064 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3066 resultTypes.reserve(numExpectedLoops);
3067 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3068 "expected one loop type or as many as loops");
3069 if (loopTypes.size() == 1)
3070 resultTypes.append(numExpectedLoops, loopTypes[0]);
3071 else
3072 llvm::append_range(resultTypes, loopTypes);
3073 SmallVector expandedScalableSizes(mixedTileSizes.size(), false);
3074 if (scalableSizes.has_value())
3075 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3076 build(builder, result, target.getType(),
3077 resultTypes,
3078 target,
3079 dynamicTileSizes,
3080 staticTileSizesAttr,
3082 expandedScalableSizes);
3083 }
3084
3086 if (getMixedSizes().size() != getScalableSizes().size())
3087 return emitOpError("expected same number of sizes (")
3088 << getMixedSizes().size() << ") and scalable sizes ("
3089 << getScalableSizes().size() << ")";
3091 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3092 if (getLoops().size() != numExpectedLoops)
3093 return emitOpError("expected number of loops to tile (")
3094 << numExpectedLoops << ") to match number of `loops` results ("
3095 << getLoops().size() << ")";
3096 return success();
3097 }
3098
3104
3106 llvm::to_vector(state.getPayloadOps(getTarget()));
3112 if (isa(transformValue.getType())) {
3113 dynamicSizeProducers.push_back({});
3115 paramSizes.push_back(
3116 llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
3117 return cast(attr).getValue().getSExtValue();
3118 })));
3119
3120 if (paramSizes.back().size() != targets.size()) {
3122 emitSilenceableError()
3123 << "expected as many parameter values ("
3124 << dynamicSizeProducers.back().size() << ") as target ops ("
3125 << targets.size() << ")";
3126 diag.attachNote(transformValue.getLoc()) << "for this parameter";
3127 return diag;
3128 }
3129
3130 continue;
3131 }
3132 paramSizes.push_back({});
3133 dynamicSizeProducers.push_back(
3134 llvm::to_vector(state.getPayloadOps(transformValue)));
3135
3136 if (dynamicSizeProducers.back().size() != targets.size()) {
3138 emitSilenceableError()
3139 << "expected as many dynamic size-producing operations ("
3140 << dynamicSizeProducers.back().size() << ") as target ops ("
3141 << targets.size() << ")";
3142 diag.attachNote(transformValue.getLoc()) << "for this handle";
3143 return diag;
3144 }
3145
3146 for (Operation *op : dynamicSizeProducers.back()) {
3149 continue;
3150 }
3151
3153 emitSilenceableError() << "expected sizes to be produced by ops "
3154 "with a single index-type result";
3155 diag.attachNote(op->getLoc()) << "size producer op";
3156 diag.attachNote(transformValue.getLoc()) << "for this handle";
3157 return diag;
3158 }
3159 }
3160
3163 loops.resize(getLoops().size());
3164 auto scalableSizes = getScalableSizes();
3166 auto tilingInterface = dyn_cast(op);
3167 if (!tilingInterface) {
3169 emitSilenceableError()
3170 << "only ops implementing TilingInterface are supported";
3171 diag.attachNote(op->getLoc()) << "target op";
3172 return diag;
3173 }
3174 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3176 emitSilenceableError()
3177 << "too many tiles provided, expected at most "
3178 << tilingInterface.getLoopIteratorTypes().size() << " found "
3179 << tileSizes.size();
3180 diag.attachNote(op->getLoc()) << "target op";
3181 return diag;
3182 }
3183
3185 if (tileSizes.empty()) {
3188 return {};
3189 });
3190 } else {
3194 sizes.reserve(tileSizes.size());
3195 unsigned dynamicIdx = 0;
3196
3198 if (auto attr = llvm::dyn_cast_if_present(ofr)) {
3199 if (scalableSizes[ofrIdx]) {
3200 auto val = b.createarith::ConstantIndexOp(
3201 getLoc(), cast(attr).getInt());
3204 sizes.push_back(
3205 b.createarith::MulIOp(getLoc(), val, vscale).getResult());
3206 } else {
3207 sizes.push_back(attr);
3208 }
3209 continue;
3210 }
3213 ++dynamicIdx;
3214 assert((dynamicSizes.empty() ^ params.empty()) &&
3215 "expected either dynamic sizes or parameters");
3216 if (!params.empty()) {
3217 sizes.push_back(b.getIndexAttr(params[index]));
3218 } else {
3219 sizes.push_back(dynamicSizes[index]->getResult(0));
3220 }
3221 }
3222 return sizes;
3223 });
3224 }
3225
3227 FailureOrscf::SCFTilingResult maybeTilingResult =
3228 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3229 if (failed(maybeTilingResult))
3231
3232 rewriter.replaceOp(op, maybeTilingResult->replacements);
3233
3234 tiled.append(maybeTilingResult->tiledOps);
3235 for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3236 loops[en2.index()].push_back(en2.value());
3237 }
3238
3239 transformResults.set(cast(getTiledLinalgOp()), tiled);
3241 transformResults.set(cast(getLoops()[en.index()]), en.value());
3242
3244 }
3245
3250 results.reserve(tileSizes.size());
3251 unsigned dynamicPos = 0;
3253 for (int64_t size : tileSizes) {
3254 if (size == ShapedType::kDynamic) {
3255 results.push_back(dynamic[dynamicPos++]);
3256 } else {
3257 results.push_back(builder.getIndexAttr(size));
3258 }
3259 }
3260 return results;
3261 }
3262
3263 void transform::TileUsingForOp::getEffects(
3267 producesHandle(getOperation()->getOpResults(), effects);
3269 }
3270
3271
3272
3273
3274
3275 void transform::TileUsingForallOp::build(OpBuilder &builder,
3279 ArrayAttr mapping) {
3280 return build(builder, result,
3281 target,
3282
3285 mapping);
3286 }
3287
3288 void transform::TileUsingForallOp::build(OpBuilder &builder,
3292 ArrayAttr mapping) {
3296
3297
3298
3302 build(builder, result,
3303 TypeRange{operationType, operationType},
3304 target,
3306 dynamicTileSizes,
3307 Value(),
3308 Value(),
3310 staticTileSizesAttr,
3311 mapping);
3312 }
3313
3314 void transform::TileUsingForallOp::build(OpBuilder &builder,
3318 ArrayAttr mapping) {
3319 return build(builder, result, target,
3322 }
3323
3324 void transform::TileUsingForallOp::build(OpBuilder &builder,
3328 ArrayAttr mapping) {
3332 staticNumThreads);
3333
3334
3335
3339 build(builder, result,
3340 TypeRange{operationType, operationType},
3341 target,
3342 dynamicNumThreads,
3344 Value(),
3345 Value(),
3346 staticNumThreadsAttr,
3348 mapping);
3349 }
3350
3351
3352
3359 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3361 for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3363 rewriter, loc, normalizedUbExpr, {lb, ub, step});
3364 normalizedUbs.push_back(normalizedUb);
3365 }
3366 return normalizedUbs;
3367 }
3368
3369
3370
3379 AffineExpr denormExpr = s0 + d0 * s1;
3381
3382 for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3385 denormalizedIvs.push_back(
3387 }
3388 return denormalizedIvs;
3389 }
3390
3391
3392
3393
3394
3395
3396
3397
3399 scf::ForallOp loop) {
3403
3405 return loop;
3406 }
3407
3408 Location loc = loop.getLoc();
3415
3416 auto normalizedForallOp = rewriter.createscf::ForallOp(
3417 loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
3419
3420 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3422 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3424
3426 denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3427 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3428 normalizedForallOp.getRegionIterArgs().end());
3429 Block *origLoopBlock = loop.getBody();
3430 rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3431
3432 rewriter.replaceOp(loop, normalizedForallOp);
3433 return normalizedForallOp;
3434 }
3435
3438 TransformOpInterface transformOp, Operation *target,
3442
3443 auto tileableOp = dyn_cast(target);
3444 if (!tileableOp) {
3446 transformOp.emitSilenceableError()
3447 << "only TilingInterface ops are supported";
3448 diag.attachNote(target->getLoc()) << "target op";
3449 return diag;
3450 }
3454 if (!mixedNumThreads.empty()) {
3455 options.setNumThreads(mixedNumThreads);
3456 } else {
3457 options.setTileSizes(mixedTileSizes);
3458 }
3459 if (mapping) {
3460 options.setMapping(mapping.value().getValue());
3461 }
3462 FailureOrscf::SCFTilingResult maybeTilingResult =
3464
3465 if (failed(maybeTilingResult))
3466 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3467
3468 rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
3469
3470 tilingResult = *maybeTilingResult;
3471
3472 if (mixedNumThreads.empty()) {
3473 auto generatedForallOp = castscf::ForallOp(tilingResult.loops.front());
3476 scf::ForallOp normalizedForallOp =
3478 tilingResult.loops.front() = normalizedForallOp;
3479 }
3480
3482 }
3483
3488 auto transformOp = cast(getOperation());
3489
3490
3493
3494
3497 getPackedNumThreads()
3499 state, transformOp, mixedNumThreads, getPackedNumThreads())
3501 state, transformOp, mixedNumThreads, getMixedNumThreads());
3503 return status;
3505 status = getPackedTileSizes()
3507 state, transformOp, mixedTileSizes, getPackedTileSizes())
3509 state, transformOp, mixedTileSizes, getMixedTileSizes());
3511 return status;
3512
3513 for (Operation *target : state.getPayloadOps(getTarget())) {
3516 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3517 getMapping(), tilingResult);
3518 if (.succeeded())
3519 return diag;
3520 tileOps.push_back(tilingResult.loops.front());
3521 tiledOps.append(tilingResult.tiledOps);
3522 }
3523
3524 transformResults.set(cast(getForallOp()), tileOps);
3525 transformResults.set(cast(getTiledOp()), tiledOps);
3526
3528 }
3529
3530 void transform::TileUsingForallOp::getEffects(
3537 producesHandle(getOperation()->getOpResults(), effects);
3539 }
3540
3543 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3544 }
3545
3549 }
3550
3552 int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3553 static_cast<int>(getPackedNumThreads() != Value());
3554 if (numThreadsSpec > 1)
3555 return emitOpError(
3556 "num_threads and packed_num_threads are mutually exclusive");
3557 int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3558 static_cast<int>(getPackedTileSizes() != Value());
3559 if (tileSizesSpec > 1)
3560 return emitOpError(
3561 "tile_sizes and packed_tile_sizes are mutually exclusive");
3562 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3563 return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3564 "must be specified");
3565 return success();
3566 }
3567
3568
3569
3570
3571
3572 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3574 bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3576 if (vectorizePadding) {
3578 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3579 result.name),
3581 }
3582 if (vectorizeExtract) {
3584 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3585 result.name),
3587 }
3588 if (flatten1DDepthwiseConv) {
3590 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3591 result.name),
3593 }
3595 }
3596
3597 namespace {
3598
3599
3600 struct VectorizationPattern : public RewritePattern {
3601 explicit VectorizationPattern(MLIRContext *context,
3602 bool vectorizeExtract = false,
3603 bool flattenConv = false)
3604 : RewritePattern(MatchAnyOpTypeTag(), 1, context),
3605 vectorizeNDExtract(vectorizeExtract),
3606 flatten1DDepthwiseConv(flattenConv) {}
3607 LogicalResult matchAndRewrite(Operation *op,
3611 "Unsupported Op, cannot vectorize");
3612 return vectorize(rewriter, op, {},
3613 {}, vectorizeNDExtract,
3614 flatten1DDepthwiseConv);
3615 }
3616
3617 private:
3618
3619
3620 bool vectorizeNDExtract = false;
3621
3622
3623
3624 bool flatten1DDepthwiseConv = false;
3625 };
3626 }
3627
3629 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3634 auto diag = this->emitOpError("requires isolated-from-above targets");
3635 diag.attachNote(target->getLoc()) << "non-isolated target";
3637 }
3638
3641 patterns.add(ctx, getVectorizeNdExtract(),
3642 getFlatten_1dDepthwiseConv());
3643
3644 if (!getDisableTransferPermutationMapLoweringPatterns())
3646
3647 if (!getDisableMultiReductionToContractPatterns())
3649
3651
3654 2);
3655 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3656 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3658
3660
3661 if (getVectorizePadding()) {
3663
3664
3666 }
3668
3670 if (failed(
3673 return emitDefaultDefiniteFailure(target);
3674
3677 }
3678
3679
3680
3681
3682
3687 auto targets = state.getPayloadOps(getTarget());
3688 if (std::empty(targets))
3690 auto transformOp = cast(getOperation());
3693 state, transformOp, getMixedVectorSizes(), vectorSizes);
3695 return status;
3696
3697
3698 for (Operation *target : targets) {
3701 << "Unsupported Op, cannot vectorize";
3702 }
3703
3705 getScalableSizes(),
3706 getVectorizeNdExtract().value_or(false)))) {
3708 << "Attempted to vectorize, but failed";
3709 }
3710 }
3711
3713 }
3714
3715 void transform::VectorizeOp::getEffects(
3720 }
3721
3724 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3725 }
3726
3728 if (getStaticVectorSizes().size() != getScalableSizes().size())
3729 return emitOpError("expected same number of vector sizes (")
3730 << getStaticVectorSizes().size() << ") and scalable sizes ("
3731 << getScalableSizes().size() << ")";
3732 return success();
3733 }
3734
3735
3736
3737
3738
3740 transform::HoistRedundantVectorTransfersOp::applyToOne(
3744
3745
3746
3750 }
3751
3752
3753
3754
3755
3757 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3765 }
3766
3767
3768
3769
3770
3776 auto maybeTransformed =
3778 target)
3779 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3781 })
3782 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3784 })
3785 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3787 })
3788 .Case([&](linalg::Conv2DNchwFchwOp op) {
3790 })
3793 });
3794 if (failed(maybeTransformed))
3795 return emitDefaultSilenceableFailure(target);
3796
3797 results.push_back(maybeTransformed->first);
3798
3799 results.push_back(maybeTransformed->second);
3801 }
3802
3803
3804
3805
3806
3814 << "only elementwise flattening is supported";
3815
3816
3817 if (target.getNumLoops() <= 1) {
3820 }
3821
3822
3824 std::iota(reassociation.begin(), reassociation.end(), 0);
3825 auto maybeFlattened =
3827 if (failed(maybeFlattened))
3829 << "attempted to flatten, but failed";
3830 results.push_back(maybeFlattened->collapsedOp);
3831 rewriter.replaceOp(target, maybeFlattened->results);
3833 }
3834
3835
3836
3837
3838
3844 auto maybeTransformed =
3846 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3848 })
3849 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3851 })
3854 });
3855 if (failed(maybeTransformed))
3856 return emitDefaultSilenceableFailure(target);
3857
3858 results.push_back(*maybeTransformed);
3860 }
3861
3862
3863
3864
3865
3871 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
3872 auto maybeTransformed =
3874 .Case([&](linalg::MatmulOp op) {
3876 })
3877 .Case([&](linalg::BatchMatmulOp op) {
3879 })
3880 .Default([&](Operation *op) { return failure(); });
3881 if (failed(maybeTransformed))
3883
3884 results.push_back(*maybeTransformed);
3886 }
3887
3888
3889
3890
3891 template
3895 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3896 tensor::ParallelInsertSliceOp>() &&
3897 "wrong op type");
3898
3899 if (auto copySource =
3900 target.getSource().template getDefiningOplinalg::CopyOp()) {
3903 }
3904
3905
3906
3907 if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3909 target->template getParentOfTypescf::InParallelOp());
3910 }
3911
3912 Value extracted = rewriter.createtensor::ExtractSliceOp(
3913 target.getLoc(), target.getDest(), target.getMixedOffsets(),
3914 target.getMixedSizes(), target.getMixedStrides());
3915 Value copied = rewriter
3916 .createlinalg::CopyOp(target.getLoc(),
3917 target.getSource(), extracted)
3918 .getResult(0);
3919
3922 target, copied, target.getDest(), target.getMixedOffsets(),
3923 target.getMixedSizes(), target.getMixedStrides());
3924
3925 results.push_back(copied.getDefiningOp());
3927 }
3928
3933
3935 if (auto target = dyn_casttensor::InsertSliceOp(targetOp))
3936 return doit(rewriter, target, results, state);
3937 if (auto target = dyn_casttensor::ParallelInsertSliceOp(targetOp))
3938 return doit(rewriter, target, results, state);
3939
3941 emitSilenceableError()
3942 << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3943 diag.attachNote(targetOp->getLoc()) << "target op";
3944 return diag;
3945 }
3946
3947
3948
3949
3950
3955
3956 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3958 emitSilenceableError()
3959 << "only linalg.copy and tensor.pad target ops are supported";
3960 diag.attachNote(target->getLoc()) << "target op";
3961 return diag;
3962 }
3963 assert(target->getNumResults() == 1 && "expected single result");
3964 auto resultShapedType = cast(target->getResult(0).getType());
3965 if (!resultShapedType.hasStaticShape()) {
3967 emitSilenceableError()
3968 << "only statically sized ops of rank <= 3 are supported";
3969 diag.attachNote(target->getLoc()) << "target op";
3970 return diag;
3971 }
3972
3973
3974 int64_t desiredBitAlignment = getDesiredBitAlignment();
3975 int64_t eltBitwidth =
3976 resultShapedType.getElementType().getIntOrFloatBitWidth();
3977 if (desiredBitAlignment % eltBitwidth != 0) {
3978 desiredBitAlignment = eltBitwidth;
3979 }
3980
3983 getTotalNumThreads(),
3984 desiredBitAlignment,
3985 resultShapedType.getShape(),
3986 false,
3987
3988 resultShapedType.getElementType().getIntOrFloatBitWidth());
3989 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3991 emitSilenceableError()
3992 << "too few threads to map copy op to threads on the most minor "
3993 "dimension, given alignment and vector size constraints, try "
3994 "smaller tile size of mapping to more threads";
3995 diag.attachNote(target->getLoc()) << "target op";
3996 return diag;
3997 }
3998
3999
4003 rewriter,
4004 state,
4005 *this,
4006 target,
4007 getMixedValues(mapping.numThreads, {}, b),
4009 b.getArrayAttr(mapping.threadMapping),
4010 tilingResult);
4011 if (.succeeded())
4012 return diag;
4013
4015 for (auto op : tilingResult.tiledOps)
4018 }
4019
4020
4021
4022
4023
4029 FailureOr<Operation *> maybeTransformed = failure();
4031 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4032 maybeTransformed =
4034 return true;
4035 })
4036 .Default([&](Operation *op) { return false; });
4037
4038 if (!supported) {
4039 return emitSilenceableError()
4040 << "this operation is not supported to convert to Winograd Conv2D";
4041 }
4042
4043 if (failed(maybeTransformed)) {
4044 return emitSilenceableError() << "apply Winograd Conv2D failed";
4045 }
4046
4047 results.push_back(*maybeTransformed);
4049 }
4050
4056 FailureOr<Operation *> maybeTransformed = failure();
4057 bool supported =
4059 .Case([&](linalg::WinogradFilterTransformOp op) {
4061 return true;
4062 })
4063 .Case([&](linalg::WinogradInputTransformOp op) {
4065 return true;
4066 })
4067 .Case([&](linalg::WinogradOutputTransformOp op) {
4069 return true;
4070 })
4071 .Default([&](Operation *op) { return false; });
4072
4073 if (!supported) {
4075 emitSilenceableError()
4076 << "this operation is not supported to decompose into other operations";
4077 diag.attachNote(target->getLoc()) << "target op";
4078 return diag;
4079 }
4080
4081 if (failed(maybeTransformed)) {
4083 emitSilenceableError() << "decompose Winograd operations failed";
4084 diag.attachNote(target->getLoc()) << "target op";
4085 return diag;
4086 }
4087
4088 results.push_back(*maybeTransformed);
4090 }
4091
4092 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4093
4094 #define GET_OP_CLASSES
4095 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
static MLIRContext * getContext(OpFoldResult val)
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
bool isValidPackingPermutation(RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type)
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, scf::ForallOp loop)
Given a scf.forall loop return a loop op with the loop bounds normalized.
static SmallVector< Value > denormalizeIndVar(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)
When a loop is normalized, the uses of the induction variable within the loop need to replaced with o...
#define DOWNSCALE_NORMAL(a, b)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static bool sameOrEquivalentIterArg(Value src, Value dst)
Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is ...
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
static SmallVector< OpFoldResult > normalizeUpperBounds(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)
Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
This class represents a saved insertion point.
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getOpResult(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
user_range getUsers()
Returns a range of all users.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
A listener that updates a TransformState based on IR modifications.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSizes)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
uint64_t getM(LevelType lt)
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, scf::SCFTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
@ MaterializeInDestination
Transformation to drop unit-extent dimensions from linalg.generic operations.
Vectorization pattern for memref::CopyOp.
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Match and rewrite for the pattern:
Match and rewrite for the pattern:
@ BufferizationMaterializeInDestination
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.