MLIR: lib/Dialect/Vector/Transforms/VectorDistribute.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include
23
24 using namespace mlir;
27
28
29
30
31
32
33
34
35
36
37
38
40 VectorType distributedType) {
42 perm.reserve(1);
43
44
45
46 for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
47 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
48 perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
49 }
50 auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
51 distributedType.getContext());
52 return map;
53 }
54
55 namespace {
56
57
58
59
60
61
62 struct DistributedLoadStoreHelper {
63 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
65 : sequentialVal(sequentialVal), distributedVal(distributedVal),
66 laneId(laneId), zero(zero) {
67 sequentialVectorType = dyn_cast(sequentialVal.getType());
68 distributedVectorType = dyn_cast(distributedVal.getType());
69 if (sequentialVectorType && distributedVectorType)
70 distributionMap =
72 }
73
75 int64_t distributedSize = distributedVectorType.getDimSize(index);
77 return b.createOrFoldaffine::AffineApplyOp(loc, tid * distributedSize,
79 }
80
81
82
83
84
85
86
87
90 assert((val == distributedVal || val == sequentialVal) &&
91 "Must store either the preregistered distributed or the "
92 "preregistered sequential value.");
93
94 if (!isa(val.getType()))
95 return b.creatememref::StoreOp(loc, val, buffer, zero);
96
97
98
99 int64_t rank = sequentialVectorType.getRank();
101 if (val == distributedVal) {
102 for (auto dimExpr : distributionMap.getResults()) {
103 int64_t index = cast(dimExpr).getPosition();
104 indices[index] = buildDistributedOffset(b, loc, index);
105 }
106 }
108 return b.createvector::TransferWriteOp(
109 loc, val, buffer, indices,
111 }
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
134
135
136 if (!isa(type))
137 return b.creatememref::LoadOp(loc, buffer, zero);
138
139
140
141
142 assert((type == distributedVectorType || type == sequentialVectorType) &&
143 "Must store either the preregistered distributed or the "
144 "preregistered sequential type.");
146 if (type == distributedVectorType) {
147 for (auto dimExpr : distributionMap.getResults()) {
148 int64_t index = cast(dimExpr).getPosition();
149 indices[index] = buildDistributedOffset(b, loc, index);
150 }
151 }
153 return b.createvector::TransferReadOp(
154 loc, cast(type), buffer, indices,
156 }
157
158 Value sequentialVal, distributedVal, laneId, zero;
159 VectorType sequentialVectorType, distributedVectorType;
161 };
162
163 }
164
165
166
173 return rewriter.create(res);
174 }
175
176 namespace {
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
207 WarpOpToScfIfPattern(MLIRContext *context,
211
212 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
214 assert(warpOp.getBodyRegion().hasOneBlock() &&
215 "expected WarpOp with single block");
216 Block *warpOpBody = &warpOp.getBodyRegion().front();
217 Location loc = warpOp.getLoc();
218
219
222
223
224 Value c0 = rewriter.createarith::ConstantIndexOp(loc, 0);
225 Value isLane0 = rewriter.createarith::CmpIOp(
226 loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
227 auto ifOp = rewriter.createscf::IfOp(loc, isLane0,
228 false);
229 rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
230
231
232
234 for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
236 Value distributedVal = it.value();
237 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
238 warpOp.getLaneid(), c0);
239
240
242 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
243 sequentialVal.getType());
244
245 helper.buildStore(rewriter, loc, distributedVal, buffer);
246
248 bbArgReplacements.push_back(
249 helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
250 }
251
252
253 if (!warpOp.getArgs().empty()) {
255 options.warpSyncronizationFn(loc, rewriter, warpOp);
256 }
257
258
259 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
260
261
262
263
264
266 auto yieldOp = castgpu::YieldOp(ifOp.thenBlock()->getTerminator());
267 Location yieldLoc = yieldOp.getLoc();
268 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
269 Value sequentialVal = it.value();
270 Value distributedVal = warpOp->getResult(it.index());
271 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
272 warpOp.getLaneid(), c0);
273
274
276 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
277 sequentialVal.getType());
278
279
280
282 helper.buildStore(rewriter, loc, sequentialVal, buffer);
283
284
286
287
288
289
290
291
292
293 replacements.push_back(
294 helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
295 }
296
297
298 if (!yieldOp.getOperands().empty()) {
300 options.warpSyncronizationFn(loc, rewriter, warpOp);
301 }
302
303
304 rewriter.eraseOp(yieldOp);
306 rewriter.createscf::YieldOp(yieldLoc);
307
308
309 rewriter.replaceOp(warpOp, replacements);
310
311 return success();
312 }
313
314 private:
316 };
317
318
319
320
321
322
323
324
325
326 static VectorType getDistributedType(VectorType originalType, AffineMap map,
327 int64_t warpSize) {
329 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
331 if (targetShape[position] % warpSize != 0) {
332 if (warpSize % targetShape[position] != 0) {
333 return VectorType();
334 }
335 warpSize /= targetShape[position];
336 targetShape[position] = 1;
337 continue;
338 }
339 targetShape[position] = targetShape[position] / warpSize;
340 warpSize = 1;
341 break;
342 }
343 if (warpSize != 1) {
344 return VectorType();
345 }
346 VectorType targetType =
347 VectorType::get(targetShape, originalType.getElementType());
348 return targetType;
349 }
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
372 unsigned maxNumElementsToExtract, PatternBenefit b = 1)
374 maxNumElementsToExtract(maxNumElementsToExtract) {}
375
376
377
378 LogicalResult tryDistributeOp(RewriterBase &rewriter,
379 vector::TransferWriteOp writeOp,
380 WarpExecuteOnLane0Op warpOp) const {
381 VectorType writtenVectorType = writeOp.getVectorType();
382
383
384
385 if (writtenVectorType.getRank() == 0)
386 return failure();
387
388
389 AffineMap map = distributionMapFn(writeOp.getVector());
390 VectorType targetType =
391 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
392 if (!targetType)
393 return failure();
394
395
396 VectorType maskType;
397 if (writeOp.getMask()) {
398
399
400
401
402
403
404 if (!writeOp.getPermutationMap().isMinorIdentity())
405 return failure();
406 maskType =
407 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
408 }
409
410
411
412 vector::TransferWriteOp newWriteOp =
413 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
414
415
416 auto newWarpOp =
417 newWriteOp.getVector().getDefiningOp();
418
419
420
421
424 for (auto [seqSize, distSize] :
425 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
426 assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
427 delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
428 }
431 delinearized = rewriter
432 .createmlir::affine::AffineDelinearizeIndexOp(
433 newWarpOp.getLoc(), newWarpOp.getLaneid(),
434 delinearizedIdSizes)
435 .getResults();
436 } else {
437
438
439 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
440 }
441
442 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
443 Location loc = newWriteOp.getLoc();
445 newWriteOp.getIndices().end());
448 bindDims(newWarpOp.getContext(), d0, d1);
449 auto indexExpr = dyn_cast(std::get<0>(it));
450 if (!indexExpr)
451 continue;
452 unsigned indexPos = indexExpr.getPosition();
453 unsigned vectorPos = cast(std::get<1>(it)).getPosition();
454 Value laneId = delinearized[vectorPos];
455 auto scale =
458 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
459 }
460 newWriteOp.getIndicesMutable().assign(indices);
461
462 return success();
463 }
464
465
466 LogicalResult tryExtractOp(RewriterBase &rewriter,
467 vector::TransferWriteOp writeOp,
468 WarpExecuteOnLane0Op warpOp) const {
469 Location loc = writeOp.getLoc();
470 VectorType vecType = writeOp.getVectorType();
471
472 if (vecType.getNumElements() > maxNumElementsToExtract) {
474 warpOp,
475 llvm::formatv(
476 "writes more elements ({0}) than allowed to extract ({1})",
477 vecType.getNumElements(), maxNumElementsToExtract));
478 }
479
480
481 if (llvm::all_of(warpOp.getOps(),
482 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
483 return failure();
484
488 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
489 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
491
492
493 auto secondWarpOp = rewriter.create(
494 loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
495 Block &body = secondWarpOp.getBodyRegion().front();
497 auto newWriteOp =
498 castvector::TransferWriteOp(rewriter.clone(*writeOp.getOperation()));
499 newWriteOp.getValueToStoreMutable().assign(
500 newWarpOp.getResult(newRetIndices[0]));
501 rewriter.eraseOp(writeOp);
502 rewriter.creategpu::YieldOp(newWarpOp.getLoc());
503 return success();
504 }
505
506 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
508 auto yield = castgpu::YieldOp(
509 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
510 Operation *lastNode = yield->getPrevNode();
511 auto writeOp = dyn_cast_or_nullvector::TransferWriteOp(lastNode);
512 if (!writeOp)
513 return failure();
514
515 Value maybeMask = writeOp.getMask();
516 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
517 return writeOp.getVector() == value ||
518 (maybeMask && maybeMask == value) ||
519 warpOp.isDefinedOutsideOfRegion(value);
520 }))
521 return failure();
522
523 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
524 return success();
525
526
527 if (writeOp.getMask())
528 return failure();
529
530 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
531 return success();
532
533 return failure();
534 }
535
536 private:
537
538
539
540 vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
541 WarpExecuteOnLane0Op warpOp,
542 vector::TransferWriteOp writeOp,
543 VectorType targetType,
544 VectorType maybeMaskType) const {
545 assert(writeOp->getParentOp() == warpOp &&
546 "write must be nested immediately under warp");
549 WarpExecuteOnLane0Op newWarpOp;
550 if (maybeMaskType) {
551 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
552 rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
553 TypeRange{targetType, maybeMaskType}, newRetIndices);
554 } else {
555 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
556 rewriter, warpOp, ValueRange{{writeOp.getVector()}},
557 TypeRange{targetType}, newRetIndices);
558 }
560 auto newWriteOp =
561 castvector::TransferWriteOp(rewriter.clone(*writeOp.getOperation()));
562 rewriter.eraseOp(writeOp);
563 newWriteOp.getValueToStoreMutable().assign(
564 newWarpOp.getResult(newRetIndices[0]));
565 if (maybeMaskType)
566 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
567 return newWriteOp;
568 }
569
571 unsigned maxNumElementsToExtract = 1;
572 };
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
593 using Base::Base;
594 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
598 });
599 if (!yieldOperand)
600 return failure();
601
604 Value distributedVal = warpOp.getResult(operandIndex);
607 Location loc = warpOp.getLoc();
609 Type targetType;
610 if (auto vecType = dyn_cast(distributedVal.getType())) {
611
612 auto operandType = cast(operand.get().getType());
613 targetType =
614 VectorType::get(vecType.getShape(), operandType.getElementType());
615 } else {
616 auto operandType = operand.get().getType();
617 assert(!isa(operandType) &&
618 "unexpected yield of vector from op with scalar result type");
619 targetType = operandType;
620 }
621 retTypes.push_back(targetType);
622 yieldValues.push_back(operand.get());
623 }
625 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
626 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
630 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
631 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
632 }
636 rewriter, loc, elementWise, newOperands,
637 {newWarpOp.getResult(operandIndex).getType()});
640 return success();
641 }
642 };
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
659 using Base::Base;
660 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
663 getWarpResult(warpOp, llvm::IsaPredarith::ConstantOp);
664 if (!yieldOperand)
665 return failure();
666 auto constantOp = yieldOperand->get().getDefiningOparith::ConstantOp();
667 auto dense = dyn_cast(constantOp.getValue());
668 if (!dense)
669 return failure();
670
671
676 cast(warpOp.getResult(operandIndex).getType()), scalarAttr);
677 Location loc = warpOp.getLoc();
679 Value distConstant = rewriter.createarith::ConstantOp(loc, newAttr);
680 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
682 return success();
683 }
684 };
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
705 using Base::Base;
706 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
708
709
710
712
713 return isavector::TransferReadOp(op) && op->hasOneUse();
714 });
715 if (!operand)
717 warpOp, "warp result is not a vector.transfer_read op");
718 auto read = operand->get().getDefiningOpvector::TransferReadOp();
719
720
721 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
723 read, "source must be defined outside of the region");
724
726 Value distributedVal = warpOp.getResult(operandIndex);
727
729 read.getIndices().end());
730 auto sequentialType = cast(read.getResult().getType());
731 auto distributedType = cast(distributedVal.getType());
734
735
736
738 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
739 distributedType.getShape(), warpOp.getWarpSize(),
740 warpOp.getLaneid(), delinearizedIds)) {
742 read, "cannot delinearize lane ID for distribution");
743 }
744 assert(!delinearizedIds.empty() || map.getNumResults() == 0);
745
746
748 SmallVector additionalResults(indices.begin(), indices.end());
751 additionalResults.push_back(read.getPadding());
752 additionalResultTypes.push_back(read.getPadding().getType());
753
754 bool hasMask = false;
755 if (read.getMask()) {
756 hasMask = true;
757
758
759
760
761
762
765 read, "non-trivial permutation maps not supported");
766 VectorType maskType =
767 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
768 additionalResults.push_back(read.getMask());
769 additionalResultTypes.push_back(maskType);
770 }
771
773 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
774 rewriter, warpOp, additionalResults, additionalResultTypes,
775 newRetIndices);
776 distributedVal = newWarpOp.getResult(operandIndex);
777
778
780 for (int64_t i = 0, e = indices.size(); i < e; ++i)
781 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
782
786 bindDims(read.getContext(), d0, d1);
787 auto indexExpr = dyn_cast(std::get<0>(it));
788 if (!indexExpr)
789 continue;
790 unsigned indexPos = indexExpr.getPosition();
791 unsigned vectorPos = cast(std::get<1>(it)).getPosition();
792 int64_t scale = distributedType.getDimSize(vectorPos);
794 rewriter, read.getLoc(), d0 + scale * d1,
795 {newIndices[indexPos], delinearizedIds[vectorPos]});
796 }
797
798
799 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
800
802 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
804 auto newRead = rewriter.createvector::TransferReadOp(
805 read.getLoc(), distributedVal.getType(), read.getBase(), newIndices,
806 read.getPermutationMapAttr(), newPadding, newMask,
807 read.getInBoundsAttr());
808
810 return success();
811 }
812 };
813
814
815
817 using Base::Base;
818 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
821 newResultTypes.reserve(warpOp->getNumResults());
823 newYieldValues.reserve(warpOp->getNumResults());
826 auto yield = castgpu::YieldOp(
827 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
828
829
830
831
832
833
834
835
836
837 for (OpResult result : warpOp.getResults()) {
838 Value yieldOperand = yield.getOperand(result.getResultNumber());
839 auto it = dedupYieldOperandPositionMap.insert(
840 std::make_pair(yieldOperand, newResultTypes.size()));
841 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
842 if (result.use_empty() || !it.second)
843 continue;
844 newResultTypes.push_back(result.getType());
845 newYieldValues.push_back(yieldOperand);
846 }
847
848 if (yield.getNumOperands() == newYieldValues.size())
849 return failure();
850
851 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
852 rewriter, warpOp, newYieldValues, newResultTypes);
853
854
855 newWarpOp.getBody()->walk([&](Operation *op) {
858 });
859
860
862 newValues.reserve(warpOp->getNumResults());
863 for (OpResult result : warpOp.getResults()) {
864 if (result.use_empty())
865 newValues.push_back(Value());
866 else
867 newValues.push_back(
868 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
869 }
870 rewriter.replaceOp(warpOp, newValues);
871 return success();
872 }
873 };
874
875
876
878 using Base::Base;
879 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
881 auto yield = castgpu::YieldOp(
882 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
883 Value valForwarded;
884 unsigned resultIndex;
885 for (OpOperand &operand : yield->getOpOperands()) {
888 continue;
889
890
891 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
893 continue;
894 valForwarded = operand.get();
896 break;
897 }
898 auto arg = dyn_cast(operand.get());
899 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
900 continue;
901 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
903 continue;
904 valForwarded = warpOperand;
906 break;
907 }
908 if (!valForwarded)
909 return failure();
910
911
913 rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
915 return success();
916 }
917 };
918
920 using Base::Base;
921 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
924 getWarpResult(warpOp, llvm::IsaPredvector::BroadcastOp);
925 if (!operand)
926 return failure();
928 auto broadcastOp = operand->get().getDefiningOpvector::BroadcastOp();
929 Location loc = broadcastOp.getLoc();
930 auto destVecType =
931 cast(warpOp->getResultTypes()[operandNumber]);
932 Value broadcastSrc = broadcastOp.getSource();
933 Type broadcastSrcType = broadcastSrc.getType();
934
935
936
937
938
941 return failure();
943 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
944 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
946 Value broadcasted = rewriter.createvector::BroadcastOp(
947 loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
949 broadcasted);
950 return success();
951 }
952 };
953
954
955
957 using Base::Base;
958 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
961 getWarpResult(warpOp, llvm::IsaPredvector::ShapeCastOp);
962 if (!operand)
963 return failure();
964
965 auto oldCastOp = operand->get().getDefiningOpvector::ShapeCastOp();
966
968 auto castDistributedType =
969 cast(warpOp->getResultTypes()[operandNumber]);
970 VectorType castOriginalType = oldCastOp.getSourceVectorType();
971 VectorType castResultType = castDistributedType;
972
973
974
975 unsigned castDistributedRank = castDistributedType.getRank();
976 unsigned castOriginalRank = castOriginalType.getRank();
977 if (castDistributedRank < castOriginalRank) {
979 llvm::append_range(shape, castDistributedType.getShape());
980 castDistributedType =
981 VectorType::get(shape, castDistributedType.getElementType());
982 }
983
985 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
986 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
987 newRetIndices);
989 Value newCast = rewriter.createvector::ShapeCastOp(
990 oldCastOp.getLoc(), castResultType,
991 newWarpOp->getResult(newRetIndices[0]));
992 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
993 return success();
994 }
995 };
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1016 using Base::Base;
1017 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1020 getWarpResult(warpOp, llvm::IsaPredvector::CreateMaskOp);
1021 if (!yieldOperand)
1022 return failure();
1023
1024 auto mask = yieldOperand->get().getDefiningOpvector::CreateMaskOp();
1025
1026
1027
1028 if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1029 return warpOp.isDefinedOutsideOfRegion(value);
1030 }))
1031 return failure();
1032
1033 Location loc = mask.getLoc();
1035
1036 auto distType = cast(warpOp.getResult(operandIndex).getType());
1037 VectorType seqType = mask.getVectorType();
1040
1042
1043
1045 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1046 warpOp.getWarpSize(), warpOp.getLaneid(),
1047 delinearizedIds))
1049 mask, "cannot delinearize lane ID for distribution");
1050 assert(!delinearizedIds.empty());
1051
1052
1053
1055
1059 for (int i = 0, e = distShape.size(); i < e; ++i) {
1060
1061
1062
1063
1064
1066 rewriter, loc, s1 - s0 * distShape[i],
1067 {delinearizedIds[i], mask.getOperand(i)});
1068 newOperands.push_back(maskDimIdx);
1069 }
1070
1071 auto newMask =
1072 rewriter.createvector::CreateMaskOp(loc, distType, newOperands);
1073 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1075 return success();
1076 }
1077 };
1078
1079
1080
1082 using Base::Base;
1083 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1086 getWarpResult(warpOp, llvm::IsaPredvector::ExtractOp);
1087 if (!operand)
1088 return failure();
1090 auto extractOp = operand->get().getDefiningOpvector::ExtractOp();
1091 VectorType extractSrcType = extractOp.getSourceVectorType();
1092 Location loc = extractOp.getLoc();
1093
1094
1095 if (extractSrcType.getRank() <= 1) {
1096 return failure();
1097 }
1098
1099
1100
1101 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1102
1103
1104
1105
1106
1108 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1109 rewriter, warpOp, {extractOp.getVector()},
1110 {extractOp.getSourceVectorType()}, newRetIndices);
1112 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1113
1114 Value newExtract = rewriter.createvector::ExtractOp(
1115 loc, distributedVec, extractOp.getMixedPosition());
1117 newExtract);
1118 return success();
1119 }
1120
1121
1122 auto distributedType =
1123 cast(warpOp.getResult(operandNumber).getType());
1124 auto yieldedType = cast(operand->get().getType());
1125 int64_t distributedDim = -1;
1126 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1127 if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1128
1129
1130 assert(distributedDim == -1 && "found multiple distributed dims");
1131 distributedDim = i;
1132 }
1133 }
1134 assert(distributedDim != -1 && "could not find distributed dimension");
1135 (void)distributedDim;
1136
1137
1139 for (int i = 0; i < distributedType.getRank(); ++i)
1140 newDistributedShape[i + extractOp.getNumIndices()] =
1141 distributedType.getDimSize(i);
1142 auto newDistributedType =
1143 VectorType::get(newDistributedShape, distributedType.getElementType());
1145 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1146 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1147 newRetIndices);
1149 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1150
1151 Value newExtract = rewriter.createvector::ExtractOp(
1152 loc, distributedVec, extractOp.getMixedPosition());
1154 newExtract);
1155 return success();
1156 }
1157 };
1158
1159
1160
1162 WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1165 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1168 getWarpResult(warpOp, llvm::IsaPredvector::ExtractOp);
1169 if (!operand)
1170 return failure();
1172 auto extractOp = operand->get().getDefiningOpvector::ExtractOp();
1173 VectorType extractSrcType = extractOp.getSourceVectorType();
1174
1175 if (extractSrcType.getRank() > 1) {
1177 extractOp, "only 0-D or 1-D source supported for now");
1178 }
1179
1180
1181 if (!extractSrcType.getElementType().isF32() &&
1182 !extractSrcType.getElementType().isInteger(32))
1184 extractOp, "only f32/i32 element types are supported");
1185 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1186 Type elType = extractSrcType.getElementType();
1187 VectorType distributedVecType;
1188 if (!is0dOrVec1Extract) {
1189 assert(extractSrcType.getRank() == 1 &&
1190 "expected that extract src rank is 0 or 1");
1191 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1192 return failure();
1193 int64_t elementsPerLane =
1194 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1195 distributedVecType = VectorType::get({elementsPerLane}, elType);
1196 } else {
1197 distributedVecType = extractSrcType;
1198 }
1199
1202 additionalResults.append(
1204 additionalResultTypes.append(
1206
1207 Location loc = extractOp.getLoc();
1209 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1210 rewriter, warpOp, additionalResults, additionalResultTypes,
1211 newRetIndices);
1213 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1214
1215
1216
1217 if (is0dOrVec1Extract) {
1218 Value newExtract;
1220 newExtract =
1221 rewriter.createvector::ExtractOp(loc, distributedVec, indices);
1223 newExtract);
1224 return success();
1225 }
1226
1227 int64_t staticPos = extractOp.getStaticPosition()[0];
1228 OpFoldResult pos = ShapedType::isDynamic(staticPos)
1229 ? (newWarpOp->getResult(newRetIndices[1]))
1231
1232
1233 int64_t elementsPerLane = distributedVecType.getShape()[0];
1235
1237 rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1238
1240 elementsPerLane == 1
1241 ? rewriter.createarith::ConstantIndexOp(loc, 0).getResult()
1243 sym0 % elementsPerLane, pos);
1244 Value extracted =
1245 rewriter.createvector::ExtractOp(loc, distributedVec, newPos);
1246
1247
1248 Value shuffled = warpShuffleFromIdxFn(
1249 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1250 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1251 return success();
1252 }
1253
1254 private:
1255 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1256 };
1257
1258
1260 using Base::Base;
1261 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1264 getWarpResult(warpOp, llvm::IsaPredvector::ExtractElementOp);
1265 if (!operand)
1266 return failure();
1267 auto extractOp = operand->get().getDefiningOpvector::ExtractElementOp();
1269 if (auto pos = extractOp.getPosition()) {
1270 indices.push_back(pos);
1271 }
1274 extractOp, extractOp.getVector(), indices);
1275 return success();
1276 }
1277 };
1278
1279
1280
1282 using Base::Base;
1283 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1285 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPredvector::InsertOp);
1286 if (!operand)
1287 return failure();
1289 auto insertOp = operand->get().getDefiningOpvector::InsertOp();
1290 VectorType vecType = insertOp.getDestVectorType();
1291 VectorType distrType =
1292 cast(warpOp.getResult(operandNumber).getType());
1293
1294
1295 if (vecType.getRank() > 1) {
1297 insertOp, "only 0-D or 1-D source supported for now");
1298 }
1299
1300
1302 insertOp.getValueToStore()};
1304 distrType, insertOp.getValueToStore().getType()};
1305 additionalResults.append(SmallVector(insertOp.getDynamicPosition()));
1306 additionalResultTypes.append(
1308
1309 Location loc = insertOp.getLoc();
1311 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1312 rewriter, warpOp, additionalResults, additionalResultTypes,
1313 newRetIndices);
1315 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1316 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1318
1320 if (vecType.getRank() != 0) {
1321 int64_t staticPos = insertOp.getStaticPosition()[0];
1322 pos = ShapedType::isDynamic(staticPos)
1323 ? (newWarpOp->getResult(newRetIndices[2]))
1325 }
1326
1327
1328 if (vecType == distrType) {
1329 Value newInsert;
1331 if (pos) {
1332 indices.push_back(pos);
1333 }
1334 newInsert = rewriter.createvector::InsertOp(loc, newSource,
1335 distributedVec, indices);
1336
1338 newInsert);
1339 return success();
1340 }
1341
1342
1343 int64_t elementsPerLane = distrType.getShape()[0];
1345
1347 rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1348
1350 rewriter, loc, sym0 % elementsPerLane, pos);
1351 Value isInsertingLane = rewriter.createarith::CmpIOp(
1352 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1353 Value newResult =
1354 rewriter
1356 loc, isInsertingLane,
1357
1359 Value newInsert = builder.createvector::InsertOp(
1360 loc, newSource, distributedVec, newPos);
1361 builder.createscf::YieldOp(loc, newInsert);
1362 },
1363
1365 builder.createscf::YieldOp(loc, distributedVec);
1366 })
1367 .getResult(0);
1368 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1369 return success();
1370 }
1371 };
1372
1374 using Base::Base;
1375 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1377 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPredvector::InsertOp);
1378 if (!operand)
1379 return failure();
1381 auto insertOp = operand->get().getDefiningOpvector::InsertOp();
1382 Location loc = insertOp.getLoc();
1383
1384
1385 if (insertOp.getDestVectorType().getRank() <= 1) {
1386 return failure();
1387 }
1388
1389
1390
1391 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1392
1393
1395 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1396 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1397 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1398 newRetIndices);
1400 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1401 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1402 Value newResult = rewriter.createvector::InsertOp(
1403 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1405 newResult);
1406 return success();
1407 }
1408
1409
1410 auto distrDestType =
1411 cast(warpOp.getResult(operandNumber).getType());
1412 auto yieldedType = cast(operand->get().getType());
1413 int64_t distrDestDim = -1;
1414 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1415 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1416
1417
1418 assert(distrDestDim == -1 && "found multiple distributed dims");
1419 distrDestDim = i;
1420 }
1421 }
1422 assert(distrDestDim != -1 && "could not find distributed dimension");
1423
1424
1425 VectorType srcVecType = cast(insertOp.getValueToStoreType());
1427
1428
1429
1430
1431
1432
1433 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1434 if (distrSrcDim >= 0)
1435 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1436 auto distrSrcType =
1437 VectorType::get(distrSrcShape, distrDestType.getElementType());
1438
1439
1441 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1442 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1443 {distrSrcType, distrDestType}, newRetIndices);
1445 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1446 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1447
1448
1449 Value newResult;
1450 if (distrSrcDim >= 0) {
1451
1452 newResult = rewriter.createvector::InsertOp(
1453 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1454 } else {
1455
1456 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1459
1460 Value insertingLane = rewriter.createarith::ConstantIndexOp(
1461 loc, newPos[distrDestDim] / elementsPerLane);
1462 Value isInsertingLane = rewriter.createarith::CmpIOp(
1463 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1464
1465 newPos[distrDestDim] %= elementsPerLane;
1467 Value newInsert = builder.createvector::InsertOp(
1468 loc, distributedSrc, distributedDest, newPos);
1469 builder.createscf::YieldOp(loc, newInsert);
1470 };
1471 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1472 builder.createscf::YieldOp(loc, distributedDest);
1473 };
1474 newResult = rewriter
1475 .createscf::IfOp(loc, isInsertingLane,
1476 insertingBuilder,
1477 nonInsertingBuilder)
1478 .getResult(0);
1479 }
1480
1481 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1482 return success();
1483 }
1484 };
1485
1487 using Base::Base;
1488 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1491 getWarpResult(warpOp, llvm::IsaPredvector::InsertElementOp);
1492 if (!operand)
1493 return failure();
1494 auto insertOp = operand->get().getDefiningOpvector::InsertElementOp();
1496 if (auto pos = insertOp.getPosition()) {
1497 indices.push_back(pos);
1498 }
1501 insertOp, insertOp.getSource(), insertOp.getDest(), indices);
1502 return success();
1503 }
1504 };
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1539
1542 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1544 auto yield = castgpu::YieldOp(
1545 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1546
1547 Operation *lastNode = yield->getPrevNode();
1548 auto forOp = dyn_cast_or_nullscf::ForOp(lastNode);
1549 if (!forOp)
1550 return failure();
1551
1552
1553
1554 llvm::SmallSetVector<Value, 32> escapingValues;
1558 forOp.getBodyRegion(), [&](OpOperand *operand) {
1559 Operation *parent = operand->get().getParentRegion()->getParentOp();
1560 if (warpOp->isAncestor(parent)) {
1561 if (!escapingValues.insert(operand->get()))
1562 return;
1563 Type distType = operand->get().getType();
1564 if (auto vecType = dyn_cast(distType)) {
1565 AffineMap map = distributionMapFn(operand->get());
1566 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1567 }
1568 inputTypes.push_back(operand->get().getType());
1569 distTypes.push_back(distType);
1570 }
1571 });
1572
1573 if (llvm::is_contained(distTypes, Type{}))
1574 return failure();
1575
1577 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1578 rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1579 newRetIndices);
1580 yield = castgpu::YieldOp(
1581 newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1582
1585
1586 for (OpOperand &yieldOperand : yield->getOpOperands()) {
1587 if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1588 continue;
1589 auto forResult = cast(yieldOperand.get());
1590 newOperands.push_back(
1592 yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1594 }
1595
1598
1599
1600
1601 auto newForOp = rewriter.createscf::ForOp(
1602 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1603 forOp.getStep(), newOperands);
1605
1607 newForOp.getRegionIterArgs().end());
1609 forOp.getResultTypes().end());
1610 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1611 for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1612 warpInput.push_back(newWarpOp.getResult(retIdx));
1613 argIndexMapping[escapingValues[i]] = warpInputType.size();
1614 warpInputType.push_back(inputTypes[i]);
1615 }
1616 auto innerWarp = rewriter.create(
1617 newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1618 newWarpOp.getWarpSize(), warpInput, warpInputType);
1619
1621 argMapping.push_back(newForOp.getInductionVar());
1622 for (Value args : innerWarp.getBody()->getArguments()) {
1623 argMapping.push_back(args);
1624 }
1625 argMapping.resize(forOp.getBody()->getNumArguments());
1627 for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1628 yieldOperands.push_back(operand);
1629 rewriter.eraseOp(forOp.getBody()->getTerminator());
1630 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1632 rewriter.creategpu::YieldOp(innerWarp.getLoc(), yieldOperands);
1634 if (!innerWarp.getResults().empty())
1635 rewriter.createscf::YieldOp(forOp.getLoc(), innerWarp.getResults());
1636 rewriter.eraseOp(forOp);
1637
1640 newForOp.getResult(res.index()));
1641 newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1642 }
1643 newForOp.walk([&](Operation *op) {
1645 auto it = argIndexMapping.find(operand.get());
1646 if (it == argIndexMapping.end())
1647 continue;
1648 operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1649 }
1650 });
1651
1652
1653 mlir::vector::moveScalarUniformCode(innerWarp);
1654 return success();
1655 }
1656
1657 private:
1659 };
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1682 DistributedReductionFn distributedReductionFn,
1685 distributedReductionFn(std::move(distributedReductionFn)) {}
1686
1687 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1690 getWarpResult(warpOp, llvm::IsaPredvector::ReductionOp);
1691 if (!yieldOperand)
1692 return failure();
1693
1694 auto reductionOp =
1695 castvector::ReductionOp(yieldOperand->get().getDefiningOp());
1696 auto vectorType = cast(reductionOp.getVector().getType());
1697
1698 if (vectorType.getRank() != 1)
1700 warpOp, "Only rank 1 reductions can be distributed.");
1701
1702 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1704 warpOp, "Reduction vector dimension must match was size.");
1705 if (!reductionOp.getType().isIntOrFloat())
1707 warpOp, "Reduction distribution currently only supports floats and "
1708 "integer types.");
1709
1710 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1711
1716 if (reductionOp.getAcc()) {
1717 yieldValues.push_back(reductionOp.getAcc());
1718 retTypes.push_back(reductionOp.getAcc().getType());
1719 }
1721 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1722 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1724
1725
1726 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1727
1728 Value fullReduce =
1729 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1730 reductionOp.getKind(), newWarpOp.getWarpSize());
1731 if (reductionOp.getAcc()) {
1733 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1734 newWarpOp.getResult(newRetIndices[1]));
1735 }
1736 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
1737 return success();
1738 }
1739
1740 private:
1741 DistributedReductionFn distributedReductionFn;
1742 };
1743
1744 }
1745
1750 }
1751
1752 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1754 unsigned maxNumElementsToExtract, PatternBenefit benefit) {
1755 patterns.add(patterns.getContext(), distributionMapFn,
1756 maxNumElementsToExtract, benefit);
1757 }
1758
1759 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1761 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
1763 patterns.add(patterns.getContext(), readBenefit);
1764 patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1765 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
1766 WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
1767 WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
1768 patterns.getContext(), benefit);
1769 patterns.add(patterns.getContext(), warpShuffleFromIdxFn,
1770 benefit);
1771 patterns.add(patterns.getContext(), distributionMapFn,
1772 benefit);
1773 }
1774
1775 void mlir::vector::populateDistributeReduction(
1777 const DistributedReductionFn &distributedReductionFn,
1779 patterns.add(patterns.getContext(), distributedReductionFn,
1780 benefit);
1781 }
1782
1783
1786 return llvm::all_of(op->getOperands(), definedOutside) &&
1788 }
1789
1790 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1791 Block *body = warpOp.getBody();
1792
1793
1794 llvm::SmallSetVector<Operation *, 8> opsToMove;
1795
1796
1797 auto isDefinedOutsideOfBody = [&](Value value) {
1799 return (definingOp && opsToMove.count(definingOp)) ||
1800 warpOp.isDefinedOutsideOfRegion(value);
1801 };
1802
1803
1804
1806 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
1807 return isa(result.getType());
1808 });
1809 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
1810 opsToMove.insert(&op);
1811 }
1812
1813
1816 }
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.