(original) (raw)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index ce97847172197..91f77307ddf8b 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -392,18 +392,29 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, void populateVectorTransposeNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); -/// Populates patterns for ND vectors (N >= 2) linearization and sets up the -/// provided ConversionTarget with the appropriate legality configuration for -/// the ops to get converted properly. -void populateVectorLinearizeTypeConversionsAndLegality( - TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, unsigned targetBitWidth); - -/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D -/// vector shuffle operations. -void populateVectorLinearizeShuffleLikeOpsPatterns( - const TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, unsigned targetBitWidth); +/// Initialize `typeConverter` and `conversionTarget` for vector linearization. +/// This registers (1) which operations are legal and hence should not be +/// linearized, (2) what converted types are (rank-1 vectors) and how to +/// materialze the conversion (with shape_cast) +/// +/// Note: the set of legal operations can be extended by a user if for example +/// certain rank>1 vectors are considered valid, but adding additional +/// dynamically legal ops to `conversionTarget`. +void populateForVectorLinearize(TypeConverter &typeConverter, + ConversionTarget &conversionTarget); + +/// Populates `patterns` for ND vector (N >= 2) linearization. This currently +/// contains patterns for converting ConstantLike, Vectorizable, and +/// vector::BitCast ops. +void populateVectorLinearizeBasePatterns(const TypeConverter &, + const ConversionTarget &, + RewritePatternSet &patterns); + +/// Populates `patterns` for linearizing ND (N >= 2) vector operations +/// to 1D vector shuffle operations. +void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &, + const ConversionTarget &, + RewritePatternSet &patterns); } // namespace vector } // namespace mlir diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index a009aa03aaf64..67e15852dc5ea 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -10,7 +10,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" @@ -22,44 +21,16 @@ #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" #include +#include #include +#include using namespace mlir; -static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { - auto resultTypes = op->getResultTypes(); - for (auto resType : resultTypes) { - VectorType vecType = dyn_cast(resType); - // Reject index since getElementTypeBitWidth will abort for Index types. - if (!vecType || vecType.getElementType().isIndex()) - return false; - // There are no dimension to fold if it is a 0-D vector. - if (vecType.getRank() == 0) - return false; - unsigned trailingVecDimBitWidth = - vecType.getShape().back() * vecType.getElementTypeBitWidth(); - if (trailingVecDimBitWidth >= targetBitWidth) - return false; - } - return true; -} - -static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) { - VectorType vecType = dyn_cast(t); - // Reject index since getElementTypeBitWidth will abort for Index types. - if (!vecType || vecType.getElementType().isIndex()) - return false; - // There are no dimension to fold if it is a 0-D vector. - if (vecType.getRank() == 0) - return false; - unsigned trailingVecDimBitWidth = - vecType.getShape().back() * vecType.getElementTypeBitWidth(); - return trailingVecDimBitWidth <= targetBitWidth; -} - static FailureOr linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, VectorType resType, Attribute value) { + if (auto dstElementsAttr = dyn_cast(value)) { if (resType.isScalable() && !isa(value)) return rewriter.notifyMatchFailure( @@ -76,16 +47,14 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, } namespace { + struct LinearizeConstantLike final : OpTraitConversionPattern { using OpTraitConversionPattern::OpTraitConversionPattern; - LinearizeConstantLike( - const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), - PatternBenefit benefit = 1) - : OpTraitConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + LinearizeConstantLike(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpTraitConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { @@ -93,16 +62,10 @@ struct LinearizeConstantLike final if (op->getNumResults() != 1) return rewriter.notifyMatchFailure(loc, "expected 1 result"); - const TypeConverter &converter = *getTypeConverter(); + const TypeConverter &typeConverter = *getTypeConverter(); auto resType = - converter.convertType(op->getResult(0).getType()); - - if (!resType) - return rewriter.notifyMatchFailure(loc, "can't convert return type"); - - if (!isLessThanTargetBitWidth(op, targetVectorBitWidth)) - return rewriter.notifyMatchFailure( - loc, "Can't flatten since targetBitWidth <= OpSize"); + typeConverter.convertType(op->getResult(0).getType()); + assert(resType && "expected 1-D vector type"); StringAttr attrName = rewriter.getStringAttr("value"); Attribute value = op->getAttr(attrName); @@ -115,7 +78,7 @@ struct LinearizeConstantLike final return failure(); FailureOr convertResult = - convertOpResultTypes(op, /*operands=*/{}, converter, rewriter); + convertOpResultTypes(op, /*operands=*/{}, typeConverter, rewriter); if (failed(convertResult)) return failure(); @@ -124,9 +87,6 @@ struct LinearizeConstantLike final rewriter.replaceOp(op, newOp); return success(); } - -private: - unsigned targetVectorBitWidth; }; struct LinearizeVectorizable final @@ -134,18 +94,12 @@ struct LinearizeVectorizable final using OpTraitConversionPattern::OpTraitConversionPattern; public: - LinearizeVectorizable( - const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), - PatternBenefit benefit = 1) - : OpTraitConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + LinearizeVectorizable(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpTraitConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!isLessThanTargetBitWidth(op, targetVectorBitWidth)) - return rewriter.notifyMatchFailure( - op->getLoc(), "Can't flatten since targetBitWidth <= OpSize"); FailureOr newOp = convertOpResultTypes(op, operands, *getTypeConverter(), rewriter); if (failed(newOp)) @@ -154,9 +108,6 @@ struct LinearizeVectorizable final rewriter.replaceOp(op, (*newOp)->getResults()); return success(); } - -private: - unsigned targetVectorBitWidth; }; /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works @@ -173,12 +124,10 @@ struct LinearizeVectorizable final struct LinearizeVectorExtractStridedSlice final : public mlir::OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LinearizeVectorExtractStridedSlice( - const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter, + MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, @@ -189,9 +138,6 @@ struct LinearizeVectorExtractStridedSlice final if (extractOp.getVector().getType().isScalable() || dstType.isScalable()) return rewriter.notifyMatchFailure(extractOp, "scalable vectors are not supported."); - if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) - return rewriter.notifyMatchFailure( - extractOp, "Can't flatten since targetBitWidth <= OpSize"); ArrayAttr offsets = extractOp.getOffsets(); ArrayAttr sizes = extractOp.getSizes(); @@ -268,9 +214,6 @@ struct LinearizeVectorExtractStridedSlice final extractOp, dstType, srcVector, srcVector, indices); return success(); } - -private: - unsigned targetVectorBitWidth; }; /// This pattern converts the ShuffleOp that works on nD (n > 1) @@ -291,8 +234,7 @@ struct LinearizeVectorShuffle final const TypeConverter &typeConverter, MLIRContext *context, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, @@ -300,15 +242,6 @@ struct LinearizeVectorShuffle final VectorType dstType = getTypeConverter()->convertType(shuffleOp.getType()); assert(dstType && "vector type destination expected."); - // The assert is used because vector.shuffle does not support scalable - // vectors. - assert(!(shuffleOp.getV1VectorType().isScalable() || - shuffleOp.getV2VectorType().isScalable() || - dstType.isScalable()) && - "scalable vectors are not supported."); - if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth)) - return rewriter.notifyMatchFailure( - shuffleOp, "Can't flatten since targetBitWidth <= OpSize"); Value vec1 = adaptor.getV1(); Value vec2 = adaptor.getV2(); @@ -327,7 +260,7 @@ struct LinearizeVectorShuffle final } // For each value in the mask, we generate the indices of the source vectors - // that needs to be shuffled to the destination vector. If shuffleSliceLen > + // that need to be shuffled to the destination vector. If shuffleSliceLen > // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of // elements) instead of scalars. ArrayRef mask = shuffleOp.getMask(); @@ -343,9 +276,6 @@ struct LinearizeVectorShuffle final vec2, indices); return success(); } - -private: - unsigned targetVectorBitWidth; }; /// This pattern converts the ExtractOp to a ShuffleOp that works on a @@ -364,23 +294,12 @@ struct LinearizeVectorExtract final const TypeConverter &typeConverter, MLIRContext *context, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type dstTy = getTypeConverter()->convertType(extractOp.getType()); - if (!dstTy) - return rewriter.notifyMatchFailure(extractOp, - "expected n-D vector type."); - - if (extractOp.getVector().getType().isScalable() || - cast(dstTy).isScalable()) - return rewriter.notifyMatchFailure(extractOp, - "scalable vectors are not supported."); - if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) - return rewriter.notifyMatchFailure( - extractOp, "Can't flatten since targetBitWidth <= OpSize"); + assert(dstTy && "expected 1-D vector type"); // Dynamic position is not supported. if (extractOp.hasDynamicPosition()) @@ -405,9 +324,6 @@ struct LinearizeVectorExtract final return success(); } - -private: - unsigned targetVectorBitWidth; }; /// This pattern converts the InsertOp to a ShuffleOp that works on a @@ -427,22 +343,13 @@ struct LinearizeVectorInsert final const TypeConverter &typeConverter, MLIRContext *context, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType dstTy = getTypeConverter()->convertType( insertOp.getDestVectorType()); assert(dstTy && "vector type destination expected."); - if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable()) - return rewriter.notifyMatchFailure(insertOp, - "scalable vectors are not supported."); - - if (!isLessThanOrEqualTargetBitWidth(insertOp.getValueToStoreType(), - targetVectorBitWidth)) - return rewriter.notifyMatchFailure( - insertOp, "Can't flatten since targetBitWidth < OpSize"); // dynamic position is not supported if (insertOp.hasDynamicPosition()) @@ -471,11 +378,11 @@ struct LinearizeVectorInsert final } llvm::SmallVector<int64_t, 2=""> indices(dstSize); - auto origValsUntil = indices.begin(); + auto *origValsUntil = indices.begin(); std::advance(origValsUntil, linearizedOffset); std::iota(indices.begin(), origValsUntil, 0); // original values that remain [0, offset) - auto newValsUntil = origValsUntil; + auto *newValsUntil = origValsUntil; std::advance(newValsUntil, srcSize); std::iota(origValsUntil, newValsUntil, dstSize); // new values [offset, offset+srcNumElements) @@ -488,9 +395,6 @@ struct LinearizeVectorInsert final return success(); } - -private: - unsigned targetVectorBitWidth; }; /// This pattern converts the BitCastOp that works on nD (n > 1) @@ -508,82 +412,111 @@ struct LinearizeVectorBitCast final const TypeConverter &typeConverter, MLIRContext *context, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = castOp.getLoc(); auto resType = getTypeConverter()->convertType(castOp.getType()); - if (!resType) - return rewriter.notifyMatchFailure(loc, "can't convert return type."); - - if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth)) - return rewriter.notifyMatchFailure( - loc, "Can't flatten since targetBitWidth <= OpSize"); - + assert(resType && "expected 1-D vector type"); rewriter.replaceOpWithNewOp(castOp, resType, adaptor.getSource()); return mlir::success(); } - -private: - unsigned targetVectorBitWidth; }; } // namespace -void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( - TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, unsigned targetBitWidth) { +/// Return true if the operation `op` does not support scalable vectors and +/// has at least 1 scalable vector result. These ops should all eventually +/// support scalable vectors, and this function should be removed. +static bool isNotLinearizableBecauseScalable(Operation *op) { + + bool unsupported = + isa<vector::extractstridedsliceop, vector::extractop,="" vector::insertop="">( + op); + if (!unsupported) + return false; + + // Check if any of the results is a scalable vector type. + auto types = op->getResultTypes(); + bool containsScalableResult = + std::any_of(types.begin(), types.end(), [](Type type) { + auto vecType = dyn_cast(type); + return vecType && vecType.isScalable(); + }); + + return containsScalableResult; +} + +static bool isNotLinearizable(Operation *op) { + + // Only ops that are in the vector dialect, are ConstantLike, or + // are Vectorizable might be linearized currently. + StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace(); + StringRef opDialect = op->getDialect()->getNamespace(); + bool unsupported = (opDialect != vectorDialect) && + !op->hasTrait() && + !op->hasTrait(); + if (unsupported) + return true; + + // Some ops currently don't support scalable vectors. + if (isNotLinearizableBecauseScalable(op)) + return true; + + return false; +} - typeConverter.addConversion([](VectorType type) -> std::optional { - if (!isLinearizableVector(type)) +void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter, + ConversionTarget &target) { + + auto convertType = [](Type type) -> std::optional { + VectorType vectorType = dyn_cast(type); + if (!vectorType || !isLinearizableVector(vectorType)) return type; - return VectorType::get(type.getNumElements(), type.getElementType(), - type.isScalable()); - }); + VectorType linearizedType = + VectorType::get(vectorType.getNumElements(), + vectorType.getElementType(), vectorType.isScalable()); + return linearizedType; + }; + typeConverter.addConversion(convertType); auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { - if (inputs.size() != 1 || !isa(inputs.front().getType()) || - !isa(type)) + if (inputs.size() != 1) + return nullptr; + + Value value = inputs.front(); + if (!isa(type) || !isa(value.getType())) return nullptr; - return builder.create(loc, type, inputs.front()); + return builder.create(loc, type, value); }; typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); + target.markUnknownOpDynamicallyLegal( [=](Operation *op) -> std::optional { - if ((isa(op) || - op->hasTrait() || - op->hasTrait())) { - return (isLessThanTargetBitWidth(op, targetBitWidth) - ? typeConverter.isLegal(op) - : true); - } - return std::nullopt; + if (isNotLinearizable(op)) + return true; + // This will return true if, for all operand and result types `t`, + // convertType(t) = t. This is true if there are no rank>=2 vectors. + return typeConverter.isLegal(op); }); +} +void mlir::vector::populateVectorLinearizeBasePatterns( + const TypeConverter &typeConverter, const ConversionTarget &target, + RewritePatternSet &patterns) { patterns.add<linearizeconstantlike, linearizevectorizable,="" -="" linearizevectorbitcast="">(typeConverter, patterns.getContext(), - targetBitWidth); + LinearizeVectorBitCast>(typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( - const TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, unsigned int targetBitWidth) { - target.addDynamicallyLegalOp( - [=](vector::ShuffleOp shuffleOp) -> bool { - return isLessThanTargetBitWidth(shuffleOp, targetBitWidth) - ? (typeConverter.isLegal(shuffleOp) && - cast(shuffleOp.getResult().getType()) - .getRank() == 1) - : true; - }); + const TypeConverter &typeConverter, const ConversionTarget &target, + RewritePatternSet &patterns) { patterns.add<linearizevectorshuffle, linearizevectorextract,="" linearizevectorinsert,="" linearizevectorextractstridedslice="">( - typeConverter, patterns.getContext(), targetBitWidth); + typeConverter, patterns.getContext()); } diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 9052c6440e6ac..06eaf58b225ae 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT -// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128 -// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0 + +// RUN: mlir-opt %s -split-input-file -test-bit-width-constrained-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128 +// RUN: mlir-opt %s -split-input-file -test-bit-width-constrained-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0 // ALL-LABEL: test_linearize // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>) @@ -97,7 +98,7 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32> // ALL-LABEL: test_index_no_linearize func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> { - // ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> + // BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> %0 = arith.addi %arg0, %arg1 : vector<2x2xindex> return %0 : vector<2x2xindex> } @@ -171,6 +172,7 @@ func.func @test_0d_vector() -> vector { } // ----- + // ALL-LABEL: test_extract_strided_slice_1 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> { func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> { @@ -193,6 +195,8 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf return %0 : vector<2x2xf32> } +// ----- + // ALL-LABEL: func.func @test_extract_strided_slice_1_scalable( // ALL-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> { func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> { @@ -205,6 +209,7 @@ func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> ve } // ----- + // ALL-LABEL: test_extract_strided_slice_2 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> { func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> { @@ -228,6 +233,7 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4 } // ----- + // ALL-LABEL: test_vector_shuffle // ALL-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> { func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> { @@ -252,6 +258,7 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) - } // ----- + // ALL-LABEL: test_vector_extract // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> { func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> { @@ -273,6 +280,8 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> { return %0 : vector<8x2xf32> } +// ----- + // ALL-LABEL: func.func @test_vector_extract_scalable( // ALL-SAME: %[[VAL_0:.*]]: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> { func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> { @@ -283,7 +292,9 @@ func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x // ALL: return %[[RES]] : vector<8x[2]xf32> return %0 : vector<8x[2]xf32> } + // ----- + // ALL-LABEL: test_vector_insert // ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> { func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> { @@ -312,6 +323,8 @@ func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) return %0 : vector<2x8x4xf32> } +// ----- + // ALL-LABEL: func.func @test_vector_insert_scalable( // ALL-SAME: %[[VAL_0:.*]]: vector<2x8x[4]xf32>, %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> { func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> { @@ -385,6 +398,7 @@ func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> { } // ----- + // ALL-LABEL: test_vector_bitcast // ALL-SAME: %[[ARG_0:.*]]: vector<[4]x2xf32> func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> { diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index a54ae816570a8..17137819f03f1 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -7,17 +7,13 @@ //===----------------------------------------------------------------------===// #include -#include #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -839,16 +835,98 @@ struct TestVectorEmulateMaskedLoadStore final } }; -struct TestVectorLinearize final - : public PassWrapper<testvectorlinearize, operationpass<="">> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) +// TODO: move this code into the user project. +namespace vendor { - TestVectorLinearize() = default; - TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {} +/// Get the set of operand/result types to check for sufficiently +/// small inner-most dimension size. +static SmallVector<std::pair<type, unsigned="">> +getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) { - StringRef getArgument() const override { return "test-vector-linearize"; } + if (auto insertOp = dyn_cast(op)) { + unsigned w = targetBitWidth < std::numeric_limits::max() + ? targetBitWidth + 1 + : targetBitWidth; + return {{insertOp.getValueToStoreType(), w}}; + } + + auto resultTypes = op->getResultTypes(); + SmallVector<std::pair<type, unsigned="">> resultsWithBitWidth; + resultsWithBitWidth.reserve(resultTypes.size()); + for (Type type : resultTypes) { + resultsWithBitWidth.push_back({type, targetBitWidth}); + } + return resultsWithBitWidth; +} + +/// If `type` is VectorType with trailing dimension of (bit) size greater than +/// or equal to `targetBitWidth`, its defining op is considered legal. +static bool +isNotLinearizableBecauseLargeInnerDimension(Type type, + unsigned targetBitWidth) { + + VectorType vecType = dyn_cast(type); + + // Not linearizable for reasons other than what this function checks. + if (!vecType || vecType.getRank() == 0) + return false; + + // The width of the type 'index' is unbounded (and therefore potentially above + // the target width). + if (vecType.getElementType().isIndex()) + return true; + + unsigned finalDimSize = vecType.getShape().back(); + unsigned nbBitsPerElm = vecType.getElementTypeBitWidth(); + unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm; + return trailingVecDimBitWidth >= targetBitWidth; +} + +static bool +isNotLinearizableBecauseLargeInnerDimension(Operation *op, + unsigned targetBitWidth) { + // Check on bitwidths. + SmallVector<std::pair<type, unsigned="">> toCheck = + getTypeBitWidthBoundPairs(op, targetBitWidth); + return std::any_of(toCheck.begin(), toCheck.end(), + [&](std::pair<type, unsigned=""> typeWidth) { + return isNotLinearizableBecauseLargeInnerDimension( + typeWidth.first, typeWidth.second); + }); +} + +void populateWithBitWidthConstraints(TypeConverter &typeConverter, + ConversionTarget &target, + unsigned targetBitWidth) { + + // The general purpose definition of what ops are legal must come first. + populateForVectorLinearize(typeConverter, target); + + // Extend the set of legal ops to include those with large inner-most + // dimensions on selected operands/results. + target.markUnknownOpDynamicallyLegal( + [=](Operation *op) -> std::optional { + if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) { + return true; + } + return {}; + }); +} + +struct TestVectorBitWidthLinearize final + : public PassWrapper<testvectorbitwidthlinearize, operationpass<="">> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize) + + TestVectorBitWidthLinearize() = default; + TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass) + : PassWrapper(pass) {} + + StringRef getArgument() const override { + return "test-bit-width-constrained-vector-linearize"; + } StringRef getDescription() const override { - return "Linearizes ND vectors for N >= 2 into 1D vectors"; + return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints " + "in inner-most dimension's bit width."; } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -866,10 +944,49 @@ struct TestVectorLinearize final RewritePatternSet patterns(context); ConversionTarget target(*context); - vector::populateVectorLinearizeTypeConversionsAndLegality( - typeConverter, patterns, target, targetVectorBitwidth); - vector::populateVectorLinearizeShuffleLikeOpsPatterns( - typeConverter, patterns, target, targetVectorBitwidth); + populateWithBitWidthConstraints(typeConverter, target, + targetVectorBitwidth); + + vector::populateVectorLinearizeBasePatterns(typeConverter, target, + patterns); + + vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target, + patterns); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace vendor + +struct TestVectorLinearize final + : public PassWrapper<testvectorlinearize, operationpass<="">> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) + + TestVectorLinearize() = default; + + StringRef getArgument() const override { return "test-vector-linearize"; } + StringRef getDescription() const override { + return "Linearizes ND vectors for N >= 2 into 1D vectors"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext &context = getContext(); + TypeConverter converter; + RewritePatternSet patterns(&context); + ConversionTarget target(context); + + vector::populateForVectorLinearize(converter, target); + + vector::populateVectorLinearizeBasePatterns(converter, target, patterns); + vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target, + patterns); + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); @@ -949,6 +1066,8 @@ void registerTestVectorLowerings() { PassRegistration(); + PassRegistration(); + PassRegistration(); } } // namespace test</testvectorlinearize,></testvectorbitwidthlinearize,></type,></std::pair<type,></std::pair<type,></std::pair<type,></testvectorlinearize,></linearizevectorshuffle,></linearizeconstantlike,></vector::extractstridedsliceop,></int64_t,>