MLIR: lib/Dialect/Vector/Transforms/LowerVectorScan.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

34

35 #define DEBUG_TYPE "vector-broadcast-lowering"

36

37 using namespace mlir;

39

40

41

43 using vector::CombiningKind;

44 enum class KindType { FLOAT, INT, INVALID };

45 KindType type{KindType::INVALID};

46 switch (kind) {

47 case CombiningKind::MINNUMF:

48 case CombiningKind::MINIMUMF:

49 case CombiningKind::MAXNUMF:

50 case CombiningKind::MAXIMUMF:

51 type = KindType::FLOAT;

52 break;

54 case CombiningKind::MINSI:

55 case CombiningKind::MAXUI:

56 case CombiningKind::MAXSI:

57 case CombiningKind::AND:

58 case CombiningKind::OR:

59 case CombiningKind::XOR:

60 type = KindType::INT;

61 break;

62 case CombiningKind::ADD:

63 case CombiningKind::MUL:

64 type = isInt ? KindType::INT : KindType::FLOAT;

65 break;

66 }

67 bool isValidIntKind = (type == KindType::INT) && isInt;

68 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);

69 return (isValidIntKind || isValidFloatKind);

70 }

71

72 namespace {

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111 struct ScanToArithOps : public OpRewritePatternvector::ScanOp {

113

114 LogicalResult matchAndRewrite(vector::ScanOp scanOp,

116 auto loc = scanOp.getLoc();

117 VectorType destType = scanOp.getDestType();

119 auto elType = destType.getElementType();

120 bool isInt = elType.isIntOrIndex();

121 if (isValidKind(isInt, scanOp.getKind()))

122 return failure();

123

125 Value result = rewriter.createarith::ConstantOp(

126 loc, resType, rewriter.getZeroAttr(resType));

127 int64_t reductionDim = scanOp.getReductionDim();

128 bool inclusive = scanOp.getInclusive();

129 int64_t destRank = destType.getRank();

130 VectorType initialValueType = scanOp.getInitialValueType();

131 int64_t initialValueRank = initialValueType.getRank();

132

134 reductionShape[reductionDim] = 1;

135 VectorType reductionType = VectorType::get(reductionShape, elType);

139 sizes[reductionDim] = 1;

141 ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);

142

143 Value lastOutput, lastInput;

144 for (int i = 0; i < destShape[reductionDim]; i++) {

145 offsets[reductionDim] = i;

146 ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);

147 Value input = rewriter.createvector::ExtractStridedSliceOp(

148 loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,

149 scanStrides);

151 if (i == 0) {

152 if (inclusive) {

153 output = input;

154 } else {

155 if (initialValueRank == 0) {

156

157 output = rewriter.createvector::BroadcastOp(

158 loc, input.getType(), scanOp.getInitialValue());

159 } else {

160 output = rewriter.createvector::ShapeCastOp(

161 loc, input.getType(), scanOp.getInitialValue());

162 }

163 }

164 } else {

165 Value y = inclusive ? input : lastInput;

167 lastOutput, y);

168 }

169 result = rewriter.createvector::InsertStridedSliceOp(

170 loc, output, result, offsets, strides);

171 lastOutput = output;

172 lastInput = input;

173 }

174

176 if (initialValueRank == 0) {

177 Value v = rewriter.createvector::ExtractOp(loc, lastOutput, 0);

178 reduction =

179 rewriter.createvector::BroadcastOp(loc, initialValueType, v);

180 } else {

181 reduction = rewriter.createvector::ShapeCastOp(loc, initialValueType,

182 lastOutput);

183 }

184

185 rewriter.replaceOp(scanOp, {result, reduction});

186 return success();

187 }

188 };

189 }

190

193 patterns.add(patterns.getContext(), benefit);

194 }

union mlir::linalg::@1203::ArityGroupAndKind::Kind kind

static bool isValidKind(bool isInt, vector::CombiningKind kind)

This function checks to see if the vector combining kind is consistent with the integer or float elem...

TypedAttr getZeroAttr(Type type)

ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)

Returns the result value of reducing two scalar/vector values with the corresponding arith operation.

void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)

Populate the pattern set with the following patterns:

Include the generated interface declarations.

const FrozenRewritePatternSet & patterns

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...