MLIR: lib/Dialect/Tensor/IR/TensorOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
33#include "llvm/ADT/DenseSet.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/SmallBitVector.h"
36#include "llvm/ADT/StringRef.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/MathExtras.h"
39#include
40
41using namespace mlir;
43
44
45
49 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
50 return op;
51 if (complex::ConstantOp::isBuildableWith(value, type))
52 return complex::ConstantOp::create(builder, loc, type,
53 llvm::cast(value));
54 return nullptr;
55}
56
59 auto tensorType = llvm::cast(value.getType());
60 if (tensorType.isDynamicDim(dim))
61 return builder.createOrFoldtensor::DimOp(loc, value, dim);
62
63 return builder.getIndexAttr(tensorType.getDimSize(dim));
64}
65
68 auto tensorType = llvm::cast(value.getType());
70 for (int64_t i = 0; i < tensorType.getRank(); ++i)
73}
74
77 auto tensorType = llvm::dyn_cast(opResult.getType());
78 assert(tensorType && "expected tensor type");
79
80
81
82 auto destOp = opResult.getDefiningOp();
83 if (destOp)
84 return destOp.getTiedOpOperand(opResult)->get();
85
86
89
90
92 if (!tensorType.hasStaticShape()) {
93
96 return failure();
98 } else {
99
100 for (int64_t sz : tensorType.getShape())
101 mixedSizes.push_back(b.getIndexAttr(sz));
102 }
103
104
105 Value emptyTensor =
106 tensor::EmptyOp::create(b, loc, mixedSizes, tensorType.getElementType());
107 return emptyTensor;
108}
109
114 if (llvm::isa(opResult.getType())) {
116 if (failed(destination))
117 return failure();
118 result.push_back(*destination);
119 }
120 }
122}
123
125 if (auto rtp1 = llvm::dyn_cast(tp1)) {
126 if (auto rtp2 = llvm::dyn_cast(tp2))
127 return rtp1.getShape() == rtp2.getShape() &&
128 rtp1.getElementType() == rtp2.getElementType();
129 return false;
130 }
131 return tp1 == tp2;
132}
133
134
135
138 llvm::SmallBitVector droppedDims(mixedSizes.size());
139 int64_t shapePos = reducedShape.size() - 1;
140
141 for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
142 size_t idx = mixedSizes.size() - size.index() - 1;
143
144 bool isStaticUnitSize =
145 isa(size.value()) &&
146 llvm::cast(cast(size.value())).getInt() == 1;
147
148 if (shapePos < 0) {
149
150
151 assert(isStaticUnitSize && "expected unit dim");
152 droppedDims.set(idx);
153 continue;
154 }
155
156
157 if (!isStaticUnitSize) {
158 --shapePos;
159 continue;
160 }
161
162
163 if (reducedShape[shapePos] == 1) {
164 --shapePos;
165 continue;
166 }
167
168
169 droppedDims.set(idx);
170 }
171
172 assert(shapePos < 0 && "dimension mismatch");
173 return droppedDims;
174}
175
176
177
178
179static RankedTensorType
183 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
184 "incorrect number of dynamic sizes");
185
186
187 unsigned ctr = 0;
188 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
189 if (type.isDynamicDim(i)) {
190 Value dynamicSize = dynamicSizes[ctr++];
192 if (cst.has_value()) {
193
194 if (cst.value() < 0) {
195 foldedDynamicSizes.push_back(dynamicSize);
196 continue;
197 }
198 staticShape[i] = *cst;
199 } else {
200 foldedDynamicSizes.push_back(dynamicSize);
201 }
202 }
203 }
204
205 return RankedTensorType::get(staticShape, type.getElementType(),
206 type.getEncoding());
207}
208
209
210
211
212
213bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
214 if (inputs.size() != 1 || outputs.size() != 1)
215 return false;
216 Type a = inputs.front(), b = outputs.front();
217 auto aT = dyn_cast(a);
218 auto bT = dyn_cast(b);
219 if (!aT || !bT)
220 return false;
221
222 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
223 return false;
224
226}
227
228namespace {
229
230
231
232struct ChainedTensorBitcast : public OpRewritePattern {
233 using OpRewritePattern::OpRewritePattern;
234
235 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
236 PatternRewriter &rewriter) const final {
237 auto tensorBitcastOperand =
238 tensorBitcast.getOperand().getDefiningOp();
239 if (!tensorBitcastOperand)
240 return failure();
241
242 auto resultType = cast(tensorBitcast.getType());
243 rewriter.replaceOpWithNewOp(tensorBitcast, resultType,
244 tensorBitcastOperand.getOperand());
246 }
247};
248
249}
250
251void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
253 results.add(context);
254}
255
256
257
258
259
260void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
261 setNameFn(getResult(), "cast");
262}
263
264
265
267 auto sourceType = llvm::dyn_cast(source);
268 auto targetType = llvm::dyn_cast(target);
269
270
271 if (!sourceType || !targetType)
272 return false;
273
274
275 if (sourceType.getElementType() != targetType.getElementType())
276 return false;
277
278
279 if (sourceType.getRank() != targetType.getRank())
280 return false;
281
282
283 if (sourceType.getEncoding() != targetType.getEncoding())
284 return false;
285
286
287 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
288 if (ShapedType::isStatic(std::get<0>(t)) &&
289 ShapedType::isDynamic(std::get<1>(t)))
290 return false;
291 }
292
293 return true;
294}
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
319 if (!castOp)
320 return false;
321
322
323
325 castOp.getSource().getType());
326}
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
350 if (!castOp)
351 return false;
353 castOp.getType());
354}
355
358 if (llvm::isa(opOperand.get()))
359 return false;
360 auto castOp = opOperand.get().getDefiningOptensor::CastOp();
361 return castOp && canFoldIntoConsumerOp(castOp);
362 });
363}
364
368 newOperands.reserve(op->getNumOperands());
369
371
372
374 for (OpOperand &opOperand : op->getOpOperands()) {
375 auto tensorCastOp = opOperand.get().getDefiningOptensor::CastOp();
377 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
378 if (op.isDpsInit(&opOperand) &&
379 !llvm::isa(newOperands.back().getType()))
380 newResTy[dpsInitIdx++] = newOperands.back().getType();
381 }
382 return newOperands;
383}
384
385
386
388 bool folded = false;
390 auto castOp = operand.get().getDefiningOptensor::CastOp();
392 operand.set(castOp.getOperand());
393 folded = true;
394 }
395 }
397}
398
400 if (inputs.size() != 1 || outputs.size() != 1)
401 return false;
402 Type a = inputs.front(), b = outputs.front();
403 auto aT = llvm::dyn_cast(a);
404 auto bT = llvm::dyn_cast(b);
405 if (!aT || !bT)
406 return false;
407
408 if (aT.getElementType() != bT.getElementType())
409 return false;
410
412}
413
414
415
418
420 return two;
422 return one;
423
424 int64_t rank = one.getRank();
425 if (rank != two.getRank())
426 return {};
427
429 join.reserve(rank);
430 for (int64_t i = 0; i < rank; ++i) {
431 if (one.isDynamicDim(i)) {
432 join.push_back(two.getDimSize(i));
433 continue;
434 }
435 if (two.isDynamicDim(i)) {
436 join.push_back(one.getDimSize(i));
437 continue;
438 }
439 if (one.getDimSize(i) != two.getDimSize(i))
440 return {};
441 join.push_back(one.getDimSize(i));
442 }
443 return RankedTensorType::get(join, one.getElementType());
444}
445
446namespace {
447
448
449
451 using OpRewritePattern::OpRewritePattern;
452
453 LogicalResult matchAndRewrite(CastOp tensorCast,
454 PatternRewriter &rewriter) const final {
455 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp();
456
457 if (!tensorCastOperand)
458 return failure();
459
460 auto sourceType =
461 llvm::cast(tensorCastOperand.getOperand().getType());
462 auto intermediateType = llvm::cast(tensorCastOperand.getType());
463 auto resultType = llvm::cast(tensorCast.getType());
464
465
466
467 auto firstJoin =
469
470
471 if (!firstJoin)
472 return failure();
473
474
475
476
477 auto newJoin = joinShapes(sourceType, resultType);
478 if (firstJoin != newJoin)
479 return failure();
480
481 rewriter.replaceOpWithNewOp(tensorCast, resultType,
482 tensorCastOperand.getOperand());
484 }
485};
486
487
488
489
490
491
492
493
494
495
496
497
498
499struct TensorCastExtractSlice : public OpRewritePattern {
500 using OpRewritePattern::OpRewritePattern;
501
502 LogicalResult matchAndRewrite(CastOp tensorCast,
503 PatternRewriter &rewriter) const final {
504 auto extractOperand =
505 tensorCast.getOperand().getDefiningOp();
506
507
508 auto rankedResultType =
509 llvm::dyn_cast(tensorCast.getType());
510 if (!rankedResultType)
511 return failure();
512
514 rankedResultType.getShape() ==
515 llvm::cast(tensorCast.getSource().getType())
516 .getShape())
517 return failure();
518
519 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
521 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
522 size_t dimIndex = 0;
523 for (size_t i = 0, e = sizes.size(); i < e; i++) {
524 if (dimMask && dimMask->count(i))
525 continue;
526 int64_t dim = rankedResultType.getShape()[dimIndex++];
527 if (ShapedType::isDynamic(dim))
528 continue;
529 sizes[i] = rewriter.getIndexAttr(dim);
530 }
531
532 rewriter.replaceOpWithNewOp(
533 tensorCast, rankedResultType, extractOperand.getSource(),
534 extractOperand.getMixedOffsets(), sizes,
535 extractOperand.getMixedStrides());
537 }
538};
539
540}
541
542void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
544 results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
545}
546
547
548
549
550
551RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
552 assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
553 auto tensorTypes =
554 llvm::map_to_vector<4>(inputTypes, llvm::CastTo);
555 int64_t concatRank = tensorTypes[0].getRank();
556
557
558 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
559
561 for (int64_t i = 0, e = concatRank; i < e; ++i) {
562 if (i == dim)
563 continue;
565 for (auto tensorType : tensorTypes)
568 }
570 for (auto tensorType : tensorTypes)
571 concatSize =
573 sizes[dim] = concatSize.asInteger();
574 return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
575}
576
579 FailureOr resultType =
580 inferResultType(dim, inputs.getTypes());
581 assert(succeeded(resultType) && "failed to infer concatenation result type");
582 build(builder, result, *resultType, dim, inputs);
583}
584
585LogicalResult ConcatOp::verify() {
586 if (getInputs().size() < 1)
587 return emitOpError("requires at least one input");
588
590 for (auto input : getInputs())
591 inputTypes.push_back(cast(input.getType()));
592
593 RankedTensorType resultType = getResultType();
594 int64_t resultRank = getRank();
595 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
596 return type.getRank() != resultRank;
597 }))
598 return emitOpError("rank of concatenated inputs must match result rank");
599
600 Type resultElementType = resultType.getElementType();
601 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
602 return type.getElementType() != resultElementType;
603 }))
604 return emitOpError("inputs and result element type must match");
605
607 if (dim >= resultRank)
608 return emitOpError("concatenation dim must be less than the tensor rank");
609
611 for (int64_t i = 0, e = resultRank; i < e; ++i) {
612 if (i == dim)
613 continue;
615 for (auto tensorType : inputTypes) {
616 FailureOr maybeSize =
618 if (failed(maybeSize))
619 return emitOpError("static concatenation size mismatch along ")
620 << "non-concatenated dimension " << i;
621 size = *maybeSize;
622 }
624 }
626 for (auto tensorType : inputTypes)
627 concatSize =
629 sizes[dim] = concatSize.asInteger();
630 auto inferredResultType =
631 RankedTensorType::get(sizes, inputTypes[0].getElementType());
632
633 for (auto [inferredSize, actualSize] :
634 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
635 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
636 ShapedType::isDynamic(actualSize);
637 if (!hasDynamic && inferredSize != actualSize)
639 << resultType << "does not match inferred shape "
640 << inferredResultType << " static sizes";
641 }
642
644}
645
646FailureOr<SmallVector> ConcatOp::decomposeOperation(OpBuilder &builder) {
647 size_t numInputs = getInputs().size();
648 uint64_t concatDim = getDim();
649
651 inputShapes.reserve(numInputs);
653 concatOffsets.reserve(numInputs);
655
660 for (auto [index, input] : llvm::enumerate(getInputs())) {
663 if (index == 0) {
664 outputShape = inputShape;
665 concatOffsets.push_back(zero);
666 } else {
667 concatOffsets.push_back(outputShape[concatDim]);
669 builder, loc, addExpr,
670 {outputShape[concatDim], inputShape[concatDim]});
671 }
672 inputShapes.emplace_back(std::move(inputShape));
673 }
674
675 Value replacement = tensor::EmptyOp::create(builder, loc, outputShape,
677
682 for (auto [index, input] : llvm::enumerate(getInputs())) {
683 offsets[concatDim] = concatOffsets[index];
684 auto insertSlice = tensor::InsertSliceOp::create(
685 builder, loc, input, replacement, offsets, inputShapes[index], strides);
687 }
690 }
692}
693
694LogicalResult
695ConcatOp::reifyResultShapes(OpBuilder &builder,
699 RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
700
701 Value init = inputs[0];
703
705
706
707
708
709 for (int64_t i = 0; i < rank; ++i) {
710 if (i == dim)
711 continue;
712 if (().isDynamicDim(i)) {
713 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
714 } else if (!inferredResultType.isDynamicDim(i)) {
716 builder, getLoc(),
717 builder.getIndexAttr(inferredResultType.getDimSize(i)));
718 } else {
719 reifiedReturnShapes[0][i] =
720 tensor::DimOp::create(builder, init.getLoc(), init, i).getResult();
721 }
722 }
723
724 if (getType().isDynamicDim(dim)) {
725
729 for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
731 sizes.push_back(
732 builder.createOrFoldtensor::DimOp(input.getLoc(), input, dim));
733 }
735 builder, getLoc(),
737 } else {
738
739
740 reifiedReturnShapes[0][dim] =
742 }
744}
745
746void ConcatOp::getAsmResultNames(
748 setNameFn(getResult(), "concat");
749}
750
753 if (inputs.size() == 1 && inputs[0].getType() == getResultType())
754 return inputs[0];
755 return {};
756}
757
758namespace {
759
760struct SingleInputConcatOp : public OpRewritePattern {
761 using OpRewritePattern::OpRewritePattern;
762
763 LogicalResult matchAndRewrite(ConcatOp concatOp,
764 PatternRewriter &rewriter) const override {
765 if (concatOp.getInputs().size() != 1)
766 return failure();
767 rewriter.replaceOpWithNewOp(concatOp, concatOp.getResultType(),
768 concatOp.getInputs()[0]);
770 }
771};
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792struct InferConcatOperandTypes : public OpRewritePattern {
793 using OpRewritePattern::OpRewritePattern;
794
795 LogicalResult matchAndRewrite(ConcatOp concatOp,
796 PatternRewriter &rewriter) const override {
797 int64_t dim = concatOp.getDim();
798 RankedTensorType inferredResultType =
799 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
800
801
802 LogicalResult matched = failure();
803
804
805 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
806 for (auto [operandIdx, operandType] :
807 llvm::enumerate(concatOp->getOperandTypes())) {
808
809 inferredOperandShape[dim] =
810 cast(operandType).getDimSize(dim);
811 auto inferredOperandType = RankedTensorType::get(
812 inferredOperandShape, inferredResultType.getElementType());
813
814
817
818
819 auto castOp =
820 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
821 concatOp.getOperand(operandIdx));
822 rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
823 concatOp->setOperand(operandIdx, castOp->getResult(0));
824 });
825 }
826 }
827
828 return matched;
829 }
830};
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846struct InferConcatResultType : public OpRewritePattern {
847 using OpRewritePattern::OpRewritePattern;
848
849 LogicalResult matchAndRewrite(ConcatOp concatOp,
850 PatternRewriter &rewriter) const override {
851 int64_t dim = concatOp.getDim();
852 RankedTensorType inferredResultType =
853 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
854
855
857 concatOp.getResultType())) {
858 return failure();
859 }
860
861 auto newConcatOp =
862 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
863 concatOp->getOperands());
864 rewriter.replaceOpWithNewOp(concatOp, concatOp.getResultType(),
865 newConcatOp);
866
868 }
869};
870}
871
872void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
874 results
875 .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
876 context);
877}
878
879
880
881
882
883void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
884 setNameFn(getResult(), "dim");
885}
886
889 auto loc = result.location;
891 build(builder, result, source, indexValue);
892}
893
894std::optional<int64_t> DimOp::getConstantIndex() {
896}
897
900 if (!constantIndex)
902
903 auto rankedSourceType = dyn_cast(getSource().getType());
904 if (!rankedSourceType)
906
907 if (rankedSourceType.getRank() <= constantIndex)
909
911}
912
915 setResultRange(getResult(),
917}
918
919OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
920
921 auto index = llvm::dyn_cast_if_present(adaptor.getIndex());
923 return {};
924
925
926 auto tensorType = llvm::dyn_cast(getSource().getType());
927 if (!tensorType)
928 return {};
929
930
931
933 if (indexVal < 0 || indexVal >= tensorType.getRank())
934 return {};
935
936
937 if (!tensorType.isDynamicDim(index.getInt())) {
939 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
940 }
941
942 Operation *definingOp = getSource().getDefiningOp();
943
944
945 if (auto fromElements = dyn_cast_or_nulltensor::GenerateOp(definingOp)) {
946 auto resultType =
947 llvm::cast(fromElements.getResult().getType());
948
949
950 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
951
952
953 auto dynExtents = fromElements.getDynamicExtents().begin();
954 for (auto dim : resultType.getShape().take_front(index.getInt()))
955 if (ShapedType::isDynamic(dim))
956 dynExtents++;
957
958 return Value{*dynExtents};
959 }
960
961
962 unsigned unsignedIndex = index.getValue().getZExtValue();
963
964 if (auto sliceOp = dyn_cast_or_nulltensor::ExtractSliceOp(definingOp)) {
965
966
967 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
968 sliceOp.isDynamicSize(unsignedIndex)) {
969 return {sliceOp.getDynamicSize(unsignedIndex)};
970 }
971 }
972
973
975 return getResult();
976
977 return {};
978}
979
980namespace {
981
983 using OpRewritePattern::OpRewritePattern;
984
985 LogicalResult matchAndRewrite(DimOp dimOp,
986 PatternRewriter &rewriter) const override {
987 auto castOp = dimOp.getSource().getDefiningOp();
988 if (!castOp)
989 return failure();
990 Value newSource = castOp.getOperand();
991 rewriter.replaceOpWithNewOp(dimOp, newSource, dimOp.getIndex());
993 }
994};
995
996
997
999 using OpRewritePattern::OpRewritePattern;
1000
1001 LogicalResult matchAndRewrite(DimOp dimOp,
1002 PatternRewriter &rewriter) const override {
1003 auto source = dimOp.getSource();
1004 auto destOp = source.getDefiningOp();
1005 if (!destOp)
1006 return failure();
1007
1008 auto resultIndex = cast(source).getResultNumber();
1009 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1010
1012 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1014 }
1015};
1016
1017
1018
1020 using OpRewritePattern::OpRewritePattern;
1021
1022 LogicalResult matchAndRewrite(DimOp dim,
1023 PatternRewriter &rewriter) const override {
1024 auto reshape = dim.getSource().getDefiningOp();
1025
1026 if (!reshape)
1027 return failure();
1028
1029
1030
1032 Location loc = dim.getLoc();
1033 Value extract =
1034 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1035 if (extract.getType() != dim.getType())
1036 extract =
1037 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1038 rewriter.replaceOp(dim, extract);
1040 }
1041};
1042}
1043
1044void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1046 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1047}
1048
1049
1050
1051
1052
1056 assert(none_of(staticShape, ShapedType::isDynamic) &&
1057 "expected only static sizes");
1058 build(builder, result, staticShape, elementType, ValueRange{}, encoding);
1059}
1060
1061void EmptyOp::build(OpBuilder &builder, OperationState &result,
1062 ArrayRef<int64_t> staticShape, Type elementType,
1063 ValueRange dynamicSizes, Attribute encoding) {
1064 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1065 build(builder, result, tensorType, dynamicSizes);
1066}
1067
1068void EmptyOp::build(OpBuilder &builder, OperationState &result,
1069 ArrayRef sizes, Type elementType,
1070 Attribute encoding) {
1071 SmallVector<int64_t> staticShape;
1072 SmallVector dynamicSizes;
1074 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1075}
1076
1077LogicalResult EmptyOp::verify() {
1079 return emitOpError("incorrect number of dynamic sizes, has ")
1081 << getType().getNumDynamicDims();
1083}
1084
1085LogicalResult
1086EmptyOp::reifyResultShapes(OpBuilder &builder,
1088 reifiedReturnShapes.resize(1, SmallVector(getType().getRank()));
1089 unsigned ctr = 0;
1090 for (int64_t i = 0; i < getType().getRank(); ++i) {
1091 if (getType().isDynamicDim(i)) {
1093 } else {
1094 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1095 }
1096 }
1098}
1099
1100Value EmptyOp::getDynamicSize(unsigned idx) {
1101 assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1102 unsigned ctr = 0;
1103 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1104 if (getType().isDynamicDim(i))
1105 ++ctr;
1107}
1108
1109SmallVector EmptyOp::getMixedSizes() {
1110 SmallVector result;
1111 unsigned ctr = 0;
1113 for (int64_t i = 0; i < getType().getRank(); ++i) {
1114 if (getType().isDynamicDim(i)) {
1116 } else {
1118 }
1119 }
1121}
1122
1123namespace {
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern {
1136 using OpRewritePattern::OpRewritePattern;
1137
1138 LogicalResult matchAndRewrite(EmptyOp op,
1139 PatternRewriter &rewriter) const override {
1140 SmallVector foldedDynamicSizes;
1142 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1143
1144
1145 if (foldedTensorType == op.getType())
1146 return failure();
1147
1148 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1149 foldedDynamicSizes);
1150 rewriter.replaceOpWithNewOptensor::CastOp(op, op.getType(), newOp);
1152 }
1153};
1154
1155struct FoldEmptyTensorWithDimOp : public OpRewritePattern {
1156 using OpRewritePattern::OpRewritePattern;
1157
1158 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1159 PatternRewriter &rewriter) const override {
1160 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1161 auto emptyTensorOp = dimOp.getSource().getDefiningOp();
1162 if (!emptyTensorOp || !maybeConstantIndex)
1163 return failure();
1164 auto emptyTensorType = emptyTensorOp.getType();
1165 if (*maybeConstantIndex < 0 ||
1166 *maybeConstantIndex >= emptyTensorType.getRank() ||
1167 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1168 return failure();
1170 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1172 }
1173};
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190struct FoldEmptyTensorWithCastOp : public OpRewritePattern {
1191 using OpRewritePattern::OpRewritePattern;
1192
1193 LogicalResult matchAndRewrite(CastOp castOp,
1194 PatternRewriter &rewriter) const override {
1196 return failure();
1197 auto producer = castOp.getSource().getDefiningOp();
1198 if (!producer)
1199 return failure();
1200
1201 auto resultType =
1202 llvm::cast(castOp->getResult(0).getType());
1203 ArrayRef<int64_t> resultShape = resultType.getShape();
1204 SmallVector currMixedSizes = producer.getMixedSizes();
1205 SmallVector newMixedSizes;
1206 newMixedSizes.reserve(currMixedSizes.size());
1207 assert(resultShape.size() == currMixedSizes.size() &&
1208 "mismatch in result shape and sizes of empty op");
1209 for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1210 int64_t newDim = std::get<0>(it);
1211 OpFoldResult currDim = std::get<1>(it);
1212
1213
1214 if (auto attr = llvm::dyn_cast_if_present(currDim)) {
1215 if (ShapedType::isDynamic(newDim) ||
1216 newDim != llvm::cast(attr).getInt()) {
1217
1218
1219
1221 producer, "mismatch in static value of shape of empty tensor "
1222 "result and cast result");
1223 }
1224 newMixedSizes.push_back(attr);
1225 continue;
1226 }
1227
1228
1229
1230 if (ShapedType::isStatic(newDim)) {
1231 newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1232 continue;
1233 }
1234
1235
1236
1237 newMixedSizes.push_back(currDim);
1238 }
1239
1240
1242 resultType.getElementType());
1244 }
1245};
1246
1247}
1248
1249void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1250 MLIRContext *context) {
1251 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1252 ReplaceEmptyTensorStaticShapeDims>(context);
1253}
1254
1255
1256
1257
1258
1259namespace {
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269struct ExtractFromTensorCast : public OpRewritePatterntensor::ExtractOp {
1270 using OpRewritePatterntensor::ExtractOp::OpRewritePattern;
1271
1272 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1273 PatternRewriter &rewriter) const final {
1274 auto tensorCast = extract.getTensor().getDefiningOptensor::CastOp();
1275 if (!tensorCast)
1276 return failure();
1277 if (!llvm::isa(tensorCast.getSource().getType()))
1278 return failure();
1280 extract, tensorCast.getSource(), extract.getIndices());
1282 }
1283};
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295struct ExtractFromCollapseShape : public OpRewritePatterntensor::ExtractOp {
1296 using OpRewritePatterntensor::ExtractOp::OpRewritePattern;
1297
1298 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1299 PatternRewriter &rewriter) const final {
1300 auto collapseOp =
1301 extractOp.getTensor().getDefiningOptensor::CollapseShapeOp();
1302 if (!collapseOp)
1303 return failure();
1304 if (!collapseOp.getSrcType().hasStaticShape())
1305 return failure();
1306
1307 auto sourceSizes = collapseOp.getSrcType().getShape();
1308
1309 SmallVector indices(extractOp.getIndices().begin(),
1310 extractOp.getIndices().end());
1311 SmallVector sourceIndices;
1312 for (auto [index, group] :
1313 llvm::zip(indices, collapseOp.getReassociationIndices())) {
1314 assert(!group.empty() && "association indices groups cannot be empty");
1315 auto groupSize = group.size();
1316
1317 if (groupSize == 1) {
1318 sourceIndices.push_back(index);
1319 continue;
1320 }
1321
1322 SmallVector<int64_t> basis =
1323 llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
1324 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1325 rewriter, extractOp.getLoc(), index, basis, true);
1326 llvm::append_range(sourceIndices, delinearize.getResults());
1327 }
1328 if (collapseOp.getReassociationIndices().empty()) {
1330 int64_t srcRank =
1331 cast(collapseOp.getSrcType()).getRank();
1333 rewriter, extractOp.getLoc(), zeroAffineMap,
1334 ArrayRef{});
1335 for (int64_t i = 0; i < srcRank; i++) {
1336 sourceIndices.push_back(
1338 }
1339 }
1340
1342 extractOp, collapseOp.getSrc(), sourceIndices);
1344 }
1345};
1346
1347}
1348
1349void ExtractOp::getAsmResultNames(
1350 function_ref<void(Value, StringRef)> setNameFn) {
1351 setNameFn(getResult(), "extracted");
1352}
1353
1354LogicalResult ExtractOp::verify() {
1355
1356 auto tensorType = llvm::cast(getTensor().getType());
1357 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1358 return emitOpError("incorrect number of indices for extract_element");
1360}
1361
1362
1363
1364
1365
1367 auto insertOp = extractOp.getTensor().getDefiningOp();
1368
1371 };
1372 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1373 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1374 return insertOp.getScalar();
1375
1376 return {};
1377}
1378
1379OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1380 if (Attribute tensor = adaptor.getTensor()) {
1381
1382
1383 if (auto splatTensor = llvm::dyn_cast(tensor))
1384 return splatTensor.getSplatValue();
1385
1386
1387 if (isa(tensor))
1388 return {};
1389 }
1390
1391
1392 SmallVector<uint64_t, 8> indices;
1393 for (Attribute indice : adaptor.getIndices()) {
1394 if (!indice || !llvm::isa(indice))
1395 return {};
1396 indices.push_back(llvm::cast(indice).getInt());
1397 }
1398
1399
1400 if (auto fromElementsOp = getTensor().getDefiningOp()) {
1401 auto tensorType = llvm::cast(fromElementsOp.getType());
1402 auto rank = tensorType.getRank();
1403 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1404 "rank mismatch");
1405 int flatIndex = 0;
1406 int stride = 1;
1407 for (int i = rank - 1; i >= 0; --i) {
1408 flatIndex += indices[i] * stride;
1409 stride *= tensorType.getDimSize(i);
1410 }
1411
1412
1413 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1414 flatIndex < 0)
1415 return {};
1416 return fromElementsOp.getElements()[flatIndex];
1417 }
1418
1419
1420 if (Attribute tensor = adaptor.getTensor()) {
1421 auto elementsAttr = llvm::dyn_cast(tensor);
1422 if (elementsAttr && elementsAttr.isValidIndex(indices))
1423 return elementsAttr.getValues()[indices];
1424 }
1425
1428
1429 return {};
1430}
1431
1432void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1433 MLIRContext *context) {
1434 results.add(context);
1435}
1436
1441
1442
1443
1444
1445
1446void FromElementsOp::getAsmResultNames(
1448 setNameFn(getResult(), "from_elements");
1449}
1450
1453 assert(!elements.empty() && "expected at least one element");
1454 Type resultType = RankedTensorType::get(
1455 {static_cast<int64_t>(elements.size())}, elements.front().getType());
1456 build(builder, result, resultType, elements);
1457}
1458
1459OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1460 if (!llvm::is_contained(adaptor.getElements(), nullptr))
1462 return {};
1463}
1464
1465namespace {
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483struct ExtractElementFromIndexCast
1484 : public OpRewritePatterntensor::ExtractOp {
1485 using OpRewritePatterntensor::ExtractOp::OpRewritePattern;
1486
1487 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1488 PatternRewriter &rewriter) const final {
1489 Location loc = extract.getLoc();
1490 auto indexCast = extract.getTensor().getDefiningOparith::IndexCastOp();
1491 if (!indexCast)
1492 return failure();
1493
1495
1496 auto newExtract = tensor::ExtractOp::create(
1497 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1498
1499 rewriter.replaceOpWithNewOparith::IndexCastOp(extract, extract.getType(),
1500 newExtract);
1501
1503 }
1504};
1505
1506}
1507
1508void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1509 MLIRContext *context) {
1510 results.add(context);
1511}
1512
1513
1514
1515
1516
1517void GatherOp::getAsmResultNames(
1518 function_ref<void(Value, StringRef)> setNameFn) {
1519 setNameFn(getResult(), "gather");
1520}
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1535 RankedTensorType indicesType,
1536 ArrayRef<int64_t> gatherDims,
1537 bool rankReduced) {
1538 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1539 resultShape.reserve(resultShape.size() + sourceType.getRank());
1540 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1541 if (llvm::binary_search(gatherDims, idx)) {
1542 if (!rankReduced)
1543 resultShape.push_back(1);
1544 continue;
1545 }
1546 resultShape.push_back(sourceType.getDimSize(idx));
1547 }
1548 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1549}
1550
1551static LogicalResult
1554 StringRef gatherOrScatter, StringRef sourceOrDest) {
1555 if (dims.empty())
1556 return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1557
1558 int64_t numGatherDims = dims.size();
1559 if (numGatherDims > rank)
1561 << "_dims overflow " << sourceOrDest << " rank";
1562 if (indices.empty() || indices.back() != numGatherDims)
1564 << "_dims length must match the size of last dimension of indices";
1565 for (int64_t val : dims) {
1566 if (val < 0)
1568 << "_dims value must be non-negative";
1569 if (val >= rank)
1571 << "_dims value must be smaller than " << sourceOrDest << " rank";
1572 }
1573 for (int64_t i = 1; i < numGatherDims; ++i) {
1574 if (dims[i - 1] >= dims[i])
1576 << "_dims values must be strictly increasing";
1577 }
1579}
1580
1581LogicalResult GatherOp::verify() {
1582 int64_t sourceRank = getSourceType().getRank();
1583 ArrayRef<int64_t> gatherDims = getGatherDims();
1585 getIndicesType().getShape(), sourceRank,
1586 "gather", "source")))
1587 return failure();
1588
1589 RankedTensorType expectedResultType = GatherOp::inferResultType(
1590 getSourceType(), getIndicesType(), gatherDims, false);
1591 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1592 getSourceType(), getIndicesType(), gatherDims, true);
1593 if (getResultType() != expectedResultType &&
1594 getResultType() != expectedRankReducedResultType) {
1596 "mismatch: "
1597 "expected ")
1598 << expectedResultType << " or its rank-reduced variant "
1599 << expectedRankReducedResultType << " (got: " << getResultType()
1600 << ")";
1601 }
1602
1604}
1605
1606OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1607 if (OpFoldResult reshapedSource = reshapeConstantSource(
1608 llvm::dyn_cast_if_present(adaptor.getSource()),
1610 return reshapedSource;
1611 return {};
1612}
1613
1614
1615
1616
1617
1618void InsertOp::getAsmResultNames(
1619 function_ref<void(Value, StringRef)> setNameFn) {
1620 setNameFn(getResult(), "inserted");
1621}
1622
1623LogicalResult InsertOp::verify() {
1624
1625 auto destType = llvm::cast(getDest().getType());
1626 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1627 return emitOpError("incorrect number of indices");
1629}
1630
1631OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1632 Attribute scalar = adaptor.getScalar();
1633 Attribute dest = adaptor.getDest();
1634 if (scalar && dest)
1635 if (auto splatDest = llvm::dyn_cast(dest))
1636 if (scalar == splatDest.getSplatValue())
1637 return dest;
1638 return {};
1639}
1640
1641
1642
1643
1644
1645void GenerateOp::getAsmResultNames(
1646 function_ref<void(Value, StringRef)> setNameFn) {
1647 setNameFn(getResult(), "generated");
1648}
1649
1650LogicalResult GenerateOp::reifyResultShapes(
1652 reifiedReturnShapes.resize(1, SmallVector(getType().getRank()));
1653 int idx = 0;
1654 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1655 if (getType().isDynamicDim(dim)) {
1656 reifiedReturnShapes[0][dim] = getOperand(idx++);
1657 } else {
1658 reifiedReturnShapes[0][dim] =
1660 }
1661 }
1663}
1664
1665LogicalResult GenerateOp::verify() {
1666
1667
1668 RankedTensorType resultType = llvm::cast(getType());
1669 if (getNumOperands() != resultType.getNumDynamicDims())
1670 return emitError("must have as many index operands as dynamic extents "
1671 "in the result type");
1673}
1674
1675LogicalResult GenerateOp::verifyRegions() {
1676 RankedTensorType resultTy = llvm::cast(getType());
1677
1678 if (!llvm::all_of(getBody().getArgumentTypes(),
1679 [](Type ty) { return ty.isIndex(); }))
1680 return emitError("all body arguments must be index");
1681 if (getBody().getNumArguments() != resultTy.getRank())
1682 return emitError("must have one body argument per input dimension");
1683
1684
1685 auto yieldOp = cast(getBody().getBlocks().front().getTerminator());
1686
1687 if (yieldOp.getValue().getType() != resultTy.getElementType())
1689 "body must be terminated with a `yield` operation of the tensor "
1690 "element type");
1691
1693}
1694
1695void GenerateOp::build(
1696 OpBuilder &b, OperationState &result, Type resultTy,
1699 build(b, result, resultTy, dynamicExtents);
1700
1701
1702 OpBuilder::InsertionGuard guard(b);
1703 Region *bodyRegion = result.regions.front().get();
1704 auto rank = llvm::cast(resultTy).getRank();
1705 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1706 SmallVector<Location, 2> argumentLocs(rank, result.location);
1707 Block *bodyBlock =
1708 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1710}
1711
1712namespace {
1713
1714
1715
1716
1717
1718struct StaticTensorGenerate : public OpRewritePattern {
1719 using OpRewritePattern::OpRewritePattern;
1720
1721 LogicalResult matchAndRewrite(GenerateOp generateOp,
1722 PatternRewriter &rewriter) const final {
1723 SmallVector foldedDynamicSizes;
1725 generateOp.getType(), generateOp.getDynamicExtents(),
1726 foldedDynamicSizes);
1727
1728
1729 if (foldedTensorType == generateOp.getType())
1730 return failure();
1731
1732 auto loc = generateOp.getLoc();
1733 auto newOp =
1734 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1736 newOp.getBody().begin());
1738 generateOp.getType(), newOp);
1740 }
1741};
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754struct ExtractFromTensorGenerate : public OpRewritePatterntensor::ExtractOp {
1755 using OpRewritePatterntensor::ExtractOp::OpRewritePattern;
1756
1757 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1758 PatternRewriter &rewriter) const final {
1759 auto tensorFromElements = extract.getTensor().getDefiningOp();
1761 return failure();
1762
1763 IRMapping mapping;
1764 Block *body = &tensorFromElements.getBody().front();
1765 mapping.map(body->getArguments(), extract.getIndices());
1767 rewriter.clone(op, mapping);
1768
1769 auto yield = cast(body->getTerminator());
1770
1773 }
1774};
1775
1776}
1777
1778void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1779 MLIRContext *context) {
1780
1781 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1782}
1783
1784
1785
1786
1787
1788void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1789 setNameFn(getResult(), "rank");
1790}
1791
1792OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1793
1794 auto type = getOperand().getType();
1795 auto shapedType = llvm::dyn_cast(type);
1796 if (shapedType && shapedType.hasRank())
1797 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1798 return IntegerAttr();
1799}
1800
1801
1802
1803
1804
1805void ReshapeOp::getAsmResultNames(
1806 function_ref<void(Value, StringRef)> setNameFn) {
1807 setNameFn(getResult(), "reshape");
1808}
1809
1811 int64_t numElements = 1;
1812 for (auto dim : type.getShape())
1813 numElements *= dim;
1814 return numElements;
1815}
1816
1817LogicalResult ReshapeOp::verify() {
1818 TensorType operandType = llvm::cast(getSource().getType());
1819 TensorType resultType = llvm::cast(getResult().getType());
1820
1822 return emitOpError("element types of source and destination tensor "
1823 "types should be the same");
1824
1825 int64_t shapeSize =
1826 llvm::cast(getShape().getType()).getDimSize(0);
1827 auto resultRankedType = llvm::dyn_cast(resultType);
1828 auto operandRankedType = llvm::dyn_cast(operandType);
1829
1830 if (resultRankedType) {
1831 if (operandRankedType && resultRankedType.hasStaticShape() &&
1832 operandRankedType.hasStaticShape()) {
1834 return emitOpError("source and destination tensor should have the "
1835 "same number of elements");
1836 }
1837 if (ShapedType::isDynamic(shapeSize))
1838 return emitOpError("cannot use shape operand with dynamic length to "
1839 "reshape to statically-ranked tensor type");
1840 if (shapeSize != resultRankedType.getRank())
1842 "length of shape operand differs from the result's tensor rank");
1843 }
1845}
1846
1847OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1848 if (OpFoldResult reshapedSource = reshapeConstantSource(
1849 llvm::dyn_cast_if_present(adaptor.getSource()),
1851 return reshapedSource;
1852
1853
1854
1855
1856 if (auto reshapeOpProducer = getSource().getDefiningOp()) {
1857 getSourceMutable().assign(reshapeOpProducer.getSource());
1858 return getResult();
1859 }
1860
1861 auto source = getSource();
1862 auto sourceTy = dyn_cast(source.getType());
1863 auto resultTy = dyn_cast(getType());
1864 if (!sourceTy || !resultTy || sourceTy != resultTy)
1865 return {};
1866
1867
1868
1869 if (sourceTy.getRank() <= 1)
1870 return source;
1871
1872 if (auto fromElements = getShape().getDefiningOptensor::FromElementsOp()) {
1873 auto elements = fromElements.getElements();
1874 bool dynamicNoop =
1875 sourceTy.getRank() == static_cast<int64_t>(elements.size());
1876 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1877 auto element = elements[id];
1878
1880 dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1881 continue;
1882 }
1883
1884 if (auto dimOp = element.getDefiningOptensor::DimOp()) {
1885 dynamicNoop &= dimOp.getSource() == source;
1886
1888 dynamicNoop &=
1889 cst.has_value() && cst.value() == static_cast<int64_t>(id);
1890 continue;
1891 }
1892
1893 dynamicNoop = false;
1894 break;
1895 }
1896
1897 if (dynamicNoop)
1898 return source;
1899 }
1900
1901 return {};
1902}
1903
1904
1905
1906
1907
1908void CollapseShapeOp::getAsmResultNames(
1909 function_ref<void(Value, StringRef)> setNameFn) {
1910 setNameFn(getResult(), "collapsed");
1911}
1912
1913void ExpandShapeOp::getAsmResultNames(
1914 function_ref<void(Value, StringRef)> setNameFn) {
1915 setNameFn(getResult(), "expanded");
1916}
1917
1918int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1919 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1920 "invalid resultDim");
1921 for (const auto &it : llvm::enumerate(getReassociationIndices()))
1922 if (llvm::is_contained(it.value(), resultDim))
1923 return it.index();
1924 llvm_unreachable("could not find reassociation group");
1925}
1926
1927FailureOr<SmallVector>
1928ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1929 RankedTensorType expandedType,
1930 ArrayRef reassociation,
1931 ArrayRef inputShape) {
1932 std::optional<SmallVector> outputShape =
1934 inputShape);
1935 if (!outputShape)
1936 return failure();
1937 return *outputShape;
1938}
1939
1940SmallVector ExpandShapeOp::getMixedOutputShape() {
1942}
1943
1944void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1945 Type resultType, Value src,
1946 ArrayRef reassociation,
1947 ArrayRef outputShape) {
1948 auto [staticOutputShape, dynamicOutputShape] =
1950 build(builder, result, cast(resultType), src,
1952 dynamicOutputShape, staticOutputShape);
1953}
1954
1955void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1956 Type resultType, Value src,
1957 ArrayRef reassociation) {
1958 SmallVector inputShape =
1960 auto tensorResultTy = cast(resultType);
1961 FailureOr<SmallVector> outputShape = inferOutputShape(
1962 builder, result.location, tensorResultTy, reassociation, inputShape);
1963 SmallVector outputShapeOrEmpty;
1964 if (succeeded(outputShape)) {
1965 outputShapeOrEmpty = *outputShape;
1966 }
1967 build(builder, result, tensorResultTy, src, reassociation,
1968 outputShapeOrEmpty);
1969}
1970
1971SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1973}
1974SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1976 getReassociationIndices());
1977}
1978
1979SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1981}
1982SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1984 getReassociationIndices());
1985}
1986
1987RankedTensorType CollapseShapeOp::inferCollapsedType(
1988 RankedTensorType type, SmallVector reassociation) {
1989 return inferCollapsedType(
1991 type.getContext(), reassociation)));
1992}
1993
1994
1995
1996RankedTensorType
1997CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1998 ArrayRef reassociation) {
1999 auto shape = type.getShape();
2000 SmallVector<int64_t, 4> newShape;
2001 newShape.reserve(reassociation.size());
2002
2003
2004
2006 unsigned currentDim = 0;
2007 for (AffineMap m : reassociation) {
2008 unsigned dim = m.getNumResults();
2009 auto band = shape.slice(currentDim, dim);
2010 int64_t size = 1;
2011 if (llvm::is_contained(band, ShapedType::kDynamic))
2012 size = ShapedType::kDynamic;
2013 else
2014 for (unsigned d = 0; d < dim; ++d)
2015 size *= shape[currentDim + d];
2016 newShape.push_back(size);
2017 currentDim += dim;
2018 }
2019
2020 return RankedTensorType::get(newShape, type.getElementType());
2021}
2022
2023void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2024 ArrayRef reassociation,
2025 ArrayRef attrs) {
2026 auto resultType = inferCollapsedType(
2027 llvm::cast(src.getType()),
2030 result.addAttribute(getReassociationAttrStrName(),
2032 build(b, result, resultType, src, attrs);
2033}
2034
2035template <typename TensorReshapeOp, bool isExpansion = std::is_same<
2036 TensorReshapeOp, ExpandShapeOp>::value>
2038 RankedTensorType expandedType,
2039 RankedTensorType collapsedType) {
2040 if (failed(
2041 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2042 return failure();
2043
2044 auto maps = op.getReassociationMaps();
2045 RankedTensorType expectedType =
2046 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2048 return op.emitOpError("expected collapsed type to be ")
2049 << expectedType << ", but got " << collapsedType;
2051}
2052
2053LogicalResult ExpandShapeOp::verify() {
2054 auto srcType = getSrcType();
2055 auto resultType = getResultType();
2056
2057 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2058 return emitOpError("expected number of static shape dims to be equal to "
2059 "the output rank (")
2060 << resultType.getRank() << ") but found "
2061 << getStaticOutputShape().size() << " inputs instead";
2062
2063 if ((int64_t)getOutputShape().size() !=
2064 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2065 return emitOpError("mismatch in dynamic dims in output_shape and "
2066 "static_output_shape: static_output_shape has ")
2067 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2068 << " dynamic dims while output_shape has " << getOutputShape().size()
2069 << " values";
2070
2072}
2073
2074LogicalResult CollapseShapeOp::verify() {
2076}
2077
2078namespace {
2079
2080
2081template
2082struct FoldReshapeWithConstant : OpRewritePattern {
2083 using OpRewritePattern::OpRewritePattern;
2084 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2085 PatternRewriter &rewriter) const override {
2086 DenseElementsAttr attr;
2088 return failure();
2089 if (!attr || !attr.isSplat())
2090 return failure();
2092 reshapeOp.getResultType(), attr.getRawData());
2095 }
2096};
2097
2098
2099template
2100class FoldReshapeWithSplat : public OpRewritePattern {
2101public:
2102 using OpRewritePattern::OpRewritePattern;
2103
2104 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2105 PatternRewriter &rewriter) const override {
2106 auto splatOp = reshapeOp.getSrc().template getDefiningOptensor::SplatOp();
2107 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2108 return failure();
2109
2111 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2113 }
2114};
2115
2116
2117
2118template
2119struct FoldReshapeWithFromElements : OpRewritePattern {
2120 using OpRewritePattern::OpRewritePattern;
2121 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2122 PatternRewriter &rewriter) const override {
2123 auto fromElements =
2124 reshapeOp.getSrc().template getDefiningOp();
2125 if (!fromElements)
2126 return failure();
2127
2128 auto shapedTy = llvm::cast(reshapeOp.getType());
2129
2130 if (!shapedTy.hasStaticShape())
2131 return failure();
2132
2133 rewriter.replaceOpWithNewOp(reshapeOp, reshapeOp.getType(),
2134 fromElements.getElements());
2136 }
2137};
2138
2139
2140struct FoldCollapseOfCastOp : public OpRewritePattern {
2141 using OpRewritePattern::OpRewritePattern;
2142
2143 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2144 PatternRewriter &rewriter) const override {
2145 auto castOp = collapseShapeOp.getSrc().getDefiningOptensor::CastOp();
2147 return failure();
2148
2149 RankedTensorType srcType =
2150 llvm::cast(castOp.getSource().getType());
2151 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2152 srcType, collapseShapeOp.getReassociationMaps());
2153
2154 if (newResultType == collapseShapeOp.getResultType()) {
2156 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2157 });
2158 } else {
2159 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2160 newResultType, castOp.getSource(),
2161 collapseShapeOp.getReassociation());
2163 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2164 }
2166 }
2167};
2168
2169
2170
2171
2172
2173struct ConvertToStaticExpandShape : public OpRewritePattern {
2174 using OpRewritePattern::OpRewritePattern;
2175
2176 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2177 PatternRewriter &rewriter) const override {
2178 auto castOp = expandOp.getSrc().getDefiningOp();
2180 return failure();
2181
2182 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2183 SmallVector<ReassociationIndices, 4> reassoc =
2184 expandOp.getReassociationIndices();
2185
2186 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2187 SmallVector dynamicOutputShape;
2188 auto outputIt = expandOp.getOutputShape().begin();
2189
2190 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2191 for (uint64_t outDim : innerReassoc) {
2192 if (ShapedType::isStatic(newOutputShape[outDim]))
2193 continue;
2194
2195
2196
2197
2198
2199 Value val = *outputIt;
2200 ++outputIt;
2201 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2202 dynamicOutputShape.push_back(val);
2203 continue;
2204 }
2205
2206 APInt cst;
2208 newOutputShape[outDim] = cst.getSExtValue();
2209 } else {
2210 dynamicOutputShape.push_back(val);
2211 }
2212 }
2213 }
2214
2215
2216 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2217 return failure();
2218
2219
2220 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2221 for (auto inDim : llvm::seq(0, newInputShape.size())) {
2222 for (auto outDim : reassoc[inDim]) {
2223 auto ofr = newOutputShape[outDim];
2224 if (ShapedType::isDynamic(ofr)) {
2225 newInputShape[inDim] = ShapedType::kDynamic;
2226 break;
2227 }
2228 newInputShape[inDim] *= ofr;
2229 }
2230 }
2231
2232 SmallVector outputOfr =
2233 getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2234 auto inputType = RankedTensorType::get(
2235 newInputShape, expandOp.getSrcType().getElementType());
2236 auto outputType = RankedTensorType::get(
2237 newOutputShape, expandOp.getSrcType().getElementType());
2238 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2239 expandOp.getSrc());
2240 auto newExpand = ExpandShapeOp::create(
2241 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2242 expandOp.getReassociationIndices(), outputOfr);
2244 newExpand.getResult());
2246 }
2247};
2248}
2249
2250void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2251 MLIRContext *context) {
2252 results.add<
2253 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2254 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
2255 ConvertToStaticExpandShape, FoldReshapeWithConstant,
2256 FoldReshapeWithSplat,
2257 FoldReshapeWithFromElements>(context);
2258}
2259
2260void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2261 MLIRContext *context) {
2262 results.add<
2263 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2264 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2265 tensor::DimOp, RankedTensorType>,
2266 FoldReshapeWithConstant,
2267 FoldReshapeWithSplat,
2268 FoldReshapeWithFromElements, FoldCollapseOfCastOp>(
2269 context);
2270}
2271
2272OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2274 adaptor.getOperands());
2275}
2276
2277OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2279 adaptor.getOperands());
2280}
2281
2282
2283
2284
2285
2286void ExtractSliceOp::getAsmResultNames(
2287 function_ref<void(Value, StringRef)> setNameFn) {
2288 setNameFn(getResult(), "extracted_slice");
2289}
2290
2291
2292
2293
2294RankedTensorType
2295ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2296 ArrayRef<int64_t> staticSizes) {
2297
2298
2299
2300 assert(static_cast<int64_t>(staticSizes.size()) ==
2301 sourceTensorType.getRank() &&
2302 "unexpected staticSizes not equal to rank of source");
2303 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2304 sourceTensorType.getEncoding());
2305}
2306
2307
2308RankedTensorType
2309ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2310 ArrayRef sizes) {
2311 SmallVector<int64_t> staticSizes;
2313
2314 assert(static_cast<int64_t>(staticSizes.size()) ==
2315 sourceTensorType.getRank() &&
2316 "unexpected staticSizes not equal to rank of source");
2317 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2318 sourceTensorType.getEncoding());
2319}
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2330 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2331 ArrayRef<int64_t> sizes) {
2332
2333 auto inferredType = llvm::cast(
2334 inferResultType(sourceRankedTensorType, sizes));
2335 int rankDiff = inferredType.getRank() - desiredResultRank;
2336 if (rankDiff > 0) {
2337 auto shape = inferredType.getShape();
2338 llvm::SmallBitVector dimsToProject =
2340 SmallVector<int64_t> projectedShape;
2341
2342 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2343 if (!dimsToProject.test(pos))
2344 projectedShape.push_back(shape[pos]);
2345 inferredType =
2346 RankedTensorType::get(projectedShape, inferredType.getElementType());
2347 }
2348 return inferredType;
2349}
2350
2351RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2352 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2353 ArrayRef sizes) {
2354 SmallVector<int64_t> staticSizes;
2355 SmallVector dynamicSizes;
2357 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2358 desiredResultRank, sourceRankedTensorType, staticSizes);
2359}
2360
2361
2362
2363void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2364 RankedTensorType resultType, Value source,
2365 ArrayRef offsets,
2366 ArrayRef sizes,
2367 ArrayRef strides,
2368 ArrayRef attrs) {
2369 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2370 SmallVector dynamicOffsets, dynamicSizes, dynamicStrides;
2374 auto sourceRankedTensorType = llvm::cast(source.getType());
2375
2376 if (!resultType) {
2377 resultType = llvm::cast(
2378 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
2379 }
2380 result.addAttributes(attrs);
2381 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2382 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2383 b.getDenseI64ArrayAttr(staticSizes),
2384 b.getDenseI64ArrayAttr(staticStrides));
2385}
2386
2387
2388
2389void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2390 ArrayRef offsets,
2391 ArrayRef sizes,
2392 ArrayRef strides,
2393 ArrayRef attrs) {
2394 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2395}
2396
2397
2398
2399void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2400 ArrayRef ranges,
2401 ArrayRef attrs) {
2403 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2404}
2405
2406
2407
2408void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2409 RankedTensorType resultType, Value source,
2411 ValueRange strides, ArrayRef attrs) {
2412 SmallVector offsetValues = llvm::to_vector<4>(
2413 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2414 SmallVector sizeValues = llvm::to_vector<4>(
2415 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2416 SmallVector strideValues = llvm::to_vector<4>(
2417 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2418 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2419}
2420
2421
2422void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2424 ValueRange strides, ArrayRef attrs) {
2425 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2426}
2427
2430 RankedTensorType expectedType) {
2435 return op->emitError("expected rank to be smaller or equal to ")
2436 << "the other rank. ";
2438 return op->emitError("expected type to be ")
2439 << expectedType << " or a rank-reduced version. (size mismatch) ";
2441 return op->emitError("expected element type to be ")
2442 << expectedType.getElementType();
2443 default:
2444 llvm_unreachable("unexpected extract_slice op verification result");
2445 }
2446}
2447
2448
2449
2450void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2451 RankedTensorType resultType, Value source,
2452 ArrayRef sizes,
2453 ArrayRef attrs) {
2454 Attribute zeroIdxAttr = b.getIndexAttr(0);
2455 Attribute oneIdxAttr = b.getIndexAttr(1);
2456 SmallVector readStrides(sizes.size(), oneIdxAttr);
2457 SmallVector readOffsets(sizes.size(), zeroIdxAttr);
2458 build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs);
2459}
2460
2461
2462LogicalResult ExtractSliceOp::verify() {
2463 RankedTensorType sourceType = getSourceType();
2464
2465
2466 RankedTensorType expectedType =
2467 ExtractSliceOp::inferResultType(sourceType, getMixedSizes());
2471
2472
2473
2475 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2476 getStaticStrides(), true);
2477 if (!boundsResult.isValid)
2478 return getOperation()->emitError(boundsResult.errorMessage);
2479
2481}
2482
2483llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2485}
2486
2487FailureOr
2488ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2489 ArrayRef<int64_t> desiredShape) {
2490 auto sourceTensorType = llvm::dyn_cast(value.getType());
2491 assert(sourceTensorType && "not a ranked tensor type");
2492 auto sourceShape = sourceTensorType.getShape();
2493 if (sourceShape.equals(desiredShape))
2494 return value;
2495 auto maybeRankReductionMask =
2497 if (!maybeRankReductionMask)
2498 return failure();
2500 b, loc, value,
2501 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2502}
2503
2504LogicalResult ExtractSliceOp::reifyResultShapes(
2506 reifiedReturnShapes.resize(1);
2507 reifiedReturnShapes[0].reserve(getType().getRank());
2508 SmallVector mixedSizes = getMixedSizes();
2509 llvm::SmallBitVector droppedDims = getDroppedDims();
2510 for (const auto &size : enumerate(mixedSizes)) {
2511 if (droppedDims.test(size.index()))
2512 continue;
2513 reifiedReturnShapes[0].push_back(size.value());
2514 }
2516}
2517
2518namespace {
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534class ExtractSliceOpCastFolder final : public OpRewritePattern {
2535public:
2536 using OpRewritePattern::OpRewritePattern;
2537
2538 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2539 PatternRewriter &rewriter) const override {
2540
2541 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2542 return matchPattern(operand, matchConstantIndex());
2543 }))
2544 return failure();
2545
2546 auto castOp = sliceOp.getSource().getDefiningOp();
2547 if (!castOp)
2548 return failure();
2549
2551 return failure();
2552
2553
2555 cast(castOp.getSource().getType()).getShape(),
2556 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2557 sliceOp.getStaticStrides());
2558 if (!sliceResult.isValid)
2559 return failure();
2560
2561
2562 Location loc = sliceOp.getLoc();
2563 Value newResult = ExtractSliceOp::create(
2564 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2565 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2566 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2567 sliceOp.getStaticStrides());
2568 rewriter.replaceOp(sliceOp, newResult);
2570 }
2571};
2572
2573
2574
2575
2576template <typename IterTy, typename ElemTy>
2577static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2578 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2579 ArrayRef<int64_t> strides,
2580 llvm::SmallVectorImpl *outValues) {
2581 assert(offsets.size() == sizes.size());
2582 assert(offsets.size() == strides.size());
2583 if (offsets.empty())
2584 return;
2585
2586 int64_t offset = offsets.front();
2587 int64_t size = sizes.front();
2588 int64_t stride = strides.front();
2589 if (offsets.size() == 1) {
2590 for (int64_t i = 0; i < size; ++i, offset += stride)
2591 outValues->push_back(*(values + offset));
2592
2593 return;
2594 }
2595
2596 for (int64_t i = 0; i < size; ++i, offset += stride) {
2597 auto begin = values + offset * counts.front();
2598 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2599 offsets.drop_front(), sizes.drop_front(),
2600 strides.drop_front(), outValues);
2601 }
2602}
2603
2604
2605
2606
2607class ConstantOpExtractSliceFolder final
2608 : public OpRewritePattern {
2609public:
2610 using OpRewritePattern::OpRewritePattern;
2611
2612 ConstantOpExtractSliceFolder(MLIRContext *context,
2614 : OpRewritePattern(context),
2615 controlFn(std::move(controlFn)) {}
2616
2617 LogicalResult matchAndRewrite(ExtractSliceOp op,
2618 PatternRewriter &rewriter) const override {
2619 DenseElementsAttr attr;
2621 return failure();
2622
2623
2625 return failure();
2626
2627
2628 auto sourceType = llvm::cast(op.getSource().getType());
2629 auto resultType = llvm::cast(op.getResult().getType());
2630 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2631 return failure();
2632
2633
2634 if (!controlFn(op))
2635 return failure();
2636
2637 int64_t count = sourceType.getNumElements();
2638 if (count == 0)
2639 return failure();
2640
2641
2642 auto offsets = op.getStaticOffsets();
2643 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2644 return failure();
2645 auto sizes = op.getStaticSizes();
2646 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2647 return failure();
2648 auto strides = op.getStaticStrides();
2649 if (llvm::is_contained(strides, ShapedType::kDynamic))
2650 return failure();
2651
2652
2653 SmallVector<int64_t> counts;
2654 ArrayRef<int64_t> shape = sourceType.getShape();
2655 counts.reserve(shape.size());
2656 for (int64_t v : shape) {
2657 count = count / v;
2658 counts.push_back(count);
2659 }
2660
2661
2662 DenseElementsAttr newAttr;
2663
2664 if (auto elems = llvm::dyn_cast(attr)) {
2665 SmallVector outValues;
2666 outValues.reserve(sourceType.getNumElements());
2667 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2668 elems.begin(), counts, offsets, sizes, strides, &outValues);
2670 } else if (auto elems = llvm::dyn_cast(attr)) {
2671 SmallVector outValues;
2672 outValues.reserve(sourceType.getNumElements());
2673 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2674 elems.begin(), counts, offsets, sizes, strides, &outValues);
2676 }
2677
2678 if (newAttr) {
2679 rewriter.replaceOpWithNewOparith::ConstantOp(op, resultType, newAttr);
2681 }
2682
2683 return failure();
2684 }
2685
2686private:
2687
2688
2690};
2691
2692}
2693
2697 patterns.add(patterns.getContext(), controlFn);
2698}
2699
2700
2706 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2707 op.getType().getRank(), op.getSourceType(), mixedSizes);
2708 }
2709};
2710
2711
2714 ExtractSliceOp newOp) {
2716 if (replacement.getType() != op.getType())
2717 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2720 }
2721};
2722
2723void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2724 MLIRContext *context) {
2725 results.add<
2726 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2727 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2728 ExtractSliceOpCastFolder>(context);
2729}
2730
2731
2732static LogicalResult
2734 ShapedType shapedType) {
2736 for (OpFoldResult ofr : op.getMixedOffsets())
2738 return failure();
2739
2740
2741 auto shape = shapedType.getShape();
2742 for (auto it : llvm::zip(op.getMixedSizes(), shape))
2744 return failure();
2745 for (OpFoldResult ofr : op.getMixedStrides())
2747 return failure();
2749}
2750
2751
2752
2753
2754
2756 auto insertOp = extractOp.getSource().getDefiningOp();
2757
2759 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2760 insertOp.isSameAs(extractOp, isSame))
2761 return insertOp.getSource();
2762
2763 return {};
2764}
2765
2766OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2767 if (OpFoldResult reshapedSource = reshapeConstantSource(
2768 llvm::dyn_cast_if_present(adaptor.getSource()),
2770 return reshapedSource;
2771 if (getSourceType() == getType() &&
2773 return this->getSource();
2775 return slice;
2776
2777 return OpFoldResult();
2778}
2779
2782 auto rankedTensorType = llvm::cast(tensor.getType());
2783 unsigned rank = rankedTensorType.getRank();
2787 return b.createOrFoldtensor::ExtractSliceOp(loc, targetType, tensor,
2788 offsets, sizes, strides);
2789}
2790
2791
2792
2793
2794
2795void InsertSliceOp::getAsmResultNames(
2797 setNameFn(getResult(), "inserted_slice");
2798}
2799
2800
2811 result.addAttributes(attrs);
2812 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2813 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2814 b.getDenseI64ArrayAttr(staticSizes),
2815 b.getDenseI64ArrayAttr(staticStrides));
2816}
2817
2818
2819
2820void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2821 Value dest, ArrayRef ranges,
2822 ArrayRef attrs) {
2824 build(b, result, source, dest, offsets, sizes, strides, attrs);
2825}
2826
2827
2828void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2830 ValueRange strides, ArrayRef attrs) {
2831 SmallVector offsetValues = llvm::to_vector<4>(
2832 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2833 SmallVector sizeValues = llvm::to_vector<4>(
2834 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2835 SmallVector strideValues = llvm::to_vector<4>(
2836 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2837 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2838}
2839
2840
2841
2843 RankedTensorType srcType, RankedTensorType dstType,
2845 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2846
2847
2848 RankedTensorType expected =
2849 ExtractSliceOp::inferResultType(dstType, staticSizes);
2850 if (expectedType)
2851 *expectedType = expected;
2853}
2854
2855
2856LogicalResult InsertSliceOp::verify() {
2857
2858 RankedTensorType expectedType;
2861 getStaticSizes(), getStaticStrides(), &expectedType);
2864
2865
2866
2868 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2869 getStaticStrides(), true);
2870 if (!boundsResult.isValid)
2871 return getOperation()->emitError(boundsResult.errorMessage);
2872
2874}
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2894 auto prevInsertOp = insertOp.getDest().getDefiningOp();
2895
2897 if (!prevInsertOp ||
2898 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2899 !prevInsertOp.isSameAs(insertOp, isSame))
2900 return failure();
2901
2902 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2904}
2905
2906
2907
2908
2909
2910
2911
2912
2914 auto extractOp = insertOp.getSource().getDefiningOp();
2915
2917 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2918 !extractOp.isSameAs(insertOp, isSame))
2919 return nullptr;
2920
2921 return extractOp.getSource();
2922}
2923
2924OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2925 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2926 getSourceType() == getType() &&
2928 return this->getSource();
2930 return getResult();
2934 return getDest();
2935 return OpFoldResult();
2936}
2937
2938LogicalResult InsertSliceOp::reifyResultShapes(
2940 reifiedReturnShapes.resize(1, SmallVector(getType().getRank()));
2943}
2944
2945namespace {
2946
2947
2948
2949template
2950class InsertSliceOpConstantArgumentFolder final
2951 : public OpRewritePattern {
2952public:
2953 using OpRewritePattern::OpRewritePattern;
2954
2955 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2956 PatternRewriter &rewriter) const override {
2957 SmallVector mixedOffsets(insertSliceOp.getMixedOffsets());
2958 SmallVector mixedSizes(insertSliceOp.getMixedSizes());
2959 SmallVector mixedStrides(insertSliceOp.getMixedStrides());
2960
2961
2965 return failure();
2966
2967
2968 SliceBoundsVerificationResult sliceResult =
2970 mixedOffsets, mixedSizes, mixedStrides);
2971 if (!sliceResult.isValid)
2972 return failure();
2973
2974
2975 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2976 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2977 mixedSizes);
2978 Value toInsert = insertSliceOp.getSource();
2979 if (sourceType != insertSliceOp.getSourceType()) {
2980 OpBuilder::InsertionGuard g(rewriter);
2981
2982
2983
2984 if (isa(insertSliceOp->getParentOp()))
2986 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
2987 sourceType, toInsert);
2988 }
2990 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2991 mixedSizes, mixedStrides);
2993 }
2994};
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016template
3017struct InsertSliceOpCastFolder final : public OpRewritePattern {
3018 using OpRewritePattern::OpRewritePattern;
3019
3020 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3021 PatternRewriter &rewriter) const override {
3022 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3023 return matchPattern(operand, matchConstantIndex());
3024 }))
3025 return failure();
3026
3027 auto getSourceOfCastOp = [](Value v) -> std::optional {
3028 auto castOp = v.getDefiningOptensor::CastOp();
3030 return std::nullopt;
3031 return castOp.getSource();
3032 };
3033 std::optional sourceCastSource =
3034 getSourceOfCastOp(insertSliceOp.getSource());
3035 std::optional destCastSource =
3036 getSourceOfCastOp(insertSliceOp.getDest());
3037 if (!sourceCastSource && !destCastSource)
3038 return failure();
3039
3040 auto src =
3041 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3042 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3043 auto srcType = llvm::dyn_cast(src.getType());
3044 auto dstType = llvm::dyn_cast(dst.getType());
3045 if (!srcType || !dstType)
3046 return failure();
3047
3048
3049
3050
3051 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3053 staticSizes, srcType.getShape(), true);
3054 if (!rankReductionMask.has_value())
3055 return failure();
3056
3057
3058
3059
3060
3061 SmallVector mixedSizes(insertSliceOp.getMixedSizes());
3062 int64_t rankReducedIdx = 0;
3063 for (auto [idx, size] : enumerate(staticSizes)) {
3064 if (!rankReductionMask.value().contains(idx) &&
3065 !srcType.isDynamicDim(rankReducedIdx)) {
3067 rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
3068 size = srcType.getDimSize(rankReducedIdx++);
3069 }
3070 }
3071
3072
3073 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
3074 staticSizes, insertSliceOp.getStaticStrides()) !=
3075 SliceVerificationResult::Success)
3076 return failure();
3077 SliceBoundsVerificationResult sliceResult =
3078 verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
3079 mixedSizes, insertSliceOp.getMixedStrides());
3080 if (!sliceResult.isValid)
3081 return failure();
3082
3084 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3085 insertSliceOp.getMixedOffsets(), mixedSizes,
3086 insertSliceOp.getMixedStrides());
3087
3088
3089 bool isParallelInsert =
3090 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3091 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3092 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3093 insertSliceOp.getDestType(),
3095 }
3098 }
3099};
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122template
3123struct InsertSliceOpSourceCastInserter final
3124 : public OpRewritePattern {
3125 using OpRewritePattern::OpRewritePattern;
3126
3127 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3128 PatternRewriter &rewriter) const override {
3129 RankedTensorType srcType = insertSliceOp.getSourceType();
3130 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3131 return failure();
3132 SmallVector<int64_t> newSrcShape(srcType.getShape());
3133 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3134 if (std::optional<int64_t> constInt =
3136
3137 if (*constInt < 0)
3138 return failure();
3139 newSrcShape[i] = *constInt;
3140 }
3141 }
3143 return failure();
3144
3145 RankedTensorType newSrcType = RankedTensorType::get(
3146 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3147 if (srcType == newSrcType ||
3149 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3150 return failure();
3151
3152
3153
3154
3155
3156
3157 OpBuilder::InsertionGuard g(rewriter);
3158
3159
3160
3161 if (isa(insertSliceOp->getParentOp()))
3163 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3164 newSrcType, insertSliceOp.getSource());
3166 insertSliceOp, cast, insertSliceOp.getDest(),
3167 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3168 insertSliceOp.getMixedStrides());
3170 }
3171};
3172}
3173
3174llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3176}
3177
3178void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3179 MLIRContext *context) {
3180 results.add<InsertSliceOpConstantArgumentFolder,
3181 InsertSliceOpCastFolder,
3182 InsertSliceOpSourceCastInserter>(context);
3183}
3184
3189 auto rankedTensorType = llvm::cast(dest.getType());
3190 unsigned rank = rankedTensorType.getRank();
3194 return b.createOrFoldtensor::InsertSliceOp(loc, tensor, dest, offsets,
3195 sizes, strides);
3196}
3197
3198
3199
3200
3201
3202void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3203 setNameFn(getResult(), "padded");
3204}
3205
3206LogicalResult PadOp::verify() {
3207 auto sourceType = llvm::cast(getSource().getType());
3208 auto resultType = llvm::cast(getResult().getType());
3209 auto expectedType =
3210 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3211 if (!expectedType) {
3212 return emitError("failed to infer expectedType from sourceType ")
3213 << sourceType << ", specified resultType is " << resultType;
3214 }
3215 if (resultType.getRank() != expectedType.getRank()) {
3216 return emitError("specified type ")
3217 << resultType << " does not match the inferred type "
3218 << expectedType;
3219 }
3220 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3221 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3222 continue;
3223 if (expectedType.isDynamicDim(i))
3224 continue;
3225 return emitError("specified type ")
3226 << resultType << " does not match the inferred type "
3227 << expectedType;
3228 }
3229
3231}
3232
3233LogicalResult PadOp::verifyRegions() {
3234 auto ®ion = getRegion();
3235 unsigned rank = llvm::cast(getResult().getType()).getRank();
3236 Block &block = region.front();
3238 return emitError("expected the block to have ") << rank << " arguments";
3239
3240
3241 for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3242 if (!en.value().isIndex())
3243 return emitOpError("expected block argument ")
3244 << (en.index() + 1) << " to be an index";
3245 }
3246
3247
3248 auto yieldOp = llvm::cast(block.getTerminator());
3249 if (yieldOp.getValue().getType() !=
3251 return emitOpError("expected yield type to match shape element type");
3252
3254}
3255
3256RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3257 ArrayRef<int64_t> staticLow,
3258 ArrayRef<int64_t> staticHigh,
3259 ArrayRef<int64_t> resultShape) {
3260 unsigned rank = sourceType.getRank();
3261 if (staticLow.size() != rank)
3262 return RankedTensorType();
3263 if (staticHigh.size() != rank)
3264 return RankedTensorType();
3265 if (!resultShape.empty() && resultShape.size() != rank)
3266 return RankedTensorType();
3267
3268 SmallVector<int64_t, 4> inferredShape;
3269 for (auto i : llvm::seq(0, rank)) {
3270 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3271 staticHigh[i] == ShapedType::kDynamic) {
3272 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3273 : resultShape[i]);
3274 } else {
3275 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3276 assert((resultShape.empty() || size == resultShape[i] ||
3277 resultShape[i] == ShapedType::kDynamic) &&
3278 "mismatch between inferred shape and result shape");
3279 inferredShape.push_back(size);
3280 }
3281 }
3282
3283 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3284}
3285
3286void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3287 Value source, ArrayRef<int64_t> staticLow,
3289 bool nofold, ArrayRef attrs) {
3290 auto sourceType = llvm::cast(source.getType());
3291 if (!resultType)
3292 resultType = inferResultType(sourceType, staticLow, staticHigh);
3293 result.addAttributes(attrs);
3294 build(b, result, resultType, source, low, high,
3295 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3296 nofold ? b.getUnitAttr() : UnitAttr());
3297}
3298
3299void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3301 ArrayRef attrs) {
3302 auto sourceType = llvm::cast(source.getType());
3303 unsigned rank = sourceType.getRank();
3304 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3305 build(b, result, resultType, source, staticVector, staticVector, low, high,
3306 nofold, attrs);
3307}
3308
3309void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3310 Value source, ArrayRef low,
3311 ArrayRef high, bool nofold,
3312 ArrayRef attrs) {
3313 auto sourceType = llvm::cast(source.getType());
3314 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3315 SmallVector<int64_t, 4> staticLow, staticHigh;
3316
3317
3318
3319
3322 if (!resultType) {
3323 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3324 }
3325 assert(llvm::isa(resultType));
3326 result.addAttributes(attrs);
3327 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3328 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3329 nofold ? b.getUnitAttr() : UnitAttr());
3330}
3331
3332void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3333 Value source, ArrayRef low,
3334 ArrayRef high, Value constantPadValue,
3335 bool nofold, ArrayRef attrs) {
3336 build(b, result, resultType, source, low, high, nofold, attrs);
3337
3338
3339 Region *region = result.regions[0].get();
3340 int sourceRank = llvm::cast(source.getType()).getRank();
3341 SmallVector blockArgTypes(sourceRank, b.getIndexType());
3342 SmallVector blockArgLocs(sourceRank, result.location);
3343
3344
3345
3346 OpBuilder::InsertionGuard guard(b);
3347 b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3348 tensor::YieldOp::create(b, result.location, constantPadValue);
3349}
3350
3351llvm::SmallBitVector PadOp::getPaddedDims() {
3352 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3353 auto extractPaddedDims = [&](ArrayRef paddingWidths) {
3354 for (const auto &en : enumerate(paddingWidths))
3356 paddedDims.set(en.index());
3357 };
3358 extractPaddedDims(getMixedLowPad());
3359 extractPaddedDims(getMixedHighPad());
3360 return paddedDims;
3361}
3362
3363namespace {
3364
3365
3366struct FoldStaticZeroPadding : public OpRewritePattern {
3367 using OpRewritePattern::OpRewritePattern;
3368
3369 LogicalResult matchAndRewrite(PadOp padTensorOp,
3370 PatternRewriter &rewriter) const override {
3371 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3372 return failure();
3373 if (padTensorOp.getNofold())
3374 return failure();
3376 padTensorOp, padTensorOp.getResult().getType(),
3377 padTensorOp.getSource());
3379 }
3380};
3381
3382
3383struct FoldSourceTensorCast : public OpRewritePattern {
3384 using OpRewritePattern::OpRewritePattern;
3385
3386 LogicalResult matchAndRewrite(PadOp padTensorOp,
3387 PatternRewriter &rewriter) const override {
3388 auto castOp = padTensorOp.getSource().getDefiningOptensor::CastOp();
3390 return failure();
3391
3392 auto newResultType = PadOp::inferResultType(
3393 llvm::cast(castOp.getSource().getType()),
3394 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3395 padTensorOp.getResultType().getShape());
3396
3397 if (newResultType == padTensorOp.getResultType()) {
3399 padTensorOp.getSourceMutable().assign(castOp.getSource());
3400 });
3401 } else {
3402 auto newOp = PadOp::create(
3403 rewriter, padTensorOp->getLoc(), newResultType,
3404 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3405 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3406 padTensorOp.getHigh(), padTensorOp.getNofold(),
3408 IRMapping mapper;
3409 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3410
3412 padTensorOp, padTensorOp.getResultType(), newOp);
3413 }
3415 }
3416};
3417
3418
3419
3420struct FoldTargetTensorCast : public OpRewritePattern {
3421 using OpRewritePattern::OpRewritePattern;
3422
3423 LogicalResult matchAndRewrite(PadOp padTensorOp,
3424 PatternRewriter &rewriter) const override {
3425 if (!padTensorOp.getResult().hasOneUse())
3426 return failure();
3427 auto tensorCastOp =
3428 dyn_casttensor::CastOp(*padTensorOp->getUsers().begin());
3429 if (!tensorCastOp)
3430 return failure();
3432 tensorCastOp.getDest().getType()))
3433 return failure();
3434
3435 auto replacementOp = PadOp::create(
3436 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3437 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3438 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3439 padTensorOp.getHigh(), padTensorOp.getNofold(),
3441 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3442
3443 rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3444 rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3446 }
3447};
3448
3449
3450
3451
3452
3453
3454
3455
3456
3457
3458
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
3483
3484struct FoldOrthogonalPaddings : public OpRewritePattern {
3485 using OpRewritePattern::OpRewritePattern;
3486
3487 LogicalResult matchAndRewrite(PadOp padOp,
3488 PatternRewriter &rewriter) const override {
3489 auto innerSliceOp = padOp.getSource().getDefiningOp();
3490 if (!innerSliceOp)
3491 return failure();
3492 auto outerPadOp = innerSliceOp.getSource().getDefiningOp();
3493 if (!outerPadOp || outerPadOp.getNofold())
3494 return failure();
3495 auto outerSliceOp = outerPadOp.getSource().getDefiningOp();
3496 if (!outerSliceOp)
3497 return failure();
3498
3499
3500 int64_t rank = padOp.getSourceType().getRank();
3501 if (outerSliceOp.getSourceType().getRank() != rank) {
3503 "cannot fold rank-reducing chain");
3504 }
3505
3506
3507 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3509 padOp, "cannot fold non-unit stride ExtractSliceOps");
3510 }
3511
3512
3513 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3515 "cannot fold PadOps with low padding");
3516 }
3517
3518
3519 Attribute innerAttr, outerAttr;
3520 Value innerValue = padOp.getConstantPaddingValue();
3521 Value outerValue = outerPadOp.getConstantPaddingValue();
3522 if (!innerValue || !outerValue ||
3525 innerAttr != outerAttr) {
3527 padOp, "cannot fold PadOps with different padding values");
3528 }
3529
3530
3531 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3532 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3533 if (innerDims.anyCommon(outerDims)) {
3535 padOp, "cannot fold PadOps with common padding dimensions");
3536 }
3537
3538
3539
3540
3541
3542
3543 SmallVector newOffsets(rank, rewriter.getIndexAttr(0));
3544 for (auto en : enumerate(newOffsets)) {
3545 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3546 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3547 if (!innerDims.test(en.index()) &&
3549 en.value() = outerOffset;
3550 continue;
3551 }
3552 if (!outerDims.test(en.index()) &&
3554 en.value() = innerOffset;
3555 continue;
3556 }
3558 padOp, "cannot find zero-offset and zero-padding pair");
3559 }
3560
3561
3562
3563
3564
3565
3566 SmallVector newSizes = innerSliceOp.getMixedSizes();
3567 for (auto en : enumerate(newSizes)) {
3568 if (!outerDims.test(en.index()))
3569 continue;
3570 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3571 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3572 assert(ShapedType::isStatic(sourceSize) &&
3573 "expected padded dimension to have a static size");
3576 padOp, "cannot fold since the inner ExtractSliceOp size does not "
3577 "match the size of the outer padding");
3578 }
3579 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3580 }
3581
3582
3583 SmallVector newHighPad(rank, rewriter.getIndexAttr(0));
3584 for (auto en : enumerate(newHighPad)) {
3585 if (innerDims.test(en.index()))
3586 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3587 if (outerDims.test(en.index()))
3588 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3589 }
3590
3591
3592
3593 auto newSliceOp = ExtractSliceOp::create(
3594 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3595 newSizes, innerSliceOp.getMixedStrides());
3596 auto newPadOp = PadOp::create(
3597 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3598 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3600 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3601 newPadOp.getRegion().begin());
3602 rewriter.replaceOp(padOp, newPadOp.getResult());
3604 }
3605};
3606
3607struct FoldStaticPadding : public OpRewritePattern {
3608 using OpRewritePattern::OpRewritePattern;
3609
3610 LogicalResult matchAndRewrite(PadOp padTensorOp,
3611 PatternRewriter &rewriter) const override {
3612 Value input = padTensorOp.getSource();
3613 if (!llvm::isa(input.getType()))
3614 return failure();
3615 auto inputDims = llvm::cast(input.getType()).getShape();
3616 auto inputRank = inputDims.size();
3617
3618 auto oldResultType =
3619 dyn_cast(padTensorOp.getResult().getType());
3620 if (!oldResultType)
3621 return failure();
3622
3623 auto outputDims = oldResultType.getShape();
3624
3625
3626 SmallVector<int64_t> constOperandsLow;
3627 SmallVector newLows;
3628 for (auto operand : padTensorOp.getLow()) {
3629 APSInt intOp;
3631 constOperandsLow.push_back(ShapedType::kDynamic);
3632 newLows.push_back(operand);
3633 continue;
3634 }
3635 constOperandsLow.push_back(intOp.getExtValue());
3636 }
3637 SmallVector<int64_t> constOperandsHigh;
3638 SmallVector newHighs;
3639 for (auto operand : padTensorOp.getHigh()) {
3640 APSInt intOp;
3642 constOperandsHigh.push_back(ShapedType::kDynamic);
3643 newHighs.push_back(operand);
3644 continue;
3645 }
3646 constOperandsHigh.push_back(intOp.getExtValue());
3647 }
3648
3649 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3650 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3651
3652
3653 if (inputDims.size() != outputDims.size() ||
3654 inputDims.size() != constLow.size() ||
3655 inputDims.size() != constHigh.size())
3656 return failure();
3657
3658 auto lowCount = 0;
3659 auto highCount = 0;
3660 for (size_t i = 0; i < inputRank; i++) {
3661 if (constLow[i] == ShapedType::kDynamic)
3662 constLow[i] = constOperandsLow[lowCount++];
3663 if (constHigh[i] == ShapedType::kDynamic)
3664 constHigh[i] = constOperandsHigh[highCount++];
3665 }
3666
3667 auto staticLow = ArrayRef<int64_t>(constLow);
3668 auto staticHigh = ArrayRef<int64_t>(constHigh);
3669
3670
3671 SmallVector<int64_t> newOutDims;
3672 for (size_t i = 0; i < inputRank; i++) {
3673 if (outputDims[i] == ShapedType::kDynamic) {
3674 newOutDims.push_back(
3675 (staticLow[i] == ShapedType::kDynamic ||
3676 staticHigh[i] == ShapedType::kDynamic ||
3677 inputDims[i] == ShapedType::kDynamic
3678 ? ShapedType::kDynamic
3679 : inputDims[i] + staticLow[i] + staticHigh[i]));
3680 } else {
3681 newOutDims.push_back(outputDims[i]);
3682 }
3683 }
3684
3685 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3686 llvm::all_of(newOutDims,
3687 [&](int64_t x) { return x == ShapedType::kDynamic; }))
3688 return failure();
3689
3690
3691 auto newResultType = RankedTensorType::get(
3692 newOutDims, padTensorOp.getType().getElementType());
3693 auto newOp = PadOp::create(
3694 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3695 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3697
3698 IRMapping mapper;
3699 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3700 rewriter.replaceOpWithNewOptensor::CastOp(padTensorOp, oldResultType,
3701 newOp);
3702
3704 }
3705};
3706
3707
3708
3709
3710
3711
3712
3713
3714
3715
3716
3717
3718
3719
3720
3721
3722
3723
3724
3725
3726
3727struct FoldConsecutiveConstantPadding : public OpRewritePatterntensor::PadOp {
3728 using OpRewritePatterntensor::PadOp::OpRewritePattern;
3729
3730 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3731 PatternRewriter &rewriter) const override {
3732 if (padOp.getNofold()) {
3733 return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3734 }
3735
3736 auto producerPad = padOp.getSource().getDefiningOptensor::PadOp();
3737 if (!producerPad || producerPad.getNofold()) {
3739 padOp, "producer is not a foldable tensor.pad op");
3740 }
3741
3742
3743 Value consumerPadValue = padOp.getConstantPaddingValue();
3744 Value producerPadValue = producerPad.getConstantPaddingValue();
3745 if (!consumerPadValue || !producerPadValue ||
3746 consumerPadValue != producerPadValue) {
3748 padOp,
3749 "cannot fold PadOps with different or non-constant padding values");
3750 }
3751
3752 Location loc = padOp.getLoc();
3753 AffineExpr d0, d1;
3755
3756
3757 auto addPaddings = [&](ArrayRef consumerPaddings,
3758 ArrayRef producerPaddings) {
3759 SmallVector sumPaddings;
3760 for (auto [consumerIndex, producerIndex] :
3761 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3763 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3764 }
3765 return sumPaddings;
3766 };
3767
3768 SmallVector newHighPad =
3769 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3770 SmallVector newLowPad =
3771 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3772
3773 auto newPadOp = tensor::PadOp::create(
3774 rewriter, padOp.getLoc(), padOp.getResultType(),
3775 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3777 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3778 newPadOp.getRegion().begin());
3779 rewriter.replaceOp(padOp, newPadOp.getResult());
3781 }
3782};
3783
3784}
3785
3786LogicalResult
3787PadOp::reifyResultShapes(OpBuilder &b,
3789 reifiedReturnShapes.resize(1, SmallVector(getType().getRank()));
3790 SmallVector lp = getMixedLowPad();
3791 SmallVector hp = getMixedHighPad();
3792 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3793 if (().isDynamicDim(i)) {
3794 reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
3795 continue;
3796 }
3797 Location loc = getLoc();
3798 Value dim = b.createOrFoldtensor::DimOp(
3800
3801 AffineExpr d0, d1, d2;
3802 bindDims(b.getContext(), d0, d1, d2);
3804 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3805 }
3807}
3808
3809void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3810 MLIRContext *context) {
3811 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3812 FoldOrthogonalPaddings, FoldStaticPadding,
3813 FoldConsecutiveConstantPadding>(context);
3814}
3815
3816
3817
3818
3819
3820
3821
3822
3823
3824
3825Value PadOp::getConstantPaddingValue() {
3826 auto yieldOp = dyn_cast(getRegion().front().getTerminator());
3827 if (!yieldOp)
3828 return {};
3829 Value padValue = yieldOp.getValue();
3830
3832 return padValue;
3833
3834 if (padValue.getParentBlock() == &getRegion().front())
3835 return {};
3836
3837 return padValue;
3838}
3839
3840OpFoldResult PadOp::fold(FoldAdaptor) {
3841 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3842 !getNofold())
3843 return getSource();
3844 return {};
3845}
3846
3847
3848
3849
3850
3851OpResult ParallelInsertSliceOp::getTiedOpResult() {
3852 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3853 for (const auto &it :
3854 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3855 Operation &nextOp = it.value();
3856 if (&nextOp == getOperation())
3857 return parallelCombiningParent.getParentResult(it.index());
3858 }
3859 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3860}
3861
3862
3863void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3864 Value source, Value dest,
3865 ArrayRef offsets,
3866 ArrayRef sizes,
3867 ArrayRef strides,
3868 ArrayRef attrs) {
3869 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3870 SmallVector dynamicOffsets, dynamicSizes, dynamicStrides;
3874 result.addAttributes(attrs);
3875 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3876 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3877 b.getDenseI64ArrayAttr(staticSizes),
3878 b.getDenseI64ArrayAttr(staticStrides));
3879}
3880
3881
3882
3883void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3884 Value source, Value dest,
3885 ArrayRef ranges,
3886 ArrayRef attrs) {
3888 build(b, result, source, dest, offsets, sizes, strides, attrs);
3889}
3890
3891
3892void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3893 Value source, Value dest, ValueRange offsets,
3895 ArrayRef attrs) {
3896 SmallVector offsetValues = llvm::to_vector<4>(
3897 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3898 SmallVector sizeValues = llvm::to_vector<4>(
3899 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3900 SmallVector strideValues = llvm::to_vector<4>(
3901 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3902 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3903}
3904
3905
3906
3907void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
3908 Value dest, ArrayRef sizes,
3909 ArrayRef attrs) {
3910 Attribute zeroIdxAttr = b.getIndexAttr(0);
3911 Attribute oneIdxAttr = b.getIndexAttr(1);
3912 SmallVector writeStrides(sizes.size(), oneIdxAttr);
3913 SmallVector writeOffsets(sizes.size(), zeroIdxAttr);
3914 build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs);
3915}
3916
3917LogicalResult ParallelInsertSliceOp::verify() {
3918 if (!isa(getOperation()->getParentOp()))
3919 return this->emitError("expected InParallelOpInterface parent, got:")
3920 << *(getOperation()->getParentOp());
3921
3922
3923 RankedTensorType expectedType;
3926 getStaticSizes(), getStaticStrides(), &expectedType);
3929
3930
3931
3933 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3934 getStaticStrides(), true);
3935 if (!boundsResult.isValid)
3936 return getOperation()->emitError(boundsResult.errorMessage);
3937
3939}
3940
3941void ParallelInsertSliceOp::getCanonicalizationPatterns(
3942 RewritePatternSet &results, MLIRContext *context) {
3943 results.add<InsertSliceOpConstantArgumentFolder,
3944 InsertSliceOpCastFolder,
3945 InsertSliceOpSourceCastInserter>(context);
3946}
3947
3948llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3950}
3951
3952
3953MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3954 return getDestMutable();
3955}
3956
3957Operation *ParallelInsertSliceOp::getIteratingParent() {
3958
3959 if (auto combiningOp =
3960 dyn_cast(getOperation()->getParentOp()))
3961 return combiningOp->getParentOp();
3962 return nullptr;
3963}
3964
3965
3966
3967
3968
3969void ScatterOp::getAsmResultNames(
3970 function_ref<void(Value, StringRef)> setNameFn) {
3971 setNameFn(getResult(), "scatter");
3972}
3973
3974LogicalResult ScatterOp::verify() {
3975 int64_t destRank = getDestType().getRank();
3976 ArrayRef<int64_t> scatterDims = getScatterDims();
3978 getIndicesType().getShape(), destRank,
3979 "scatter", "dest")))
3980 return failure();
3981
3982 if (!getUnique())
3983 return emitOpError("requires 'unique' attribute to be set");
3984
3985
3986
3987
3988
3989
3990 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3991 getDestType(), getIndicesType(), scatterDims, false);
3992 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3993 getDestType(), getIndicesType(), scatterDims, true);
3994 if (getSourceType() != expectedSourceType &&
3995 getSourceType() != expectedRankReducedSourceType) {
3997 "mismatch: "
3998 "expected ")
3999 << expectedSourceType << " or its rank-reduced variant "
4000 << expectedRankReducedSourceType << " (got: " << getSourceType()
4001 << ")";
4002 }
4003
4005}
4006
4007
4008
4009
4010
4011void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4012 Type aggregateType, ValueRange dynamicSizes) {
4013 build(builder, result, aggregateType, element, dynamicSizes);
4014}
4015
4016void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4017 ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
4018 auto aggregateType = RankedTensorType::get(staticShape, element.getType());
4019 build(builder, result, aggregateType, element, dynamicSizes);
4020}
4021
4022void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4023 ArrayRef sizes) {
4024 SmallVector<int64_t> staticShape;
4025 SmallVector dynamicSizes;
4027 build(builder, result, element, staticShape, dynamicSizes);
4028}
4029
4030void SplatOp::getAsmResultNames(
4031 function_ref<void(Value, StringRef)> setNameFn) {
4032 setNameFn(getResult(), "splat");
4033}
4034
4035LogicalResult SplatOp::verify() {
4037 return emitOpError("incorrect number of dynamic sizes, has ")
4039 << getType().getNumDynamicDims();
4041}
4042
4043LogicalResult
4044SplatOp::reifyResultShapes(OpBuilder &builder,
4046 reifiedReturnShapes.resize(1, SmallVector(getType().getRank()));
4047 unsigned ctr = 0;
4048 for (int64_t i = 0; i < getType().getRank(); ++i) {
4049 if (getType().isDynamicDim(i)) {
4051 } else {
4052 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
4053 }
4054 }
4056}
4057
4058OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4059 auto constOperand = adaptor.getInput();
4060 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4061 return {};
4062
4063
4064 if (().hasStaticShape())
4065 return {};
4066
4067
4068
4070}
4071
4072
4073
4074
4076
4077
4078
4079 if (isa(op.getOperation()) ||
4080 isa(op.getOperation()))
4081 return false;
4082
4084}
4085
4086
4087
4088
4089
4090
4091
4092
4093
4094
4095
4096
4097
4098
4099
4100
4101
4106
4109
4110
4111
4113 isalinalg::RelayoutOpInterface(*op))
4114 return failure();
4115
4119
4120
4121 auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4122
4124 replacements.reserve(newOp->getNumResults());
4125 for (auto [oldResult, newResult] :
4126 llvm::zip(op->getResults(), newOp->getResults())) {
4127 if (newResult.getType() != oldResult.getType()) {
4128 replacements.push_back(tensor::CastOp::create(
4129 rewriter, op->getLoc(), oldResult.getType(), newResult));
4130 } else {
4131 replacements.push_back(newResult);
4132 }
4133 }
4134 rewriter.replaceOp(op, replacements);
4135
4137 }
4138};
4139
4140
4141
4142
4143
4144void TensorDialect::getCanonicalizationPatterns(
4145 RewritePatternSet &results) const {
4146 results.add(getContext());
4147}
4148
4149
4150
4151
4152
4153#define GET_OP_CLASSES
4154#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition TensorOps.cpp:416
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
Definition TensorOps.cpp:1366
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
Definition TensorOps.cpp:1552
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
Definition TensorOps.cpp:2428
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
Definition TensorOps.cpp:4075
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition TensorOps.cpp:2893
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition TensorOps.cpp:2733
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition TensorOps.cpp:2755
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
Definition TensorOps.cpp:2842
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
Definition TensorOps.cpp:180
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition TensorOps.cpp:136
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
Definition TensorOps.cpp:2913
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition TensorOps.cpp:2037
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
Attributes are known-constant values of operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
AffineExpr getAffineDimExpr(unsigned position)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition TensorOps.cpp:387
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
Definition TensorOps.cpp:356
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition TensorOps.cpp:2694
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition TensorOps.cpp:349
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
Definition TensorOps.cpp:365
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition TensorOps.cpp:318
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition TensorOps.cpp:3185
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition TensorOps.cpp:2780
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition TensorOps.cpp:124
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:57
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
Definition TensorOps.cpp:1437
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition TensorOps.cpp:75
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition TensorOps.cpp:266
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition TensorOps.cpp:110
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
Definition TensorOps.cpp:4103
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
Definition TensorOps.cpp:4107
A canonicalizer wrapper to replace ExtractSliceOps.
Definition TensorOps.cpp:2712
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition TensorOps.cpp:2713
Return the canonical type of the result of an extract_slice op.
Definition TensorOps.cpp:2701
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition TensorOps.cpp:2702
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.