MLIR: lib/Dialect/Linalg/IR/LinalgInterfaces.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetOperations.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include
29 #include
30 #include
31
32 using namespace mlir;
34
35
36 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
37
38
39
40
41
45 for (auto &opOperand : linalgOp->getOpOperands()) {
46 if (llvm::is_contained(droppedOperands, &opOperand))
47 continue;
48 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
49 }
50 if (indexingMaps.empty()) {
51
52
53 return linalgOp.getNumLoops() == 0;
54 }
56 indexingMaps, linalgOp.getContext())) != AffineMap();
57 }
58
59
60
61
62
64
65 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
66 return false;
67
68 auto mapRange = op.getIndexingMapsArray();
69 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
70 !mapRange.back().isIdentity()) {
71 return false;
72 }
73
74 return llvm::hasSingleElement(op.getBlock()->getOperations());
75 }
76
77
78
79
81
82 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
83 !op.isSingleYieldOp())
84 return std::nullopt;
85
86
87 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
88 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
89 return std::nullopt;
90
91 OpOperand *value = op.getDpsInputOperand(0);
92 if (!op.isScalar(value))
93 return std::nullopt;
94 return value->get();
95 }
96
97
98
99
100 std::optional<SmallVector<int64_t>>
102
103 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
104 !op.isSingleYieldOp())
105 return std::nullopt;
106
107 auto srcTy = op.getDpsInputOperand(0)->get().getType();
108 auto dstTy = op.getDpsInitOperand(0)->get().getType();
109 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
110 !isa<MemRefType, RankedTensorType>(dstTy))
111 return std::nullopt;
112
113
114
115
116 auto dstMap = op.getIndexingMapsArray()[1];
117 if (!dstMap.isIdentity())
118 return std::nullopt;
119
121 auto srcMap = op.getIndexingMapsArray()[0];
122
123 if (srcMap.getResults().size() >= dstMap.getResults().size())
124 return std::nullopt;
125
126
127 for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
128 auto expr = llvm::dyn_cast(srcMap.getResults()[i]);
129 if (!expr)
130 return std::nullopt;
131 int64_t pos = expr.getPosition();
132 if (i > 0 && pos <= position[i - 1])
133 return std::nullopt;
134 position.push_back(expr.getPosition());
135 }
136
138 auto numDims = srcMap.getNumDims();
139
140 for (auto dim : llvm::seq<int64_t>(0, numDims)) {
141 if (!llvm::is_contained(position, dim))
142 broadcastedDims.push_back(dim);
143 }
144 return broadcastedDims;
145 }
146
147
148
149
150 std::optional<SmallVector<int64_t>>
152
153
154
155 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
156 !op.isSingleYieldOp())
157 return std::nullopt;
158
159 auto mapRange = op.getIndexingMapsArray();
160 if (mapRange.size() != 2)
161 return std::nullopt;
162
163 auto mapOfInput = mapRange.front();
164 auto mapOfResult = mapRange.back();
165
166
167
168 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
169 return std::nullopt;
170
172 for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
173 auto expr = llvm::cast(mapOfInput.getResults()[i]);
174 permutation[expr.getPosition()] = i;
175 }
176 return permutation;
177 }
178
179
180
181
183 unsigned arity) {
184
185 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
186 return false;
187
188
189 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
190 !llvm::all_of(op.getIndexingMapsArray(),
191 [](AffineMap map) { return map.isIdentity(); }))
192 return false;
193
194
195 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
196 return false;
197
198
199
200
201
202 Block *body = op.getBody();
204 return false;
205
208 return false;
209
210 auto yieldOp = dyn_castlinalg::YieldOp(body->back());
211 if (!yieldOp || yieldOp.getNumOperands() != 1 ||
212 yieldOp->getOperand(0).getDefiningOp() != oper)
213 return false;
214 return true;
215 }
216
218
220 return false;
221
222
223 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
224 return false;
225 return true;
226 }
227
230 return false;
231
232
233 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
234 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
235 if (!op.payloadUsesValueFromOperand(inputOpOperand0) ||
236 !op.payloadUsesValueFromOperand(inputOpOperand1))
237 return false;
238 return true;
239 }
240
241
242
243
244
245
246
247
248
252 auto iface = dyn_cast(op);
253 if (!iface || !iface.hasNoEffect())
254 break;
257 }
258 return value;
259 }
260
263 llvm::raw_ostream &errs) {
265 errs << "no terminator in the block";
266 return false;
267 }
268
270 errs << "expected block with 3 arguments";
271 return false;
272 }
273
276 errs << "expected terminator with 1 operand";
277 return false;
278 }
279
283 errs << "expected reduction op to be binary";
284 return false;
285 }
286
289
290 if (reductionLHS != block.getArgument(2) &&
292 errs << "expected reduction to take block argument #2 as one of the "
293 "operands (modulo unary casts)";
294 return false;
295 }
296
298 isa(reductionLHS) ? reductionRHS : reductionLHS);
300 if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
302 errs << "expected elementwise op to be binary";
303 return false;
304 }
305
306 if (!isaPair(elementwiseOp, reductionOp)) {
307 errs << "expected reduction/elementwise op kind not satisfied";
308 return false;
309 }
310
313 if ((elementwiseLHS == block.getArgument(0) &&
314 elementwiseRHS == block.getArgument(1)) ||
315 (elementwiseLHS == block.getArgument(1) &&
316 elementwiseRHS == block.getArgument(0))) {
317 return true;
318 }
319
320 errs << "expected elementwise op to apply to block arguments (modulo unary "
321 "casts)";
322 return false;
323 }
324
325
326
327 template <typename AddOpTy, typename MulOpTy, typename... Args>
329 static_assert(sizeof...(Args) % 2 == 0,
330 "expected an even number of template arguments");
331 if (isa(add) && isa(mul))
332 return true;
333
334 if constexpr (sizeof...(Args) > 0)
336 else
337 return false;
338 }
339
340
341
342 template <typename... Args>
345 }
346
347
348
349
350
351
352
353
354 static llvm::SmallDenseSet<int64_t>
357 utils::IteratorType iter) {
358 assert(iterators.size() == indexingMap.getNumDims());
359 llvm::SmallDenseSet<int64_t> res;
361 if (auto d = dyn_cast(e)) {
362 if (iterators[d.getPosition()] == iter &&
364 return e.isFunctionOfDim(d.getPosition());
365 }) == 1)
366 res.insert(d.getPosition());
367 }
368 }
369 return res;
370 }
371
372 namespace {
373 auto par = utils::IteratorType::parallel;
374 auto red = utils::IteratorType::reduction;
375 }
376
377
378
379
380
381 static FailureOr<SmallVectorutils::IteratorType>
384 return failure();
387 if (auto dim = dyn_cast(expr))
388 iterators[dim.getPosition()] = par;
389 return iterators;
390 }
391
392
393
394
395
396
397
398
399
400
401
402
403 static FailureOr
406 llvm::SmallDenseSet<int64_t> a =
408 llvm::SmallDenseSet<int64_t> b =
410 llvm::SmallDenseSet<int64_t> c =
412
413
414 llvm::SmallDenseSet<int64_t> ac = a;
415 llvm::set_intersect(ac, c);
416 llvm::set_subtract(ac, b);
417
418 llvm::SmallDenseSet<int64_t> bc = b;
419 llvm::set_intersect(bc, c);
420 llvm::set_subtract(bc, a);
421
422 llvm::SmallDenseSet<int64_t> batches = a;
423 llvm::set_intersect(batches, b);
424 llvm::set_intersect(batches, c);
425
426
427 llvm::SmallDenseSet<int64_t> ra =
429 llvm::SmallDenseSet<int64_t> rb =
431 llvm::set_intersect(ra, rb);
432
433
439 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
440 llvm::sort(dimensions.m.begin(), dimensions.m.end());
441 llvm::sort(dimensions.n.begin(), dimensions.n.end());
442 llvm::sort(dimensions.k.begin(), dimensions.k.end());
443 return dimensions;
444 }
445
446 FailureOr
448 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
449 return failure();
451 linalgOp.getIteratorTypesArray());
452 }
453
454 FailureOr
456 if (indexingMaps.size() != 3)
457 return failure();
459 if (failed(iterators))
460 return failure();
462 }
463
472 };
473 }
474
478 auto linalgOp = dyn_castlinalg::LinalgOp(op);
479 if (!linalgOp)
480 return MatchContractionResult::NotLinalgOp;
481 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
482 return MatchContractionResult::WrongNumOperands;
483 auto mapRange = linalgOp.getIndexingMapsArray();
484 if (linalgOp.getNumReductionLoops() == 0)
485 return MatchContractionResult::NoReduction;
486 if (llvm::any_of(mapRange,
487 [](AffineMap m) { return !m.isProjectedPermutation(); }))
488 return MatchContractionResult::NotProjectedPermutations;
489
490
492 arith::MulFOp, arith::AddFOp,
493 arith::MulIOp, arith::AddIOp,
494 complex::MulOp, complex::AddOp,
495 arith::AndIOp, arith::OrIOp>(
496 *linalgOp.getBlock())) {
497 return MatchContractionResult::NotAddMul;
498 }
499
500
501 if (dimensions) {
503 assert(succeeded(res) && "unexpected failure to infer contraction dims");
504 *dimensions = *res;
505 }
506 return MatchContractionResult::Success;
507 }
508
509 StringRef
511 switch (res) {
512 case MatchContractionResult::NotLinalgOp:
513 return "expected a LinalgOp";
514 case MatchContractionResult::WrongNumOperands:
515 return "expected op with 2 inputs and 1 output";
516 case MatchContractionResult::NoReduction:
517 return "expected at least 1 reduction";
518 case MatchContractionResult::NotProjectedPermutations:
519 return "expected indexing maps to be projected permutations";
520 case MatchContractionResult::NotAddMul:
521 return "expected add/mul op in the body";
522 case MatchContractionResult::Success:
523 return "";
524 }
525 llvm_unreachable("unhandled MatchContractionResult case");
526 }
527
529 if (!linalgOp)
530 return false;
531 Operation *op = linalgOp.getOperation();
532 return isa(op) ||
535 }
536
537
538
539
540
541
542
543
544
545
546
547
548
549
552 if (res != MatchContractionResult::Success)
554 return success();
555 }
556
557
558
559
560
561
562
563 template
565 return isa(lhs) ? cast(lhs) : (isa(rhs) ? cast(rhs) : nullptr);
566 }
567
568 namespace {
569
570
571
572
573
574
575
576
577 struct ConvAccessExprWalker
578 : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
579
580 llvm::SmallDenseSet<int64_t> convolvedDims;
581
582 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
583
584 llvm::SmallDenseSet<int64_t> unConvolvedDims;
585
586 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
587
588
589
590 void clearMultiUseDims(AffineMap map) {
591 for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
593 return e.isFunctionOfDim(dimPos);
594 }) > 1) {
595 convolvedDims.erase(dimPos);
596 unConvolvedDims.erase(dimPos);
597
598
599 auto it = convolvedDimMapping.find(dimPos);
600 if (it != convolvedDimMapping.end()) {
601 int64_t pairedDim = it->second;
602 convolvedDims.erase(pairedDim);
603 unConvolvedDims.erase(pairedDim);
604 strideAndDilationMapping.erase(pairedDim);
605 convolvedDimMapping.erase(dimPos);
606 convolvedDimMapping.erase(pairedDim);
607 }
608 }
609 }
610 }
611
612 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
613 unsigned position = dimExpr.getPosition();
614 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
615 return failure();
616 }
617 unConvolvedDims.insert(position);
618 return success();
619 }
620
621 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
622
623 LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
624
625 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
626
628 return failure();
629 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
630 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
631 if (failed(lhsDimPos) || failed(rhsDimPos))
632 return failure();
633 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
634 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
635 return success();
636 }
637
638 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
639 if (auto dimExpr = dyn_cast(expr)) {
641 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
642 return failure();
643
644 strideAndDilationMapping[dim] =
646 convolvedDims.insert(dim);
647 return dim;
648 }
649 if (auto symbolMulExpr = dyn_cast(expr)) {
651 return failure();
652 auto lhsExpr = symbolMulExpr.getLHS();
653 auto rhsExpr = symbolMulExpr.getRHS();
654
656 getAffineExprOfType(lhsExpr, rhsExpr);
657
658 if (!mulExpr) {
659 mulExpr = getAffineExprOfType(lhsExpr, rhsExpr);
660 }
661 auto dimExpr = getAffineExprOfType(lhsExpr, rhsExpr);
662 if (!mulExpr || !dimExpr)
663 return failure();
665 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
666 return failure();
667 strideAndDilationMapping[dim] = mulExpr;
668 convolvedDims.insert(dim);
669 return dim;
670 }
671 return failure();
672 }
673 };
674 }
675
678 "expected map to have projected permutations");
679 llvm::SmallDenseSet<int64_t> preservedDims;
681 preservedDims.insert(cast(expr).getPosition());
682 return preservedDims;
683 }
684
688 for (auto e : exprs) {
689 auto constantExpr = dyn_cast(e);
690 assert(constantExpr && "Found non-constant stride/dilation");
691 vals.push_back(constantExpr.getValue());
692 }
693 return vals;
694 }
695
696
697
698
699
700
701
702
703 static FailureOr
705 ConvAccessExprWalker &inputExprWalker,
706 bool allowEmptyConvolvedDims) {
707 auto filterMap =
708 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
709 auto outputMap =
710 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
712 filterMap, linalgOp.getIteratorTypesArray(), par);
714 outputMap, linalgOp.getIteratorTypesArray(), par);
715
716
717 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
718 llvm::set_intersect(batch, outputDims);
719 llvm::set_subtract(batch, filterDims);
720
721
722 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
723 llvm::set_intersect(oi, outputDims);
724
725
726 llvm::SmallDenseSet<int64_t> oc = filterDims;
727 llvm::set_intersect(oc, outputDims);
728 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
729
730
731 llvm::SmallDenseSet<int64_t> depth = filterDims;
732 llvm::set_intersect(depth, outputDims);
733 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
734
735 llvm::SmallDenseSet<int64_t> filterReducedDims =
737 linalgOp.getIteratorTypesArray(), red);
738
739
740 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
741 llvm::set_intersect(fl, filterReducedDims);
742
743
744 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
745 llvm::set_intersect(ic, filterReducedDims);
746
747 if (oi.empty() && !allowEmptyConvolvedDims)
748 return failure();
749
750
760 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
761 llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
762 llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
763 llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
764 llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
765 llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
766
767
768 auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
769 if (!nativeStrides) {
771 for (unsigned oiDim : dimensions.outputImage)
772 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
774 } else {
775 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
776 }
777 auto nativeDilations =
779 if (!nativeDilations) {
781 for (unsigned flDim : dimensions.filterLoop)
782 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
784 } else {
785 dimensions.dilations =
786 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
787 }
788 return dimensions;
789 }
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815 FailureOr
817 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
818 return failure();
819
820 auto indexingMaps = linalgOp.getIndexingMapsArray();
821
822
823 ConvAccessExprWalker inputExprWalker;
824 for (AffineExpr expr : indexingMaps[0].getResults())
825 (void)inputExprWalker.visit(expr);
826 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
827
829 false);
830 }
831
843 };
844 }
845
849 bool allowEmptyConvolvedDims) {
850 auto linalgOp = dyn_castlinalg::LinalgOp(op);
851 if (!linalgOp)
852 return MatchConvolutionResult::NotLinalgOp;
853 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
854 return MatchConvolutionResult::WrongNumOperands;
855
856 auto indexingMaps = linalgOp.getIndexingMapsArray();
857
858
859 ConvAccessExprWalker inputExprWalker;
860 if (llvm::any_of(indexingMaps[0].getResults(),
861 [&inputExprWalker](AffineExpr expr) {
862 return failed(inputExprWalker.visit(expr));
863 })) {
864 return MatchConvolutionResult::WrongInputIndexingMap;
865 }
866
867
868 if (!indexingMaps[1].isProjectedPermutation() ||
869 !indexingMaps.back().isProjectedPermutation())
870 return MatchConvolutionResult::NotProjectedPermutations;
871
872 auto iteratorTypes = linalgOp.getIteratorTypesArray();
873
874 llvm::SmallDenseSet<int64_t> outputDims =
876 llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
877
878
879
880
881
882
883
884
885
886
887
888
889
890 llvm::SmallDenseSet<int64_t> allLoopDims;
891 for (auto outputExpr : indexingMaps.back().getResults()) {
892 int64_t outputDim = cast(outputExpr).getPosition();
893 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
894 !filterDims.count(outputDim)) {
895
896 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
897 return MatchConvolutionResult::OutputDimsNotParallel;
898 allLoopDims.insert(outputDim);
899 continue;
900 }
901 if (inputExprWalker.convolvedDims.count(outputDim) &&
902 !filterDims.count(outputDim)) {
903
904 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
905 return MatchConvolutionResult::OutputDimsNotParallel;
906 allLoopDims.insert(outputDim);
907 continue;
908 }
909 if (!inputExprWalker.convolvedDims.count(outputDim) &&
910 !inputExprWalker.unConvolvedDims.count(outputDim) &&
911 filterDims.count(outputDim)) {
912
913 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
914 return MatchConvolutionResult::OutputDimsNotParallel;
915 allLoopDims.insert(outputDim);
916 continue;
917 }
918 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
919 filterDims.count(outputDim)) {
920
921 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
922 return MatchConvolutionResult::OutputDimsNotParallel;
923 allLoopDims.insert(outputDim);
924 continue;
925 }
926 return MatchConvolutionResult::NonConvolutionLoop;
927 }
928 for (auto filterExpr : indexingMaps[1].getResults()) {
929 int64_t filterDim = cast(filterExpr).getPosition();
930 if (outputDims.count(filterDim) &&
931 !inputExprWalker.unConvolvedDims.count(filterDim) &&
932 !inputExprWalker.convolvedDims.count(filterDim)) {
933
934 continue;
935 }
936 if (inputExprWalker.convolvedDims.count(filterDim) &&
937 !outputDims.count(filterDim)) {
938
939 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
940 return MatchConvolutionResult::NonOutputDimNotReduction;
941 if (allLoopDims.count(filterDim))
942 return MatchConvolutionResult::NonConvolutionLoop;
943 allLoopDims.insert(filterDim);
944 continue;
945 }
946 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
947 !outputDims.count(filterDim)) {
948
949 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
950 return MatchConvolutionResult::NonOutputDimNotReduction;
951 if (allLoopDims.count(filterDim))
952 return MatchConvolutionResult::NonConvolutionLoop;
953 allLoopDims.insert(filterDim);
954 continue;
955 }
956 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
957 outputDims.count(filterDim)) {
958
959 continue;
960 }
961 return MatchConvolutionResult::NonConvolutionLoop;
962 }
963
964 if (allLoopDims.size() != linalgOp.getNumLoops())
965 return MatchConvolutionResult::NonConvolutionLoop;
966
967 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
968 return MatchConvolutionResult::EmptyConvolvedDims;
969
970 if (dimensions) {
972 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
973 assert(succeeded(res) && "unexpected failure to infer convolution dims");
974 *dimensions = *res;
975 }
976
977 return MatchConvolutionResult::Success;
978 }
979
980 StringRef
982 switch (res) {
983 case MatchConvolutionResult::NotLinalgOp:
984 return "expected a LinalgOp";
985 case MatchConvolutionResult::WrongNumOperands:
986 return "expected op with 2 inputs and 1 output";
987 case MatchConvolutionResult::WrongInputIndexingMap:
988 return "unexpected input index map for convolutions";
989 case MatchConvolutionResult::NotProjectedPermutations:
990 return "expected output/filter indexing maps to be projected permutations";
991 case MatchConvolutionResult::NonConvolutionLoop:
992 return "unexpected loop dimension for convolution op";
993 case MatchConvolutionResult::OutputDimsNotParallel:
994 return "expected all iterators used to access outputs to be parallel";
995 case MatchConvolutionResult::NonOutputDimNotReduction:
996 return "expected all iterators not used to access outputs to be reduction";
997 case MatchConvolutionResult::EmptyConvolvedDims:
998 return "expected convolved dim to be non-empty";
999 case MatchConvolutionResult::Success:
1000 return "";
1001 }
1002 llvm_unreachable("unhandled MatchConvolutionResult case");
1003 }
1004
1006 bool allowEmptyConvolvedDims) {
1008 linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) ==
1010 }
1011
1014 if (res != MatchConvolutionResult::Success)
1016 return success();
1017 }
1018
1019
1020
1021
1022
1024 Success = 0,
1025 NotLinalgOp,
1026 WrongNumOperands,
1028 };
1029
1031 auto linalgOp = dyn_castlinalg::LinalgOp(op);
1032 if (!linalgOp)
1034 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1036
1037 OpOperand *value = linalgOp.getDpsInputOperand(0);
1038 if (!linalgOp.isScalar(value))
1040
1042 }
1043
1047 return op->emitError("expected a LinalgOp");
1049 return op->emitError("expected op with 1 input and 1 output");
1051 return op->emitError("expected op with scalar input");
1052
1053 return success();
1054 }
1055
1056
1057
1058
1059
1063 for (OpOperand &opOperand : getOperation()->getOpOperands()) {
1064 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1066 }
1067 return res;
1068 }
1069
1072 assert(!hasDynamicShape() && "expected operands to have static shapes");
1073 for (OpOperand &opOperand : getOperation()->getOpOperands())
1074 llvm::append_range(res, getShape(&opOperand));
1075 return res;
1076 }
1077
1079 AffineMap map = getLoopsToShapesMap();
1081 auto viewSizes = createFlatListOfOperandDims(b, loc);
1083 for (unsigned idx = 0; idx < numRes; ++idx) {
1084 auto result = map.getResult(idx);
1085 if (auto d = dyn_cast(result)) {
1086 if (res[d.getPosition()].offset)
1087 continue;
1088 res[d.getPosition()] =
1090 }
1091 }
1092 return res;
1093 }
1094
1095
1096
1100 : positions(std::move(positions)) {}
1101
1104 }
1105
1107 return positions.test(dimExpr.getPosition());
1108 }
1109
1111
1113
1114 private:
1115 llvm::SmallBitVector positions;
1116 };
1117
1118 static std::pair<int64_t, int64_t>
1120 int64_t inputRankSum = 0;
1121 int64_t outputRankSum = 0;
1122 for (OpOperand *input : op.getDpsInputOperands())
1123 inputRankSum += op.getRank(input);
1124 for (OpOperand &output : op.getDpsInitsMutable())
1125 outputRankSum += op.getRank(&output);
1126 return {inputRankSum, inputRankSum + outputRankSum};
1127 }
1128
1129 LogicalResult
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1142
1143
1144
1146
1147
1148
1150 resultShapesSubMapPos.first,
1151 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1152 AffineMap resultShapesFromInputShapesMap =
1153 loopToResultsShapeMap.compose(getShapesToLoopsMap());
1154
1155
1156
1157 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
1158 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1160 Location loc = getOperation()->getLoc();
1164 rewriter, loc, resultShapesFromInputShapesMap,
1165 createFlatListOfOperandDims(b, loc));
1166 int64_t pos = 0;
1168 for (OpOperand &opOperand : getDpsInitsMutable()) {
1170 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1171 auto shapedType = llvm::cast(opOperand.get().getType());
1172 if (!shapedType.isDynamicDim(dim)) {
1173
1174 shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
1175 } else {
1176
1177 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1179 : allResultDimValues[pos];
1181 }
1182 pos++;
1183 }
1184 reifiedReturnShapes.emplace_back(std::move(shapes));
1185 }
1186 return success();
1187 }
1188
1189
1190
1191 int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1193 auto dpsIface = cast(*this->getOperation());
1194 if (!dpsIface.isDpsInput(opOperand))
1195 return operandNumber;
1196 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1197 assert(!dpsIface.isDpsInit(opOperand));
1198
1199
1200 return cast(*this->getOperation())
1201 .getNumDpsInputs() +
1202 operandNumber - start;
1203 }
1204
1206 LinalgOp linalgOp = cast(op);
1207
1208 if (!linalgOp.hasPureTensorSemantics() &&
1209 !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1210 return op->emitOpError("expected to have pure tensor or buffer semantics");
1211
1212
1213
1214 if (linalgOp.hasDynamicIndexingMaps())
1215 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1216 return failure();
1217
1218
1219 if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
1220 linalgOp->getNumOperands())
1221 return op->emitOpError("expected the number of indexing_map (")
1222 << linalgOp.getIndexingMapsArray().size()
1223 << ") to be equal to the number of input/output operands ("
1224 << linalgOp->getNumOperands() << ")";
1225
1226
1227
1228 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1229 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1230
1231
1233 return op->emitOpError("unexpected symbols in indexing_map #")
1235
1236
1237 unsigned numLoops = linalgOp.getNumLoops();
1238 if (indexingMap.getNumDims() != numLoops)
1239 return op->emitOpError("expected indexing_map #")
1241 << " dim(s) to match the number of loops";
1242
1243 int64_t rank = linalgOp.getRank(&opOperand);
1244
1246 return op->emitOpError("expected operand rank (")
1247 << rank << ") to match the result rank of indexing_map #"
1250 }
1252 linalgOp.getReductionDims(redDims);
1253
1254 if (!linalgOp.getShapesToLoopsMap())
1255 return op->emitOpError("expected the shape-to-loops map to be non-null");
1256
1257
1260
1261
1262 if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1263 for (int64_t &range : endLoopRangeValues)
1264 range -= 1;
1265 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1266 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1268 indexingMap.compose(startLoopRangeValues);
1270 indexingMap.compose(endLoopRangeValues);
1272 for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
1273
1274 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1275 continue;
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287 int64_t inferredDimSize =
1288 std::max(startIndices[dim], endIndices[dim]) + 1;
1289 if (std::min(startIndices[dim], endIndices[dim]) < 0) {
1290 std::string mapStr;
1291 {
1292 llvm::raw_string_ostream os(mapStr);
1293 os << indexingMap;
1294 }
1296 "unexpected result less than 0 at expression #")
1297 << dim << " in " << mapStr;
1298 }
1299 if (isa(indexingMap.getResult(dim))) {
1300 if (inferredDimSize != shape[dim]) {
1301 return op->emitOpError("inferred input/output operand #")
1302 << opOperand.getOperandNumber() << " has shape's dimension #"
1303 << dim << " to be " << inferredDimSize << ", but found "
1304 << shape[dim];
1305 }
1306 } else {
1307 if (inferredDimSize > shape[dim]) {
1308 return op->emitOpError("inferred input/output operand #")
1309 << opOperand.getOperandNumber() << " has shape's dimension #"
1310 << dim << " to be greater than or equal to "
1311 << inferredDimSize << ", but found " << shape[dim];
1312 }
1313 }
1314 }
1315 }
1316 }
1317
1318
1319 if (linalgOp->getNumRegions() != 1 ||
1320 !llvm::hasSingleElement(linalgOp->getRegion(0)))
1321 return op->emitOpError("expects to have 1 region with 1 block");
1322
1323
1324
1325
1326
1327
1328
1329 Block &block = linalgOp->getRegion(0).front();
1330
1331 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1332 return op->emitOpError("expected as many non-induction variable region "
1333 "arguments as the number of input/output operands");
1334
1335 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1337 if (isa<MemRefType, RankedTensorType>(elementType))
1340 if (elementType != argType)
1341 return op->emitOpError("expected type of bb argument #")
1343 << " to match element or self type of the corresponding operand ("
1344 << elementType << ")";
1345 }
1346
1347 return success();
1348 }
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
IntegerAttr getIndexAttr(int64_t value)
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.