MLIR: include/mlir/IR/Matchers.h Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15 #ifndef MLIR_IR_MATCHERS_H

16 #define MLIR_IR_MATCHERS_H

17

22

23 namespace mlir {

24

25 namespace detail {

26

27

28

29 template <

30 typename AttrClass,

31

32

33 typename ValueType = typename std::enable_if_t<

34 std::is_base_of<Attribute, AttrClass>::value, AttrClass>::ValueType,

35

36 typename = std::enable_if_t<!std::is_void::value>>

39

40

42

44 if (auto intAttr = llvm::dyn_cast(attr)) {

46 return true;

47 }

48 return false;

49 }

50 };

51

52

55 };

56

57

61

63 };

64

65

69

71 };

72

73

74

75 template

78

79

80

82

84

87 return false;

88

89

91 LogicalResult result = op->fold(std::nullopt, foldedOp);

92 (void)result;

93 assert(succeeded(result) && "expected ConstantLike op to be foldable");

94

95 if (auto attr = llvm::dyn_cast(cast(foldedOp.front()))) {

98 return true;

99 }

100 return false;

101 }

102 };

103

104

105

108

111

113 auto inferIntRangeOp = dyn_cast(op);

114 if (!inferIntRangeOp)

115 return false;

116

117

120

121

122 bool matched = false;

123 auto setResultRanges = [&](Value value,

125 if (argRanges.isUninitialized())

126 return;

128 return;

130 matched = true;

131 };

132 inferIntRangeOp.inferResultRangesFromOptional(argRanges, setResultRanges);

133 return matched;

134 }

135 };

136

137

138

139 template

141

142

145

147

152 return true;

153 }

154 return false;

155 }

158 };

159

160

161

164

165

167

170 if (matcher.match(attr))

171 return true;

172

173 if (auto splatAttr = dyn_cast(attr))

174 return matcher.match(splatAttr.getSplatValue<Attribute>());

175

176 return false;

177 }

178

182 return false;

183

185 if (isa<FloatType, VectorType, RankedTensorType>(type))

186 return match(attr);

187

188 return false;

189 }

190 };

191

192

193

196

198 APFloat value(APFloat::Bogus());

200 }

201

203 APFloat value(APFloat::Bogus());

205 }

206 };

207

208

209

212

213

215

218 if (matcher.match(attr))

219 return true;

220

221 if (auto splatAttr = dyn_cast(attr))

222 return matcher.match(splatAttr.getSplatValue<Attribute>());

223

224 return false;

225 }

226

230 return false;

231

233 if (isa<IntegerType, IndexType, VectorType, RankedTensorType>(type))

234 return match(attr);

235

236 return false;

237 }

238 };

239

240

241

244

246 APInt value;

248 }

249

251 APInt value;

253 }

254 };

255

256

257

260

262 APInt value;

265 }

266

268

269 APInt value;

272

273

274

278 }

279 };

280

281

282 template

285 };

286

287

288

289 template <typename T, typename MatchTarget>

291 decltype(std::declval().match(std::declval()));

292

293

294 template

296 MatcherClass, Value>::value,

297 bool>

299 return matcher.match(op->getOperand(idx));

300 }

301

302

303 template

305 MatcherClass, Operation *>::value,

306 bool>

309 return matcher.match(defOp);

310 return false;

311 }

312

313

316 };

317

318

324 return true;

325 }

326 };

327

328

333 };

334

335 template <typename TupleT, class CallbackT, std::size_t... Is>

336 constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback,

337 std::index_sequence<Is...>) {

338

339 (callback(std::integral_constant<std::size_t, Is>{}, std::get(tuple)),

340 ...);

341 }

342

343 template <typename... Tys, typename CallbackT>

344 constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) {

346 std::make_index_sequence<sizeof...(Tys)>{});

347 }

348

349

350 template <typename OpType, typename... OperandMatchers>

355 if (!isa(op) || op->getNumOperands() != sizeof...(OperandMatchers))

356 return false;

357 bool res = true;

360 });

361 return res;

362 }

364 };

365

366 }

367

368

371 }

372

373

376 }

377

378

381 }

382

383

384

385 template

