MLIR: lib/Dialect/SMT/IR/SMTOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

12#include "llvm/ADT/APSInt.h"

13

14using namespace mlir;

15using namespace smt;

16using namespace mlir;

17

18

19

20

21

22LogicalResult BVConstantOp::inferReturnTypes(

23 mlir::MLIRContext *context, std::optionalmlir::Location location,

27 inferredReturnTypes.push_back(

28 properties.as<Properties *>()->getValue().getType());

30}

31

32void BVConstantOp::getAsmResultNames(

35 llvm::raw_svector_ostream specialName(specialNameBuffer);

36 specialName << "c" << getValue().getValue() << "_bv"

37 << getValue().getValue().getBitWidth();

38 setNameFn(getResult(), specialName.str());

39}

40

41OpFoldResult BVConstantOp::fold(FoldAdaptor adaptor) {

42 assert(adaptor.getOperands().empty() && "constant has no operands");

43 return getValueAttr();

44}

45

46

47

48

49

50void DeclareFunOp::getAsmResultNames(

52 setNameFn(getResult(), getNamePrefix().has_value() ? *getNamePrefix() : "");

53}

54

55

56

57

58

59LogicalResult SolverOp::verifyRegions() {

60 if (getBody()->getTerminator()->getOperands().getTypes() != getResultTypes())

61 return emitOpError() << "types of yielded values must match return values";

62 if (getBody()->getArgumentTypes() != getInputs().getTypes())

64 << "block argument types must match the types of the 'inputs'";

65

67}

68

69

70

71

72

73LogicalResult CheckOp::verifyRegions() {

74 if (getSatRegion().front().getTerminator()->getOperands().getTypes() !=

75 getResultTypes())

76 return emitOpError() << "types of yielded values in 'sat' region must "

77 "match return values";

78 if (getUnknownRegion().front().getTerminator()->getOperands().getTypes() !=

79 getResultTypes())

80 return emitOpError() << "types of yielded values in 'unknown' region must "

81 "match return values";

82 if (getUnsatRegion().front().getTerminator()->getOperands().getTypes() !=

83 getResultTypes())

84 return emitOpError() << "types of yielded values in 'unsat' region must "

85 "match return values";

86

88}

89

90

91

92

93

94static LogicalResult

100

104 return failure();

105

108 loc, result.operands))

109 return failure();

110

112}

113

116}

117

119 printer << ' ' << getInputs();

121 printer << " : " << getInputs().front().getType();

122}

123

124LogicalResult EqOp::verify() {

125 if (getInputs().size() < 2)

126 return emitOpError() << "'inputs' must have at least size 2, but got "

127 << getInputs().size();

128

130}

131

132

133

134

135

138}

139

140void DistinctOp::print(OpAsmPrinter &printer) {

141 printer << ' ' << getInputs();

143 printer << " : " << getInputs().front().getType();

144}

145

146LogicalResult DistinctOp::verify() {

147 if (getInputs().size() < 2)

148 return emitOpError() << "'inputs' must have at least size 2, but got "

149 << getInputs().size();

150

152}

153

154

155

156

157

158LogicalResult ExtractOp::verify() {

159 unsigned rangeWidth = getType().getWidth();

160 unsigned inputWidth = cast(getInput().getType()).getWidth();

161 if (getLowBit() + rangeWidth > inputWidth)

162 return emitOpError("range to be extracted is too big, expected range "

163 "starting at index ")

164 << getLowBit() << " of length " << rangeWidth

165 << " requires input width of at least " << (getLowBit() + rangeWidth)

166 << ", but the input width is only " << inputWidth;

168}

169

170

171

172

173

174LogicalResult ConcatOp::inferReturnTypes(

178 inferredReturnTypes.push_back(BitVectorType::get(

179 context, cast(operands[0].getType()).getWidth() +

180 cast(operands[1].getType()).getWidth()));

182}

183

184

185

186

187

188LogicalResult RepeatOp::verify() {

189 unsigned inputWidth = cast(getInput().getType()).getWidth();

190 unsigned resultWidth = getType().getWidth();

191 if (resultWidth % inputWidth != 0)

192 return emitOpError() << "result bit-vector width must be a multiple of the "

193 "input bit-vector width";

194

196}

197

198unsigned RepeatOp::getCount() {

199 unsigned inputWidth = cast(getInput().getType()).getWidth();

200 unsigned resultWidth = getType().getWidth();

201 return resultWidth / inputWidth;

202}

203

206 unsigned inputWidth = cast(input.getType()).getWidth();

