MLIR: lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/MathExtras.h"
36
37#include
38
39#define DEBUG_TYPE "mlir-spirv-conversion"
40
41using namespace mlir;
42
43namespace {
44
45
46
47
48
49static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
50 LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
51 if (vecType.isScalable()) {
52 LLVM_DEBUG(llvm::dbgs()
53 << "--scalable vectors are not supported -> BAIL\n");
54 return std::nullopt;
55 }
59 if (!targetShape) {
60 LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
61 return std::nullopt;
62 }
63 auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
64 if (!maybeShapeRatio) {
65 LLVM_DEBUG(llvm::dbgs()
66 << "--could not compute integral shape ratio -> BAIL\n");
67 return std::nullopt;
68 }
69 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
70 LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
71 return std::nullopt;
72 }
73 LLVM_DEBUG(llvm::dbgs()
74 << "--found an integral shape ratio to unroll to -> SUCCESS\n");
75 return targetShape;
76}
77
78
79
80
81
82
83
84template
85static LogicalResult checkExtensionRequirements(
88 for (const auto &ors : candidates) {
89 if (targetEnv.allows(ors))
90 continue;
91
92 LLVM_DEBUG({
94 for (spirv::Extension ext : ors)
95 extStrings.push_back(spirv::stringifyExtension(ext));
96
97 llvm::dbgs() << label << " illegal: requires at least one extension in ["
98 << llvm::join(extStrings, ", ")
99 << "] but none allowed in target environment\n";
100 });
101 return failure();
102 }
104}
105
106
107
108
109
110
111
112template
113static LogicalResult checkCapabilityRequirements(
116 for (const auto &ors : candidates) {
117 if (targetEnv.allows(ors))
118 continue;
119
120 LLVM_DEBUG({
122 for (spirv::Capability cap : ors)
123 capStrings.push_back(spirv::stringifyCapability(cap));
124
125 llvm::dbgs() << label << " illegal: requires at least one capability in ["
126 << llvm::join(capStrings, ", ")
127 << "] but none allowed in target environment\n";
128 });
129 return failure();
130 }
132}
133
134
135
136static bool needsExplicitLayout(spirv::StorageClass storageClass) {
137 switch (storageClass) {
138 case spirv::StorageClass::PhysicalStorageBuffer:
139 case spirv::StorageClass::PushConstant:
140 case spirv::StorageClass::StorageBuffer:
141 case spirv::StorageClass::Uniform:
142 return true;
143 default:
144 return false;
145 }
146}
147
148
149
151wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
152 auto structType = needsExplicitLayout(storageClass)
156}
157
158
159
160
161
164 return castspirv::ScalarType(
165 IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
166}
167
168
169
170static std::optional<int64_t>
172 if (isaspirv::ScalarType(type)) {
174
175
176
177
178
179
180 if (bitWidth == 1)
181 return std::nullopt;
182 return bitWidth / 8;
183 }
184
185
186 if (options.emulateUnsupportedFloatTypes && isa(type)) {
188 if (bitWidth == 8)
189 return bitWidth / 8;
190 return std::nullopt;
191 }
192
193 if (auto complexType = dyn_cast(type)) {
194 auto elementSize = getTypeNumBytes(options, complexType.getElementType());
195 if (!elementSize)
196 return std::nullopt;
197 return 2 * *elementSize;
198 }
199
200 if (auto vecType = dyn_cast(type)) {
201 auto elementSize = getTypeNumBytes(options, vecType.getElementType());
202 if (!elementSize)
203 return std::nullopt;
204 return vecType.getNumElements() * *elementSize;
205 }
206
207 if (auto memRefType = dyn_cast(type)) {
208
209
212 if (!memRefType.hasStaticShape() ||
213 failed(memRefType.getStridesAndOffset(strides, offset)))
214 return std::nullopt;
215
216
217
218
219 auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
220 if (!elementSize)
221 return std::nullopt;
222
223 if (memRefType.getRank() == 0)
224 return elementSize;
225
226 auto dims = memRefType.getShape();
227 if (llvm::is_contained(dims, ShapedType::kDynamic) ||
228 ShapedType::isDynamic(offset) ||
229 llvm::is_contained(strides, ShapedType::kDynamic))
230 return std::nullopt;
231
233 for (const auto &shape : enumerate(dims))
234 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
235
236 return (offset + memrefSize) * *elementSize;
237 }
238
239 if (auto tensorType = dyn_cast(type)) {
240 if (!tensorType.hasStaticShape())
241 return std::nullopt;
242
243 auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
244 if (!elementSize)
245 return std::nullopt;
246
247 int64_t size = *elementSize;
248 for (auto shape : tensorType.getShape())
250
251 return size;
252 }
253
254
255 return std::nullopt;
256}
257
258
262 std::optionalspirv::StorageClass storageClass = {}) {
263
266 type.getExtensions(extensions, storageClass);
267 type.getCapabilities(capabilities, storageClass);
268
269
270 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
271 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
272 return type;
273
274
275
276 if (.emulateLT32BitScalarTypes)
277 return nullptr;
278
279
281 LLVM_DEBUG(llvm::dbgs()
282 << type
283 << " not converted to 32-bit for SPIR-V to avoid truncation\n");
284 return nullptr;
285 }
286
287 if (auto floatType = dyn_cast(type)) {
288 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
290 }
291
292 auto intType = cast(type);
293 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
294 return IntegerType::get(targetEnv.getContext(), 32,
295 intType.getSignedness());
296}
297
298
299
300
301
302
303
304
305
307 IntegerType type) {
308 if (type.getWidth() > 8) {
309 LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n");
310 return nullptr;
311 }
313 LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
314 return nullptr;
315 }
316
317 if (!llvm::isPowerOf2_32(type.getWidth())) {
318 LLVM_DEBUG(llvm::dbgs()
319 << "unsupported non-power-of-two bitwidth in sub-byte" << type
320 << "\n");
321 return nullptr;
322 }
323
324 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
325 return IntegerType::get(type.getContext(), 32,
326 type.getSignedness());
327}
328
329
330
332 FloatType type) {
333 if (.emulateUnsupportedFloatTypes)
334 return nullptr;
335
336 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
337 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
338 Float8E8M0FNUType>(type))
339 return IntegerType::get(type.getContext(), type.getWidth());
340 LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");
341 return nullptr;
342}
343
344
345
346
347static ShapedType
348convertShaped8BitFloatType(ShapedType type,
350 if (.emulateUnsupportedFloatTypes)
351 return type;
352 Type srcElementType = type.getElementType();
353 Type convertedElementType = nullptr;
354
355 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
356 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
357 Float8E8M0FNUType>(srcElementType))
358 convertedElementType = IntegerType::get(
360
361 if (!convertedElementType)
362 return type;
363
364 return type.clone(convertedElementType);
365}
366
367
368
369
370static ShapedType
371convertIndexElementType(ShapedType type,
373 Type indexType = dyn_cast(type.getElementType());
374 if (!indexType)
375 return type;
376
378}
379
380
384 std::optionalspirv::StorageClass storageClass = {}) {
385 type = cast(convertIndexElementType(type, options));
386 type = cast(convertShaped8BitFloatType(type, options));
387 auto scalarType = dyn_cast_or_nullspirv::ScalarType(type.getElementType());
388 if (!scalarType) {
389
390
391 auto intType = dyn_cast(type.getElementType());
392 if (!intType) {
393 LLVM_DEBUG(llvm::dbgs()
394 << type
395 << " illegal: cannot convert non-scalar element type\n");
396 return nullptr;
397 }
398
399 Type elementType = convertSubByteIntegerType(options, intType);
400 if (!elementType)
401 return nullptr;
402
403 if (type.getRank() <= 1 && type.getNumElements() == 1)
404 return elementType;
405
406 if (type.getNumElements() > 4) {
407 LLVM_DEBUG(llvm::dbgs()
408 << type << " illegal: > 4-element unimplemented\n");
409 return nullptr;
410 }
411
412 return VectorType::get(type.getShape(), elementType);
413 }
414
415 if (type.getRank() <= 1 && type.getNumElements() == 1)
416 return convertScalarType(targetEnv, options, scalarType, storageClass);
417
419 LLVM_DEBUG(llvm::dbgs()
420 << type << " illegal: not a valid composite type\n");
421 return nullptr;
422 }
423
424
427 castspirv::CompositeType(type).getExtensions(extensions, storageClass);
428 castspirv::CompositeType(type).getCapabilities(capabilities, storageClass);
429
430
431 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
432 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
433 return type;
434
435 auto elementType =
436 convertScalarType(targetEnv, options, scalarType, storageClass);
437 if (elementType)
438 return VectorType::get(type.getShape(), elementType);
439 return nullptr;
440}
441
445 std::optionalspirv::StorageClass storageClass = {}) {
446 auto scalarType = dyn_cast_or_nullspirv::ScalarType(type.getElementType());
447 if (!scalarType) {
448 LLVM_DEBUG(llvm::dbgs()
449 << type << " illegal: cannot convert non-scalar element type\n");
450 return nullptr;
451 }
452
453 auto elementType =
454 convertScalarType(targetEnv, options, scalarType, storageClass);
455 if (!elementType)
456 return nullptr;
457 if (elementType != type.getElementType()) {
458 LLVM_DEBUG(llvm::dbgs()
459 << type << " illegal: complex type emulation unsupported\n");
460 return nullptr;
461 }
462
463 return VectorType::get(2, elementType);
464}
465
466
467
468
469
470
471
475
476 if (!type.hasStaticShape()) {
477 LLVM_DEBUG(llvm::dbgs()
478 << type << " illegal: dynamic shape unimplemented\n");
479 return nullptr;
480 }
481
482 type = cast(convertIndexElementType(type, options));
483 type = cast(convertShaped8BitFloatType(type, options));
484 auto scalarType = dyn_cast_or_nullspirv::ScalarType(type.getElementType());
485 if (!scalarType) {
486 LLVM_DEBUG(llvm::dbgs()
487 << type << " illegal: cannot convert non-scalar element type\n");
488 return nullptr;
489 }
490
491 std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
492 std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
493 if (!scalarSize || !tensorSize) {
494 LLVM_DEBUG(llvm::dbgs()
495 << type << " illegal: cannot deduce element count\n");
496 return nullptr;
497 }
498
499 int64_t arrayElemCount = *tensorSize / *scalarSize;
500 if (arrayElemCount == 0) {
501 LLVM_DEBUG(llvm::dbgs()
502 << type << " illegal: cannot handle zero-element tensors\n");
503 return nullptr;
504 }
505 if (arrayElemCount > std::numeric_limits::max()) {
506 LLVM_DEBUG(llvm::dbgs()
507 << type << " illegal: cannot fit tensor into target type\n");
508 return nullptr;
509 }
510
511 Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
512 if (!arrayElemType)
513 return nullptr;
514 std::optional<int64_t> arrayElemSize =
515 getTypeNumBytes(options, arrayElemType);
516 if (!arrayElemSize) {
517 LLVM_DEBUG(llvm::dbgs()
518 << type << " illegal: cannot deduce converted element size\n");
519 return nullptr;
520 }
521
523}
524
527 MemRefType type,
528 spirv::StorageClass storageClass) {
529 unsigned numBoolBits = options.boolNumBits;
530 if (numBoolBits != 8) {
531 LLVM_DEBUG(llvm::dbgs()
532 << "using non-8-bit storage for bool types unimplemented");
533 return nullptr;
534 }
535 auto elementType = dyn_castspirv::ScalarType(
536 IntegerType::get(type.getContext(), numBoolBits));
537 if (!elementType)
538 return nullptr;
539 Type arrayElemType =
540 convertScalarType(targetEnv, options, elementType, storageClass);
541 if (!arrayElemType)
542 return nullptr;
543 std::optional<int64_t> arrayElemSize =
544 getTypeNumBytes(options, arrayElemType);
545 if (!arrayElemSize) {
546 LLVM_DEBUG(llvm::dbgs()
547 << type << " illegal: cannot deduce converted element size\n");
548 return nullptr;
549 }
550
551 if (!type.hasStaticShape()) {
552
553
554 if (targetEnv.allows(spirv::Capability::Kernel))
556 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
558
559
560 return wrapInStructAndGetPointer(arrayType, storageClass);
561 }
562
563 if (type.getNumElements() == 0) {
564 LLVM_DEBUG(llvm::dbgs()
565 << type << " illegal: zero-element memrefs are not supported\n");
566 return nullptr;
567 }
568
569 int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
570 int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
571 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
573 if (targetEnv.allows(spirv::Capability::Kernel))
575 return wrapInStructAndGetPointer(arrayType, storageClass);
576}
577
580 MemRefType type,
581 spirv::StorageClass storageClass) {
582 IntegerType elementType = cast(type.getElementType());
583 Type arrayElemType = convertSubByteIntegerType(options, elementType);
584 if (!arrayElemType)
585 return nullptr;
586 int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);
587
588 if (!type.hasStaticShape()) {
589
590
591 if (targetEnv.allows(spirv::Capability::Kernel))
593 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
595
596
597 return wrapInStructAndGetPointer(arrayType, storageClass);
598 }
599
600 if (type.getNumElements() == 0) {
601 LLVM_DEBUG(llvm::dbgs()
602 << type << " illegal: zero-element memrefs are not supported\n");
603 return nullptr;
604 }
605
607 llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
608 int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
609 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
611 if (targetEnv.allows(spirv::Capability::Kernel))
613 return wrapInStructAndGetPointer(arrayType, storageClass);
614}
615
616static spirv::Dim convertRank(int64_t rank) {
617 switch (rank) {
618 case 1:
619 return spirv::Dim::Dim1D;
620 case 2:
621 return spirv::Dim::Dim2D;
622 case 3:
623 return spirv::Dim::Dim3D;
624 default:
625 llvm_unreachable("Invalid memref rank!");
626 }
627}
628
629static spirv::ImageFormat getImageFormat(Type elementType) {
631 .Case([](Float16Type) { return spirv::ImageFormat::R16f; })
632 .Case([](Float32Type) { return spirv::ImageFormat::R32f; })
633 .Case([](IntegerType intType) {
634 auto const isSigned = intType.isSigned() || intType.isSignless();
635#define BIT_WIDTH_CASE(BIT_WIDTH) \
636 case BIT_WIDTH: \
637 return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \
638 : spirv::ImageFormat::R##BIT_WIDTH##ui
639
640 switch (intType.getWidth()) {
643 default:
644 llvm_unreachable("Unhandled integer type!");
645 }
646 })
647 .DefaultUnreachable("Unhandled element type!");
648#undef BIT_WIDTH_CASE
649}
650
653 MemRefType type) {
654 auto attr = dyn_cast_or_nullspirv::StorageClassAttr(type.getMemorySpace());
655 if (!attr) {
656 LLVM_DEBUG(
657 llvm::dbgs()
658 << type
659 << " illegal: expected memory space to be a SPIR-V storage class "
660 "attribute; please use MemorySpaceToStorageClassConverter to map "
661 "numeric memory spaces beforehand\n");
662 return nullptr;
663 }
664 spirv::StorageClass storageClass = attr.getValue();
665
666
667
668
669 if (storageClass == spirv::StorageClass::Image) {
670 const int64_t rank = type.getRank();
671 if (rank < 1 || rank > 3) {
672 LLVM_DEBUG(llvm::dbgs()
673 << type << " illegal: cannot lower memref of rank " << rank
674 << " to a SPIR-V Image\n");
675 return nullptr;
676 }
677
678
679
680 auto elementType = type.getElementType();
681 if (!isaspirv::ScalarType(elementType)) {
682 LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot lower memref of "
683 << elementType << " to a SPIR-V Image\n");
684 return nullptr;
685 }
686
687
688
689
691 elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown,
692 spirv::ImageArrayedInfo::NonArrayed,
693 spirv::ImageSamplingInfo::SingleSampled,
694 spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType));
697 spvSampledImageType, spirv::StorageClass::UniformConstant);
698 return imagePtrType;
699 }
700
701 if (isa(type.getElementType())) {
702 if (type.getElementTypeBitWidth() == 1)
703 return convertBoolMemrefType(targetEnv, options, type, storageClass);
704 if (type.getElementTypeBitWidth() < 8)
705 return convertSubByteMemrefType(targetEnv, options, type, storageClass);
706 }
707
708 Type arrayElemType;
709 Type elementType = type.getElementType();
710 if (auto vecType = dyn_cast(elementType)) {
711 arrayElemType =
712 convertVectorType(targetEnv, options, vecType, storageClass);
713 } else if (auto complexType = dyn_cast(elementType)) {
714 arrayElemType =
715 convertComplexType(targetEnv, options, complexType, storageClass);
716 } else if (auto scalarType = dyn_castspirv::ScalarType(elementType)) {
717 arrayElemType =
718 convertScalarType(targetEnv, options, scalarType, storageClass);
719 } else if (auto indexType = dyn_cast(elementType)) {
720 type = cast(convertIndexElementType(type, options));
721 arrayElemType = type.getElementType();
722 } else if (auto floatType = dyn_cast(elementType)) {
723
724 type = cast(convertShaped8BitFloatType(type, options));
725 arrayElemType = type.getElementType();
726 } else {
727 LLVM_DEBUG(
728 llvm::dbgs()
729 << type
730 << " unhandled: can only convert scalar or vector element type\n");
731 return nullptr;
732 }
733 if (!arrayElemType)
734 return nullptr;
735
736 std::optional<int64_t> arrayElemSize =
737 getTypeNumBytes(options, arrayElemType);
738 if (!arrayElemSize) {
739 LLVM_DEBUG(llvm::dbgs()
740 << type << " illegal: cannot deduce converted element size\n");
741 return nullptr;
742 }
743
744 if (!type.hasStaticShape()) {
745
746
747 if (targetEnv.allows(spirv::Capability::Kernel))
749 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
751
752
753 return wrapInStructAndGetPointer(arrayType, storageClass);
754 }
755
756 std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
757 if (!memrefSize) {
758 LLVM_DEBUG(llvm::dbgs()
759 << type << " illegal: cannot deduce element count\n");
760 return nullptr;
761 }
762
763 if (*memrefSize == 0) {
764 LLVM_DEBUG(llvm::dbgs()
765 << type << " illegal: zero-element memrefs are not supported\n");
766 return nullptr;
767 }
768
769 int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
770 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
772 if (targetEnv.allows(spirv::Capability::Kernel))
774 return wrapInStructAndGetPointer(arrayType, storageClass);
775}
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
797
798 if (inputs.size() != 1) {
799 auto castOp =
800 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
801 return castOp.getResult(0);
802 }
803 Value input = inputs.front();
804
805
806 if (!isa(type)) {
807 auto castOp =
808 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
809 return castOp.getResult(0);
810 }
811 auto inputType = cast(input.getType());
812
813 auto scalarType = dyn_castspirv::ScalarType(type);
814 if (!scalarType) {
815 auto castOp =
816 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
817 return castOp.getResult(0);
818 }
819
820
821
822
823 if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
824 auto castOp =
825 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
826 return castOp.getResult(0);
827 }
828
829
831 Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
832 return spirv::IEqualOp::create(builder, loc, input, one);
833 }
834
835
838 scalarType.getExtensions(exts);
839 scalarType.getCapabilities(caps);
840 if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
841 failed(checkExtensionRequirements(type, targetEnv, exts))) {
842 auto castOp =
843 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
844 return castOp.getResult(0);
845 }
846
847
848
849
851 return spirv::SConvertOp::create(builder, loc, type, input);
852 }
853 return spirv::UConvertOp::create(builder, loc, type, input);
854}
855
856
857
858
859
860static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
861 spirv::BuiltIn builtin) {
862
863
864 for (auto varOp : body.getOpsspirv::GlobalVariableOp()) {
865 if (auto builtinAttr = varOp->getAttrOfType(
866 spirv::SPIRVDialect::getAttributeName(
867 spirv::Decoration::BuiltIn))) {
868 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
869 if (varBuiltIn == builtin) {
870 return varOp;
871 }
872 }
873 }
874 return nullptr;
875}
876
877
878std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
879 StringRef suffix) {
880 return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
881}
882
883
884static spirv::GlobalVariableOp
887 StringRef prefix, StringRef suffix) {
888 if (auto varOp = getBuiltinVariable(body, builtin))
889 return varOp;
890
893
894 spirv::GlobalVariableOp newVarOp;
896 case spirv::BuiltIn::NumWorkgroups:
897 case spirv::BuiltIn::WorkgroupSize:
898 case spirv::BuiltIn::WorkgroupId:
899 case spirv::BuiltIn::LocalInvocationId:
900 case spirv::BuiltIn::GlobalInvocationId: {
902 spirv::StorageClass::Input);
903 std::string name = getBuiltinVarName(builtin, prefix, suffix);
904 newVarOp =
905 spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
906 break;
907 }
908 case spirv::BuiltIn::SubgroupId:
909 case spirv::BuiltIn::NumSubgroups:
910 case spirv::BuiltIn::SubgroupSize:
911 case spirv::BuiltIn::SubgroupLocalInvocationId: {
912 auto ptrType =
914 std::string name = getBuiltinVarName(builtin, prefix, suffix);
915 newVarOp =
916 spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
917 break;
918 }
919 default:
920 emitError(loc, "unimplemented builtin variable generation for ")
921 << stringifyBuiltIn(builtin);
922 }
923 return newVarOp;
924}
925
926
927
928
929
930
931
932static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
934 Type indexType) {
936 4);
939}
940
941
942
943static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
944 unsigned elementCount) {
945 for (auto varOp : body.getOpsspirv::GlobalVariableOp()) {
946 auto ptrType = dyn_castspirv::PointerType(varOp.getType());
947 if (!ptrType)
948 continue;
949
950
951
952
953 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
954 auto numElements = castspirv::ArrayType(
955 castspirv::StructType(ptrType.getPointeeType())
956 .getElementType(0))
957 .getNumElements();
958 if (numElements == elementCount)
959 return varOp;
960 }
961 }
962 return nullptr;
963}
964
965
966
967static spirv::GlobalVariableOp
968getOrInsertPushConstantVariable(Location loc, Block &block,
970 Type indexType) {
971 if (auto varOp = getPushConstantVariable(block, elementCount))
972 return varOp;
973
975 auto type = getPushConstantStorageType(elementCount, builder, indexType);
976 const char *name = "__push_constant_var__";
977 return spirv::GlobalVariableOp::create(builder, loc, type, name,
978 nullptr);
979}
980
981
982
983
984
985
986
987struct FuncOpConversion final : OpConversionPatternfunc::FuncOp {
988 using Base::Base;
989
990 LogicalResult
991 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
992 ConversionPatternRewriter &rewriter) const override {
993 FunctionType fnType = funcOp.getFunctionType();
994 if (fnType.getNumResults() > 1)
995 return failure();
996
997 TypeConverter::SignatureConversion signatureConverter(
998 fnType.getNumInputs());
999 for (const auto &argType : enumerate(fnType.getInputs())) {
1000 auto convertedType = getTypeConverter()->convertType(argType.value());
1001 if (!convertedType)
1002 return failure();
1003 signatureConverter.addInputs(argType.index(), convertedType);
1004 }
1005
1006 Type resultType;
1007 if (fnType.getNumResults() == 1) {
1008 resultType = getTypeConverter()->convertType(fnType.getResult(0));
1009 if (!resultType)
1010 return failure();
1011 }
1012
1013
1014 auto newFuncOp = spirv::FuncOp::create(
1015 rewriter, funcOp.getLoc(), funcOp.getName(),
1016 rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
1017 resultType ? TypeRange(resultType)
1019
1020
1021 for (const auto &namedAttr : funcOp->getAttrs()) {
1022 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
1024 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1025 }
1026
1027 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1028 newFuncOp.end());
1029 if (failed(rewriter.convertRegionTypes(
1030 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
1031 return failure();
1032 rewriter.eraseOp(funcOp);
1034 }
1035};
1036
1037
1038
1039struct FuncOpVectorUnroll final : OpRewritePatternfunc::FuncOp {
1040 using Base::Base;
1041
1042 LogicalResult matchAndRewrite(func::FuncOp funcOp,
1044 FunctionType fnType = funcOp.getFunctionType();
1045
1046
1047 if (funcOp.isDeclaration()) {
1048 LLVM_DEBUG(llvm::dbgs()
1049 << fnType << " illegal: declarations are unsupported\n");
1050 return failure();
1051 }
1052
1053
1054 auto newFuncOp = func::FuncOp::create(rewriter, funcOp.getLoc(),
1055 funcOp.getName(), fnType);
1057 newFuncOp.end());
1058
1059 Location loc = newFuncOp.getBody().getLoc();
1060
1061 Block &entryBlock = newFuncOp.getBlocks().front();
1064
1065 TypeConverter::SignatureConversion oneToNTypeMapping(
1066 fnType.getInputs().size());
1067
1068
1069
1070
1072 size_t newInputNo = 0;
1073
1074
1075
1076
1077
1078 llvm::SmallDenseMap<Operation *, size_t> tmpOps;
1079
1080
1081 size_t newOpCount = 0;
1082
1083
1084 for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
1085
1086 auto origVecType = dyn_cast(origType);
1087 if (!origVecType) {
1088
1089 Value result = arith::ConstantOp::create(
1090 rewriter, loc, origType, rewriter.getZeroAttr(origType));
1092 tmpOps.insert({result.getDefiningOp(), newInputNo});
1093 oneToNTypeMapping.addInputs(origInputNo, origType);
1094 ++newInputNo;
1095 ++newOpCount;
1096 continue;
1097 }
1098
1100 if (!targetShape) {
1101
1102 Value result = arith::ConstantOp::create(
1103 rewriter, loc, origType, rewriter.getZeroAttr(origType));
1105 tmpOps.insert({result.getDefiningOp(), newInputNo});
1106 oneToNTypeMapping.addInputs(origInputNo, origType);
1107 ++newInputNo;
1108 ++newOpCount;
1109 continue;
1110 }
1111 VectorType unrolledType =
1112 VectorType::get(*targetShape, origVecType.getElementType());
1113 auto originalShape =
1114 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1115
1116
1117 Value result = arith::ConstantOp::create(
1118 rewriter, loc, origVecType, rewriter.getZeroAttr(origVecType));
1119 ++newOpCount;
1120
1121 Value dummy = arith::ConstantOp::create(
1122 rewriter, loc, unrolledType, rewriter.getZeroAttr(unrolledType));
1123 ++newOpCount;
1124
1125
1130 result = vector::InsertStridedSliceOp::create(rewriter, loc, dummy,
1131 result, offsets, strides);
1132 newTypes.push_back(unrolledType);
1133 unrolledInputNums.push_back(newInputNo);
1134 ++newInputNo;
1135 ++newOpCount;
1136 }
1138 oneToNTypeMapping.addInputs(origInputNo, newTypes);
1139 }
1140
1141
1142 auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
1143 auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1145 [&] { newFuncOp.setFunctionType(newFnType); });
1146
1147
1148 entryBlock.eraseArguments(0, fnType.getNumInputs());
1150 entryBlock.addArguments(convertedTypes, locs);
1151
1152
1153
1154 for (auto &[placeholderOp, argIdx] : tmpOps) {
1155 if (!placeholderOp)
1156 continue;
1159 }
1160
1161
1162
1163
1164
1165 size_t unrolledInputIdx = 0;
1166 for (auto [count, op] : enumerate(entryBlock.getOperations())) {
1168
1169
1170
1171 if (count >= newOpCount)
1172 continue;
1173 if (auto vecOp = dyn_castvector::InsertStridedSliceOp(op)) {
1174 size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1176 curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1177 });
1178 ++unrolledInputIdx;
1179 }
1180 }
1181
1182
1183
1184 rewriter.eraseOp(funcOp);
1186 }
1187};
1188
1189
1190
1191
1192
1193
1194
1195struct ReturnOpVectorUnroll final : OpRewritePatternfunc::ReturnOp {
1196 using Base::Base;
1197
1198 LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1200
1201 auto funcOp = dyn_castfunc::FuncOp(returnOp->getParentOp());
1202 if (!funcOp)
1203 return failure();
1204
1205 FunctionType fnType = funcOp.getFunctionType();
1206 TypeConverter::SignatureConversion oneToNTypeMapping(
1207 fnType.getResults().size());
1208 Location loc = returnOp.getLoc();
1209
1210
1212
1213
1214 for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
1215
1216 auto origVecType = dyn_cast(origType);
1217 if (!origVecType) {
1218 oneToNTypeMapping.addInputs(origResultNo, origType);
1219 newOperands.push_back(returnOp.getOperand(origResultNo));
1220 continue;
1221 }
1222
1224 if (!targetShape) {
1225
1226 oneToNTypeMapping.addInputs(origResultNo, origType);
1227 newOperands.push_back(returnOp.getOperand(origResultNo));
1228 continue;
1229 }
1230 VectorType unrolledType =
1231 VectorType::get(*targetShape, origVecType.getElementType());
1232
1233
1234
1235 auto originalShape =
1236 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1239 extractShape.back() = targetShape->back();
1241 Value returnValue = returnOp.getOperand(origResultNo);
1244 Value result = vector::ExtractStridedSliceOp::create(
1245 rewriter, loc, returnValue, offsets, extractShape, strides);
1246 if (originalShape.size() > 1) {
1249 vector::ExtractOp::create(rewriter, loc, result, extractIndices);
1250 }
1251 newOperands.push_back(result);
1252 newTypes.push_back(unrolledType);
1253 }
1254 oneToNTypeMapping.addInputs(origResultNo, newTypes);
1255 }
1256
1257
1258 auto newFnType =
1260 TypeRange(oneToNTypeMapping.getConvertedTypes()));
1262 [&] { funcOp.setFunctionType(newFnType); });
1263
1264
1265
1267 func::ReturnOp::create(rewriter, loc, newOperands));
1268
1270 }
1271};
1272
1273}
1274
1275
1276
1277
1278
1282 StringRef prefix, StringRef suffix) {
1284 if (!parent) {
1285 op->emitError("expected operation to be within a module-like op");
1286 return nullptr;
1287 }
1288
1289 spirv::GlobalVariableOp varOp =
1291 builtin, integerType, builder, prefix, suffix);
1292 Value ptr = spirv::AddressOfOp::create(builder, op->getLoc(), varOp);
1293 return spirv::LoadOp::create(builder, op->getLoc(), ptr);
1294}
1295
1296
1297
1298
1299
1301 unsigned offset, Type integerType,
1305 if (!parent) {
1306 op->emitError("expected operation to be within a module-like op");
1307 return nullptr;
1308 }
1309
1310 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1311 loc, parent->getRegion(0).front(), elementCount, builder, integerType);
1312
1313 Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
1314 Value offsetOp = spirv::ConstantOp::create(builder, loc, integerType,
1316 auto addrOp = spirv::AddressOfOp::create(builder, loc, varOp);
1317 auto acOp = spirv::AccessChainOp::create(builder, loc, addrOp,
1319 return spirv::LoadOp::create(builder, loc, acOp);
1320}
1321
1322
1323
1324
1325
1329 assert(indices.size() == strides.size() &&
1330 "must provide indices for all dimensions");
1331
1332
1333
1334
1335
1336
1337 Value linearizedIndex = builder.createOrFoldspirv::ConstantOp(
1338 loc, integerType, IntegerAttr::get(integerType, offset));
1339 for (const auto &index : llvm::enumerate(indices)) {
1341 loc, integerType,
1342 IntegerAttr::get(integerType, strides[index.index()]));
1344 builder.createOrFoldspirv::IMulOp(loc, index.value(), strideVal);
1345 linearizedIndex =
1346 builder.createOrFoldspirv::IAddOp(loc, update, linearizedIndex);
1347 }
1348 return linearizedIndex;
1349}
1350
1352 MemRefType baseType, Value basePtr,
1355
1356
1359 if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1360 llvm::is_contained(strides, ShapedType::kDynamic) ||
1361 ShapedType::isDynamic(offset)) {
1362 return nullptr;
1363 }
1364
1365 auto indexType = typeConverter.getIndexType();
1366
1368 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1369
1370
1371 linearizedIndices.push_back(zero);
1372
1373 if (baseType.getRank() == 0) {
1374 linearizedIndices.push_back(zero);
1375 } else {
1376 linearizedIndices.push_back(
1378 }
1379 return spirv::AccessChainOp::create(builder, loc, basePtr, linearizedIndices);
1380}
1381
1383 MemRefType baseType, Value basePtr,
1386
1387
1390 if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1391 llvm::is_contained(strides, ShapedType::kDynamic) ||
1392 ShapedType::isDynamic(offset)) {
1393 return nullptr;
1394 }
1395
1396 auto indexType = typeConverter.getIndexType();
1397
1399 Value linearIndex;
1400 if (baseType.getRank() == 0) {
1401 linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1402 } else {
1403 linearIndex =
1405 }
1406 Type pointeeType =
1407 castspirv::PointerType(basePtr.getType()).getPointeeType();
1408 if (isaspirv::ArrayType(pointeeType)) {
1409 linearizedIndices.push_back(linearIndex);
1410 return spirv::AccessChainOp::create(builder, loc, basePtr,
1411 linearizedIndices);
1412 }
1413 return spirv::PtrAccessChainOp::create(builder, loc, basePtr, linearIndex,
1414 linearizedIndices);
1415}
1416
1418 MemRefType baseType, Value basePtr,
1421
1422 if (typeConverter.allows(spirv::Capability::Kernel)) {
1424 builder);
1425 }
1426
1428 builder);
1429}
1430
1431
1432
1433
1434
1436 for (int i : {4, 3, 2}) {
1437 if (size % i == 0)
1438 return i;
1439 }
1440 return 1;
1441}
1442
1445 VectorType srcVectorType = op.getSourceVectorType();
1446 assert(srcVectorType.getRank() == 1);
1449 return {vectorSize};
1450}
1451
1454 VectorType vectorType = op.getResultVectorType();
1456 nativeSize.back() =
1458 return nativeSize;
1459}
1460
1461std::optional<SmallVector<int64_t>>
1464 if (auto vecType = dyn_cast(op->getResultTypes()[0])) {
1466 nativeSize.back() =
1468 return nativeSize;
1469 }
1470 }
1471
1473 .Case<vector::ReductionOp, vector::TransposeOp>(
1475 .Default(std::nullopt);
1476}
1477
1490
1493
1494
1495 {
1501 return failure();
1502 }
1503
1504
1505
1506 {
1509 patterns, vector::VectorTransposeLowering::EltWise);
1512 return failure();
1513 }
1514
1515
1516 {
1518
1519
1520 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
1521 vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
1522 vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
1523
1524
1525
1526 vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
1528 vector::InsertOp::getCanonicalizationPatterns(patterns, context);
1529 vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
1530
1531
1532
1533 vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
1534 vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
1535
1537 return failure();
1538 }
1540}
1541
1542
1543
1544
1545
1548 : targetEnv(targetAttr), options(options) {
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1560
1561 addConversion([this](IndexType ) { return getIndexType(); });
1562
1563 addConversion([this](IntegerType intType) -> std::optional {
1564 if (auto scalarType = dyn_castspirv::ScalarType(intType))
1565 return convertScalarType(this->targetEnv, this->options, scalarType);
1566 if (intType.getWidth() < 8)
1567 return convertSubByteIntegerType(this->options, intType);
1568 return Type();
1569 });
1570
1571 addConversion([this](FloatType floatType) -> std::optional {
1572 if (auto scalarType = dyn_castspirv::ScalarType(floatType))
1573 return convertScalarType(this->targetEnv, this->options, scalarType);
1574 if (floatType.getWidth() == 8)
1575 return convert8BitFloatType(this->options, floatType);
1576 return Type();
1577 });
1578
1579 addConversion([this](ComplexType complexType) {
1580 return convertComplexType(this->targetEnv, this->options, complexType);
1581 });
1582
1583 addConversion([this](VectorType vectorType) {
1584 return convertVectorType(this->targetEnv, this->options, vectorType);
1585 });
1586
1587 addConversion([this](TensorType tensorType) {
1588 return convertTensorType(this->targetEnv, this->options, tensorType);
1589 });
1590
1591 addConversion([this](MemRefType memRefType) {
1592 return convertMemrefType(this->targetEnv, this->options, memRefType);
1593 });
1594
1595
1596 addSourceMaterialization(
1598 return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1599 });
1602 auto cast = UnrealizedConversionCastOp::create(builder, loc, type, inputs);
1603 return cast.getResult(0);
1604 });
1605}
1606
1608 return ::getIndexType(getContext(), options);
1609}
1610
1611MLIRContext *SPIRVTypeConverter::getContext() const {
1612 return targetEnv.getAttr().getContext();
1613}
1614
1616 return targetEnv.allows(capability);
1617}
1618
1619
1620
1621
1622
1623std::unique_ptr
1625 std::unique_ptr target(
1626
1627 new SPIRVConversionTarget(targetAttr));
1628 SPIRVConversionTarget *targetPtr = target.get();
1629 target->addDynamicallyLegalDialectspirv::SPIRVDialect(
1630
1631
1632 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
1634}
1635
1636SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
1638
1639bool SPIRVConversionTarget::isLegalOp(Operation *op) {
1640
1641
1642
1643 if (auto minVersionIfx = dyn_castspirv::QueryMinVersionInterface(op)) {
1644 std::optionalspirv::Version minVersion = minVersionIfx.getMinVersion();
1645 if (minVersion && *minVersion > this->targetEnv.getVersion()) {
1646 LLVM_DEBUG(llvm::dbgs()
1647 << op->getName() << " illegal: requiring min version "
1648 << spirv::stringifyVersion(*minVersion) << "\n");
1649 return false;
1650 }
1651 }
1652 if (auto maxVersionIfx = dyn_castspirv::QueryMaxVersionInterface(op)) {
1653 std::optionalspirv::Version maxVersion = maxVersionIfx.getMaxVersion();
1654 if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1655 LLVM_DEBUG(llvm::dbgs()
1656 << op->getName() << " illegal: requiring max version "
1657 << spirv::stringifyVersion(*maxVersion) << "\n");
1658 return false;
1659 }
1660 }
1661
1662
1663
1664
1665 if (auto extensions = dyn_castspirv::QueryExtensionInterface(op))
1666 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1667 extensions.getExtensions())))
1668 return false;
1669
1670
1671
1672
1673 if (auto capabilities = dyn_castspirv::QueryCapabilityInterface(op))
1674 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1675 capabilities.getCapabilities())))
1676 return false;
1677
1678 SmallVector<Type, 4> valueTypes;
1681
1682
1683 if (llvm::any_of(valueTypes,
1684 [](Type t) { return !isaspirv::SPIRVType(t); }))
1685 return false;
1686
1687
1688
1689 if (auto globalVar = dyn_castspirv::GlobalVariableOp(op))
1690 valueTypes.push_back(globalVar.getType());
1691
1692
1693
1694 SmallVector<ArrayRefspirv::Extension, 4> typeExtensions;
1695 SmallVector<ArrayRefspirv::Capability, 8> typeCapabilities;
1696 for (Type valueType : valueTypes) {
1697 typeExtensions.clear();
1698 castspirv::SPIRVType(valueType).getExtensions(typeExtensions);
1699 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1700 typeExtensions)))
1701 return false;
1702
1703 typeCapabilities.clear();
1704 castspirv::SPIRVType(valueType).getCapabilities(typeCapabilities);
1705 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1706 typeCapabilities)))
1707 return false;
1708 }
1709
1710 return true;
1711}
1712
1713
1714
1715
1716
1719 patterns.add(typeConverter, patterns.getContext());
1720}
1721
1725
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
#define BIT_WIDTH_CASE(BIT_WIDTH)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
OpListType & getOperations()
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setOperand(unsigned idx, Value value)
operand_type_iterator operand_type_end()
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_type_iterator result_type_end()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_type_iterator result_type_begin()
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumResults()
Return the number of results held by this operation.
operand_type_iterator operand_type_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Definition SPIRVConversion.cpp:1624
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
Definition SPIRVConversion.cpp:1607
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
Definition SPIRVConversion.cpp:1546
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
Definition SPIRVConversion.cpp:1615
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ArrayType get(Type elementType, unsigned elementCount)
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
static SampledImageType get(Type imageType)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Definition SPIRVConversion.cpp:1417
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Definition SPIRVConversion.cpp:1382
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
Definition SPIRVConversion.cpp:1300
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Definition SPIRVConversion.cpp:1462
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
Definition SPIRVConversion.cpp:1326
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Definition SPIRVConversion.cpp:1491
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Definition SPIRVConversion.cpp:1351
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
Definition SPIRVConversion.cpp:1444
int getComputeVectorSize(int64_t size)
Definition SPIRVConversion.cpp:1435
LogicalResult unrollVectorsInSignatures(Operation *op)
Definition SPIRVConversion.cpp:1478
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
Definition SPIRVConversion.cpp:1722
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
Definition SPIRVConversion.cpp:1726
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
Definition SPIRVConversion.cpp:1717
const FrozenRewritePatternSet & patterns
llvm::TypeSwitch< T, ResultT > TypeSwitch
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)