388 }

389

390

391 template

393 AttrT *bindValue) {

395 }

396

397

398

400 return {[](const APFloat &value) { return value.isZero(); }};

401 }

402

403

405 return {[](const APFloat &value) { return value.isPosZero(); }};

406 }

407

408

410 return {[](const APFloat &value) { return value.isNegZero(); }};

411 }

412

413

415 return {[](const APFloat &value) {

416 return APFloat(value.getSemantics(), 1) == value;

417 }};

418 }

419

420

422 return {[](const APFloat &value) { return value.isNaN(); }};

423 }

424

425

426

428 return {[](const APFloat &value) {

429 return !value.isNegative() && value.isInfinity();

430 }};

431 }

432

433

434

436 return {[](const APFloat &value) {

437 return value.isNegative() && value.isInfinity();

438 }};

439 }

440

441

443 return {[](const APInt &value) { return 0 == value; }};

444 }

445

446

447

449 return {[](const APInt &value) { return 0 != value; }};

450 }

451

452

453

454

456 return {[](const ConstantIntRanges &range) { return range.umin().ugt(0); }};

457 }

458

459

460

461

464 return range.smin().sgt(0) || range.smax().slt(0);

465 }};

466 }

467

468

469

470

473 return range.smin().sgt(-1) || range.smax().slt(-1);

474 }};

475 }

476

477

479 return {[](const APInt &value) { return 1 == value; }};

480 }

481

482

483 template

486 }

487

488

489 template

491 assert(value);

492

494 return const_cast<Pattern &>(pattern).match(op);

495 return false;

496 }

497

498

499 template

501 assert(op);

502 return const_cast<Pattern &>(pattern).match(op);

503 }

504

505

506

507 template

511 "Pattern does not support matching Attributes");

512 if (!attr)

513 return false;

514 return const_cast<Pattern &>(pattern).match(attr);

515 }

516

517

518

519 inline detail::constant_float_value_binder

522 }

523

524

525

526 inline detail::constant_int_value_binder

529 }

530

531 template <typename OpType, typename... Matchers>

532 auto m_Op(Matchers... matchers) {

534 }

535

536 namespace matchers {

540 }

541

542 }

543

544 #endif

Attributes are known-constant values of operations.

A set of arbitrary-precision integers representing bounds on a given integer value.

static ConstantIntRanges constant(const APInt &value)

Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...

This lattice value represents the integer range of an SSA value.

const ConstantIntRanges & getValue() const

Get the known integer value range.

static IntegerValueRange getMaxRange(Value value)

Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...

This class provides the API for a sub-set of ops that are known to be constant-like.

StringRef getStringRef() const

Return the name of this operation. This always succeeds.

Operation is the basic unit of execution within MLIR.

LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)

Attempt to fold this operation with the specified constant operand values.

Value getOperand(unsigned idx)

bool hasTrait()

Returns true if the operation was registered with a particular trait, e.g.

AttrClass getAttrOfType(StringAttr name)

bool hasAttr(StringAttr name)

Return true if the operation has an attribute with the provided name, false otherwise.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

unsigned getNumOperands()

OperationName getName()

The name of an operation is the key identifier for it.

operand_range getOperands()

Returns an iterator on the underlying Value's.

This class contains all of the data related to a pattern, but does not contain any methods or logic f...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback, std::index_sequence< Is... >)

std::enable_if_t< llvm::is_detected< detail::has_compatible_matcher_t, MatcherClass, Value >::value, bool > matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher)

Statically switch to a Value matcher.

decltype(std::declval< T >().match(std::declval< MatchTarget >())) has_compatible_matcher_t

Trait to check whether T provides a 'match' method with type MatchTarget (Value, Operation,...

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)

Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...

detail::constant_float_predicate_matcher m_NaNFloat()

Matches a constant scalar / vector splat / tensor splat float ones.

detail::AttrOpMatcher m_Attr(StringRef attrName)

Matches a named attribute operation.

detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()

Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...

detail::NameOpMatcher m_Op(StringRef opName)

Matches a named operation.

detail::constant_float_predicate_matcher m_PosZeroFloat()

