MLIR: lib/Dialect/Linalg/IR/LinalgOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
14
41
42 #include "llvm/ADT/DenseMap.h"
43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/SetOperations.h"
45 #include "llvm/ADT/SmallSet.h"
46 #include "llvm/ADT/SmallVector.h"
47 #include "llvm/ADT/StringSet.h"
48 #include "llvm/ADT/TypeSwitch.h"
49 #include "llvm/Support/FormatVariadic.h"
50 #include "llvm/Support/InterleavedRange.h"
51 #include "llvm/Support/LogicalResult.h"
52 #include "llvm/Support/MathExtras.h"
53 #include "llvm/Support/raw_ostream.h"
54 #include
55 #include
56
57 using namespace mlir;
59
60
62 int64_t dim) {
63 auto type = cast(v.getType());
64 if (!type.isDynamicDim(dim))
65 return builder.getIndexAttr(type.getDimSize(dim));
66
69 .Case([&](RankedTensorType t) -> Value {
70 return builder.createtensor::DimOp(loc, v, dim);
71 })
72 .Case([&](MemRefType t) -> Value {
73 return builder.creatememref::DimOp(loc, v, dim);
74 }));
75 }
76
77
78
84 .Case([&](RankedTensorType t) -> Operation * {
85 return b.createtensor::ExtractSliceOp(loc, source, offsets, sizes,
86 strides);
87 })
88 .Case([&](MemRefType type) -> Operation * {
89 return b.creatememref::SubViewOp(loc, source, offsets, sizes,
90 strides);
91 })
92 .Default([&](Type t) -> Operation * { return nullptr; });
93 }
94
95
96
97
98
100 int64_t dim) {
101 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
102 return b.createOrFoldmemref::DimOp(loc, source, dim);
103 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
104 return b.createOrFoldtensor::DimOp(loc, source, dim);
105 llvm_unreachable("Expected MemRefType or TensorType");
106 }
107
109 int64_t dim) {
110 auto shapedType = llvm::cast(source.getType());
111 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
113 return b.getIndexAttr(shapedType.getDimSize(dim));
114 }
115
116
117
118
119
122
123
124
125
126
127
134 for (auto containers : {inputTypes, outputTypes}) {
135 for (auto t : containers) {
136 argTypes.push_back(
138
139
141 }
142 }
143
144
147 opBuilder.createBlock(®ion, {}, argTypes, argLocs);
148
151 regionBuilder(b, *body, attrs);
152
153
154
155
156 }
157
158
159
160
161
163 std::optional resultTensorTypes,
167
169 resultTensorTypes.value_or(TypeRange());
170 if (!resultTensorTypes)
171 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
172 llvm::IsaPred);
173
174 state.addOperands(inputs);
175 state.addOperands(outputs);
176 state.addTypes(derivedResultTypes);
177
178 state.addAttributes(attributes);
179 state.addAttribute(
180 "operandSegmentSizes",
182 static_cast<int32_t>(outputs.size())}));
183
184
185 Region ®ion = *state.addRegion();
187 state.attributes.getAttrs(), regionBuilder);
188 }
189
191 std::optional resultTensorTypes,
196
198 indexingMapsAttrVal = llvm::map_to_vector(
199 MatmulOp::getDefaultIndexingMaps(b.getContext()),
201 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
202 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
203 attributes, regionBuilder);
204 }
205
207 std::optional resultTensorTypes,
212
214 indexingMapsAttrVal =
217 });
218 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
219 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
220 attributes, regionBuilder);
221 }
222
224 std::optional resultTensorTypes,
229
231 indexingMapsAttrVal =
234 });
235 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
236 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
237 attributes, regionBuilder);
238 }
239
240
241
242 static ParseResult
246 bool addOperandSegmentSizes = true) {
247 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
249 outputsOperands;
250
253 return failure();
254 }
257 return failure();
258
261 return failure();
262
266 return failure();
267 }
268
273 return failure();
274 }
275
276 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
278 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
280 return failure();
281
282 if (addOperandSegmentSizes) {
283
284
285
286
287
288
291 attrs.append("operandSegmentSizes",
293 {static_cast<int32_t>(inputsOperands.size()),
294 static_cast<int32_t>(outputsOperands.size())}));
296 } else {
299 {static_cast<int32_t>(inputsOperands.size()),
300 static_cast<int32_t>(outputsOperands.size())}));
301 }
302 }
304 std::optional info =
306 if (info) {
307 if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
308 return parser.emitError(attrsLoc)
309 << "'" << result.name.getStringRef() << "' op ";
310 })))
311 return failure();
312 }
313 }
314 return success();
315 }
316
319 if (!inputs.empty())
320 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
321 if (!outputs.empty())
322 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
323 }
324
325
326
327
328
333 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
336 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
337 "region expects {0} args, got {1}",
338 numRegionArgs, inputTypes.size() + outputTypes.size()));
339 }
340
343 regionBuilder);
344 return success();
345 }
346
347 static ParseResult
351 return failure();
352 return success();
353 }
354
357 unsigned numRegionArgs,
359
362 return failure();
363
364
366 return failure();
367
368
369
372 return failure();
373 result.addTypes(outputTensorsTypes);
374
375 std::unique_ptr region = std::make_unique();
378 regionBuilder))
379 return failure();
380 result.addRegion(std::move(region));
381
382 return success();
383 }
384
387 if (resultTypes.empty())
388 return;
390 }
391
396
397
398
400
401
403
404
405 }
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430 namespace {
431
432 class RegionBuilderHelper {
433 public:
434 RegionBuilderHelper(OpBuilder &builder, Block &block)
435 : builder(builder), block(block) {}
436
437
439 if (!isFloatingPoint(arg))
440 llvm_unreachable("unsupported non numeric type");
442 builder.setInsertionPointToEnd(&block);
444 case UnaryFn::exp:
445 return builder.createmath::ExpOp(arg.getLoc(), arg);
446 case UnaryFn:🪵
447 return builder.createmath::LogOp(arg.getLoc(), arg);
449 return builder.createmath::AbsFOp(arg.getLoc(), arg);
451 return builder.createmath::CeilOp(arg.getLoc(), arg);
453 return builder.createmath::FloorOp(arg.getLoc(), arg);
454 case UnaryFn::negf:
455 return builder.createarith::NegFOp(arg.getLoc(), arg);
456 case UnaryFn::reciprocal: {
458 auto one = builder.createarith::ConstantOp(arg.getLoc(),
459 ::cast(oneAttr));
460 return builder.createarith::DivFOp(arg.getLoc(), one, arg);
461 }
463 return builder.createmath::RoundOp(arg.getLoc(), arg);
464 case UnaryFn::sqrt:
465 return builder.createmath::SqrtOp(arg.getLoc(), arg);
466 case UnaryFn::rsqrt:
467 return builder.createmath::RsqrtOp(arg.getLoc(), arg);
468 case UnaryFn::square:
469 return builder.createarith::MulFOp(arg.getLoc(), arg, arg);
470 case UnaryFn::tanh:
471 return builder.createmath::TanhOp(arg.getLoc(), arg);
472 case UnaryFn::erf:
473 return builder.createmath::ErfOp(arg.getLoc(), arg);
474 }
475 llvm_unreachable("unsupported unary function");
476 }
477
478
480 bool allComplex = isComplex(arg0) && isComplex(arg1);
481 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
482 bool allInteger = isInteger(arg0) && isInteger(arg1);
485 if (!allComplex && !allFloatingPoint && !allInteger)
486 llvm_unreachable("unsupported non numeric type");
488 builder.setInsertionPointToEnd(&block);
490 case BinaryFn::add:
491 if (allComplex)
492 return builder.createcomplex::AddOp(arg0.getLoc(), arg0, arg1);
493 if (allFloatingPoint)
494 return builder.createarith::AddFOp(arg0.getLoc(), arg0, arg1);
495 if (allBool)
496 return builder.createarith::OrIOp(arg0.getLoc(), arg0, arg1);
497 return builder.createarith::AddIOp(arg0.getLoc(), arg0, arg1);
498 case BinaryFn::sub:
499 if (allComplex)
500 return builder.createcomplex::SubOp(arg0.getLoc(), arg0, arg1);
501 if (allFloatingPoint)
502 return builder.createarith::SubFOp(arg0.getLoc(), arg0, arg1);
503 if (allBool)
504 llvm_unreachable("unsupported operation: sub with bools");
505 return builder.createarith::SubIOp(arg0.getLoc(), arg0, arg1);
506 case BinaryFn::mul:
507 if (allComplex)
508 return builder.createcomplex::MulOp(arg0.getLoc(), arg0, arg1);
509 if (allFloatingPoint)
510 return builder.createarith::MulFOp(arg0.getLoc(), arg0, arg1);
511 if (allBool)
512 return builder.createarith::AndIOp(arg0.getLoc(), arg0, arg1);
513 return builder.createarith::MulIOp(arg0.getLoc(), arg0, arg1);
514 case BinaryFn::div:
515 if (allComplex)
516 return builder.createcomplex::DivOp(arg0.getLoc(), arg0, arg1);
517 if (allFloatingPoint)
518 return builder.createarith::DivFOp(arg0.getLoc(), arg0, arg1);
519 if (allBool)
520 llvm_unreachable("unsupported operation: div with bools");
521 return builder.createarith::DivSIOp(arg0.getLoc(), arg0, arg1);
522 case BinaryFn::div_unsigned:
523 if (!allInteger || allBool)
524 llvm_unreachable("unsupported operation: unsigned div not on uint");
525 return builder.createarith::DivUIOp(arg0.getLoc(), arg0, arg1);
526 case BinaryFn::max_signed:
527 assert(!allComplex);
528 if (allFloatingPoint)
529 return builder.createarith::MaximumFOp(arg0.getLoc(), arg0, arg1);
530 return builder.createarith::MaxSIOp(arg0.getLoc(), arg0, arg1);
531 case BinaryFn::min_signed:
532 assert(!allComplex);
533 if (allFloatingPoint)
534 return builder.createarith::MinimumFOp(arg0.getLoc(), arg0, arg1);
535 return builder.createarith::MinSIOp(arg0.getLoc(), arg0, arg1);
536 case BinaryFn::max_unsigned:
537 assert(!allComplex);
538 if (allFloatingPoint)
539 return builder.createarith::MaximumFOp(arg0.getLoc(), arg0, arg1);
540 return builder.createarith::MaxUIOp(arg0.getLoc(), arg0, arg1);
541 case BinaryFn::min_unsigned:
542 assert(!allComplex);
543 if (allFloatingPoint)
544 return builder.createarith::MinimumFOp(arg0.getLoc(), arg0, arg1);
545 return builder.createarith::MinUIOp(arg0.getLoc(), arg0, arg1);
546 case BinaryFn::powf:
547 assert(allFloatingPoint);
548 return builder.createmath::PowFOp(arg0.getLoc(), arg0, arg1);
549 }
550 llvm_unreachable("unsupported binary function");
551 }
552
553
556 bool headBool =
558 bool tailFloatingPoint =
559 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
560 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
562 builder.setInsertionPointToEnd(&block);
564 case TernaryFn::select:
565 if (!headBool && !(tailFloatingPoint || tailInteger))
566 llvm_unreachable("unsupported non numeric type");
567 return builder.createarith::SelectOp(arg0.getLoc(), arg0, arg1, arg2);
568 }
569 llvm_unreachable("unsupported ternary function");
570 }
571
572
573 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
574 switch (typeFn) {
575 case TypeFn::cast_signed:
576 return cast(toType, operand, false);
577 case TypeFn::cast_unsigned:
578 return cast(toType, operand, true);
579 }
580 llvm_unreachable("unsupported type conversion function");
581 }
582
583 void yieldOutputs(ValueRange values) {
585 builder.setInsertionPointToEnd(&block);
586 Location loc = builder.getUnknownLoc();
587 builder.create(loc, values);
588 }
589
590 Value constant(const std::string &value) {
592 builder.setInsertionPointToEnd(&block);
593 Location loc = builder.getUnknownLoc();
595 return builder.createarith::ConstantOp(loc, ::cast(valueAttr));
596 }
597
598 Value index(int64_t dim) {
600 builder.setInsertionPointToEnd(&block);
601 return builder.create(builder.getUnknownLoc(), dim);
602 }
603
604 Type getIntegerType(unsigned width) {
606 }
607
610
611 private:
612
613
614
615
616 Value cast(Type toType, Value operand, bool isUnsignedCast) {
618 builder.setInsertionPointToEnd(&block);
619 auto loc = operand.getLoc();
621 }
622
623 bool isComplex(Value value) {
624 return llvm::isa(value.getType());
625 }
626 bool isFloatingPoint(Value value) {
627 return llvm::isa(value.getType());
628 }
629 bool isInteger(Value value) {
630 return llvm::isa(value.getType());
631 }
632
635 };
636
637 }
638
639
640
641
642
643 namespace {
644
647 LogicalResult matchAndRewrite(CopyOp copyOp,
649 if (copyOp.getInputs() != copyOp.getOutputs())
651 if (copyOp.hasPureBufferSemantics())
652 rewriter.eraseOp(copyOp);
653 else
654 rewriter.replaceOp(copyOp, copyOp.getInputs());
655
656 return success();
657 }
658 };
659
660 }
661
662 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
664 results.add(context);
665 }
666
667
668
669
670
671 namespace {
672
673
674
675
676
677 template
678 struct FoldFillWithTensorReshape : OpRewritePattern {
680 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
682 auto oldFill = reshapeOp.getSrc().template getDefiningOp();
683 if (!oldFill)
684 return failure();
685
686 Location loc = oldFill.getLoc();
687 TensorReshapeOp newInit;
688 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
689
690 newInit = rewriter.create(
691 loc, reshapeOp.getResultType(), oldFill.output(),
692 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
693 reshapeOp.getStaticOutputShape());
694 } else {
695 newInit = rewriter.create(loc, reshapeOp.getResultType(),
696 oldFill.output(),
697 reshapeOp.getReassociation());
698 }
701 return success();
702 }
703 };
704
705
706
707 struct FoldFillWithPad final : public OpRewritePatterntensor::PadOp {
709
710 LogicalResult matchAndRewrite(tensor::PadOp padOp,
712 auto fillOp = padOp.getSource().getDefiningOplinalg::FillOp();
713 if (!fillOp)
714 return failure();
715
716
717
718 Value padValue = padOp.getConstantPaddingValue();
719 if (!padValue || fillOp.value() != padValue)
720 return failure();
721
725 padOp, "failed to reify tensor.pad op result shape");
726
727 auto emptyTensor = rewriter.createtensor::EmptyOp(
728 padOp.getLoc(), reifiedShape.front(),
729 padOp.getResultType().getElementType());
730 Value replacement =
731 rewriter
734 .getResult(0);
735 if (replacement.getType() != padOp.getResultType()) {
736 replacement = rewriter.createtensor::CastOp(
737 fillOp.getLoc(), padOp.getResultType(), replacement);
738 }
739 rewriter.replaceOp(padOp, replacement);
740 return success();
741 }
742 };
743
744
745
746
747 struct FoldInsertPadIntoFill : public OpRewritePatterntensor::InsertSliceOp {
749
750 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
752 auto srcPadOp = insertOp.getSource().getDefiningOptensor::PadOp();
753 if (!srcPadOp)
754 return failure();
755
756 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
757 return failure();
758
759
760
761 Value firstDest = insertOp.getDest();
762 while (auto prevOp = firstDest.getDefiningOptensor::InsertSliceOp()) {
763 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
764 return failure();
765
766
767
768 bool disjoint = false;
769 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
770
771
772 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
773 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
774 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
775 continue;
776
777
778 int64_t prevStart = prevOp.getStaticOffset(i);
779 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
780 prevOp.getStaticStride(i);
781 int64_t nextStart = insertOp.getStaticOffset(i);
782 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
783 insertOp.getStaticStride(i);
784 if (prevEnd < nextStart || nextEnd < prevStart) {
785 disjoint = true;
786 break;
787 }
788 }
789
790 if (!disjoint)
791 break;
792 firstDest = prevOp.getDest();
793 }
794
795
796
797 auto dstFillOp = firstDest.getDefiningOplinalg::FillOp();
798 if (!dstFillOp)
799 return failure();
800
801
802
803 Value padValue = srcPadOp.getConstantPaddingValue();
804 if (!padValue || dstFillOp.value() != padValue)
805 return failure();
806
809
810 Location loc = insertOp.getLoc();
812
815 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
816
817
818
820 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
822 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
823 }
824
825 RankedTensorType srcPadType = srcPadOp.getSourceType();
827 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
828 if (srcPadType.isDynamicDim(i)) {
829 newSizes.push_back(
830 rewriter.createtensor::DimOp(loc, srcPadOp.getSource(), i)
831 .getResult());
832 } else {
833 newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
834 }
835 }
836
838 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
839 newSizes, insertOp.getMixedStrides());
840 return success();
841 }
842 };
843
844
845 struct FoldFillWithTensorExtract : public OpRewritePatterntensor::ExtractOp {
846 public:
848
849 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
851
852
853 auto fillOp = extractOp.getTensor().getDefiningOplinalg::FillOp();
854 if (!fillOp)
855 return failure();
856
857
858 Value extractedScalar = fillOp.getInputs()[0];
859
860
861 rewriter.replaceOp(extractOp, extractedScalar);
862 return success();
863 }
864 };
865
866
867
868
869 static FailureOr foldFillPackIntoFillOp(RewriterBase &rewriter,
870 linalg::PackOp packOp) {
871 auto fillOp = packOp.getSource().getDefiningOp();
872 if (!fillOp)
873 return failure();
874
875 if (auto paddingValue = packOp.getPaddingValue())
877 return failure();
878
879 Value packOpDest = packOp.getDest();
881 return failure();
882
883 return rewriter.createlinalg::FillOp(packOp.getLoc(), fillOp.getInputs(),
884 packOp.getDest());
885 }
886
887
888 struct FoldFillWithPack : public OpRewritePatternlinalg::PackOp {
889 public:
892
893 LogicalResult matchAndRewrite(linalg::PackOp packOp,
895 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
896 if (failed(fillOp))
897 return failure();
898 rewriter.replaceOp(packOp, fillOp.value().result());
899 return success();
900 }
901 };
902
903
906
907 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
909 if (auto fillOp = copyOp.getInputs().front().getDefiningOp()) {
911 fillOp.getInputs(),
912 copyOp.getOutputs());
913 return success();
914 }
915 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp()) {
916 rewriter.replaceOpWithNewOplinalg::CopyOp(copyOp, copyOp.getInputs(),
917 fillOp.getOutputs());
918 return success();
919 }
920 return failure();
921 }
922 };
923
924
925 struct FoldFillWithTranspose : OpRewritePatternlinalg::TransposeOp {
927
928 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
930 if (auto fillOp = transposeOp.getInput().getDefiningOp()) {
932 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
933 transposeOp.getDpsInitOperand(0)->get());
934 return success();
935 }
936 return failure();
937 }
938 };
939
940
941
942 struct FoldConcatsOfFill : public OpRewritePatterntensor::ConcatOp {
944
945 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
947 auto concatOperands = concatOp.getInputs();
948 if (concatOperands.empty()) {
949 return failure();
950 }
951
952 auto firstFillOp = concatOperands.front().getDefiningOplinalg::FillOp();
953 if (!firstFillOp) {
954 return failure();
955 }
956
959
961 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
962
963 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
964 auto fillOp = v.getDefiningOplinalg::FillOp();
965 if (!fillOp) {
966 return false;
967 }
968
971 if (fillVal != firstFillVal)
972 return false;
973
974 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
975 return true;
976 };
977 if (!llvm::all_of(concatOperands.drop_front(),
978 isDefinedByCompatibleFillOp)) {
980 concatOp, "not all operands are defined by a compatible fill op");
981 }
982
983 Value outsConcat = rewriter.createtensor::ConcatOp(
984 concatOp.getLoc(), concatOp.getDim(), allOuts);
986 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
987 return success();
988 }
989 };
990
991 }
992
993 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
995 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
996 FoldFillWithPack, FoldFillWithPad,
997 FoldFillWithTensorReshapetensor::CollapseShapeOp,
998 FoldFillWithTensorReshapetensor::ExpandShapeOp,
999 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1000 }
1001
1002
1003
1004
1005
1012 for (ValueRange container : {inputs, outputs}) {
1013 for (Value v : container) {
1014 Type t = v.getType();
1015 blockArgTypes.push_back(
1017 blockArgLocs.push_back(v.getLoc());
1018 }
1019 }
1020
1022 Block *bodyBlock =
1023 builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
1024 bodyBuild(builder, loc, bodyBlock->getArguments());
1025 }
1026
1027 void GenericOp::getAsmBlockArgumentNames(Region ®ion,
1029 for (Value v : getRegionInputArgs())
1030 setNameFn(v, "in");
1031 for (Value v : getRegionOutputArgs())
1032 setNameFn(v, "out");
1033 }
1034
1035 void GenericOp::build(
1038 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1041 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1042 iteratorTypes, doc, libraryCall);
1044 if (bodyBuild)
1046 inputs, outputs, bodyBuild);
1047 }
1048
1049 void GenericOp::build(
1053 StringRef libraryCall,
1056 build(builder, result, resultTensorTypes, inputs, outputs,
1058 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1059 iteratorTypes,
1061 return IteratorTypeAttr::get(builder.getContext(), iter);
1062 }))),
1063 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1064 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1065 bodyBuild, attributes);
1066 }
1067
1068 void GenericOp::build(
1072 StringRef libraryCall,
1075 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1076 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1077 }
1078
1079 void GenericOp::build(
1085 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1086 "",
1087 "", bodyBuild, attributes);
1088 }
1089
1090 void GenericOp::build(
1096 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1097 iteratorTypes,
1098 "",
1099 "", bodyBuild, attributes);
1100 }
1101
1103 p << " ";
1104
1105
1106 auto genericAttrNames = linalgTraitAttrNames();
1107
1109 genericAttrNamesSet.insert_range(genericAttrNames);
1111 for (auto attr : (*this)->getAttrs()) {
1112 if (attr.getName() == getIteratorTypesAttrName()) {
1113 auto iteratorTypes =
1114 llvm::cast(attr.getValue())
1115 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1116
1117
1118
1119
1121 llvm::to_vector(llvm::map_range(
1122 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1124 }));
1125
1126 genericAttrs.emplace_back(
1127 getIteratorTypesAttrName(),
1129 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1130 genericAttrs.push_back(attr);
1131 }
1132 }
1133 if (!genericAttrs.empty()) {
1135 p << genericDictAttr;
1136 }
1137
1138
1140
1141 genericAttrNames.push_back("operandSegmentSizes");
1142 genericAttrNamesSet.insert(genericAttrNames.back());
1143
1144 bool hasExtraAttrs = false;
1146 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1147 break;
1148 }
1149 if (hasExtraAttrs) {
1150 p << " attrs = ";
1152 genericAttrNames);
1153 }
1154
1155
1156 if (!getRegion().empty()) {
1157 p << ' ';
1159 }
1160
1161
1163 }
1164
1166 DictionaryAttr dictAttr;
1167
1168
1169
1170
1173 return failure();
1175 dictAttr.getValue().end());
1176
1177
1178
1179
1180
1181 auto iteratorTypes = dyn_cast_or_null(
1183 if (!iteratorTypes) {
1184 return parser.emitError(attributeLocation)
1185 << "expected " << getIteratorTypesAttrName(result.name)
1186 << " array attribute";
1187 }
1188
1190
1191 for (StringRef s : iteratorTypes.getAsValueRange()) {
1192 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1193 if (!maybeIteratorType.has_value())
1195 << "unexpected iterator_type (" << s << ")";
1196
1197 iteratorTypeAttrs.push_back(
1199 }
1202
1203
1206 return failure();
1207
1208
1212 return failure();
1213
1214 std::unique_ptr region = std::make_unique();
1216 return failure();
1217 result.addRegion(std::move(region));
1218
1219
1220
1221
1222
1225 return failure();
1226 result.addTypes(outputTensorsTypes);
1227
1228 return success();
1229 }
1230
1233 &effects,
1234 LinalgOp linalgOp) {
1235 for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1236 if (!llvm::isa(operand.getType()))
1237 continue;
1238 effects.emplace_back(
1241 }
1242
1243 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1244 if (!llvm::isa(operand.get().getType()))
1245 continue;
1246 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1248 true,
1250 }
1252 true,
1254 }
1255 }
1256
1257 void GenericOp::getEffects(
1259 &effects) {
1261 }
1262
1265
1266
1267 if (!linalgOp.hasPureTensorSemantics())
1269
1271 }
1272
1275 }
1276
1278
1279 namespace {
1280
1281
1282
1283
1284
1285
1286
1287 template
1288 struct EraseIdentityLinalgOp : public OpRewritePattern {
1290
1291 LogicalResult matchAndRewrite(OpTy linalgOp,
1293
1294 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1295 return failure();
1296
1297
1298
1299 Block &body = linalgOp->getRegion(0).front();
1300 if (!llvm::hasSingleElement(body))
1301 return failure();
1302 auto yieldOp = dyn_castlinalg::YieldOp(body.getTerminator());
1303 if (!yieldOp)
1304 return failure();
1305
1306
1307 if (linalgOp.hasPureBufferSemantics()) {
1308 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1309 linalgOp.getDpsInputOperand(0)->get() !=
1310 linalgOp.getDpsInitOperand(0)->get()) {
1312 linalgOp, "expected single input and output to be the same value");
1313 }
1314
1315 auto yieldArg = dyn_cast(yieldOp.getOperand(0));
1316 if (!yieldArg || yieldArg.getOwner() != &body) {
1318 "cannot fold fill-like op");
1319 }
1320
1321 rewriter.eraseOp(linalgOp);
1322 return success();
1323 }
1324
1325 if (!linalgOp.hasPureTensorSemantics()) {
1327 linalgOp, "mixed semantics is not supported yet");
1328 }
1329
1330
1331
1333 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1334 auto yieldArg = llvm::dyn_cast(yieldVal.value());
1335 if (!yieldArg || yieldArg.getOwner() != &body)
1336 return failure();
1337 unsigned argumentNumber = yieldArg.getArgNumber();
1338 Value returnedArg = linalgOp->getOperand(argumentNumber);
1339 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1340
1341
1342 Type returnType = returnedArg.getType();
1343 if (returnType != resultType) {
1344
1345
1348 returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1349 linalgOp.getLoc(), resultType, returnedArg);
1350 else {
1351 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1352 resultType))
1353 return failure();
1354 returnedArg = rewriter.createtensor::CastOp(
1355 linalgOp.getLoc(), resultType, returnedArg);
1356 }
1357 }
1358 returnedArgs.push_back(returnedArg);
1359 }
1360
1361 if (returnedArgs.size() != linalgOp->getNumResults())
1362 return failure();
1363 rewriter.replaceOp(linalgOp, returnedArgs);
1364 return success();
1365 }
1366 };
1367
1368 }
1369
1370 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1372 results.add<EraseIdentityLinalgOp>(context);
1373 }
1374
1377 }
1378
1379
1380
1381
1382
1386 nullptr) {
1387
1390 false))
1391 return failure();
1392
1393
1394 for (Type outputType : outputTypes) {
1395 if (llvm::isa(outputType))
1396 result.addTypes(outputType);
1397 }
1398
1399
1400 if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1401 return failure();
1402
1403
1405 return failure();
1406 return success();
1407 }
1408
1409 void MapOp::getAsmBlockArgumentNames(Region ®ion,
1411 for (Value v : getRegionInputArgs())
1412 setNameFn(v, "in");
1413 }
1414
1415 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1416 if (!getResults().empty())
1417 setNameFn(getResults().front(), "mapped");
1418 }
1419
1420 void MapOp::build(
1424 build(builder, result, TypeRange{}, inputs, init);
1426
1427
1429 if (llvm::isa(initType))
1431
1432 if (bodyBuild)
1434 inputs, {}, bodyBuild);
1435 }
1436
1441 bool initFirst = false) {
1446 for (auto &operand : operands) {
1448 llvm::cast(operand.getType()).getElementType(),
1450 }
1452
1453
1454 if (initFirst) {
1455 payloadOpOperands.push_back(block.getArguments().back());
1456 for (const auto &arg : block.getArguments().drop_back())
1457 payloadOpOperands.push_back(arg);
1458 } else {
1459 payloadOpOperands = {block.getArguments().begin(),
1461 }
1462
1465 payloadOpOperands,
1466 TypeRange{llvm::cast(result.operands.back().getType())
1467 .getElementType()},
1468 payloadOpAttrs);
1470 }
1471
1473 std::optional payloadOpName;
1477 if (failed(operationName))
1478 return failure();
1480 return failure();
1481 payloadOpName = operationName.value();
1483 return failure();
1484 }
1485
1487 return failure();
1488
1489 if (payloadOpName.has_value()) {
1490 if (!result.operands.empty())
1492 payloadOpAttrs,
1494 else
1496 } else {
1499 true, true)) {
1500 return failure();
1501 }
1503 if (parser.parseRegion(*body, regionArgs))
1504 return failure();
1505 }
1506 return success();
1507 }
1508
1509
1510
1511
1512
1515 return nullptr;
1517 assert(isa(body->getOperations().back()));
1518
1521 return nullptr;
1522 if (initFirst) {
1523
1525 return nullptr;
1526
1527 for (const auto &[operand, bbArg] :
1529 if (bbArg != operand)
1530 return nullptr;
1531 }
1532 } else {
1533 for (const auto &[operand, bbArg] :
1535 if (bbArg != operand)
1536 return nullptr;
1537 }
1538 }
1539 return &payload;
1540 }
1541
1544 std::string attrToElide;
1546 for (const auto &attr : payloadOp->getAttrs()) {
1547 auto fastAttr =
1548 llvm::dyn_castmlir::arith::FastMathFlagsAttr(attr.getValue());
1549 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1550 attrToElide = attr.getName().str();
1551 elidedAttrs.push_back(attrToElide);
1552 break;
1553 }
1554 }
1556 p << " }";
1557 }
1558
1560 Block *mapper = getBody();
1562 if (payloadOp) {
1564 }
1565
1568
1569 if (!payloadOp) {
1570
1573 p << "(";
1574 llvm::interleaveComma(mapper->getArguments(), p,
1575 [&](auto arg) { p.printRegionArgument(arg); });
1576 p << ") ";
1577
1578 p.printRegion(getMapper(), false);
1580 }
1581 }
1582
1584 auto *bodyBlock = getBody();
1585 auto blockArgs = bodyBlock->getArguments();
1586
1587
1588 if (getInputs().size() != blockArgs.size())
1589 return emitOpError() << "expects number of operands to match the arity of "
1590 "mapper, but got: "
1591 << getInputs().size() << " and " << blockArgs.size();
1592
1593
1594 for (const auto &[bbArgType, inputArg] :
1595 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1596 auto inputElemType =
1597 llvm::cast(inputArg.getType()).getElementType();
1598 if (bbArgType != inputElemType) {
1599 return emitOpError() << "expected element type of input " << inputElemType
1600 << " to match bbArg type " << bbArgType;
1601 }
1602 }
1603
1604
1605 auto outputShape = getInit().getType().getShape();
1606 for (Type inputArgType : TypeRange{getInputs()}) {
1607 auto inputElemShape = llvm::cast(inputArgType).getShape();
1608 if (inputElemShape != outputShape) {
1609 return emitOpError() << "expected shape of input (" << inputElemShape
1610 << ") to match shape of output (" << outputShape
1611 << ")";
1612 }
1613 }
1614
1615 return success();
1616 }
1617
1619 int64_t rank = getInit().getType().getRank();
1621 }
1622
1623 ArrayAttr MapOp::getIndexingMaps() {
1625 int64_t rank = getInit().getType().getRank();
1626 int64_t numIndexingMaps = getOperands().size();
1629 }
1630
1631 void MapOp::getEffects(
1633 &effects) {
1635 }
1636
1639 }
1640
1641
1642
1643
1644
1645 void ReduceOp::getAsmBlockArgumentNames(Region ®ion,
1647 for (Value v : getRegionInputArgs())
1648 setNameFn(v, "in");
1649 for (Value v : getRegionOutputArgs())
1650 setNameFn(v, "init");
1651 }
1652
1653 void ReduceOp::getAsmResultNames(
1655 if (!getResults().empty())
1656 setNameFn(getResults().front(), "reduced");
1657 }
1658
1659 void ReduceOp::build(
1664 build(builder, result, TypeRange{}, inputs, inits, dimensions);
1666
1667
1668 for (Value init : inits) {
1670 if (llvm::isa(initType))
1672 }
1673
1674 if (bodyBuild)
1676 inputs, inits, bodyBuild);
1677 }
1678
1680 int64_t inputRank =
1681 llvm::cast(getInputs()[0].getType()).getRank();
1683 utils::IteratorType::parallel);
1684 for (int64_t reductionDim : getDimensions())
1685 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1686 return iteratorTypes;
1687 }
1688
1689 ArrayAttr ReduceOp::getIndexingMaps() {
1690 int64_t inputRank =
1691 llvm::cast(getInputs()[0].getType()).getRank();
1693 getNumDpsInputs(),
1698 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1699 affineMaps.push_back(resultMap);
1701 }
1702
1703 void ReduceOp::getEffects(
1705 &effects) {
1707 }
1708
1711 }
1712
1715 StringRef attributeName) {
1717 return failure();
1718
1720 return success();
1721 }
1722
1724 std::optional payloadOpName;
1728 if (failed(operationName))
1729 return failure();
1731 return failure();
1732 payloadOpName = operationName.value();
1734 return failure();
1735 }
1736
1740 }))
1741 return failure();
1742
1743 if (payloadOpName.has_value()) {
1746 } else {
1749 true, true)) {
1750 return failure();
1751 }
1752
1754 if (parser.parseRegion(*body, regionArgs))
1755 return failure();
1756 }
1757
1758 return success();
1759 }
1760
1763 p << ' ' << attributeName << " = [" << attributeValue << "] ";
1764 }
1765
1767 Block *mapper = getBody();
1769 if (payloadOp) {
1771 }
1772
1776 if (!payloadOp) {
1777
1780 p << "(";
1781 llvm::interleaveComma(mapper->getArguments(), p,
1782 [&](auto arg) { p.printRegionArgument(arg); });
1783 p << ") ";
1784
1785 p.printRegion(getCombiner(), false);
1787 }
1788 }
1789
1792
1793 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1794 if (llvm::cast(getInputs()[i].getType()).getShape() !=
1795 llvm::cast(getInputs()[0].getType()).getShape()) {
1796 return emitOpError() << "expects all inputs to have the same shapes. "
1797 "Shape at input-index "
1798 << i
1799 << " is not equal to the shape at input-index 0.";
1800 }
1801 }
1802 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1803 if (llvm::cast(getInits()[i].getType()).getShape() !=
1804 llvm::cast(getInits()[0].getType()).getShape()) {
1805 return emitOpError() << "expects all outputs to have the same shapes. "
1806 "Shape at output-index "
1807 << i
1808 << " is not equal to the shape at output-index 0.";
1809 }
1810 }
1811 auto inputType = llvm::cast(getInputs()[0].getType());
1812 auto initType = llvm::cast(getInits()[0].getType());
1813
1815 for (int64_t dimension : dimensionsRef) {
1816 if (dimension < 0 || dimension >= inputType.getRank()) {
1817 return emitOpError()
1818 << "dimensions for reduction should be in the range [0, "
1819 << inputType.getRank() - 1 << "].";
1820 }
1821 dimensionsToReduce.insert(dimension);
1822 }
1823
1824 auto inputDims = inputType.getShape();
1825 auto initDims = initType.getShape();
1826
1827
1830 if (!dimensionsToReduce.count(en.index()))
1831 reducedInputDims.push_back(en.value());
1832 }
1833
1834 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1835 return emitOpError() << "number of dimensions after reduction "
1836 << reducedInputDims.size()
1837 << " doesn't match the init rank "
1838 << initType.getRank();
1839 }
1840
1841 if (reducedInputDims != initDims)
1842 return emitOpError() << "init dimensions [" << initDims
1843 << "] doesn't match input dimensions after reduction ["
1844 << reducedInputDims << "]";
1845
1846 Block *block = getBody();
1848 return emitOpError()
1849 << "mismatching number of operands and block arguments";
1850
1851
1852 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1853 Type inputElementType =
1854 llvm::cast(input.getType()).getElementType();
1855 if (inputElementType != bbArg.getType())
1856 return emitOpError()
1857 << "input element type " << inputElementType
1858 << " does not match corresponding block argument type "
1859 << bbArg.getType();
1860 }
1861
1862
1863 for (auto [output, bbArg] : llvm::zip(
1864 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1865 auto outputElementType =
1866 llvm::cast(output.getType()).getElementType();
1867 if (outputElementType != bbArg.getType())
1868 return emitOpError()
1869 << "output element type " << outputElementType
1870 << " does not match corresponding block argument type "
1871 << bbArg.getType();
1872 }
1873 return success();
1874 }
1875
1876
1877
1878
1879
1885 if (!args.empty())
1886 b.createlinalg::YieldOp(loc, args[0]);
1887 });
1888 }
1889
1896 result.addAttribute(getPermutationAttrName(result.name), permutation);
1898
1899
1901 if (llvm::isa(initType))
1903
1905 init);
1906 }
1907
1912 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1913 attributes);
1914 }
1915
1920 })))
1921 return failure();
1922
1926 {});
1927 return success();
1928 }
1929
1930 void TransposeOp::getAsmResultNames(
1932 if (!getResults().empty())
1933 setNameFn(getResults().front(), "transposed");
1934 }
1935
1940 }
1941
1944
1946 return emitOpError("permutation is not valid");
1947
1948 auto inputType = getInput().getType();
1949 auto initType = getInit().getType();
1950
1951 int64_t rank = inputType.getRank();
1952
1953 if (rank != initType.getRank())
1954 return emitOpError() << "input rank " << rank
1955 << " does not match init rank " << initType.getRank();
1956
1957 if (rank != static_cast<int64_t>(permutationRef.size()))
1958 return emitOpError() << "size of permutation " << permutationRef.size()
1959 << " does not match the argument rank " << rank;
1960
1961 auto inputDims = inputType.getShape();
1962 auto initDims = initType.getShape();
1963
1964 for (int64_t i = 0; i < rank; ++i) {
1965 int64_t inputDim = inputDims[permutationRef[i]];
1966 int64_t initDim = initDims[i];
1967
1968 if (inputDim != initDim) {
1969 return emitOpError() << "dim(result, " << i << ") = " << initDim
1970 << " doesn't match dim(input, permutation[" << i
1971 << "]) = " << inputDim;
1972 }
1973 }
1974
1975 return success();
1976 }
1977
1979 int64_t rank = getInit().getType().getRank();
1981 }
1982
1983 ArrayAttr TransposeOp::getIndexingMaps() {
1985 int64_t rank = getInit().getType().getRank();
1988 llvm::to_vector_of(getPermutation()), getContext())),
1990 }
1991
1992 void TransposeOp::getEffects(
1994 &effects) {
1996 }
1997
2000 }
2001
2002 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2004
2005 if (!isa(getInput().getType()))
2006 return failure();
2007
2008
2009 if (getPermutation().size() == 0) {
2010 result.push_back(getInput());
2011 return success();
2012 }
2013
2015 result.push_back(getInput());
2016 return success();
2017 }
2018
2019 return failure();
2020 }
2021
2022
2025
2028 auto defTransposeOp = transposeOp.getInput().getDefiningOp();
2029 if (!defTransposeOp)
2030 return failure();
2034 foldedPerms.reserve(perms.size());
2035 for (int64_t perm : perms)
2036 foldedPerms.push_back(defPerms[perm]);
2037
2039 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2040 foldedPerms);
2041 return success();
2042 }
2043 };
2044
2045
2046
2047
2050
2053 Value input = transposeOp.getInput();
2054 BroadcastOp broadcastOp = input.getDefiningOp();
2055 if (!input.hasOneUse() || !broadcastOp)
2056 return failure();
2057
2060
2061
2065 unsigned dimensionSize = dimensions.size();
2066 for (unsigned i = 0; i < dimensionSize; ++i)
2067 resultDimensions.push_back(invertPerm[dimensions[i]]);
2068
2069
2070 Value broadcastInput = broadcastOp.getInput();
2071 Location loc = transposeOp.getLoc();
2072 MLIRContext *ctx = transposeOp.getContext();
2074 auto broadcastInputTy =
2075 mlir::cast(broadcastInput.getType());
2076 unsigned inputRank = broadcastInputTy.getRank();
2077 for (unsigned i = 0; i < inputRank; ++i) {
2078 if (broadcastInputTy.isDynamicDim(i)) {
2079 dims.push_back(rewriter.createtensor::DimOp(loc, broadcastInput, i)
2081 } else {
2083 broadcastInputTy.getDimSize(i)));
2084 }
2085 }
2088 Value transposeInit = rewriter.createtensor::EmptyOp(
2089 transposeOp.getLoc(), transposeResultShapes,
2090 broadcastInputTy.getElementType());
2091
2092
2093 Value transposeResult =
2094 rewriter
2095 .create(loc, broadcastOp.getInput(), transposeInit,
2096 resultPerms)
2097 ->getResult(0);
2099 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2100 return success();
2101 }
2102 };
2103
2104 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2107 }
2108
2109
2110
2111
2112
2119 result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2121
2122
2124 if (llvm::isa(initType))
2126
2128 init);
2129 }
2130
2135 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2136 attributes);
2137 }
2138
2143 })))
2144 return failure();
2145
2149 {});
2150 return success();
2151 }
2152
2153 void BroadcastOp::getAsmResultNames(
2155 if (!getResults().empty())
2156 setNameFn(getResults().front(), "broadcasted");
2157 }
2158
2163 }
2164
2167
2168 auto inputType = getInput().getType();
2169 auto initType = getInit().getType();
2170
2171 int64_t inputRank = inputType.getRank();
2172 int64_t initRank = initType.getRank();
2173
2174 auto inputShape = inputType.getShape();
2175 auto initShape = initType.getShape();
2176
2177 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2178 return emitOpError() << "input rank plus added dimensions does not "
2179 "match init rank. input rank: "
2180 << inputRank
2181 << ", dimensions size: " << dimensionsRef.size()
2182 << ", init rank: " << initRank;
2183
2184 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2185 if (dim < 0 || dim >= initRank)
2186 return emitOpError() << "dimension " << idx
2187 << " is out of range. expected range: [0, "
2188 << initRank - 1 << "], got: " << dim;
2189 }
2190
2191
2193 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2194 if (!llvm::is_contained(dimensionsRef, dim))
2195 dimMap.push_back(dim);
2196 }
2197
2198 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2199
2200
2201 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2202 return emitOpError() << "input dim " << inputDimIdx
2203 << " should match init dim " << initDimIdx
2204 << ". input: " << inputShape[inputDimIdx]
2205 << ", init: " << initShape[initDimIdx];
2206 }
2207
2208 return success();
2209 }
2210
2212 int64_t rank = getInit().getType().getRank();
2214 }
2215
2216 ArrayAttr BroadcastOp::getIndexingMaps() {
2218 int64_t rank = getInit().getType().getRank();
2222 }
2223
2224 void BroadcastOp::getEffects(
2226 &effects) {
2228 }
2229
2232 }
2233
2234 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2236 results.add<EraseIdentityLinalgOp>(context);
2237 }
2238
2239
2240
2241
2242
2244 if (getNumOperands() > 0)
2245 p << ' ' << getOperands();
2247 if (getNumOperands() > 0)
2248 p << " : " << getOperandTypes();
2249 }
2250
2259 }
2260
2261
2262
2263 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2264 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2265 return op.emitOpError("expected number of yield values (")
2266 << op.getNumOperands()
2267 << ") to match the number of inits / outs operands of the enclosing "
2268 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2269
2270 for (OpOperand &opOperand : op->getOpOperands()) {
2272 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2274 if (isa<MemRefType, RankedTensorType>(elementType))
2276 if (opOperand.get().getType() != elementType)
2277 return op.emitOpError("type of yield operand ")
2278 << (opOperand.getOperandNumber() + 1) << " ("
2279 << opOperand.get().getType() << ") doesn't match "
2280 << "the element type of the enclosing linalg.generic op ("
2281 << elementType << ")";
2282 }
2283 return success();
2284 }
2285
2287 auto *parentOp = (*this)->getParentOp();
2288 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2289 return emitOpError("expected single non-empty parent region");
2290
2291 if (auto linalgOp = dyn_cast(parentOp))
2293
2294 return emitOpError("expected parent op with LinalgOp interface");
2295 }
2296
2297
2298
2299
2300
2302 auto linalgOp = dyn_cast((*this)->getParentOp());
2303 if (!linalgOp)
2304 return emitOpError("expected parent op with LinalgOp interface");
2305 if (linalgOp.getNumLoops() <= getDim())
2306 return emitOpError("expected dim (")
2307 << getDim() << ") to be lower than the number of loops ("
2308 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2309 return success();
2310 }
2311
2312 OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2313 auto linalgOp = dyn_cast_or_null((*this)->getParentOp());
2314
2315
2316
2317 if (!linalgOp)
2319
2320
2322 uint64_t dim = getDim();
2323 assert(dim < loopBounds.size() && "Dim is out of bounds");
2324 if (loopBounds[dim] == 1)
2326
2328 }
2329
2330
2331
2332 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2333
2334 #define GET_OP_CLASSES
2335 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2336
2337 #define GET_OP_CLASSES
2338 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2339 #define GET_OP_CLASSES
2340 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2341
2343 unsigned rank,
2345 if (maybeMap)
2346 return *maybeMap;
2347 if (rank == 0)
2350 }
2351
2356 res.reserve(num);
2357 for (unsigned i = 0; i < num; ++i)
2359 return res;
2360 }
2361
2364 auto rangeA = llvm::make_range(a.begin(), a.end());
2365 auto rangeB = llvm::make_range(b.begin(), b.end());
2366 auto concatRanges = llvm::concat(rangeA, rangeB);
2367 return llvm::to_vector<4>(concatRanges);
2368 }
2369
2371 if (auto memref = llvm::dyn_cast(t)) {
2372 ss << "view";
2373 for (auto size : memref.getShape())
2374 if (size < 0)
2375 ss << "sx";
2376 else
2377 ss << size << "x";
2379 return failure();
2380 if (auto as = memref.getMemorySpace()) {
2381 if (auto attr = llvm::dyn_cast(as))
2382 ss << "as" << attr.getInt();
2383 else
2384 return failure();
2385 }
2386 return success();
2387 }
2388 if (auto vec = llvm::dyn_cast(t)) {
2389 ss << "vector";
2390 llvm::interleave(
2391 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2393 return failure();
2394 return success();
2395 }
2397 ss << t;
2398 return success();
2399 }
2400 return failure();
2401 }
2402
2404 assert(isa(op));
2406 std::string fun = "";
2408 if (UnaryFnAttr ufa = llvm::dyn_cast(kv.getValue())) {
2409 fun = stringifyEnum(ufa.getValue()).str() + "_";
2410 } else if (BinaryFnAttr bfa = llvm::dyn_cast(kv.getValue())) {
2411 fun = stringifyEnum(bfa.getValue()).str() + "_";
2412 }
2413 }
2414 name.reserve(128);
2415 llvm::replace(name, '.', '_');
2416 llvm::raw_string_ostream ss(name);
2417 ss << "_" << fun;
2420 return std::string();
2421 ss << "_";
2422 }
2423 name.pop_back();
2424 return name;
2425 }
2426
2427
2428
2429
2430
2431 namespace {
2434
2435 LogicalResult matchAndRewrite(LinalgOp op,
2437 for (OpOperand &opOperand : op->getOpOperands()) {
2438
2439
2440
2441 auto mt = llvm::dyn_cast(opOperand.get().getType());
2442 if (!mt)
2443 continue;
2444 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2446 return success();
2447 }
2448 }
2449 return failure();
2450 }
2451 };
2452
2453
2454
2455 struct FoldTensorCastConsumerOp : public OpRewritePatterntensor::CastOp {
2457
2458 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2461 return failure();
2462
2463 auto linalgOp = castOp.getSource().getDefiningOp();
2464 if (!linalgOp)
2465 return failure();
2466
2467
2468
2469
2470 if (castOp->getBlock() != linalgOp->getBlock())
2471 return failure();
2472
2475
2476 Location loc = linalgOp.getLoc();
2477 OpResult resultValue = llvm::cast(castOp.getSource());
2479 auto resultType =
2480 llvm::cast(castOp->getResult(0).getType());
2481
2482
2483
2484
2485
2486 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2487 Value newOperand =
2488 rewriter.createtensor::CastOp(loc, resultType, outOperand->get());
2491 linalgOp.getDpsInits().end());
2492 outputOperands[resultNumber] = newOperand;
2493 newOperands.append(outputOperands.begin(), outputOperands.end());
2494
2496 linalgOp->result_type_end());
2497 resultTypes[resultNumber] = resultType;
2498 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2499
2500
2501 Value castBack = rewriter.createtensor::CastOp(
2502 loc, resultValue.getType(), newOp->getResult(resultNumber));
2503
2505 results[resultNumber] = castBack;
2506 rewriter.replaceOp(linalgOp, results);
2508 return success();
2509 }
2510 };
2511
2512
2513
2516 for (OpOperand &opOperand : operands) {
2517 if (linalgOp.isScalar(&opOperand))
2518 continue;
2519 Value src = opOperand.get();
2520 auto sourceType = llvm::cast(src.getType());
2521 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2522
2523
2524
2525
2528 if (parentOp) {
2529 if (auto castOp = dyn_casttensor::CastOp(parentOp)) {
2530 Value castSource = castOp.getSource();
2531 auto castSourceType =
2532 llvm::dyn_cast(castSource.getType());
2533 if (castSourceType && castSourceType.hasStaticShape())
2534 sourceShape = castSourceType.getShape();
2535 }
2536 }
2537
2538
2539
2540 for (unsigned i = 0; i < sourceShape.size(); i++) {
2541 if (sourceType.isDynamicDim(i))
2542 continue;
2543 if (auto affineDimExpr = dyn_cast(sourceMap.getResult(i)))
2544 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2545 }
2546 }
2547 }
2548
2549
2550
2551
2552
2553
2554 static void createNewOperandWithStaticSizes(
2558 bool &changeNeeded) {
2559 Value src = opOperand->get();
2560 newOperands.push_back(src);
2561 if (linalgOp.isScalar(opOperand))
2562 return;
2563 auto sourceType = llvm::cast(src.getType());
2564 Type resultType = sourceType;
2565 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2566 resultTypes.push_back(resultType);
2567 return;
2568 }
2570 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2572
2573
2574 bool newOperandNeeded = false;
2575 for (unsigned i = 0; i < sourceShape.size(); i++) {
2576 int64_t dimShape = sourceShape[i];
2578 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2579 newShape.push_back(dimShape);
2580 continue;
2581 }
2582
2583
2584
2585 newShape.push_back(affineExprToSize[dimExpr]);
2586 newOperandNeeded = true;
2587 }
2589 sourceType.getEncoding());
2590 if (newOperandNeeded) {
2591 changeNeeded = true;
2592
2593
2594 Value newOperand = rewriter.createtensor::CastOp(loc, resultType, src);
2596 newOperands[index] = newOperand;
2597 }
2598 if (linalgOp.isDpsInit(opOperand))
2599 resultTypes.push_back(resultType);
2600 }
2601
2602
2603
2604
2607
2608 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2610 if (!linalgOp.hasPureTensorSemantics())
2611 return failure();
2612
2613
2614 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2615 return !map.isProjectedPermutation();
2616 }))
2617 return failure();
2618
2619
2621 Location loc = linalgOp.getLoc();
2622
2623
2624
2625 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2626
2629
2630
2631
2632 bool changeNeeded = false;
2633 newOperands.reserve(linalgOp->getNumOperands());
2634 resultTypes.reserve(linalgOp.getNumDpsInits());
2635
2636
2637 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2638 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2639 affineExprToSize, linalgOp, newOperands,
2640 resultTypes, changeNeeded);
2641 }
2642
2643
2644
2645 if (!changeNeeded)
2646 return failure();
2647
2648
2649 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2652 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2653 Value newResult = std::get<1>(it);
2654 Value oldResult = std::get<0>(it);
2657 replacements.push_back(
2658 (newType != oldType)
2659 ? rewriter.createtensor::CastOp(loc, oldType, newResult)
2660 : newResult);
2661 }
2662 rewriter.replaceOp(linalgOp, replacements);
2663 return success();
2664 }
2665 };
2666
2667 }
2668
2669
2670
2671
2672
2673
2674
2675
2677 ShapedType inputType = getInputOperandType();
2678 ShapedType outputType = getOutputOperandType();
2679
2683 return emitOpError("incompatible output shape");
2684
2685 int64_t inputRank = getInputOperandRank();
2686 int64_t dimension = getDimension();
2687 if ((dimension < 0) || (dimension >= inputRank))
2688 return emitOpError("incorrect dimension specified");
2689
2690 return success();
2691 }
2692
2694 int64_t operandRank = getInputOperandRank();
2697 Value zero = builder.createarith::ConstantIndexOp(loc, 0);
2698 Value one = builder.createarith::ConstantIndexOp(loc, 1);
2699 Value source = getInput();
2700 for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2701 loopBounds[dim].offset = zero;
2702 loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2703 loopBounds[dim].stride = one;
2704 }
2705 return loopBounds;
2706 }
2707
2710 utils::IteratorType::parallel);
2711 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2712 return iteratorTypes;
2713 }
2714
2715 FailureOr
2719 int64_t rank = getInputOperandRank();
2724 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2725 if (!inputSlice) {
2726 return emitOpError("failed to compute input slice");
2727 }
2728 tiledOperands.emplace_back(inputSlice->getResult(0));
2730 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2731 if (!outputSlice) {
2732 return emitOpError("failed to compute output slice");
2733 }
2734 tiledOperands.emplace_back(outputSlice->getResult(0));
2735
2737 if (hasPureTensorSemantics())
2738 resultTypes.push_back(tiledOperands[1].getType());
2740 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2741
2743 {tiledOp},
2746 }
2747
2752 if (resultNumber == 0) {
2753 resultOffsets.assign(offsets.begin(), offsets.end());
2754 resultSizes.assign(sizes.begin(), sizes.end());
2755 return success();
2756 }
2757 return failure();
2758 }
2759
2760
2763 }
2764
2765 LogicalResult
2769 Location loc = getOperation()->getLoc();
2771 auto inputShapedType = llvm::cast(getInputOperandType());
2772 auto outputShapedType = llvm::cast(getOutputOperandType());
2773 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2774 if (!outputShapedType.isDynamicDim(dim)) {
2775
2776 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2777 } else {
2778
2781 }
2782 }
2783 reifiedReturnShapes.emplace_back(std::move(shapes));
2784 return success();
2785 }
2786
2787 void SoftmaxOp::getEffects(
2789 &effects) {
2790 for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2791 if (!llvm::isa(operand.getType()))
2792 continue;
2794 &getOperation()->getOpOperand(index), 0,
2795 true,
2797 }
2798
2799 for (OpOperand &operand : getDpsInitsMutable()) {
2800 if (!llvm::isa(operand.get().getType()))
2801 continue;
2803 true,
2806 true,
2808 }
2809 }
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2833 int64_t dim, bool allParallel = false) {
2835 utils::IteratorType::parallel);
2836 if (!allParallel)
2837 iteratorTypes[dim] = utils::IteratorType::reduction;
2841 for (int i = 0; i < inputRank; i++) {
2842 if (i != dim)
2844 }
2845 auto reductionMap =
2846 AffineMap::get(inputRank, 0, affineExprs, ctxt);
2848 return std::make_tuple(iteratorTypes, indexingMaps);
2849 }
2850
2851
2852
2853 template
2855 int64_t dim) {
2856 auto inputType = cast(input.getType());
2858 int64_t inputRank = inputShape.size();
2859 auto [iteratorTypes, indexingMaps] =
2861 assert(indexingMaps.size() == 2 &&
2862 "We should have two maps: 1 for the input, 1 for the output");
2863 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2864
2865 auto genericOp = builder.createlinalg::GenericOp(
2866 loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2868 Value result = b.create(loc, args[0], args[1]);
2869 b.createlinalg::YieldOp(loc, result);
2870 });
2872 }
2873
2874
2875
2876
2879 auto inputType = cast(input.getType());
2881 int64_t inputRank = inputShape.size();
2883 builder, inputRank, dim, true);
2884 assert(indexingMaps.size() == 2 && "We should have one map for each input");
2885 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2886
2887 indexingMaps.push_back(indexingMaps[0]);
2888 auto genericOp = builder.createlinalg::GenericOp(
2891 Value diff = b.createarith::SubFOp(loc, args[0], args[1]);
2892 Value result = b.createmath::ExpOp(loc, diff);
2893 b.createlinalg::YieldOp(loc, result);
2894 });
2896 }
2897
2898
2899
2900
2901
2902
2904 Value denominator, Value output, int64_t dim) {
2905 auto inputType = cast(numerator.getType());
2907 int64_t inputRank = inputShape.size();
2909 builder, inputRank, dim, true);
2910 assert(indexingMaps.size() == 2 &&
2911 "We should have one map for each input (2)");
2912 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2913
2914 indexingMaps.push_back(indexingMaps[0]);
2915 auto genericOp = builder.createlinalg::GenericOp(
2916 loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2917 indexingMaps, iteratorTypes,
2919 Value result = b.createarith::DivFOp(loc, args[0], args[1]);
2920 b.createlinalg::YieldOp(loc, result);
2921 });
2923 }
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944 FailureOr<SmallVector> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2948 Value input = getInput();
2949 ShapedType inputType = getInputOperandType();
2950 Type elementType = inputType.getElementType();
2951 int64_t reductionDim = getDimension();
2953 Value output = getOutput();
2954 dims.erase(dims.begin() + reductionDim);
2955
2956 Value outputReduce = b.createtensor::EmptyOp(loc, dims, elementType);
2958 elementType, b, loc,
2959 true);
2960 Value neutralForMaxFInit =
2961 b.createlinalg::FillOp(loc, Value{neutralForMaxF}, outputReduce)
2962 .result();
2964 reducearith::MaxNumFOp(b, loc, input, neutralForMaxFInit, reductionDim);
2965
2966
2968
2969
2971 b, loc, true);
2972 Value zeroInit =
2973 b.createlinalg::FillOp(loc, Value{zero}, outputReduce).result();
2974 Value denominator =
2975 reducearith::AddFOp(b, loc, numerator, zeroInit, reductionDim);
2976
2977
2979 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2981 }
2982
2983
2984
2985
2986
2988 auto filterType = cast(getFilter().getType());
2990 int64_t filterH = filterShape[getFilterHDim()];
2991 int64_t filterW = filterShape[getFilterWDim()];
2992 int64_t r = getR();
2993 int64_t m = getM();
2994
2995 if (filterH != r && filterH != 1)
2996 return emitOpError("expect filter height either equals to r or 1");
2997 if (filterW != r && filterW != 1)
2998 return emitOpError("expect filter width either equals to r or 1");
2999 if (filterH == 1 && filterW == 1)
3000 return emitOpError("expect either filter height or width equals to r");
3001
3003 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3004 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3005 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3006 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3007
3008 auto outputType = cast(getOutput().getType());
3011 return emitOpError("the output shape is not expected");
3012 }
3013 return success();
3014 }
3015
3017 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3019 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3020 IntegerAttr oneAttr = builder.getIndexAttr(1);
3021 Value filter = getFilter();
3022 int64_t filterRank = getFilterOperandRank();
3024 for (unsigned dim = 0; dim < filterRank; ++dim) {
3025 loopBounds[dim].offset = zeroAttr;
3026 loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3027 loopBounds[dim].stride = oneAttr;
3028 }
3029 return loopBounds;
3030 }
3031
3033 WinogradFilterTransformOp::getLoopIteratorTypes() {
3034 int64_t filterRank = getFilterOperandRank();
3036 utils::IteratorType::parallel);
3037 return iteratorTypes;
3038 }
3039
3045 ShapedType filterType = getFilterOperandType();
3047 int64_t filterH = filterShape[getFilterHDim()];
3048 int64_t filterW = filterShape[getFilterWDim()];
3049 int64_t m = getM();
3050 int64_t r = getR();
3051 int64_t alpha = m + r - 1;
3052 int64_t alphaH = filterH != 1 ? alpha : 1;
3053 int64_t alphaW = filterW != 1 ? alpha : 1;
3056
3057 resultOffsets.append(
3058 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3059 resultSizes.append(
3060 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3061
3062 return success();
3063 }
3064
3065
3066
3067
3068
3069
3070
3076 ShapedType filterType = getFilterOperandType();
3078 int64_t filterH = filterShape[getFilterHDim()];
3079 int64_t filterW = filterShape[getFilterWDim()];
3084
3085 sliceOffsets.append(
3086 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3087 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3088 sizes[getFilterCDim()]});
3089 int64_t filterRank = getFilterOperandRank();
3092 auto filterSlice = builder.createtensor::ExtractSliceOp(
3093 loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3094 tiledOperands.emplace_back(filterSlice);
3095
3098 resultSizes)))
3099 return failure();
3100
3101 int64_t outputRank = getOutputOperandRank();
3103 auto outputSlice = builder.createtensor::ExtractSliceOp(
3104 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3105 tiledOperands.emplace_back(outputSlice);
3106
3108 resultTypes.push_back(tiledOperands[1].getType());
3110 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3111
3113 {tiledOp},
3116 }
3117
3118
3119
3120
3121
3123 auto inputType = cast(getInput().getType());
3125 int64_t inputH = inputShape[getInputHDim()];
3126 int64_t inputW = inputShape[getInputWDim()];
3127 int m = getM();
3128 int r = getR();
3129 int64_t tileSize = m + r - 1;
3130
3131 auto outputType = cast(getOutput().getType());
3133 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3134 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3135
3137 if (ShapedType::isDynamic(inputH)) {
3138 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3139 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3140 } else {
3141 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3142 expectedOutputShape[getOutputTileHDim()] =
3143 leftTransform ? (inputH - (r - 1)) / m : inputH;
3144 }
3145 if (ShapedType::isDynamic(inputW)) {
3146 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3147 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3148 } else {
3149 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3150 expectedOutputShape[getOutputTileWDim()] =
3151 rightTransform ? (inputW - (r - 1)) / m : inputW;
3152 }
3153 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3154 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3155
3157 return emitOpError("the output shape is not expected");
3158 }
3159 return success();
3160 }
3161
3163 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3165 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3166 IntegerAttr oneAttr = builder.getIndexAttr(1);
3167 Value output = getOutput();
3168 int64_t outputRank = getOutputOperandRank();
3170 for (unsigned dim = 0; dim < outputRank; ++dim) {
3171 loopBounds[dim].offset = zeroAttr;
3172
3173 loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3174 loopBounds[dim].stride = oneAttr;
3175 }
3176 return loopBounds;
3177 }
3178
3180 WinogradInputTransformOp::getLoopIteratorTypes() {
3181 int64_t outputRank = getOutputOperandRank();
3183 utils::IteratorType::parallel);
3184 return iteratorTypes;
3185 }
3186
3192 ShapedType outputType = getOutputOperandType();
3194 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3195 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3196
3197 int64_t m = getM();
3198 int64_t r = getR();
3199 int64_t alpha = m + r - 1;
3200 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3201 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3202
3205
3206 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3207 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3208 offsets[getOutputCDim()]});
3209 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3210 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3211 sizes[getOutputCDim()]});
3212
3213 return success();
3214 }
3215
3216
3217
3218
3219
3220
3221
3222 FailureOr
3227 int64_t m = getM();
3228 int64_t r = getR();
3229
3230 ShapedType outputType = getOutputOperandType();
3232 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3233 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3234
3237 auto identityAffineMap =
3239 auto offsetAffineMap =
3242 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3243 offsets[getOutputTileHDim()]);
3245 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3246 offsets[getOutputTileWDim()]);
3248 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3250 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3252 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3253
3256
3259 sliceOffsets.append(
3260 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3265 sliceSizes.append(
3266 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3267 int64_t inputRank = getInputOperandRank();
3269 auto inputSlice = builder.createtensor::ExtractSliceOp(
3270 loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3271 tiledOperands.emplace_back(inputSlice);
3272
3275 resultSizes)))
3276 return failure();
3277
3278 int64_t outputRank = getOutputOperandRank();
3280 auto outputSlice = builder.createtensor::ExtractSliceOp(
3281 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3282 tiledOperands.emplace_back(outputSlice);
3283
3285 resultTypes.push_back(tiledOperands[1].getType());
3287 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3288
3290 {tiledOp},
3293 }
3294
3295
3296
3297
3298
3300 auto valueType = cast(getValue().getType());
3302 int64_t valueH = valueShape[getValueAlphaHDim()];
3303 int64_t valueW = valueShape[getValueAlphaWDim()];
3304 int64_t valueTileH = valueShape[getValueTileHDim()];
3305 int64_t valueTileW = valueShape[getValueTileWDim()];
3306 int m = getM();
3307 int r = getR();
3308 bool leftTransform = valueH != 1;
3309 bool rightTransform = valueW != 1;
3310
3311 int64_t outputRank = getOutputOperandRank();
3313 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3314 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3315 } else {
3316 if (valueH != (leftTransform ? m + r - 1 : 1))
3317 return emitOpError("expect input height equals to input tile size");
3318 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3319 }
3320 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3321 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3322 } else {
3323 if (valueW != (rightTransform ? m + r - 1 : 1))
3324 return emitOpError("expect input width equals to input tile size");
3325 expectedOutputShape[getOutputWDim()] =
3326 (rightTransform ? m : 1) * valueTileW;
3327 }
3328 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3329 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3330
3331 auto outputType = cast(getOutput().getType());
3334 return emitOpError("the output shape is not expected");
3335 }
3336 return success();
3337 }
3338
3340 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3342 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3343 IntegerAttr oneAttr = builder.getIndexAttr(1);
3344 Value value = getValue();
3345 int64_t valueRank = getValueOperandRank();
3347 for (unsigned dim = 0; dim < valueRank; ++dim) {
3348 loopBounds[dim].offset = zeroAttr;
3349
3350 loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3351 loopBounds[dim].stride = oneAttr;
3352 }
3353 return loopBounds;
3354 }
3355
3357 WinogradOutputTransformOp::getLoopIteratorTypes() {
3358 int64_t valueRank = getValueOperandRank();
3360 utils::IteratorType::parallel);
3361 return iteratorTypes;
3362 }
3363
3368 int64_t m = getM();
3369
3372 auto identityAffineMap =
3374 auto affineMap =
3376
3377 ShapedType valueType = getValueOperandType();
3379 int64_t valueH = valueShape[0];
3380 int64_t valueW = valueShape[1];
3382 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3383 offsets[getValueTileHDim()]);
3385 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3386 offsets[getValueTileWDim()]);
3388 builder, loc, affineMap, sizes[getValueTileHDim()]);
3390 builder, loc, affineMap, sizes[getValueTileWDim()]);
3391
3399
3400 resultOffsets.append(
3401 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3402 resultSizes.append(
3403 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3404 return success();
3405 }
3406
3407
3408
3409
3410
3411
3412
3421
3422 ShapedType valueType = getValueOperandType();
3424 int64_t alphaH = valueShape[getValueAlphaHDim()];
3425 int64_t alphaW = valueShape[getValueAlphaWDim()];
3428
3429 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3430 offsets[getValueTileWDim()], offsets[getValueNDim()],
3431 offsets[getValueFDim()]});
3432 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3433 sizes[getValueTileWDim()], sizes[getValueNDim()],
3434 sizes[getValueFDim()]});
3435 int64_t valueRank = getValueOperandRank();
3437 auto valueSlice = builder.createtensor::ExtractSliceOp(
3438 loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3439 tiledOperands.emplace_back(valueSlice);
3440
3443 resultSizes)))
3444 return failure();
3445
3446 int64_t outputRank = getOutputOperandRank();
3448 auto outputSlice = builder.createtensor::ExtractSliceOp(
3449 loc, getOutput(), resultOffsets, resultSizes, strides);
3450 tiledOperands.emplace_back(outputSlice);
3451
3453 resultTypes.push_back(tiledOperands[1].getType());
3455 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3456
3458 {tiledOp},
3461 }
3462
3463
3464
3465
3466
3467
3468
3470 auto explicitRange = subMap.getResults();
3471 auto defaultRange = fullMap.getResults();
3472 DenseSet explicitSet(explicitRange.begin(), explicitRange.end());
3474 llvm::set_union(explicitSet, defaultSet);
3475 return explicitSet == defaultSet;
3476 }
3477
3478
3479
3480
3481
3482
3483
3486 }
3487
3488
3489
3490
3492 unsigned opIndex) {
3495 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3496
3497 auto opIndexingMap = opIndexingMaps[opIndex];
3498 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3499
3501 return matmulOp->emitOpError()
3502 << "Unexpected dim expression in map result.";
3503
3504 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3505 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3506 return matmulOp->emitOpError()
3507 << "Invalid broadcast requested, should be (d2).";
3508 }
3509 return success();
3510 }
3511 return success();
3512 }
3513
3514
3515
3516 template
3519 AffineMap defaultIndexingMap, bool isLHS) {
3520 assert((isa(batchVariantMatmulOp) ||
3521 isa(batchVariantMatmulOp)) &&
3522 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3523
3525 return batchVariantMatmulOp->emitOpError()
3526 << "Unexpected result dim expression (outside the set of default "
3527 "result dims).";
3528
3529
3531 return batchVariantMatmulOp->emitOpError()
3532 << "no. of result dim expressions exceeds 3.";
3533
3534 auto hasValidBatchDim = [](AffineMap map) {
3537 };
3538
3539
3540 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3541 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3542 return batchVariantMatmulOp->emitOpError()
3543 << "Invalid broadcast requested.";
3544 } else if (!hasValidBatchDim(opIndexingMap)) {
3545 return batchVariantMatmulOp->emitOpError()
3546 << "Invalid batch dimension expression.";
3547 }
3548 return success();
3549 }
3550
3551
3552
3553
3554 template
3557 assert((isa(batchVariantMatmulOp) ||
3558 isa(batchVariantMatmulOp)) &&
3559 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3560 if (isa(batchVariantMatmulOp) &&
3562
3563 return batchVariantMatmulOp->emitOpError()
3564 << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3565 << ").";
3566 }
3567 if (isa(batchVariantMatmulOp) &&
3569 return batchVariantMatmulOp->emitOpError()
3570 << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3571 << ").";
3572 }
3573
3574 auto areValidOutputResultDim = [&](AffineMap outputMap) {
3575 return isa(batchVariantMatmulOp)
3576 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3577 outputMap.getResult(1).isFunctionOfDim(1) &&
3578 outputMap.getResult(2).isFunctionOfDim(2)
3579 : outputMap.getResult(0).isFunctionOfDim(1) &&
3580 outputMap.getResult(1).isFunctionOfDim(2);
3581 };
3582
3583 if (!areValidOutputResultDim(opIndexingMap)) {
3584 return batchVariantMatmulOp->emitOpError()
3585 << "Invalid output map result dimension.";
3586 }
3587
3588 return success();
3589 }
3590
3591
3592
3593
3594 template
3595 static LogicalResult
3597 unsigned opIndex) {
3599 batchVariantMatmulOp.getIndexingMapsArray();
3601 batchVariantMatmulOp.getDefaultIndexingMaps(
3602 batchVariantMatmulOp->getContext());
3603
3604 if (opIndexingMaps.size() != 3)
3605 return batchVariantMatmulOp->emitOpError()
3606 << "Indexing_map attribute must have 3 affine maps.";
3607
3608 auto opIndexingMap = opIndexingMaps[opIndex];
3609 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3610
3611 if (opIndex == 2 &&
3612 failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3613 return failure();
3614
3615 if (opIndex != 2 &&
3616 failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3617 defaultIndexingMap, opIndex == 0)))
3618 return failure();
3619
3620 return success();
3621 }
3622
3623 namespace mlir {
3624 namespace linalg {
3625
3626
3627
3628
3629
3630
3634 bindDims(context, d0, d1, d2);
3635 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3636 indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3637 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3638 return indexingMaps;
3639 }
3640
3643 utils::IteratorType::parallel,
3644 utils::IteratorType::reduction};
3645 }
3646
3647 unsigned MatmulOp::getNumRegionArgs() { return 3; }
3648
3649 std::string MatmulOp::getLibraryCallName() {
3651 }
3652
3653 bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3654
3655
3656
3657 bool MatmulOp::hasUserDefinedMaps() {
3659 getDefaultIndexingMaps(this->getContext());
3661 return defaultMaps != explicitMaps;
3662 }
3663
3664
3665
3669 "MatmulOp regionBuilder expects 3 (>=0) args");
3670 RegionBuilderHelper helper(b, block);
3672
3673 TypeFn castVal = TypeFn::cast_signed;
3674 const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3675 return attr.getName() == "cast";
3676 });
3677 if (castIter != attrs.end()) {
3678 if (auto attr = llvm::dyn_cast(castIter->getValue()))
3680 }
3681
3686 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3688 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
3689 yields.push_back(value4);
3690 helper.yieldOutputs(yields);
3691 }
3692
3693
3694
3695
3696
3697
3698
3699
3700 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3701 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3703
3705 }
3706
3709 return ArrayAttr{
3710 nullptr};
3711
3712 ArrayAttr arrayAttr;
3714 return failure();
3715
3716 if (llvm::any_of(arrayAttr,
3717 [](auto elt) { return !dyn_cast(elt); }))
3719 << "element of indexing_maps array is not an affine_map";
3720
3721 return arrayAttr;
3722 }
3723
3726 if (failed(indexingMapsAttr))
3727 return failure();
3728
3729 if (*indexingMapsAttr == nullptr) {
3730 auto indexingMapAttrs = llvm::map_to_vector(
3731 MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3732 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3734 }
3735
3736 result.addAttribute("indexing_maps", *indexingMapsAttr);
3738 MatmulOp::getRegionBuilder());
3739 }
3740
3743 MatmulOp::getDefaultIndexingMaps(getContext()),
3745 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3746 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3747
3748 std::array<StringRef, 3> elidedAttrs = {
3749 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3751 elidedAttrs);
3752 }
3753
3754
3756
3757 if (!hasUserDefinedMaps())
3758 return success();
3759
3760 for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3762 return failure();
3763 }
3764 return success();
3765 }
3766
3769 }
3770
3771 void MatmulOp::getEffects(
3773 &effects) {
3774 if (hasPureTensorSemantics())
3775 return;
3777 }
3778
3781 }
3782
3783
3784
3785
3786
3788 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3789
3790
3791
3792
3793
3794
3795
3796
3798 for (auto result : outAffineMap.getResults()) {
3799 auto dimExpr = dyn_cast(result);
3800 assert(dimExpr && "affine_map is a projected permutation");
3801 dimsInOutput[dimExpr.getPosition()] = true;
3802 }
3803
3805 for (auto dimOccursInOutput : dimsInOutput)
3806 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3807 : utils::IteratorType::reduction);
3808
3809 return iteratorTypes;
3810 }
3811
3812 unsigned ContractOp::getNumRegionArgs() { return 3; }
3813
3814
3818 "ContractOp regionBuilder expects 3 args");
3819 RegionBuilderHelper helper(b, block);
3820
3821 TypeFn castSignedness = TypeFn::cast_signed;
3822 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3823 return attr.getName() == "cast";
3824 });
3825 if (castIter != attrs.end()) {
3826 if (auto attr = llvm::dyn_cast(castIter->getValue()))
3827 castSignedness = attr.getValue();
3828 }
3829
3830
3832 Value lhsAtOutType =
3833 helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
3834 Value rhsAtOutType =
3835 helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
3836 Value productAtOutType =
3837 helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
3838 Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3839 productAtOutType);
3840 helper.yieldOutputs({result});
3841 }
3842
3845 if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
3847 "expected 'indexing_maps' attribute");
3848 result.addAttribute("indexing_maps", *indexingMapsAttr);
3849
3851 regionBuilder);
3852 }
3853
3855 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3857 p, getOperation(), getInputs(), getOutputs(),
3858 {"indexing_maps", "operandSegmentSizes"});
3859 }
3860
3862 int iterationSpaceDims = -1;
3863
3864
3865
3866
3869
3870
3871 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
3872 bool isInput) -> LogicalResult {
3873
3875 return emitError("provided affine_map is not a projected permutation");
3876
3877
3878 if (auto shapedType = dyn_cast(operandType)) {
3879 if (affineMap.getNumResults() != shapedType.getRank())
3880 return emitError("ranks of shaped operand and results of corresponding "
3881 "affine_map differ");
3883 return emitError("affine_map specifies shaped access while operand has "
3884 "non-shaped type");
3885 }
3886
3887
3888 if (iterationSpaceDims == -1) {
3889 iterationSpaceDims = affineMap.getNumDims();
3892 } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
3893 return emitError("iteration spaces of provided affine_maps differ");
3894 }
3895
3896
3898 auto affineDimExpr = dyn_cast(affineExpr);
3899 if (!affineDimExpr)
3900 llvm_unreachable("affine_map is a projected permutation");
3901
3902 if (isInput)
3903 inOccurrences[affineDimExpr.getPosition()] += 1;
3904 else
3905 outOccurrences[affineDimExpr.getPosition()] += 1;
3906 }
3907
3908 return success();
3909 };
3910
3911 for (auto &&[affineMap, operandType, isInput] :
3912 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
3914 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
3915 return failure();
3916 }
3917
3918 bool hasContractingDim = false;
3919 for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
3920 size_t inOccCount = inOccurrences[dimIndex];
3921 size_t outOccCount = outOccurrences[dimIndex];
3922
3923
3924 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
3925
3926 if (inOccCount == 0 && outOccCount == 0)
3927 return emitError() << "iteration space dim at index " << dimIndex
3928 << " not used to access any operand";
3929
3930
3931
3932
3933
3934
3935
3936
3937
3938
3939 if (inOccCount == 1 && outOccCount != 1)
3941 << "iteration space dim at index " << dimIndex
3942 << " is neither a contracting dim nor of parallel iteration type";
3943 }
3944
3945 if (!hasContractingDim)
3946 return emitError("'indexing_maps' do not specify a contracting dimension");
3947
3948 return success();
3949 }
3950
3953 }
3954
3955 void ContractOp::getEffects(
3957 &effects) {
3958 if (hasPureTensorSemantics())
3959 return;
3961 }
3962
3965 }
3966
3967
3968
3969
3971 BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3974 bindDims(context, d0, d1, d2, d3);
3975 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
3976 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
3977 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
3978 return indexingMaps;
3979 }
3980
3983 utils::IteratorType::parallel, utils::IteratorType::parallel,
3984 utils::IteratorType::parallel, utils::IteratorType::reduction};
3985 }
3986
3987 unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
3988
3989 std::string BatchMatmulOp::getLibraryCallName() {
3991 }
3992
3993
3994
3995 bool BatchMatmulOp::hasUserDefinedMaps() {
3997 getDefaultIndexingMaps(this->getContext());
3999 return defaultMaps != explicitMaps;
4000 }
4001
4002
4003
4004
4005
4006
4007
4008
4009 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4011 "Expected less than 3 result dim expr.");
4012 bool isValid = false;
4013 enum Indices { batchPos, mPos, nPos, kPos };
4020 isValid =
4027 }
4028 return isValid;
4029 }
4030
4034 "BatchMatmulOp regionBuilder expects 3 (>=0) args");
4035 RegionBuilderHelper helper(b, block);
4037
4038 TypeFn castVal = TypeFn::cast_signed;
4039 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4040 return attr.getName() == "cast";
4041 });
4042 if (castIter != attrs.end()) {
4043 if (auto attr = llvm::dyn_cast(castIter->getValue()))
4045 }
4046
4048 Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
4049 Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
4050 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4052 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
4053 yields.push_back(addVal);
4054 helper.yieldOutputs(yields);
4055 }
4056
4062 return failure();
4063
4065 return failure();
4066
4067 do {
4069 return failure();
4070 if (!isa(mapAttr)) {
4072 "expected affine map attribute");
4073 }
4074 indexingMapsAttr.push_back(mapAttr);
4075
4077 break;
4078 } while (true);
4079
4081 return failure();
4082 }
4083
4084 if (indexingMapsAttr.empty()) {
4085 indexingMapsAttr = llvm::map_to_vector(
4086 BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4087 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4088 }
4091
4093 BatchMatmulOp::getNumRegionArgs(),
4094 BatchMatmulOp::getRegionBuilder());
4095 }
4096
4099 BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4101 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4102 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4103
4104 std::array<StringRef, 3> elidedAttrs = {
4105 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4107 elidedAttrs);
4108 }
4109
4110
4112
4113
4114 if (!hasUserDefinedMaps())
4115 return success();
4116
4117 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4119 return failure();
4120 }
4121 return success();
4122 }
4123
4124 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4127 }
4128
4129 void BatchMatmulOp::getEffects(
4131 &effects) {
4132 if (hasPureTensorSemantics())
4133 return;
4135 }
4136
4139 }
4140
4141
4142
4143
4144
4145 namespace {
4146 struct ArityGroupAndKind {
4147
4149
4150
4151 union Kind {
4156 };
4157
4158 unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4159 return static_cast<unsigned>(arityGroup);
4160 }
4161 }
4162
4164 constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4165 constexpr int lastBinary =
4166 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4167 constexpr int lastTernary =
4168 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4169
4170 int val = static_cast<int>(kind);
4171 ArityGroupAndKind result;
4172
4173 if (val < lastUnary) {
4174 result.arityGroup = ElementwiseArityGroup::Unary;
4175 result.kind.unaryFn = static_cast<UnaryFn>(val);
4176 return result;
4177 }
4178 if (val < lastBinary) {
4179 result.arityGroup = ElementwiseArityGroup::Binary;
4180 result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4181 return result;
4182 }
4183 if (val >= lastTernary) {
4184 llvm_unreachable("unhandled ElementwiseFn");
4185 }
4186 result.arityGroup = ElementwiseArityGroup::Ternary;
4187 result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4188 return result;
4189 }
4190
4192 auto rank = getResultRank();
4194 }
4195
4197 ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4201 }
4202
4204
4206 mlir::linalg::ElementwiseKind elemwiseKindVal;
4208 return failure();
4209
4211 auto elemwiseKindAttr = dyn_cast(attr);
4212 if (!elemwiseKindAttr)
4214 "expected ElementwiseKind attribute");
4215 elemwiseKindVal = elemwiseKindAttr.getValue();
4216 } else {
4218 "expected operation 'kind' attribute");
4219 }
4222
4223
4228 return failure();
4230 return failure();
4231 do {
4233 return failure();
4234 if (!isa(mapAttr))
4236 "expected affine map attribute");
4237 indexingMapsAttr.push_back(mapAttr);
4239 break;
4240 } while (true);
4242 return failure();
4243 }
4244
4245
4247 int numRegionArgs =
4248 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
4250 ElementwiseOp::getRegionBuilder())) {
4252 "unable to parse elemwise op");
4253 }
4254
4255
4256 if (indexingMapsAttr.empty()) {
4257
4258
4259 auto resultType = result.operands[result.operands.size() - 1].getType();
4260 auto shapedType = llvm::dyn_cast(resultType);
4261 if (!shapedType)
4263 "return type needs to be shaped type");
4264 auto numDims = shapedType.getRank();
4265 indexingMapsAttr = llvm::map_to_vector(
4266 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4268 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4269 }
4270
4273 return success();
4274 }
4275
4277 p << " kind=";
4280 "indexing_maps"};
4281 unsigned arity =
4283 unsigned numDims = getResultRank();
4284
4286 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
4289
4290 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4291 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4292
4294 elidedAttrs);
4295 }
4296
4298
4299
4300
4301 return success();
4302 }
4303
4304
4305
4308 ElementwiseKind elemwiseKind;
4309 for (auto attr : attrs) {
4310 if (attr.getName() == b.getStringAttr("kind")) {
4311 auto kindAttr = dyn_cast(attr.getValue());
4312 assert(kindAttr && "op kind attribute incorrectly set");
4313 elemwiseKind = kindAttr.getValue();
4314 break;
4315 }
4316 }
4317
4319 auto arityGroup = groupAndKind.arityGroup;
4320 auto kind = groupAndKind.kind;
4322 getArityGroupAsUInt(arityGroup) + 1
4323 && "Elementwise regionBuilder number of block args mismatch");
4324
4325 RegionBuilderHelper helper(b, block);
4328
4329 if (arityGroup == ElementwiseArityGroup::Unary) {
4330 result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
4331
4332 } else if (arityGroup == ElementwiseArityGroup::Binary) {
4333 result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
4335
4336 } else if (arityGroup == ElementwiseArityGroup::Ternary) {
4337 result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
4339
4340 } else {
4341 assert(false && "found unhandled category in elemwise");
4342 }
4343
4344 yields.push_back(result);
4345 helper.yieldOutputs(yields);
4346 }
4347
4348 LogicalResult ElementwiseOp::fold(FoldAdaptor,
4351 }
4352
4353 void ElementwiseOp::getEffects(
4355 &effects) {
4356 if (hasPureTensorSemantics())
4357 return;
4359 }
4360
4363 }
4364
4365
4366
4367
4368
4369
4370
4371
4372
4373
4374
4379 for (auto it : llvm::zip(cast(newPackedTy)
4381 .take_back(mixedTiles.size()),
4382 mixedTiles)) {
4383 int64_t shape = std::get<0>(it);
4384 if (shape == ShapedType::kDynamic) {
4385 newMixedTileSizes.push_back(std::get<1>(it));
4386 continue;
4387 }
4388
4389
4390
4392 if (Attribute attr = llvm::dyn_cast_if_present(tile)) {
4393
4394 newMixedTileSizes.push_back(tile);
4395 } else {
4397 "tile size and dim size don't match!");
4398 newMixedTileSizes.push_back(
4400 }
4401 }
4402
4403 return newMixedTileSizes;
4404 }
4405
4406 template
4407 static LogicalResult
4410 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4411 "applies to only pack or unpack operations");
4412 int64_t destRank = op.getDestRank();
4414 reifiedReturnShapes[0] =
4416 return success();
4417 }
4418
4419 template
4421 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4422 "applies to only pack or unpack operations");
4426 assert(tiles.size() == dimsToTile.size() &&
4427 "tiles must match indices of dimension to block");
4428
4429 for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
4430 dimAndTileMapping[dimsToTile[i]] = tiles[i];
4431 return dimAndTileMapping;
4432 }
4433
4434 template
4436 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4437 "applies to only pack or unpack operations");
4440 unsigned dynamicValIndex = 0;
4441 for (int64_t staticTile : op.getStaticInnerTiles()) {
4442 if (!ShapedType::isDynamic(staticTile))
4443 mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
4444 else
4445 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
4446 }
4447 return mixedInnerTiles;
4448 }
4449
4450 template
4452 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4453 "applies to only pack or unpack operations");
4457 return staticTiles;
4458 }
4459
4460
4461
4462
4463
4465 size_t rank) {
4466 size_t dimsPosSize = dimsPos.size();
4467 if (dimsPosSize > rank)
4468 return true;
4470 if (dimsPosSize != uniqued.size())
4471 return true;
4472 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
4473 return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
4474 });
4475 }
4476
4477
4478
4481 assert(
4482 sourceShape.size() == limitShape.size() &&
4483 "expected source shape rank, and limit of the shape to have same rank");
4484 return llvm::all_of(
4485 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
4486 int64_t sourceExtent = std::get<0>(it);
4487 int64_t limit = std::get<1>(it);
4488 return ShapedType::isDynamic(sourceExtent) ||
4489 ShapedType::isDynamic(limit) || sourceExtent <= limit;
4490 });
4491 }
4492
4493 template
4495 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4496 "applies to only pack or unpack operations");
4497 Operation *op = packOrUnPack.getOperation();
4498
4499
4502 };
4503
4504
4506 if (hasZeros(mixedTiles))
4507 return op->emitError("invalid zero tile factor");
4508
4509
4510 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4511 ? packOrUnPack.getSourceType()
4512 : packOrUnPack.getDestType();
4513 size_t unpackedRank = unpackedType.getRank();
4515 ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
4517 return op->emitError("invalid inner_dims_pos vector");
4519 return op->emitError("invalid outer_dims_perm vector");
4520 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
4521 return op->emitError("outer_dims_perm must be a permutation or empty");
4522
4523
4524
4525 if (mixedTiles.size() > unpackedRank) {
4526 return op->emitError("tiling factors must be less than or equal to the "
4527 "input rank for pack or output rank for unpack");
4528 }
4529 if (mixedTiles.size() != innerDimsPos.size()) {
4531 "tiling factors must equal the number of dimensions to tile");
4532 }
4533
4534 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4535 ? packOrUnPack.getDestType()
4536 : packOrUnPack.getSourceType();
4537 size_t packedRank = packedType.getRank();
4538
4539 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
4540 if (expectedPackedRank != packedRank) {
4542 "packed rank != (unpacked rank + num tiling factors), got ")
4543 << packedRank << " != " << expectedPackedRank;
4544 }
4545
4546
4547
4548
4549 RankedTensorType expectedPackedType = PackOp::inferPackedType(
4550 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
4551 if ((expectedPackedType.getShape(), packedType.getShape())) {
4552 return op->emitError("the shape of output is not large enough to hold the "
4553 "packed data. Expected at least ")
4554 << expectedPackedType << ", got " << packedType;
4555 }
4556 if (!llvm::all_of(
4557 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4558 mixedTiles),
4559 [](std::tuple<int64_t, OpFoldResult> it) {
4560 int64_t shape = std::get<0>(it);
4561 if (Attribute attr =
4562 llvm::dyn_cast_if_present(std::get<1>(it))) {
4563 IntegerAttr intAttr = dyn_cast_or_null(attr);
4564 int64_t staticTileSize = intAttr.getValue().getSExtValue();
4565 return shape == staticTileSize;
4566 }
4567 return ShapedType::isDynamic(shape);
4568 })) {
4569 return op->emitError("mismatch in inner tile sizes specified and shaped of "
4570 "tiled dimension in the packed type");
4571 }
4572 return success();
4573 }
4574
4575 namespace {
4576
4577
4578
4579
4580
4581
4582 struct PackOrUnPackTransposeResult {
4586 };
4587 }
4588
4589 template
4590 static PackOrUnPackTransposeResult
4594 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4595 "applies to only pack or unpack operations");
4596 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4597 "some permutation must be non-empty");
4598 PackOrUnPackTransposeResult metadata;
4599 metadata.innerDimsPos =
4601 metadata.innerTiles =
4603 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4604 ? packOrUnPackOp.getSourceRank()
4605 : packOrUnPackOp.getDestRank();
4606 metadata.outerDimsPerm =
4607 packOrUnPackOp.getOuterDimsPerm().empty()
4608 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4610 if (!innerPermutation.empty()) {
4611 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4613 "invalid inner permutation");
4616 }
4617 if (!outerPermutation.empty()) {
4618 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4620 "invalid outer permutation");
4622 }
4623 return metadata;
4624 }
4625
4626
4627
4628
4629
4630 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
4631 setNameFn(getResult(), "pack");
4632 }
4633
4637 std::optional paddingValue,
4640 "number of tile sizes specified must match the specified number of "
4641 "original dimensions to be tiled");
4645 build(builder, state, dest.getType(), source, dest,
4646 paddingValue ? *paddingValue : nullptr,
4651 }
4652
4653 LogicalResult
4657 }
4658
4661 }
4662
4665 }
4666
4669 }
4670
4672 ShapedType inputType = getSourceType();
4673 int64_t inputRank = inputType.getRank();
4674 return getDestType().getShape().take_front(inputRank);
4675 }
4676
4679 auto packedShape = getDestType().getShape();
4681
4683 res.push_back(packedShape[index]);
4684
4685 return res;
4686 }
4687
4694 outputShape.take_front(inputShape.size()));
4696 assert(outerDimsPerm.size() == outputTileSizes.size() &&
4697 "expected output and outer_dims_perm to have same size");
4700 }
4702 if (ShapedType::isDynamic(inputShape[pos]))
4703 continue;
4705
4706 if (!constantTile) {
4707 if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4708 (inputShape[pos] % outputTileSizes[pos] != 0))
4709 return true;
4710 } else if (inputShape[pos] % (*constantTile) != 0) {
4711 return true;
4712 }
4713 }
4714 return false;
4715 }
4716
4719 return failure();
4720
4721
4722
4723
4724 auto paddingValue = getPaddingValue();
4725 if (paddingValue &&
4726 paddingValue.getType() != getSourceType().getElementType()) {
4727 return emitOpError("expected padding_value has ")
4728 << getSourceType().getElementType()
4729 << " but got: " << paddingValue.getType();
4730 }
4731
4732 if (!paddingValue &&
4733 requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
4734 getDestType().getShape(), getOuterDimsPerm(),
4735 getMixedTiles())) {
4736 return emitOpError(
4737 "invalid tile factor or output size provided. Only full tiles are "
4738 "supported when padding_value is not set");
4739 }
4740 return success();
4741 }
4742
4743
4744
4748 for (auto o : ofrs) {
4749
4750 if (llvm::dyn_cast_if_present(o))
4751 result.push_back(ShapedType::kDynamic);
4752 else
4753 result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
4754 }
4755 return result;
4756 }
4757
4758
4759
4760
4766 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4767 continue;
4768 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4769 resultShape[tiledDim.value()] = ShapedType::kDynamic;
4770 continue;
4771 }
4772 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
4773 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4774 }
4775
4776
4779
4780
4781 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4782 return resultShape;
4783 }
4784
4790
4796 builder, loc, ceilDivExpr,
4797 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4798 }
4801 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4802
4807
4808
4809
4810
4811
4812 for (unsigned i = 0; i < resultDims.size(); ++i) {
4813 if (!ShapedType::isDynamic(resultTypeShape[i]))
4814 continue;
4815 resultDims[i] =
4817 }
4818
4819 return resultDims;
4820 }
4821
4822
4823
4824 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4831 }
4832
4841 {v1, v2});
4842 };
4843
4846 llvm::cast(source.getType()).getShape())) {
4847 if (ShapedType::isDynamic(value))
4848 mixedSizes.push_back(
4849 b.createtensor::DimOp(loc, source, index).getResult());
4850 else
4851 mixedSizes.push_back(b.getIndexAttr(value));
4852 }
4853 for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4854 int64_t dimPos = std::get<0>(it);
4856 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4857 }
4859 applyPermutationToVector(mixedSizes, outerDimsPerm);
4860
4861 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4862 auto elemType = llvm::cast(source.getType()).getElementType();
4863 return b.createtensor::EmptyOp(loc, mixedSizes, elemType);
4864 }
4865
4866 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
4870 *this, innerPermutation, outerPermutation);
4871 Value transposedDest =
4872 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4873 metadata.innerDimsPos, metadata.outerDimsPerm);
4874 return b.create(loc, getSource(), transposedDest,
4875 metadata.innerDimsPos, metadata.innerTiles,
4876 getPaddingValue(), metadata.outerDimsPerm);
4877 }
4878
4879
4880 template
4882 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4883 "applies to only pack or unpack operations");
4884 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4885 ? op.getDestType()
4886 : op.getSourceType();
4888 for (auto [dimDest, tile] : llvm::zip(
4889 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4891 if (!constTileSize || ShapedType::isDynamic(dimDest))
4892 return false;
4893 }
4894 return true;
4895 }
4896
4898 if (getPaddingValue())
4900
4901
4902
4903
4906
4908 }
4909
4910
4911
4913 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4914 return false;
4915 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4916 return true;
4917
4918
4919
4922 }
4923
4924
4925
4927 auto packTiles = packOp.getMixedTiles();
4928 auto unPackTiles = unPackOp.getMixedTiles();
4929 if (packTiles.size() != unPackTiles.size())
4930 return false;
4931 for (size_t i = 0, e = packTiles.size(); i < e; i++) {
4933 return false;
4934 }
4935 return true;
4936 }
4937
4938
4940 auto srcType = op.getSourceType();
4941 if (llvm::any_of(op.getInnerDimsPos(),
4942 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4943 return false;
4944 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4945 return false;
4946 return !PackOp::requirePaddingValue(
4947 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4948 op.getOuterDimsPerm(), op.getMixedTiles());
4949 }
4950
4951
4952
4955 bool changeNeeded = false;
4956 srcShape.assign(packOp.getSourceType().getShape().begin(),
4957 packOp.getSourceType().getShape().end());
4958 destShape.assign(packOp.getDestType().getShape().begin(),
4959 packOp.getDestType().getShape().end());
4960 llvm::SmallSetVector<int64_t, 4> innerDims;
4961 innerDims.insert_range(packOp.getInnerDimsPos());
4963 if (!packOp.getOuterDimsPerm().empty())
4965 int srcRank = packOp.getSourceRank();
4966 for (auto i : llvm::seq<int64_t>(0, srcRank)) {
4967 if (innerDims.contains(i))
4968 continue;
4969 int64_t srcPos = i;
4970 int64_t destPos = i;
4971 if (!inverseOuterDimsPerm.empty())
4972 destPos = inverseOuterDimsPerm[srcPos];
4973 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4974 ShapedType::isDynamic(destShape[destPos])) {
4975 continue;
4976 }
4977 int64_t size = srcShape[srcPos];
4978 if (ShapedType::isDynamic(size))
4979 size = destShape[destPos];
4980 srcShape[srcPos] = size;
4981 destShape[destPos] = size;
4982 changeNeeded = true;
4983 }
4984 return changeNeeded;
4985 }
4986
4987 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4988
4989 if (auto unPackOp = packOp.getSource().getDefiningOp()) {
4990 if (unPackOp.getSourceType() != packOp.getDestType())
4991 return failure();
4992 if (packOp.getPaddingValue() ||
4995 return failure();
4996 rewriter.replaceOp(packOp, unPackOp.getSource());
4997 return success();
4998 }
4999
5000
5003 packOp.getPaddingValueMutable().clear();
5005 return success();
5006 }
5007
5008
5011 Location loc = packOp.getLoc();
5012 Value source = packOp.getSource();
5013 if (srcShape != packOp.getSourceType().getShape()) {
5014 auto newSrcType = packOp.getSourceType().clone(srcShape);
5015 source =
5016 rewriter.createtensor::CastOp(loc, newSrcType, packOp.getSource());
5017 }
5018 Value dest = packOp.getDest();
5019 RankedTensorType originalResultType = packOp.getDestType();
5020 bool needUpdateDestType = (destShape != originalResultType.getShape());
5021 if (needUpdateDestType) {
5022 auto newDestType = packOp.getDestType().clone(destShape);
5023 dest =
5024 rewriter.createtensor::CastOp(loc, newDestType, packOp.getDest());
5025 }
5027 packOp.getSourceMutable().assign(source);
5028 packOp.getDestMutable().assign(dest);
5029 packOp.getResult().setType(cast(dest.getType()));
5030 });
5031
5032 if (needUpdateDestType) {
5034 auto castOp =
5035 rewriter.createtensor::CastOp(loc, originalResultType, packOp);
5037 }
5038 return success();
5039 }
5040
5041 return failure();
5042 }
5043
5044 template
5046 RankedTensorType packedTensorType) {
5047 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5048 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5049 "Function meant for pack/unpack");
5050
5051
5052
5054 int64_t numPackedDims = innerDimsPos.size();
5055 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5057
5058 return false;
5059 }
5060
5062 int64_t packedRank = packedTensorType.getRank();
5063
5064
5065
5066
5067
5068
5069
5070
5071
5072 return llvm::all_of(
5073 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5074 [&packedShape](int64_t i) { return packedShape[i] == 1; });
5075 }
5076
5077 bool PackOp::isLikePad() {
5078 auto packedTensorType =
5079 llvm::cast((*this)->getResultTypes().front());
5081 }
5082
5083 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
5084 std::optional paddingValue;
5085 if (auto pad = adaptor.getPaddingValue())
5086 paddingValue = pad;
5087 if (OpFoldResult reshapedSource = reshapeConstantSource(
5088 llvm::dyn_cast_if_present(adaptor.getSource()),
5089 getDestType(), paddingValue))
5090 return reshapedSource;
5091 return {};
5092 }
5093
5094
5095
5096
5097
5098
5099
5100
5101
5102
5103
5104
5105
5106
5107
5110
5114 return failure();
5115
5119
5120
5123
5124
5125
5126
5127
5128 PackOp newOp = rewriter.create(
5129 op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
5130 newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
5132
5133
5134 Value oldResult = op.getResult();
5135 Value newResult = newOp.getResult();
5137 ? rewriter.createtensor::CastOp(
5138 op->getLoc(), oldResult.getType(), newResult)
5139 : newResult;
5140
5141 rewriter.replaceOp(op, {replacement});
5142
5143 return success();
5144 }
5145 };
5146
5147
5148
5149
5150
5151 void UnPackOp::getAsmResultNames(
5153 setNameFn(getResult(), "unpack");
5154 }
5155
5156 LogicalResult
5160 }
5161
5164 }
5165
5168 }
5169
5172 }
5173
5175 ShapedType destType = getDestType();
5176 int64_t destRank = destType.getRank();
5177 return getSourceType().getShape().take_front(destRank);
5178 }
5179
5182 auto packedShape = getSourceType().getShape();
5184
5186 res.push_back(packedShape[index]);
5187
5188 return res;
5189 }
5190
5193 }
5194
5196
5199
5201 }
5202
5208 "number of tile sizes specified must match the specified number of "
5209 "original dimensions to be tiled");
5213 build(builder, state, dest.getType(), source, dest,
5218 }
5219
5229 };
5230
5232 auto srcType = llvm::cast(source.getType());
5233 for (auto i :
5234 llvm::seq(0, srcType.getRank() - innerTileSizes.size())) {
5235 if (srcType.isDynamicDim(i))
5236 mixedSizes.push_back(b.createtensor::DimOp(loc, source, i).getResult());
5237 else
5238 mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
5239 }
5241 applyPermutationToVector(
5243 }
5244
5245 for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
5246 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5247
5248 auto elemType = srcType.getElementType();
5249 return b.createtensor::EmptyOp(loc, mixedSizes, elemType);
5250 }
5251
5252 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
5253 Value transposedSource,
5257 *this, innerPermutation, outerPermutation);
5258 return b.create(loc, transposedSource, getDest(),
5259 metadata.innerDimsPos, metadata.innerTiles,
5260 metadata.outerDimsPerm);
5261 }
5262
5263
5264
5267 bool changeNeeded = false;
5268 srcShape.assign(op.getSourceType().getShape().begin(),
5269 op.getSourceType().getShape().end());
5270 destShape.assign(op.getDestType().getShape().begin(),
5271 op.getDestType().getShape().end());
5272 llvm::SmallSetVector<int64_t, 4> innerDims;
5273 innerDims.insert_range(op.getInnerDimsPos());
5275 if (!op.getOuterDimsPerm().empty())
5277 int destRank = op.getDestRank();
5278 for (auto i : llvm::seq<int64_t>(0, destRank)) {
5279 if (innerDims.contains(i))
5280 continue;
5281 int64_t srcPos = i;
5282 int64_t destPos = i;
5283 if (!inverseOuterDimsPerm.empty())
5284 srcPos = inverseOuterDimsPerm[destPos];
5285 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5286 ShapedType::isDynamic(destShape[destPos])) {
5287 continue;
5288 }
5289 int64_t size = srcShape[srcPos];
5290 if (ShapedType::isDynamic(size))
5291 size = destShape[destPos];
5292 srcShape[srcPos] = size;
5293 destShape[destPos] = size;
5294 changeNeeded = true;
5295 }
5296 return changeNeeded;
5297 }
5298
5299 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5301
5302 if (PackOp packOp = unPackOp.getSource().getDefiningOp()) {
5303 if (packOp.getSourceType() != unPackOp.getDestType())
5304 return failure();
5305 if (packOp.getPaddingValue() ||
5308 return failure();
5309 rewriter.replaceOp(unPackOp, packOp.getSource());
5310 return success();
5311 }
5312
5313 if (auto dstStyleOp =
5314 unPackOp.getDest().getDefiningOp()) {
5315 auto destValue = cast(unPackOp.getDest());
5316 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5318 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5319 return success();
5320 }
5321
5322 if (unPackOp->hasOneUse()) {
5323 auto extractSliceUser =
5324 dyn_casttensor::ExtractSliceOp(*unPackOp->getUsers().begin());
5325 if (extractSliceUser &&
5328 extractSliceUser.getSourceType().getRank() ==
5329 extractSliceUser.getResultType().getRank()) {
5332 auto newDest = rewriter.createtensor::ExtractSliceOp(
5333 unPackOp->getLoc(), unPackOp.getDest(),
5334 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5335 extractSliceUser.getMixedStrides());
5337 unPackOp.setDpsInitOperand(0, newDest);
5338 unPackOp.getResult().setType(newDest.getType());
5339 });
5340 rewriter.replaceOp(extractSliceUser, unPackOp);
5341 return success();
5342 }
5343 }
5344
5345
5348 Location loc = unPackOp.getLoc();
5349 Value source = unPackOp.getSource();
5350 if (srcShape != unPackOp.getSourceType().getShape()) {
5351 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5352 source = rewriter.createtensor::CastOp(loc, newSrcType,
5353 unPackOp.getSource());
5354 }
5355 Value dest = unPackOp.getDest();
5356 if (destShape != unPackOp.getDestType().getShape()) {
5357 auto newDestType = unPackOp.getDestType().clone(destShape);
5358 dest =
5359 rewriter.createtensor::CastOp(loc, newDestType, unPackOp.getDest());
5360 }
5361 Value newOp = rewriter.create(
5362 loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
5363 unPackOp.getOuterDimsPerm());
5365 unPackOp, unPackOp.getResult().getType(), newOp);
5366 return success();
5367 }
5368
5369 return failure();
5370 }
5371
5372 bool UnPackOp::isLikeUnPad() {
5373 RankedTensorType packedTensorType = getSourceType();
5375 }
5376
5377 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
5378 if (OpFoldResult reshapedSource = reshapeConstantSource(
5379 llvm::dyn_cast_if_present(adaptor.getSource()),
5381 return reshapedSource;
5382 return {};
5383 }
5384
5385
5386
5387
5388
5389
5390
5391
5392
5393
5394
5395
5396
5397
5398
5401
5405 return failure();
5406
5410 Value sourceTensor = newOperands[0];
5411
5412
5414 rewriter, sourceTensor.getType(), op.getMixedTiles());
5415
5416
5417
5418
5419
5420 UnPackOp newOp = rewriter.create(
5421 op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
5422 newMixedTileSizes, op.getOuterDimsPerm());
5424
5425
5426 Value oldResult = op.getResult();
5427 Value newResult = newOp.getResult();
5429 ? rewriter.createtensor::CastOp(
5430 op->getLoc(), oldResult.getType(), newResult)
5431 : newResult;
5432
5433 rewriter.replaceOp(op, {replacement});
5434
5435 return success();
5436 }
5437 };
5438
5439
5440
5441
5444 utils::IteratorType::reduction, utils::IteratorType::parallel,
5445 utils::IteratorType::parallel, utils::IteratorType::reduction};
5446 }
5447
5449 BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
5452 bindDims(context, d0, d1, d2, d3);
5453 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
5454 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
5455 indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
5456 return indexingMaps;
5457 }
5458
5459 unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
5460
5461 std::string BatchReduceMatmulOp::getLibraryCallName() {
5463 }
5464
5465
5466
5467 bool BatchReduceMatmulOp::hasUserDefinedMaps() {
5469 getDefaultIndexingMaps(this->getContext());
5471 return defaultMaps != explicitMaps;
5472 }
5473
5474
5475
5476
5477
5478
5479
5480
5481 bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
5482 bool isLHS) {
5484 "Expected less than 3 result dim expr.");
5485 bool isValid = false;
5486 enum Indices { batchPos, mPos, nPos, kPos };
5493 isValid =
5500 }
5501 return isValid;
5502 }
5503
5507 "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");
5508 RegionBuilderHelper helper(b, block);
5510
5512 Value castValA =
5513 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
5514 Value castValB =
5515 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
5516 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
5518 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
5519 yields.push_back(addVal);
5520 helper.yieldOutputs(yields);
5521 }
5522
5529 return failure();
5531 return failure();
5532
5533 do {
5535 return failure();
5536 if (!isa(mapAttr)) {
5538 "expected affine map attribute");
5539 }
5540 indexingMapsAttr.push_back(mapAttr);
5541
5543 break;
5544 } while (true);
5545
5547 return failure();
5548 }
5549
5550 if (indexingMapsAttr.empty()) {
5551 indexingMapsAttr = llvm::map_to_vector(
5552 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
5553 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
5554 }
5558 BatchReduceMatmulOp::getNumRegionArgs(),
5559 BatchReduceMatmulOp::getRegionBuilder());
5560 }
5561
5564 BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
5566
5567 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
5568 p << " indexing_maps = [";
5569 llvm::interleaveComma(getIndexingMaps(), p,
5571 p << "]";
5572 }
5573
5575 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
5577 elidedAttrs);
5578 }
5579
5580
5582
5583
5584 if (!hasUserDefinedMaps())
5585 return success();
5586
5587 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
5589 return failure();
5590 }
5591 return success();
5592 }
5593 LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
5596 }
5597 void BatchReduceMatmulOp::getEffects(
5599 &effects) {
5600 if (hasPureTensorSemantics())
5601 return;
5603 }
5604
5607 }
5608
5609 }
5610 }
5611
5612
5613
5614
5615
5616 void LinalgDialect::getCanonicalizationPatterns(
5620 }
5621
5625 return arith::ConstantOp::materialize(builder, value, type, loc);
5626 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
SmallVector< int64_t > outerDimsPerm
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
SmallVector< OpFoldResult > innerTiles
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static Operation * findPayloadOp(Block *body, bool initFirst=false)
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
ElementwiseArityGroup arityGroup
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
SmallVector< int64_t > innerDimsPos
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
result_iterator result_begin()
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Kind
An enumeration of the kinds of predicates.
DynamicAPInt floor(const Fraction &f)
DynamicAPInt ceil(const Fraction &f)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override