[mlir][Vector] Add vector bitwidth target to xfer op flattening by dcaballe · Pull Request #81966 · llvm/llvm-project (original) (raw)
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir
Author: Diego Caballero (dcaballe)
Changes
This PR adds an optional bitwidth parameter to the vector xfer op flattening transformation so that the flattening doesn't happen if the trailing dimension of the read/writen vector is larger than this bitwidth (i.e., we are already able to fill at least one vector register with that size).
Full diff: https://github.com/llvm/llvm-project/pull/81966.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+7-2)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+40-5)
- (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+34-2)
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+2-1)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index f5941d32e683fc..cb3b3de8051d6f 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -328,8 +328,13 @@ void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, /// These patterns insert memref.collapse_shape + vector.shape_cast patterns /// to transform multiple small n-D transfers into a larger 1-D transfer where /// the memref contiguity properties allow it. -void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+///
+/// Flattening is only applied if the bitwidth of the trailing vector dimension
+/// is smaller or equal to targetVectorBitwidth
.
+void populateFlattenVectorTransferPatterns(
- RewritePatternSet &patterns,
- unsigned targetVectorBitwidth = std::numeric_limits::max(),
- PatternBenefit benefit = 1);
/// Collect a set of patterns that bubble up/down bitcast ops.
///
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index b761d1ed888973..04e5a816dd91e6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -19,7 +19,6 @@
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/STLExtras.h"
@@ -535,9 +534,17 @@ namespace {
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_read has a 1D source. Requires the source shape to be
/// already reduced i.e. without unit dims.
+/// If targetVectorBitwidth
is provided, the flattening will only happen if
+/// the trailing dimension of the vector read is smaller than the provided
+/// bitwidth.
class FlattenContiguousRowMajorTransferReadPattern
: public OpRewritePatternvector::TransferReadOp {
- using OpRewritePattern::OpRewritePattern; +public:
- FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
unsigned vectorBitwidth,
PatternBenefit benefit)
: OpRewritePattern<vector::TransferReadOp>(context, benefit),
LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, PatternRewriter &rewriter) const override {targetVectorBitwidth(vectorBitwidth) {}
@@ -554,6 +561,12 @@ class FlattenContiguousRowMajorTransferReadPattern // If this is already 0D/1D, there's nothing to do. if (vectorType.getRank() <= 1) return failure();
- if (!vectorType.getElementType().isSignlessIntOrFloat())
return failure();
- unsigned trailingVectorDimBitwidth =
vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
- if (trailingVectorDimBitwidth >= targetVectorBitwidth)
if (!vector::isContiguousSlice(sourceType, vectorType)) return failure(); // TODO: generalize this pattern, relax the requirements here. @@ -642,6 +655,11 @@ class FlattenContiguousRowMajorTransferReadPattern transferReadOp, cast(vector.getType()), flatRead); return success(); }return failure();
- +private:
- // Minimum bitwidth that the trailing vector dimension should have after
- // flattening.
- unsigned targetVectorBitwidth; };
/// Rewrites contiguous row-major vector.transfer_write ops by inserting @@ -650,7 +668,12 @@ class FlattenContiguousRowMajorTransferReadPattern /// already reduced i.e. without unit dims. class FlattenContiguousRowMajorTransferWritePattern : public OpRewritePatternvector::TransferWriteOp {
- using OpRewritePattern::OpRewritePattern; +public:
- FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
unsigned vectorBitwidth,
PatternBenefit benefit)
: OpRewritePattern<vector::TransferWriteOp>(context, benefit),
LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, PatternRewriter &rewriter) const override {targetVectorBitwidth(vectorBitwidth) {}
@@ -665,6 +688,12 @@ class FlattenContiguousRowMajorTransferWritePattern if (vectorType.getRank() <= 1) // Already 0D/1D, nothing to do. return failure();
- if (!vectorType.getElementType().isSignlessIntOrFloat())
return failure();
- unsigned trailingVectorDimBitwidth =
vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
- if (trailingVectorDimBitwidth >= targetVectorBitwidth)
if (!vector::isContiguousSlice(sourceType, vectorType)) return failure(); int64_t firstContiguousInnerDim = @@ -702,6 +731,11 @@ class FlattenContiguousRowMajorTransferWritePattern rewriter.eraseOp(transferWriteOp); return success(); }return failure();
- +private:
- // Minimum bitwidth that the trailing vector dimension should have after
- // flattening.
- unsigned targetVectorBitwidth; };
/// Base class for vector.extract/vector.extract_element(vector.transfer_read)
@@ -917,10 +951,11 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
}
void mlir::vector::populateFlattenVectorTransferPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
- RewritePatternSet &patterns, unsigned targetVectorBitwidth,
- PatternBenefit benefit) {
patterns.add<FlattenContiguousRowMajorTransferReadPattern, FlattenContiguousRowMajorTransferWritePattern>(
patterns.getContext(), benefit);
populateShapeCastFoldingPatterns(patterns, benefit); populateDropUnitDimWithShapeCastPatterns(patterns, benefit); } diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 9976048a3320b6..5ba3ac824770ce 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -66,7 +66,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices( %m_out: memref<1x2x6xi32>) { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32patterns.getContext(), targetVectorBitwidth, benefit);
- %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
- %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : memref<1x43x4x6xi32>, vector<1x2x6xi32> vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x2x6xi32>, memref<1x2x6xi32> @@ -99,7 +99,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( %m_out: memref<1x2x6xi32>) { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32
- %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
- %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : memref<1x?x4x6xi32>, vector<1x2x6xi32> vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x2x6xi32>, memref<1x2x6xi32> @@ -389,3 +389,35 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>, // CHECK: %[[VAL_3:.]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32> // CHECK: %[[VAL_4:.]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32> // CHECK: return %[[VAL_4]] : vector<8xi32>
- +// -----
- +func.func @trailing_dim_larger_than_target_vector_bitwidth_read(
%arg : memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>) -> vector<5x4x3x20xi32> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i32
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>, vector<5x4x3x20xi32>
- return %v : vector<5x4x3x20xi32>
+} + +// CHECK-LABEL: func.func @trailing_dim_larger_than_target_vector_bitwidth_read( +// CHECK-NOT: tensor.collapse_shape + +// ----- + +func.func @trailing_dim_larger_than_target_vector_bitwidth_write(
%arg0 : memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>,
%arg1 : vector<5x4x3x20xi32>) {
- %c0 = arith.constant 0 : index
- vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] :
vector<5x4x3x20xi32>, memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>
- return
+} + +// CHECK-LABEL: func.func @trailing_dim_larger_than_target_vector_bitwidth_write( +// CHECK-NOT: tensor.collapse_shape + + + + diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 126d65b1b8487f..57d104e80d7243 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -480,7 +480,8 @@ struct TestFlattenVectorTransferPatterns } void runOnOperation() override { RewritePatternSet patterns(&getContext());
- populateFlattenVectorTransferPatterns(patterns);
- constexpr unsigned targetVectorBitwidth = 512;
- populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } };