[mlir][vector] Add verification for incorrect vector.extract by dcaballe · Pull Request #115824 · llvm/llvm-project (original) (raw)

Expand Up

@@ -1339,6 +1339,83 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {

return l == r;

}

// Common verification rules for `InsertOp` and `ExtractOp` involving indices

// and shapes. `indexedType` is the vector type being indexed by the operation,

// i.e., the destination type in `InsertOp` and the source type in `ExtractOp`.

// `nonIndexedType` is the inserted or extracted type by an `InsertOp` or and

// `ExtractOp`, respectively.

static LogicalResult verifyInsertExtractIndicesAndShapes(Operation *op,

VectorType indexedType,

int64_t numIndices,

Type nonIndexedType) {

assert((isa(op) || isa(op)) &&

"Expected InsertOp or ExtractOp");

std::string nonIndexedStr = isa(op) ? "inserted" : "extracted";

std::string indexedStr = isa(op) ? "destination" : "source";

int64_t indexedRank = indexedType.getRank();

if (numIndices > indexedRank) {

return op->emitOpError()

<< "expected a number of indices no greater than the " << indexedStr

<< " vector rank";

}

if (auto nonIndexedVecType = dyn_cast(nonIndexedType)) {

// Vector case, including meaningful cases such as:

// * 0-D vector:

// * vector.extract %src[2]: vector from vector<8xf32)

// * vector.insert %src, %dst[3]: vector into vector<8xf32>

// * One-element vector:

// * vector.extract %src[4]: vector<1xf32> from vector<8xf32>

// * vector.insert %src, %dst[1]: vector<1xf32> into vector<8xf32>

// * vector.extract %src[7]: vector<1xf32> from vector<8x1xf32>

// * vector.insert %src, %dst[5]: vector<1xf32> into vector<8x1xf32>

int64_t nonIndexedRank = nonIndexedVecType.getRank();

bool isSingleElem1DNonIndexedVec =

(nonIndexedRank == 1 && nonIndexedVecType.getDimSize(0) == 1);

bool isSingleElem1DIndexedVec =

(indexedRank == 1 && indexedType.getDimSize(0) == 1);

// Verify 0-D -> single-element 1-D supported cases.

if ((indexedRank == 0 && isSingleElem1DNonIndexedVec) ||

(nonIndexedRank == 0 && isSingleElem1DIndexedVec)) {

return op->emitOpError("expected source and destination vectors with "

"different number of elements");

}

// Verify indices for all the cases.

int64_t indexedRankMinusIndices = indexedRank - numIndices;

if (indexedRankMinusIndices != nonIndexedRank &&

(!isSingleElem1DNonIndexedVec || indexedRankMinusIndices != 0)) {

return op->emitOpError()

<< "expected " << indexedStr

<< " vector rank minus number of indices to match the rank of the "

<< nonIndexedStr << " vector";

}

// Check that if we are inserting or extracting a sub-vector, the

// corresponding source and destination shapes match.

if (indexedRankMinusIndices > 0) {

auto indexedShape = indexedType.getShape();

if (indexedShape.drop_front(numIndices) != nonIndexedVecType.getShape()) {

return op->emitOpError() << "expected " << nonIndexedStr

<< " vector shape to match the sub-vector "

"shape of the "

<< indexedStr << " vector";

}

}

return success();

}

// Scalar case.

if (indexedRank != numIndices) {

return op->emitOpError()

<< "expected " << indexedStr

<< " vector rank to match the number of indices for scalar cases";

}

return success();

}

LogicalResult vector::ExtractOp::verify() {

// Note: This check must come before getMixedPosition() to prevent a crash.

auto dynamicMarkersCount =

Expand All

@@ -1348,14 +1425,16 @@ LogicalResult vector::ExtractOp::verify() {

"mismatch between dynamic and static positions (kDynamic marker but no "

"corresponding dynamic position) -- this can only happen due to an "

"incorrect fold/rewrite");

auto position = getMixedPosition();

if (position.size() > static_cast(getSourceVectorType().getRank()))

return emitOpError(

"expected position attribute of rank no greater than vector rank");

for (auto [idx, pos] : llvm::enumerate(position)) {

auto srcVecType = getSourceVectorType();

if (failed(verifyInsertExtractIndicesAndShapes(

*this, srcVecType, getNumIndices(), getResult().getType()))) {

return failure();

}

for (auto [idx, pos] : llvm::enumerate(getMixedPosition())) {

if (pos.is()) {

int64_t constIdx = cast(pos.get()).getInt();

if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {

if (constIdx < 0 || constIdx >= srcVecType.getDimSize(idx)) {

return emitOpError("expected position attribute #")

<< (idx + 1)

<< " to be a non-negative integer smaller than the "

Expand Down Expand Up

@@ -2861,25 +2940,16 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,

}

LogicalResult InsertOp::verify() {

SmallVector position = getMixedPosition();

auto destVectorType = getDestVectorType();

if (position.size() > static_cast(destVectorType.getRank()))

return emitOpError(

"expected position attribute of rank no greater than dest vector rank");

auto srcVectorType = llvm::dyn_cast(getSourceType());

if (srcVectorType &&

(static_cast(srcVectorType.getRank()) + position.size() !=

static_cast(destVectorType.getRank())))

return emitOpError("expected position attribute rank + source rank to "

"match dest vector rank");

if (!srcVectorType &&

(position.size() != static_cast(destVectorType.getRank())))

return emitOpError(

"expected position attribute rank to match the dest vector rank");

for (auto [idx, pos] : llvm::enumerate(position)) {

auto dstVecType = getDestVectorType();

if (failed(verifyInsertExtractIndicesAndShapes(

*this, dstVecType, getNumIndices(), getSourceType()))) {

return failure();

}

for (auto [idx, pos] : llvm::enumerate(getMixedPosition())) {

if (auto attr = pos.dyn_cast()) {

int64_t constIdx = cast(attr).getInt();

if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {

if (constIdx < 0 || constIdx >= dstVecType.getDimSize(idx)) {

return emitOpError("expected position attribute #")

<< (idx + 1)

<< " to be a non-negative integer smaller than the "

Expand Down