MLIR: lib/Conversion/LLVMCommon/TypeConverter.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Threading.h"
16 #include
17 #include
18 #include
19
20 using namespace mlir;
21
23 {
24
26 std::defer_lock);
27 if (getContext().isMultithreadingEnabled())
28 lock.lock();
31 return *recursiveStack->second;
32 }
33
34
35
39 return *recursiveStackInserted.first->second;
40 }
41
42
46
47
49 return values.size() == 1 &&
50 isaLLVM::LLVMPointerType(values.front().getType());
51 }
52
53
58
59
63 inputs);
64 }
65
66
70 assert(resultType && "expected non-null result type");
73 resultType, inputs[0]);
76 true))
78
79
81 }
82
83
88
89
90
93 if (!packed)
95 return builder.create(loc, resultType, packed)
96 .getResult(0);
97 }
98
99
101 MemRefType resultType,
104
105
106
109 if (!packed)
111 return builder.create(loc, resultType, packed)
112 .getResult(0);
113 }
114
115
119 : llvmDialect(ctx->getOrLoadDialectLLVM::LLVMDialect()), options(options),
120 dataLayoutAnalysis(analysis) {
121 assert(llvmDialect && "LLVM IR dialect is not registered");
122
123
124 addConversion([&](ComplexType type) { return convertComplexType(type); });
125 addConversion([&](FloatType type) { return convertFloatType(type); });
126 addConversion([&](FunctionType type) { return convertFunctionType(type); });
127 addConversion([&](IndexType type) { return convertIndexType(type); });
128 addConversion([&](IntegerType type) { return convertIntegerType(type); });
129 addConversion([&](MemRefType type) { return convertMemRefType(type); });
131 [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
132 addConversion([&](VectorType type) -> std::optional {
133 FailureOr llvmType = convertVectorType(type);
134 if (failed(llvmType))
135 return std::nullopt;
136 return llvmType;
137 });
138
139
140
141
144 : std::nullopt;
145 });
146
148 -> std::optional {
149
151 results.push_back(type);
152 return success();
153 }
154
155 if (type.isIdentified()) {
156 auto convertedType = LLVM::LLVMStructType::getIdentified(
157 type.getContext(), ("_Converted." + type.getName()).str());
158
160 if (llvm::count(recursiveStack, type)) {
161 results.push_back(convertedType);
162 return success();
163 }
164 recursiveStack.push_back(type);
165 auto popConversionCallStack = llvm::make_scope_exit(
166 [&recursiveStack]() { recursiveStack.pop_back(); });
167
169 convertedElemTypes.reserve(type.getBody().size());
170 if (failed(convertTypes(type.getBody(), convertedElemTypes)))
171 return std::nullopt;
172
173
174
175 if (!convertedType.isInitialized()) {
176 if (failed(
177 convertedType.setBody(convertedElemTypes, type.isPacked()))) {
178 return failure();
179 }
180 results.push_back(convertedType);
181 return success();
182 }
183
184
185
186
187 if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) &&
188 convertedType.isPacked() == type.isPacked()) {
189 results.push_back(convertedType);
190 return success();
191 }
192
193 return failure();
194 }
195
197 convertedSubtypes.reserve(type.getBody().size());
198 if (failed(convertTypes(type.getBody(), convertedSubtypes)))
199 return std::nullopt;
200
201 results.push_back(LLVM::LLVMStructType::getLiteral(
202 type.getContext(), convertedSubtypes, type.isPacked()));
203 return success();
204 });
205 addConversion([&](LLVM::LLVMArrayType type) -> std::optional {
206 if (auto element = convertType(type.getElementType()))
208 return std::nullopt;
209 });
210 addConversion([&](LLVM::LLVMFunctionType type) -> std::optional {
211 Type convertedResType = convertType(type.getReturnType());
212 if (!convertedResType)
213 return std::nullopt;
214
216 convertedArgTypes.reserve(type.getNumParams());
217 if (failed(convertTypes(type.getParams(), convertedArgTypes)))
218 return std::nullopt;
219
221 type.isVarArg());
222 });
223
224
225
228 return builder.create(loc, resultType, inputs)
229 .getResult(0);
230 });
233 return builder.create(loc, resultType, inputs)
234 .getResult(0);
235 });
236
237
238
239
244 *this);
245 });
249 });
250
251
255
256
257
258 if (!originalType)
260 if (resultType != convertType(originalType))
262 if (auto memrefType = dyn_cast(originalType))
264 if (auto unrankedMemrefType = dyn_cast(originalType))
266 *this);
268 });
269
270
272 [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
273 }
274
275
278 }
279
282 }
283
285 return options.dataLayout.getPointerSizeInBits(addressSpace);
286 }
287
288 Type LLVMTypeConverter::convertIndexType(IndexType type) const {
290 }
291
292 Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
294 }
295
296 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
297
299 return type;
300
301
302 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
303 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
304 Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
305 Float8E8M0FNUType>(type))
307
308
309
310 return Type();
311 }
312
313
314
315
316
317 Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
318 auto elementType = convertType(type.getElementType());
319 return LLVM::LLVMStructType::getLiteral(&getContext(),
320 {elementType, elementType});
321 }
322
323
324
325 Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
327 }
328
329
330
331
332 static void
334 SmallVectorImpl<std::optional> &result) {
335 assert(result.empty() && "Unexpected non-empty output");
336 result.resize(funcOp.getNumArguments(), std::nullopt);
337 bool foundByValByRefAttrs = false;
338 for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
339 for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
340 if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
341 namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
342 foundByValByRefAttrs = true;
343 result[argIdx] = namedAttr;
344 break;
345 }
346 }
347 }
348
349 if (!foundByValByRefAttrs)
350 result.clear();
351 }
352
353
354
355
356
357
358
359
360
361 Type LLVMTypeConverter::convertFunctionSignatureImpl(
362 FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
363 LLVMTypeConverter::SignatureConversion &result,
364 SmallVectorImpl<std::optional> *byValRefNonPtrAttrs) const {
365
366 useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
369
370 for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
372 if (failed(funcArgConverter(*this, type, converted)))
373 return {};
374
375
376
377 if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() &&
378 converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
379
380
381 if (isaLLVM::LLVMPointerType(converted[0]))
382 (*byValRefNonPtrAttrs)[idx] = std::nullopt;
383 else
385 }
386
387 result.addInputs(idx, converted);
388 }
389
390
391
392
393 Type resultType =
394 funcTy.getNumResults() == 0
397 if (!resultType)
398 return {};
400 isVariadic);
401 }
402
404 FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
406 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
407 result,
408 nullptr);
409 }
410
412 FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv,
414 SmallVectorImpl<std::optional> &byValRefNonPtrAttrs) const {
415
416
417
419 auto funcTy = cast(funcOp.getFunctionType());
420 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
421 result, &byValRefNonPtrAttrs);
422 }
423
424
425
426 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
429
430 Type resultType = type.getNumResults() == 0
433 if (!resultType)
434 return {};
435
437 auto structType = dyn_castLLVM::LLVMStructType(resultType);
438 if (structType) {
439
440
441 inputs.push_back(ptrType);
443 }
444
445 for (Type t : type.getInputs()) {
448 return {};
449 if (isa<MemRefType, UnrankedMemRefType>(t))
450 converted = ptrType;
451 inputs.push_back(converted);
452 }
453
455 }
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
487 bool unpackAggregates) const {
488 if (!type.isStrided()) {
491 "conversion to strided form failed either due to non-strided layout "
492 "maps (which should have been normalized away) or other reasons");
493 return {};
494 }
495
497 if (!elementType)
498 return {};
499
501 if (failed(addressSpace)) {
503 "conversion of memref memory space ")
504 << type.getMemorySpace()
505 << " to integer address space "
506 "failed. Consider adding memory space conversions.";
507 return {};
508 }
510
512
514 auto rank = type.getRank();
515 if (rank == 0)
516 return results;
517
518 if (unpackAggregates)
519 results.insert(results.end(), 2 * rank, indexTy);
520 else
522 return results;
523 }
524
525 unsigned
528
532 }
533
534
535
536 Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
537
538
541 if (types.empty())
542 return {};
543 return LLVM::LLVMStructType::getLiteral(&getContext(), types);
544 }
545
546
547
548
549
550
551
552
556 }
557
560
564 }
565
566 Type LLVMTypeConverter::convertUnrankedMemRefType(
569 return {};
570 return LLVM::LLVMStructType::getLiteral(&getContext(),
572 }
573
574 FailureOr
576 if (!type.getMemorySpace())
577 return 0;
578 std::optional converted =
580 if (!converted)
581 return failure();
582 if (!(*converted))
583 return 0;
584 if (auto explicitSpace = dyn_cast_if_present(*converted)) {
585 if (explicitSpace.getType().isIndex() ||
586 explicitSpace.getType().isSignlessInteger())
587 return explicitSpace.getInt();
588 }
589 return failure();
590 }
591
592
594 if (isa(type))
595
596 return false;
597
598
599
600 auto memrefTy = cast(type);
601 if (!memrefTy.hasStaticShape())
602 return false;
603
604 int64_t offset = 0;
606 if (failed(memrefTy.getStridesAndOffset(strides, offset)))
607 return false;
608
609 for (int64_t stride : strides)
610 if (ShapedType::isDynamic(stride))
611 return false;
612
613 return !ShapedType::isDynamic(offset);
614 }
615
616
617 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
619 return {};
621 if (!elementType)
622 return {};
624 if (failed(addressSpace))
625 return {};
627 }
628
629
630
631
632
633
634
635
636 FailureOr LLVMTypeConverter::convertVectorType(VectorType type) const {
637 auto elementType = convertType(type.getElementType());
638 if (!elementType)
639 return {};
640 if (type.getShape().empty())
643 type.getScalableDims().back());
645 "expected vector type compatible with the LLVM dialect");
646
647
648
649 if (llvm::is_contained(type.getScalableDims().drop_back(), true))
650 return failure();
651 auto shape = type.getShape();
652 for (int i = shape.size() - 2; i >= 0; --i)
654 return vectorType;
655 }
656
657
658
659
660
661
663 Type type, bool useBarePtrCallConv) const {
664 if (useBarePtrCallConv)
665 if (auto memrefTy = dyn_cast(type))
666 return convertMemRefToBarePtr(memrefTy);
667
669 }
670
671
672
673
677 assert(stdTypes.size() == values.size() &&
678 "The number of types and values doesn't match");
679 for (unsigned i = 0, end = values.size(); i < end; ++i)
680 if (auto memrefTy = dyn_cast(stdTypes[i]))
682 memrefTy, values[i]);
683 }
684
685
686
687
688
690 assert(!types.empty() && "expected non-empty list of type");
691 if (types.size() == 1)
693
695 resultTypes.reserve(types.size());
696 for (Type type : types) {
699 return {};
700 resultTypes.push_back(converted);
701 }
702
703 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
704 }
705
706
707
708
709
711 bool useBarePtrCallConv) const {
712 assert(!types.empty() && "expected non-empty list of type");
713
715 if (types.size() == 1)
717
719 resultTypes.reserve(types.size());
720 for (auto t : types) {
723 return {};
724 resultTypes.push_back(converted);
725 }
726
727 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
728 }
729
732
733
737 Value allocated =
738 builder.createLLVM::AllocaOp(loc, ptrType, operand.getType(), one);
739
740 builder.createLLVM::StoreOp(loc, operand, allocated);
741 return allocated;
742 }
743
747 bool useBarePtrCallConv) const {
749 promotedOperands.reserve(operands.size());
751 for (auto it : llvm::zip(opOperands, operands)) {
752 auto operand = std::get<0>(it);
753 auto llvmOperand = std::get<1>(it);
754
755 if (useBarePtrCallConv) {
756
757
758 if (isa(operand.getType())) {
760 llvmOperand = desc.alignedPtr(builder, loc);
761 } else if (isa(operand.getType())) {
762 llvm_unreachable("Unranked memrefs are not supported");
763 }
764 } else {
765 if (isa(operand.getType())) {
767 promotedOperands);
768 continue;
769 }
770 if (auto memrefType = dyn_cast(operand.getType())) {
772 promotedOperands);
773 continue;
774 }
775 }
776
777 promotedOperands.push_back(llvmOperand);
778 }
779 return promotedOperands;
780 }
781
782
783
784
785
786 LogicalResult
789 if (auto memref = dyn_cast(type)) {
790
791
792 auto converted =
794 if (converted.empty())
795 return failure();
796 result.append(converted.begin(), converted.end());
797 return success();
798 }
799 if (isa(type)) {
801 if (converted.empty())
802 return failure();
803 result.append(converted.begin(), converted.end());
804 return success();
805 }
806 auto converted = converter.convertType(type);
807 if (!converted)
808 return failure();
809 result.push_back(converted);
810 return success();
811 }
812
813
814
815 LogicalResult
819 type, true);
820 if (!llvmTy)
821 return failure();
822
823 result.push_back(llvmTy);
824 return success();
825 }
static llvm::ManagedStatic< PassManagerOptions > options
static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
Pack SSA values into a ranked memref descriptor struct.
static Value unrankedMemRefMaterialization(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
MemRef descriptor elements -> UnrankedMemRefType.
static Value packUnrankedMemRefDesc(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
Pack SSA values into an unranked memref descriptor struct.
static Value rankedMemRefMaterialization(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
MemRef descriptor elements -> MemRefType.
static bool isBarePointer(ValueRange values)
Helper function that checks if the given value range is a bare pointer.
static void filterByValRefArgAttrs(FunctionOpInterface funcOp, SmallVectorImpl< std::optional< NamedAttribute >> &result)
Returns the llvm.byval or llvm.byref attributes that are present in the function arguments.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMDialect * llvmDialect
Pointer to the LLVM dialect.
llvm::sys::SmartRWMutex< true > callStackMutex
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
SmallVector< Type, 5 > getMemRefDescriptorFields(MemRefType type, bool unpackAggregates) const
Convert a memref type into a list of LLVM IR types that will form the memref descriptor.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, Location loc, ArrayRef< Type > stdTypes, SmallVectorImpl< Value > &values) const
Promote the bare pointers in 'values' that resulted from memrefs to descriptors.
DenseMap< uint64_t, std::unique_ptr< SmallVector< Type > > > conversionCallStack
SmallVector< Type > & getCurrentThreadRecursiveStack()
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
Type convertCallingConventionType(Type type, bool useBarePointerCallConv=false) const
Convert a type in the context of the default or bare pointer calling convention.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
SmallVector< Value, 4 > promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder, bool useBarePtrCallConv=false) const
Promote the LLVM representation of all operands including promoting MemRef descriptors to stack and u...
LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis=nullptr)
Create an LLVMTypeConverter using the default LowerToLLVMOptions.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
SmallVector< Type, 2 > getUnrankedMemRefDescriptorFields() const
Convert an unranked memref type into a list of non-aggregate LLVM IR types that will form the unranke...
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
friend LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Give structFuncArgTypeConverter access to memref-specific functions.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
llvm::DataLayout dataLayout
The data layout of the module to produce.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class provides all of the information necessary to convert a type signature.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...