(original) (raw)
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h index c0b286494996b..ef5ff54a2f470 100644 --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -146,6 +146,12 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef values, // Map strings to float types. std::optional parseFloatType(MLIRContext *ctx, StringRef name); +// Map strings to Int types. +std::optional parseIntType(MLIRContext *ctx, StringRef name); + +// Map strings to int or float types. +std::optional parseIntOrFloatType(MLIRContext *ctx, StringRef name); + } // namespace arith } // namespace mlir diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h index 6cd6f03253aea..0b7339a94b274 100644 --- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h @@ -16,6 +16,8 @@ #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Utils/GPUUtils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include @@ -87,6 +89,24 @@ void populateGpuLowerClusteredSubgroupReduceToDPPPatterns( RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, PatternBenefit benefit = 1); +/// Set up a type converter to convert unsupported source types to +/// supported target types. +void populateImitateUnsupportedTypesTypeConverter(TypeConverter &typeConverter, + ArrayRef sourceTypes, + ArrayRef targetTypes); + +/// Collect a set of pattern needed to imitate unsupported source types +/// using supported target types. +void populateImitateUnsupportedTypesConversionPatterns( + RewritePatternSet &patterns, TypeConverter &typeConverter, + ArrayRef sourceTypes, ArrayRef targetTypes, + DenseMap<stringattr, functiontype=""> &convertedFuncTypes); + +/// Set up a dialect conversion to reject operations on unsupported +/// float types. +void configureImitateUnsupportedTypesLegality(ConversionTarget &target, + TypeConverter &typeConverter); + /// Collect all patterns to rewrite ops within the GPU dialect. inline void populateGpuRewritePatterns(RewritePatternSet &patterns) { populateGpuAllReducePatterns(patterns); diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td index 3766eb16e9429..feb1b2820abd6 100644 --- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td @@ -258,4 +258,57 @@ def GpuSPIRVAttachTarget: Pass<"spirv-attach-target", ""> { ]; } +def GpuImitateUnsupportedTypes : Pass<"imitate-unsupported-types", "::mlir::ModuleOp"> { + let summary = "Imitate unsupported types with supported types of same bitwidth."; + let description = [{ + This pass imitates (bitcast/reinterpret_cast) unsupported types + with supported types of same bitwidth. The imitation is done + by bitcasting the unspported types to the supported types of same bitwidth. + Therefore, the source type and destination type must have the same bitwidth. + The imitation is done by using the following operations: arith.bitcast. + + The imitation is often needed when the GPU target (dialect/IR) does not + support a certain type but the underlying architecture does. Take SPIR-V for + example, it does not support bf16, but an underlying architecture (e.g., + intel pvc gpu) that uses SPIR-V for code-generation does. + Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to + be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a + kernel parameter or inside the kernel), bf16 have to be bitcasted (similar + to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The + SPIR-V kernel can then use the imitated type (i16) in the computation. + However, i16 is not the same as bf16 (integer vs float), so the computation + can not readily use the imitated type (i16). + + Therefore, this transformation pass is intended to be used in conjuction + with other transformation passes such as `EmulateUnsupportedFloats` and + `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and + vice-versa. + + Finally, usually, there are instructions available in the target + (dialect/IR) that can take advantage of these generated patterns + (bf16->i16->f32, f32->bf16->i16), and convert them to the supported + types. + For example, Intel provides SPIR-V extension ops that can + take imitated bf16 (i16) and convert them to f32 and vice-versa. + https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV\_INTEL\_bfloat16\_conversion.asciidoc + https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop + https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op + + }]; + + let options = [ + ListOption<"sourceTypeStrs", "source-types", "std::string", + "MLIR types without type support on a given target">, + ListOption<"targetTypeStrs", "target-types", "std::string", + "MLIR types to convert the unsupported source types to">, + ]; + + let dependentDialects = [ + "::mlir::gpu::GPUDialect", + "::mlir::arith::ArithDialect", + "::mlir::memref::MemRefDialect" + ]; +} + + #endif // MLIR_DIALECT_GPU_PASSES diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index 6b1074e454bd5..6f2e054a34620 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -380,4 +380,29 @@ std::optional parseFloatType(MLIRContext *ctx, StringRef name) { .Default(std::nullopt); } +/// Map strings to Int types. +std::optional parseIntType(MLIRContext *ctx, StringRef name) { + Builder b(ctx); + return llvm::StringSwitchstd::optional\(name) + .Case("i1", b.getIntegerType(1)) + .Case("i2", b.getIntegerType(2)) + .Case("i4", b.getIntegerType(4)) + .Case("i6", b.getIntegerType(6)) + .Case("i8", b.getIntegerType(8)) + .Case("i16", b.getIntegerType(16)) + .Case("i32", b.getIntegerType(32)) + .Case("i64", b.getIntegerType(64)) + .Case("i80", b.getIntegerType(80)) + .Case("i128", b.getIntegerType(128)) + .Default(std::nullopt); +} +/// Map strings to Int or Float types. +std::optional parseIntOrFloatType(MLIRContext *ctx, StringRef name) { + if (auto floatTy = parseFloatType(ctx, name)) + return *floatTy; + if (auto intTy = parseIntType(ctx, name)) + return *intTy; + return std::nullopt; +} + } // namespace mlir::arith diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt index e21fa501bae6b..6d63f0d79e7d2 100644 --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -23,7 +23,7 @@ add_mlir_dialect_library(MLIRGPUDialect MLIRMemRefDialect MLIRSideEffectInterfaces MLIRSupport - ) +) add_mlir_dialect_library(MLIRGPUTransforms Transforms/AllReduceLowering.cpp @@ -42,6 +42,7 @@ add_mlir_dialect_library(MLIRGPUTransforms Transforms/SPIRVAttachTarget.cpp Transforms/SubgroupIdRewriter.cpp Transforms/SubgroupReduceLowering.cpp + Transforms/ImitateUnsupportedTypes.cpp OBJECT @@ -76,7 +77,7 @@ add_mlir_dialect_library(MLIRGPUTransforms MLIRROCDLTarget MLIRTransformUtils MLIRVectorDialect - ) +) add_subdirectory(TransformOps) add_subdirectory(Pipelines) diff --git a/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp new file mode 100644 index 0000000000000..8330214b873a2 --- /dev/null +++ b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp @@ -0,0 +1,676 @@ +//===- ImitateUnsupportedTypes.cpp - Unsupported Type Imitation ----*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// +/// \file +/// This pass imitates (bitcast/reinterpret_cast) unsupported types +/// with supported types of same bitwidth. The imitation is done +/// by bitcasting the unspported types to the supported types of same bitwidth. +/// Therefore, the source type and destination type must have the same bitwidth. +/// The imitation is done by using the following operations: arith.bitcast. +/// +/// The imitation is often needed when the GPU target (dialect/IR) does not +/// support a certain type but the underlying architecture does. Take SPIR-V for +/// example, it does not support bf16, but an underlying architecture (e.g., +/// intel pvc gpu) that uses SPIR-V for code-generation does. +/// Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to +/// be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a +/// kernel parameter or inside the kernel), bf16 have to be bitcasted (similar +/// to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The +/// SPIR-V kernel can then use the imitated type (i16) in the computation. +/// However, i16 is not the same as bf16 (integer vs float), so the computation +/// can not readily use the imitated type (i16). +/// +/// Therefore, this transformation pass is intended to be used in conjuction +/// with other transformation passes such as `EmulateUnsupportedFloats` and +/// `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and +/// vice-versa. +/// +/// Finally, usually, there are instructions available in the target +/// (dialect/IR) that can take advantage of these generated patterns +/// (bf16->i16->f32, f32->bf16->i16), and convert them to the supported +/// types. +/// For example, Intel provides SPIR-V extension ops that can +/// take imitated bf16 (i16) and convert them to f32 and vice-versa. +/// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV\_INTEL\_bfloat16\_conversion.asciidoc +/// https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop +/// https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/GPU/Transforms/Passes.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/ErrorHandling.h" + +#include +#include +#include + +using namespace mlir; +using namespace mlir::gpu; + +namespace mlir { +#define GEN_PASS_DEF_GPUIMITATEUNSUPPORTEDTYPES +#include "mlir/Dialect/GPU/Transforms/Passes.h.inc" +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +APFloat bitcastAPIntToAPFloat(const APInt &intValue, + const llvm::fltSemantics &semantics) { + // Get the bit width of the APInt. + unsigned intBitWidth = intValue.getBitWidth(); + // Get the total bit size required for the APFloat based on the semantics. + unsigned floatBitWidth = APFloat::getSizeInBits(semantics); + // Ensure the bit widths match for a direct bitcast. + assert(intBitWidth == floatBitWidth && + "Bitwidth of APInt and APFloat must match for bitcast"); + + // Get the raw bit representation of the APInt as a byte vector. + auto intWords = intValue.getRawData(); + // Create an APFloat with the specified semantics and the raw integer bits. + APFloat floatValue(semantics, APInt(intBitWidth, *intWords)); + return floatValue; +} + +// Get FloatAttr from IntegerAttr. +FloatAttr getFloatAttrFromIntegerAttr(IntegerAttr intAttr, Type dstType, + ConversionPatternRewriter &rewriter) { + APInt intVal = intAttr.getValue(); + auto floatVal = bitcastAPIntToAPFloat( + intVal, cast(dstType).getFloatSemantics()); + return rewriter.getFloatAttr(dstType, floatVal); +} +// Get IntegerAttr from FloatAttr. +IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, + ConversionPatternRewriter &rewriter) { + APFloat floatVal = floatAttr.getValue(); + APInt intVal = floatVal.bitcastToAPInt(); + return rewriter.getIntegerAttr(dstType, intVal); +} + +//===----------------------------------------------------------------------===// +// Convertion patterns +//===----------------------------------------------------------------------===// +namespace { + +//===----------------------------------------------------------------------===// +// FunctionOp conversion pattern +//===----------------------------------------------------------------------===// +template +struct ConvertFuncOp final : public OpConversionPattern { + ConvertFuncOp(MLIRContext *context, TypeConverter &typeConverter, + ArrayRef sourceTypes, ArrayRef targetTypes, + DenseMap<stringattr, functiontype=""> &convertedFuncTypes) + : OpConversionPattern(context), + typeConverter(typeConverter), // Store the reference + sourceTypes(sourceTypes), targetTypes(targetTypes), + convertedFuncTypes(convertedFuncTypes) {} + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(FuncLikeOp op, typename FuncLikeOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only handle functions a gpu.module + if (!op->template getParentOfType()) + return failure(); + FunctionType oldFuncType = op.getFunctionType(); + + // Convert function signature + TypeConverter::SignatureConversion signatureConverter( + oldFuncType.getNumInputs()); + for (const auto &argType : + llvm::enumerate(op.getFunctionType().getInputs())) { + auto convertedType = typeConverter.convertType(argType.value()); + if (!convertedType) + return failure(); + signatureConverter.addInputs(argType.index(), convertedType); + } + SmallVector<type, 4=""> newResultTypes; + for (const auto &resultType : llvm::enumerate(oldFuncType.getResults())) { + auto convertedType = typeConverter.convertType(resultType.value()); + if (!convertedType) + return failure(); + newResultTypes.push_back(convertedType); + } + + // Convert function signature + FunctionType newFuncType = rewriter.getFunctionType( + signatureConverter.getConvertedTypes(), newResultTypes); + + if (!newFuncType) + return rewriter.notifyMatchFailure(op, "could not convert function " + "type"); + + // Create new GPU function with converted type + auto newFuncOp = + rewriter.create(op.getLoc(), op.getName(), newFuncType); + + newFuncOp.setVisibility(op.getVisibility()); + // Copy attributes + for (auto attr : op->getAttrs()) { + // Skip the function_type attribute since it is already set by + // the newFuncType and we don't want to overwrite it. + if (attr.getName() != op.getFunctionTypeAttrName() && + attr.getName() != SymbolTable::getSymbolAttrName()) + newFuncOp->setAttr(attr.getName(), attr.getValue()); + } + + newFuncOp.getRegion().getBlocks().clear(); + // Inline region approach + rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + // Convert block argument types using the type converter + if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, + &signatureConverter))) { + return rewriter.notifyMatchFailure(op, "could not convert region " + "types"); + } + + if (!op.use_empty()) { + op.emitError("Cannot erase func: still has uses"); + } + for (Operation *user : op->getUsers()) { + user->emitRemark() << "User of function " << op.getName(); + } + rewriter.eraseOp(op); + // Add the converted function type to the map + newFuncOp.getNameAttr().getValue(); + convertedFuncTypes[newFuncOp.getNameAttr()] = newFuncType; + return success(); + } + +private: + TypeConverter &typeConverter; // Store a reference + ArrayRef sourceTypes; + ArrayRef targetTypes; + DenseMap<stringattr, functiontype=""> &convertedFuncTypes; +}; + +//===----------------------------------------------------------------------===// +// CallOp conversion pattern +//===----------------------------------------------------------------------===// +struct ConvertCallOp : OpConversionPattern { + ConvertCallOp(MLIRContext *context, TypeConverter &typeConverter, + const DenseMap<stringattr, functiontype=""> &convertedFuncTypes) + : OpConversionPattern(context), convertedFuncTypes(convertedFuncTypes) {} + + LogicalResult + matchAndRewrite(func::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto callee = op.getCalleeAttr(); + + auto it = convertedFuncTypes.find( + StringAttr::get(callee.getContext(), callee.getValue())); + if (it == convertedFuncTypes.end()) + return rewriter.notifyMatchFailure( + op, "Callee signature not converted. Perhaps the callee is not in " + "the same gpu module as the caller."); + + auto newResultTypes = it->second.getResults(); + rewriter.replaceOpWithNewOp( + op, callee.getValue(), newResultTypes, adaptor.getOperands()); + + return success(); + } + +private: + const DenseMap<stringattr, functiontype=""> &convertedFuncTypes; +}; + +//===----------------------------------------------------------------------===// +// GPULaunchFuncOp conversion pattern +//===----------------------------------------------------------------------===// +struct ConvertGPULaunchFuncOp : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(gpu::LaunchFuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + std::optional clusterSizeOpernads = + op.hasClusterSize() + ? std::optional(op.getClusterSizeOperandValues()) + : std::nullopt; + + // Create the new launch_func. + auto newOp = rewriter.create( + op.getLoc(), adaptor.getKernel(), op.getGridSizeOperandValues(), + op.getBlockSizeOperandValues(), op.getDynamicSharedMemorySize(), + adaptor.getKernelOperands(), op.getAsyncObject(), clusterSizeOpernads); + + // Copy block size and grid size attributes + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ArithConstantOp conversion pattern +//===----------------------------------------------------------------------===// +struct ConvertArithConstantOp : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + ConvertArithConstantOp(MLIRContext *context, TypeConverter &typeConverter, + ArrayRef sourceTypes, ArrayRef targetTypes) + : OpConversionPattern(context), + typeConverter(typeConverter), // Store the reference. + sourceTypes(sourceTypes), targetTypes(targetTypes) {} + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type srcType = op.getType(); + Type dstType = typeConverter.convertType(srcType); + if (!dstType || dstType == srcType) + return failure(); + + Attribute value = op.getValue(); + Value newConstOp = nullptr; + + // When source is IntegerAttr. + if (auto intAttr = dyn_cast(value)) { + APInt intVal = intAttr.getValue(); + if (isa(dstType)) { + auto newAttr = getFloatAttrFromIntegerAttr(intAttr, dstType, rewriter); + newConstOp = + rewriter.create(op.getLoc(), dstType, newAttr); + } else if (isa(dstType)) { + auto newAttr = rewriter.getIntegerAttr(dstType, intVal); + newConstOp = + rewriter.create(op.getLoc(), dstType, newAttr); + } else { + return rewriter.notifyMatchFailure( + op, "expected integer or float target type for constant"); + } + } + + // When source is FloatAttr. + else if (auto floatAttr = dyn_cast(value)) { + if (llvm::isa(dstType)) { + auto newAttr = + getIntegerAttrFromFloatAttr(floatAttr, dstType, rewriter); + newConstOp = + rewriter.create(op.getLoc(), dstType, newAttr); + } else if (llvm::isa(dstType)) { + auto newAttr = rewriter.getFloatAttr(dstType, floatAttr.getValue()); + newConstOp = + rewriter.create(op.getLoc(), dstType, newAttr); + } else { + return rewriter.notifyMatchFailure( + op, "expected integer or float target type for constant"); + } + } + // Handle DenseElementsAttr. + else if (auto denseAttr = dyn_cast(value)) { + Type newEltType; + if (auto shapedType = dyn_cast(dstType)) + newEltType = shapedType.getElementType(); + else + return rewriter.notifyMatchFailure( + op, "expected shaped type for dense constant"); + + SmallVector newValues; + for (Attribute attr : denseAttr.getValues()) { + if (auto intAttr = dyn_cast(attr)) { + if (llvm::isa(newEltType)) { + auto newAttr = + getFloatAttrFromIntegerAttr(intAttr, newEltType, rewriter); + newValues.push_back(newAttr); + } else if (llvm::isa(newEltType)) { + newValues.push_back( + rewriter.getIntegerAttr(newEltType, intAttr.getValue())); + } else { + return rewriter.notifyMatchFailure( + op, "unsupported target element type in dense constant"); + } + } else if (auto floatAttr = dyn_cast(attr)) { + if (llvm::isa(newEltType)) { + auto newAttr = + getIntegerAttrFromFloatAttr(floatAttr, newEltType, rewriter); + newValues.push_back(newAttr); + } else if (llvm::isa(newEltType)) + newValues.push_back( + rewriter.getFloatAttr(newEltType, floatAttr.getValue())); + else + return rewriter.notifyMatchFailure( + op, "unsupported target element type in dense constant"); + } else { + return rewriter.notifyMatchFailure( + op, "unsupported target element type in dense constant"); + } + } + + auto newAttr = + DenseElementsAttr::get(cast(dstType), newValues); + newConstOp = + rewriter.create(op.getLoc(), dstType, newAttr); + } + if (!newConstOp) + return rewriter.notifyMatchFailure( + op, "unsupported constant type for source to target conversion"); + + auto bitcastOp = + rewriter.create(op.getLoc(), srcType, newConstOp); + rewriter.replaceOp(op, bitcastOp.getResult()); + return success(); + } + +private: + TypeConverter &typeConverter; // Store a reference. + ArrayRef sourceTypes; + ArrayRef targetTypes; +}; + +//===----------------------------------------------------------------------===// +// GenericOp conversion pattern +//===----------------------------------------------------------------------===// +struct ConvertOpWithSourceType final : ConversionPattern { + ConvertOpWithSourceType(MLIRContext *context, + const TypeConverter &typeConverter, + ArrayRef sourceTypes, + ArrayRef targetTypes) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 1, context), + sourceTypes(sourceTypes), targetTypes(targetTypes) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + SmallVector<type, 4=""> newResultTypes; + for (Type t : op->getResultTypes()) { + Type converted = typeConverter->convertType(t); + if (!converted) + return failure(); + newResultTypes.push_back(converted); + } + + // Clone the op manually with the converted result types + OperationState state(op->getLoc(), op->getName().getStringRef()); + state.addOperands(operands); + state.addTypes(newResultTypes); + state.addAttributes(op->getAttrs()); + + for ([[maybe_unused]] auto ®ion : op->getRegions()) + state.regions.emplace_back(); + + Operation *newOp = rewriter.create(state); + // Transfer regions and convert them + for (auto [oldRegion, newRegion] : + llvm::zip(op->getRegions(), newOp->getRegions())) { + if (!oldRegion.empty()) { + newRegion.takeBody(oldRegion); + if (failed(rewriter.convertRegionTypes(&newRegion, *typeConverter))) { + return rewriter.notifyMatchFailure(op, + "region type conversion failed"); + } + } + } + + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } + +private: + ArrayRef sourceTypes; + ArrayRef targetTypes; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Type Converter +//===----------------------------------------------------------------------===// + +void mlir::populateImitateUnsupportedTypesTypeConverter( + TypeConverter &typeConverter, ArrayRef sourceTypes, + ArrayRef targetTypes) { + auto srcTypes = SmallVector(sourceTypes); + auto tgtTypes = SmallVector(targetTypes); + + assert(sourceTypes.size() == targetTypes.size() && + "Source and target types must have same size"); + + typeConverter.addConversion([srcTypes, tgtTypes](Type type) -> Type { + if (type.isIntOrIndexOrFloat()) { + for (auto [src, tgt] : llvm::zip_equal(srcTypes, tgtTypes)) { + if (type == src) + return tgt; + } + } else if (auto memref = llvm::dyn_cast(type)) { + Type elemType = memref.getElementType(); + for (auto [src, tgt] : llvm::zip_equal(srcTypes, tgtTypes)) { + if (elemType == src) + return MemRefType::get(memref.getShape(), tgt, memref.getLayout(), + memref.getMemorySpace()); + } + } else if (auto vec = llvm::dyn_cast(type)) { + Type elemType = vec.getElementType(); + for (auto [src, tgt] : llvm::zip_equal(srcTypes, tgtTypes)) { + if (elemType == src) + return VectorType::get(vec.getShape(), tgt); + } + } + return type; + }); + + auto materializeCast = [](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + assert(inputs.size() == 1 && "Expected single input"); + Type inputType = inputs[0].getType(); + if ((resultType.isIntOrIndexOrFloat() || isa(resultType) || + isa(resultType)) && + (inputType.isIntOrIndexOrFloat() || isa(inputType) || + isa(inputType))) { + return builder.create(loc, resultType, inputs[0]) + .getResult(); + } + return nullptr; + }; + + typeConverter.addSourceMaterialization(materializeCast); + typeConverter.addTargetMaterialization(materializeCast); +} + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// + +void mlir::populateImitateUnsupportedTypesConversionPatterns( + RewritePatternSet &patterns, TypeConverter &typeConverter, + ArrayRef sourceTypes, ArrayRef targetTypes, + DenseMap<stringattr, functiontype=""> &convertedFuncTypes) { + auto ctx = patterns.getContext(); + auto srcTypes = SmallVector(sourceTypes); + auto tgtTypes = SmallVector(targetTypes); + assert(srcTypes.size() == tgtTypes.size() && + "Source and target types must have same size"); + + patterns.add(ctx, typeConverter, srcTypes, tgtTypes); + patterns.add<convertfuncop, ConvertFuncOp>( + ctx, typeConverter, srcTypes, tgtTypes, convertedFuncTypes); + patterns.add(ctx, typeConverter, convertedFuncTypes); + patterns.add(ctx, typeConverter, srcTypes, tgtTypes); + patterns.add(ctx); +} + +//===----------------------------------------------------------------------===// +// Conversion Legality configuration +//===----------------------------------------------------------------------===// + +void mlir::configureImitateUnsupportedTypesLegality( + ConversionTarget &target, TypeConverter &typeConverter) { + target.addLegalDialect(); + target.addLegalDialect(); + // Make Memref, func dialect legal for all ops in host code + target.addDynamicallyLegalDialect([&](Operation *op) { + if (op->getParentOfType()) + return typeConverter.isLegal(op); + else + return true; + }); + + target.addDynamicallyLegalDialect([&](Operation *op) { + if (op->getParentOfType()) + return typeConverter.isLegal(op); + return true; + }); + + target.addDynamicallyLegalDialect([&](Operation *op) { + if (op->getParentOfType()) + return typeConverter.isLegal(op); + else + return true; + }); + + target.addLegalOp(); + // Manually mark arithmetic-performing vector instructions. + target.addLegalOp<vector::contractionop, vector::reductionop,="" +="" vector::multidimreductionop,="" vector::fmaop,="" vector::outerproductop,="" vector::matmulop,="" vector::scanop,="" vector::splatop="">(); + target.addDynamicallyLegalOp([&](arith::ConstantOp op) { + return typeConverter.isLegal(op.getType()); + }); + target.addDynamicallyLegalOp([&](gpu::GPUFuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()); + }); + target.addDynamicallyLegalOp( + [&](gpu::LaunchFuncOp op) { return typeConverter.isLegal(op); }); + // Only convert functions and function calls in gpu.module + target.addDynamicallyLegalOp([&](func::FuncOp op) { + if (op->getParentOfType()) + return typeConverter.isSignatureLegal(op.getFunctionType()); + return true; + }); + target.addDynamicallyLegalOp([&](func::CallOp op) { + if (op->getParentOfType()) + return typeConverter.isSignatureLegal(op.getCalleeType()); + return true; + }); + + // Mark unknown ops that are inside gpu.module, and one of its's operand is + // a memref type as dynamically legal. + target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool { + // Check if the operation is inside a gpu.module. + if (op->getParentOfType()) { + return typeConverter.isLegal(op); + } + return true; // If not in gpu.module, mark it as legal. + }); +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace { + +struct GpuImitateUnsupportedTypesPass + : public impl::GpuImitateUnsupportedTypesBase< + GpuImitateUnsupportedTypesPass> { + using Base::Base; + + SmallVector sourceTypes; + SmallVector targetTypes; + TypeConverter typeConverter; + + LogicalResult initialize(MLIRContext *ctx) override { + // Parse source types. + for (StringRef sourceTypeStr : sourceTypeStrs) { + std::optional maybeSourceType = + arith::parseIntOrFloatType(ctx, sourceTypeStr); + + if (!maybeSourceType) { + emitError(UnknownLoc::get(ctx), + "could not map source type '" + sourceTypeStr + + "' to a known integer or floating-point type."); + return failure(); + } + sourceTypes.push_back(*maybeSourceType); + } + if (sourceTypes.empty()) { + (void)emitOptionalWarning(std::nullopt, "no source types " + "specified, type " + "imitation will do " + "nothing"); + } + + // Parse target types. + for (StringRef targetTypeStr : targetTypeStrs) { + std::optional maybeTargetType = + arith::parseIntOrFloatType(ctx, targetTypeStr); + + if (!maybeTargetType) { + emitError(UnknownLoc::get(ctx), + "could not map target type '" + targetTypeStr + + "' to a known integer or floating-point type"); + return failure(); + } + targetTypes.push_back(*maybeTargetType); + + if (llvm::is_contained(sourceTypes, *maybeTargetType)) { + emitError(UnknownLoc::get(ctx), + "target type cannot be an unsupported source type"); + return failure(); + } + } + if (targetTypes.empty()) { + (void)emitOptionalWarning( + std::nullopt, + "no target types specified, type imitation will do nothing"); + } + + if (sourceTypes.size() != targetTypes.size()) { + emitError(UnknownLoc::get(ctx), + "source and target types must have the same size"); + return failure(); + } + // Set up the type converter. + populateImitateUnsupportedTypesTypeConverter(typeConverter, sourceTypes, + targetTypes); + return success(); + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + Operation *op = getOperation(); + + // Populate the conversion patterns. + RewritePatternSet patterns(ctx); + DenseMap<stringattr, functiontype=""> convertedFuncTypes; + populateImitateUnsupportedTypesConversionPatterns( + patterns, typeConverter, sourceTypes, targetTypes, convertedFuncTypes); + + // Set up conversion target and configure the legality of the conversion. + ConversionTarget target(*ctx); + configureImitateUnsupportedTypesLegality(target, typeConverter); + + // Apply the conversion. + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace diff --git a/mlir/test/Dialect/GPU/imitate-unsupported-types.mlir b/mlir/test/Dialect/GPU/imitate-unsupported-types.mlir new file mode 100644 index 0000000000000..db4d692241023 --- /dev/null +++ b/mlir/test/Dialect/GPU/imitate-unsupported-types.mlir @@ -0,0 +1,206 @@ +// RUN: mlir-opt -verify-diagnostics -imitate-unsupported-types="source-types=bf16 target-types=i16" --canonicalize -split-input-file %s | FileCheck %s + +// CHECK: module @builtin_module +module @builtin_module { + // CHECK: gpu.module @gpu_func_module + gpu.module @gpu_func_module { + // CHECK: gpu.func @arith_and_vector_ops + // CHECK-SAME: (%[[ARG0:.*]]: memref<10x10xi16>, %[[ARG1:.*]]: memref<10x10xf32>, %[[ARG2:.*]]: vector<10x10xi16>, %[[ARG3:.*]]: memref<10x10xi16>, %[[ARG4:.*]]: vector<10x10xi16>) kernel + gpu.func @arith_and_vector_ops( + %arg0: memref<10x10xbf16>, + %arg1: memref<10x10xf32>, + %arg2: vector<10x10xbf16>, + %arg3: memref<10x10xi16>, + %arg4: vector<10x10xi16> + ) kernel { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + + // CHECK: %[[CAST_ARG2:.*]] = arith.bitcast %[[ARG2]] : vector<10x10xi16> to vector<10x10xbf16> + // CHECK: %[[LOAD_ARG0:.*]] = vector.load %[[ARG0]][%[[C0]], %[[C0]]] : memref<10x10xi16>, vector<10x10xi16> + // CHECK: %[[CAST_LOAD:.*]] = arith.bitcast %[[LOAD_ARG0]] : vector<10x10xi16> to vector<10x10xbf16> + %0 = vector.load %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> + + // CHECK: %[[ADDF:.*]] = arith.addf %[[CAST_LOAD]], %[[CAST_ARG2]] : vector<10x10xbf16> + %1 = arith.addf %0, %arg2 : vector<10x10xbf16> + + // CHECK: %[[EXT0:.*]] = arith.extf %[[CAST_LOAD]] : vector<10x10xbf16> to vector<10x10xf32> + %2 = arith.extf %0 : vector<10x10xbf16> to vector<10x10xf32> + + // CHECK: %[[EXT1:.*]] = arith.extf %[[ADDF]] : vector<10x10xbf16> to vector<10x10xf32> + %3 = arith.extf %1 : vector<10x10xbf16> to vector<10x10xf32> + + // CHECK: %[[FADD:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<10x10xf32> + %4 = arith.addf %2, %3 : vector<10x10xf32> + + // CHECK: %[[TRUNC:.*]] = arith.truncf %[[FADD]] : vector<10x10xf32> to vector<10x10xbf16> + %5 = arith.truncf %4 : vector<10x10xf32> to vector<10x10xbf16> + + // CHECK: %[[CAST_TRUNC:.*]] = arith.bitcast %[[TRUNC]] : vector<10x10xbf16> to vector<10x10xi16> + // CHECK: vector.store %[[CAST_TRUNC]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<10x10xi16>, vector<10x10xi16> + vector.store %5, %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> + + // CHECK: %[[LOAD2:.*]] = vector.load %[[ARG3]][%[[C0]], %[[C0]]] : memref<10x10xi16>, vector<10x10xi16> + %6 = vector.load %arg3[%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> + + // CHECK: %[[ADDI:.*]] = arith.addi %[[LOAD2]], %[[ARG4]] : vector<10x10xi16> + %7 = arith.addi %6, %arg4 : vector<10x10xi16> + + // CHECK: vector.store %[[ADDI]], %[[ARG3]][%[[C0]], %[[C0]]] : memref<10x10xi16>, vector<10x10xi16> + vector.store %7, %arg3[%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> + + // CHECK: gpu.return + gpu.return + } + } +} + +// ----- + + +// CHECK: module @caller_callee_launch_func_module attributes {gpu.container_module} +module @caller_callee_launch_func_module attributes {gpu.container_module} { + // CHECK: gpu.module @caller_callee_gpu_module { + gpu.module @caller_callee_gpu_module attributes{} { + // CHECK: gpu.func @caller_func + // CHECK-SAME: (%[[ARG0:.*]]: memref<10x10xi16>, %[[ARG1:.*]]: vector<10x10xi16>) kernel + gpu.func @caller_func(%arg0: memref<10x10xbf16>, %arg1: vector<10x10xbf16>) kernel attributes {} { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + + // CHECK: %[[RET:.*]] = func.call @callee_constant_return() : () -> vector<10x10xi16> + %func_result = func.call @callee_constant_return() : () -> vector<10x10xbf16> + + // CHECK: vector.store %[[RET]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<10x10xi16>, vector<10x10xi16> + vector.store %func_result, %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> + + // CHECK: func.call @callee_func(%[[RET]]) : (vector<10x10xi16>) -> () + func.call @callee_func(%func_result) : (vector<10x10xbf16>) -> () + + // CHECK: gpu.return + gpu.return + } + + // CHECK: func.func @callee_constant_return() -> vector<10x10xi16> { + func.func @callee_constant_return() -> vector<10x10xbf16> { + // CHECK: %[[CST:.*]] = arith.constant dense<16128> : vector<10x10xi16> + %dense_const = arith.constant dense<5.000000e-01> : vector<10x10xbf16> + // CHECK: return %[[CST]] : vector<10x10xi16> + func.return %dense_const : vector<10x10xbf16> + } + + // CHECK: func.func @callee_func(%[[ARG:.*]]: vector<10x10xi16>) { + func.func @callee_func(%arg0: vector<10x10xbf16>) { + return + } + } + + // CHECK: func.func @gpu_launch_func( + func.func @gpu_launch_func(%arg0: memref<10x10xbf16>, %arg1: vector<10x10xbf16>) { + + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // Handling bf16 constants, dealing with constants for both cases: + // - not used in gpu.launch_func (no conversion) + // - used in gpu.launch_func (needs conversion to i16) + + // CHECK: %[[BF16_CONST:.*]] = arith.constant dense<5.000000e-01> : vector<10x10xbf16> + // CHECK: %[[I16_CONST:.*]] = arith.constant dense<16128> : vector<10x10xi16> + %dense_const = arith.constant dense<5.000000e-01> : vector<10x10xbf16> + + // CHECK: %[[BF16_CONST_2:.*]] = arith.constant dense<1.500000e+00> : vector<10x10xbf16> + %dense_const_2 = arith.constant dense<1.500000e+00> : vector<10x10xbf16> + + // CHECK: %[[ADDF:.*]] = arith.addf %arg1, %[[BF16_CONST]] : vector<10x10xbf16> + %add = arith.addf %dense_const, %arg1 : vector<10x10xbf16> + + // CHECK: vector.store %[[ADDF]], %arg0[%[[C0]], %[[C0]]] : memref<10x10xbf16>, vector<10x10xbf16> + vector.store %add, %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> + + // CHECK: %[[ALLOC:.*]] = gpu.alloc () : memref<10x10xbf16> + %alloc = gpu.alloc () : memref<10x10xbf16> + // CHECK: %[[BITCAST:.*]] = arith.bitcast %[[ALLOC]] : memref<10x10xbf16> to memref<10x10xi16> + // CHECK: vector.store %[[BF16_CONST_2]], %[[ALLOC]][%[[C0]], %[[C0]]] : memref<10x10xbf16>, vector<10x10xbf16> + vector.store %dense_const_2, %alloc[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> + + + // CHECK: gpu.launch_func @caller_callee_gpu_module::@caller_func + // CHECK-SAME: args(%[[BITCAST]] : memref<10x10xi16>, %[[I16_CONST]] : vector<10x10xi16>) + gpu.launch_func @caller_callee_gpu_module::@caller_func + blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) + args(%alloc: memref<10x10xbf16>, %dense_const: vector<10x10xbf16>) + return + } +} + +// ----- + + +// CHECK: #map = affine_map<(d0, d1) -> (d1, d0)> +#map = affine_map<(d0, d1) -> (d1, d0)> + +// CHECK-LABEL: module @module_multi_level_call attributes {gpu.container_module} { +module @module_multi_level_call attributes {gpu.container_module} { + // CHECK: gpu.module @gpu_module_multi_level_call { + gpu.module @gpu_module_multi_level_call { + // CHECK: gpu.func @kernel(%[[K_ARG:.*]]: memref<10x10xi16>) kernel { + gpu.func @kernel(%arg0: memref<10x10xi16>) kernel { + // CHECK: gpu.return + gpu.return + } + + // CHECK: gpu.func @affine_memref_arg(%[[AFF_ARG:.*]]: memref<100x100xi16, #map, 2>) kernel { + gpu.func @affine_memref_arg(%arg0: memref<100x100xi16, #map, 2>) kernel { + // CHECK: gpu.return + gpu.return + } + } + + // CHECK-LABEL: func.func @gpu_launch_func + func.func @gpu_launch_func(%arg0: memref<10x10xbf16>, %arg1: memref<100x100xbf16, #map, 2>) { + // CHECK: %[[C1:.*]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + + // CHECK: %[[AFF_CAST:.*]] = arith.bitcast %[[ARG1:.*]] : memref<100x100xbf16, #map, 2> to memref<100x100xi16, #map, 2> + %0 = arith.bitcast %arg1 : memref<100x100xbf16, #map, 2> to memref<100x100xi16, #map, 2> + + // CHECK: %[[BF16_CAST:.*]] = arith.bitcast %[[ARG0:.*]] : memref<10x10xbf16> to memref<10x10xi16> + %1 = arith.bitcast %arg0 : memref<10x10xbf16> to memref<10x10xi16> + + // CHECK: gpu.launch_func @gpu_module_multi_level_call::@kernel + // CHECK-SAME: args(%[[BF16_CAST]] : memref<10x10xi16>) + gpu.launch_func @gpu_module_multi_level_call::@kernel + blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) + args(%1 : memref<10x10xi16>) + + // CHECK: gpu.launch_func @gpu_module_multi_level_call::@affine_memref_arg + // CHECK-SAME: args(%[[AFF_CAST]] : memref<100x100xi16, #map, 2>) + gpu.launch_func @gpu_module_multi_level_call::@affine_memref_arg + blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) + args(%0 : memref<100x100xi16, #map, 2>) + // CHECK: return + return + } + + // CHECK-LABEL: func.func @main + func.func @main() { + // CHECK: %[[ALLOC0:.*]] = memref.alloc() : memref<10x10xbf16> + %alloc = memref.alloc() : memref<10x10xbf16> + // CHECK: %[[ALLOC1:.*]] = memref.alloc() : memref<100x100xbf16, #map, 2> + %alloc_0 = memref.alloc() : memref<100x100xbf16, #map, 2> + // CHECK: call @gpu_launch_func(%[[ALLOC0]], %[[ALLOC1]]) + call @gpu_launch_func(%alloc, %alloc_0) : (memref<10x10xbf16>, memref<100x100xbf16, #map, 2>) -> () + // CHECK: memref.dealloc %[[ALLOC0]] + memref.dealloc %alloc : memref<10x10xbf16> + // CHECK: memref.dealloc %[[ALLOC1]] + memref.dealloc %alloc_0 : memref<100x100xbf16, #map, 2> + // CHECK: return + return + } +} + + + </stringattr,></vector::contractionop,></convertfuncop</stringattr,></type,></stringattr,></stringattr,></stringattr,></type,></stringattr,></std::optional</stringattr,>