[mlir][vector] Linearization: push 'bit width' logic out of patterns by newling · Pull Request #136581 · llvm/llvm-project (original) (raw)

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

Vector linearization is a collection of rewrite patterns that reduce the rank of vector operands and results of operations.

In #83314 an option to ignore (legalize) operations with large inner-most dimensions was added. This current PR is a step towards making that option live outside of upstream MLIR. The motivation is to reduce non-core functionality (I would like to use this pass, but would prefer not to deal with 'targetVectorBitWidth` at all).

As a follow-up to this PR, I propose that user(s) of the targetVectorBitWidth move legalBecauseOfBitwidth to their code bases, and then remove it from upstream.

The approach I've used is to move the logic pertaining to targetVectorBitWidth out the patterns, and into the conversion target, which the end user can control outside of core MLIR.


Patch is 26.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136581.diff

4 Files Affected:

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index ce97847172197..d9a0791cdea33 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -392,18 +392,24 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, void populateVectorTransposeNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); -/// Populates patterns for ND vectors (N >= 2) linearization and sets up the -/// provided ConversionTarget with the appropriate legality configuration for -/// the ops to get converted properly. -void populateVectorLinearizeTypeConversionsAndLegality( - TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, unsigned targetBitWidth);

-static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) { - VectorType vecType = dyn_cast(t); - // Reject index since getElementTypeBitWidth will abort for Index types. - if (!vecType || vecType.getElementType().isIndex()) - return false; - // There are no dimension to fold if it is a 0-D vector. - if (vecType.getRank() == 0) - return false; - unsigned trailingVecDimBitWidth = - vecType.getShape().back() * vecType.getElementTypeBitWidth(); - return trailingVecDimBitWidth <= targetBitWidth; -}

static FailureOr linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, VectorType resType, Attribute value) { + if (auto dstElementsAttr = dyn_cast(value)) { if (resType.isScalable() && !isa(value)) return rewriter.notifyMatchFailure( @@ -76,16 +47,14 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, } namespace { + struct LinearizeConstantLike final : OpTraitConversionPatternOpTrait::ConstantLike { using OpTraitConversionPattern::OpTraitConversionPattern; - LinearizeConstantLike( - const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), - PatternBenefit benefit = 1) - : OpTraitConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + LinearizeConstantLike(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpTraitConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { @@ -100,10 +69,6 @@ struct LinearizeConstantLike final if (!resType) return rewriter.notifyMatchFailure(loc, "can't convert return type"); - if (!isLessThanTargetBitWidth(op, targetVectorBitWidth)) - return rewriter.notifyMatchFailure( - loc, "Can't flatten since targetBitWidth <= OpSize");

 StringAttr attrName = rewriter.getStringAttr("value");
 Attribute value = op->getAttr(attrName);
 if (!value)

@@ -124,9 +89,6 @@ struct LinearizeConstantLike final rewriter.replaceOp(op, newOp); return success(); }

-private: - unsigned targetVectorBitWidth; }; struct LinearizeVectorizable final @@ -134,18 +96,12 @@ struct LinearizeVectorizable final using OpTraitConversionPattern::OpTraitConversionPattern; public: - LinearizeVectorizable( - const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), - PatternBenefit benefit = 1) - : OpTraitConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + LinearizeVectorizable(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpTraitConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!isLessThanTargetBitWidth(op, targetVectorBitWidth)) - return rewriter.notifyMatchFailure( - op->getLoc(), "Can't flatten since targetBitWidth <= OpSize"); FailureOr<Operation *> newOp = convertOpResultTypes(op, operands, *getTypeConverter(), rewriter); if (failed(newOp)) @@ -154,9 +110,6 @@ struct LinearizeVectorizable final rewriter.replaceOp(op, (*newOp)->getResults()); return success(); }

-private: - unsigned targetVectorBitWidth; }; /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works @@ -173,12 +126,10 @@ struct LinearizeVectorizable final struct LinearizeVectorExtractStridedSlice final : public mlir::OpConversionPatternmlir::vector::ExtractStridedSliceOp { using OpConversionPattern::OpConversionPattern; - LinearizeVectorExtractStridedSlice( - const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter, + MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, @@ -189,9 +140,6 @@ struct LinearizeVectorExtractStridedSlice final if (extractOp.getVector().getType().isScalable() || dstType.isScalable()) return rewriter.notifyMatchFailure(extractOp, "scalable vectors are not supported."); - if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) - return rewriter.notifyMatchFailure( - extractOp, "Can't flatten since targetBitWidth <= OpSize"); ArrayAttr offsets = extractOp.getOffsets(); ArrayAttr sizes = extractOp.getSizes(); @@ -268,9 +216,6 @@ struct LinearizeVectorExtractStridedSlice final extractOp, dstType, srcVector, srcVector, indices); return success(); }