Matches a constant scalar / vector splat / tensor splat float positive zero.

detail::constant_int_predicate_matcher m_Zero()

Matches a constant scalar / vector splat / tensor splat integer zero.

detail::constant_float_predicate_matcher m_AnyZeroFloat()

Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.

detail::constant_int_predicate_matcher m_One()

Matches a constant scalar / vector splat / tensor splat integer one.

detail::constant_int_predicate_matcher m_NonZero()

Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.

detail::constant_float_predicate_matcher m_NegInfFloat()

Matches a constant scalar / vector splat / tensor splat float negative infinity.

detail::constant_float_predicate_matcher m_NegZeroFloat()

Matches a constant scalar / vector splat / tensor splat float negative zero.

detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()

Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

detail::constant_float_predicate_matcher m_PosInfFloat()

Matches a constant scalar / vector splat / tensor splat float positive infinity.

detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)

Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...

detail::constant_float_predicate_matcher m_OneFloat()

Matches a constant scalar / vector splat / tensor splat float ones.

detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()

Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...

Terminal matcher, always returns true.

AnyCapturedValueMatcher(Value *what)

bool match(Value op) const

Terminal matcher, always returns true.

bool match(Value op) const

The matcher that matches operations that have the specified attribute name, and binds the attribute v...

AttrOpBinder(StringRef attrName, AttrT *bindValue)

Creates a matcher instance that binds the attribute value to bind_value if match succeeds.

bool match(Operation *op)

AttrOpBinder(StringRef attrName)

Creates a matcher instance that doesn't bind if match succeeds.

The matcher that matches operations that have the specified attribute name.

bool match(Operation *op)

AttrOpMatcher(StringRef attrName)

The matcher that matches operations that have the specified op name.

NameOpMatcher(StringRef name)

bool match(Operation *op)

Binds to a specific value and matches it.

bool match(Value val) const

PatternMatcherValue(Value val)

RecursivePatternMatcher that composes.

RecursivePatternMatcher(OperandMatchers... matchers)

std::tuple< OperandMatchers... > operandMatchers

bool match(Operation *op)

The matcher that matches a certain kind of Attribute and binds the value inside the Attribute.

attr_value_binder(ValueType *bv)

Creates a matcher instance that binds the value to bv if match succeeds.

bool match(Attribute attr)

The matcher that matches a given target constant scalar / vector splat / tensor splat float value tha...

bool match(Operation *op)

bool(* predicate)(const APFloat &)

bool match(Attribute attr)

The matcher that matches a constant scalar / vector splat / tensor splat float Attribute or Operation...

bool match(Attribute attr)

constant_float_value_binder(FloatAttr::ValueType *bv)

Creates a matcher instance that binds the value to bv if match succeeds.

bool match(Operation *op)

FloatAttr::ValueType * bind_value

The matcher that matches a given target constant scalar / vector splat / tensor splat integer value t...

bool match(Operation *op)

bool(* predicate)(const APInt &)

bool match(Attribute attr)

A matcher that matches a given a constant scalar / vector splat / tensor splat integer value or a con...

bool match(Attribute attr)

bool match(Operation *op)

bool(* predicate)(const ConstantIntRanges &)

The matcher that matches a constant scalar / vector splat / tensor splat integer Attribute or Operati...

constant_int_value_binder(IntegerAttr::ValueType *bv)

Creates a matcher instance that binds the value to bv if match succeeds.

bool match(Attribute attr)

IntegerAttr::ValueType * bind_value

bool match(Operation *op)

The matcher that matches operations that have the ConstantLike trait, and binds the folded attribute ...

constant_op_binder()

Creates a matcher instance that doesn't bind if match succeeds.

constant_op_binder(AttrT *bind_value)

Creates a matcher instance that binds the constant attribute value to bind_value if match succeeds.

bool match(Operation *op)

The matcher that matches operations that have the ConstantLike trait.

bool match(Operation *op)

A matcher that matches operations that implement the InferIntRangeInterface interface,...

IntegerValueRange * bind_value

infer_int_range_op_binder(IntegerValueRange *bind_value)

bool match(Operation *op)

The matcher that matches a certain kind of op.

bool match(Operation *op)