MLIR: lib/Dialect/Tensor/IR/TensorOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
33 #include "llvm/ADT/DenseSet.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallBitVector.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/LogicalResult.h"
39 #include "llvm/Support/MathExtras.h"
40 #include
41 #include
42 #include
43
44 using namespace mlir;
46
47 using llvm::divideCeilSigned;
48 using llvm::divideFloorSigned;
49 using llvm::mod;
50
51
52
56 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
57 return op;
58 if (complex::ConstantOp::isBuildableWith(value, type))
59 return builder.createcomplex::ConstantOp(loc, type,
60 llvm::cast(value));
61 return nullptr;
62 }
63
65 int64_t dim) {
66 auto tensorType = llvm::cast(value.getType());
67 if (tensorType.isDynamicDim(dim))
68 return builder.createOrFoldtensor::DimOp(loc, value, dim);
69
70 return builder.getIndexAttr(tensorType.getDimSize(dim));
71 }
72
75 auto tensorType = llvm::cast(value.getType());
77 for (int64_t i = 0; i < tensorType.getRank(); ++i)
78 result.push_back(getMixedSize(builder, loc, value, i));
79 return result;
80 }
81
84 auto tensorType = llvm::dyn_cast(opResult.getType());
85 assert(tensorType && "expected tensor type");
86
87
88
89 auto destOp = opResult.getDefiningOp();
90 if (destOp)
91 return destOp.getTiedOpOperand(opResult)->get();
92
93
96
97
99 if (!tensorType.hasStaticShape()) {
100
103 return failure();
105 } else {
106
107 for (int64_t sz : tensorType.getShape())
109 }
110
111
112 Value emptyTensor =
113 b.createtensor::EmptyOp(loc, mixedSizes, tensorType.getElementType());
114 return emptyTensor;
115 }
116
121 if (llvm::isa(opResult.getType())) {
123 if (failed(destination))
124 return failure();
125 result.push_back(*destination);
126 }
127 }
128 return success();
129 }
130
132 if (auto rtp1 = llvm::dyn_cast(tp1)) {
133 if (auto rtp2 = llvm::dyn_cast(tp2))
134 return rtp1.getShape() == rtp2.getShape() &&
135 rtp1.getElementType() == rtp2.getElementType();
136 return false;
137 }
138 return tp1 == tp2;
139 }
140
141
142
145 llvm::SmallBitVector droppedDims(mixedSizes.size());
146 int64_t shapePos = reducedShape.size() - 1;
147
148 for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
149 size_t idx = mixedSizes.size() - size.index() - 1;
150
151 bool isStaticUnitSize =
152 isa(size.value()) &&
153 llvm::cast(cast(size.value())).getInt() == 1;
154
155 if (shapePos < 0) {
156
157
158 assert(isStaticUnitSize && "expected unit dim");
159 droppedDims.set(idx);
160 continue;
161 }
162
163
164 if (!isStaticUnitSize) {
165 --shapePos;
166 continue;
167 }
168
169
170 if (reducedShape[shapePos] == 1) {
171 --shapePos;
172 continue;
173 }
174
175
176 droppedDims.set(idx);
177 }
178
179 assert(shapePos < 0 && "dimension mismatch");
180 return droppedDims;
181 }
182
183
184
185
186 static RankedTensorType
190 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
191 "incorrect number of dynamic sizes");
192
193
194 unsigned ctr = 0;
195 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
196 if (type.isDynamicDim(i)) {
197 Value dynamicSize = dynamicSizes[ctr++];
199 if (cst.has_value()) {
200
201 if (cst.value() < 0) {
202 foldedDynamicSizes.push_back(dynamicSize);
203 continue;
204 }
205 staticShape[i] = *cst;
206 } else {
207 foldedDynamicSizes.push_back(dynamicSize);
208 }
209 }
210 }
211
213 type.getEncoding());
214 }
215
216
217
218
219
220 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
221 if (inputs.size() != 1 || outputs.size() != 1)
222 return false;
223 Type a = inputs.front(), b = outputs.front();
224 auto aT = dyn_cast(a);
225 auto bT = dyn_cast(b);
226 if (!aT || !bT)
227 return false;
228
229 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
230 return false;
231
233 }
234
235 namespace {
236
237
238
239 struct ChainedTensorBitcast : public OpRewritePattern {
241
242 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
244 auto tensorBitcastOperand =
245 tensorBitcast.getOperand().getDefiningOp();
246 if (!tensorBitcastOperand)
247 return failure();
248
249 auto resultType = cast(tensorBitcast.getType());
250 rewriter.replaceOpWithNewOp(tensorBitcast, resultType,
251 tensorBitcastOperand.getOperand());
252 return success();
253 }
254 };
255
256 }
257
258 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
260 results.add(context);
261 }
262
263
264
265
266
267 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
268 setNameFn(getResult(), "cast");
269 }
270
271
272
274 auto sourceType = llvm::dyn_cast(source);
275 auto targetType = llvm::dyn_cast(target);
276
277
278 if (!sourceType || !targetType)
279 return false;
280
281
282 if (sourceType.getElementType() != targetType.getElementType())
283 return false;
284
285
286 if (sourceType.getRank() != targetType.getRank())
287 return false;
288
289
290 if (sourceType.getEncoding() != targetType.getEncoding())
291 return false;
292
293
294 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
295 if (!ShapedType::isDynamic(std::get<0>(t)) &&
296 ShapedType::isDynamic(std::get<1>(t)))
297 return false;
298 }
299
300 return true;
301 }
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
326 if (!castOp)
327 return false;
328
329
330
332 castOp.getSource().getType());
333 }
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
357 if (!castOp)
358 return false;
360 castOp.getType());
361 }
362
365 if (llvm::isa(opOperand.get()))
366 return false;
367 auto castOp = opOperand.get().getDefiningOptensor::CastOp();
368 return castOp && canFoldIntoConsumerOp(castOp);
369 });
370 }
371
375 newOperands.reserve(op->getNumOperands());
376
378
379
380 int64_t dpsInitIdx = 0;
381 for (OpOperand &opOperand : op->getOpOperands()) {
382 auto tensorCastOp = opOperand.get().getDefiningOptensor::CastOp();
384 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
385 if (op.isDpsInit(&opOperand) &&
386 !llvm::isa(newOperands.back().getType()))
387 newResTy[dpsInitIdx++] = newOperands.back().getType();
388 }
389 return newOperands;
390 }
391
392
393
395 bool folded = false;
397 auto castOp = operand.get().getDefiningOptensor::CastOp();
399 operand.set(castOp.getOperand());
400 folded = true;
401 }
402 }
403 return success(folded);
404 }
405
406 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
407 if (inputs.size() != 1 || outputs.size() != 1)
408 return false;
409 Type a = inputs.front(), b = outputs.front();
410 auto aT = llvm::dyn_cast(a);
411 auto bT = llvm::dyn_cast(b);
412 if (!aT || !bT)
413 return false;
414
415 if (aT.getElementType() != bT.getElementType())
416 return false;
417
419 }
420
421
422
425
427 return two;
429 return one;
430
431 int64_t rank = one.getRank();
432 if (rank != two.getRank())
433 return {};
434
436 join.reserve(rank);
437 for (int64_t i = 0; i < rank; ++i) {
438 if (one.isDynamicDim(i)) {
439 join.push_back(two.getDimSize(i));
440 continue;
441 }
442 if (two.isDynamicDim(i)) {
443 join.push_back(one.getDimSize(i));
444 continue;
445 }
446 if (one.getDimSize(i) != two.getDimSize(i))
447 return {};
448 join.push_back(one.getDimSize(i));
449 }
451 }
452
453 namespace {
454
455
456
459
460 LogicalResult matchAndRewrite(CastOp tensorCast,
462 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp();
463
464 if (!tensorCastOperand)
465 return failure();
466
467 auto sourceType =
468 llvm::cast(tensorCastOperand.getOperand().getType());
469 auto intermediateType = llvm::cast(tensorCastOperand.getType());
470 auto resultType = llvm::cast(tensorCast.getType());
471
472
473
474 auto firstJoin =
476
477
478 if (!firstJoin)
479 return failure();
480
481
482
483
484 auto newJoin = joinShapes(sourceType, resultType);
485 if (firstJoin != newJoin)
486 return failure();
487
488 rewriter.replaceOpWithNewOp(tensorCast, resultType,
489 tensorCastOperand.getOperand());
490 return success();
491 }
492 };
493
494
495
496
497
498
499
500
501
502
503
504
505
506 struct TensorCastExtractSlice : public OpRewritePattern {
508
509 LogicalResult matchAndRewrite(CastOp tensorCast,
511 auto extractOperand =
512 tensorCast.getOperand().getDefiningOp();
513
514
515 auto rankedResultType =
516 llvm::dyn_cast(tensorCast.getType());
517 if (!rankedResultType)
518 return failure();
519
521 rankedResultType.getShape() ==
522 llvm::cast(tensorCast.getSource().getType())
523 .getShape())
524 return failure();
525
528 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
529 size_t dimIndex = 0;
530 for (size_t i = 0, e = sizes.size(); i < e; i++) {
531 if (dimMask && dimMask->count(i))
532 continue;
533 int64_t dim = rankedResultType.getShape()[dimIndex++];
534 if (ShapedType::isDynamic(dim))
535 continue;
536 sizes[i] = rewriter.getIndexAttr(dim);
537 }
538
539 rewriter.replaceOpWithNewOp(
540 tensorCast, rankedResultType, extractOperand.getSource(),
541 extractOperand.getMixedOffsets(), sizes,
542 extractOperand.getMixedStrides());
543 return success();
544 }
545 };
546
547 }
548
549 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
551 results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
552 }
553
554
555
556
557
558 RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
559 assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
560 auto tensorTypes =
561 llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
562 return llvm::cast(type);
563 }));
564 int64_t concatRank = tensorTypes[0].getRank();
565
566
567 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
568
570 for (int64_t i = 0, e = concatRank; i < e; ++i) {
571 if (i == dim)
572 continue;
574 for (auto tensorType : tensorTypes)
577 }
579 for (auto tensorType : tensorTypes)
580 concatSize =
582 sizes[dim] = concatSize.asInteger();
584 }
585
588 FailureOr resultType =
589 inferResultType(dim, inputs.getTypes());
590 assert(succeeded(resultType) && "failed to infer concatenation result type");
591 build(builder, result, *resultType, dim, inputs);
592 }
593
595 if (getInputs().size() < 1)
596 return emitOpError("requires at least one input");
597
599 for (auto input : getInputs())
600 inputTypes.push_back(cast(input.getType()));
601
602 RankedTensorType resultType = getResultType();
603 int64_t resultRank = getRank();
604 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
605 return type.getRank() != resultRank;
606 }))
607 return emitOpError("rank of concatenated inputs must match result rank");
608
609 Type resultElementType = resultType.getElementType();
610 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
611 return type.getElementType() != resultElementType;
612 }))
613 return emitOpError("inputs and result element type must match");
614
615 int64_t dim = getDim();
616 if (dim >= resultRank)
617 return emitOpError("concatenation dim must be less than the tensor rank");
618
620 for (int64_t i = 0, e = resultRank; i < e; ++i) {
621 if (i == dim)
622 continue;
624 for (auto tensorType : inputTypes) {
625 FailureOr maybeSize =
627 if (failed(maybeSize))
628 return emitOpError("static concatenation size mismatch along ")
629 << "non-concatenated dimension " << i;
630 size = *maybeSize;
631 }
633 }
635 for (auto tensorType : inputTypes)
636 concatSize =
638 sizes[dim] = concatSize.asInteger();
639 auto inferredResultType =
641
642 for (auto [inferredSize, actualSize] :
643 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
644 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
645 ShapedType::isDynamic(actualSize);
646 if (!hasDynamic && inferredSize != actualSize)
647 return emitOpError("result type ")
648 << resultType << "does not match inferred shape "
649 << inferredResultType << " static sizes";
650 }
651
652 return success();
653 }
654
655 FailureOr<SmallVector> ConcatOp::decomposeOperation(OpBuilder &builder) {
656 size_t numInputs = getInputs().size();
657 uint64_t concatDim = getDim();
658
660 inputShapes.reserve(numInputs);
662 concatOffsets.reserve(numInputs);
664
669 for (auto [index, input] : llvm::enumerate(getInputs())) {
672 if (index == 0) {
673 outputShape = inputShape;
674 concatOffsets.push_back(zero);
675 } else {
676 concatOffsets.push_back(outputShape[concatDim]);
678 builder, loc, addExpr,
679 {outputShape[concatDim], inputShape[concatDim]});
680 }
681 inputShapes.emplace_back(std::move(inputShape));
682 }
683
684 Value replacement = builder.createtensor::EmptyOp(
685 loc, outputShape, getType().getElementType());
686
687 int64_t rank = getType().getRank();
691 for (auto [index, input] : llvm::enumerate(getInputs())) {
692 offsets[concatDim] = concatOffsets[index];
693 auto insertSlice = builder.createtensor::InsertSliceOp(
694 loc, input, replacement, offsets, inputShapes[index], strides);
695 replacement = insertSlice.getResult();
696 }
697 if (replacement.getType() != getType()) {
698 replacement = builder.createtensor::CastOp(loc, getType(), replacement);
699 }
701 }
702
703 LogicalResult
707 int64_t dim = getDim();
708 RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
709
710 Value init = inputs[0];
711 int64_t rank = getType().getRank();
712
714
715
716
717
718 for (int64_t i = 0; i < rank; ++i) {
719 if (i == dim)
720 continue;
721 if (().isDynamicDim(i)) {
722 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
723 } else if (!inferredResultType.isDynamicDim(i)) {
725 builder, getLoc(),
726 builder.getIndexAttr(inferredResultType.getDimSize(i)));
727 } else {
728 reifiedReturnShapes[0][i] =
729 builder.createtensor::DimOp(init.getLoc(), init, i).getResult();
730 }
731 }
732
733 if (getType().isDynamicDim(dim)) {
734
738 for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
740 sizes.push_back(
741 builder.createOrFoldtensor::DimOp(input.getLoc(), input, dim));
742 }
744 builder, getLoc(),
746 } else {
747
748
749 reifiedReturnShapes[0][dim] =
751 }
752 return success();
753 }
754
755 void ConcatOp::getAsmResultNames(
757 setNameFn(getResult(), "concat");
758 }
759
762 if (inputs.size() == 1 && inputs[0].getType() == getResultType())
763 return inputs[0];
764 return {};
765 }
766
767 namespace {
768
769 struct SingleInputConcatOp : public OpRewritePattern {
771
772 LogicalResult matchAndRewrite(ConcatOp concatOp,
774 if (concatOp.getInputs().size() != 1)
775 return failure();
776 rewriter.replaceOpWithNewOp(concatOp, concatOp.getResultType(),
777 concatOp.getInputs()[0]);
778 return success();
779 }
780 };
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801 struct InferConcatOperandTypes : public OpRewritePattern {
803
804 LogicalResult matchAndRewrite(ConcatOp concatOp,
806 int64_t dim = concatOp.getDim();
807 RankedTensorType inferredResultType =
808 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
809
810
811 LogicalResult matched = failure();
812
813
815 for (auto [operandIdx, operandType] :
817
818 inferredOperandShape[dim] =
819 cast(operandType).getDimSize(dim);
821 inferredOperandShape, inferredResultType.getElementType());
822
823
825 matched = success();
826
827
828 auto castOp =
829 rewriter.create(concatOp->getLoc(), inferredOperandType,
830 concatOp.getOperand(operandIdx));
831 rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
832 concatOp->setOperand(operandIdx, castOp->getResult(0));
833 });
834 }
835 }
836
837 return matched;
838 }
839 };
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855 struct InferConcatResultType : public OpRewritePattern {
857
858 LogicalResult matchAndRewrite(ConcatOp concatOp,
860 int64_t dim = concatOp.getDim();
861 RankedTensorType inferredResultType =
862 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
863
864
866 concatOp.getResultType())) {
867 return failure();
868 }
869
870 auto newConcatOp = rewriter.create(
871 concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
872 rewriter.replaceOpWithNewOp(concatOp, concatOp.getResultType(),
873 newConcatOp);
874
875 return success();
876 }
877 };
878 }
879
880 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
882 results
883 .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
884 context);
885 }
886
887
888
889
890
891 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
892 setNameFn(getResult(), "dim");
893 }
894
896 int64_t index) {
898 Value indexValue = builder.createarith::ConstantIndexOp(loc, index);
899 build(builder, result, source, indexValue);
900 }
901
902 std::optional<int64_t> DimOp::getConstantIndex() {
904 }
905
910
911 auto rankedSourceType = dyn_cast(getSource().getType());
912 if (!rankedSourceType)
914
917
919 }
920
923 setResultRange(getResult(),
925 }
926
927 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
928
929 auto index = llvm::dyn_cast_if_present(adaptor.getIndex());
930 if (!index)
931 return {};
932
933
934 auto tensorType = llvm::dyn_cast(getSource().getType());
935 if (!tensorType)
936 return {};
937
938
939
940 int64_t indexVal = index.getInt();
941 if (indexVal < 0 || indexVal >= tensorType.getRank())
942 return {};
943
944
945 if (!tensorType.isDynamicDim(index.getInt())) {
947 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
948 }
949
950 Operation *definingOp = getSource().getDefiningOp();
951
952
953 if (auto fromElements = dyn_cast_or_nulltensor::GenerateOp(definingOp)) {
954 auto resultType =
955 llvm::cast(fromElements.getResult().getType());
956
957
958 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
959
960
961 auto dynExtents = fromElements.getDynamicExtents().begin();
962 for (auto dim : resultType.getShape().take_front(index.getInt()))
963 if (ShapedType::isDynamic(dim))
964 dynExtents++;
965
966 return Value{*dynExtents};
967 }
968
969
970 unsigned unsignedIndex = index.getValue().getZExtValue();
971
972 if (auto sliceOp = dyn_cast_or_nulltensor::ExtractSliceOp(definingOp)) {
973
974
975 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
976 sliceOp.isDynamicSize(unsignedIndex)) {
977 return {sliceOp.getDynamicSize(unsignedIndex)};
978 }
979 }
980
981
983 return getResult();
984
985 return {};
986 }
987
988 namespace {
989
992
993 LogicalResult matchAndRewrite(DimOp dimOp,
995 auto castOp = dimOp.getSource().getDefiningOp();
996 if (!castOp)
997 return failure();
998 Value newSource = castOp.getOperand();
999 rewriter.replaceOpWithNewOp(dimOp, newSource, dimOp.getIndex());
1000 return success();
1001 }
1002 };
1003
1004
1005
1008
1009 LogicalResult matchAndRewrite(DimOp dimOp,
1011 auto source = dimOp.getSource();
1012 auto destOp = source.getDefiningOp();
1013 if (!destOp)
1014 return failure();
1015
1016 auto resultIndex = cast(source).getResultNumber();
1017 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1018
1020 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1021 return success();
1022 }
1023 };
1024
1025
1026
1029
1030 LogicalResult matchAndRewrite(DimOp dim,
1032 auto reshape = dim.getSource().getDefiningOp();
1033
1034 if (!reshape)
1035 return failure();
1036
1037
1038
1040 Location loc = dim.getLoc();
1042 rewriter.create(loc, reshape.getShape(), dim.getIndex());
1043 if (extract.getType() != dim.getType())
1044 extract =
1045 rewriter.createarith::IndexCastOp(loc, dim.getType(), extract);
1046 rewriter.replaceOp(dim, extract);
1047 return success();
1048 }
1049 };
1050 }
1051
1052 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1054 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1055 }
1056
1057
1058
1059
1060
1064 assert(none_of(staticShape, ShapedType::isDynamic) &&
1065 "expected only static sizes");
1066 build(builder, result, staticShape, elementType, ValueRange{}, encoding);
1067 }
1068
1073 build(builder, result, tensorType, dynamicSizes);
1074 }
1075
1082 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1083 }
1084
1087 return emitOpError("incorrect number of dynamic sizes, has ")
1089 << getType().getNumDynamicDims();
1090 return success();
1091 }
1092
1093 LogicalResult
1097 unsigned ctr = 0;
1098 for (int64_t i = 0; i < getType().getRank(); ++i) {
1099 if (getType().isDynamicDim(i)) {
1101 } else {
1102 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1103 }
1104 }
1105 return success();
1106 }
1107
1108 Value EmptyOp::getDynamicSize(unsigned idx) {
1109 assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1110 unsigned ctr = 0;
1111 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1112 if (getType().isDynamicDim(i))
1113 ++ctr;
1115 }
1116
1119 unsigned ctr = 0;
1121 for (int64_t i = 0; i < getType().getRank(); ++i) {
1122 if (getType().isDynamicDim(i)) {
1124 } else {
1125 result.push_back(b.getIndexAttr(getType().getShape()[i]));
1126 }
1127 }
1128 return result;
1129 }
1130
1131 namespace {
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern {
1145
1146 LogicalResult matchAndRewrite(EmptyOp op,
1150 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1151
1152
1153 if (foldedTensorType == op.getType())
1154 return failure();
1155
1156 auto newOp = rewriter.create(op.getLoc(), foldedTensorType,
1157 foldedDynamicSizes);
1158 rewriter.replaceOpWithNewOptensor::CastOp(op, op.getType(), newOp);
1159 return success();
1160 }
1161 };
1162
1163 struct FoldEmptyTensorWithDimOp : public OpRewritePattern {
1165
1166 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1168 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1169 auto emptyTensorOp = dimOp.getSource().getDefiningOp();
1170 if (!emptyTensorOp || !maybeConstantIndex)
1171 return failure();
1172 auto emptyTensorType = emptyTensorOp.getType();
1173 if (*maybeConstantIndex < 0 ||
1174 *maybeConstantIndex >= emptyTensorType.getRank() ||
1175 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1176 return failure();
1178 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1179 return success();
1180 }
1181 };
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198 struct FoldEmptyTensorWithCastOp : public OpRewritePattern {
1200
1201 LogicalResult matchAndRewrite(CastOp castOp,
1204 return failure();
1205 auto producer = castOp.getSource().getDefiningOp();
1206 if (!producer)
1207 return failure();
1208
1209 auto resultType =
1210 llvm::cast(castOp->getResult(0).getType());
1214 newMixedSizes.reserve(currMixedSizes.size());
1215 assert(resultShape.size() == currMixedSizes.size() &&
1216 "mismatch in result shape and sizes of empty op");
1217 for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1218 int64_t newDim = std::get<0>(it);
1220
1221
1222 if (auto attr = llvm::dyn_cast_if_present(currDim)) {
1223 if (ShapedType::isDynamic(newDim) ||
1224 newDim != llvm::cast(attr).getInt()) {
1225
1226
1227
1229 producer, "mismatch in static value of shape of empty tensor "
1230 "result and cast result");
1231 }
1232 newMixedSizes.push_back(attr);
1233 continue;
1234 }
1235
1236
1237
1238 if (!ShapedType::isDynamic(newDim)) {
1239 newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1240 continue;
1241 }
1242
1243
1244
1245 newMixedSizes.push_back(currDim);
1246 }
1247
1248
1250 resultType.getElementType());
1251 return success();
1252 }
1253 };
1254
1255 }
1256
1257 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1259 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1260 ReplaceEmptyTensorStaticShapeDims>(context);
1261 }
1262
1263
1264
1265
1266
1267 namespace {
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277 struct ExtractFromTensorCast : public OpRewritePatterntensor::ExtractOp {
1279
1280 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1282 auto tensorCast = extract.getTensor().getDefiningOptensor::CastOp();
1283 if (!tensorCast)
1284 return failure();
1285 if (!llvm::isa(tensorCast.getSource().getType()))
1286 return failure();
1288 extract, tensorCast.getSource(), extract.getIndices());
1289 return success();
1290 }
1291 };
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303 struct ExtractFromCollapseShape : public OpRewritePatterntensor::ExtractOp {
1305
1306 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1308 auto collapseOp =
1309 extractOp.getTensor().getDefiningOptensor::CollapseShapeOp();
1310 if (!collapseOp)
1311 return failure();
1312 if (!collapseOp.getSrcType().hasStaticShape())
1313 return failure();
1314
1315 auto sourceSizes = collapseOp.getSrcType().getShape();
1316
1318 extractOp.getIndices().end());
1320 for (auto [index, group] :
1321 llvm::zip(indices, collapseOp.getReassociationIndices())) {
1322 assert(!group.empty() && "association indices groups cannot be empty");
1323 auto groupSize = group.size();
1324
1325 if (groupSize == 1) {
1326 sourceIndices.push_back(index);
1327 continue;
1328 }
1329
1331 llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
1332 auto delinearize = rewriter.createaffine::AffineDelinearizeIndexOp(
1333 extractOp.getLoc(), index, basis, true);
1334 llvm::append_range(sourceIndices, delinearize.getResults());
1335 }
1336 if (collapseOp.getReassociationIndices().empty()) {
1338 int64_t srcRank =
1339 cast(collapseOp.getSrcType()).getRank();
1341 rewriter, extractOp.getLoc(), zeroAffineMap,
1343 for (int64_t i = 0; i < srcRank; i++) {
1344 sourceIndices.push_back(
1346 }
1347 }
1348
1350 extractOp, collapseOp.getSrc(), sourceIndices);
1351 return success();
1352 }
1353 };
1354
1355 }
1356
1357 void ExtractOp::getAsmResultNames(
1359 setNameFn(getResult(), "extracted");
1360 }
1361
1363
1364 auto tensorType = llvm::cast(getTensor().getType());
1365 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1366 return emitOpError("incorrect number of indices for extract_element");
1367 return success();
1368 }
1369
1370
1371
1372
1373
1375 auto insertOp = extractOp.getTensor().getDefiningOp();
1376
1377 auto isSame = [](Value a, Value b) {
1379 };
1380 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1381 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1382 return insertOp.getScalar();
1383
1384 return {};
1385 }
1386
1387 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1388 if (Attribute tensor = adaptor.getTensor()) {
1389
1390
1391 if (auto splatTensor = llvm::dyn_cast(tensor))
1392 return splatTensor.getSplatValue<Attribute>();
1393
1394
1395 if (isa(tensor))
1396 return {};
1397 }
1398
1399
1401 for (Attribute indice : adaptor.getIndices()) {
1402 if (!indice || !llvm::isa(indice))
1403 return {};
1404 indices.push_back(llvm::cast(indice).getInt());
1405 }
1406
1407
1408 if (auto fromElementsOp = getTensor().getDefiningOp()) {
1409 auto tensorType = llvm::cast(fromElementsOp.getType());
1410 auto rank = tensorType.getRank();
1411 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1412 "rank mismatch");
1413 int flatIndex = 0;
1414 int stride = 1;
1415 for (int i = rank - 1; i >= 0; --i) {
1416 flatIndex += indices[i] * stride;
1417 stride *= tensorType.getDimSize(i);
1418 }
1419
1420
1421 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1422 flatIndex < 0)
1423 return {};
1424 return fromElementsOp.getElements()[flatIndex];
1425 }
1426
1427
1428 if (Attribute tensor = adaptor.getTensor()) {
1429 auto elementsAttr = llvm::dyn_cast(tensor);
1430 if (elementsAttr && elementsAttr.isValidIndex(indices))
1431 return elementsAttr.getValues<Attribute>()[indices];
1432 }
1433
1435 return result;
1436
1437 return {};
1438 }
1439
1440 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1442 results.add(context);
1443 }
1444
1447 patterns.add(patterns.getContext());
1448 }
1449
1450
1451
1452
1453
1454 void FromElementsOp::getAsmResultNames(
1456 setNameFn(getResult(), "from_elements");
1457 }
1458
1461 assert(!elements.empty() && "expected at least one element");
1463 {static_cast<int64_t>(elements.size())}, elements.front().getType());
1464 build(builder, result, resultType, elements);
1465 }
1466
1467 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1468 if (!llvm::is_contained(adaptor.getElements(), nullptr))
1470 return {};
1471 }
1472
1473 namespace {
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491 struct ExtractElementFromIndexCast
1494
1495 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1497 Location loc = extract.getLoc();
1498 auto indexCast = extract.getTensor().getDefiningOparith::IndexCastOp();
1499 if (!indexCast)
1500 return failure();
1501
1503
1504 auto newExtract = rewriter.createtensor::ExtractOp(
1505 loc, elementTy, indexCast.getIn(), extract.getIndices());
1506
1507 rewriter.replaceOpWithNewOparith::IndexCastOp(extract, extract.getType(),
1508 newExtract);
1509
1510 return success();
1511 }
1512 };
1513
1514 }
1515
1516 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1518 results.add(context);
1519 }
1520
1521
1522
1523
1524
1525 void GatherOp::getAsmResultNames(
1527 setNameFn(getResult(), "gather");
1528 }
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1543 RankedTensorType indicesType,
1545 bool rankReduced) {
1547 resultShape.reserve(resultShape.size() + sourceType.getRank());
1548 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1549 if (llvm::binary_search(gatherDims, idx)) {
1550 if (!rankReduced)
1551 resultShape.push_back(1);
1552 continue;
1553 }
1554 resultShape.push_back(sourceType.getDimSize(idx));
1555 }
1557 }
1558
1559 static LogicalResult
1562 StringRef gatherOrScatter, StringRef sourceOrDest) {
1563 if (dims.empty())
1564 return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1565
1566 int64_t numGatherDims = dims.size();
1567 if (numGatherDims > rank)
1569 << "_dims overflow " << sourceOrDest << " rank";
1570 if (indices.empty() || indices.back() != numGatherDims)
1572 << "_dims length must match the size of last dimension of indices";
1573 for (int64_t val : dims) {
1574 if (val < 0)
1576 << "_dims value must be non-negative";
1577 if (val >= rank)
1579 << "_dims value must be smaller than " << sourceOrDest << " rank";
1580 }
1581 for (int64_t i = 1; i < numGatherDims; ++i) {
1582 if (dims[i - 1] >= dims[i])
1584 << "_dims values must be strictly increasing";
1585 }
1586 return success();
1587 }
1588
1590 int64_t sourceRank = getSourceType().getRank();
1593 getIndicesType().getShape(), sourceRank,
1594 "gather", "source")))
1595 return failure();
1596
1597 RankedTensorType expectedResultType = GatherOp::inferResultType(
1598 getSourceType(), getIndicesType(), gatherDims, false);
1599 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1600 getSourceType(), getIndicesType(), gatherDims, true);
1601 if (getResultType() != expectedResultType &&
1602 getResultType() != expectedRankReducedResultType) {
1603 return emitOpError("result type "
1604 "mismatch: "
1605 "expected ")
1606 << expectedResultType << " or its rank-reduced variant "
1607 << expectedRankReducedResultType << " (got: " << getResultType()
1608 << ")";
1609 }
1610
1611 return success();
1612 }
1613
1614 OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1615 if (OpFoldResult reshapedSource = reshapeConstantSource(
1616 llvm::dyn_cast_if_present(adaptor.getSource()),
1618 return reshapedSource;
1619 return {};
1620 }
1621
1622
1623
1624
1625
1626 void InsertOp::getAsmResultNames(
1628 setNameFn(getResult(), "inserted");
1629 }
1630
1632
1633 auto destType = llvm::cast(getDest().getType());
1634 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1635 return emitOpError("incorrect number of indices");
1636 return success();
1637 }
1638
1639 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1640 Attribute scalar = adaptor.getScalar();
1641 Attribute dest = adaptor.getDest();
1642 if (scalar && dest)
1643 if (auto splatDest = llvm::dyn_cast(dest))
1644 if (scalar == splatDest.getSplatValue<Attribute>())
1645 return dest;
1646 return {};
1647 }
1648
1649
1650
1651
1652
1653 void GenerateOp::getAsmResultNames(
1655 setNameFn(getResult(), "generated");
1656 }
1657
1661 int idx = 0;
1662 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1663 if (getType().isDynamicDim(dim)) {
1664 reifiedReturnShapes[0][dim] = getOperand(idx++);
1665 } else {
1666 reifiedReturnShapes[0][dim] =
1668 }
1669 }
1670 return success();
1671 }
1672
1674
1675
1676 RankedTensorType resultType = llvm::cast(getType());
1677 if (getNumOperands() != resultType.getNumDynamicDims())
1678 return emitError("must have as many index operands as dynamic extents "
1679 "in the result type");
1680 return success();
1681 }
1682
1683 LogicalResult GenerateOp::verifyRegions() {
1684 RankedTensorType resultTy = llvm::cast(getType());
1685
1686 if (!llvm::all_of(getBody().getArgumentTypes(),
1688 return emitError("all body arguments must be index");
1689 if (getBody().getNumArguments() != resultTy.getRank())
1690 return emitError("must have one body argument per input dimension");
1691
1692
1693 auto yieldOp = cast(getBody().getBlocks().front().getTerminator());
1694
1695 if (yieldOp.getValue().getType() != resultTy.getElementType())
1696 return emitOpError(
1697 "body must be terminated with a `yield` operation of the tensor "
1698 "element type");
1699
1700 return success();
1701 }
1702
1703 void GenerateOp::build(
1707 build(b, result, resultTy, dynamicExtents);
1708
1709
1711 Region *bodyRegion = result.regions.front().get();
1712 auto rank = llvm::cast(resultTy).getRank();
1715 Block *bodyBlock =
1716 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1718 }
1719
1720 namespace {
1721
1722
1723
1724
1725
1726 struct StaticTensorGenerate : public OpRewritePattern {
1728
1729 LogicalResult matchAndRewrite(GenerateOp generateOp,
1733 generateOp.getType(), generateOp.getDynamicExtents(),
1734 foldedDynamicSizes);
1735
1736
1737 if (foldedTensorType == generateOp.getType())
1738 return failure();
1739
1740 auto loc = generateOp.getLoc();
1741 auto newOp =
1742 rewriter.create(loc, foldedTensorType, foldedDynamicSizes);
1744 newOp.getBody().begin());
1746 generateOp.getType(), newOp);
1747 return success();
1748 }
1749 };
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762 struct ExtractFromTensorGenerate : public OpRewritePatterntensor::ExtractOp {
1764
1765 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1767 auto tensorFromElements = extract.getTensor().getDefiningOp();
1769 return failure();
1770
1772 Block *body = &tensorFromElements.getBody().front();
1773 mapping.map(body->getArguments(), extract.getIndices());
1775 rewriter.clone(op, mapping);
1776
1777 auto yield = cast(body->getTerminator());
1778
1780 return success();
1781 }
1782 };
1783
1784 }
1785
1786 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1788
1789 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1790 }
1791
1792
1793
1794
1795
1796 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1797 setNameFn(getResult(), "rank");
1798 }
1799
1800 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1801
1802 auto type = getOperand().getType();
1803 auto shapedType = llvm::dyn_cast(type);
1804 if (shapedType && shapedType.hasRank())
1806 return IntegerAttr();
1807 }
1808
1809
1810
1811
1812
1813 void ReshapeOp::getAsmResultNames(
1815 setNameFn(getResult(), "reshape");
1816 }
1817
1819 int64_t numElements = 1;
1820 for (auto dim : type.getShape())
1821 numElements *= dim;
1822 return numElements;
1823 }
1824
1826 TensorType operandType = llvm::cast(getSource().getType());
1827 TensorType resultType = llvm::cast(getResult().getType());
1828
1830 return emitOpError("element types of source and destination tensor "
1831 "types should be the same");
1832
1833 int64_t shapeSize =
1834 llvm::cast(getShape().getType()).getDimSize(0);
1835 auto resultRankedType = llvm::dyn_cast(resultType);
1836 auto operandRankedType = llvm::dyn_cast(operandType);
1837
1838 if (resultRankedType) {
1839 if (operandRankedType && resultRankedType.hasStaticShape() &&
1840 operandRankedType.hasStaticShape()) {
1842 return emitOpError("source and destination tensor should have the "
1843 "same number of elements");
1844 }
1845 if (ShapedType::isDynamic(shapeSize))
1846 return emitOpError("cannot use shape operand with dynamic length to "
1847 "reshape to statically-ranked tensor type");
1848 if (shapeSize != resultRankedType.getRank())
1849 return emitOpError(
1850 "length of shape operand differs from the result's tensor rank");
1851 }
1852 return success();
1853 }
1854
1855 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1856 if (OpFoldResult reshapedSource = reshapeConstantSource(
1857 llvm::dyn_cast_if_present(adaptor.getSource()),
1859 return reshapedSource;
1860
1861
1862
1863
1864 if (auto reshapeOpProducer = getSource().getDefiningOp()) {
1865 getSourceMutable().assign(reshapeOpProducer.getSource());
1866 return getResult();
1867 }
1868
1869 auto source = getSource();
1870 auto sourceTy = dyn_cast(source.getType());
1871 auto resultTy = dyn_cast(getType());
1872 if (!sourceTy || !resultTy || sourceTy != resultTy)
1873 return {};
1874
1875
1876
1877 if (sourceTy.getRank() == 1)
1878 return source;
1879
1880 if (auto fromElements = getShape().getDefiningOptensor::FromElementsOp()) {
1881 auto elements = fromElements.getElements();
1882 bool dynamicNoop =
1883 sourceTy.getRank() == static_cast<int64_t>(elements.size());
1884 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1885 auto element = elements[id];
1886
1888 dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1889 continue;
1890 }
1891
1892 if (auto dimOp = element.getDefiningOptensor::DimOp()) {
1893 dynamicNoop &= dimOp.getSource() == source;
1894
1896 dynamicNoop &=
1897 cst.has_value() && cst.value() == static_cast<int64_t>(id);
1898 continue;
1899 }
1900
1901 dynamicNoop = false;
1902 break;
1903 }
1904
1905 if (dynamicNoop)
1906 return source;
1907 }
1908
1909 return {};
1910 }
1911
1912
1913
1914
1915
1916 void CollapseShapeOp::getAsmResultNames(
1918 setNameFn(getResult(), "collapsed");
1919 }
1920
1921 void ExpandShapeOp::getAsmResultNames(
1923 setNameFn(getResult(), "expanded");
1924 }
1925
1926 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1927 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1928 "invalid resultDim");
1929 for (const auto &it : llvm::enumerate(getReassociationIndices()))
1930 if (llvm::is_contained(it.value(), resultDim))
1931 return it.index();
1932 llvm_unreachable("could not find reassociation group");
1933 }
1934
1935 FailureOr<SmallVector>
1937 RankedTensorType expandedType,
1940 std::optional<SmallVector> outputShape =
1942 inputShape);
1943 if (!outputShape)
1944 return failure();
1945 return *outputShape;
1946 }
1947
1950 }
1951
1956 auto [staticOutputShape, dynamicOutputShape] =
1958 build(builder, result, cast(resultType), src,
1960 dynamicOutputShape, staticOutputShape);
1961 }
1962
1968 auto tensorResultTy = cast(resultType);
1969 FailureOr<SmallVector> outputShape = inferOutputShape(
1970 builder, result.location, tensorResultTy, reassociation, inputShape);
1972 if (succeeded(outputShape)) {
1973 outputShapeOrEmpty = *outputShape;
1974 }
1975 build(builder, result, tensorResultTy, src, reassociation,
1976 outputShapeOrEmpty);
1977 }
1978
1981 }
1984 getReassociationIndices());
1985 }
1986
1989 }
1992 getReassociationIndices());
1993 }
1994
1995 RankedTensorType CollapseShapeOp::inferCollapsedType(
1997 return inferCollapsedType(
1999 type.getContext(), reassociation)));
2000 }
2001
2002
2003
2004 RankedTensorType
2005 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2007 auto shape = type.getShape();
2009 newShape.reserve(reassociation.size());
2010
2011
2012
2014 unsigned currentDim = 0;
2015 for (AffineMap m : reassociation) {
2016 unsigned dim = m.getNumResults();
2017 auto band = shape.slice(currentDim, dim);
2018 int64_t size = 1;
2019 if (llvm::is_contained(band, ShapedType::kDynamic))
2020 size = ShapedType::kDynamic;
2021 else
2022 for (unsigned d = 0; d < dim; ++d)
2023 size *= shape[currentDim + d];
2024 newShape.push_back(size);
2025 currentDim += dim;
2026 }
2027
2029 }
2030
2034 auto resultType = inferCollapsedType(
2035 llvm::cast(src.getType()),
2038 result.addAttribute(getReassociationAttrStrName(),
2040 build(b, result, resultType, src, attrs);
2041 }
2042
2043 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
2044 TensorReshapeOp, ExpandShapeOp>::value>
2046 RankedTensorType expandedType,
2047 RankedTensorType collapsedType) {
2048 if (failed(
2050 return failure();
2051
2052 auto maps = op.getReassociationMaps();
2053 RankedTensorType expectedType =
2054 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2056 return op.emitOpError("expected collapsed type to be ")
2057 << expectedType << ", but got " << collapsedType;
2058 return success();
2059 }
2060
2062 auto srcType = getSrcType();
2063 auto resultType = getResultType();
2064
2065 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2066 return emitOpError("expected number of static shape dims to be equal to "
2067 "the output rank (")
2068 << resultType.getRank() << ") but found "
2069 << getStaticOutputShape().size() << " inputs instead";
2070
2071 if ((int64_t)getOutputShape().size() !=
2072 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2073 return emitOpError("mismatch in dynamic dims in output_shape and "
2074 "static_output_shape: static_output_shape has ")
2075 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2076 << " dynamic dims while output_shape has " << getOutputShape().size()
2077 << " values";
2078
2080 }
2081
2084 }
2085
2086 namespace {
2087
2088
2089 template
2090 struct FoldReshapeWithConstant : OpRewritePattern {
2092 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2096 return failure();
2097 if (!attr || !attr.isSplat())
2098 return failure();
2100 reshapeOp.getResultType(), attr.getRawData());
2102 return success();
2103 }
2104 };
2105
2106
2107 template
2108 class FoldReshapeWithSplat : public OpRewritePattern {
2109 public:
2111
2112 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2114 auto splatOp = reshapeOp.getSrc().template getDefiningOptensor::SplatOp();
2115 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2116 return failure();
2117
2119 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2120 return success();
2121 }
2122 };
2123
2124
2125
2126 template
2127 struct FoldReshapeWithFromElements : OpRewritePattern {
2129 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2131 auto fromElements =
2132 reshapeOp.getSrc().template getDefiningOp();
2133 if (!fromElements)
2134 return failure();
2135
2136 auto shapedTy = llvm::cast(reshapeOp.getType());
2137
2138 if (!shapedTy.hasStaticShape())
2139 return failure();
2140
2141 rewriter.replaceOpWithNewOp(reshapeOp, reshapeOp.getType(),
2142 fromElements.getElements());
2143 return success();
2144 }
2145 };
2146
2147
2148 struct FoldCollapseOfCastOp : public OpRewritePattern {
2150
2151 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2153 auto castOp = collapseShapeOp.getSrc().getDefiningOptensor::CastOp();
2155 return failure();
2156
2157 RankedTensorType srcType =
2158 llvm::cast(castOp.getSource().getType());
2159 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2160 srcType, collapseShapeOp.getReassociationMaps());
2161
2162 if (newResultType == collapseShapeOp.getResultType()) {
2164 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2165 });
2166 } else {
2167 auto newOp = rewriter.create(
2168 collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
2169 collapseShapeOp.getReassociation());
2171 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2172 }
2173 return success();
2174 }
2175 };
2176
2177
2178
2179
2180
2181 struct ConvertToStaticExpandShape : public OpRewritePattern {
2183
2184 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2186 auto castOp = expandOp.getSrc().getDefiningOp();
2188 return failure();
2189
2190 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2192 expandOp.getReassociationIndices();
2193
2196 auto outputIt = expandOp.getOutputShape().begin();
2197
2198 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2199 for (uint64_t outDim : innerReassoc) {
2200 if (!ShapedType::isDynamic(newOutputShape[outDim]))
2201 continue;
2202
2203
2204
2205
2206
2207 Value val = *outputIt;
2208 ++outputIt;
2209 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2210 dynamicOutputShape.push_back(val);
2211 continue;
2212 }
2213
2214 APInt cst;
2216 newOutputShape[outDim] = cst.getSExtValue();
2217 } else {
2218 dynamicOutputShape.push_back(val);
2219 }
2220 }
2221 }
2222
2223
2224 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2225 return failure();
2226
2227
2229 for (auto inDim : llvm::seq(0, newInputShape.size())) {
2230 for (auto outDim : reassoc[inDim]) {
2231 auto ofr = newOutputShape[outDim];
2232 if (ShapedType::isDynamic(ofr)) {
2233 newInputShape[inDim] = ShapedType::kDynamic;
2234 break;
2235 }
2236 newInputShape[inDim] *= ofr;
2237 }
2238 }
2239
2241 getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2243 newInputShape, expandOp.getSrcType().getElementType());
2245 newOutputShape, expandOp.getSrcType().getElementType());
2246 auto inputCast = rewriter.create(expandOp.getLoc(), inputType,
2247 expandOp.getSrc());
2248 auto newExpand = rewriter.create(
2249 expandOp.getLoc(), outputType, inputCast.getResult(),
2250 expandOp.getReassociationIndices(), outputOfr);
2252 newExpand.getResult());
2253 return success();
2254 }
2255 };
2256 }
2257
2258 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2260 results.add<
2263 ConvertToStaticExpandShape, FoldReshapeWithConstant,
2264 FoldReshapeWithSplat,
2265 FoldReshapeWithFromElements>(context);
2266 }
2267
2268 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2270 results.add<
2273 tensor::DimOp, RankedTensorType>,
2274 FoldReshapeWithConstant,
2275 FoldReshapeWithSplat,
2276 FoldReshapeWithFromElements, FoldCollapseOfCastOp>(
2277 context);
2278 }
2279
2280 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2281 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2282 adaptor.getOperands());
2283 }
2284
2285 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2286 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2287 adaptor.getOperands());
2288 }
2289
2290
2291
2292
2293
2294 void ExtractSliceOp::getAsmResultNames(
2296 setNameFn(getResult(), "extracted_slice");
2297 }
2298
2299
2300
2301
2302 RankedTensorType ExtractSliceOp::inferResultType(
2303 RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2305
2306
2307
2308 assert(static_cast<int64_t>(staticSizes.size()) ==
2309 sourceTensorType.getRank() &&
2310 "unexpected staticSizes not equal to rank of source");
2312 sourceTensorType.getEncoding());
2313 }
2314
2315 RankedTensorType ExtractSliceOp::inferResultType(
2323 return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2324 staticSizes, staticStrides);
2325 }
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2336 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2339
2340 auto inferredType = llvm::cast(
2341 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2342 int rankDiff = inferredType.getRank() - desiredResultRank;
2343 if (rankDiff > 0) {
2344 auto shape = inferredType.getShape();
2345 llvm::SmallBitVector dimsToProject =
2348
2349 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2350 if (!dimsToProject.test(pos))
2351 projectedShape.push_back(shape[pos]);
2352 inferredType =
2354 }
2355 return inferredType;
2356 }
2357
2358 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2359 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2367 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2368 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2369 staticStrides);
2370 }
2371
2372
2373
2375 RankedTensorType resultType, Value source,
2385 auto sourceRankedTensorType = llvm::cast(source.getType());
2386
2387 if (!resultType) {
2388 resultType = llvm::cast(ExtractSliceOp::inferResultType(
2389 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2390 }
2392 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2396 }
2397
2398
2399
2405 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2406 }
2407
2408
2409
2414 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2415 }
2416
2417
2418
2420 RankedTensorType resultType, Value source,
2424 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2426 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2428 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2429 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2430 }
2431
2432
2436 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2437 }
2438
2441 RankedTensorType expectedType) {
2442 switch (result) {
2444 return success();
2446 return op->emitError("expected rank to be smaller or equal to ")
2447 << "the other rank. ";
2449 return op->emitError("expected type to be ")
2450 << expectedType << " or a rank-reduced version. (size mismatch) ";
2452 return op->emitError("expected element type to be ")
2453 << expectedType.getElementType();
2454 default:
2455 llvm_unreachable("unexpected extract_slice op verification result");
2456 }
2457 }
2458
2459
2461 RankedTensorType sourceType = getSourceType();
2462
2463
2464 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2465 sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
2469
2470
2471
2473 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2474 getStaticStrides(), true);
2475 if (!boundsResult.isValid)
2476 return getOperation()->emitError(boundsResult.errorMessage);
2477
2478 return success();
2479 }
2480
2483 }
2484
2485 FailureOr
2488 auto sourceTensorType = llvm::dyn_cast(value.getType());
2489 assert(sourceTensorType && "not a ranked tensor type");
2490 auto sourceShape = sourceTensorType.getShape();
2491 if (sourceShape.equals(desiredShape))
2492 return value;
2493 auto maybeRankReductionMask =
2495 if (!maybeRankReductionMask)
2496 return failure();
2498 b, loc, value,
2500 }
2501
2504 reifiedReturnShapes.resize(1);
2505 reifiedReturnShapes[0].reserve(getType().getRank());
2507 llvm::SmallBitVector droppedDims = getDroppedDims();
2508 for (const auto &size : enumerate(mixedSizes)) {
2509 if (droppedDims.test(size.index()))
2510 continue;
2511 reifiedReturnShapes[0].push_back(size.value());
2512 }
2513 return success();
2514 }
2515
2516 namespace {
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532 class ExtractSliceOpCastFolder final : public OpRewritePattern {
2533 public:
2535
2536 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2538
2539 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2540 return matchPattern(operand, matchConstantIndex());
2541 }))
2542 return failure();
2543
2544 auto castOp = sliceOp.getSource().getDefiningOp();
2545 if (!castOp)
2546 return failure();
2547
2549 return failure();
2550
2551
2553 cast(castOp.getSource().getType()).getShape(),
2554 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2555 sliceOp.getStaticStrides());
2556 if (!sliceResult.isValid)
2557 return failure();
2558
2559
2560 Location loc = sliceOp.getLoc();
2561 Value newResult = rewriter.create(
2562 loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2563 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2564 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2565 rewriter.replaceOp(sliceOp, newResult);
2566 return success();
2567 }
2568 };
2569
2570
2571
2572
2573 template <typename IterTy, typename ElemTy>
2574 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2578 assert(offsets.size() == sizes.size());
2579 assert(offsets.size() == strides.size());
2580 if (offsets.empty())
2581 return;
2582
2583 int64_t offset = offsets.front();
2584 int64_t size = sizes.front();
2585 int64_t stride = strides.front();
2586 if (offsets.size() == 1) {
2587 for (int64_t i = 0; i < size; ++i, offset += stride)
2588 outValues->push_back(*(values + offset));
2589
2590 return;
2591 }
2592
2593 for (int64_t i = 0; i < size; ++i, offset += stride) {
2594 auto begin = values + offset * counts.front();
2595 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2596 offsets.drop_front(), sizes.drop_front(),
2597 strides.drop_front(), outValues);
2598 }
2599 }
2600
2601
2602
2603
2604 class ConstantOpExtractSliceFolder final
2606 public:
2608
2609 ConstantOpExtractSliceFolder(MLIRContext *context,
2612 controlFn(std::move(controlFn)) {}
2613
2614 LogicalResult matchAndRewrite(ExtractSliceOp op,
2618 return failure();
2619
2620
2622 return failure();
2623
2624
2625 auto sourceType = llvm::cast(op.getSource().getType());
2626 auto resultType = llvm::cast(op.getResult().getType());
2627 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2628 return failure();
2629
2630
2631 if (!controlFn(op))
2632 return failure();
2633
2634 int64_t count = sourceType.getNumElements();
2635 if (count == 0)
2636 return failure();
2637
2638
2639 auto offsets = op.getStaticOffsets();
2640 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2641 return failure();
2642 auto sizes = op.getStaticSizes();
2643 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2644 return failure();
2645 auto strides = op.getStaticStrides();
2646 if (llvm::is_contained(strides, ShapedType::kDynamic))
2647 return failure();
2648
2649
2652 counts.reserve(shape.size());
2653 for (int64_t v : shape) {
2654 count = count / v;
2655 counts.push_back(count);
2656 }
2657
2658
2660
2661 if (auto elems = llvm::dyn_cast(attr)) {
2663 outValues.reserve(sourceType.getNumElements());
2664 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2665 elems.begin(), counts, offsets, sizes, strides, &outValues);
2667 } else if (auto elems = llvm::dyn_cast(attr)) {
2669 outValues.reserve(sourceType.getNumElements());
2670 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2671 elems.begin(), counts, offsets, sizes, strides, &outValues);
2673 }
2674
2675 if (newAttr) {
2676 rewriter.replaceOpWithNewOparith::ConstantOp(op, resultType, newAttr);
2677 return success();
2678 }
2679
2680 return failure();
2681 }
2682
2683 private:
2684
2685
2687 };
2688
2689 }
2690
2694 patterns.add(patterns.getContext(), controlFn);
2695 }
2696
2697
2703 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2704 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2705 mixedStrides);
2706 }
2707 };
2708
2709
2712 ExtractSliceOp newOp) {
2713 Value replacement = newOp.getResult();
2714 if (replacement.getType() != op.getType())
2715 replacement = rewriter.createtensor::CastOp(op.getLoc(), op.getType(),
2716 replacement);
2717 rewriter.replaceOp(op, replacement);
2718 }
2719 };
2720
2721 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2723 results.add<
2726 ExtractSliceOpCastFolder>(context);
2727 }
2728
2729
2730 static LogicalResult
2732 ShapedType shapedType) {
2734 for (OpFoldResult ofr : op.getMixedOffsets())
2736 return failure();
2737
2738
2739 auto shape = shapedType.getShape();
2740 for (auto it : llvm::zip(op.getMixedSizes(), shape))
2742 return failure();
2743 for (OpFoldResult ofr : op.getMixedStrides())
2745 return failure();
2746 return success();
2747 }
2748
2749
2750
2751
2752
2754 auto insertOp = extractOp.getSource().getDefiningOp();
2755
2757 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2758 insertOp.isSameAs(extractOp, isSame))
2759 return insertOp.getSource();
2760
2761 return {};
2762 }
2763
2764 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2765 if (OpFoldResult reshapedSource = reshapeConstantSource(
2766 llvm::dyn_cast_if_present(adaptor.getSource()),
2768 return reshapedSource;
2769 if (getSourceType() == getType() &&
2771 return this->getSource();
2773 return slice;
2774
2776 }
2777
2780 auto rankedTensorType = llvm::cast(tensor.getType());
2781 unsigned rank = rankedTensorType.getRank();
2785 return b.createOrFoldtensor::ExtractSliceOp(loc, targetType, tensor,
2786 offsets, sizes, strides);
2787 }
2788
2789
2790
2791
2792
2793 void InsertSliceOp::getAsmResultNames(
2795 setNameFn(getResult(), "inserted_slice");
2796 }
2797
2798
2810 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2814 }
2815
2816
2817
2822 build(b, result, source, dest, offsets, sizes, strides, attrs);
2823 }
2824
2825
2830 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2832 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2834 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2835 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2836 }
2837
2838
2839
2841 RankedTensorType srcType, RankedTensorType dstType,
2843 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2844
2845
2846 RankedTensorType expected = ExtractSliceOp::inferResultType(
2847 dstType, staticOffsets, staticSizes, staticStrides);
2848 if (expectedType)
2849 *expectedType = expected;
2851 }
2852
2853
2855
2856 RankedTensorType expectedType;
2859 getStaticSizes(), getStaticStrides(), &expectedType);
2862
2863
2864
2866 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2867 getStaticStrides(), true);
2868 if (!boundsResult.isValid)
2869 return getOperation()->emitError(boundsResult.errorMessage);
2870
2871 return success();
2872 }
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2892 auto prevInsertOp = insertOp.getDest().getDefiningOp();
2893
2895 if (!prevInsertOp ||
2896 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2897 !prevInsertOp.isSameAs(insertOp, isSame))
2898 return failure();
2899
2900 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2901 return success();
2902 }
2903
2904
2905
2906
2907
2908
2909
2910
2912 auto extractOp = insertOp.getSource().getDefiningOp();
2913
2915 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2916 !extractOp.isSameAs(insertOp, isSame))
2917 return nullptr;
2918
2919 return extractOp.getSource();
2920 }
2921
2922 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2923 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2924 getSourceType() == getType() &&
2926 return this->getSource();
2928 return getResult();
2930 return result;
2932 return getDest();
2934 }
2935
2940 return success();
2941 }
2942
2943 namespace {
2944
2945
2946
2947 template
2948 class InsertSliceOpConstantArgumentFolder final
2950 public:
2952
2953 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2958
2959
2963 return failure();
2964
2965
2968 mixedOffsets, mixedSizes, mixedStrides);
2969 if (!sliceResult.isValid)
2970 return failure();
2971
2972
2973 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2974 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2975 mixedOffsets, mixedSizes, mixedStrides);
2976 Value toInsert = insertSliceOp.getSource();
2977 if (sourceType != insertSliceOp.getSourceType()) {
2979
2980
2981
2982 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2984 toInsert = rewriter.createtensor::CastOp(insertSliceOp.getLoc(),
2985 sourceType, toInsert);
2986 }
2988 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2989 mixedSizes, mixedStrides);
2990 return success();
2991 }
2992 };
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014 template
3015 struct InsertSliceOpCastFolder final : public OpRewritePattern {
3017
3018 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3020 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3021 return matchPattern(operand, matchConstantIndex());
3022 }))
3023 return failure();
3024
3025 auto getSourceOfCastOp = [](Value v) -> std::optional {
3026 auto castOp = v.getDefiningOptensor::CastOp();
3028 return std::nullopt;
3029 return castOp.getSource();
3030 };
3031 std::optional sourceCastSource =
3032 getSourceOfCastOp(insertSliceOp.getSource());
3033 std::optional destCastSource =
3034 getSourceOfCastOp(insertSliceOp.getDest());
3035 if (!sourceCastSource && !destCastSource)
3036 return failure();
3037
3038 auto src =
3039 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3040 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3041 auto srcType = llvm::dyn_cast(src.getType());
3042 auto dstType = llvm::dyn_cast(dst.getType());
3043 if (!srcType || !dstType)
3044 return failure();
3045
3046
3047
3048
3051 staticSizes, srcType.getShape(), true);
3052 if (!rankReductionMask.has_value())
3053 return failure();
3054
3055
3056
3057
3058
3060 int64_t rankReducedIdx = 0;
3061 for (auto [idx, size] : enumerate(staticSizes)) {
3062 if (!rankReductionMask.value().contains(idx) &&
3063 !srcType.isDynamicDim(rankReducedIdx)) {
3065 rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
3066 size = srcType.getDimSize(rankReducedIdx++);
3067 }
3068 }
3069
3070
3071 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
3072 staticSizes, insertSliceOp.getStaticStrides()) !=
3074 return failure();
3076 verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
3077 mixedSizes, insertSliceOp.getMixedStrides());
3078 if (!sliceResult.isValid)
3079 return failure();
3080
3082 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
3083 mixedSizes, insertSliceOp.getMixedStrides());
3084
3085
3086 bool isParallelInsert =
3087 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3088 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3089 replacement = rewriter.createtensor::CastOp(insertSliceOp.getLoc(),
3090 insertSliceOp.getDestType(),
3092 }
3094 return success();
3095 }
3096 };
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119 template
3120 struct InsertSliceOpSourceCastInserter final
3123
3124 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3126 RankedTensorType srcType = insertSliceOp.getSourceType();
3127 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3128 return failure();
3130 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3131 if (std::optional<int64_t> constInt =
3133
3134 if (*constInt < 0)
3135 return failure();
3136 newSrcShape[i] = *constInt;
3137 }
3138 }
3140 return failure();
3141
3143 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3144 if (srcType == newSrcType ||
3146 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3147 return failure();
3148
3149
3150
3151
3152
3153
3155
3156
3157
3158 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
3160 Value cast = rewriter.createtensor::CastOp(
3161 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
3163 insertSliceOp, cast, insertSliceOp.getDest(),
3164 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3165 insertSliceOp.getMixedStrides());
3166 return success();
3167 }
3168 };
3169 }
3170
3173 }
3174
3175 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3177 results.add<InsertSliceOpConstantArgumentFolder,
3178 InsertSliceOpCastFolder,
3179 InsertSliceOpSourceCastInserter>(context);
3180 }
3181
3186 auto rankedTensorType = llvm::cast(dest.getType());
3187 unsigned rank = rankedTensorType.getRank();
3191 return b.createOrFoldtensor::InsertSliceOp(loc, tensor, dest, offsets,
3192 sizes, strides);
3193 }
3194
3195
3196
3197
3198
3199 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3200 setNameFn(getResult(), "padded");
3201 }
3202
3203
3204
3206 Type typeToInfer, Type typeToInferFrom) {}
3207
3208 ParseResult
3210 std::optionalOpAsmParser::UnresolvedOperand optOperand,
3211 Type &typeToInfer, Type typeToInferFrom) {
3212 if (optOperand)
3213 typeToInfer = typeToInferFrom;
3214 return success();
3215 }
3216
3218 auto sourceType = llvm::cast(getSource().getType());
3219 auto resultType = llvm::cast(getResult().getType());
3220 auto expectedType =
3221 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3222 if (!expectedType) {
3223 return emitError("failed to infer expectedType from sourceType ")
3224 << sourceType << ", specified resultType is " << resultType;
3225 }
3226 if (resultType.getRank() != expectedType.getRank()) {
3227 return emitError("specified type ")
3228 << resultType << " does not match the inferred type "
3229 << expectedType;
3230 }
3231 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3232 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3233 continue;
3234 if (expectedType.isDynamicDim(i))
3235 continue;
3236 return emitError("specified type ")
3237 << resultType << " does not match the inferred type "
3238 << expectedType;
3239 }
3240
3241 return success();
3242 }
3243
3244 LogicalResult PadOp::verifyRegions() {
3245 auto ®ion = getRegion();
3246 unsigned rank = llvm::cast(getResult().getType()).getRank();
3249 return emitError("expected the block to have ") << rank << " arguments";
3250
3251
3253 if (!en.value().isIndex())
3254 return emitOpError("expected block argument ")
3255 << (en.index() + 1) << " to be an index";
3256 }
3257
3258
3259 auto yieldOp = llvm::cast(block.getTerminator());
3260 if (yieldOp.getValue().getType() !=
3262 return emitOpError("expected yield type to match shape element type");
3263
3264 return success();
3265 }
3266
3267 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3271 unsigned rank = sourceType.getRank();
3272 if (staticLow.size() != rank)
3273 return RankedTensorType();
3274 if (staticHigh.size() != rank)
3275 return RankedTensorType();
3276 if (!resultShape.empty() && resultShape.size() != rank)
3277 return RankedTensorType();
3278
3280 for (auto i : llvm::seq(0, rank)) {
3281 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3282 staticHigh[i] == ShapedType::kDynamic) {
3283 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3284 : resultShape[i]);
3285 } else {
3286 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3287 assert((resultShape.empty() || size == resultShape[i] ||
3288 resultShape[i] == ShapedType::kDynamic) &&
3289 "mismatch between inferred shape and result shape");
3290 inferredShape.push_back(size);
3291 }
3292 }
3293
3295 }
3296
3301 auto sourceType = llvm::cast(source.getType());
3302 if (!resultType)
3303 resultType = inferResultType(sourceType, staticLow, staticHigh);
3305 build(b, result, resultType, source, low, high,
3308 }
3309
3313 auto sourceType = llvm::cast(source.getType());
3314 unsigned rank = sourceType.getRank();
3316 build(b, result, resultType, source, staticVector, staticVector, low, high,
3317 nofold, attrs);
3318 }
3319
3324 auto sourceType = llvm::cast(source.getType());
3327
3328
3329
3330
3333 if (!resultType) {
3334 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3335 }
3336 assert(llvm::isa(resultType));
3338 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3341 }
3342
3347 build(b, result, resultType, source, low, high, nofold, attrs);
3348
3349
3351 int sourceRank = llvm::cast(source.getType()).getRank();
3354
3355
3356
3358 b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3359 b.createtensor::YieldOp(result.location, constantPadValue);
3360 }
3361
3362 llvm::SmallBitVector PadOp::getPaddedDims() {
3363 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3365 for (const auto &en : enumerate(paddingWidths))
3367 paddedDims.set(en.index());
3368 };
3369 extractPaddedDims(getMixedLowPad());
3370 extractPaddedDims(getMixedHighPad());
3371 return paddedDims;
3372 }
3373
3374 namespace {
3375
3376
3377 struct FoldStaticZeroPadding : public OpRewritePattern {
3379
3380 LogicalResult matchAndRewrite(PadOp padTensorOp,
3382 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3383 return failure();
3384 if (padTensorOp.getNofold())
3385 return failure();
3387 padTensorOp, padTensorOp.getResult().getType(),
3388 padTensorOp.getSource());
3389 return success();
3390 }
3391 };
3392
3393
3394 struct FoldSourceTensorCast : public OpRewritePattern {
3396
3397 LogicalResult matchAndRewrite(PadOp padTensorOp,
3399 auto castOp = padTensorOp.getSource().getDefiningOptensor::CastOp();
3401 return failure();
3402
3403 auto newResultType = PadOp::inferResultType(
3404 llvm::cast(castOp.getSource().getType()),
3405 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3406 padTensorOp.getResultType().getShape());
3407
3408 if (newResultType == padTensorOp.getResultType()) {
3410 padTensorOp.getSourceMutable().assign(castOp.getSource());
3411 });
3412 } else {
3413 auto newOp = rewriter.create(
3414 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3415 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3416 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3419 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3420
3422 padTensorOp, padTensorOp.getResultType(), newOp);
3423 }
3424 return success();
3425 }
3426 };
3427
3428
3429
3430 struct FoldTargetTensorCast : public OpRewritePattern {
3432
3433 LogicalResult matchAndRewrite(PadOp padTensorOp,
3435 if (!padTensorOp.getResult().hasOneUse())
3436 return failure();
3437 auto tensorCastOp =
3438 dyn_casttensor::CastOp(*padTensorOp->getUsers().begin());
3439 if (!tensorCastOp)
3440 return failure();
3442 tensorCastOp.getDest().getType()))
3443 return failure();
3444
3445 auto replacementOp = rewriter.create(
3446 padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3447 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3448 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3449 padTensorOp.getHigh(), padTensorOp.getNofold(),
3452
3453 rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3454 rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3455 return success();
3456 }
3457 };
3458
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
3483
3484
3485
3486
3487
3488
3489
3490
3491
3492
3493
3494 struct FoldOrthogonalPaddings : public OpRewritePattern {
3496
3497 LogicalResult matchAndRewrite(PadOp padOp,
3499 auto innerSliceOp = padOp.getSource().getDefiningOp();
3500 if (!innerSliceOp)
3501 return failure();
3502 auto outerPadOp = innerSliceOp.getSource().getDefiningOp();
3503 if (!outerPadOp || outerPadOp.getNofold())
3504 return failure();
3505 auto outerSliceOp = outerPadOp.getSource().getDefiningOp();
3506 if (!outerSliceOp)
3507 return failure();
3508
3509
3510 int64_t rank = padOp.getSourceType().getRank();
3511 if (outerSliceOp.getSourceType().getRank() != rank) {
3513 "cannot fold rank-reducing chain");
3514 }
3515
3516
3517 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3519 padOp, "cannot fold non-unit stride ExtractSliceOps");
3520 }
3521
3522
3523 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3525 "cannot fold PadOps with low padding");
3526 }
3527
3528
3530 Value innerValue = padOp.getConstantPaddingValue();
3531 Value outerValue = outerPadOp.getConstantPaddingValue();
3532 if (!innerValue || !outerValue ||
3535 innerAttr != outerAttr) {
3537 padOp, "cannot fold PadOps with different padding values");
3538 }
3539
3540
3541 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3542 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3543 if (innerDims.anyCommon(outerDims)) {
3545 padOp, "cannot fold PadOps with common padding dimensions");
3546 }
3547
3548
3549
3550
3551
3552
3554 for (auto en : enumerate(newOffsets)) {
3555 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3556 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3557 if (!innerDims.test(en.index()) &&
3559 en.value() = outerOffset;
3560 continue;
3561 }
3562 if (!outerDims.test(en.index()) &&
3564 en.value() = innerOffset;
3565 continue;
3566 }
3568 padOp, "cannot find zero-offset and zero-padding pair");
3569 }
3570
3571
3572
3573
3574
3575
3577 for (auto en : enumerate(newSizes)) {
3578 if (!outerDims.test(en.index()))
3579 continue;
3580 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3581 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3582 assert(!ShapedType::isDynamic(sourceSize) &&
3583 "expected padded dimension to have a static size");
3586 padOp, "cannot fold since the inner ExtractSliceOp size does not "
3587 "match the size of the outer padding");
3588 }
3589 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3590 }
3591
3592
3594 for (auto en : enumerate(newHighPad)) {
3595 if (innerDims.test(en.index()))
3596 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3597 if (outerDims.test(en.index()))
3598 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3599 }
3600
3601
3602
3603 auto newSliceOp = rewriter.create(
3604 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3605 innerSliceOp.getMixedStrides());
3606 auto newPadOp = rewriter.create(
3607 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3608 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3610 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3611 newPadOp.getRegion().begin());
3612 rewriter.replaceOp(padOp, newPadOp.getResult());
3613 return success();
3614 }
3615 };
3616
3619
3620 LogicalResult matchAndRewrite(PadOp padTensorOp,
3622 Value input = padTensorOp.getSource();
3623 if (!llvm::isa(input.getType()))
3624 return failure();
3625 auto inputDims = llvm::cast(input.getType()).getShape();
3626 auto inputRank = inputDims.size();
3627
3628 auto oldResultType =
3629 dyn_cast(padTensorOp.getResult().getType());
3630 if (!oldResultType)
3631 return failure();
3632
3633 auto outputDims = oldResultType.getShape();
3634
3635
3638 for (auto operand : padTensorOp.getLow()) {
3639 APSInt intOp;
3641 constOperandsLow.push_back(ShapedType::kDynamic);
3642 newLows.push_back(operand);
3643 continue;
3644 }
3645 constOperandsLow.push_back(intOp.getExtValue());
3646 }
3649 for (auto operand : padTensorOp.getHigh()) {
3650 APSInt intOp;
3652 constOperandsHigh.push_back(ShapedType::kDynamic);
3653 newHighs.push_back(operand);
3654 continue;
3655 }
3656 constOperandsHigh.push_back(intOp.getExtValue());
3657 }
3658
3661
3662
3663 if (inputDims.size() != outputDims.size() ||
3664 inputDims.size() != constLow.size() ||
3665 inputDims.size() != constHigh.size())
3666 return failure();
3667
3668 auto lowCount = 0;
3669 auto highCount = 0;
3670 for (size_t i = 0; i < inputRank; i++) {
3671 if (constLow[i] == ShapedType::kDynamic)
3672 constLow[i] = constOperandsLow[lowCount++];
3673 if (constHigh[i] == ShapedType::kDynamic)
3674 constHigh[i] = constOperandsHigh[highCount++];
3675 }
3676
3679
3680
3682 for (size_t i = 0; i < inputRank; i++) {
3683 if (outputDims[i] == ShapedType::kDynamic) {
3684 newOutDims.push_back(
3685 (staticLow[i] == ShapedType::kDynamic ||
3686 staticHigh[i] == ShapedType::kDynamic ||
3687 inputDims[i] == ShapedType::kDynamic
3688 ? ShapedType::kDynamic
3689 : inputDims[i] + staticLow[i] + staticHigh[i]));
3690 } else {
3691 newOutDims.push_back(outputDims[i]);
3692 }
3693 }
3694
3696 llvm::all_of(newOutDims,
3697 [&](int64_t x) { return x == ShapedType::kDynamic; }))
3698 return failure();
3699
3700
3702 newOutDims, padTensorOp.getType().getElementType());
3703 auto newOp = rewriter.create(
3704 padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3705 newLows, newHighs, padTensorOp.getNofold(),
3707
3709 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3710 rewriter.replaceOpWithNewOptensor::CastOp(padTensorOp, oldResultType,
3711 newOp);
3712
3713 return success();
3714 }
3715 };
3716
3717
3718
3719
3720
3721
3722
3723
3724
3725
3726
3727
3728
3729
3730
3731
3732
3733
3734
3735
3736
3737 struct FoldConsecutiveConstantPadding : public OpRewritePatterntensor::PadOp {
3739
3740 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3742 if (padOp.getNofold()) {
3743 return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3744 }
3745
3746 auto producerPad = padOp.getSource().getDefiningOptensor::PadOp();
3747 if (!producerPad || producerPad.getNofold()) {
3749 padOp, "producer is not a foldable tensor.pad op");
3750 }
3751
3752
3753 Value consumerPadValue = padOp.getConstantPaddingValue();
3754 Value producerPadValue = producerPad.getConstantPaddingValue();
3755 if (!consumerPadValue || !producerPadValue ||
3756 consumerPadValue != producerPadValue) {
3758 padOp,
3759 "cannot fold PadOps with different or non-constant padding values");
3760 }
3761
3762 Location loc = padOp.getLoc();
3765
3766
3770 for (auto [consumerIndex, producerIndex] :
3771 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3773 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3774 }
3775 return sumPaddings;
3776 };
3777
3779 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3781 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3782
3783 auto newPadOp = rewriter.createtensor::PadOp(
3784 padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
3785 newLowPad, newHighPad, padOp.getNofold(),
3787 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3788 newPadOp.getRegion().begin());
3789 rewriter.replaceOp(padOp, newPadOp.getResult());
3790 return success();
3791 }
3792 };
3793
3794 }
3795
3796 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3798 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3799 FoldOrthogonalPaddings, FoldStaticPadding,
3800 FoldConsecutiveConstantPadding>(context);
3801 }
3802
3803
3804
3805
3806
3807
3808
3809
3810
3811
3812 Value PadOp::getConstantPaddingValue() {
3813 auto yieldOp = dyn_cast(getRegion().front().getTerminator());
3814 if (!yieldOp)
3815 return {};
3816 Value padValue = yieldOp.getValue();
3817
3819 return padValue;
3820
3821 if (padValue.getParentBlock() == &getRegion().front())
3822 return {};
3823
3824 return padValue;
3825 }
3826
3828 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3829 !getNofold())
3830 return getSource();
3831 return {};
3832 }
3833
3834
3835
3836
3837
3838 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3839 ParallelCombiningOpInterface parallelCombiningParent =
3840 getParallelCombiningParent();
3841 for (const auto &it :
3842 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3844 if (&nextOp == getOperation())
3845 return parallelCombiningParent.getParentResult(it.index());
3846 }
3847 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3848 }
3849
3850
3863 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3867 }
3868
3869
3870
3876 build(b, result, source, dest, offsets, sizes, strides, attrs);
3877 }
3878
3879
3885 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3887 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3889 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3890 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3891 }
3892
3894 if (!isa(getOperation()->getParentOp()))
3895 return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3896 << *(getOperation()->getParentOp());
3897
3898
3899 RankedTensorType expectedType;
3902 getStaticSizes(), getStaticStrides(), &expectedType);
3905
3906
3907
3909 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3910 getStaticStrides(), true);
3911 if (!boundsResult.isValid)
3912 return getOperation()->emitError(boundsResult.errorMessage);
3913
3914 return success();
3915 }
3916
3917 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3919 results.add<InsertSliceOpConstantArgumentFolder,
3920 InsertSliceOpCastFolder,
3921 InsertSliceOpSourceCastInserter>(context);
3922 }
3923
3926 }
3927
3928
3929
3930
3931
3932 void ScatterOp::getAsmResultNames(
3934 setNameFn(getResult(), "scatter");
3935 }
3936
3938 int64_t destRank = getDestType().getRank();
3941 getIndicesType().getShape(), destRank,
3942 "scatter", "dest")))
3943 return failure();
3944
3945 if (!getUnique())
3946 return emitOpError("requires 'unique' attribute to be set");
3947
3948
3949
3950
3951
3952
3953 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3954 getDestType(), getIndicesType(), scatterDims, false);
3955 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3956 getDestType(), getIndicesType(), scatterDims, true);
3957 if (getSourceType() != expectedSourceType &&
3958 getSourceType() != expectedRankReducedSourceType) {
3959 return emitOpError("source type "
3960 "mismatch: "
3961 "expected ")
3962 << expectedSourceType << " or its rank-reduced variant "
3963 << expectedRankReducedSourceType << " (got: " << getSourceType()
3964 << ")";
3965 }
3966
3967 return success();
3968 }
3969
3970
3971
3972
3973
3976 build(builder, result, aggregateType, element, dynamicSizes);
3977 }
3978
3982 build(builder, result, aggregateType, element, dynamicSizes);
3983 }
3984
3990 build(builder, result, element, staticShape, dynamicSizes);
3991 }
3992
3993 void SplatOp::getAsmResultNames(
3995 setNameFn(getResult(), "splat");
3996 }
3997
4000 return emitOpError("incorrect number of dynamic sizes, has ")
4002 << getType().getNumDynamicDims();
4003 return success();
4004 }
4005
4006 LogicalResult
4010 unsigned ctr = 0;
4011 for (int64_t i = 0; i < getType().getRank(); ++i) {
4012 if (getType().isDynamicDim(i)) {
4014 } else {
4015 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
4016 }
4017 }
4018 return success();
4019 }
4020
4021 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4022 auto constOperand = adaptor.getInput();
4023 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4024 return {};
4025
4026
4027 if (().hasStaticShape())
4028 return {};
4029
4030
4031
4033 }
4034
4035
4036
4037
4039
4040
4041
4042 if (isa(op.getOperation()) ||
4043 isa(op.getOperation()))
4044 return false;
4045
4047 }
4048
4049
4050
4051
4052
4053
4054
4055
4056
4057
4058
4059
4060
4061
4062
4063
4064
4069
4072
4073
4074
4076 isalinalg::RelayoutOpInterface(*op))
4077 return failure();
4078
4082
4083
4084 auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4085
4087 replacements.reserve(newOp->getNumResults());
4088 for (auto [oldResult, newResult] :
4089 llvm::zip(op->getResults(), newOp->getResults())) {
4090 if (newResult.getType() != oldResult.getType()) {
4091 replacements.push_back(rewriter.createtensor::CastOp(
4092 op->getLoc(), oldResult.getType(), newResult));
4093 } else {
4094 replacements.push_back(newResult);
4095 }
4096 }
4097 rewriter.replaceOp(op, replacements);
4098
4099 return success();
4100 }
4101 };
4102
4103
4104
4105
4106
4107 void TensorDialect::getCanonicalizationPatterns(
4110 }
4111
4112
4113
4114
4115
4116 #define GET_OP_CLASSES
4117 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static int64_t getNumElements(ShapedType type)
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineDimExpr(unsigned position)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.