MLIR: lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
24 #include
25
26 using namespace mlir;
28
29
30
31
32
33
34
40 for (auto result : indexingMap.getResults()) {
43 Value v = b.createaffine::AffineApplyOp(loc, m, ivs);
44 indices.push_back(v);
45 }
46 return indices;
47 }
48
49
50
53 Block *body = linalgOp.getBlock();
57 if (auto indexOp = dyn_cast(&op)) {
58 map.map(indexOp.getResult(), ivs[indexOp.getDim()]);
59 continue;
60 }
62 }
63
68 OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index());
70 b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs);
71 b.creatememref::StoreOp(
72 loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(),
73 indices);
74 }
75 return success();
76 }
77
78
79
80
81
82 namespace {
83
84
85
86
87
88 template
89 struct LinalgOpTilingInterface
90 : public TilingInterface::ExternalModel<LinalgOpTilingInterface,
91 LinalgOpTy> {
92
94 LinalgOpTy concreteOp = cast(op);
95 return concreteOp.getIteratorTypesArray();
96 }
97
98
103 LinalgOp linalgOp = cast(op);
105 linalgOp.createFlatListOfOperandDims(b, loc);
106 AffineMap map = linalgOp.getShapesToLoopsMap();
107
108 return llvm::to_vector(
110 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
111 b, loc, loopExpr, allShapesSizes);
112 return Range{b.getIndexAttr(0), ofr, b.getIndexAttr(1)};
113 }));
114 }
115
116
117 FailureOr
121
122
124 LinalgOp linalgOp = cast(op);
127 b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
129 llvm::make_filter_range(
130 tiledOperands,
131 [](Value v) -> bool {
132 return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>(
134 }),
136
139
140 Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
141 offsetIndices(b, cast(tiledOp), offsets);
142
145 }
146
147
148
149
150 void
151 getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
156 unsigned numLoops = linalgOp.getNumLoops();
157 auto tilingInterfaceOp = cast(linalgOp.getOperation());
158 mappedOffsets.resize(numLoops);
159 mappedSizes.resize(numLoops);
162 tilingInterfaceOp.getIterationDomain(b);
163 for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
164 mappedOffsets[index] = value.offset;
165 mappedSizes[index] = value.size;
166 }
167 }
168 for (const auto &&[index, value] :
170 unsigned dimPosition = cast(value).getPosition();
171 mappedOffsets[dimPosition] = offsets[index];
172 mappedSizes[dimPosition] = sizes[index];
173 }
174 }
175
176
177
178 LogicalResult getIterationDomainTileFromOperandTile(
183 auto linalgOp = cast(op);
184
185
186
187
188
190 linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
193 << "unhandled get iter domain position when operand is not "
194 "accessed using a permuted projection";
195 }
196
197 getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
198 iterDomainOffsets, iterDomainSizes);
199 return success();
200 }
201
202
203
204 LogicalResult
211 LinalgOp linalgOp = cast(op);
212
216 llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) {
218 }));
219
220 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
222 b, loc, outOperand->get(), sizes,
223 linalgOp.getMatchingIndexingMap(outOperand), offsets,
224 {}, subShapeSizes, true);
225 resultOffsets = sliceParams.offsets;
226 resultSizes = sliceParams.sizes;
227 return success();
228 }
229
230 LogicalResult getIterationDomainTileFromResultTile(
235 auto linalgOp = cast(op);
236
237
238
239
240
242 linalgOp.getIndexingMapMatchingResult(op->getResult(resultNumber));
245 "unhandled tiled implementation generation when result is not "
246 "accessed using a permuted projection");
247 }
248
249 getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
250 iterDomainOffsets, iterDomainSizes);
251 return success();
252 }
253
254 FailureOr
255 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
259 if (failed(getIterationDomainTileFromResultTile(
260 op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
261 return failure();
262 }
263 auto tilingInterfaceOp = cast(op);
264 FailureOr tilingResult =
265 tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
266
267 if (failed(tilingResult))
268 return failure();
269
270 if (tilingResult->tiledOps.size() != 1)
271 return op->emitOpError("failed to generate tiled implementation");
272
276 tilingResult->generatedSlices};
277 }
278
279
280
281 FailureOr getTiledImplementationFromOperandTile(
285 if (failed(getIterationDomainTileFromOperandTile(
286 op, b, operandNumber, offsets, sizes, mappedOffsets,
287 mappedSizes))) {
288 return failure();
289 }
291 }
292
293 LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
296 auto linalgOp = cast(op);
297 if (!linalgOp.hasPureBufferSemantics())
298 return op->emitOpError("expected operation to have buffer semantics");
299
301 indexedValues.reserve(linalgOp->getNumOperands());
303
304
305 for (OpOperand &operand : linalgOp->getOpOperands()) {
306 if (!linalgOp.payloadUsesValueFromOperand(&operand)) {
307 indexedValues.push_back(nullptr);
308 continue;
309 }
310 if (linalgOp.isScalar(&operand)) {
311 indexedValues.push_back(operand.get());
312 continue;
313 }
315 builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs);
317 builder.creatememref::LoadOp(linalgOpLoc, operand.get(), indices);
318 indexedValues.push_back(load);
319 }
320
321
322 return inlinePayload(builder, linalgOp, ivs, indexedValues);
323 }
324 };
325
326
327
328
329
330
331
332
333
334
335
336
337 static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
339 unsigned resultNumber) {
341 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
342 for (int redPos : reductionDims) {
345 }
346 return map;
347 }
348
349
350
351 template
352 struct LinalgOpPartialReductionInterface
353 : public PartialReductionOpInterface::ExternalModel<
354 LinalgOpPartialReductionInterface, LinalgOpTy> {
355 FailureOr<SmallVector> generateInitialTensorForPartialReduction(
358 auto linalgOp = cast(op);
360
361 if (linalgOp.hasPureBufferSemantics())
362 return op->emitOpError("expected operation to have tensor semantics");
363
364
365 auto tilingInterfaceOp = cast(linalgOp.getOperation());
367 llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b),
368 [](Range x) { return x.size; });
369
371 for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
373 tiledShape.push_back(dimSize);
374 } else {
375 tiledShape.push_back(tileSize);
376 }
377 }
378
380 for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;
381 ++initIdx) {
383 if ((linalgOp.getRegionOutputArgs(), initIdx,
384 combinerOps) ||
385 combinerOps.size() != 1)
386 return op->emitOpError("Failed to anaysis the reduction operation.");
387
388 Operation *reductionOp = combinerOps[0];
390 if (!identity.has_value())
392 "Failed to get an identity value for the reduction operation.");
393
394
396 getPartialResultAffineMap(linalgOp, reductionDims, initIdx);
399 auto dim = cast(dimExpr);
400 partialResultShape.push_back(tiledShape[dim.getPosition()]);
401 }
402
403 Type elType =
405 Value emptyTensor =
406 b.createtensor::EmptyOp(loc, partialResultShape, elType);
407 Value constantOp = b.createarith::ConstantOp(loc, *identity);
408 auto identityTensor =
409 b.createlinalg::FillOp(loc, constantOp, emptyTensor);
410 inits.push_back(identityTensor.getResult(0));
411 }
412
413 return inits;
414 }
415
416 FailureOr
422 auto linalgOp = cast(op);
423
424
425
427 newInitMaps.reserve(linalgOp.getNumDpsInits());
428 for (int idx : llvm::seq(0, linalgOp.getNumDpsInits())) {
429
430
432 getPartialResultAffineMap(linalgOp, reductionDims, idx);
433 newInitMaps.push_back(newMap);
434 }
435
436
438 b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
440 llvm::make_filter_range(
443
444
446 for (auto [valueMap, valueToTile] : llvm::zip_equal(newInitMaps, init)) {
447 int64_t initRank = valueMap.getNumResults();
451 for (AffineExpr dimExpr : valueMap.getResults()) {
452 auto dim = cast(dimExpr);
453 initSizes.push_back(sizes[dim.getPosition()]);
454 }
455
456 auto extractSlice = b.createtensor::ExtractSliceOp(
457 loc, valueToTile, initOffset, initSizes, initStride);
458 tiledInits.push_back(extractSlice);
459 generatedSlices.push_back(extractSlice);
460 }
461
462
464
465 for (int idx : llvm::seq(0, linalgOp.getNumDpsInits())) {
466
467
468 OpOperand *initOperand = linalgOp.getDpsInitOperand(idx);
469 int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand);
470 newMaps[mapIdx] = newInitMaps[idx];
471 }
472
473
475 linalgOp.getIteratorTypesArray();
476 for (int dim : reductionDims)
477 newIteratorTypes[dim] = utils::IteratorType::parallel;
478
479
480 auto genericOp =
482 tiledInits, newMaps, newIteratorTypes);
485 genericOp.getRegion().begin(), mapping);
487 {genericOp.getOperation()},
488 llvm::map_to_vector(genericOp->getResults(),
490 generatedSlices};
491 }
492
496 auto linalgOp = cast(op);
497
498
499
500 int64_t numInits = linalgOp.getNumDpsInits();
503 for (int idx : llvm::seq(numInits)) {
504
505
506
507
509 getPartialResultAffineMap(linalgOp, reductionDims, idx);
511 for (auto [resultNum, dimExpr] :
513 unsigned dim = cast(dimExpr).getPosition();
514 if (llvm::is_contained(reductionDims, dim)) {
515 partialReductionDims.push_back(resultNum);
516 }
517 }
518
519 Value partialResult = partialReduce[idx];
520 Value init = linalgOp.getDpsInits()[idx];
521
522 auto reduction = b.createlinalg::ReduceOp(
523 loc, partialResult, init, partialReductionDims,
525
527 matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
528 Operation *clonedReductionOp = b.clone(*combinerOps[0]);
529
530 clonedReductionOp->setOperand(0, inputs[0]);
531 clonedReductionOp->setOperand(1, inputs[1]);
532 b.createlinalg::YieldOp(loc, clonedReductionOp->getResult(0));
533 });
534
535 mergeOperations.push_back(reduction);
536 replacements.push_back(reduction->getResult(0));
537 }
538
539 return MergeResult{mergeOperations, replacements};
540 }
541
542 LogicalResult getPartialResultTilePosition(
548 auto linalgOp = cast(op);
549
551 getPartialResultAffineMap(linalgOp, reductionDims, resultNumber);
553 unsigned dim = cast(dimExpr).getPosition();
554 resultSizes.push_back(sizes[dim]);
555
556 if (llvm::is_contained(reductionDims, dim)) {
557
558
560 } else {
561 resultOffsets.push_back(offsets[dim]);
562 }
563 }
564
565 return success();
566 }
567 };
568
569 template
572 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
573 "applies to only pack or unpack operations");
575 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
576 : op.getDestRank();
582 for (auto dim : llvm::seq<int64_t>(0, rank)) {
583 loopBounds[dim].offset = zero;
584 loopBounds[dim].stride = one;
585 loopBounds[dim].size = resultShape[0][dim];
586 }
587 return loopBounds;
588 }
589
593 if (permutation.empty())
594 return;
595 applyPermutationToVector(offsets, permutation);
596 applyPermutationToVector(sizes, permutation);
597 }
598
599 struct PackOpTiling
600 : public TilingInterface::ExternalModel<PackOpTiling, linalg::PackOp> {
601
603
604
605
606 auto packOp = cast(op);
608 packOp.getSourceRank(), utils::IteratorType::parallel);
609 return iteratorTypes;
610 }
611
613 return getPackUnPackIterationDomain(cast(op), b);
614 }
615
616 FailureOr
620 auto packOp = cast(op);
621 Location loc = packOp.getLoc();
622
623
624
625 int64_t inputRank = packOp.getSourceRank();
628 applyPermToRange(origOffsets, origSizes,
630
632 packOp.getDimAndTileMapping();
636 for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
642 if (dimAndTileMapping.count(dim)) {
643
644
645
646 auto avOffset = AV(dim0).bind(origOffsets[dim]);
647 auto avSize = AV(dim0).bind(origSizes[dim]);
648 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
649 inputIndices.push_back(ab.mul(avOffset, avTileSize));
650 inputSizes.push_back(ab.mul(avSize, avTileSize));
651 } else {
652 inputIndices.push_back(origOffsets[dim]);
653 inputSizes.push_back(origSizes[dim]);
654 }
655
656
657 if (packOp.getPaddingValue()) {
659 auto avDimSize = AV(dim0).bind(dimSize);
660 auto avInputIdx = AV(dim1).bind(inputIndices.back());
661 inputSizes.back() =
662 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
663 }
664 }
665
668
670 auto sourceSlice = b.createtensor::ExtractSliceOp(
671 loc, packOp.getSource(), inputIndices, inputSizes, strides);
672 tiledOperands.push_back(sourceSlice);
673
676 outputSizes)))
677 return {};
678
679 strides.append(packOp.getDestRank() - inputRank, oneAttr);
680 auto outSlice = b.createtensor::ExtractSliceOp(
681 loc, packOp.getDest(), outputOffsets, outputSizes, strides);
682 tiledOperands.push_back(outSlice);
683
684 if (auto val = packOp.getPaddingValue())
685 tiledOperands.push_back(val);
686 for (auto tile : packOp.getInnerTiles())
687 tiledOperands.push_back(tile);
688
690 loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
691
693 {tiledPackOp},
696 }
697
698 LogicalResult
704
705
706
707
708 auto packOp = cast(op);
709 int64_t inputRank = packOp.getSourceRank();
710 int64_t outputRank = packOp.getDestRank();
712 resultOffsets.assign(offsets.begin(), offsets.end());
713 resultOffsets.append(outputRank - inputRank, zeroAttr);
714
717 resultSizes.assign(sizes.begin(), sizes.end());
718 for (auto dataTileDim : llvm::seq(inputRank, outputRank))
719 resultSizes.push_back(outputShape[0][dataTileDim]);
720
721 return success();
722 }
723
724 FailureOr
725 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
728 auto packOp = cast(op);
729 int64_t numTiles = packOp.getInnerDimsPos().size();
730
731
732
733
734 for (auto offset : offsets.take_back(numTiles))
736 return failure();
737
738 for (auto iter :
739 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
741 return failure();
742
744 op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
745 if (failed(tilingResult))
746 return failure();
747 return tilingResult.value();
748 }
749
750
751
752
753 LogicalResult getIterationDomainTileFromOperandTile(
758 if (operandNumber != 0)
759 return failure();
760
761 auto packOp = cast(op);
762
763
764 if (packOp.getPaddingValue())
765 return failure();
766
767 Location loc = packOp.getLoc();
768
771 packOp.getDimAndTileMapping();
772 for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
773 if (dimAndTileMapping.count(dim)) {
774 FailureOr<int64_t> cstSize =
777 nullptr, true);
778 std::optional<int64_t> cstInnerSize =
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794 if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
795 return failure();
796 }
797
803 auto avOffset = AV(dim0).bind(offsets[dim]);
804 auto avSize = AV(dim0).bind(sizes[dim]);
805 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
806 outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
807 outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
808 } else {
809 outerDimOffsets.push_back(offsets[dim]);
810 outerDimSizes.push_back(sizes[dim]);
811 }
812 }
813 applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm());
814 resultOffsets = outerDimOffsets;
815 resultSizes = outerDimSizes;
816 return success();
817 }
818
819
820 FailureOr getTiledImplementationFromOperandTile(
823 if (operandNumber != 0)
824 return failure();
825
826 auto packOp = cast(op);
827 Location loc = packOp.getLoc();
828
829 int64_t inputRank = packOp.getSourceRank();
832
834 auto sourceSlice = b.createtensor::ExtractSliceOp(
835 loc, packOp.getSource(), offsets, sizes, strides);
836 tiledOperands.push_back(sourceSlice);
837
839 if (failed(getIterationDomainTileFromOperandTile(
840 op, b, 0, offsets, sizes, outerDimOffsets,
841 outerDimSizes)))
842 return failure();
843
846 outputOffsets, outputSizes)))
847 return failure();
848
849 strides.append(packOp.getDestRank() - inputRank, oneAttr);
850 auto outSlice = b.createtensor::ExtractSliceOp(
851 loc, packOp.getDest(), outputOffsets, outputSizes, strides);
852 tiledOperands.push_back(outSlice);
853
854 assert(!packOp.getPaddingValue() && "Expect no padding semantic");
855 for (auto tile : packOp.getInnerTiles())
856 tiledOperands.push_back(tile);
857
859 loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
860
862 {tiledPackOp},
865 }
866 };
867
868 struct UnpackTileDimInfo {
869 bool isAlignedToInnerTileSize;
874 };
875
876
877
878
879 static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
880 int64_t tileDim,
883 UnpackTileDimInfo info;
887 unpackOp.getDimAndTileMapping();
888
889 if (!dimAndTileMapping.count(tileDim)) {
890 info.isAlignedToInnerTileSize = true;
891 info.sourceOffset = tileOffset;
892 info.sourceSize = tileSize;
893 info.resultOffset = zeroAttr;
894 info.destExpandedSize = tileSize;
895 return info;
896 }
897
898 Location loc = unpackOp.getLoc();
904
905 OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
906
907 info.isAlignedToInnerTileSize = false;
910 nullptr, true);
911 std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
912 if (!failed(cstSize) && cstInnerSize) {
913 if (*cstSize % *cstInnerSize == 0)
914 info.isAlignedToInnerTileSize = true;
915
916
917
918 if (*cstInnerSize == *cstSize) {
919 auto lhs = AV(dim0).bind(tileOffset);
920 auto rhs = AV(dim1).bind(innerTileSize);
921 info.sourceOffset = ab.floor(lhs, rhs);
922 info.sourceSize = oneAttr;
923 info.resultOffset = zeroAttr;
924 info.destExpandedSize = tileSize;
925 return info;
926 }
927 }
928
929 if (info.isAlignedToInnerTileSize) {
930 info.sourceOffset =
931 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
932 info.resultOffset = zeroAttr;
933 info.destExpandedSize = tileSize;
934
935
936
937
938
939
940
941 info.sourceSize =
942 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
943 return info;
944 }
945
950 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
952 b, loc,
954 b, loc,
955 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
957
959 AV(dim1).bind(firstCoord.quotient));
960 info.sourceSize =
961 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
962 info.sourceOffset = firstCoord.quotient;
963 info.resultOffset = firstCoord.remainder;
964
965
966 info.destExpandedSize = b.createOrFoldarith::MulIOp(
969 return info;
970 }
971
972 struct UnPackOpTiling
973 : public TilingInterface::ExternalModel<UnPackOpTiling, linalg::UnPackOp> {
974
976 auto unpackOp = cast(op);
978 unpackOp.getDestRank(), utils::IteratorType::parallel);
979 return iteratorTypes;
980 }
981
983 return getPackUnPackIterationDomain(cast(op), b);
984 }
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000 FailureOr
1004 auto unpackOp = cast(op);
1005 int64_t srcRank = unpackOp.getSourceRank();
1006 int64_t destRank = unpackOp.getDestRank();
1007 int64_t numInnerTiles = srcRank - destRank;
1008 Location loc = unpackOp.getLoc();
1009
1010
1011
1012
1013 bool isPerfectTilingCase = true;
1018 for (auto dim : llvm::seq<int64_t>(0, destRank)) {
1019 UnpackTileDimInfo info =
1020 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
1021 if (!info.isAlignedToInnerTileSize)
1022 isPerfectTilingCase = false;
1023 sliceSrcIndices.push_back(info.sourceOffset);
1024 sliceSrcSizes.push_back(info.sourceSize);
1025 destExpandedSizes.push_back(info.destExpandedSize);
1026 resultOffsetsFromDest.push_back(info.resultOffset);
1027 }
1028
1029
1030
1031 applyPermToRange(sliceSrcIndices, sliceSrcSizes,
1032 unpackOp.getOuterDimsPerm());
1034 sliceSrcIndices.append(numInnerTiles, zeroAttr);
1035 sliceSrcSizes.append(unpackOp.getMixedTiles());
1036 sliceSrcStrides.append(numInnerTiles, oneAttr);
1038 tensor::ExtractSliceOp sliceSource = b.createtensor::ExtractSliceOp(
1039 loc, unpackOp.getSource(), sliceSrcIndices, sliceSrcSizes,
1040 sliceSrcStrides);
1041 generatedSlices.push_back(sliceSource);
1042
1044 Value sliceDest;
1045 if (isPerfectTilingCase) {
1046 auto destSliceOp = b.createtensor::ExtractSliceOp(
1047 loc, unpackOp.getDest(), offsets, sizes, destStrides);
1048 sliceDest = destSliceOp;
1049 generatedSlices.push_back(destSliceOp);
1050 } else {
1051 sliceDest = b.createtensor::EmptyOp(
1052 loc, destExpandedSizes, unpackOp.getDestType().getElementType());
1053 }
1054
1055 SmallVector tiledOperands = {sliceSource.getResult(), sliceDest};
1056 for (auto tile : unpackOp.getInnerTiles())
1057 tiledOperands.push_back(tile);
1058
1060 loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
1061
1062 if (isPerfectTilingCase)
1065 generatedSlices};
1066
1067 auto extractSlice = b.createtensor::ExtractSliceOp(
1068 loc, tiledUnpackOp->getResult(0), resultOffsetsFromDest, sizes,
1069 destStrides);
1071 {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
1072 }
1073
1074 LogicalResult
1080 resultOffsets = llvm::to_vector(offsets);
1081 resultSizes = llvm::to_vector(sizes);
1082 return success();
1083 }
1084
1085 FailureOr
1086 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
1089 FailureOr tilingResult =
1091 if (failed(tilingResult))
1092 return failure();
1093 return tilingResult.value();
1094 }
1095
1096
1097
1098 LogicalResult getIterationDomainTileFromOperandTile(
1103 auto unPackOp = cast(op);
1104
1105 if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
1106 resultOffsets = llvm::to_vector(offsets);
1107 resultSizes = llvm::to_vector(sizes);
1108 return success();
1109 }
1110 Location loc = unPackOp.getLoc();
1111
1112 int64_t numTiles = unPackOp.getInnerDimsPos().size();
1113 auto destOffsets = offsets.drop_back(numTiles);
1114 auto destSizes = sizes.drop_back(numTiles);
1115
1116
1117 int64_t outputRank = unPackOp.getDestRank();
1119 if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes)))
1120 return failure();
1124 applyPermToRange(origOffsets, origSizes,
1126
1128 unPackOp.getDimAndTileMapping();
1129
1130 for (auto dim : llvm::seq<int64_t>(0, outputRank)) {
1136 if (dimAndTileMapping.count(dim)) {
1137
1138
1139
1140 auto avOffset = AV(dim0).bind(origOffsets[dim]);
1141 auto avSize = AV(dim0).bind(origSizes[dim]);
1142 auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]);
1143 auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]);
1144 resultOffsets.push_back(ab.mul(avOffset, avTileSize));
1145 auto avResultOffset = AV(dim1).bind(resultOffsets.back());
1146 resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize),
1147 ab.sub(avResultSize, avResultOffset)}));
1148 } else {
1149 resultOffsets.push_back(origOffsets[dim]);
1150 resultSizes.push_back(origSizes[dim]);
1151 }
1152 }
1153 return success();
1154 }
1155
1156
1157 FailureOr getTiledImplementationFromOperandTile(
1160 auto unPackOp = cast(op);
1161
1162
1163 int64_t numTiles = unPackOp.getInnerDimsPos().size();
1164 for (auto iter :
1165 llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
1167 return failure();
1168 }
1169
1170 Location loc = unPackOp.getLoc();
1171
1172
1173
1175 if (failed(getIterationDomainTileFromOperandTile(
1176 op, b, 0, offsets, sizes, outputOffsets,
1177 outputSizes)))
1178 return failure();
1179
1181 int64_t outputRank = unPackOp.getDestRank();
1183
1185
1186 auto extractDestSlice = b.createtensor::ExtractSliceOp(
1187 loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
1188 tiledOperands.push_back(extractDestSlice);
1189
1190 strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
1191
1192 auto extractSourceSlice = b.createtensor::ExtractSliceOp(
1193 loc, unPackOp.getSource(), offsets, sizes, strides);
1194 tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
1195 for (auto tile : unPackOp.getInnerTiles())
1196 tiledOperands.push_back(tile);
1197
1198
1200 b.create(loc, TypeRange{extractDestSlice.getType()},
1201 tiledOperands, op->getAttrs());
1202
1206 extractSourceSlice, extractDestSlice})};
1207 }
1208 };
1209
1210 }
1211
1212 template
1214 OpType::template attachInterface<LinalgOpTilingInterface>(*ctx);
1215 OpType::template attachInterface<LinalgOpPartialReductionInterface>(
1216 *ctx);
1217 }
1218
1219
1220 template <typename... OpTypes>
1222 (registerOne(ctx), ...);
1223 }
1224
1225 #define GET_OP_LIST
1226
1230 registerOnelinalg::GenericOp(ctx);
1231 linalg::PackOp::attachInterface(*ctx);
1232 linalg::UnPackOp::attachInterface(*ctx);
1234 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1235 >(ctx);
1236 });
1237 }
1238
1242 linalg::PackOp::attachInterface(*ctx);
1243 linalg::UnPackOp::attachInterface(*ctx);
1244 });
1245 }
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static SmallVector< Value > getIndicesForAccess(OpBuilder &b, Location loc, AffineMap indexingMap, ValueRange ivs)
Return the SSA values that represent the data point accessed using a given indexingMap for a given po...
static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, ValueRange ivs, ValueRange argValues)
Method to inline the payload of a linalgOp given the iteration space point and values for the argumen...
static void registerAll(MLIRContext *ctx)
Variadic helper function.
static void registerOne(MLIRContext *ctx)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpOperand & getOpOperand(unsigned idx)
void setOperand(unsigned idx, Value value)
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, StopConditionFn stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs)
Create IR to calculate (div lhs, rhs) and (mod lhs, rhs).
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
void registerTilingInterfaceExternalModelsForPackUnPackOps(DialectRegistry ®istry)
Similar to the above registeration, but it is only for tensor.pack and tensor.unpack ops.
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
void registerTilingInterfaceExternalModels(DialectRegistry ®istry)
SmallVector< Type > getTensorOutputTypes(LinalgOp op, ValueRange operands)
Returns the list of tensor output types produced when the given structured operation op is applied to...
SliceParameters computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Computes SliceParameters for a single valueToTile assuming that its user is being tiled with the give...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Container for the result of merge operation of tiling.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Container for result values of tiling.
SmallVector< Operation * > tiledOps
Helper struct to build simple AffineValueExprs with minimal type inference support.
Holds the result of (div a, b) and (mod a, b).
A struct containg offsets-sizes-strides arguments of the tiled shape.
SmallVector< OpFoldResult > sizes
SmallVector< OpFoldResult > offsets