MLIR: include/mlir/IR/Matchers.h Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 #ifndef MLIR_IR_MATCHERS_H
16 #define MLIR_IR_MATCHERS_H
17
22
23 namespace mlir {
24
25 namespace detail {
26
27
28
29 template <
30 typename AttrClass,
31
32
33 typename ValueType = typename std::enable_if_t<
34 std::is_base_of<Attribute, AttrClass>::value, AttrClass>::ValueType,
35
36 typename = std::enable_if_t<!std::is_void::value>>
39
40
42
44 if (auto intAttr = llvm::dyn_cast(attr)) {
46 return true;
47 }
48 return false;
49 }
50 };
51
52
55 };
56
57
61
63 };
64
65
69
71 };
72
73
74
75 template
78
79
80
82
84
87 return false;
88
89
91 LogicalResult result = op->fold(std::nullopt, foldedOp);
92 (void)result;
93 assert(succeeded(result) && "expected ConstantLike op to be foldable");
94
95 if (auto attr = llvm::dyn_cast(cast(foldedOp.front()))) {
98 return true;
99 }
100 return false;
101 }
102 };
103
104
105
108
111
113 auto inferIntRangeOp = dyn_cast(op);
114 if (!inferIntRangeOp)
115 return false;
116
117
120
121
122 bool matched = false;
123 auto setResultRanges = [&](Value value,
125 if (argRanges.isUninitialized())
126 return;
128 return;
130 matched = true;
131 };
132 inferIntRangeOp.inferResultRangesFromOptional(argRanges, setResultRanges);
133 return matched;
134 }
135 };
136
137
138
139 template
141
142
145
147
152 return true;
153 }
154 return false;
155 }
158 };
159
160
161
164
165
167
170 if (matcher.match(attr))
171 return true;
172
173 if (auto splatAttr = dyn_cast(attr))
174 return matcher.match(splatAttr.getSplatValue<Attribute>());
175
176 return false;
177 }
178
182 return false;
183
185 if (isa<FloatType, VectorType, RankedTensorType>(type))
186 return match(attr);
187
188 return false;
189 }
190 };
191
192
193
196
198 APFloat value(APFloat::Bogus());
200 }
201
203 APFloat value(APFloat::Bogus());
205 }
206 };
207
208
209
212
213
215
218 if (matcher.match(attr))
219 return true;
220
221 if (auto splatAttr = dyn_cast(attr))
222 return matcher.match(splatAttr.getSplatValue<Attribute>());
223
224 return false;
225 }
226
230 return false;
231
233 if (isa<IntegerType, IndexType, VectorType, RankedTensorType>(type))
234 return match(attr);
235
236 return false;
237 }
238 };
239
240
241
244
246 APInt value;
248 }
249
251 APInt value;
253 }
254 };
255
256
257
260
262 APInt value;
265 }
266
268
269 APInt value;
272
273
274
278 }
279 };
280
281
282 template
285 };
286
287
288
289 template <typename T, typename MatchTarget>
291 decltype(std::declval().match(std::declval()));
292
293
294 template
296 MatcherClass, Value>::value,
297 bool>
299 return matcher.match(op->getOperand(idx));
300 }
301
302
303 template
305 MatcherClass, Operation *>::value,
306 bool>
309 return matcher.match(defOp);
310 return false;
311 }
312
313
316 };
317
318
324 return true;
325 }
326 };
327
328
333 };
334
335 template <typename TupleT, class CallbackT, std::size_t... Is>
336 constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback,
337 std::index_sequence<Is...>) {
338
339 (callback(std::integral_constant<std::size_t, Is>{}, std::get(tuple)),
340 ...);
341 }
342
343 template <typename... Tys, typename CallbackT>
344 constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) {
346 std::make_index_sequence<sizeof...(Tys)>{});
347 }
348
349
350 template <typename OpType, typename... OperandMatchers>
355 if (!isa(op) || op->getNumOperands() != sizeof...(OperandMatchers))
356 return false;
357 bool res = true;
360 });
361 return res;
362 }
364 };
365
366 }
367
368
371 }
372
373
376 }
377
378
381 }
382
383
384
385 template
388 }
389
390
391 template
393 AttrT *bindValue) {
395 }
396
397
398
400 return {[](const APFloat &value) { return value.isZero(); }};
401 }
402
403
405 return {[](const APFloat &value) { return value.isPosZero(); }};
406 }
407
408
410 return {[](const APFloat &value) { return value.isNegZero(); }};
411 }
412
413
415 return {[](const APFloat &value) {
416 return APFloat(value.getSemantics(), 1) == value;
417 }};
418 }
419
420
422 return {[](const APFloat &value) { return value.isNaN(); }};
423 }
424
425
426
428 return {[](const APFloat &value) {
429 return !value.isNegative() && value.isInfinity();
430 }};
431 }
432
433
434
436 return {[](const APFloat &value) {
437 return value.isNegative() && value.isInfinity();
438 }};
439 }
440
441
443 return {[](const APInt &value) { return 0 == value; }};
444 }
445
446
447
449 return {[](const APInt &value) { return 0 != value; }};
450 }
451
452
453
454
456 return {[](const ConstantIntRanges &range) { return range.umin().ugt(0); }};
457 }
458
459
460
461
464 return range.smin().sgt(0) || range.smax().slt(0);
465 }};
466 }
467
468
469
470
473 return range.smin().sgt(-1) || range.smax().slt(-1);
474 }};
475 }
476
477
479 return {[](const APInt &value) { return 1 == value; }};
480 }
481
482
483 template
486 }
487
488
489 template
491 assert(value);
492
494 return const_cast<Pattern &>(pattern).match(op);
495 return false;
496 }
497
498
499 template
501 assert(op);
502 return const_cast<Pattern &>(pattern).match(op);
503 }
504
505
506
507 template
511 "Pattern does not support matching Attributes");
512 if (!attr)
513 return false;
514 return const_cast<Pattern &>(pattern).match(attr);
515 }
516
517
518
519 inline detail::constant_float_value_binder
522 }
523
524
525
526 inline detail::constant_int_value_binder
529 }
530
531 template <typename OpType, typename... Matchers>
532 auto m_Op(Matchers... matchers) {
534 }
535
536 namespace matchers {
540 }
541
542 }
543
544 #endif
Attributes are known-constant values of operations.
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
This lattice value represents the integer range of an SSA value.
const ConstantIntRanges & getValue() const
Get the known integer value range.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
This class provides the API for a sub-set of ops that are known to be constant-like.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
AttrClass getAttrOfType(StringAttr name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback, std::index_sequence< Is... >)
std::enable_if_t< llvm::is_detected< detail::has_compatible_matcher_t, MatcherClass, Value >::value, bool > matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher)
Statically switch to a Value matcher.
decltype(std::declval< T >().match(std::declval< MatchTarget >())) has_compatible_matcher_t
Trait to check whether T provides a 'match' method with type MatchTarget (Value, Operation,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::AttrOpMatcher m_Attr(StringRef attrName)
Matches a named attribute operation.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::NameOpMatcher m_Op(StringRef opName)
Matches a named operation.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
detail::constant_int_predicate_matcher m_NonZero()
Matches a constant scalar / vector splat / tensor splat integer that is any non-zero value.
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Terminal matcher, always returns true.
AnyCapturedValueMatcher(Value *what)
bool match(Value op) const
Terminal matcher, always returns true.
bool match(Value op) const
The matcher that matches operations that have the specified attribute name, and binds the attribute v...
AttrOpBinder(StringRef attrName, AttrT *bindValue)
Creates a matcher instance that binds the attribute value to bind_value if match succeeds.
bool match(Operation *op)
AttrOpBinder(StringRef attrName)
Creates a matcher instance that doesn't bind if match succeeds.
The matcher that matches operations that have the specified attribute name.
bool match(Operation *op)
AttrOpMatcher(StringRef attrName)
The matcher that matches operations that have the specified op name.
NameOpMatcher(StringRef name)
bool match(Operation *op)
Binds to a specific value and matches it.
bool match(Value val) const
PatternMatcherValue(Value val)
RecursivePatternMatcher that composes.
RecursivePatternMatcher(OperandMatchers... matchers)
std::tuple< OperandMatchers... > operandMatchers
bool match(Operation *op)
The matcher that matches a certain kind of Attribute and binds the value inside the Attribute.
attr_value_binder(ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
bool match(Attribute attr)
The matcher that matches a given target constant scalar / vector splat / tensor splat float value tha...
bool match(Operation *op)
bool(* predicate)(const APFloat &)
bool match(Attribute attr)
The matcher that matches a constant scalar / vector splat / tensor splat float Attribute or Operation...
bool match(Attribute attr)
constant_float_value_binder(FloatAttr::ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
bool match(Operation *op)
FloatAttr::ValueType * bind_value
The matcher that matches a given target constant scalar / vector splat / tensor splat integer value t...
bool match(Operation *op)
bool(* predicate)(const APInt &)
bool match(Attribute attr)
A matcher that matches a given a constant scalar / vector splat / tensor splat integer value or a con...
bool match(Attribute attr)
bool match(Operation *op)
bool(* predicate)(const ConstantIntRanges &)
The matcher that matches a constant scalar / vector splat / tensor splat integer Attribute or Operati...
constant_int_value_binder(IntegerAttr::ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
bool match(Attribute attr)
IntegerAttr::ValueType * bind_value
bool match(Operation *op)
The matcher that matches operations that have the ConstantLike trait, and binds the folded attribute ...
constant_op_binder()
Creates a matcher instance that doesn't bind if match succeeds.
constant_op_binder(AttrT *bind_value)
Creates a matcher instance that binds the constant attribute value to bind_value if match succeeds.
bool match(Operation *op)
The matcher that matches operations that have the ConstantLike trait.
bool match(Operation *op)
A matcher that matches operations that implement the InferIntRangeInterface interface,...
IntegerValueRange * bind_value
infer_int_range_op_binder(IntegerValueRange *bind_value)
bool match(Operation *op)
The matcher that matches a certain kind of op.
bool match(Operation *op)