MLIR: lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/FormatVariadic.h"
28
29 #define DEBUG_TYPE "spirv-to-llvm-pattern"
30
31 using namespace mlir;
32
33
34
35
36
37
40 return true;
41 if (auto vecType = dyn_cast(type))
42 return vecType.getElementType().isSignedInteger();
43 return false;
44 }
45
46
49 return true;
50 if (auto vecType = dyn_cast(type))
51 return vecType.getElementType().isUnsignedInteger();
52 return false;
53 }
54
55
56
58 if (auto intType = dyn_cast(type))
59 return intType.getWidth();
60 if (auto vecType = dyn_cast(type))
61 if (auto intType = dyn_cast(vecType.getElementType()))
62 return intType.getWidth();
63 return std::nullopt;
64 }
65
66
68 assert((type.isIntOrFloat() || isa(type)) &&
69 "bitwidth is not supported for this type");
72 auto vecType = dyn_cast(type);
73 auto elementType = vecType.getElementType();
74 assert(elementType.isIntOrFloat() &&
75 "only integers and floats have a bitwidth");
76 return elementType.getIntOrFloatBitWidth();
77 }
78
79
81 if (auto vecTy = dyn_cast(type))
82 type = vecTy.getElementType();
83 return cast(type).getWidth();
84 }
85
86
88 if (auto vecType = dyn_cast(type)) {
89 auto integerType = cast(vecType.getElementType());
91 }
92 auto integerType = cast(type);
94 }
95
96
99 if (isa(srcType)) {
100 return rewriter.createLLVM::ConstantOp(
101 loc, dstType,
104 }
105 return rewriter.createLLVM::ConstantOp(
107 }
108
109
112 if (auto vecType = dyn_cast(srcType)) {
113 auto floatType = cast(vecType.getElementType());
114 return rewriter.createLLVM::ConstantOp(
115 loc, dstType,
118 }
119 auto floatType = cast(srcType);
120 return rewriter.createLLVM::ConstantOp(
121 loc, dstType, rewriter.getFloatAttr(floatType, value));
122 }
123
124
125
126
127
128
129
131 Type llvmType,
133 auto srcType = value.getType();
138
139 if (valueBitWidth < targetBitWidth)
140 return rewriter.createLLVM::ZExtOp(loc, llvmType, value);
141
142
143
144
145 if (valueBitWidth > targetBitWidth)
146 return rewriter.createLLVM::TruncOp(loc, llvmType, value);
147 return value;
148 }
149
150
155 auto llvmVectorType = typeConverter.convertType(vectorType);
157 Value broadcasted = rewriter.createLLVM::PoisonOp(loc, llvmVectorType);
158 for (unsigned i = 0; i < numElements; ++i) {
159 auto index = rewriter.createLLVM::ConstantOp(
161 broadcasted = rewriter.createLLVM::InsertElementOp(
162 loc, llvmVectorType, broadcasted, toBroadcast, index);
163 }
164 return broadcasted;
165 }
166
167
171 if (auto vectorType = dyn_cast(srcType)) {
172 unsigned numElements = vectorType.getNumElements();
173 return broadcast(loc, value, numElements, typeConverter, rewriter);
174 }
175 return value;
176 }
177
178
179
180
181
182
183
184
185
186
187
191 Value broadcasted =
194 }
195
196
197
201 return nullptr;
202
205 return nullptr;
206 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
207 false);
208 }
209
210
215 return nullptr;
216 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
217 true);
218 }
219
220
222 unsigned value) {
223 return rewriter.createLLVM::ConstantOp(
226 }
227
228
232 unsigned alignment, bool isVolatile,
233 bool isNonTemporal) {
234 if (auto loadOp = dyn_castspirv::LoadOp(op)) {
235 auto dstType = typeConverter.convertType(loadOp.getType());
236 if (!dstType)
239 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
240 isVolatile, isNonTemporal);
241 return success();
242 }
243 auto storeOp = castspirv::StoreOp(op);
244 spirv::StoreOpAdaptor adaptor(operands);
245 rewriter.replaceOpWithNewOpLLVM::StoreOp(storeOp, adaptor.getValue(),
246 adaptor.getPtr(), alignment,
247 isVolatile, isNonTemporal);
248 return success();
249 }
250
251
252
253
254
255
256
257
262 auto sizeInBytes = castspirv::SPIRVType(elementType).getSizeInBytes();
263 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
264 return std::nullopt;
265
266 auto llvmElementType = converter.convertType(elementType);
269 }
270
271
272
275 spirv::ClientAPI clientAPI) {
276 unsigned addressSpace =
279 }
280
281
282
283
287 return std::nullopt;
290 }
291
292
293
298 if (!memberDecorations.empty())
299 return nullptr;
303 }
304
305
306
307
308
309 namespace {
310
312 public:
314
315 LogicalResult
316 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
318 auto dstType =
319 getTypeConverter()->convertType(op.getComponentPtr().getType());
320 if (!dstType)
322
323 auto indices = llvm::to_vector<4>(adaptor.getIndices());
324 Type indexType = op.getIndices().front().getType();
325 auto llvmIndexType = getTypeConverter()->convertType(indexType);
326 if (!llvmIndexType)
328 Value zero = rewriter.createLLVM::ConstantOp(
329 op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
330 indices.insert(indices.begin(), zero);
331
332 auto elementType = getTypeConverter()->convertType(
333 castspirv::PointerType(op.getBasePtr().getType()).getPointeeType());
334 if (!elementType)
337 adaptor.getBasePtr(), indices);
338 return success();
339 }
340 };
341
343 public:
345
346 LogicalResult
347 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
349 auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
350 if (!dstType)
353 op.getVariable());
354 return success();
355 }
356 };
357
358 class BitFieldInsertPattern
360 public:
362
363 LogicalResult
364 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
366 auto srcType = op.getType();
367 auto dstType = getTypeConverter()->convertType(srcType);
368 if (!dstType)
371
372
374 *getTypeConverter(), rewriter);
376 *getTypeConverter(), rewriter);
377
378
380 Value maskShiftedByCount =
381 rewriter.createLLVM::ShlOp(loc, dstType, minusOne, count);
382 Value negated = rewriter.createLLVM::XOrOp(loc, dstType,
383 maskShiftedByCount, minusOne);
384 Value maskShiftedByCountAndOffset =
385 rewriter.createLLVM::ShlOp(loc, dstType, negated, offset);
386 Value mask = rewriter.createLLVM::XOrOp(
387 loc, dstType, maskShiftedByCountAndOffset, minusOne);
388
389
390
391 Value baseAndMask =
392 rewriter.createLLVM::AndOp(loc, dstType, op.getBase(), mask);
393 Value insertShiftedByOffset =
394 rewriter.createLLVM::ShlOp(loc, dstType, op.getInsert(), offset);
396 insertShiftedByOffset);
397 return success();
398 }
399 };
400
401
402 class ConstantScalarAndVectorPattern
404 public:
406
407 LogicalResult
408 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
410 auto srcType = constOp.getType();
411 if (!isa(srcType) && !srcType.isIntOrFloat())
412 return failure();
413
414 auto dstType = getTypeConverter()->convertType(srcType);
415 if (!dstType)
416 return rewriter.notifyMatchFailure(constOp, "type conversion failed");
417
418
419
420
421
422
426
427 if (isa(srcType)) {
428 auto dstElementsAttr = cast(constOp.getValue());
430 constOp, dstType,
431 dstElementsAttr.mapValues(
432 signlessType, [&](const APInt &value) { return value; }));
433 return success();
434 }
435 auto srcAttr = cast(constOp.getValue());
436 auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
437 rewriter.replaceOpWithNewOpLLVM::ConstantOp(constOp, dstType, dstAttr);
438 return success();
439 }
441 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
442 return success();
443 }
444 };
445
446 class BitFieldSExtractPattern
448 public:
450
451 LogicalResult
452 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
454 auto srcType = op.getType();
455 auto dstType = getTypeConverter()->convertType(srcType);
456 if (!dstType)
459
460
462 *getTypeConverter(), rewriter);
464 *getTypeConverter(), rewriter);
465
466
467 IntegerType integerType;
468 if (auto vecType = dyn_cast(srcType))
469 integerType = cast(vecType.getElementType());
470 else
471 integerType = cast(srcType);
472
475 isa(srcType)
476 ? rewriter.createLLVM::ConstantOp(
477 loc, dstType,
479 : rewriter.createLLVM::ConstantOp(loc, dstType, baseSize);
480
481
482
483 Value countPlusOffset =
484 rewriter.createLLVM::AddOp(loc, dstType, count, offset);
485 Value amountToShiftLeft =
486 rewriter.createLLVM::SubOp(loc, dstType, size, countPlusOffset);
487 Value baseShiftedLeft = rewriter.createLLVM::ShlOp(
488 loc, dstType, op.getBase(), amountToShiftLeft);
489
490
491 Value amountToShiftRight =
492 rewriter.createLLVM::AddOp(loc, dstType, offset, amountToShiftLeft);
493 rewriter.replaceOpWithNewOpLLVM::AShrOp(op, dstType, baseShiftedLeft,
494 amountToShiftRight);
495 return success();
496 }
497 };
498
499 class BitFieldUExtractPattern
501 public:
503
504 LogicalResult
505 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
507 auto srcType = op.getType();
508 auto dstType = getTypeConverter()->convertType(srcType);
509 if (!dstType)
512
513
515 *getTypeConverter(), rewriter);
517 *getTypeConverter(), rewriter);
518
519
521 Value maskShiftedByCount =
522 rewriter.createLLVM::ShlOp(loc, dstType, minusOne, count);
523 Value mask = rewriter.createLLVM::XOrOp(loc, dstType, maskShiftedByCount,
524 minusOne);
525
526
527 Value shiftedBase =
528 rewriter.createLLVM::LShrOp(loc, dstType, op.getBase(), offset);
529 rewriter.replaceOpWithNewOpLLVM::AndOp(op, dstType, shiftedBase, mask);
530 return success();
531 }
532 };
533
535 public:
537
538 LogicalResult
539 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
541 rewriter.replaceOpWithNewOpLLVM::BrOp(branchOp, adaptor.getOperands(),
542 branchOp.getTarget());
543 return success();
544 }
545 };
546
547 class BranchConditionalConversionPattern
549 public:
552
553 LogicalResult
554 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
556
558 if (auto weights = op.getBranchWeights()) {
560 for (auto weight : weights->getAsRange())
561 weightValues.push_back(weight.getInt());
563 }
564
566 op, op.getCondition(), op.getTrueBlockArguments(),
567 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
568 op.getFalseBlock());
569 return success();
570 }
571 };
572
573
574
575
576 class CompositeExtractPattern
578 public:
580
581 LogicalResult
582 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
584 auto dstType = this->getTypeConverter()->convertType(op.getType());
585 if (!dstType)
587
588 Type containerType = op.getComposite().getType();
589 if (isa(containerType)) {
591 IntegerAttr value = cast(op.getIndices()[0]);
594 op, dstType, adaptor.getComposite(), index);
595 return success();
596 }
597
599 op, adaptor.getComposite(),
601 return success();
602 }
603 };
604
605
606
607
608 class CompositeInsertPattern
610 public:
612
613 LogicalResult
614 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
616 auto dstType = this->getTypeConverter()->convertType(op.getType());
617 if (!dstType)
619
620 Type containerType = op.getComposite().getType();
621 if (isa(containerType)) {
623 IntegerAttr value = cast(op.getIndices()[0]);
626 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
627 return success();
628 }
629
631 op, adaptor.getComposite(), adaptor.getObject(),
633 return success();
634 }
635 };
636
637
638
639 template <typename SPIRVOp, typename LLVMOp>
641 public:
643
644 LogicalResult
645 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
647 auto dstType = this->getTypeConverter()->convertType(op.getType());
648 if (!dstType)
650 rewriter.template replaceOpWithNewOp(
651 op, dstType, adaptor.getOperands(), op->getAttrs());
652 return success();
653 }
654 };
655
656
657
658 class ExecutionModePattern
660 public:
662
663 LogicalResult
664 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
666
667
668
669 ModuleOp module = op->getParentOfType();
670 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
671 std::string moduleName;
672 if (module.getName().has_value())
673 moduleName = "_" + module.getName()->str();
674 else
675 moduleName = "";
676 std::string executionModeInfoName = llvm::formatv(
677 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
678 static_cast<uint32_t>(executionModeAttr.getValue()));
679
683
684
685
686
687
688
691 fields.push_back(llvmI32Type);
692 ArrayAttr values = op.getValues();
693 if (!values.empty()) {
695 fields.push_back(arrayType);
696 }
697 auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
698
699
700 auto global = rewriter.createLLVM::GlobalOp(
701 UnknownLoc::get(context), structType, true,
702 LLVM::Linkage::External, executionModeInfoName, Attribute(),
703 0);
704 Location loc = global.getLoc();
705 Region ®ion = global.getInitializerRegion();
707
708
710 Value structValue = rewriter.createLLVM::PoisonOp(loc, structType);
711 Value executionMode = rewriter.createLLVM::ConstantOp(
712 loc, llvmI32Type,
714 static_cast<uint32_t>(executionModeAttr.getValue())));
715 structValue = rewriter.createLLVM::InsertValueOp(loc, structValue,
716 executionMode, 0);
717
718
719 for (unsigned i = 0, e = values.size(); i < e; ++i) {
720 auto attr = values.getValue()[i];
721 Value entry = rewriter.createLLVM::ConstantOp(loc, llvmI32Type, attr);
722 structValue = rewriter.createLLVM::InsertValueOp(
724 }
727 return success();
728 }
729 };
730
731
732
733
734
735 class GlobalVariablePattern
737 public:
738 template <typename... Args>
739 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
741 std::forward(args)...),
742 clientAPI(clientAPI) {}
743
744 LogicalResult
745 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
747
748
749 if (op.getInitializer())
750 return failure();
751
752 auto srcType = castspirv::PointerType(op.getType());
753 auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
754 if (!dstType)
756
757
758
759
760 auto storageClass = srcType.getStorageClass();
761 switch (storageClass) {
762 case spirv::StorageClass::Input:
763 case spirv::StorageClass::Private:
764 case spirv::StorageClass::Output:
765 case spirv::StorageClass::StorageBuffer:
766 case spirv::StorageClass::UniformConstant:
767 break;
768 default:
769 return failure();
770 }
771
772
773
774
775 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
776 (storageClass == spirv::StorageClass::UniformConstant);
777
778
779
780
781
782 auto linkage = storageClass == spirv::StorageClass::Private
783 ? LLVM::Linkage::Private
784 : LLVM::Linkage::External;
786 op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
788
789
790 if (op.getLocationAttr())
791 newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
792
793 return success();
794 }
795
796 private:
797 spirv::ClientAPI clientAPI;
798 };
799
800
801
802 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
804 public:
806
807 LogicalResult
808 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
810
811 Type fromType = op.getOperand().getType();
812 Type toType = op.getType();
813
814 auto dstType = this->getTypeConverter()->convertType(toType);
815 if (!dstType)
817
819 rewriter.template replaceOpWithNewOp(op, dstType,
820 adaptor.getOperands());
821 return success();
822 }
824 rewriter.template replaceOpWithNewOp(op, dstType,
825 adaptor.getOperands());
826 return success();
827 }
828 return failure();
829 }
830 };
831
832 class FunctionCallPattern
834 public:
836
837 LogicalResult
838 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
840 if (callOp.getNumResults() == 0) {
842 callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
843 newOp.getProperties().operandSegmentSizes = {
844 static_cast<int32_t>(adaptor.getOperands().size()), 0};
846 return success();
847 }
848
849
850 auto dstType = getTypeConverter()->convertType(callOp.getType(0));
851 if (!dstType)
852 return rewriter.notifyMatchFailure(callOp, "type conversion failed");
854 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
855 newOp.getProperties().operandSegmentSizes = {
856 static_cast<int32_t>(adaptor.getOperands().size()), 0};
858 return success();
859 }
860 };
861
862
863 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
865 public:
867
868 LogicalResult
869 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
871
872 auto dstType = this->getTypeConverter()->convertType(op.getType());
873 if (!dstType)
875
876 rewriter.template replaceOpWithNewOpLLVM::FCmpOp(
877 op, dstType, predicate, op.getOperand1(), op.getOperand2());
878 return success();
879 }
880 };
881
882
883 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
885 public:
887
888 LogicalResult
889 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
891
892 auto dstType = this->getTypeConverter()->convertType(op.getType());
893 if (!dstType)
895
896 rewriter.template replaceOpWithNewOpLLVM::ICmpOp(
897 op, dstType, predicate, op.getOperand1(), op.getOperand2());
898 return success();
899 }
900 };
901
902 class InverseSqrtPattern
904 public:
906
907 LogicalResult
908 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
910 auto srcType = op.getType();
911 auto dstType = getTypeConverter()->convertType(srcType);
912 if (!dstType)
914
917 Value sqrt = rewriter.createLLVM::SqrtOp(loc, dstType, op.getOperand());
919 return success();
920 }
921 };
922
923
924 template
926 public:
928
929 LogicalResult
930 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
932 if (!op.getMemoryAccess()) {
934 *this->getTypeConverter(), 0,
935 false,
936 false);
937 }
938 auto memoryAccess = *op.getMemoryAccess();
939 switch (memoryAccess) {
940 case spirv::MemoryAccess::Aligned:
942 case spirv::MemoryAccess::Nontemporal:
943 case spirv::MemoryAccess::Volatile: {
944 unsigned alignment =
945 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
946 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
947 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
949 *this->getTypeConverter(), alignment,
950 isVolatile, isNonTemporal);
951 }
952 default:
953
954 return failure();
955 }
956 }
957 };
958
959
960 template
962 public:
964
965 LogicalResult
966 matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
968 auto srcType = notOp.getType();
969 auto dstType = this->getTypeConverter()->convertType(srcType);
970 if (!dstType)
972
973 Location loc = notOp.getLoc();
975 auto mask =
976 isa(srcType)
977 ? rewriter.createLLVM::ConstantOp(
978 loc, dstType,
980 : rewriter.createLLVM::ConstantOp(loc, dstType, minusOne);
981 rewriter.template replaceOpWithNewOpLLVM::XOrOp(notOp, dstType,
982 notOp.getOperand(), mask);
983 return success();
984 }
985 };
986
987
988 template
990 public:
992
993 LogicalResult
994 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
997 return success();
998 }
999 };
1000
1002 public:
1004
1005 LogicalResult
1006 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1010 return success();
1011 }
1012 };
1013
1015 public:
1017
1018 LogicalResult
1019 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1022 adaptor.getOperands());
1023 return success();
1024 }
1025 };
1026
1028 StringRef name,
1030 Type resultType,
1031 bool convergent = true) {
1032 auto func = dyn_cast_or_nullLLVM::LLVMFuncOp(
1034 if (func)
1035 return func;
1036
1038 func = b.createLLVM::LLVMFuncOp(
1039 symbolTable->getLoc(), name,
1041 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1042 func.setConvergent(convergent);
1043 func.setNoUnwind(true);
1044 func.setWillReturn(true);
1045 return func;
1046 }
1047
1049 LLVM::LLVMFuncOp func,
1051 auto call = builder.createLLVM::CallOp(loc, func, args);
1052 call.setCConv(func.getCConv());
1053 call.setConvergentAttr(func.getConvergentAttr());
1054 call.setNoUnwindAttr(func.getNoUnwindAttr());
1055 call.setWillReturnAttr(func.getWillReturnAttr());
1056 return call;
1057 }
1058
1059 template
1061 public:
1063
1065
1066 static constexpr StringRef getFuncName();
1067
1068 LogicalResult
1069 matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1071 constexpr StringRef funcName = getFuncName();
1073 controlBarrierOp->template getParentWithTraitOpTrait::SymbolTable();
1074
1076
1077 Type voidTy = rewriter.getTypeLLVM::LLVMVoidType();
1078 LLVM::LLVMFuncOp func =
1080
1081 Location loc = controlBarrierOp->getLoc();
1082 Value execution = rewriter.createLLVM::ConstantOp(
1083 loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
1084 Value memory = rewriter.createLLVM::ConstantOp(
1085 loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
1086 Value semantics = rewriter.createLLVM::ConstantOp(
1087 loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
1088
1090 {execution, memory, semantics});
1091
1092 rewriter.replaceOp(controlBarrierOp, call);
1093 return success();
1094 }
1095 };
1096
1097 namespace {
1098
1099 StringRef getTypeMangling(Type type, bool isSigned) {
1101 .Case([](auto) { return "Dh"; })
1102 .Case([](auto) { return "f"; })
1103 .Case([](auto) { return "d"; })
1104 .Case([isSigned](IntegerType intTy) {
1105 switch (intTy.getWidth()) {
1106 case 1:
1107 return "b";
1108 case 8:
1109 return (isSigned) ? "a" : "c";
1110 case 16:
1111 return (isSigned) ? "s" : "t";
1112 case 32:
1113 return (isSigned) ? "i" : "j";
1114 case 64:
1115 return (isSigned) ? "l" : "m";
1116 default:
1117 llvm_unreachable("Unsupported integer width");
1118 }
1119 })
1120 .Default([](auto) {
1121 llvm_unreachable("No mangling defined");
1122 return "";
1123 });
1124 }
1125
1126 template
1127 constexpr StringLiteral getGroupFuncName();
1128
1129 template <>
1130 constexpr StringLiteral getGroupFuncNamespirv::GroupIAddOp() {
1131 return "_Z17__spirv_GroupIAddii";
1132 }
1133 template <>
1134 constexpr StringLiteral getGroupFuncNamespirv::GroupFAddOp() {
1135 return "_Z17__spirv_GroupFAddii";
1136 }
1137 template <>
1138 constexpr StringLiteral getGroupFuncNamespirv::GroupSMinOp() {
1139 return "_Z17__spirv_GroupSMinii";
1140 }
1141 template <>
1142 constexpr StringLiteral getGroupFuncNamespirv::GroupUMinOp() {
1143 return "_Z17__spirv_GroupUMinii";
1144 }
1145 template <>
1146 constexpr StringLiteral getGroupFuncNamespirv::GroupFMinOp() {
1147 return "_Z17__spirv_GroupFMinii";
1148 }
1149 template <>
1150 constexpr StringLiteral getGroupFuncNamespirv::GroupSMaxOp() {
1151 return "_Z17__spirv_GroupSMaxii";
1152 }
1153 template <>
1154 constexpr StringLiteral getGroupFuncNamespirv::GroupUMaxOp() {
1155 return "_Z17__spirv_GroupUMaxii";
1156 }
1157 template <>
1158 constexpr StringLiteral getGroupFuncNamespirv::GroupFMaxOp() {
1159 return "_Z17__spirv_GroupFMaxii";
1160 }
1161 template <>
1162 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformIAddOp() {
1163 return "_Z27__spirv_GroupNonUniformIAddii";
1164 }
1165 template <>
1166 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformFAddOp() {
1167 return "_Z27__spirv_GroupNonUniformFAddii";
1168 }
1169 template <>
1170 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformIMulOp() {
1171 return "_Z27__spirv_GroupNonUniformIMulii";
1172 }
1173 template <>
1174 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformFMulOp() {
1175 return "_Z27__spirv_GroupNonUniformFMulii";
1176 }
1177 template <>
1178 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformSMinOp() {
1179 return "_Z27__spirv_GroupNonUniformSMinii";
1180 }
1181 template <>
1182 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformUMinOp() {
1183 return "_Z27__spirv_GroupNonUniformUMinii";
1184 }
1185 template <>
1186 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformFMinOp() {
1187 return "_Z27__spirv_GroupNonUniformFMinii";
1188 }
1189 template <>
1190 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformSMaxOp() {
1191 return "_Z27__spirv_GroupNonUniformSMaxii";
1192 }
1193 template <>
1194 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformUMaxOp() {
1195 return "_Z27__spirv_GroupNonUniformUMaxii";
1196 }
1197 template <>
1198 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformFMaxOp() {
1199 return "_Z27__spirv_GroupNonUniformFMaxii";
1200 }
1201 template <>
1202 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformBitwiseAndOp() {
1203 return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1204 }
1205 template <>
1206 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformBitwiseOrOp() {
1207 return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1208 }
1209 template <>
1210 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformBitwiseXorOp() {
1211 return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1212 }
1213 template <>
1214 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformLogicalAndOp() {
1215 return "_Z33__spirv_GroupNonUniformLogicalAndii";
1216 }
1217 template <>
1218 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformLogicalOrOp() {
1219 return "_Z32__spirv_GroupNonUniformLogicalOrii";
1220 }
1221 template <>
1222 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformLogicalXorOp() {
1223 return "_Z33__spirv_GroupNonUniformLogicalXorii";
1224 }
1225 }
1226
1227 template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
1229 public:
1231
1232 LogicalResult
1233 matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
1235
1236 Type retTy = op.getResult().getType();
1238 return failure();
1239 }
1240 SmallString<36> funcName = getGroupFuncName();
1241 funcName += getTypeMangling(retTy, false);
1242
1245 if constexpr (NonUniform) {
1246 if (adaptor.getClusterSize()) {
1247 funcName += "j";
1248 paramTypes.push_back(i32Ty);
1249 }
1250 }
1251
1253 op->template getParentWithTraitOpTrait::SymbolTable();
1254
1255 LLVM::LLVMFuncOp func =
1257
1259 Value scope = rewriter.createLLVM::ConstantOp(
1260 loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
1261 Value groupOp = rewriter.createLLVM::ConstantOp(
1262 loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
1264 operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1265
1268 return success();
1269 }
1270 };
1271
1272 template <>
1273 constexpr StringRef
1274 ControlBarrierPatternspirv::ControlBarrierOp::getFuncName() {
1275 return "_Z22__spirv_ControlBarrieriii";
1276 }
1277
1278 template <>
1279 constexpr StringRef
1280 ControlBarrierPatternspirv::INTELControlBarrierArriveOp::getFuncName() {
1281 return "_Z33__spirv_ControlBarrierArriveINTELiii";
1282 }
1283
1284 template <>
1285 constexpr StringRef
1286 ControlBarrierPatternspirv::INTELControlBarrierWaitOp::getFuncName() {
1287 return "_Z31__spirv_ControlBarrierWaitINTELiii";
1288 }
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1339 public:
1341
1342 LogicalResult
1343 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1345
1347 return failure();
1348
1349
1350 if (loopOp.getBody().empty()) {
1351 rewriter.eraseOp(loopOp);
1352 return success();
1353 }
1354
1355 Location loc = loopOp.getLoc();
1356
1357
1358
1361 Block *endBlock = rewriter.splitBlock(currentBlock, position);
1362
1363
1364
1365 Block *entryBlock = loopOp.getEntryBlock();
1366 assert(entryBlock->getOperations().size() == 1);
1367 auto brOp = dyn_castspirv::BranchOp(entryBlock->getOperations().front());
1368 if (!brOp)
1369 return failure();
1370 Block *headerBlock = loopOp.getHeaderBlock();
1372 rewriter.createLLVM::BrOp(loc, brOp.getBlockArguments(), headerBlock);
1374
1375
1376 Block *mergeBlock = loopOp.getMergeBlock();
1380 rewriter.createLLVM::BrOp(loc, terminatorOperands, endBlock);
1381
1384 return success();
1385 }
1386 };
1387
1388
1389
1390
1392 public:
1394
1395 LogicalResult
1396 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1398
1399
1400
1402 return failure();
1403
1404
1405
1406
1407
1408 if (op.getBody().getBlocks().size() <= 2) {
1410 return success();
1411 }
1412
1414
1415
1416
1420 auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1421
1422
1423
1424
1425
1426 auto *headerBlock = op.getHeaderBlock();
1427 assert(headerBlock->getOperations().size() == 1);
1428 auto condBrOp = dyn_castspirv::BranchConditionalOp(
1430 if (!condBrOp)
1431 return failure();
1433
1434
1435 auto *mergeBlock = op.getMergeBlock();
1439 rewriter.createLLVM::BrOp(loc, terminatorOperands, continueBlock);
1440
1441
1442 Block *trueBlock = condBrOp.getTrueBlock();
1443 Block *falseBlock = condBrOp.getFalseBlock();
1445 rewriter.createLLVM::CondBrOp(loc, condBrOp.getCondition(), trueBlock,
1446 condBrOp.getTrueTargetOperands(),
1447 falseBlock,
1448 condBrOp.getFalseTargetOperands());
1449
1451 rewriter.replaceOp(op, continueBlock->getArguments());
1452 return success();
1453 }
1454 };
1455
1456
1457
1458
1459
1460 template <typename SPIRVOp, typename LLVMOp>
1462 public:
1464
1465 LogicalResult
1466 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1468
1469 auto dstType = this->getTypeConverter()->convertType(op.getType());
1470 if (!dstType)
1472
1473 Type op1Type = op.getOperand1().getType();
1474 Type op2Type = op.getOperand2().getType();
1475
1476 if (op1Type == op2Type) {
1477 rewriter.template replaceOpWithNewOp(op, dstType,
1478 adaptor.getOperands());
1479 return success();
1480 }
1481
1482 std::optional<uint64_t> dstTypeWidth =
1484 std::optional<uint64_t> op2TypeWidth =
1486
1487 if (!dstTypeWidth || !op2TypeWidth)
1488 return failure();
1489
1492 if (op2TypeWidth < dstTypeWidth) {
1494 extended = rewriter.template createLLVM::ZExtOp(
1495 loc, dstType, adaptor.getOperand2());
1496 } else {
1497 extended = rewriter.template createLLVM::SExtOp(
1498 loc, dstType, adaptor.getOperand2());
1499 }
1500 } else if (op2TypeWidth == dstTypeWidth) {
1501 extended = adaptor.getOperand2();
1502 } else {
1503 return failure();
1504 }
1505
1506 Value result = rewriter.template create(
1507 loc, dstType, adaptor.getOperand1(), extended);
1509 return success();
1510 }
1511 };
1512
1514 public:
1516
1517 LogicalResult
1518 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1520 auto dstType = getTypeConverter()->convertType(tanOp.getType());
1521 if (!dstType)
1522 return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
1523
1524 Location loc = tanOp.getLoc();
1525 Value sin = rewriter.createLLVM::SinOp(loc, dstType, tanOp.getOperand());
1526 Value cos = rewriter.createLLVM::CosOp(loc, dstType, tanOp.getOperand());
1527 rewriter.replaceOpWithNewOpLLVM::FDivOp(tanOp, dstType, sin, cos);
1528 return success();
1529 }
1530 };
1531
1532
1533
1534
1535
1536
1537
1539 public:
1541
1542 LogicalResult
1543 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1545 auto srcType = tanhOp.getType();
1546 auto dstType = getTypeConverter()->convertType(srcType);
1547 if (!dstType)
1548 return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
1549
1550 Location loc = tanhOp.getLoc();
1552 Value multiplied =
1553 rewriter.createLLVM::FMulOp(loc, dstType, two, tanhOp.getOperand());
1554 Value exponential = rewriter.createLLVM::ExpOp(loc, dstType, multiplied);
1556 Value numerator =
1557 rewriter.createLLVM::FSubOp(loc, dstType, exponential, one);
1558 Value denominator =
1559 rewriter.createLLVM::FAddOp(loc, dstType, exponential, one);
1560 rewriter.replaceOpWithNewOpLLVM::FDivOp(tanhOp, dstType, numerator,
1561 denominator);
1562 return success();
1563 }
1564 };
1565
1567 public:
1569
1570 LogicalResult
1571 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1573 auto srcType = varOp.getType();
1574
1575 auto pointerTo = castspirv::PointerType(srcType).getPointeeType();
1576 auto init = varOp.getInitializer();
1577 if (init && !pointerTo.isIntOrFloat() && !isa(pointerTo))
1578 return failure();
1579
1580 auto dstType = getTypeConverter()->convertType(srcType);
1581 if (!dstType)
1582 return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1583
1584 Location loc = varOp.getLoc();
1586 if (!init) {
1587 auto elementType = getTypeConverter()->convertType(pointerTo);
1588 if (!elementType)
1589 return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1590 rewriter.replaceOpWithNewOpLLVM::AllocaOp(varOp, dstType, elementType,
1591 size);
1592 return success();
1593 }
1594 auto elementType = getTypeConverter()->convertType(pointerTo);
1595 if (!elementType)
1596 return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1597 Value allocated =
1598 rewriter.createLLVM::AllocaOp(loc, dstType, elementType, size);
1599 rewriter.createLLVM::StoreOp(loc, adaptor.getInitializer(), allocated);
1600 rewriter.replaceOp(varOp, allocated);
1601 return success();
1602 }
1603 };
1604
1605
1606
1607
1608
1609 class BitcastConversionPattern
1611 public:
1613
1614 LogicalResult
1615 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1617 auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1618 if (!dstType)
1619 return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
1620
1621
1622 if (isaLLVM::LLVMPointerType(dstType)) {
1623 rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1624 return success();
1625 }
1626
1628 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1629 return success();
1630 }
1631 };
1632
1633
1634
1635
1636
1638 public:
1640
1641 LogicalResult
1642 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1644
1645
1646
1647 auto funcType = funcOp.getFunctionType();
1649 funcType.getNumInputs());
1650 auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())
1651 ->convertFunctionSignature(
1652 funcType, false,
1653 false, signatureConverter);
1654 if (!llvmType)
1655 return failure();
1656
1657
1658 Location loc = funcOp.getLoc();
1659 StringRef name = funcOp.getName();
1660 auto newFuncOp = rewriter.createLLVM::LLVMFuncOp(loc, name, llvmType);
1661
1662
1663 MLIRContext *context = funcOp.getContext();
1664 switch (funcOp.getFunctionControl()) {
1665 case spirv::FunctionControl::Inline:
1666 newFuncOp.setAlwaysInline(true);
1667 break;
1668 case spirv::FunctionControl::DontInline:
1669 newFuncOp.setNoInline(true);
1670 break;
1671
1672 #define DISPATCH(functionControl, llvmAttr) \
1673 case functionControl: \
1674 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1675 break;
1676
1677 DISPATCH(spirv::FunctionControl::Pure,
1679 DISPATCH(spirv::FunctionControl::Const,
1681
1682 #undef DISPATCH
1683
1684
1685
1686 default:
1687 break;
1688 }
1689
1691 newFuncOp.end());
1693 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1694 return failure();
1695 }
1696 rewriter.eraseOp(funcOp);
1697 return success();
1698 }
1699 };
1700
1701
1702
1703
1704
1706 public:
1708
1709 LogicalResult
1710 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1712
1713 auto newModuleOp =
1714 rewriter.create(spvModuleOp.getLoc(), spvModuleOp.getName());
1715 rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1716
1717
1718 rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1719 rewriter.eraseOp(spvModuleOp);
1720 return success();
1721 }
1722 };
1723
1724
1725
1726
1727
1728 class VectorShufflePattern
1730 public:
1732 LogicalResult
1733 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1736 auto components = adaptor.getComponents();
1737 auto vector1 = adaptor.getVector1();
1738 auto vector2 = adaptor.getVector2();
1739 int vector1Size = cast(vector1.getType()).getNumElements();
1740 int vector2Size = cast(vector2.getType()).getNumElements();
1741 if (vector1Size == vector2Size) {
1743 op, vector1, vector2,
1744 LLVM::convertArrayToIndices<int32_t>(components));
1745 return success();
1746 }
1747
1748 auto dstType = getTypeConverter()->convertType(op.getType());
1749 if (!dstType)
1751 auto scalarType = cast(dstType).getElementType();
1752 auto componentsArray = components.getValue();
1753 auto *context = rewriter.getContext();
1755 Value targetOp = rewriter.createLLVM::PoisonOp(loc, dstType);
1756 for (unsigned i = 0; i < componentsArray.size(); i++) {
1757 if (!isa(componentsArray[i]))
1758 return op.emitError("unable to support non-constant component");
1759
1760 int indexVal = cast(componentsArray[i]).getInt();
1761 if (indexVal == -1)
1762 continue;
1763
1764 int offsetVal = 0;
1765 Value baseVector = vector1;
1766 if (indexVal >= vector1Size) {
1767 offsetVal = vector1Size;
1768 baseVector = vector2;
1769 }
1770
1771 Value dstIndex = rewriter.createLLVM::ConstantOp(
1773 Value index = rewriter.createLLVM::ConstantOp(
1774 loc, llvmI32Type,
1776
1777 auto extractOp = rewriter.createLLVM::ExtractElementOp(
1778 loc, scalarType, baseVector, index);
1779 targetOp = rewriter.createLLVM::InsertElementOp(loc, dstType, targetOp,
1780 extractOp, dstIndex);
1781 }
1782 rewriter.replaceOp(op, targetOp);
1783 return success();
1784 }
1785 };
1786 }
1787
1788
1789
1790
1791
1793 spirv::ClientAPI clientAPI) {
1796 });
1799 });
1802 });
1805 });
1806 }
1807
1810 spirv::ClientAPI clientAPI) {
1812
1813 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1814 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1815 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1816 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1817 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1818 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1819 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1820 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1821 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1822 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1823 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1824 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1825 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1826
1827
1828 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1829 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1830 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1831 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1832 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1833 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1834 NotPatternspirv::NotOp,
1835
1836
1837 BitcastConversionPattern,
1838 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1839 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1840 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1841 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1842 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1843 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1844 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1845
1846
1847 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1848 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1849 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1850 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1851 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1852 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1853 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1854 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1855 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1856 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1857 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1858 LLVM::FCmpPredicate::uge>,
1859 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1860 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1861 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1862 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1863 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1864 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1865 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1866 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1867 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1868 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1869 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1870
1871
1872 ConstantScalarAndVectorPattern,
1873
1874
1875 BranchConversionPattern, BranchConditionalConversionPattern,
1876 FunctionCallPattern, LoopPattern, SelectionPattern,
1877 ErasePatternspirv::MergeOp,
1878
1879
1880 ErasePatternspirv::EntryPointOp, ExecutionModePattern,
1881
1882
1883 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1884 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1885 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1886 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1887 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1888 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1889 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1890 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1891 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1892 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1893 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1894 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1895 InverseSqrtPattern, TanPattern, TanhPattern,
1896
1897
1898 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1899 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1900 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1901 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1902 NotPatternspirv::LogicalNotOp,
1903
1904
1905 AccessChainPattern, AddressOfPattern, LoadStorePatternspirv::LoadOp,
1906 LoadStorePatternspirv::StoreOp, VariablePattern,
1907
1908
1909 CompositeExtractPattern, CompositeInsertPattern,
1910 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1911 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1912 VectorShufflePattern,
1913
1914
1915 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1916 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1917 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1918
1919
1920 ReturnPattern, ReturnValuePattern,
1921
1922
1923 ControlBarrierPatternspirv::ControlBarrierOp,
1924 ControlBarrierPatternspirv::INTELControlBarrierArriveOp,
1925 ControlBarrierPatternspirv::INTELControlBarrierWaitOp,
1926
1927
1928 GroupReducePatternspirv::GroupIAddOp,
1929 GroupReducePatternspirv::GroupFAddOp,
1930 GroupReducePatternspirv::GroupFMinOp,
1931 GroupReducePatternspirv::GroupUMinOp,
1932 GroupReducePattern<spirv::GroupSMinOp, true>,
1933 GroupReducePatternspirv::GroupFMaxOp,
1934 GroupReducePatternspirv::GroupUMaxOp,
1935 GroupReducePattern<spirv::GroupSMaxOp, true>,
1936 GroupReducePattern<spirv::GroupNonUniformIAddOp, false,
1937 true>,
1938 GroupReducePattern<spirv::GroupNonUniformFAddOp, false,
1939 true>,
1940 GroupReducePattern<spirv::GroupNonUniformIMulOp, false,
1941 true>,
1942 GroupReducePattern<spirv::GroupNonUniformFMulOp, false,
1943 true>,
1944 GroupReducePattern<spirv::GroupNonUniformSMinOp, true,
1945 true>,
1946 GroupReducePattern<spirv::GroupNonUniformUMinOp, false,
1947 true>,
1948 GroupReducePattern<spirv::GroupNonUniformFMinOp, false,
1949 true>,
1950 GroupReducePattern<spirv::GroupNonUniformSMaxOp, true,
1951 true>,
1952 GroupReducePattern<spirv::GroupNonUniformUMaxOp, false,
1953 true>,
1954 GroupReducePattern<spirv::GroupNonUniformFMaxOp, false,
1955 true>,
1956 GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, false,
1957 true>,
1958 GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, false,
1959 true>,
1960 GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, false,
1961 true>,
1962 GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, false,
1963 true>,
1964 GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, false,
1965 true>,
1966 GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, false,
1967 true>>(patterns.getContext(),
1968 typeConverter);
1969
1970 patterns.add(clientAPI, patterns.getContext(),
1971 typeConverter);
1972 }
1973
1976 patterns.add(patterns.getContext(), typeConverter);
1977 }
1978
1981 patterns.add(patterns.getContext(), typeConverter);
1982 }
1983
1984
1985
1986
1987
1988
1989 static constexpr StringRef kBinding = "binding";
1992 auto spvModules = module.getOpsspirv::ModuleOp();
1993 for (auto spvModule : spvModules) {
1994 spvModule.walk([&](spirv::GlobalVariableOp op) {
1995 IntegerAttr descriptorSet =
1997 IntegerAttr binding = op->getAttrOfType(kBinding);
1998
1999
2000 if (descriptorSet && binding) {
2001
2002
2003 auto moduleAndName =
2004 spvModule.getName().has_value()
2005 ? spvModule.getName()->str() + "_" + op.getSymName().str()
2006 : op.getSymName().str();
2007 std::string name =
2008 llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
2009 std::to_string(descriptorSet.getInt()),
2010 std::to_string(binding.getInt()));
2011 auto nameAttr = StringAttr::get(op->getContext(), name);
2012
2013
2014
2016 op.emitError("unable to replace all symbol uses for ") << name;
2020 }
2021 });
2022 }
2023 }
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static MLIRContext * getContext(OpFoldResult val)
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
StorageClass getStorageClass() const
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
const FrozenRewritePatternSet & patterns
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.