[mlir][vector] Add verification for incorrect vector.extract by dcaballe · Pull Request #115824 · llvm/llvm-project (original) (raw)
@@ -1339,6 +1339,83 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return l == r;
}
// Common verification rules for `InsertOp` and `ExtractOp` involving indices
// and shapes. `indexedType` is the vector type being indexed by the operation,
// i.e., the destination type in `InsertOp` and the source type in `ExtractOp`.
// `nonIndexedType` is the inserted or extracted type by an `InsertOp` or and
// `ExtractOp`, respectively.
static LogicalResult verifyInsertExtractIndicesAndShapes(Operation *op,
VectorType indexedType,
int64_t numIndices,
Type nonIndexedType) {
assert((isa(op) || isa(op)) &&
"Expected InsertOp or ExtractOp");
std::string nonIndexedStr = isa(op) ? "inserted" : "extracted";
std::string indexedStr = isa(op) ? "destination" : "source";
int64_t indexedRank = indexedType.getRank();
if (numIndices > indexedRank) {
return op->emitOpError()
<< "expected a number of indices no greater than the " << indexedStr
<< " vector rank";
}
if (auto nonIndexedVecType = dyn_cast(nonIndexedType)) {
// Vector case, including meaningful cases such as:
// * 0-D vector:
// * vector.extract %src[2]: vector from vector<8xf32)
// * vector.insert %src, %dst[3]: vector into vector<8xf32>
// * One-element vector:
// * vector.extract %src[4]: vector<1xf32> from vector<8xf32>
// * vector.insert %src, %dst[1]: vector<1xf32> into vector<8xf32>
// * vector.extract %src[7]: vector<1xf32> from vector<8x1xf32>
// * vector.insert %src, %dst[5]: vector<1xf32> into vector<8x1xf32>
int64_t nonIndexedRank = nonIndexedVecType.getRank();
bool isSingleElem1DNonIndexedVec =
(nonIndexedRank == 1 && nonIndexedVecType.getDimSize(0) == 1);
bool isSingleElem1DIndexedVec =
(indexedRank == 1 && indexedType.getDimSize(0) == 1);
// Verify 0-D -> single-element 1-D supported cases.
if ((indexedRank == 0 && isSingleElem1DNonIndexedVec) ||
(nonIndexedRank == 0 && isSingleElem1DIndexedVec)) {
return op->emitOpError("expected source and destination vectors with "
"different number of elements");
}
// Verify indices for all the cases.
int64_t indexedRankMinusIndices = indexedRank - numIndices;
if (indexedRankMinusIndices != nonIndexedRank &&
(!isSingleElem1DNonIndexedVec || indexedRankMinusIndices != 0)) {
return op->emitOpError()
<< "expected " << indexedStr
<< " vector rank minus number of indices to match the rank of the "
<< nonIndexedStr << " vector";
}
// Check that if we are inserting or extracting a sub-vector, the
// corresponding source and destination shapes match.
if (indexedRankMinusIndices > 0) {
auto indexedShape = indexedType.getShape();
if (indexedShape.drop_front(numIndices) != nonIndexedVecType.getShape()) {
return op->emitOpError() << "expected " << nonIndexedStr
<< " vector shape to match the sub-vector "
"shape of the "
<< indexedStr << " vector";
}
}
return success();
}
// Scalar case.
if (indexedRank != numIndices) {
return op->emitOpError()
<< "expected " << indexedStr
<< " vector rank to match the number of indices for scalar cases";
}
return success();
}
LogicalResult vector::ExtractOp::verify() {
// Note: This check must come before getMixedPosition() to prevent a crash.
auto dynamicMarkersCount =
@@ -1348,14 +1425,16 @@ LogicalResult vector::ExtractOp::verify() {
"mismatch between dynamic and static positions (kDynamic marker but no "
"corresponding dynamic position) -- this can only happen due to an "
"incorrect fold/rewrite");
auto position = getMixedPosition();
if (position.size() > static_cast(getSourceVectorType().getRank()))
return emitOpError(
"expected position attribute of rank no greater than vector rank");
for (auto [idx, pos] : llvm::enumerate(position)) {
auto srcVecType = getSourceVectorType();
if (failed(verifyInsertExtractIndicesAndShapes(
*this, srcVecType, getNumIndices(), getResult().getType()))) {
return failure();
}
for (auto [idx, pos] : llvm::enumerate(getMixedPosition())) {
if (pos.is()) {
int64_t constIdx = cast(pos.get()).getInt();
if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
if (constIdx < 0 || constIdx >= srcVecType.getDimSize(idx)) {
return emitOpError("expected position attribute #")
<< (idx + 1)
<< " to be a non-negative integer smaller than the "
@@ -2861,25 +2940,16 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
}
LogicalResult InsertOp::verify() {
SmallVector position = getMixedPosition();
auto destVectorType = getDestVectorType();
if (position.size() > static_cast(destVectorType.getRank()))
return emitOpError(
"expected position attribute of rank no greater than dest vector rank");
auto srcVectorType = llvm::dyn_cast(getSourceType());
if (srcVectorType &&
(static_cast(srcVectorType.getRank()) + position.size() !=
static_cast(destVectorType.getRank())))
return emitOpError("expected position attribute rank + source rank to "
"match dest vector rank");
if (!srcVectorType &&
(position.size() != static_cast(destVectorType.getRank())))
return emitOpError(
"expected position attribute rank to match the dest vector rank");
for (auto [idx, pos] : llvm::enumerate(position)) {
auto dstVecType = getDestVectorType();
if (failed(verifyInsertExtractIndicesAndShapes(
*this, dstVecType, getNumIndices(), getSourceType()))) {
return failure();
}
for (auto [idx, pos] : llvm::enumerate(getMixedPosition())) {
if (auto attr = pos.dyn_cast()) {
int64_t constIdx = cast(attr).getInt();
if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
if (constIdx < 0 || constIdx >= dstVecType.getDimSize(idx)) {
return emitOpError("expected position attribute #")
<< (idx + 1)
<< " to be a non-negative integer smaller than the "