MLIR: include/mlir/Dialect/Transform/Interfaces/MatchInterfaces.h Source File (original) (raw)

1

2

3

4

5

6

7

8

9 #ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H

10 #define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H

11

12 #include

13 #include <type_traits>

14

17 #include "llvm/ADT/STLExtras.h"

18

19 namespace mlir {

20 namespace transform {

21 class MatchOpInterface;

22

23 namespace detail {

24

25

26 template

30 if constexpr (std::is_same_v<

31 typename llvm::function_traits<

32 decltype(&OpTy::matchOperation)>::template arg_t<0>,

34 return op.matchOperation(nullptr, results, state);

35 } else {

36 return op.matchOperation(std::nullopt, results, state);

37 }

38 }

39 }

40

41 template

44 template

45 using has_get_operand_handle =

46 decltype(std::declval<T &>().getOperandHandle());

47 template

48 using has_match_operation_ptr = decltype(std::declval<T &>().matchOperation(

49 std::declval<Operation *>(), std::declval<TransformResults &>(),

50 std::declval<TransformState &>()));

51 template

52 using has_match_operation_optional =

53 decltype(std::declval<T &>().matchOperation(

54 std::declval<std::optional<Operation *>>(),

55 std::declval<TransformResults &>(),

56 std::declval<TransformState &>()));

57

58 public:

60 static_assert(llvm::is_detected<has_get_operand_handle, OpTy>::value,

61 "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expects "

62 "operation type to have the getOperandHandle() method");

63 static_assert(

64 llvm::is_detected<has_match_operation_ptr, OpTy>::value ||

65 llvm::is_detected<has_match_operation_optional, OpTy>::value,

66 "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expected operation "

67 "type to have either the matchOperation(Operation *, TransformResults "

68 "&, TransformState &) or the matchOperation(std::optional<Operation*>, "

69 "TransformResults &, TransformState &) method");

70

71

72 assert(

73 isa(op) &&

74 "AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait is only available on "

75 "operations with MatchOpInterface");

76 Value operandHandle = cast(op).getOperandHandle();

77 if (!isa(operandHandle.getType())) {

78 return op->emitError() << "AtMostOneOpMatcherOpTrait/"

79 "SingleOpMatchOpTrait requires the op handle "

80 "to be of TransformHandleTypeInterface";

81 }

82

83 return success();

84 }

85

89 Value operandHandle = cast(this->getOperation()).getOperandHandle();

90 auto payload = state.getPayloadOps(operandHandle);

91 if (!llvm::hasNItemsOrLess(payload, 1)) {

93 << "AtMostOneOpMatcherOpTrait requires the operand handle to "

94 "point to at most one payload op";

95 }

96 if (payload.empty()) {

98 results, state);

99 }

101 .matchOperation(*payload.begin(), results, state);

102 }

103

108 }

109 };

110

111 template

113

114 public:

118 Value operandHandle = cast(this->getOperation()).getOperandHandle();

119 auto payload = state.getPayloadOps(operandHandle);

120 if (!llvm::hasSingleElement(payload)) {

122 << "SingleOpMatchOpTrait requires the operand handle to point to "

123 "a single payload op";

124 }

126 rewriter, results, state);

127 }

128 };

129

130 template

133 public:

135

136

137 assert(isa(op) &&

138 "SingleValueMatchOpTrait is only available on operations with "

139 "MatchOpInterface");

140

141 Value operandHandle = cast(op).getOperandHandle();

142 if (!isa(operandHandle.getType())) {

143 return op->emitError() << "SingleValueMatchOpTrait requires an operand "

144 "of TransformValueHandleTypeInterface";

145 }

146

147 return success();

148 }

149

153 Value operandHandle = cast(this->getOperation()).getOperandHandle();

154 auto payload = state.getPayloadValues(operandHandle);

155 if (!llvm::hasSingleElement(payload)) {

157 << "SingleValueMatchOpTrait requires the value handle to point "

158 "to a single payload value";

159 }

160

162 .matchValue(*payload.begin(), results, state);

163 }

164

169 }

170 };

171

172

173

174

175

176

177

178

179

180

181

182

185 UnitAttr &isInverted, UnitAttr &isAll);

186

187

190 UnitAttr isAll);

191

192

193

194

195

196

197

199 bool inverted, bool all);

200

201

202

203

204

205

206

207

208

209

210

211

212

217

218 }

219 }

220

221 #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h.inc"

222

223 #endif

The result of a transform IR operation application.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

Helper class for implementing traits.

Operation * getOperation()

Return the ultimate Operation being worked on.

Operation is the basic unit of execution within MLIR.

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...

void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

static LogicalResult verifyTrait(Operation *op)

DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, TransformResults &results, TransformState &state)

DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, TransformResults &results, TransformState &state)

static LogicalResult verifyTrait(Operation *op)

void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, TransformResults &results, TransformState &state)

Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...

This is a special rewriter to be used in transform op implementations, providing additional helper fu...

The state maintained across applications of various ops implementing the TransformOpInterface.

DiagnosedSilenceableFailure matchOptionalOperation(OpTy op, TransformResults &results, TransformState &state)

Dispatch matchOperation based on Operation* or std::optional<Operation*> first operand.

LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)

Checks if the positional specification defined is valid and reports errors otherwise.

void printTransformMatchDims(OpAsmPrinter &printer, Operation *op, DenseI64ArrayAttr rawDimList, UnitAttr isInverted, UnitAttr isAll)

Prints a positional index specification for transform match operations.

void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)

Populates result with the positional identifiers relative to maxNumber.

ParseResult parseTransformMatchDims(OpAsmParser &parser, DenseI64ArrayAttr &rawDimList, UnitAttr &isInverted, UnitAttr &isAll)

Parses a positional index specification for transform match operations.

void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

Include the generated interface declarations.

DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})

Emits a definite failure with the given message.