207 Type resultTy = BitVectorType::get(builder.getContext(), inputWidth * count);

208 build(builder, state, resultTy, input);

209}

210

213 Type inputType;

215

216 APInt count;

218 return failure();

219

220 if (count.isNonPositive())

221 return parser.emitError(countLoc) << "integer must be positive";

222

227 return failure();

228

230 return failure();

231

232 auto bvInputTy = dyn_cast(inputType);

233 if (!bvInputTy)

234 return parser.emitError(inputLoc) << "input must have bit-vector type";

235

236

237

238 const unsigned maxBw = 63;

239 if (count.getActiveBits() > maxBw)

240 return parser.emitError(countLoc)

241 << "integer must fit into " << maxBw << " bits";

242

243

244

245

246 APInt resultBw = bvInputTy.getWidth() * count.zext(2 * maxBw);

247 if (resultBw.getActiveBits() > maxBw)

248 return parser.emitError(countLoc)

249 << "result bit-width (provided integer times bit-width of the input "

250 "type) must fit into "

251 << maxBw << " bits";

252

253 Type resultTy =

254 BitVectorType::get(parser.getContext(), resultBw.getZExtValue());

255 result.addTypes(resultTy);

257}

258

260 printer << " " << getCount() << " times " << getInput();

262 printer << " : " << getInput().getType();

263}

264

265

266

267

268

269void BoolConstantOp::getAsmResultNames(

271 setNameFn(getResult(), getValue() ? "true" : "false");

272}

273

274OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {

275 assert(adaptor.getOperands().empty() && "constant has no operands");

276 return getValueAttr();

277}

278

279

280

281

282

283void IntConstantOp::getAsmResultNames(

286 llvm::raw_svector_ostream specialName(specialNameBuffer);

287 specialName << "c" << getValue();

288 setNameFn(getResult(), specialName.str());

289}

290

291OpFoldResult IntConstantOp::fold(FoldAdaptor adaptor) {

292 assert(adaptor.getOperands().empty() && "constant has no operands");

293 return getValueAttr();

294}

295

297 p << " " << getValue();

299}

300

302 APInt value;

304 return failure();

305

306 result.getOrAddProperties().setValue(

307 IntegerAttr::get(parser.getContext(), APSInt(value)));

308

310 return failure();

311

314}

315

316

317

318

319

320template

322 if (op.getBoundVarNames() &&

323 op.getBody().getNumArguments() != op.getBoundVarNames()->size())

324 return op.emitOpError(

325 "number of bound variable names must match number of block arguments");

327 return op.emitOpError()

328 << "bound variables must by any non-function SMT value";

329

330 if (op.getBody().front().getTerminator()->getNumOperands() != 1)

331 return op.emitOpError("must have exactly one yielded value");

332 if (!isa(

333 op.getBody().front().getTerminator()->getOperand(0).getType()))

334 return op.emitOpError("yielded value must be of '!smt.bool' type");

335

336 for (auto regionWithIndex : llvm::enumerate(op.getPatterns())) {

337 unsigned i = regionWithIndex.index();

338 Region &region = regionWithIndex.value();

339

340 if (op.getBody().getArgumentTypes() != region.getArgumentTypes())

341 return op.emitOpError()

342 << "block argument number and types of the 'body' "

343 "and 'patterns' region #"

344 << i << " must match";

346 return op.emitOpError() << "'patterns' region #" << i

347 << " must have at least one yielded value";

348

349

351 if (!isa(childOp->getDialect())) {

352 auto diag = op.emitOpError()

353 << "the 'patterns' region #" << i

354 << " may only contain SMT dialect operations";

355 diag.attachNote(childOp->getLoc()) << "first non-SMT operation here";

357 }

358

359

360

361 if (isa<ForallOp, ExistsOp>(childOp)) {

362 auto diag = op.emitOpError() << "the 'patterns' region #" << i

363 << " must not contain "

364 "any variable binding operations";

365 diag.attachNote(childOp->getLoc()) << "first violating operation here";

367 }

368

370 });

371 if (result.wasInterrupted())

372 return failure();

373 }

374

376}

377

378template

384 uint32_t weight, bool noPattern) {

386 if (weight != 0)

389 if (noPattern)

392 if (boundVarNames.has_value()) {

394 for (StringRef str : *boundVarNames)

395 boundVarNamesList.emplace_back(odsBuilder.getStringAttr(str));

398 }

399 {

404 boundVarTypes,

406 Value returnVal =

408 smt::YieldOp::create(odsBuilder, odsState.location, returnVal);

409 }

410 if (patternBuilder) {

415 boundVarTypes,

419 smt::YieldOp::create(odsBuilder, odsState.location, returnVals);

420 }

421}

