MLIR: lib/Dialect/GPU/IR/GPUDialect.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
14
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/ErrorHandling.h"
38 #include "llvm/Support/FormatVariadic.h"
39 #include "llvm/Support/InterleavedRange.h"
40 #include "llvm/Support/StringSaver.h"
41 #include
42 #include
43
44 using namespace mlir;
46
47 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
48
49
50
51
52
53 int64_t GPUBlockMappingAttr::getMappingId() const {
54 return static_cast<int64_t>(getBlock());
55 }
56
57 bool GPUBlockMappingAttr::isLinearMapping() const {
58 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
59 }
60
61 int64_t GPUBlockMappingAttr::getRelativeIndex() const {
62 return isLinearMapping()
63 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
64 : getMappingId();
65 }
66
67 int64_t GPUWarpgroupMappingAttr::getMappingId() const {
68 return static_cast<int64_t>(getWarpgroup());
69 }
70
71 bool GPUWarpgroupMappingAttr::isLinearMapping() const {
72 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
73 }
74
75 int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
76 return isLinearMapping()
77 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
78 : getMappingId();
79 }
80
81 int64_t GPUWarpMappingAttr::getMappingId() const {
82 return static_cast<int64_t>(getWarp());
83 }
84
85 bool GPUWarpMappingAttr::isLinearMapping() const {
86 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
87 }
88
89 int64_t GPUWarpMappingAttr::getRelativeIndex() const {
90 return isLinearMapping()
91 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
92 : getMappingId();
93 }
94
95 int64_t GPUThreadMappingAttr::getMappingId() const {
96 return static_cast<int64_t>(getThread());
97 }
98
99 bool GPUThreadMappingAttr::isLinearMapping() const {
100 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
101 }
102
103 int64_t GPUThreadMappingAttr::getRelativeIndex() const {
104 return isLinearMapping()
105 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
106 : getMappingId();
107 }
108
109 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
110 return static_cast<int64_t>(getAddressSpace());
111 }
112
113 bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
114 llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
115 }
116
117 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
118 llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
119 }
120
121
122
123
124
126 StringRef operand) {
128 }
129
133 StringRef operand) {
135 elementType, operand);
136 }
137
139
141 return getImpl()->getShape();
142 }
143
145
147
149 return elementType.isF16() || elementType.isF32() ||
152 }
153
154 LogicalResult
157 StringRef operand) {
158 if (operand != "AOp" && operand != "BOp" && operand != "COp")
159 return emitError() << "operand expected to be one of AOp, BOp or COp";
160
161 if (shape.size() != 2)
162 return emitError() << "MMAMatrixType must have exactly two dimensions";
163
166 << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
167
168 return success();
169 }
170
171
172
173
174
175 bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
176 if (!memorySpace)
177 return false;
178 if (auto gpuAttr = llvm::dyn_castgpu::AddressSpaceAttr(memorySpace))
179 return gpuAttr.getValue() == getWorkgroupAddressSpace();
180 return false;
181 }
182
183 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
184 Attribute memorySpace = type.getMemorySpace();
185 return isWorkgroupMemoryAddressSpace(memorySpace);
186 }
187
188 bool GPUDialect::isKernel(Operation *op) {
189 UnitAttr isKernelAttr = op->getAttrOfType(getKernelFuncAttrName());
190 return static_cast<bool>(isKernelAttr);
191 }
192
193 namespace {
194
195
198
199
201 return true;
202 }
203 };
204 }
205
206 void GPUDialect::initialize() {
207 addTypes();
208 addTypes();
209 addTypes();
210 addTypes();
211 addTypes();
212 addOperations<
213 #define GET_OP_LIST
214 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
215 >();
216 addAttributes<
217 #define GET_ATTRDEF_LIST
218 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
219 >();
220 addInterfaces();
221 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
222 TerminatorOp>();
223 declarePromisedInterfaces<
224 ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
225 ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
226 SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
227 }
228
230 switch (kind) {
232 return "sparse.dntensor_handle";
234 return "sparse.spmat_handle";
236 return "sparse.spgemmop_handle";
237 }
238 llvm_unreachable("unknown sparse handle kind");
239 return "";
240 }
241
243
244 StringRef keyword;
246 return Type();
248
249
250 if (keyword == "async.token")
252
253 if (keyword == "mma_matrix") {
254 SMLoc beginLoc = parser.getNameLoc();
255
256
258 return nullptr;
259
260
262 Type elementType;
265 return nullptr;
266
267
269 return nullptr;
270
271
272 std::string operand;
274 return nullptr;
275
276
278 return nullptr;
279
282 shape, elementType, operand);
283 }
284
291
293 return Type();
294 }
295
296
300 .Case([&](Type) {
302 })
303 .Case(
305 .Case([&](Type) {
307 })
309 os << "mma_matrix<";
310 auto shape = fragTy.getShape();
311 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
312 os << *dim << 'x';
313 os << shape.back() << 'x' << fragTy.getElementType();
314 os << ", \"" << fragTy.getOperand() << "\"" << '>';
315 })
316 .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
317 }
318
321 auto array = dyn_cast(attr.getValue());
322 if (!array)
324 " must be a dense i32 array");
325 if (array.size() != 3)
327 " must contain exactly 3 elements");
328 return success();
329 }
330
331 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
333 if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
335 if (attr.getName() == getKnownGridSizeAttrHelper().getName())
337 if (!llvm::isa(attr.getValue()) ||
338 attr.getName() != getContainerModuleAttrName())
339 return success();
340
341 auto module = dyn_cast(op);
342 if (!module)
343 return op->emitError("expected '")
344 << getContainerModuleAttrName() << "' attribute to be attached to '"
345 << ModuleOp::getOperationName() << '\'';
346
347 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
348
349
350 if (!launchOp->getParentOp() ||
351 launchOp->getParentOp()->getParentOp() != module)
352 return success();
353
354
355
356 if (!launchOp->getAttrOfType(
357 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
358 return success();
359
360
361 StringAttr kernelContainerName = launchOp.getKernelModuleName();
362 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
363 if (!kernelContainer)
365 << "kernel container '" << kernelContainerName.getValue()
366 << "' is undefined";
367
368
369 if (isa(kernelContainer))
370 return success();
371
372 auto kernelModule = dyn_cast(kernelContainer);
373 if (!kernelModule)
374 return launchOp.emitOpError()
375 << "kernel module '" << kernelContainerName.getValue()
376 << "' is undefined";
377
378
379 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
380 if (!kernelFunc)
381 return launchOp.emitOpError("kernel function '")
382 << launchOp.getKernel() << "' is undefined";
383 auto kernelConvertedFunction = dyn_cast(kernelFunc);
384 if (!kernelConvertedFunction) {
386 << "referenced kernel '" << launchOp.getKernel()
387 << "' is not a function";
388 diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
390 }
391
393 GPUDialect::getKernelFuncAttrName()))
394 return launchOp.emitOpError("kernel function is missing the '")
395 << GPUDialect::getKernelFuncAttrName() << "' attribute";
396
397
398
399
400 auto kernelGPUFunction = dyn_castgpu::GPUFuncOp(kernelFunc);
401 if (!kernelGPUFunction)
402 return success();
403
404 unsigned actualNumArguments = launchOp.getNumKernelOperands();
405 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
406 if (expectedNumArguments != actualNumArguments)
407 return launchOp.emitOpError("got ")
408 << actualNumArguments << " kernel operands but expected "
409 << expectedNumArguments;
410
411 auto functionType = kernelGPUFunction.getFunctionType();
412 for (unsigned i = 0; i < expectedNumArguments; ++i) {
413 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
414 return launchOp.emitOpError("type of function argument ")
415 << i << " does not match";
416 }
417 }
418
419 return success();
420 });
421
422 return walkResult.wasInterrupted() ? failure() : success();
423 }
424
425
426
427
428
435 return parser.emitError(loc, "needs to be named when marked 'async'");
437 }
440 }
441
442
443
444
446 Type asyncTokenType,
448 if (asyncTokenType)
449 printer << "async";
450 if (asyncDependencies.empty())
451 return;
452 if (asyncTokenType)
453 printer << ' ';
454 printer << llvm::interleaved_array(asyncDependencies);
455 }
456
457
458
459
460
461
462
463
464
465 static ParseResult
468
470 return success();
471
473 true);
474 }
475
476
479 if (values.empty())
480 return;
481
483 return llvm::formatv("{} : {}", v, v.getType());
484 };
485 p << ' ' << keyword << '('
486 << llvm::interleaved(llvm::map_range(values, printBlockArg)) << ')';
487 }
488
489
492 gpu::AddressSpace memorySpace) {
493 for (Value v : attributions) {
494 auto type = llvm::dyn_cast(v.getType());
495 if (!type)
496 return op->emitOpError() << "expected memref type in attribution";
497
498
499
500 auto addressSpace =
501 llvm::dyn_cast_or_nullgpu::AddressSpaceAttr(type.getMemorySpace());
502 if (!addressSpace)
503 continue;
504 if (addressSpace.getValue() != memorySpace)
506 << "expected memory space " << stringifyAddressSpace(memorySpace)
507 << " in attribution";
508 }
509 return success();
510 }
511
512
513
514
515
517 Type resType) {
518 using Kind = gpu::AllReduceOperation;
519 if (llvm::is_contained(
520 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
521 opName)) {
522 if (!isa(resType))
523 return failure();
524 }
525
526 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
527 Kind::AND, Kind::OR, Kind::XOR},
528 opName)) {
529 if (!isa(resType))
530 return failure();
531 }
532
533 return success();
534 }
535
536 LogicalResult gpu::AllReduceOp::verifyRegions() {
537 if (getBody().empty() != getOp().has_value())
538 return emitError("expected either an op attribute or a non-empty body");
539 if (!getBody().empty()) {
540 if (getBody().getNumArguments() != 2)
541 return emitError("expected two region arguments");
542 for (auto argument : getBody().getArguments()) {
543 if (argument.getType() != getType())
544 return emitError("incorrect region argument type");
545 }
546 unsigned yieldCount = 0;
547 for (Block &block : getBody()) {
548 if (auto yield = dyn_castgpu::YieldOp(block.getTerminator())) {
549 if (yield.getNumOperands() != 1)
550 return emitError("expected one gpu.yield operand");
551 if (yield.getOperand(0).getType() != getType())
552 return emitError("incorrect gpu.yield type");
553 ++yieldCount;
554 }
555 }
556 if (yieldCount == 0)
557 return emitError("expected gpu.yield op in region");
558 } else {
559 gpu::AllReduceOperation opName = *getOp();
561 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
562 << "` reduction operation is not compatible with type "
564 }
565 }
566
567 return success();
568 }
569
571 auto launchOp = dyn_castgpu::LaunchOp(op->getParentOp());
572 if (!launchOp)
573 return false;
574
575 Region &body = launchOp.getBody();
576 assert(!body.empty() && "Invalid region");
577
578
580 }
581
582 OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor ) {
584 setUniform(true);
585 return getResult();
586 }
587
588 return nullptr;
589 }
590
591
593 AllReduceOperationAttr &attr) {
594 StringRef enumStr;
596 std::optional op =
597 gpu::symbolizeAllReduceOperation(enumStr);
598 if (!op)
601 }
602 return success();
603 }
604
606 AllReduceOperationAttr attr) {
607 if (attr)
608 attr.print(printer);
609 }
610
611
612
613
614
617 if (auto vecTy = dyn_cast(elemType)) {
618 if (vecTy.isScalable())
619 return emitOpError() << "is not compatible with scalable vector types";
620
621 elemType = vecTy.getElementType();
622 }
623
624 gpu::AllReduceOperation opName = getOp();
626 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
627 << "` reduction operation is not compatible with type "
629 }
630
631 auto clusterSize = getClusterSize();
632 if (clusterSize) {
633 uint32_t size = *clusterSize;
634 if (!llvm::isPowerOf2_32(size)) {
635 return emitOpError() << "cluster size " << size
636 << " is not a power of two";
637 }
638 }
639
640 uint32_t stride = getClusterStride();
641 if (stride != 1 && !clusterSize) {
642 return emitOpError() << "cluster stride can only be specified if cluster "
643 "size is specified";
644 }
645 if (!llvm::isPowerOf2_32(stride)) {
646 return emitOpError() << "cluster stride " << stride
647 << " is not a power of two";
648 }
649
650 return success();
651 }
652
653 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
654 if (getClusterSize() == 1)
655 return getValue();
656
658 setUniform(true);
659 return getResult();
660 }
661
662 return nullptr;
663 }
664
665
666
667
668
671 if (!op->template hasTraitOpTrait::AttrSizedOperandSegments())
672 return;
673 auto attrName =
675 auto sizeAttr = op->template getAttrOfType(attrName);
676
677
678 if (!sizeAttr)
679 return;
680
682 ++sizes.front();
684 }
685
686
687
688
689
692 Value getBlockSizeX, Value getBlockSizeY,
693 Value getBlockSizeZ, Value dynamicSharedMemorySize,
697 Value clusterSizeY, Value clusterSizeZ) {
699
700
701
702 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
704
705
707 if (asyncTokenType)
709
710
711 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
712 getBlockSizeY, getBlockSizeZ});
713 if (clusterSizeX)
715 if (clusterSizeY)
717 if (clusterSizeZ)
719 if (dynamicSharedMemorySize)
720 result.addOperands(dynamicSharedMemorySize);
721
722
723
724
727
728 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
730
731 for (Type argTy : workgroupAttributions)
733 for (Type argTy : privateAttributions)
735
737 segmentSizes.front() = asyncDependencies.size();
738 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
739 segmentSizes[7] = clusterSizeX ? 1 : 0;
740 segmentSizes[8] = clusterSizeY ? 1 : 0;
741 segmentSizes[9] = clusterSizeZ ? 1 : 0;
742 result.addAttribute(getOperandSegmentSizeAttr(),
744 }
745
747 assert(!getBody().empty() && "LaunchOp body must not be empty.");
748 auto args = getBody().getArguments();
749 return KernelDim3{args[0], args[1], args[2]};
750 }
751
752 KernelDim3 LaunchOp::getThreadIds() {
753 assert(!getBody().empty() && "LaunchOp body must not be empty.");
754 auto args = getBody().getArguments();
755 return KernelDim3{args[3], args[4], args[5]};
756 }
757
759 assert(!getBody().empty() && "LaunchOp body must not be empty.");
760 auto args = getBody().getArguments();
761 return KernelDim3{args[6], args[7], args[8]};
762 }
763
765 assert(!getBody().empty() && "LaunchOp body must not be empty.");
766 auto args = getBody().getArguments();
767 return KernelDim3{args[9], args[10], args[11]};
768 }
769
770 std::optional LaunchOp::getClusterIds() {
771 assert(!getBody().empty() && "LaunchOp body must not be empty.");
772 if (!hasClusterSize())
773 return std::nullopt;
774 auto args = getBody().getArguments();
775 return KernelDim3{args[12], args[13], args[14]};
776 }
777
778 std::optional LaunchOp::getClusterSize() {
779 assert(!getBody().empty() && "LaunchOp body must not be empty.");
780 if (!hasClusterSize())
781 return std::nullopt;
782 auto args = getBody().getArguments();
783 return KernelDim3{args[15], args[16], args[17]};
784 }
785
786 KernelDim3 LaunchOp::getGridSizeOperandValues() {
787 auto operands = getOperands().drop_front(getAsyncDependencies().size());
788 return KernelDim3{operands[0], operands[1], operands[2]};
789 }
790
791 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
792 auto operands = getOperands().drop_front(getAsyncDependencies().size());
793 return KernelDim3{operands[3], operands[4], operands[5]};
794 }
795
796 std::optional LaunchOp::getClusterSizeOperandValues() {
797 auto operands = getOperands().drop_front(getAsyncDependencies().size());
798 if (!hasClusterSize())
799 return std::nullopt;
800 return KernelDim3{operands[6], operands[7], operands[8]};
801 }
802
804 if (!(hasClusterSize()) &&
805 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
806 return emitOpError() << "cluster size must be all present";
807 return success();
808 }
809
810 LogicalResult LaunchOp::verifyRegions() {
811
812
813
814 if (!getBody().empty()) {
815 if (getBody().getNumArguments() <
816 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
817 return emitOpError("unexpected number of region arguments");
818 }
819
820
821 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
822 GPUDialect::getWorkgroupAddressSpace())) ||
824 GPUDialect::getPrivateAddressSpace())))
825 return failure();
826
827
828
829 for (Block &block : getBody()) {
830 if (block.empty())
831 continue;
832 if (block.back().getNumSuccessors() != 0)
833 continue;
834 if (!isagpu::TerminatorOp(&block.back())) {
835 return block.back()
836 .emitError()
837 .append("expected '", gpu::TerminatorOp::getOperationName(),
838 "' or a terminator with successors")
839 .attachNote(getLoc())
840 .append("in '", LaunchOp::getOperationName(), "' body region");
841 }
842 }
843
844 if (getNumResults() == 0 && getAsyncToken())
845 return emitOpError("needs to be named when async keyword is specified");
846
847 return success();
848 }
849
850
851
852
853
856 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
857 p << size.x << " = " << operands.x << ", ";
858 p << size.y << " = " << operands.y << ", ";
859 p << size.z << " = " << operands.z << ')';
860 }
861
863 if (getAsyncToken()) {
864 p << " async";
865 if (!getAsyncDependencies().empty())
866 p << " [" << getAsyncDependencies() << ']';
867 }
868
869 if (hasClusterSize()) {
870 p << ' ' << getClustersKeyword();
872 getClusterSizeOperandValues().value(),
873 getClusterIds().value());
874 }
875 p << ' ' << getBlocksKeyword();
877 getBlockIds());
878 p << ' ' << getThreadsKeyword();
880 getThreadIds());
881 if (getDynamicSharedMemorySize())
882 p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
883 << getDynamicSharedMemorySize();
884
885 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
886 printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
887
888 p << ' ';
889
890 p.printRegion(getBody(), false);
892 LaunchOp::getOperandSegmentSizeAttr(),
893 getNumWorkgroupAttributionsAttrName()});
894 }
895
896
897
898
899
900
901
902 static ParseResult
907 assert(indices.size() == 3 && "space for three indices expected");
910 false) ||
912 return failure();
913 std::move(args.begin(), args.end(), indices.begin());
914
915 for (int i = 0; i < 3; ++i) {
917 return failure();
918 if (parser.parseOperand(regionSizes[i], false) ||
920 return failure();
921 }
922
924 }
925
926
927
928
929
930
931
932
933
935
937 sizes(LaunchOp::kNumConfigOperands);
938
939
941 LaunchOp::kNumConfigRegionAttributes);
942
943
945 Type asyncTokenType;
946 if (failed(
948 parser.resolveOperands(asyncDependencies, asyncTokenType,
950 return failure();
952 result.types.push_back(asyncTokenType);
953
954 bool hasCluster = false;
955 if (succeeded(
957 hasCluster = true;
958 sizes.resize(9);
959 regionArgs.resize(18);
960 }
963
964
965
966 if (hasCluster) {
968 regionArgsRef.slice(15, 3),
969 regionArgsRef.slice(12, 3)))
970 return failure();
971 }
972
973
974
975
976
977 if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
979 regionArgsRef.slice(6, 3),
980 regionArgsRef.slice(0, 3)) ||
981 parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
983 regionArgsRef.slice(9, 3),
984 regionArgsRef.slice(3, 3)) ||
987 return failure();
988
990 bool hasDynamicSharedMemorySize = false;
992 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
993 hasDynamicSharedMemorySize = true;
994 if (parser.parseOperand(dynamicSharedMemorySize) ||
998 return failure();
999 }
1000
1001
1002
1003
1004
1005
1006
1009 LaunchOp::kNumConfigRegionAttributes + 6, index);
1010
1012 for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1014 arg.ssaName = std::get<0>(ssaValueAndType);
1015 arg.type = std::get<1>(ssaValueAndType);
1016 regionArguments.push_back(arg);
1017 }
1018
1020
1021 if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1022 regionArguments)))
1023 return failure();
1024
1025
1026
1027 unsigned numWorkgroupAttrs = regionArguments.size() -
1028 LaunchOp::kNumConfigRegionAttributes -
1029 (hasCluster ? 6 : 0);
1030 result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1032
1033
1034 if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1035 regionArguments)))
1036 return failure();
1037
1038
1039
1040
1042 if (parser.parseRegion(*body, regionArguments) ||
1044 return failure();
1045
1047 segmentSizes.front() = asyncDependencies.size();
1048
1049 if (!hasCluster) {
1050 segmentSizes[7] = 0;
1051 segmentSizes[8] = 0;
1052 segmentSizes[9] = 0;
1053 }
1054 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1055 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1057 return success();
1058 }
1059
1060
1061
1066
1067
1069 bool simplified = false;
1070 auto constPropIdUses = [&](Value id, Value size) {
1071
1073 return;
1074 if (id.getUses().empty())
1075 return;
1076 if (!simplified) {
1077
1080 zero =
1081 rewriter.createarith::ConstantIndexOp(op.getLoc(), 0);
1082 }
1084 simplified = true;
1085 };
1086 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1087 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1088 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1089 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1090 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1091 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1092
1093 return success(simplified);
1094 }
1095 };
1096
1097 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1100 }
1101
1102
1103
1105 auto attrName = getNumWorkgroupAttributionsAttrName();
1106 auto attr = (*this)->getAttrOfType(attrName);
1107 (*this)->setAttr(attrName,
1109 return getBody().insertArgument(
1110 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1111 }
1112
1113
1114
1116
1117
1118 return getBody().addArgument(type, loc);
1119 }
1120
1121
1122
1123
1124
1126 SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1130 std::optional clusterSize) {
1131 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1132 "expected a symbol reference with a single nested reference");
1134 if (asyncTokenType)
1136
1137
1140 if (clusterSize.has_value())
1141 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1142 if (dynamicSharedMemorySize)
1143 result.addOperands(dynamicSharedMemorySize);
1145
1147 prop.kernel = kernelSymbol;
1148 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1149
1150 for (auto &sz : prop.operandSegmentSizes)
1151 sz = 1;
1152 prop.operandSegmentSizes[0] = asyncDependencies.size();
1153 if (!clusterSize.has_value()) {
1154 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1155 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1156 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1157 }
1158 prop.operandSegmentSizes[segmentSizesLen - 3] =
1159 dynamicSharedMemorySize ? 1 : 0;
1160 prop.operandSegmentSizes[segmentSizesLen - 2] =
1161 static_cast<int32_t>(kernelOperands.size());
1162 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1163 }
1164
1166 GPUFuncOp kernelFunc, KernelDim3 gridSize,
1170 std::optional clusterSize) {
1171 auto kernelModule = kernelFunc->getParentOfType();
1172 auto kernelSymbol =
1174 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1175 build(builder, result, kernelSymbol, gridSize, getBlockSize,
1176 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1177 asyncDependencies, clusterSize);
1178 }
1179
1181 SymbolRefAttr kernel, KernelDim3 gridSize,
1184 std::optional clusterSize) {
1185
1188 if (clusterSize.has_value())
1189 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1190 if (dynamicSharedMemorySize)
1191 result.addOperands(dynamicSharedMemorySize);
1193 if (asyncObject)
1196 prop.kernel = kernel;
1197 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1198
1199 for (auto &sz : prop.operandSegmentSizes)
1200 sz = 1;
1201 prop.operandSegmentSizes[0] = 0;
1202 if (!clusterSize.has_value()) {
1203 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1204 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1205 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1206 }
1207 prop.operandSegmentSizes[segmentSizesLen - 3] =
1208 dynamicSharedMemorySize ? 1 : 0;
1209 prop.operandSegmentSizes[segmentSizesLen - 2] =
1210 static_cast<int32_t>(kernelOperands.size());
1211 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1212 }
1213
1214 StringAttr LaunchFuncOp::getKernelModuleName() {
1215 return getKernel().getRootReference();
1216 }
1217
1218 StringAttr LaunchFuncOp::getKernelName() {
1219 return getKernel().getLeafReference();
1220 }
1221
1222 unsigned LaunchFuncOp::getNumKernelOperands() {
1223 return getKernelOperands().size();
1224 }
1225
1226 Value LaunchFuncOp::getKernelOperand(unsigned i) {
1227 return getKernelOperands()[i];
1228 }
1229
1230 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1231 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1232 return KernelDim3{operands[0], operands[1], operands[2]};
1233 }
1234
1235 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1236 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1237 return KernelDim3{operands[3], operands[4], operands[5]};
1238 }
1239
1240 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1241 assert(hasClusterSize() &&
1242 "cluster size is not set, check hasClusterSize() first");
1243 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1244 return KernelDim3{operands[6], operands[7], operands[8]};
1245 }
1246
1248 auto module = (*this)->getParentOfType();
1249 if (!module)
1250 return emitOpError("expected to belong to a module");
1251
1252 if (!module->getAttrOfType(
1253 GPUDialect::getContainerModuleAttrName()))
1254 return emitOpError("expected the closest surrounding module to have the '" +
1255 GPUDialect::getContainerModuleAttrName() +
1256 "' attribute");
1257
1258 if (hasClusterSize()) {
1259 if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1260 getClusterSizeZ().getType() != getClusterSizeX().getType())
1261 return emitOpError()
1262 << "expects types of the cluster dimensions must be the same";
1263 }
1264
1265 return success();
1266 }
1267
1268 static ParseResult
1270 std::optionalOpAsmParser::UnresolvedOperand clusterValue,
1271 Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1274 return failure();
1275 } else {
1277 }
1278 if (clusterValue.has_value()) {
1279 clusterXTy = clusterYTy = clusterZTy = dimTy;
1280 }
1281 return success();
1282 }
1283
1285 Value clusterValue, Type clusterXTy,
1286 Type clusterYTy, Type clusterZTy) {
1288 printer << ": " << dimTy;
1289 }
1290
1296 return success();
1297
1298 auto parseElement = [&]() -> ParseResult {
1299 return failure(parser.parseOperand(argNames.emplace_back()) ||
1301 };
1302
1304 parseElement, " in argument list");
1305 }
1306
1309 if (operands.empty())
1310 return;
1311 printer << "args(";
1312 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1313 [&](const auto &pair) {
1314 auto [operand, type] = pair;
1315 printer << operand << " : " << type;
1316 });
1317 printer << ")";
1318 }
1319
1320
1321
1322
1323
1325 int32_t offset, int32_t width, ShuffleMode mode) {
1326 build(builder, result, value,
1331 mode);
1332 }
1333
1334
1335
1336
1337
1338 namespace {
1339
1340
1341 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1343 if (isa_and_nonnull(op->getNextNode())) {
1345 return success();
1346 }
1347 return failure();
1348 }
1349
1350 }
1351
1352 void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1354 results.add(eraseRedundantGpuBarrierOps);
1355 }
1356
1357
1358
1359
1360
1361
1362
1364 auto attrName = getNumWorkgroupAttributionsAttrName();
1365 auto attr = (*this)->getAttrOfType(attrName);
1366 (*this)->setAttr(attrName,
1368 return getBody().insertArgument(
1369 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1370 }
1371
1372
1373
1375
1376
1377 return getBody().addArgument(type, loc);
1378 }
1379
1381 StringRef name, FunctionType type,
1382 TypeRange workgroupAttributions,
1386
1391 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1396
1397
1398 for (Type argTy : type.getInputs())
1400 for (Type argTy : workgroupAttributions)
1402 for (Type argTy : privateAttributions)
1404 }
1405
1406
1407
1408
1409
1410
1411
1412
1413 static ParseResult
1417
1419 return success();
1420
1421 size_t existingArgs = args.size();
1422 ParseResult result =
1424 true, true);
1425 if (failed(result))
1426 return result;
1427
1428 bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1430 return arg.attrs && !arg.attrs.empty();
1431 });
1432 if (!hadAttrs) {
1433 attributionAttrs = nullptr;
1434 return result;
1435 }
1436
1439 for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1440 if (!argument.attrs)
1442 else
1443 attributionAttrsVec.push_back(argument.attrs);
1444 }
1445 attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1446 return result;
1447 }
1448
1449
1450
1451
1452
1453
1458 bool isVariadic;
1459
1460
1461 StringAttr nameAttr;
1464 return failure();
1465
1468 parser, false, entryArgs, isVariadic, resultTypes,
1469 resultAttrs)))
1470 return failure();
1471
1472 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1473 return parser.emitError(signatureLocation)
1474 << "gpu.func requires named arguments";
1475
1476
1477
1479
1481 for (auto &arg : entryArgs)
1482 argTypes.push_back(arg.type);
1483 auto type = builder.getFunctionType(argTypes, resultTypes);
1486
1488 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1489 getResAttrsAttrName(result.name));
1490
1491 Attribute workgroupAttributionAttrs;
1492
1493 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1494 entryArgs, workgroupAttributionAttrs)))
1495 return failure();
1496
1497
1498
1499 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1500 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1502 if (workgroupAttributionAttrs)
1503 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1504 workgroupAttributionAttrs);
1505
1506 Attribute privateAttributionAttrs;
1507
1508 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1509 entryArgs, privateAttributionAttrs)))
1510 return failure();
1511 if (privateAttributionAttrs)
1512 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1513 privateAttributionAttrs);
1514
1515
1517 result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1519
1520
1522 return failure();
1523
1524
1525
1526 auto *body = result.addRegion();
1527 return parser.parseRegion(*body, entryArgs);
1528 }
1529
1532 ArrayAttr attributes) {
1533 if (values.empty())
1534 return;
1535
1536 p << ' ' << keyword << '(';
1537 llvm::interleaveComma(
1538 llvm::enumerate(values), p, [&p, attributes](auto pair) {
1540 p << v << " : " << v.getType();
1541
1542 size_t attributionIndex = pair.index();
1543 DictionaryAttr attrs;
1544 if (attributes && attributionIndex < attributes.size())
1545 attrs = llvm::cast(attributes[attributionIndex]);
1546 if (attrs)
1548 });
1549 p << ')';
1550 }
1551
1553 p << ' ';
1555
1556 FunctionType type = getFunctionType();
1558 false,
1559 type.getResults());
1560
1561 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1562 getWorkgroupAttribAttrs().value_or(nullptr));
1563 printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1564 getPrivateAttribAttrs().value_or(nullptr));
1565 if (isKernel())
1566 p << ' ' << getKernelKeyword();
1567
1569 p, *this,
1570 {getNumWorkgroupAttributionsAttrName(),
1571 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1572 getArgAttrsAttrName(), getResAttrsAttrName(),
1573 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1574 p << ' ';
1575 p.printRegion(getBody(), false);
1576 }
1577
1579 StringAttr attrName) {
1580 auto allAttrs = llvm::dyn_cast_or_null(op->getAttr(attrName));
1581 if (!allAttrs || index >= allAttrs.size())
1582 return DictionaryAttr();
1583 return llvm::cast(allAttrs[index]);
1584 }
1585
1586 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1587 return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1588 }
1589
1590 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1591 return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1592 }
1593
1595 DictionaryAttr value, StringAttr attrName) {
1597 auto allAttrs = llvm::dyn_cast_or_null(op->getAttr(attrName));
1599 if (allAttrs)
1600 elements.append(allAttrs.begin(), allAttrs.end());
1601 while (elements.size() <= index)
1603 if (!value)
1605 else
1606 elements[index] = value;
1608 op->setAttr(attrName, newValue);
1609 }
1610
1611 void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1612 DictionaryAttr value) {
1613 setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1614 }
1615
1616 void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1617 DictionaryAttr value) {
1618 setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1619 }
1620
1622 StringAttr name, StringAttr attrsName) {
1624 if (!dict)
1626 return dict.get(name);
1627 }
1628
1629 Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1630 StringAttr name) {
1631 assert(index < getNumWorkgroupAttributions() &&
1632 "index must map to a workgroup attribution");
1634 getWorkgroupAttribAttrsAttrName());
1635 }
1636
1637 Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1638 StringAttr name) {
1639 assert(index < getNumPrivateAttributions() &&
1640 "index must map to a private attribution");
1642 getPrivateAttribAttrsAttrName());
1643 }
1644
1646 Attribute value, StringAttr attrsName) {
1650 if (oldDict)
1651 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1652
1653 bool found = false;
1654 bool mustSort = true;
1655 for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1656 if (elems[i].getName() == name) {
1657 found = true;
1658 if (!value) {
1659 std::swap(elems[i], elems[elems.size() - 1]);
1660 elems.pop_back();
1661 } else {
1662 mustSort = false;
1663 elems[i] = NamedAttribute(elems[i].getName(), value);
1664 }
1665 break;
1666 }
1667 }
1668 if (!found) {
1669 if (!value)
1670 return;
1671 elems.emplace_back(name, value);
1672 }
1673 if (mustSort) {
1674 DictionaryAttr::sortInPlace(elems);
1675 }
1676 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1678 }
1679
1680 void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1682 assert(index < getNumWorkgroupAttributions() &&
1683 "index must map to a workgroup attribution");
1685 getWorkgroupAttribAttrsAttrName());
1686 }
1687
1688 void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1690 assert(index < getNumPrivateAttributions() &&
1691 "index must map to a private attribution");
1693 getPrivateAttribAttrsAttrName());
1694 }
1695
1696 LogicalResult GPUFuncOp::verifyType() {
1697 if (isKernel() && getFunctionType().getNumResults() != 0)
1698 return emitOpError() << "expected void return type for kernel function";
1699
1700 return success();
1701 }
1702
1703
1704 LogicalResult GPUFuncOp::verifyBody() {
1705 if (empty())
1706 return emitOpError() << "expected body with at least one block";
1707 unsigned numFuncArguments = getNumArguments();
1708 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1709 unsigned numBlockArguments = front().getNumArguments();
1710 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1711 return emitOpError() << "expected at least "
1712 << numFuncArguments + numWorkgroupAttributions
1713 << " arguments to body region";
1714
1715 ArrayRef funcArgTypes = getFunctionType().getInputs();
1716 for (unsigned i = 0; i < numFuncArguments; ++i) {
1717 Type blockArgType = front().getArgument(i).getType();
1718 if (funcArgTypes[i] != blockArgType)
1719 return emitOpError() << "expected body region argument #" << i
1720 << " to be of type " << funcArgTypes[i] << ", got "
1721 << blockArgType;
1722 }
1723
1724 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1725 GPUDialect::getWorkgroupAddressSpace())) ||
1727 GPUDialect::getPrivateAddressSpace())))
1728 return failure();
1729
1730 return success();
1731 }
1732
1733
1734
1735
1736
1738 GPUFuncOp function = (*this)->getParentOfType();
1739
1740 FunctionType funType = function.getFunctionType();
1741
1742 if (funType.getNumResults() != getOperands().size())
1743 return emitOpError()
1744 .append("expected ", funType.getNumResults(), " result operands")
1745 .attachNote(function.getLoc())
1746 .append("return type declared here");
1747
1749 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1750 auto [type, operand] = pair.value();
1751 if (type != operand.getType())
1752 return emitOpError() << "unexpected type `" << operand.getType()
1753 << "' for operand #" << pair.index();
1754 }
1755 return success();
1756 }
1757
1758
1759
1760
1761
1763 StringRef name, ArrayAttr targets,
1767 if (targets)
1768 props.targets = targets;
1770 props.offloadingHandler = offloadingHandler;
1771 }
1772
1776 build(builder, result, name,
1777 targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1778 offloadingHandler);
1779 }
1780
1781 bool GPUModuleOp::hasTarget(Attribute target) {
1782 if (ArrayAttr targets = getTargetsAttr())
1783 return llvm::count(targets.getValue(), target);
1784 return false;
1785 }
1786
1788 ArrayAttr &targetsAttr = getProperties().targets;
1791 }
1792
1794 auto targets = getOperation()->getAttrOfType("targets");
1795
1796 if (!targets)
1797 return success();
1798
1799 for (auto target : targets) {
1800 if (auto verifyTargetAttr =
1801 llvm::dyn_cast(target)) {
1802 if (verifyTargetAttr.verifyTarget(getOperation()).failed())
1803 return failure();
1804 }
1805 }
1806 return success();
1807 }
1808
1809
1810
1811
1813 Attribute offloadingHandler, ArrayAttr objects) {
1817 properties.objects = objects;
1818 if (offloadingHandler)
1819 properties.offloadingHandler = offloadingHandler;
1820 else
1821 properties.offloadingHandler = builder.getAttr(nullptr);
1822 }
1823
1826 build(builder, result, name, offloadingHandler,
1827 objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1828 }
1829
1834 return failure();
1836 return failure();
1837 }
1838 if (!offloadingHandler)
1839 offloadingHandler = parser.getBuilder().getAttr(nullptr);
1840 return success();
1841 }
1842
1846 printer << '<' << offloadingHandler << '>';
1847 }
1848
1849 //===----------------------------------------------------------------------===//
1850 // GPUMemcpyOp
1851 //===----------------------------------------------------------------------===//
1852
1853 LogicalResult MemcpyOp::verify() {
1854 auto srcType = getSrc().getType();
1855 auto dstType = getDst().getType();
1856
1857 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1858 return emitOpError("arguments have incompatible element type");
1859
1860 if (failed(verifyCompatibleShape(srcType, dstType)))
1861 return emitOpError("arguments have incompatible shape");
1862
1863 return success();
1864 }
1865
1866 namespace {
1867
1868
1869
1870 struct EraseTrivialCopyOp : public OpRewritePattern {
1871 using OpRewritePattern::OpRewritePattern;
1872
1873 LogicalResult matchAndRewrite(MemcpyOp op,
1874 PatternRewriter &rewriter) const override {
1875 Value dest = op.getDst();
1876 Operation *destDefOp = dest.getDefiningOp();
1877 // `dest` must be defined by an op having Allocate memory effect in order to
1878 // perform the folding.
1879 if (!destDefOp ||
1880 !hasSingleEffectMemoryEffects::Allocate(destDefOp, dest))
1881 return failure();
1882 // We can erase `op` iff `dest` has no other use apart from its
1883 // use by `op` and dealloc ops.
1884 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1885 return user != op &&
1886 !hasSingleEffectMemoryEffects::Free(user, dest);
1887 }))
1888 return failure();
1889 // We can perform the folding if and only if op has a single async
1890 // dependency and produces an async token as result, or if it does not have
1891 // any async dependency and does not produce any async token result.
1892 if (op.getAsyncDependencies().size() > 1 ||
1893 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1894 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1895 return failure();
1896 rewriter.replaceOp(op, op.getAsyncDependencies());
1897 return success();
1898 }
1899 };
1900
1901 } // end anonymous namespace
1902
1903 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1904 MLIRContext *context) {
1905 results.add(context);
1906 }
1907
1908 //===----------------------------------------------------------------------===//
1909 // GPU_SubgroupMmaLoadMatrixOp
1910 //===----------------------------------------------------------------------===//
1911
1912 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1913 auto srcType = getSrcMemref().getType();
1914 auto resType = getRes().getType();
1915 auto resMatrixType = llvm::castgpu::MMAMatrixType(resType);
1916 auto operand = resMatrixType.getOperand();
1917 auto srcMemrefType = llvm::cast(srcType);
1918
1919 if (!srcMemrefType.isLastDimUnitStride())
1920 return emitError(
1921 "expected source memref most minor dim must have unit stride");
1922
1923 if (operand != "AOp" && operand != "BOp" && operand != "COp")
1924 return emitError("only AOp, BOp and COp can be loaded");
1925
1926 return success();
1927 }
1928
1929 //===----------------------------------------------------------------------===//
1930 // GPU_SubgroupMmaStoreMatrixOp
1931 //===----------------------------------------------------------------------===//
1932
1933 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1934 auto srcType = getSrc().getType();
1935 auto dstType = getDstMemref().getType();
1936 auto srcMatrixType = llvm::castgpu::MMAMatrixType(srcType);
1937 auto dstMemrefType = llvm::cast(dstType);
1938
1939 if (!dstMemrefType.isLastDimUnitStride())
1940 return emitError(
1941 "expected destination memref most minor dim must have unit stride");
1942
1943 if (srcMatrixType.getOperand() != "COp")
1944 return emitError(
1945 "expected the operand matrix being stored to have 'COp' operand type");
1946
1947 return success();
1948 }
1949
1950 //===----------------------------------------------------------------------===//
1951 // GPU_SubgroupMmaComputeOp
1952 //===----------------------------------------------------------------------===//
1953
1954 LogicalResult SubgroupMmaComputeOp::verify() {
1955 enum OperandMap { A, B, C };
1956 SmallVector<MMAMatrixType, 3> opTypes;
1957 opTypes.push_back(llvm::cast(getOpA().getType()));
1958 opTypes.push_back(llvm::cast(getOpB().getType()));
1959 opTypes.push_back(llvm::cast(getOpC().getType()));
1960
1961 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
1962 opTypes[C].getOperand() != "COp")
1963 return emitError("operands must be in the order AOp, BOp, COp");
1964
1965 ArrayRef<int64_t> aShape, bShape, cShape;
1966 aShape = opTypes[A].getShape();
1967 bShape = opTypes[B].getShape();
1968 cShape = opTypes[C].getShape();
1969
1970 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1971 bShape[1] != cShape[1])
1972 return emitError("operand shapes do not satisfy matmul constraints");
1973
1974 return success();
1975 }
1976
1977 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
1978 SmallVectorImpl<::mlir::OpFoldResult> &results) {
1979 return memref::foldMemRefCast(*this);
1980 }
1981
1982 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
1983 SmallVectorImpl<::mlir::OpFoldResult> &results) {
1984 return memref::foldMemRefCast(*this);
1985 }
1986
1987 //===----------------------------------------------------------------------===//
1988 // GPU_WaitOp
1989 //===----------------------------------------------------------------------===//
1990
1991 namespace {
1992
1993
1994
1995
1996 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern {
1997 public:
1998 using OpRewritePattern::OpRewritePattern;
1999
2000 LogicalResult matchAndRewrite(WaitOp op,
2001 PatternRewriter &rewriter) const final {
2002 auto predicate = [](Value value) {
2003 auto waitOp = value.getDefiningOp();
2004 return waitOp && waitOp->getNumOperands() == 0;
2005 };
2006 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2007 return failure();
2008 SmallVector validOperands;
2009 for (Value operand : op->getOperands()) {
2010 if (predicate(operand))
2011 continue;
2012 validOperands.push_back(operand);
2013 }
2014 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2015 return success();
2016 }
2017 };
2018
2019
2020
2021
2022
2023
2024
2025
2026 struct SimplifyGpuWaitOp : public OpRewritePattern {
2027 public:
2028 using OpRewritePattern::OpRewritePattern;
2029
2030 LogicalResult matchAndRewrite(WaitOp op,
2031 PatternRewriter &rewriter) const final {
2032 // Erase gpu.wait ops that neither have any async dependencies nor return
2033 // any async token.
2034 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2035 rewriter.eraseOp(op);
2036 return success();
2037 }
2038 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2039 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2040 op.getAsyncToken()) {
2041 rewriter.replaceOp(op, op.getAsyncDependencies());
2042 return success();
2043 }
2044 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2045 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2046 rewriter.eraseOp(op);
2047 return success();
2048 }
2049 return failure();
2050 }
2051 };
2052
2053 } // end anonymous namespace
2054
2055 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2056 MLIRContext *context) {
2057 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2058 }
2059
2060 //===----------------------------------------------------------------------===//
2061 // GPU_AllocOp
2062 //===----------------------------------------------------------------------===//
2063
2064 LogicalResult AllocOp::verify() {
2065 auto memRefType = llvm::cast(getMemref().getType());
2066
2067 if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
2068 return emitOpError("dimension operand count does not equal memref "
2069 "dynamic dimension count");
2070
2071 unsigned numSymbols = 0;
2072 if (!memRefType.getLayout().isIdentity())
2073 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2074 if (getSymbolOperands().size() != numSymbols) {
2075 return emitOpError(
2076 "symbol operand count does not equal memref symbol count");
2077 }
2078
2079 return success();
2080 }
2081
2082 namespace {
2083
2084
2085
2086 struct SimplifyDimOfAllocOp : public OpRewritePatternmemref::DimOp {
2087 using OpRewritePatternmemref::DimOp::OpRewritePattern;
2088
2089 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2090 PatternRewriter &rewriter) const override {
2091 std::optional<int64_t> index = dimOp.getConstantIndex();
2092 if (!index)
2093 return failure();
2094
2095 auto memrefType = llvm::dyn_cast(dimOp.getSource().getType());
2096 if (!memrefType || index.value() >= memrefType.getRank() ||
2097 !memrefType.isDynamicDim(index.value()))
2098 return failure();
2099
2100 auto alloc = dimOp.getSource().getDefiningOp();
2101 if (!alloc)
2102 return failure();
2103
2104 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2105 memrefType.getDynamicDimIndex(index.value()));
2106 rewriter.replaceOp(dimOp, substituteOp);
2107 return success();
2108 }
2109 };
2110
2111 } // namespace
2112
2113 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2114 MLIRContext *context) {
2115 results.add(context);
2116 }
2117
2118 //===----------------------------------------------------------------------===//
2119 // GPU object attribute
2120 //===----------------------------------------------------------------------===//
2121
2122 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2123 Attribute target, CompilationTarget format,
2124 StringAttr object, DictionaryAttr properties,
2125 KernelTableAttr kernels) {
2126 if (!target)
2127 return emitError() << "the target attribute cannot be null";
2128 if (target.hasPromiseOrImplementsInterface())
2129 return success();
2130 return emitError() << "the target attribute must implement or promise the "
2131 "`gpu::TargetAttrInterface`";
2132 }
2133
2134 namespace {
2135 ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2136 StringAttr &object) {
2137 std::optional formatResult;
2138 StringRef enumKeyword;
2139 auto loc = odsParser.getCurrentLocation();
2140 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2141 formatResult = CompilationTarget::Fatbin;
2142 if (!formatResult &&
2143 (formatResult =
2144 gpu::symbolizeEnumgpu::CompilationTarget(enumKeyword)) &&
2145 odsParser.parseEqual())
2146 return odsParser.emitError(loc, "expected an equal sign");
2147 if (!formatResult)
2148 return odsParser.emitError(loc, "expected keyword for GPU object format");
2149 FailureOr objectResult =
2150 FieldParser::parse(odsParser);
2151 if (failed(objectResult))
2152 return odsParser.emitError(odsParser.getCurrentLocation(),
2153 "failed to parse GPU_ObjectAttr parameter "
2154 "'object' which is to be a `StringAttr`");
2155 format = *formatResult;
2156 object = *objectResult;
2157 return success();
2158 }
2159
2160 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2161 StringAttr object) {
2162 if (format != CompilationTarget::Fatbin)
2163 odsParser << stringifyEnum(format) << " = ";
2164 odsParser << object;
2165 }
2166 } // namespace
2167
2168 //===----------------------------------------------------------------------===//
2169 // GPU select object attribute
2170 //===----------------------------------------------------------------------===//
2171
2172 LogicalResult
2173 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2174 Attribute target) {
2175 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2176 if (target) {
2177 if (auto intAttr = mlir::dyn_cast(target)) {
2178 if (intAttr.getInt() < 0) {
2179 return emitError() << "the object index must be positive";
2180 }
2181 } else if (!target.hasPromiseOrImplementsInterface()) {
2182 return emitError()
2183 << "the target attribute must be a GPU Target attribute";
2184 }
2185 }
2186 return success();
2187 }
2188
2189 //===----------------------------------------------------------------------===//
2190 // DynamicSharedMemoryOp
2191 //===----------------------------------------------------------------------===//
2192
2193 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2194 if (!getOperation()->getParentWithTraitOpTrait::SymbolTable())
2195 return emitOpError() << "must be inside an op with symbol table";
2196
2197 MemRefType memrefType = getResultMemref().getType();
2198 // Check address space
2199 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2200 return emitOpError() << "address space must be "
2201 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2202 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2203 }
2204 if (memrefType.hasStaticShape()) {
2205 return emitOpError() << "result memref type must be memref<?xi8, "
2206 "#gpu.address_space>";
2207 }
2208 return success();
2209 }
2210
2211 //===----------------------------------------------------------------------===//
2212 // GPU WarpExecuteOnLane0Op
2213 //===----------------------------------------------------------------------===//
2214
2215 void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2216 p << "(" << getLaneid() << ")";
2217
2218 SmallVector coreAttr = {getWarpSizeAttrName()};
2219 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2220 p << "[" << llvm::cast(warpSizeAttr).getInt() << "]";
2221
2222 if (!getArgs().empty())
2223 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2224 if (!getResults().empty())
2225 p << " -> (" << getResults().getTypes() << ')';
2226 p << " ";
2227 p.printRegion(getRegion(),
2228 /*printEntryBlockArgs=*/true,
2229 /*printBlockTerminators=*/!getResults().empty());
2230 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2231 }
2232
2233 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2234 OperationState &result) {
2235 // Create the region.
2236 result.regions.reserve(1);
2237 Region *warpRegion = result.addRegion();
2238
2239 auto &builder = parser.getBuilder();
2240 OpAsmParser::UnresolvedOperand laneId;
2241
2242 // Parse predicate operand.
2243 if (parser.parseLParen() ||
2244 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2245 parser.parseRParen())
2246 return failure();
2247
2248 int64_t warpSize;
2249 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2250 parser.parseRSquare())
2251 return failure();
2252 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2253 builder.getContext())),
2254 builder.getI64IntegerAttr(warpSize));
2255
2256 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2257 return failure();
2258
2259 llvm::SMLoc inputsOperandsLoc;
2260 SmallVectorOpAsmParser::UnresolvedOperand inputsOperands;
2261 SmallVector inputTypes;
2262 if (succeeded(parser.parseOptionalKeyword("args"))) {
2263 if (parser.parseLParen())
2264 return failure();
2265
2266 inputsOperandsLoc = parser.getCurrentLocation();
2267 if (parser.parseOperandList(inputsOperands) ||
2268 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2269 return failure();
2270 }
2271 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2272 result.operands))
2273 return failure();
2274
2275 // Parse optional results type list.
2276 if (parser.parseOptionalArrowTypeList(result.types))
2277 return failure();
2278 // Parse the region.
2279 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2280 /*argTypes=*/{}))
2281 return failure();
2282 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2283
2284 // Parse the optional attribute list.
2285 if (parser.parseOptionalAttrDict(result.attributes))
2286 return failure();
2287 return success();
2288 }
2289
2290 void WarpExecuteOnLane0Op::getSuccessorRegions(
2291 RegionBranchPoint point, SmallVectorImpl ®ions) {
2292 if (!point.isParent()) {
2293 regions.push_back(RegionSuccessor(getResults()));
2294 return;
2295 }
2296
2297 // The warp region is always executed
2298 regions.push_back(RegionSuccessor(&getWarpRegion()));
2299 }
2300
2301 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2302 TypeRange resultTypes, Value laneId,
2303 int64_t warpSize) {
2304 build(builder, result, resultTypes, laneId, warpSize,
2305 /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);
2306 }
2307
2308 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2309 TypeRange resultTypes, Value laneId,
2310 int64_t warpSize, ValueRange args,
2311 TypeRange blockArgTypes) {
2312 result.addOperands(laneId);
2313 result.addAttribute(getAttributeNames()[0],
2314 builder.getI64IntegerAttr(warpSize));
2315 result.addTypes(resultTypes);
2316 result.addOperands(args);
2317 assert(args.size() == blockArgTypes.size());
2318 OpBuilder::InsertionGuard guard(builder);
2319 Region *warpRegion = result.addRegion();
2320 Block *block = builder.createBlock(warpRegion);
2321 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2322 block->addArgument(type, arg.getLoc());
2323 }
2324
2325
2326
2327 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2328 int64_t warpSize, Operation *op) {
2329 // If the types matches there is no distribution.
2330 if (expanded == distributed)
2331 return success();
2332 auto expandedVecType = llvm::dyn_cast(expanded);
2333 auto distributedVecType = llvm::dyn_cast(distributed);
2334 if (!expandedVecType || !distributedVecType)
2335 return op->emitOpError("expected vector type for distributed operands.");
2336 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2337 expandedVecType.getElementType() != distributedVecType.getElementType())
2338 return op->emitOpError(
2339 "expected distributed vectors to have same rank and element type.");
2340
2341 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2342 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2343 int64_t eDim = expandedVecType.getDimSize(i);
2344 int64_t dDim = distributedVecType.getDimSize(i);
2345 if (eDim == dDim)
2346 continue;
2347 if (eDim % dDim != 0)
2348 return op->emitOpError()
2349 << "expected expanded vector dimension #" << i << " (" << eDim
2350 << ") to be a multipler of the distributed vector dimension ("
2351 << dDim << ")";
2352 scales[i] = eDim / dDim;
2353 }
2354 if (std::accumulate(scales.begin(), scales.end(), 1,
2355 std::multiplies<int64_t>()) != warpSize)
2356 return op->emitOpError()
2357 << "incompatible distribution dimensions from " << expandedVecType
2358 << " to " << distributedVecType << " with warp size = " << warpSize;
2359
2360 return success();
2361 }
2362
2363 LogicalResult WarpExecuteOnLane0Op::verify() {
2364 if (getArgs().size() != getWarpRegion().getNumArguments())
2365 return emitOpError(
2366 "expected same number op arguments and block arguments.");
2367 auto yield =
2368 cast(getWarpRegion().getBlocks().begin()->getTerminator());
2369 if (yield.getNumOperands() != getNumResults())
2370 return emitOpError(
2371 "expected same number of yield operands and return values.");
2372 int64_t warpSize = getWarpSize();
2373 for (auto [regionArg, arg] :
2374 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2375 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2376 warpSize, getOperation())))
2377 return failure();
2378 }
2379 for (auto [yieldOperand, result] :
2380 llvm::zip_equal(yield.getOperands(), getResults())) {
2381 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2382 warpSize, getOperation())))
2383 return failure();
2384 }
2385 return success();
2386 }
2387 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2388 return succeeded(
2389 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2390 }
2391
2392 //===----------------------------------------------------------------------===//
2393 // GPU KernelMetadataAttr
2394 //===----------------------------------------------------------------------===//
2395
2396 KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2397 DictionaryAttr metadata) {
2398 assert(kernel && "invalid kernel");
2399 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2400 kernel.getAllArgAttrs(), metadata);
2401 }
2402
2403 KernelMetadataAttr
2404 KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2405 FunctionOpInterface kernel,
2406 DictionaryAttr metadata) {
2407 assert(kernel && "invalid kernel");
2408 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2409 kernel.getAllArgAttrs(), metadata);
2410 }
2411
2412 KernelMetadataAttr
2413 KernelMetadataAttr::appendMetadata(ArrayRef attrs) const {
2414 if (attrs.empty())
2415 return *this;
2416 NamedAttrList attrList;
2417 if (DictionaryAttr dict = getMetadata())
2418 attrList.append(dict);
2419 attrList.append(attrs);
2420 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2421 attrList.getDictionary(getContext()));
2422 }
2423
2424 LogicalResult
2425 KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2426 StringAttr name, Type functionType,
2427 ArrayAttr argAttrs, DictionaryAttr metadata) {
2428 if (name.empty())
2429 return emitError() << "the kernel name can't be empty";
2430 if (argAttrs) {
2431 if (llvm::any_of(argAttrs, [](Attribute attr) {
2432 return !llvm::isa(attr);
2433 }))
2434 return emitError()
2435 << "all attributes in the array must be a dictionary attribute";
2436 }
2437 return success();
2438 }
2439
2440 //===----------------------------------------------------------------------===//
2441 // GPU KernelTableAttr
2442 //===----------------------------------------------------------------------===//
2443
2444 KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2445 ArrayRef kernels,
2446 bool isSorted) {
2447 // Note that `is_sorted` is always only invoked once even with assertions ON.
2448 assert((!isSorted || llvm::is_sorted(kernels)) &&
2449 "expected a sorted kernel array");
2450 // Immediately return the attribute if the array is sorted.
2451 if (isSorted || llvm::is_sorted(kernels))
2452 return Base::get(context, kernels);
2453 // Sort the array.
2454 SmallVector kernelsTmp(kernels);
2455 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2456 return Base::get(context, kernelsTmp);
2457 }
2458
2459 KernelTableAttr KernelTableAttr::getChecked(
2460 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2461 ArrayRef kernels, bool isSorted) {
2462 // Note that `is_sorted` is always only invoked once even with assertions ON.
2463 assert((!isSorted || llvm::is_sorted(kernels)) &&
2464 "expected a sorted kernel array");
2465 // Immediately return the attribute if the array is sorted.
2466 if (isSorted || llvm::is_sorted(kernels))
2467 return Base::getChecked(emitError, context, kernels);
2468 // Sort the array.
2469 SmallVector kernelsTmp(kernels);
2470 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2471 return Base::getChecked(emitError, context, kernelsTmp);
2472 }
2473
2474 LogicalResult
2475 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2476 ArrayRef kernels) {
2477 if (kernels.size() < 2)
2478 return success();
2479 // Check that the kernels are uniquely named.
2480 if (std::adjacent_find(kernels.begin(), kernels.end(),
2481 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2482 return l.getName() == r.getName();
2483 }) != kernels.end()) {
2484 return emitError() << "expected all kernels to be uniquely named";
2485 }
2486 return success();
2487 }
2488
2489 KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2490 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2491 return found ? *iterator : KernelMetadataAttr();
2492 }
2493
2494 KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2495 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2496 return found ? *iterator : KernelMetadataAttr();
2497 }
2498
2499 //===----------------------------------------------------------------------===//
2500 // GPU target options
2501 //===----------------------------------------------------------------------===//
2502
2503 TargetOptions::TargetOptions(
2504 StringRef toolkitPath, ArrayRef librariesToLink,
2505 StringRef cmdOptions, StringRef elfSection,
2506 CompilationTarget compilationTarget,
2507 function_ref<SymbolTable *()> getSymbolTableCallback,
2508 function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2509 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2510 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2511 function_ref<void(StringRef)> isaCallback)
2512 : TargetOptions(TypeID::get(), toolkitPath, librariesToLink,
2513 cmdOptions, elfSection, compilationTarget,
2514 getSymbolTableCallback, initialLlvmIRCallback,
2515 linkedLlvmIRCallback, optimizedLlvmIRCallback,
2516 isaCallback) {}
2517
2518 TargetOptions::TargetOptions(
2519 TypeID typeID, StringRef toolkitPath, ArrayRef librariesToLink,
2520 StringRef cmdOptions, StringRef elfSection,
2521 CompilationTarget compilationTarget,
2522 function_ref<SymbolTable *()> getSymbolTableCallback,
2523 function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2524 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2525 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2526 function_ref<void(StringRef)> isaCallback)
2527 : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
2528 cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
2529 compilationTarget(compilationTarget),
2530 getSymbolTableCallback(getSymbolTableCallback),
2531 initialLlvmIRCallback(initialLlvmIRCallback),
2532 linkedLlvmIRCallback(linkedLlvmIRCallback),
2533 optimizedLlvmIRCallback(optimizedLlvmIRCallback),
2534 isaCallback(isaCallback), typeID(typeID) {}
2535
2536 TypeID TargetOptions::getTypeID() const { return typeID; }
2537
2538 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2539
2540 ArrayRef TargetOptions::getLibrariesToLink() const {
2541 return librariesToLink;
2542 }
2543
2544 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2545
2546 StringRef TargetOptions::getELFSection() const { return elfSection; }
2547
2548 SymbolTable *TargetOptions::getSymbolTable() const {
2549 return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2550 }
2551
2552 function_ref<void(llvm::Module &)>
2553 TargetOptions::getInitialLlvmIRCallback() const {
2554 return initialLlvmIRCallback;
2555 }
2556
2557 function_ref<void(llvm::Module &)>
2558 TargetOptions::getLinkedLlvmIRCallback() const {
2559 return linkedLlvmIRCallback;
2560 }
2561
2562 function_ref<void(llvm::Module &)>
2563 TargetOptions::getOptimizedLlvmIRCallback() const {
2564 return optimizedLlvmIRCallback;
2565 }
2566
2567 function_ref<void(StringRef)> TargetOptions::getISACallback() const {
2568 return isaCallback;
2569 }
2570
2571 CompilationTarget TargetOptions::getCompilationTarget() const {
2572 return compilationTarget;
2573 }
2574
2575 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2576 return CompilationTarget::Fatbin;
2577 }
2578
2579 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2580 TargetOptions::tokenizeCmdOptions(const std::string &cmdOptions) {
2581 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2582 llvm::StringSaver stringSaver(options.first);
2583 StringRef opts = cmdOptions;
2584 // For a correct tokenization of the command line options `opts` must be
2585 // unquoted, otherwise the tokenization function returns a single string: the
2586 // unquoted `cmdOptions` -which is not the desired behavior.
2587 // Remove any quotes if they are at the beginning and end of the string:
2588 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2589 opts.consume_front("\""), opts.consume_back("\"");
2590 if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2591 opts.consume_front("'"), opts.consume_back("'");
2592 #ifdef _WIN32
2593 llvm:๐:TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2594 false);
2595 #else
2596 llvm:๐:TokenizeGNUCommandLine(opts, stringSaver, options.second,
2597 false);
2598 #endif
2600 }
2601
2602 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2605 }
2606
2607 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2609 size_t startPos = cmdOptions.find(startsWith);
2610 if (startPos == std:๐งต:npos)
2612
2613 auto tokenized =
2616 return tokenized;
2617 }
2618
2620
2621 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2622 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2623
2624 #define GET_ATTRDEF_CLASSES
2625 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2626
2627 #define GET_OP_CLASSES
2628 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2629
2630 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
llvm::StringMap< llvm::SmallString< 8 > > dictionary
A dictionary stores a mapping of template variable names to their assigned string values.
Kind
An enumeration of the kinds of predicates.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....