(original) (raw)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 58ca84c8d7bca..a9a07c323c735 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1096,43 +1096,55 @@ class VectorExtractOpConversion SmallVector positionVec = getMixedValues( adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter); - // Extract entire vector. Should be handled by folder, but just to be safe. - ArrayRef position(positionVec); - if (position.empty()) { - rewriter.replaceOp(extractOp, adaptor.getVector()); - return success(); - } - - // One-shot extraction of vector from array (only requires extractvalue). - // Except for extracting 1-element vectors. - if (isa(resultType) && - position.size() != - static_cast(extractOp.getSourceVectorType().getRank())) { - if (extractOp.hasDynamicPosition()) - return failure(); - - Value extracted = rewriter.create( - loc, adaptor.getVector(), getAsIntegers(position)); - rewriter.replaceOp(extractOp, extracted); - return success(); + // The Vector -> LLVM lowering models N-D vectors as nested aggregates of + // 1-d vectors. This nesting is modeled using arrays. We do this conversion + // from a N-d vector extract to a nested aggregate vector extract in two + // steps: + // - Extract a member from the nested aggregate. The result can be + // a lower rank nested aggregate or a vector (1-D). This is done using + // `llvm.extractvalue`. + // - Extract a scalar out of the vector if needed. This is done using + // `llvm.extractelement`. + + // Determine if we need to extract a member out of the aggregate. We + // always need to extract a member if the input rank >= 2. + bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2; + // Determine if we need to extract a scalar as the result. We extract + // a scalar if the extract is full rank, i.e., the number of indices is + // equal to source vector rank. + bool extractsScalar = static_cast(positionVec.size()) == + extractOp.getSourceVectorType().getRank(); + + // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we + // need to add a position for this change. + if (extractOp.getSourceVectorType().getRank() == 0) { + Type idxType = typeConverter->convertType(rewriter.getIndexType()); + positionVec.push_back(rewriter.getZeroAttr(idxType)); } - // Potential extraction of 1-D vector from array. Value extracted = adaptor.getVector(); - if (position.size() > 1) { - if (extractOp.hasDynamicPosition()) + if (extractsAggregate) { + ArrayRef position(positionVec); + if (extractsScalar) { + // If we are extracting a scalar from the extracted member, we drop + // the last index, which will be used to extract the scalar out of the + // vector. + position = position.drop_back(); + } + // llvm.extractvalue does not support dynamic dimensions. + if (!llvm::all_of(position, llvm::IsaPred)) { return failure(); + } + extracted = rewriter.create( + loc, extracted, getAsIntegers(position)); + } - SmallVector nMinusOnePosition = - getAsIntegers(position.drop_back()); - extracted = rewriter.create(loc, extracted, - nMinusOnePosition); + if (extractsScalar) { + extracted = rewriter.create( + loc, extracted, getAsLLVMValue(rewriter, loc, positionVec.back())); } - Value lastPosition = getAsLLVMValue(rewriter, loc, position.back()); - // Remaining extraction of element from 1-D LLVM vector. - rewriter.replaceOpWithNewOp(extractOp, extracted, - lastPosition); + rewriter.replaceOp(extractOp, extracted); return success(); } }; diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 1c42538cf8591..6e8a9018d0a25 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1290,26 +1290,68 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16 // ----- -func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 { +func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 { %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x16xf32> return %0 : f32 } -// Multi-dim vectors are not supported but this test shouldn't crash. +// Lowering supports extracting from multi-dim vectors with dynamic indices +// provided that only the trailing index is dynamic. -// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx( -// CHECK: vector.extract +// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx( +// CHECK: llvm.extractvalue +// CHECK: llvm.extractelement -func.func @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 { +func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 { %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x[16]xf32> return %0 : f32 } -// Multi-dim vectors are not supported but this test shouldn't crash. +// Lowering supports extracting from multi-dim vectors with dynamic indices +// provided that only the trailing index is dynamic. + +// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable( +// CHECK: llvm.extractvalue +// CHECK: llvm.extractelement + +// ----- -// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable( +func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 { + %0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x16xf32> + return %0 : f32 +} + +// Lowering supports extracting from multi-dim vectors with dynamic indices +// provided that only the trailing index is dynamic. + +// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx( // CHECK: vector.extract +func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 { + %0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x[16]xf32> + return %0 : f32 +} + +// Lowering does not support extracting from multi-dim vectors with non trailing +// dynamic index, but it shouldn't crash. + +// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable( +// CHECK: vector.extract + +// ----- + +func.func @extract_scalar_from_vec_0d_index(%arg0: vector) -> index { + %0 = vector.extract %arg0[]: index from vector+ return %0 : index +} +// CHECK-LABEL: @extract_scalar_from_vec_0d_index( +// CHECK-SAME: %[[A:.*]]: vector) +// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector to vector<1xi64> +// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xi64> +// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index +// CHECK: return %[[T3]] : index + // ----- func.func @insertelement_into_vec_0d_f32(%arg0: f32, %arg1: vector) -> vector {