-private: - unsigned targetVectorBitWidth; }; /// This pattern converts the ShuffleOp that works on nD (n > 1) @@ -291,8 +236,7 @@ struct LinearizeVectorShuffle final const TypeConverter &typeConverter, MLIRContext *context, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, @@ -302,13 +246,12 @@ struct LinearizeVectorShuffle final assert(dstType && "vector type destination expected."); // The assert is used because vector.shuffle does not support scalable // vectors. - assert(!(shuffleOp.getV1VectorType().isScalable() || - shuffleOp.getV2VectorType().isScalable() || - dstType.isScalable()) && - "scalable vectors are not supported."); - if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth)) - return rewriter.notifyMatchFailure( - shuffleOp, "Can't flatten since targetBitWidth <= OpSize"); + bool scalable = shuffleOp.getV1VectorType().isScalable() || + shuffleOp.getV2VectorType().isScalable() || + dstType.isScalable(); + if (scalable) + return rewriter.notifyMatchFailure(shuffleOp, + "scalable vectors are not supported."); Value vec1 = adaptor.getV1(); Value vec2 = adaptor.getV2(); @@ -343,9 +286,6 @@ struct LinearizeVectorShuffle final vec2, indices); return success(); }

-private: - unsigned targetVectorBitWidth; }; /// This pattern converts the ExtractOp to a ShuffleOp that works on a @@ -364,8 +304,7 @@ struct LinearizeVectorExtract final const TypeConverter &typeConverter, MLIRContext *context, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -378,9 +317,6 @@ struct LinearizeVectorExtract final cast(dstTy).isScalable()) return rewriter.notifyMatchFailure(extractOp, "scalable vectors are not supported."); - if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) - return rewriter.notifyMatchFailure( - extractOp, "Can't flatten since targetBitWidth <= OpSize"); // Dynamic position is not supported. if (extractOp.hasDynamicPosition()) @@ -405,9 +341,6 @@ struct LinearizeVectorExtract final return success(); }

-private: - unsigned targetVectorBitWidth; }; /// This pattern converts the InsertOp to a ShuffleOp that works on a @@ -427,8 +360,7 @@ struct LinearizeVectorInsert final const TypeConverter &typeConverter, MLIRContext *context, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -439,11 +371,6 @@ struct LinearizeVectorInsert final return rewriter.notifyMatchFailure(insertOp, "scalable vectors are not supported."); - if (!isLessThanOrEqualTargetBitWidth(insertOp.getValueToStoreType(), - targetVectorBitWidth)) - return rewriter.notifyMatchFailure( - insertOp, "Can't flatten since targetBitWidth < OpSize");

 // dynamic position is not supported
 if (insertOp.hasDynamicPosition())
   return rewriter.notifyMatchFailure(insertOp,

@@ -471,11 +398,11 @@ struct LinearizeVectorInsert final } llvm::SmallVector<int64_t, 2> indices(dstSize); - auto origValsUntil = indices.begin(); + auto *origValsUntil = indices.begin(); std::advance(origValsUntil, linearizedOffset); std::iota(indices.begin(), origValsUntil, 0); // original values that remain [0, offset) - auto newValsUntil = origValsUntil; + auto *newValsUntil = origValsUntil; std::advance(newValsUntil, srcSize); std::iota(origValsUntil, newValsUntil, dstSize); // new values [offset, offset+srcNumElements) @@ -488,9 +415,6 @@ struct LinearizeVectorInsert final return success(); }

-private: - unsigned targetVectorBitWidth; }; /// This pattern converts the BitCastOp that works on nD (n > 1) @@ -508,8 +432,7 @@ struct LinearizeVectorBitCast final const TypeConverter &typeConverter, MLIRContext *context, unsigned targetVectBitWidth = std::numeric_limits::max(), PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -518,24 +441,103 @@ struct LinearizeVectorBitCast final if (!resType) return rewriter.notifyMatchFailure(loc, "can't convert return type."); - if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth)) - return rewriter.notifyMatchFailure( - loc, "Can't flatten since targetBitWidth <= OpSize");

 rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
                                                adaptor.getSource());
 return mlir::success();

}

 return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());

}; + typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); + target.markUnknownOpDynamicallyLegal( [=](Operation *op) -> std::optional {

+}

+void mlir::vector::populateVectorLinearizeBasePatterns(

[truncated]