MLIR: include/mlir/Dialect/Transform/Interfaces/MatchInterfaces.h Source File (original) (raw)
1
2
3
4
5
6
7
8
9 #ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
10 #define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
11
12 #include
13 #include <type_traits>
14
17 #include "llvm/ADT/STLExtras.h"
18
19 namespace mlir {
20 namespace transform {
21 class MatchOpInterface;
22
23 namespace detail {
24
25
26 template
30 if constexpr (std::is_same_v<
31 typename llvm::function_traits<
32 decltype(&OpTy::matchOperation)>::template arg_t<0>,
34 return op.matchOperation(nullptr, results, state);
35 } else {
36 return op.matchOperation(std::nullopt, results, state);
37 }
38 }
39 }
40
41 template
44 template
45 using has_get_operand_handle =
46 decltype(std::declval<T &>().getOperandHandle());
47 template
48 using has_match_operation_ptr = decltype(std::declval<T &>().matchOperation(
49 std::declval<Operation *>(), std::declval<TransformResults &>(),
50 std::declval<TransformState &>()));
51 template
52 using has_match_operation_optional =
53 decltype(std::declval<T &>().matchOperation(
54 std::declval<std::optional<Operation *>>(),
55 std::declval<TransformResults &>(),
56 std::declval<TransformState &>()));
57
58 public:
60 static_assert(llvm::is_detected<has_get_operand_handle, OpTy>::value,
61 "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expects "
62 "operation type to have the getOperandHandle() method");
63 static_assert(
64 llvm::is_detected<has_match_operation_ptr, OpTy>::value ||
65 llvm::is_detected<has_match_operation_optional, OpTy>::value,
66 "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expected operation "
67 "type to have either the matchOperation(Operation *, TransformResults "
68 "&, TransformState &) or the matchOperation(std::optional<Operation*>, "
69 "TransformResults &, TransformState &) method");
70
71
72 assert(
73 isa(op) &&
74 "AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait is only available on "
75 "operations with MatchOpInterface");
76 Value operandHandle = cast(op).getOperandHandle();
77 if (!isa(operandHandle.getType())) {
78 return op->emitError() << "AtMostOneOpMatcherOpTrait/"
79 "SingleOpMatchOpTrait requires the op handle "
80 "to be of TransformHandleTypeInterface";
81 }
82
83 return success();
84 }
85
89 Value operandHandle = cast(this->getOperation()).getOperandHandle();
90 auto payload = state.getPayloadOps(operandHandle);
91 if (!llvm::hasNItemsOrLess(payload, 1)) {
93 << "AtMostOneOpMatcherOpTrait requires the operand handle to "
94 "point to at most one payload op";
95 }
96 if (payload.empty()) {
98 results, state);
99 }
101 .matchOperation(*payload.begin(), results, state);
102 }
103
108 }
109 };
110
111 template
113
114 public:
118 Value operandHandle = cast(this->getOperation()).getOperandHandle();
119 auto payload = state.getPayloadOps(operandHandle);
120 if (!llvm::hasSingleElement(payload)) {
122 << "SingleOpMatchOpTrait requires the operand handle to point to "
123 "a single payload op";
124 }
126 rewriter, results, state);
127 }
128 };
129
130 template
133 public:
135
136
137 assert(isa(op) &&
138 "SingleValueMatchOpTrait is only available on operations with "
139 "MatchOpInterface");
140
141 Value operandHandle = cast(op).getOperandHandle();
142 if (!isa(operandHandle.getType())) {
143 return op->emitError() << "SingleValueMatchOpTrait requires an operand "
144 "of TransformValueHandleTypeInterface";
145 }
146
147 return success();
148 }
149
153 Value operandHandle = cast(this->getOperation()).getOperandHandle();
154 auto payload = state.getPayloadValues(operandHandle);
155 if (!llvm::hasSingleElement(payload)) {
157 << "SingleValueMatchOpTrait requires the value handle to point "
158 "to a single payload value";
159 }
160
162 .matchValue(*payload.begin(), results, state);
163 }
164
169 }
170 };
171
172
173
174
175
176
177
178
179
180
181
182
185 UnitAttr &isInverted, UnitAttr &isAll);
186
187
190 UnitAttr isAll);
191
192
193
194
195
196
197
199 bool inverted, bool all);
200
201
202
203
204
205
206
207
208
209
210
211
212
217
218 }
219 }
220
221 #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h.inc"
222
223 #endif
The result of a transform IR operation application.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Helper class for implementing traits.
Operation * getOperation()
Return the ultimate Operation being worked on.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
static LogicalResult verifyTrait(Operation *op)
DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, TransformResults &results, TransformState &state)
DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, TransformResults &results, TransformState &state)
static LogicalResult verifyTrait(Operation *op)
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, TransformResults &results, TransformState &state)
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
DiagnosedSilenceableFailure matchOptionalOperation(OpTy op, TransformResults &results, TransformState &state)
Dispatch matchOperation based on Operation* or std::optional<Operation*> first operand.
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void printTransformMatchDims(OpAsmPrinter &printer, Operation *op, DenseI64ArrayAttr rawDimList, UnitAttr isInverted, UnitAttr isAll)
Prints a positional index specification for transform match operations.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
ParseResult parseTransformMatchDims(OpAsmParser &parser, DenseI64ArrayAttr &rawDimList, UnitAttr &isInverted, UnitAttr &isAll)
Parses a positional index specification for transform match operations.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Include the generated interface declarations.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.