422

423LogicalResult ForallOp::verify() {

424 if (!getPatterns().empty() && getNoPattern())

425 return emitOpError() << "patterns and the no_pattern attribute must not be "

426 "specified at the same time";

427

429}

430

431LogicalResult ForallOp::verifyRegions() {

433}

434

435void ForallOp::build(

436 OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,

438 std::optional<ArrayRef> boundVarNames,

440 uint32_t weight, bool noPattern) {

442 boundVarNames, patternBuilder, weight, noPattern);

443}

444

445

446

447

448

449LogicalResult ExistsOp::verify() {

450 if (!getPatterns().empty() && getNoPattern())

451 return emitOpError() << "patterns and the no_pattern attribute must not be "

452 "specified at the same time";

453

455}

456

457LogicalResult ExistsOp::verifyRegions() {

459}

460

461void ExistsOp::build(

462 OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,

464 std::optional<ArrayRef> boundVarNames,

466 uint32_t weight, bool noPattern) {

468 boundVarNames, patternBuilder, weight, noPattern);

469}

470

471#define GET_OP_CLASSES

472#include "mlir/Dialect/SMT/IR/SMT.cpp.inc"

p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")

Given a list of lists of parsed operands, populates uniqueOperands with unique operands.

static std::string diag(const llvm::Value &value)

static LogicalResult verifyQuantifierRegions(QuantifierOp op)

Definition SMTOps.cpp:321

static LogicalResult parseSameOperandTypeVariadicToBoolOp(OpAsmParser &parser, OperationState &result)

Definition SMTOps.cpp:95

static void buildQuantifier(OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, function_ref< Value(OpBuilder &, Location, ValueRange)> bodyBuilder, std::optional< ArrayRef< StringRef > > boundVarNames, function_ref< ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder, uint32_t weight, bool noPattern)

Definition SMTOps.cpp:379

virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0

Parse a named dictionary into 'result' if it is present.

MLIRContext * getContext() const

virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0

Emit a diagnostic at the specified location and return failure.

ParseResult parseInteger(IntT &result)

Parse an integer value from the stream.

virtual SMLoc getCurrentLocation()=0

Get the location of the next token and store it into the argument.

virtual ParseResult parseColon()=0

Parse a : token.

virtual ParseResult parseType(Type &result)=0

Parse a type.

ParseResult parseKeyword(StringRef keyword)

Parse a given keyword.

Block represents an ordered list of Operations.

iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)

Add one argument to the argument list for each type specified in the list.

Operation * getTerminator()

Get the terminator operation of this block.

BlockArgListType getArguments()

IntegerAttr getIntegerAttr(Type type, int64_t value)

IntegerType getIntegerType(unsigned width)

StringAttr getStringAttr(const Twine &bytes)

ArrayAttr getArrayAttr(ArrayRef< Attribute > value)

MLIRContext * getContext() const

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0

Resolve an operand to an SSA value, emitting an error on failure.

ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)

Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...

virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0

Parse a single SSA value operand name along with a result number if allowResultNumber is true.

virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0

Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0

If the specified operation has attributes, print out an attribute dictionary with their values.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

This class represents a single result from folding an operation.

Simple wrapper around a void* in order to express generically how to pass in op properties through AP...

Operation is the basic unit of execution within MLIR.

Dialect * getDialect()

Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...

Location getLoc()

The source location the operation was defined or derived from.

unsigned getNumOperands()

This class provides an abstraction over the different types of ranges over Regions.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

ValueTypeRange< BlockArgListType > getArgumentTypes()

Returns the argument types of the first block within the region.

RetT walk(FnT &&callback)

Walk all nested operations, blocks or regions (including this region), depending on the type of callb...

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

static WalkResult advance()

static WalkResult interrupt()

bool isAnyNonFuncSMTValueType(mlir::Type type)

Returns whether the given type is an SMT value type (excluding functions).

Include the generated interface declarations.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

llvm::function_ref< Fn > function_ref

This is the representation of an operand reference.

This represents an operation in an abstracted form, suitable for use with the builder APIs.

T & getOrAddProperties()

Get (or create) a properties of the provided type to be set on the operation on creation.

void addTypes(ArrayRef< Type > newTypes)

Region * addRegion()

Create a region that should be attached to the operation.