MLIR: lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
24 #include "llvm/Support/Debug.h"
25 #include
26 #include
27
28 #define DEBUG_TYPE "memref-to-spirv-pattern"
29
30 using namespace mlir;
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
48 int targetBits, OpBuilder &builder) {
49 assert(targetBits % sourceBits == 0);
51 IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.createOrFoldspirv::ConstantOp(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
54 auto srcBitsValue =
55 builder.createOrFoldspirv::ConstantOp(loc, type, srcBitsAttr);
56 auto m = builder.createOrFoldspirv::UModOp(loc, srcIdx, idx);
57 return builder.createOrFoldspirv::IMulOp(loc, type, m, srcBitsValue);
58 }
59
60
61
62
63
64
65
66
67
70 spirv::AccessChainOp op, int sourceBits,
71 int targetBits, OpBuilder &builder) {
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.getLoc();
74 Value lastDim = op->getOperand(op.getNumOperands() - 1);
76 IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.createOrFoldspirv::ConstantOp(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
79
80 assert(indices.size() == 2);
81 indices.back() = builder.createOrFoldspirv::SDivOp(loc, lastDim, idx);
82 Type t = typeConverter.convertType(op.getComponentPtr().getType());
83 return builder.createspirv::AccessChainOp(loc, t, op.getBasePtr(), indices);
84 }
85
86
91 return srcBool;
93 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
94 return builder.createOrFoldspirv::SelectOp(loc, dstType, srcBool, one,
95 zero);
96 }
97
98
99
102 IntegerType dstType = cast(mask.getType());
103 int targetBits = static_cast<int>(dstType.getWidth());
105 assert(valueBits <= targetBits);
106
107 if (valueBits == 1) {
108 value = castBoolToIntN(loc, value, dstType, builder);
109 } else {
110 if (valueBits < targetBits) {
111 value = builder.createspirv::UConvertOp(
113 }
114
115 value = builder.createOrFoldspirv::BitwiseAndOp(loc, value, mask);
116 }
117 return builder.createOrFoldspirv::ShiftLeftLogicalOp(loc, value.getType(),
118 value, offset);
119 }
120
121
122
124 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
125 auto sc = dyn_cast_or_nullspirv::StorageClassAttr(type.getMemorySpace());
126 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
127 return false;
128 } else if (isamemref::AllocaOp(allocOp)) {
129 auto sc = dyn_cast_or_nullspirv::StorageClassAttr(type.getMemorySpace());
130 if (!sc || sc.getValue() != spirv::StorageClass::Function)
131 return false;
132 } else {
133 return false;
134 }
135
136
137
138 if (!type.hasStaticShape())
139 return false;
140
141 Type elementType = type.getElementType();
142 if (auto vecType = dyn_cast(elementType))
143 elementType = vecType.getElementType();
145 }
146
147
148
149
151 auto sc = dyn_cast_or_nullspirv::StorageClassAttr(type.getMemorySpace());
152 switch (sc.getValue()) {
153 case spirv::StorageClass::StorageBuffer:
154 return spirv::Scope::Device;
155 case spirv::StorageClass::Workgroup:
156 return spirv::Scope::Workgroup;
157 default:
158 break;
159 }
160 return {};
161 }
162
163
166 return srcInt;
167
169 return builder.createOrFoldspirv::INotEqualOp(loc, srcInt, one);
170 }
171
172
173
174
175
176
177
178
179
180 namespace {
181
182
183 class AllocaOpPattern final : public OpConversionPatternmemref::AllocaOp {
184 public:
186
187 LogicalResult
188 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
190 };
191
192
193
194
195
196 class AllocOpPattern final : public OpConversionPatternmemref::AllocOp {
197 public:
199
200 LogicalResult
201 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
203 };
204
205
206 class AtomicRMWOpPattern final
208 public:
210
211 LogicalResult
212 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
214 };
215
216
217
218 class DeallocOpPattern final : public OpConversionPatternmemref::DeallocOp {
219 public:
221
222 LogicalResult
223 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
225 };
226
227
228 class IntLoadOpPattern final : public OpConversionPatternmemref::LoadOp {
229 public:
231
232 LogicalResult
233 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
235 };
236
237
239 public:
241
242 LogicalResult
243 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
245 };
246
247
248 class IntStoreOpPattern final : public OpConversionPatternmemref::StoreOp {
249 public:
251
252 LogicalResult
253 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
255 };
256
257
258 class MemorySpaceCastOpPattern final
260 public:
262
263 LogicalResult
264 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
266 };
267
268
269 class StoreOpPattern final : public OpConversionPatternmemref::StoreOp {
270 public:
272
273 LogicalResult
274 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
276 };
277
278 class ReinterpretCastPattern final
280 public:
282
283 LogicalResult
284 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
286 };
287
289 public:
291
292 LogicalResult
293 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
295 Value src = adaptor.getSource();
297
298 const TypeConverter *converter = getTypeConverter();
300 if (srcType != dstType)
302 diag << "types doesn't match: " << srcType << " and " << dstType;
303 });
304
306 return success();
307 }
308 };
309
310
311 class ExtractAlignedPointerAsIndexOpPattern final
313 public:
315
316 LogicalResult
317 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
318 OpAdaptor adaptor,
320 };
321 }
322
323
324
325
326
327 LogicalResult
328 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
330 MemRefType allocType = allocaOp.getType();
332 return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
333
334
335 Type spirvType = getTypeConverter()->convertType(allocType);
336 if (!spirvType)
337 return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
338
340 spirv::StorageClass::Function,
341 nullptr);
342 return success();
343 }
344
345
346
347
348
349 LogicalResult
350 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
352 MemRefType allocType = operation.getType();
354 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
355
356
357 Type spirvType = getTypeConverter()->convertType(allocType);
358 if (!spirvType)
359 return rewriter.notifyMatchFailure(operation, "type conversion failed");
360
361
364 if (!parent)
365 return failure();
366 Location loc = operation.getLoc();
367 spirv::GlobalVariableOp varOp;
368 {
372 auto varOps = entryBlock.getOpsspirv::GlobalVariableOp();
373 std::string varName =
374 std::string("__workgroup_mem__") +
375 std::to_string(std::distance(varOps.begin(), varOps.end()));
376 varOp = rewriter.createspirv::GlobalVariableOp(loc, spirvType, varName,
377 nullptr);
378 }
379
380
382 return success();
383 }
384
385
386
387
388
389 LogicalResult
390 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
391 OpAdaptor adaptor,
393 if (isa(atomicOp.getType()))
395 "unimplemented floating-point case");
396
397 auto memrefType = cast(atomicOp.getMemref().getType());
398 std::optionalspirv::Scope scope = getAtomicOpScope(memrefType);
399 if (!scope)
401 "unsupported memref memory space");
402
403 auto &typeConverter = *getTypeConverter();
404 Type resultType = typeConverter.convertType(atomicOp.getType());
405 if (!resultType)
407 "failed to convert result type");
408
409 auto loc = atomicOp.getLoc();
412 adaptor.getIndices(), loc, rewriter);
413
414 if (!ptr)
415 return failure();
416
417 #define ATOMIC_CASE(kind, spirvOp) \
418 case arith::AtomicRMWKind::kind: \
419 rewriter.replaceOpWithNewOpspirv::spirvOp( \
420 atomicOp, resultType, ptr, *scope, \
421 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
422 break
423
424 switch (atomicOp.getKind()) {
432 default:
433 return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
434 }
435
436 #undef ATOMIC_CASE
437
438 return success();
439 }
440
441
442
443
444
445 LogicalResult
446 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
447 OpAdaptor adaptor,
449 MemRefType deallocType = cast(operation.getMemref().getType());
451 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
452 rewriter.eraseOp(operation);
453 return success();
454 }
455
456
457
458
459
463 };
464
465
466
467 static FailureOr
470
472 if (isNontemporal) {
473 memoryAccess = spirv::MemoryAccess::Nontemporal;
474 }
475
476 auto ptrType = castspirv::PointerType(accessedPtr.getType());
477 if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
480 }
482 IntegerAttr{}};
483 }
484
485
486 auto pointeeType = dyn_castspirv::ScalarType(ptrType.getPointeeType());
487 if (!pointeeType)
488 return failure();
489
490
491 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
492 if (!sizeInBytes.has_value())
493 return failure();
494
495 memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
499 }
500
501
502
503
504 template
505 static FailureOr
507 static_assert(
508 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
509 "Must be called on either memref::LoadOp or memref::StoreOp");
510
511 Operation *memrefAccessOp = loadOrStoreOp.getOperation();
512 auto memrefMemAccess = memrefAccessOp->getAttrOfTypespirv::MemoryAccessAttr(
513 spirv::attributeNamespirv::MemoryAccess());
514 auto memrefAlignment =
515 memrefAccessOp->getAttrOfType("alignment");
516 if (memrefMemAccess && memrefAlignment)
518
520 loadOrStoreOp.getNontemporal());
521 }
522
523 LogicalResult
524 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
526 auto loc = loadOp.getLoc();
527 auto memrefType = cast(loadOp.getMemref().getType());
528 if (!memrefType.getElementType().isSignlessInteger())
529 return failure();
530
531 const auto &typeConverter = *getTypeConverter();
532 Value accessChain =
534 adaptor.getIndices(), loc, rewriter);
535
536 if (!accessChain)
537 return failure();
538
539 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
540 bool isBool = srcBits == 1;
541 if (isBool)
542 srcBits = typeConverter.getOptions().boolNumBits;
543
544 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
545 if (!pointerType)
546 return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
547
548 Type pointeeType = pointerType.getPointeeType();
549 Type dstType;
550 if (typeConverter.allows(spirv::Capability::Kernel)) {
551 if (auto arrayType = dyn_castspirv::ArrayType(pointeeType))
552 dstType = arrayType.getElementType();
553 else
554 dstType = pointeeType;
555 } else {
556
557 Type structElemType =
558 castspirv::StructType(pointeeType).getElementType(0);
559 if (auto arrayType = dyn_castspirv::ArrayType(structElemType))
560 dstType = arrayType.getElementType();
561 else
562 dstType = castspirv::RuntimeArrayType(structElemType).getElementType();
563 }
565 assert(dstBits % srcBits == 0);
566
567
568
569 if (srcBits == dstBits) {
571 if (failed(memoryRequirements))
573 loadOp, "failed to determine memory requirements");
574
575 auto [memoryAccess, alignment] = *memoryRequirements;
576 Value loadVal = rewriter.createspirv::LoadOp(loc, accessChain,
577 memoryAccess, alignment);
578 if (isBool)
580 rewriter.replaceOp(loadOp, loadVal);
581 return success();
582 }
583
584
585
586 if (typeConverter.allows(spirv::Capability::Kernel))
587 return failure();
588
589 auto accessChainOp = accessChain.getDefiningOpspirv::AccessChainOp();
590 if (!accessChainOp)
591 return failure();
592
593
594
595
596 assert(accessChainOp.getIndices().size() == 2);
598 srcBits, dstBits, rewriter);
600 if (failed(memoryRequirements))
602 loadOp, "failed to determine memory requirements");
603
604 auto [memoryAccess, alignment] = *memoryRequirements;
605 Value spvLoadOp = rewriter.createspirv::LoadOp(loc, dstType, adjustedPtr,
606 memoryAccess, alignment);
607
608
609
610 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
612 Value result = rewriter.createOrFoldspirv::ShiftRightArithmeticOp(
613 loc, spvLoadOp.getType(), spvLoadOp, offset);
614
615
617 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
618 result =
619 rewriter.createOrFoldspirv::BitwiseAndOp(loc, dstType, result, mask);
620
621
622
623
624 IntegerAttr shiftValueAttr =
627 rewriter.createOrFoldspirv::ConstantOp(loc, dstType, shiftValueAttr);
628 result = rewriter.createOrFoldspirv::ShiftLeftLogicalOp(loc, dstType,
630 result = rewriter.createOrFoldspirv::ShiftRightArithmeticOp(
632
633 rewriter.replaceOp(loadOp, result);
634
635 assert(accessChainOp.use_empty());
636 rewriter.eraseOp(accessChainOp);
637
638 return success();
639 }
640
641 LogicalResult
642 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
644 auto memrefType = cast(loadOp.getMemref().getType());
645 if (memrefType.getElementType().isSignlessInteger())
646 return failure();
648 *getTypeConverter(), memrefType, adaptor.getMemref(),
649 adaptor.getIndices(), loadOp.getLoc(), rewriter);
650
651 if (!loadPtr)
652 return failure();
653
655 if (failed(memoryRequirements))
657 loadOp, "failed to determine memory requirements");
658
659 auto [memoryAccess, alignment] = *memoryRequirements;
660 rewriter.replaceOpWithNewOpspirv::LoadOp(loadOp, loadPtr, memoryAccess,
661 alignment);
662 return success();
663 }
664
665 LogicalResult
666 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
668 auto memrefType = cast(storeOp.getMemref().getType());
669 if (!memrefType.getElementType().isSignlessInteger())
671 "element type is not a signless int");
672
673 auto loc = storeOp.getLoc();
674 auto &typeConverter = *getTypeConverter();
675 Value accessChain =
677 adaptor.getIndices(), loc, rewriter);
678
679 if (!accessChain)
681 storeOp, "failed to convert element pointer type");
682
683 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
684
685 bool isBool = srcBits == 1;
686 if (isBool)
687 srcBits = typeConverter.getOptions().boolNumBits;
688
689 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
690 if (!pointerType)
692 "failed to convert memref type");
693
694 Type pointeeType = pointerType.getPointeeType();
695 IntegerType dstType;
696 if (typeConverter.allows(spirv::Capability::Kernel)) {
697 if (auto arrayType = dyn_castspirv::ArrayType(pointeeType))
698 dstType = dyn_cast(arrayType.getElementType());
699 else
700 dstType = dyn_cast(pointeeType);
701 } else {
702
703 Type structElemType =
704 castspirv::StructType(pointeeType).getElementType(0);
705 if (auto arrayType = dyn_castspirv::ArrayType(structElemType))
706 dstType = dyn_cast(arrayType.getElementType());
707 else
708 dstType = dyn_cast(
709 castspirv::RuntimeArrayType(structElemType).getElementType());
710 }
711
712 if (!dstType)
714 storeOp, "failed to determine destination element type");
715
716 int dstBits = static_cast<int>(dstType.getWidth());
717 assert(dstBits % srcBits == 0);
718
719 if (srcBits == dstBits) {
721 if (failed(memoryRequirements))
723 storeOp, "failed to determine memory requirements");
724
725 auto [memoryAccess, alignment] = *memoryRequirements;
726 Value storeVal = adaptor.getValue();
727 if (isBool)
728 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
729 rewriter.replaceOpWithNewOpspirv::StoreOp(storeOp, accessChain, storeVal,
730 memoryAccess, alignment);
731 return success();
732 }
733
734
735
736 if (typeConverter.allows(spirv::Capability::Kernel))
737 return failure();
738
739 auto accessChainOp = accessChain.getDefiningOpspirv::AccessChainOp();
740 if (!accessChainOp)
741 return failure();
742
743
744
745
746
747
748
749
750
751
752
753
754 assert(accessChainOp.getIndices().size() == 2);
755 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
757
758
759
761 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
762 Value clearBitsMask = rewriter.createOrFoldspirv::ShiftLeftLogicalOp(
763 loc, dstType, mask, offset);
764 clearBitsMask =
765 rewriter.createOrFoldspirv::NotOp(loc, dstType, clearBitsMask);
766
767 Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
769 srcBits, dstBits, rewriter);
770 std::optionalspirv::Scope scope = getAtomicOpScope(memrefType);
771 if (!scope)
772 return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
773
774 Value result = rewriter.createspirv::AtomicAndOp(
775 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
776 clearBitsMask);
777 result = rewriter.createspirv::AtomicOrOp(
778 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
779 storeVal);
780
781
782
783
784
785 rewriter.eraseOp(storeOp);
786
787 assert(accessChainOp.use_empty());
788 rewriter.eraseOp(accessChainOp);
789
790 return success();
791 }
792
793
794
795
796
797 LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
798 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
800 Location loc = addrCastOp.getLoc();
801 auto &typeConverter = *getTypeConverter();
802 if (!typeConverter.allows(spirv::Capability::Kernel))
804 loc, "address space casts require kernel capability");
805
806 auto sourceType = dyn_cast(addrCastOp.getSource().getType());
807 if (!sourceType)
809 loc, "SPIR-V lowering requires ranked memref types");
810 auto resultType = cast(addrCastOp.getResult().getType());
811
812 auto sourceStorageClassAttr =
813 dyn_cast_or_nullspirv::StorageClassAttr(sourceType.getMemorySpace());
814 if (!sourceStorageClassAttr)
816 diag << "source address space " << sourceType.getMemorySpace()
817 << " must be a SPIR-V storage class";
818 });
819 auto resultStorageClassAttr =
820 dyn_cast_or_nullspirv::StorageClassAttr(resultType.getMemorySpace());
821 if (!resultStorageClassAttr)
823 diag << "result address space " << resultType.getMemorySpace()
824 << " must be a SPIR-V storage class";
825 });
826
827 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
828 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
829
830 Value result = adaptor.getSource();
831 Type resultPtrType = typeConverter.convertType(resultType);
832 if (!resultPtrType)
834 "failed to convert memref type");
835
836 Type genericPtrType = resultPtrType;
837
838
839
840
841
842
843
844 if (sourceSc != spirv::StorageClass::Generic &&
845 resultSc != spirv::StorageClass::Generic) {
846 Type intermediateType =
847 MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
848 sourceType.getLayout(),
849 rewriter.getAttrspirv::StorageClassAttr(
850 spirv::StorageClass::Generic));
851 genericPtrType = typeConverter.convertType(intermediateType);
852 }
853 if (sourceSc != spirv::StorageClass::Generic) {
854 result =
855 rewriter.createspirv::PtrCastToGenericOp(loc, genericPtrType, result);
856 }
857 if (resultSc != spirv::StorageClass::Generic) {
858 result =
859 rewriter.createspirv::GenericCastToPtrOp(loc, resultPtrType, result);
860 }
861 rewriter.replaceOp(addrCastOp, result);
862 return success();
863 }
864
865 LogicalResult
866 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
868 auto memrefType = cast(storeOp.getMemref().getType());
869 if (memrefType.getElementType().isSignlessInteger())
872 *getTypeConverter(), memrefType, adaptor.getMemref(),
873 adaptor.getIndices(), storeOp.getLoc(), rewriter);
874
875 if (!storePtr)
876 return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
877
879 if (failed(memoryRequirements))
881 storeOp, "failed to determine memory requirements");
882
883 auto [memoryAccess, alignment] = *memoryRequirements;
885 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
886 return success();
887 }
888
889 LogicalResult ReinterpretCastPattern::matchAndRewrite(
890 memref::ReinterpretCastOp op, OpAdaptor adaptor,
892 Value src = adaptor.getSource();
893 auto srcType = dyn_castspirv::PointerType(src.getType());
894
895 if (!srcType)
897 diag << "invalid src type " << src.getType();
898 });
899
900 const TypeConverter *converter = getTypeConverter();
901
903 if (dstType != srcType)
905 diag << "invalid dst type " << op.getType();
906 });
907
909 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
910 .front();
913 return success();
914 }
915
917 if (!intType)
918 return rewriter.notifyMatchFailure(op, "failed to convert index type");
919
921 auto offsetValue = [&]() -> Value {
922 if (auto val = dyn_cast(offset))
923 return val;
924
925 int64_t attrVal = cast(cast(offset)).getInt();
927 return rewriter.createOrFoldspirv::ConstantOp(loc, intType, attr);
928 }();
929
931 op, src, offsetValue, std::nullopt);
932 return success();
933 }
934
935
936
937
938
939 LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
940 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
942 auto &typeConverter = *getTypeConverter();
943 Type indexType = typeConverter.getIndexType();
944 rewriter.replaceOpWithNewOpspirv::ConvertPtrToUOp(extractOp, indexType,
945 adaptor.getSource());
946 return success();
947 }
948
949
950
951
952
953 namespace mlir {
957 .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
958 DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
959 MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
960 CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
961 typeConverter, patterns.getContext());
962 }
963 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess