MLIR: lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
17
41
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/Support/Error.h"
44 #include "llvm/Support/FormatVariadic.h"
45
46 #define DEBUG_TYPE "gpu-to-llvm"
47
48 namespace mlir {
49 #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
50 #include "mlir/Conversion/Passes.h.inc"
51 }
52
53 using namespace mlir;
54
55 namespace {
56 class GpuToLLVMConversionPass
57 : public impl::GpuToLLVMConversionPassBase {
58 public:
59 using Base::Base;
60 void getDependentDialects(DialectRegistry ®istry) const final {
61 Base::getDependentDialects(registry);
63 }
64
65 void runOnOperation() override;
66 };
67
68 template
70 public:
71 explicit ConvertOpToGpuRuntimeCallPattern(
74
75 protected:
79 if (type.hasStaticShape())
81 rewriter, loc, indexType, type.getNumElements());
82
83 uint64_t rank = type.getRank();
84 Value numElements = desc.size(rewriter, loc, 0);
85 for (unsigned i = 1; i < rank; i++)
86 numElements = rewriter.createLLVM::MulOp(
87 loc, numElements, desc.size(rewriter, loc, i));
88 return numElements;
89 }
90
91 MLIRContext *context = &this->getTypeConverter()->getContext();
92
101 context, this->getTypeConverter()->getPointerBitwidth(0));
102
104 "mgpuStreamCreate", llvmPointerType , {}};
106 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType }};
108 "mgpuStreamSynchronize",
109 llvmVoidType,
110 {llvmPointerType }};
112 "mgpuStreamWaitEvent",
113 llvmVoidType,
114 {llvmPointerType , llvmPointerType }};
116 "mgpuEventCreate", llvmPointerType , {}};
118 "mgpuEventDestroy", llvmVoidType, {llvmPointerType }};
120 "mgpuEventSynchronize",
121 llvmVoidType,
122 {llvmPointerType }};
124 "mgpuEventRecord",
125 llvmVoidType,
126 {llvmPointerType , llvmPointerType }};
128 "mgpuMemHostRegisterMemRef",
129 llvmVoidType,
130 {llvmIntPtrType ,
131 llvmPointerType ,
132 llvmIntPtrType }};
134 "mgpuMemHostUnregisterMemRef",
135 llvmVoidType,
136 {llvmIntPtrType ,
137 llvmPointerType ,
138 llvmIntPtrType }};
140 "mgpuMemAlloc",
141 llvmPointerType ,
142 {llvmIntPtrType ,
143 llvmPointerType ,
144 llvmInt8Type }};
146 "mgpuMemFree",
147 llvmVoidType,
148 {llvmPointerType , llvmPointerType }};
150 "mgpuMemcpy",
151 llvmVoidType,
152 {llvmPointerType , llvmPointerType ,
153 llvmIntPtrType ,
154 llvmPointerType }};
156 "mgpuMemset16",
157 llvmVoidType,
158 {llvmPointerType ,
159 llvmInt16Type ,
160 llvmIntPtrType ,
161 llvmPointerType }};
163 "mgpuMemset32",
164 llvmVoidType,
165 {llvmPointerType , llvmInt32Type ,
166 llvmIntPtrType ,
167 llvmPointerType }};
169 "mgpuSetDefaultDevice",
170 llvmVoidType,
171 {llvmInt32Type }};
173 "mgpuCreateDnVec",
174 llvmPointerType,
175 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
176 llvmPointerType }};
178 "mgpuDestroyDnVec",
179 llvmVoidType,
180 {llvmPointerType, llvmPointerType }};
182 "mgpuCreateDnMat",
183 llvmPointerType,
184 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
185 llvmPointerType }};
187 "mgpuDestroyDnMat",
188 llvmVoidType,
189 {llvmPointerType, llvmPointerType }};
191 "mgpuCreateCoo",
192 llvmPointerType,
193 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
194 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
195 llvmPointerType }};
197 "mgpuCreateCooAoS",
198 llvmPointerType,
199 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
200 llvmPointerType, llvmInt32Type, llvmInt32Type,
201 llvmPointerType }};
203 "mgpuCreateCsr",
204 llvmPointerType,
205 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
206 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
207 llvmInt32Type, llvmPointerType }};
209 "mgpuCreateCsc",
210 llvmPointerType,
211 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
212 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
213 llvmInt32Type, llvmPointerType }};
215 "mgpuCreateBsr",
216 llvmPointerType,
217 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
218 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
219 llvmInt32Type, llvmInt32Type, llvmInt32Type,
220 llvmPointerType }};
222 "mgpuDestroySpMat",
223 llvmVoidType,
224 {llvmPointerType, llvmPointerType }};
226 "mgpuSpMVBufferSize",
227 llvmIntPtrType,
228 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
229 llvmInt32Type, llvmPointerType }};
231 "mgpuSpMV",
232 llvmVoidType,
233 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
234 llvmInt32Type, llvmPointerType, llvmPointerType }};
236 "mgpuSpMMBufferSize",
237 llvmIntPtrType,
238 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
239 llvmPointerType, llvmInt32Type, llvmPointerType }};
241 "mgpuSpMM",
242 llvmVoidType,
243 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
244 llvmPointerType, llvmInt32Type, llvmPointerType,
245 llvmPointerType }};
247 "mgpuSDDMMBufferSize",
248 llvmIntPtrType,
249 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
250 llvmPointerType, llvmInt32Type, llvmPointerType }};
252 "mgpuSDDMM",
253 llvmVoidType,
254 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
255 llvmPointerType, llvmInt32Type, llvmPointerType,
256 llvmPointerType }};
258 "mgpuCreateCuSparseLtDnMat",
259 llvmVoidType,
260 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
261 llvmInt32Type, llvmPointerType }};
263 "mgpuDestroyCuSparseLtSpMat",
264 llvmVoidType,
265 {llvmPointerType, llvmPointerType }};
267 "mgpuDestroyCuSparseLtDnMat",
268 llvmVoidType,
269 {llvmPointerType, llvmPointerType }};
271 "mgpuCusparseLtCreate2To4SpMat",
272 llvmVoidType,
273 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
274 llvmInt32Type, llvmPointerType }};
276 "mgpuCuSparseLtSpMMBufferSize",
277 llvmVoidType,
278 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
279 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
280 llvmPointerType }};
282 "mgpuCuSparseLtSpMM",
283 llvmVoidType,
284 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
285 llvmPointerType, llvmPointerType, llvmPointerType }};
287 "mgpuSpGEMMCreateDescr",
288 llvmPointerType,
289 {llvmPointerType }};
291 "mgpuSpGEMMDestroyDescr",
292 llvmVoidType,
293 {llvmPointerType , llvmPointerType }};
295 "mgpuSpGEMMWorkEstimation",
296 llvmIntPtrType,
297 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
298 llvmPointerType , llvmPointerType , llvmPointerType ,
299 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
300 llvmPointerType }};
302 "mgpuSpGEMMCompute",
303 llvmIntPtrType,
304 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
305 llvmPointerType , llvmPointerType , llvmPointerType ,
306 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
307 llvmPointerType }};
309 "mgpuSpGEMMCopy",
310 llvmVoidType,
311 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
312 llvmPointerType , llvmPointerType , llvmPointerType ,
313 llvmInt32Type , llvmPointerType }};
315 "mgpuSpMatGetSize",
316 llvmVoidType,
317 {llvmPointerType , llvmPointerType , llvmPointerType ,
318 llvmPointerType , llvmPointerType }};
320 "mgpuSetCsrPointers",
321 llvmVoidType,
322 {llvmPointerType , llvmPointerType ,
323 llvmPointerType , llvmPointerType ,
324 llvmPointerType }};
325 };
326
327
328
329 class ConvertHostRegisterOpToGpuRuntimeCallPattern
330 : public ConvertOpToGpuRuntimeCallPatterngpu::HostRegisterOp {
331 public:
332 ConvertHostRegisterOpToGpuRuntimeCallPattern(
334 : ConvertOpToGpuRuntimeCallPatterngpu::HostRegisterOp(typeConverter) {}
335
336 private:
337 LogicalResult
338 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
340 };
341
342 class ConvertHostUnregisterOpToGpuRuntimeCallPattern
343 : public ConvertOpToGpuRuntimeCallPatterngpu::HostUnregisterOp {
344 public:
345 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
347 : ConvertOpToGpuRuntimeCallPatterngpu::HostUnregisterOp(typeConverter) {
348 }
349
350 private:
351 LogicalResult
352 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
354 };
355
356
357
358 class ConvertAllocOpToGpuRuntimeCallPattern
359 : public ConvertOpToGpuRuntimeCallPatterngpu::AllocOp {
360 public:
361 ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
362 : ConvertOpToGpuRuntimeCallPatterngpu::AllocOp(typeConverter) {}
363
364 private:
365 LogicalResult
366 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
368 };
369
370
371
372 class ConvertDeallocOpToGpuRuntimeCallPattern
373 : public ConvertOpToGpuRuntimeCallPatterngpu::DeallocOp {
374 public:
375 ConvertDeallocOpToGpuRuntimeCallPattern(
377 : ConvertOpToGpuRuntimeCallPatterngpu::DeallocOp(typeConverter) {}
378
379 private:
380 LogicalResult
381 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
383 };
384
385 class ConvertAsyncYieldToGpuRuntimeCallPattern
386 : public ConvertOpToGpuRuntimeCallPatternasync::YieldOp {
387 public:
388 ConvertAsyncYieldToGpuRuntimeCallPattern(
390 : ConvertOpToGpuRuntimeCallPatternasync::YieldOp(typeConverter) {}
391
392 private:
393 LogicalResult
394 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
396 };
397
398
399
400 class ConvertWaitOpToGpuRuntimeCallPattern
401 : public ConvertOpToGpuRuntimeCallPatterngpu::WaitOp {
402 public:
403 ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
404 : ConvertOpToGpuRuntimeCallPatterngpu::WaitOp(typeConverter) {}
405
406 private:
407 LogicalResult
408 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
410 };
411
412
413
414 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
415 : public ConvertOpToGpuRuntimeCallPatterngpu::WaitOp {
416 public:
417 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
419 : ConvertOpToGpuRuntimeCallPatterngpu::WaitOp(typeConverter) {}
420
421 private:
422 LogicalResult
423 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
425 };
426
427
428 class LegalizeLaunchFuncOpPattern
429 : public ConvertOpToGpuRuntimeCallPatterngpu::LaunchFuncOp {
430 public:
431 LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
432 bool kernelBarePtrCallConv,
433 bool kernelIntersperseSizeCallConv)
434 : ConvertOpToGpuRuntimeCallPatterngpu::LaunchFuncOp(typeConverter),
435 kernelBarePtrCallConv(kernelBarePtrCallConv),
436 kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
437
438 private:
439 LogicalResult
440 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
442
443 bool kernelBarePtrCallConv;
444 bool kernelIntersperseSizeCallConv;
445 };
446
447
448
449 class ConvertMemcpyOpToGpuRuntimeCallPattern
450 : public ConvertOpToGpuRuntimeCallPatterngpu::MemcpyOp {
451 public:
452 ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
453 : ConvertOpToGpuRuntimeCallPatterngpu::MemcpyOp(typeConverter) {}
454
455 private:
456 LogicalResult
457 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
459 };
460
461
462
463 class ConvertMemsetOpToGpuRuntimeCallPattern
464 : public ConvertOpToGpuRuntimeCallPatterngpu::MemsetOp {
465 public:
466 ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
467 : ConvertOpToGpuRuntimeCallPatterngpu::MemsetOp(typeConverter) {}
468
469 private:
470 LogicalResult
471 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
473 };
474
475
476
477 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
478 : public ConvertOpToGpuRuntimeCallPatterngpu::SetDefaultDeviceOp {
479 public:
480 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
482 : ConvertOpToGpuRuntimeCallPatterngpu::SetDefaultDeviceOp(
483 typeConverter) {}
484
485 LogicalResult
486 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
488 };
489
490
491
492 #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
493 class Convert##op_name##ToGpuRuntimeCallPattern \
494 : public ConvertOpToGpuRuntimeCallPatterngpu::op\_name { \
495 public: \
496 Convert##op_name##ToGpuRuntimeCallPattern( \
497 const LLVMTypeConverter &typeConverter) \
498 : ConvertOpToGpuRuntimeCallPatterngpu::op\_name(typeConverter) {} \
499 \
500 private: \
501 LogicalResult \
502 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
503 ConversionPatternRewriter &rewriter) const override; \
504 };
505
527
528 }
529
530 void GpuToLLVMConversionPass::runOnOperation() {
532
533
534 {
536
538 1);
540 return signalPassFailure();
541 }
542
544 options.useBarePtrCallConv = hostBarePtrCallConv;
547 target.addLegalDialectLLVM::LLVMDialect();
549
550
551
553 auto iface = dyn_cast(dialect);
554 if (!iface)
555 continue;
556 iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
557 }
558
559
560
561 target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
562
563 target.addDynamicallyLegalOpgpu::LaunchFuncOp(
564 [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
565
566
570 target);
572 kernelBarePtrCallConv,
573 kernelIntersperseSizeCallConv);
574
575 if (failed(
577 signalPassFailure();
578 }
579
583 auto function = [&] {
584 if (auto function = module.lookupSymbolLLVM::LLVMFuncOp(functionName))
585 return function;
588 }();
589 return builder.createLLVM::CallOp(loc, function, arguments);
590 }
591
592
595 return 1;
597 return 2;
598 return 3;
599 }
600
602 if (type.isF16())
603 return 0;
605 return 1;
606 llvm_unreachable("unsupported type");
607
608 }
609
610
612 if (llvm::isa(type)) {
613
614 auto elementType = cast(type).getElementType();
615 if (elementType.isBF16())
616 return 15;
617 if (elementType.isF16())
618 return 6;
619 if (elementType.isF32())
620 return 4;
621 if (elementType.isF64())
622 return 5;
623 if (elementType.isInteger(8))
624 return 7;
625 if (elementType.isInteger(16))
626 return 21;
627 if (elementType.isInteger(32))
628 return 11;
629 }
631 return 14;
632 if (type.isF16())
633 return 2;
634 if (type.isF32())
635 return 0;
636 if (type.isF64())
637 return 1;
639 return 3;
641 return 20;
643 return 10;
644
645 llvm_unreachable("unsupported element type");
646 }
647
649 return spMat.getDefiningOpgpu::Create2To4SpMatOp().getPruneFlag();
650 }
651
652
653
654
655
656
657
658
660 if (auto op = spMat.getDefiningOpgpu::Create2To4SpMatOp())
661 return true;
662 if (auto op = spMat.getDefiningOpgpu::CreateCooOp())
663 return false;
664 if (auto op = spMat.getDefiningOpgpu::CreateCooAoSOp())
665 return false;
666 if (auto op = spMat.getDefiningOpgpu::CreateCsrOp())
667 return false;
668 if (auto op = spMat.getDefiningOpgpu::CreateCscOp())
669 return false;
670 if (auto op = spMat.getDefiningOpgpu::CreateBsrOp())
671 return false;
672
674 llvm_unreachable("cannot find spmat def");
675 }
676
679 auto spmmOp = dyn_castgpu::SpMMOp(user);
680
681 if (!spmmOp)
682 continue;
684 return true;
685 }
686 return false;
687 }
688
689
692 if (!llvm::all_of(operands, [](Value value) {
694 }))
696 op, "Cannot convert if operands aren't of LLVM type.");
697 return success();
698 }
699
700 static LogicalResult
702 gpu::AsyncOpInterface op) {
703 if (op.getAsyncDependencies().size() != 1)
705 op, "Can only convert with exactly one async dependency.");
706
707 if (!op.getAsyncToken())
708 return rewriter.notifyMatchFailure(op, "Can convert only async version.");
709
710 return success();
711 }
712
713 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
714 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
716 auto *op = hostRegisterOp.getOperation();
717 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
718 return failure();
719
721
722 auto memRefType = hostRegisterOp.getValue().getType();
723 auto elementType = cast(memRefType).getElementType();
724 auto elementSize = getSizeInBytes(loc, elementType, rewriter);
725
726 auto arguments = getTypeConverter()->promoteOperands(
727 loc, op->getOperands(), adaptor.getOperands(), rewriter);
728 arguments.push_back(elementSize);
729 hostRegisterCallBuilder.create(loc, rewriter, arguments);
730
732 return success();
733 }
734
735 LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
736 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
738 Operation *op = hostUnregisterOp.getOperation();
739 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
740 return failure();
741
743
744 auto memRefType = hostUnregisterOp.getValue().getType();
745 auto elementType = cast(memRefType).getElementType();
746 auto elementSize = getSizeInBytes(loc, elementType, rewriter);
747
748 auto arguments = getTypeConverter()->promoteOperands(
749 loc, op->getOperands(), adaptor.getOperands(), rewriter);
750 arguments.push_back(elementSize);
751 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
752
754 return success();
755 }
756
757 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
758 gpu::AllocOp allocOp, OpAdaptor adaptor,
760
761 MemRefType memRefType = allocOp.getType();
762
763 if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
764 !isConvertibleAndHasIdentityMaps(memRefType))
765 return failure();
766
767 auto loc = allocOp.getLoc();
768
769 bool isShared = allocOp.getHostShared();
770
771 if (isShared && allocOp.getAsyncToken())
773 allocOp, "Host Shared allocation cannot be done async");
775 return failure();
776
777
778
782 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
783 shape, strides, sizeBytes);
784
785
786
787 auto nullPtr = rewriter.createmlir::LLVM::ZeroOp(loc, llvmPointerType);
788 Value stream = adaptor.getAsyncDependencies().empty()
789 ? nullPtr
790 : adaptor.getAsyncDependencies().front();
791
792 auto isHostShared = rewriter.createmlir::LLVM::ConstantOp(
794
795 Value allocatedPtr =
796 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
797 .getResult();
798
799
800 Value alignedPtr = allocatedPtr;
801
802
803 auto memRefDescriptor = this->createMemRefDescriptor(
804 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
805
806 if (allocOp.getAsyncToken()) {
807
808 rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
809 } else {
810 rewriter.replaceOp(allocOp, {memRefDescriptor});
811 }
812
813 return success();
814 }
815
816 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
817 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
819 if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
821 return failure();
822
823 Location loc = deallocOp.getLoc();
824
827 Value stream = adaptor.getAsyncDependencies().front();
828 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
829
830 rewriter.replaceOp(deallocOp, {stream});
831 return success();
832 }
833
835 return isagpu::AsyncTokenType(value.getType());
836 }
837
838
839
840
841
842 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
843 async::YieldOp yieldOp, OpAdaptor adaptor,
846 return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
847
848 Location loc = yieldOp.getLoc();
850 llvm::SmallDenseSet streams;
851 for (auto &operand : yieldOp->getOpOperands()) {
853 continue;
854 auto idx = operand.getOperandNumber();
855 auto stream = adaptor.getOperands()[idx];
856 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
857 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
858 newOperands[idx] = event;
859 streams.insert(stream);
860 }
861 for (auto stream : streams)
862 streamDestroyCallBuilder.create(loc, rewriter, {stream});
863
864 rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
865 return success();
866 }
867
868
870 assert(isaLLVM::LLVMPointerType(value.getType()));
871 if (auto defOp = value.getDefiningOpLLVM::CallOp())
872 return *defOp.getCallee() == functionName;
873 return false;
874 }
875
876
877
878
879
880 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
881 gpu::WaitOp waitOp, OpAdaptor adaptor,
883 if (waitOp.getAsyncToken())
884 return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
885
886 Location loc = waitOp.getLoc();
887
888 for (auto operand : adaptor.getOperands()) {
889 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
890
891 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
892 streamDestroyCallBuilder.create(loc, rewriter, {operand});
893 } else {
894
895
896 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
897 eventDestroyCallBuilder.create(loc, rewriter, {operand});
898 }
899 }
900
901 rewriter.eraseOp(waitOp);
902 return success();
903 }
904
905
906
907
908
909
910 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
911 gpu::WaitOp waitOp, OpAdaptor adaptor,
913 if (!waitOp.getAsyncToken())
914 return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
915
916 Location loc = waitOp.getLoc();
917
920 for (auto pair :
921 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
922 auto operand = std::get<1>(pair);
923 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
924
925
926 auto *defOp = std::get<0>(pair).getDefiningOp();
928 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
929 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
930 events.push_back(event);
931 } else {
932
933
934 events.push_back(operand);
935 }
936 }
938 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
939 for (auto event : events)
940 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
941 for (auto event : events)
942 eventDestroyCallBuilder.create(loc, rewriter, {event});
943 rewriter.replaceOp(waitOp, {stream});
944
945 return success();
946 }
947
948
949 LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
950 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
952 if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
953 return failure();
954
955 if (launchOp.getAsyncDependencies().size() > 1)
957 launchOp, "Cannot convert with more than one async dependency.");
958
959
960
961
962 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
964 launchOp, "Cannot convert non-async op with async dependencies.");
965
966 Location loc = launchOp.getLoc();
967
969 if (!adaptor.getAsyncDependencies().empty())
970 stream = adaptor.getAsyncDependencies().front();
971
972
973 else if (launchOp.getAsyncToken())
974 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
975
976
977
978
979 OperandRange origArguments = launchOp.getKernelOperands();
981 loc, origArguments, adaptor.getKernelOperands(), rewriter,
982 kernelBarePtrCallConv);
984
985
986 if (kernelIntersperseSizeCallConv) {
987 if (origArguments.size() != llvmArguments.size()) {
988
990 launchOp,
991 "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
992 }
993
994 llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
995 for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
996 auto memrefTy = dyn_cast(origArg.getType());
997 if (!memrefTy) {
999 launchOp, "Operand to launch op is not a memref.");
1000 }
1001
1002 if (!memrefTy.hasStaticShape() ||
1003 !memrefTy.getElementType().isIntOrFloat()) {
1005 launchOp, "Operand to launch op is not a memref with a static "
1006 "shape and an integer or float element type.");
1007 }
1008
1009 unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1010 if (bitwidth % 8 != 0) {
1012 launchOp, "Operand to launch op is not a memref with a "
1013 "byte-aligned element type.");
1014 }
1015
1016 uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
1017 static_cast<uint64_t>(memrefTy.getNumElements());
1018
1019 Value sizeArg = rewriter.createLLVM::ConstantOp(
1020 loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1021 llvmArgumentsWithSizes.push_back(llvmArg);
1022 llvmArgumentsWithSizes.push_back(sizeArg);
1023 }
1024 }
1025
1026 std::optionalgpu::KernelDim3 clusterSize = std::nullopt;
1027 if (launchOp.hasClusterSize()) {
1028 clusterSize =
1029 gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1030 adaptor.getClusterSizeZ()};
1031 }
1032 rewriter.creategpu::LaunchFuncOp(
1033 launchOp.getLoc(), launchOp.getKernelAttr(),
1034 gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1035 adaptor.getGridSizeZ()},
1036 gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1037 adaptor.getBlockSizeZ()},
1038 adaptor.getDynamicSharedMemorySize(),
1039 llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1040 stream, clusterSize);
1041 if (launchOp.getAsyncToken())
1042 rewriter.replaceOp(launchOp, {stream});
1043 else
1044 rewriter.eraseOp(launchOp);
1045 return success();
1046 }
1047
1050 LLVM::LLVMPointerType destinationType,
1051 Value sourcePtr,
1053 auto sourceTy = castLLVM::LLVMPointerType(sourcePtr.getType());
1054 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1055 sourcePtr = rewriter.createLLVM::AddrSpaceCastOp(
1056 loc,
1058 destinationType.getAddressSpace()),
1059 sourcePtr);
1060 return sourcePtr;
1061 }
1062
1063 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1064 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1066 auto memRefType = cast(memcpyOp.getSrc().getType());
1067
1068 if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1069 !isConvertibleAndHasIdentityMaps(memRefType) ||
1071 return failure();
1072
1073 auto loc = memcpyOp.getLoc();
1074
1077
1079 Value nullPtr = rewriter.createLLVM::ZeroOp(loc, elementPtrType);
1080 Value gepPtr = rewriter.createLLVM::GEPOp(
1081 loc, elementPtrType,
1082 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1083 numElements);
1084 auto sizeBytes =
1085 rewriter.createLLVM::PtrToIntOp(loc, getIndexType(), gepPtr);
1086
1088 srcDesc.alignedPtr(rewriter, loc),
1089 *getTypeConverter());
1091 loc, rewriter, llvmPointerType,
1093 *getTypeConverter());
1094
1095 auto stream = adaptor.getAsyncDependencies().front();
1096 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1097
1098 rewriter.replaceOp(memcpyOp, {stream});
1099
1100 return success();
1101 }
1102
1103 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1104 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1106 auto memRefType = cast(memsetOp.getDst().getType());
1107
1108 if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1109 !isConvertibleAndHasIdentityMaps(memRefType) ||
1111 return failure();
1112
1113 auto loc = memsetOp.getLoc();
1114
1115 Type valueType = adaptor.getValue().getType();
1117
1118 if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1120 memsetOp, "value must be a 16 or 32 bit int or float");
1121 }
1122
1124 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1125
1128
1129 auto value =
1130 rewriter.createLLVM::BitcastOp(loc, bitCastType, adaptor.getValue());
1132 dstDesc.alignedPtr(rewriter, loc),
1133 *getTypeConverter());
1134
1135 auto stream = adaptor.getAsyncDependencies().front();
1137 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1138 builder.create(loc, rewriter, {dst, value, numElements, stream});
1139
1140 rewriter.replaceOp(memsetOp, {stream});
1141 return success();
1142 }
1143
1144 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1145 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1148 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1149 {adaptor.getDevIndex()});
1151 return success();
1152 }
1153
1154 template
1157 return builder.createLLVM::ConstantOp(loc, llvmInt32Type,
1158 static_cast<int32_t>(tValue));
1159 }
1160
1161 template
1164 return builder.createLLVM::ConstantOp(
1165 loc, llvmFloat32Type,
1167 }
1168
1169 LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1170 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1172 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1174 return failure();
1176 auto stream = adaptor.getAsyncDependencies().front();
1179 Type dType = op.getMemref().getType().getElementType();
1181
1183 for (Value dim : adaptor.getDims()) {
1184 dims.push_back(dim);
1185 }
1186
1188
1189
1190
1191
1192
1193
1194 if (dims.size() == 2) {
1196 auto handleSz = rewriter.createLLVM::ConstantOp(
1197 loc, getIndexType(), rewriter.getIndexAttr(11032));
1198 handle = rewriter.createLLVM::AllocaOp(
1199 loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1200 handle = rewriter.createLLVM::BitcastOp(loc, llvmPointerType, handle);
1201
1202 createLtDnMatCallBuilder
1203 .create(loc, rewriter,
1204 {handle, dims[0], dims[1], pTensor, dtp, stream})
1205 .getResult();
1206 } else {
1207 handle =
1208 createDnMatCallBuilder
1209 .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1210 .getResult();
1211 }
1212 } else {
1213 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1214 handle = createDnVecCallBuilder
1215 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1216 .getResult();
1217 }
1218 rewriter.replaceOp(op, {handle, stream});
1219 return success();
1220 }
1221
1222 LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1223 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1225 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1227 return failure();
1229 auto stream = adaptor.getAsyncDependencies().front();
1230 auto definingOp = op.getDnTensor().getDefiningOpgpu::CreateDnTensorOp();
1232 for (Value dim : definingOp.getDims()) {
1233 dims.push_back(dim);
1234 }
1235 if (dims.size() == 2) {
1236
1237
1239 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1240 {adaptor.getDnTensor(), stream});
1241 } else {
1242 destroyDnMatCallBuilder.create(loc, rewriter,
1243 {adaptor.getDnTensor(), stream});
1244 }
1245 } else {
1246 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1247 destroyDnVecCallBuilder.create(loc, rewriter,
1248 {adaptor.getDnTensor(), stream});
1249 }
1250 rewriter.replaceOp(op, {stream});
1251 return success();
1252 }
1253
1254 LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1255 gpu::CreateCooOp op, OpAdaptor adaptor,
1257 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1259 return failure();
1261 auto stream = adaptor.getAsyncDependencies().front();
1262 Value pRowIdxs =
1264 Value pColIdxs =
1268 Type iType =
1269 llvm::cast(op.getColIdxs().getType()).getElementType();
1270 Type dType =
1271 llvm::cast(op.getValues().getType()).getElementType();
1274 auto handle =
1275 createCooCallBuilder
1276 .create(loc, rewriter,
1277 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1278 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1279 .getResult();
1280 rewriter.replaceOp(op, {handle, stream});
1281 return success();
1282 }
1283
1284 LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1285 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1287 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1289 return failure();
1291 auto stream = adaptor.getAsyncDependencies().front();
1295 Type iType = llvm::cast(op.getIdxs().getType()).getElementType();
1296 Type dType =
1297 llvm::cast(op.getValues().getType()).getElementType();
1300 auto handle =
1301 createCooAoSCallBuilder
1302 .create(loc, rewriter,
1303 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1304 pIdxs, pValues, itp, dtp, stream})
1305 .getResult();
1306 rewriter.replaceOp(op, {handle, stream});
1307 return success();
1308 }
1309
1310 LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1311 gpu::CreateCsrOp op, OpAdaptor adaptor,
1313 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1315 return failure();
1317 auto stream = adaptor.getAsyncDependencies().front();
1320 Value pColIdxs =
1324 Type pType =
1325 llvm::cast(op.getRowPos().getType()).getElementType();
1326 Type iType =
1327 llvm::cast(op.getColIdxs().getType()).getElementType();
1328 Type dType =
1329 llvm::cast(op.getValues().getType()).getElementType();
1333 auto handle =
1334 createCsrCallBuilder
1335 .create(loc, rewriter,
1336 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1337 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1338 .getResult();
1339 rewriter.replaceOp(op, {handle, stream});
1340 return success();
1341 }
1342
1343 LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1344 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1346 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1348 return failure();
1350 auto stream = adaptor.getAsyncDependencies().front();
1353 Type dType =
1354 llvm::cast(op.getMemref().getType()).getElementType();
1356
1357
1358 auto handleSz = rewriter.createLLVM::ConstantOp(
1359 loc, getIndexType(), rewriter.getIndexAttr(44104));
1360 Value handle = rewriter.createLLVM::AllocaOp(
1361 loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1362 handle = rewriter.createLLVM::BitcastOp(loc, llvmPointerType, handle);
1363
1364 create2To4SpMatCallBuilder
1365 .create(loc, rewriter,
1366 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1367 .getResult();
1368 rewriter.replaceOp(op, {handle, stream});
1369 return success();
1370 }
1371
1372 LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1373 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1375 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1377 return failure();
1379 auto stream = adaptor.getAsyncDependencies().front();
1380
1382 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1383 {adaptor.getSpmat(), stream});
1384
1385 } else {
1386 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1387 }
1388 rewriter.replaceOp(op, {stream});
1389 return success();
1390 }
1391
1392 LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1393 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1395 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1397 return failure();
1402 auto stream = adaptor.getAsyncDependencies().front();
1403 auto bufferSize = spMVBufferSizeCallBuilder
1404 .create(loc, rewriter,
1405 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1406 adaptor.getDnY(), computeType, stream})
1407 .getResult();
1408 rewriter.replaceOp(op, {bufferSize, stream});
1409 return success();
1410 }
1411
1412 LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1413 gpu::SpMVOp op, OpAdaptor adaptor,
1415 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1417 return failure();
1419 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1422 auto stream = adaptor.getAsyncDependencies().front();
1425 spMVCallBuilder.create(loc, rewriter,
1426 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1427 adaptor.getDnY(), computeType, pBuf, stream});
1428 rewriter.replaceOp(op, {stream});
1429 return success();
1430 }
1431
1432 LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1433 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1435 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1437 return failure();
1439 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1440 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1441 auto stream = adaptor.getAsyncDependencies().front();
1442 Value bufferSize;
1444 auto pruneFlag =
1448 auto three = rewriter.createLLVM::ConstantOp(loc, getIndexType(),
1450 auto bufferSize = rewriter.createLLVM::AllocaOp(
1451 loc, llvmPointerType, llvmPointerType, three, 16);
1452 createCuSparseLtSpMMBufferSizeBuilder
1453 .create(loc, rewriter,
1454 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1455 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1456 pruneFlag, stream})
1457 .getResult();
1458
1459 auto bufferSizePtr1 = rewriter.createLLVM::GEPOp(
1460 loc, llvmPointerType, llvmPointerType, bufferSize,
1462 loc, getIndexType(), rewriter.getIndexAttr(1))});
1463 auto bufferSizePtr2 = rewriter.createLLVM::GEPOp(
1464 loc, llvmPointerType, llvmPointerType, bufferSize,
1466 loc, getIndexType(), rewriter.getIndexAttr(2))});
1467 auto bufferSize0 =
1468 rewriter.createLLVM::LoadOp(loc, llvmInt64Type, bufferSize);
1469 auto bufferSize1 =
1470 rewriter.createLLVM::LoadOp(loc, llvmInt64Type, bufferSizePtr1);
1471 auto bufferSize2 =
1472 rewriter.createLLVM::LoadOp(loc, llvmInt64Type, bufferSizePtr2);
1473
1474 rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1475 } else {
1478 bufferSize =
1479 createSpMMBufferSizeCallBuilder
1480 .create(loc, rewriter,
1481 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1482 adaptor.getDnmatC(), computeType, stream})
1483 .getResult();
1484 rewriter.replaceOp(op, {bufferSize, stream});
1485 }
1486 return success();
1487 }
1488
1489 LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1490 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1492 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1494 return failure();
1496 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1497 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1500 auto stream = adaptor.getAsyncDependencies().front();
1501 auto bufferSize =
1502 createSDDMMBufferSizeCallBuilder
1503 .create(loc, rewriter,
1504 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1505 adaptor.getSpmatC(), computeType, stream})
1506 .getResult();
1507 rewriter.replaceOp(op, {bufferSize, stream});
1508 return success();
1509 }
1510
1511 LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1512 gpu::SpMMOp op, OpAdaptor adaptor,
1514 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1516 return failure();
1518 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1519 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1522
1523 auto stream = adaptor.getAsyncDependencies().front();
1524
1525
1528 for (Value buffer : adaptor.getBuffers()) {
1530 pBufs.push_back(pBuf);
1531 }
1532 createCuSparseLtSpMMBuilder.create(
1533 loc, rewriter,
1534 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1535 pBufs[0], pBufs[1], pBufs[2], stream});
1536 } else {
1539 createSpMMCallBuilder.create(loc, rewriter,
1540 {modeA, modeB, adaptor.getSpmatA(),
1541 adaptor.getDnmatB(), adaptor.getDnmatC(),
1542 computeType, pBuf, stream});
1543 }
1544 rewriter.replaceOp(op, {stream});
1545 return success();
1546 }
1547
1548 template
1552 });
1553 }
1554
1555 LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1556 gpu::SDDMMOp op, OpAdaptor adaptor,
1558 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1560 return failure();
1564 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1565 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1566 auto stream = adaptor.getAsyncDependencies().front();
1569 createSDDMMCallBuilder.create(loc, rewriter,
1570 {modeA, modeB, adaptor.getDnmatA(),
1571 adaptor.getDnmatB(), adaptor.getSpmatC(),
1572 computeType, pBuf, stream});
1573 rewriter.replaceOp(op, {stream});
1574 return success();
1575 }
1576
1577 LogicalResult
1578 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1579 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1581 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1583 return failure();
1585 auto stream = adaptor.getAsyncDependencies().front();
1586 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1587 .getResult();
1588 rewriter.replaceOp(op, {descr, stream});
1589 return success();
1590 }
1591
1592 LogicalResult
1593 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1594 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1596 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1598 return failure();
1600 auto stream = adaptor.getAsyncDependencies().front();
1601 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1602 {adaptor.getDesc(), stream});
1603 rewriter.replaceOp(op, {stream});
1604 return success();
1605 }
1606
1607 LogicalResult
1608 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1609 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1611 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1613 return failure();
1617 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1618 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1619 auto stream = adaptor.getAsyncDependencies().front();
1620
1623 Value bufferSizeNew;
1624
1625 if (adaptor.getKind() ==
1626 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1627 bufferSizeNew =
1628 createSpGEMMWorkEstimationBuilder
1629 .create(loc, rewriter,
1630 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1631 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1632 adaptor.getBufferSz(), pBuf, stream})
1633 .getResult();
1634 } else {
1635 bufferSizeNew =
1636 createSpGEMMComputeBuilder
1637 .create(loc, rewriter,
1638 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1639 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1640 adaptor.getBufferSz(), pBuf, stream})
1641 .getResult();
1642 }
1643 rewriter.replaceOp(op, {bufferSizeNew, stream});
1644 return success();
1645 }
1646
1647 LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1648 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1650 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1652 return failure();
1656 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1657 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1658 auto stream = adaptor.getAsyncDependencies().front();
1659 createSpGEMMCopyBuilder.create(loc, rewriter,
1660 {adaptor.getDesc(), modeA, modeB,
1661 adaptor.getSpmatA(), adaptor.getSpmatB(),
1662 adaptor.getSpmatC(), computeType, stream});
1663 rewriter.replaceOp(op, {stream});
1664 return success();
1665 }
1666
1667 LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1668 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1670 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1672 return failure();
1674 auto stream = adaptor.getAsyncDependencies().front();
1675
1676 auto three = rewriter.createLLVM::ConstantOp(loc, getIndexType(),
1678 auto buffer = rewriter.createLLVM::AllocaOp(
1679 loc, llvmPointerType, llvmInt64Type, three, 16);
1680
1681 auto rowsPtr = rewriter.createLLVM::GEPOp(
1682 loc, llvmPointerType, llvmPointerType, buffer,
1683 ValueRange{rewriter.createLLVM::ConstantOp(loc, getIndexType(),
1685 auto colsPtr = rewriter.createLLVM::GEPOp(
1686 loc, llvmPointerType, llvmPointerType, buffer,
1687 ValueRange{rewriter.createLLVM::ConstantOp(loc, getIndexType(),
1689 auto nnzsPtr = rewriter.createLLVM::GEPOp(
1690 loc, llvmPointerType, llvmPointerType, buffer,
1691 ValueRange{rewriter.createLLVM::ConstantOp(loc, getIndexType(),
1693 createSpMatGetSizeBuilder.create(
1694 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1695 auto rows = rewriter.createLLVM::LoadOp(loc, llvmInt64Type, rowsPtr);
1696 auto cols = rewriter.createLLVM::LoadOp(loc, llvmInt64Type, colsPtr);
1697 auto nnzs = rewriter.createLLVM::LoadOp(loc, llvmInt64Type, nnzsPtr);
1698
1700 return success();
1701 }
1702
1703 LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1704 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1706 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1708 return failure();
1710 auto stream = adaptor.getAsyncDependencies().front();
1717 createSetCsrPointersBuilder.create(
1718 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1719 rewriter.replaceOp(op, {stream});
1720 return success();
1721 }
1722
1723 LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1724 gpu::CreateCscOp op, OpAdaptor adaptor,
1726 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1728 return failure();
1730 auto stream = adaptor.getAsyncDependencies().front();
1733 Value pRowIdxs =
1737 Type pType =
1738 llvm::cast(op.getColPos().getType()).getElementType();
1739 Type iType =
1740 llvm::cast(op.getRowIdxs().getType()).getElementType();
1741 Type dType =
1742 llvm::cast(op.getValues().getType()).getElementType();
1746 auto handle =
1747 createCscCallBuilder
1748 .create(loc, rewriter,
1749 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1750 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1751 .getResult();
1752 rewriter.replaceOp(op, {handle, stream});
1753 return success();
1754 }
1755
1756 LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1757 gpu::CreateBsrOp op, OpAdaptor adaptor,
1759 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1761 return failure();
1763 auto stream = adaptor.getAsyncDependencies().front();
1766 Value pColIdxs =
1770 Type pType =
1771 llvm::cast(op.getBRowPos().getType()).getElementType();
1772 Type iType =
1773 llvm::cast(op.getBColIdxs().getType()).getElementType();
1774 Type dType =
1775 llvm::cast(op.getValues().getType()).getElementType();
1779 auto handle =
1780 createBsrCallBuilder
1781 .create(loc, rewriter,
1782 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1783 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1784 pColIdxs, pValues, ptp, itp, dtp, stream})
1785 .getResult();
1786 rewriter.replaceOp(op, {handle, stream});
1787 return success();
1788 }
1789
1792 bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) {
1793 addOpaquePointerConversiongpu::AsyncTokenType(converter);
1794 addOpaquePointerConversiongpu::SparseDnTensorHandleType(converter);
1795 addOpaquePointerConversiongpu::SparseSpMatHandleType(converter);
1796 addOpaquePointerConversiongpu::SparseSpGEMMOpHandleType(converter);
1797
1798 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1799 ConvertDeallocOpToGpuRuntimeCallPattern,
1800 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1801 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1802 ConvertMemcpyOpToGpuRuntimeCallPattern,
1803 ConvertMemsetOpToGpuRuntimeCallPattern,
1804 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1805 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1806 ConvertWaitOpToGpuRuntimeCallPattern,
1807 ConvertAsyncYieldToGpuRuntimeCallPattern,
1808 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1809 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1810 ConvertCreateCooOpToGpuRuntimeCallPattern,
1811 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1812 ConvertCreateCsrOpToGpuRuntimeCallPattern,
1813 ConvertCreateCscOpToGpuRuntimeCallPattern,
1814 ConvertCreateBsrOpToGpuRuntimeCallPattern,
1815 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1816 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1817 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1818 ConvertSpMVOpToGpuRuntimeCallPattern,
1819 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1820 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1821 ConvertSpMMOpToGpuRuntimeCallPattern,
1822 ConvertSDDMMOpToGpuRuntimeCallPattern,
1823 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1824 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1825 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1826 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1827 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1828 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1829 patterns.add(converter, kernelBarePtrCallConv,
1830 kernelIntersperseSizeCallConv);
1831 }
1832
1833
1834
1835
1836
1837 namespace {
1838 struct GPUModuleOpConvertToLLVMInterface
1839 : public ConvertToLLVMOpInterface::ExternalModel<
1840 GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1841
1842 void getConvertToLLVMConversionAttrs(
1844 };
1845 }
1846
1847 void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1849 auto module = castgpu::GPUModuleOp(op);
1850 ArrayAttr targetsAttr = module.getTargetsAttr();
1851
1852 if (!targetsAttr || targetsAttr.size() != 1)
1853 return;
1854 if (auto patternAttr = dyn_cast(targetsAttr[0]))
1855 attrs.push_back(patternAttr);
1856 }
1857
1860 gpu::GPUModuleOp::attachInterface(*ctx);
1861 });
1862 }
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
FloatAttr getF32FloatAttr(float value)
IntegerAttr getI8IntegerAttr(int8_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
Location getLoc()
The source location the operation was defined or derived from.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
operand_range getOperands()
Returns an iterator on the underlying Value's.
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void registerConvertGpuToLLVMInterface(DialectRegistry ®istry)
Registers the ConvertToLLVMOpInterface interface on the gpu::GPUModuleOP operation.
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false, bool kernelIntersperseSizeCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::LLVMFunctionType functionType
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Utility class for the GPU dialect to represent triples of Values accessible through ....