MLIR: lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

17

18 using namespace mlir;

20

21

22

23

24

25

26

27

28

29

30

31

32

33

36 public:

38

41 auto srcType = op.getSourceVectorType();

42 auto dstType = op.getDestVectorType();

43

44 if (op.getOffsets().getValue().empty())

45 return failure();

46

47 auto loc = op.getLoc();

48 int64_t rankDiff = dstType.getRank() - srcType.getRank();

49 assert(rankDiff >= 0);

50 if (rankDiff == 0)

51 return failure();

52

53 int64_t rankRest = dstType.getRank() - rankDiff;

54

55

56 Value extracted = rewriter.create(

57 loc, op.getDest(),

59 rankRest));

60

61

62

63 auto stridedSliceInnerOp = rewriter.create(

64 loc, op.getValueToStore(), extracted,

65 getI64SubArray(op.getOffsets(), rankDiff),

67

69 op, stridedSliceInnerOp.getResult(), op.getDest(),

71 rankRest));

72 return success();

73 }

74 };

75

76

77

78

79

80

81

82

83

86 public:

88

90

91

92 setHasBoundedRewriteRecursion();

93 }

94

97 auto srcType = op.getSourceVectorType();

98 auto dstType = op.getDestVectorType();

99 int64_t srcRank = srcType.getRank();

100

101

102 if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1)

103 return failure();

104

105 if (op.getOffsets().getValue().empty())

106 return failure();

107

108 int64_t dstRank = dstType.getRank();

109 assert(dstRank >= srcRank);

110 if (dstRank != srcRank)

111 return failure();

112

113 if (srcType == dstType) {

114 rewriter.replaceOp(op, op.getValueToStore());

115 return success();

116 }

117

118 int64_t offset =

119 cast(op.getOffsets().getValue().front()).getInt();

120 int64_t size = srcType.getShape().front();

121 int64_t stride =

122 cast(op.getStrides().getValue().front()).getInt();

123

124 auto loc = op.getLoc();

125 Value res = op.getDest();

126

127 if (srcRank == 1) {

128 int nSrc = srcType.getShape().front();

129 int nDest = dstType.getShape().front();

130

132 for (int64_t i = 0; i < nSrc; ++i)

133 offsets[i] = i;

134 Value scaledSource = rewriter.create(

135 loc, op.getValueToStore(), op.getValueToStore(), offsets);

136

137

138

139 offsets.clear();

140 for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {

141 if (i < offset || i >= e || (i - offset) % stride != 0)

142 offsets.push_back(nDest + i);

143 else

144 offsets.push_back((i - offset) / stride);

145 }

146

147

148 rewriter.replaceOpWithNewOp(op, scaledSource, op.getDest(),

149 offsets);

150

151 return success();

152 }

153

154

155 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;

156 off += stride, ++idx) {

157

158 Value extractedSource =

159 rewriter.create(loc, op.getValueToStore(), idx);

160 if (isa(extractedSource.getType())) {

161

162

163 Value extractedDest =

164 rewriter.create(loc, op.getDest(), off);

165

166

167 extractedSource = rewriter.create(

168 loc, extractedSource, extractedDest,

171 }

172

173 res = rewriter.create(loc, extractedSource, res, off);

174 }

175

177 return success();

178 }

179 };

180

181

182

185 public:

187

190 auto dstType = op.getType();

191 auto srcType = op.getSourceVectorType();

192

193

194 if (dstType.isScalable() || srcType.isScalable())

195 return failure();

196

197 assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");

198

199 int64_t offset =

200 cast(op.getOffsets().getValue().front()).getInt();

201 int64_t size = cast(op.getSizes().getValue().front()).getInt();

202 int64_t stride =

203 cast(op.getStrides().getValue().front()).getInt();

204

205 assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());

206

207

208 if (op.getOffsets().getValue().size() != 1)

209 return failure();

210

212 offsets.reserve(size);

213 for (int64_t off = offset, e = offset + size * stride; off < e;

214 off += stride)

215 offsets.push_back(off);

217 op.getVector(), offsets);

218 return success();

219 }

220 };

221

222

223

224

227 public:

230 std::function<bool(ExtractStridedSliceOp)> controlFn,

232 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}

233

236 if (controlFn && !controlFn(op))

237 return failure();

238

239

240 if (op.getOffsets().getValue().size() != 1)

241 return failure();

242

243 int64_t offset =

244 cast(op.getOffsets().getValue().front()).getInt();

245 int64_t size = cast(op.getSizes().getValue().front()).getInt();

246 int64_t stride =

247 cast(op.getStrides().getValue().front()).getInt();

248

251 elements.reserve(size);

252 for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)

253 elements.push_back(rewriter.create(loc, op.getVector(), i));

254

255 Value result = rewriter.createarith::ConstantOp(

256 loc, rewriter.getZeroAttr(op.getType()));

257 for (int64_t i = 0; i < size; ++i)

258 result = rewriter.create(loc, elements[i], result, i);

259

261 return success();

262 }

263

264 private:

265 std::function<bool(ExtractStridedSliceOp)> controlFn;

266 };

267

268

269

270

273 public:

275

277

278

279 setHasBoundedRewriteRecursion();

280 }

281

284 auto dstType = op.getType();

285

286 assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");

287

288 int64_t offset =

289 cast(op.getOffsets().getValue().front()).getInt();

290 int64_t size = cast(op.getSizes().getValue().front()).getInt();

291 int64_t stride =

292 cast(op.getStrides().getValue().front()).getInt();

293

294 auto loc = op.getLoc();

295 auto elemType = dstType.getElementType();

296 assert(elemType.isSignlessIntOrIndexOrFloat());

297

298

299

300 if (op.getOffsets().getValue().size() == 1)

301 return failure();

302

303

304 Value zero = rewriter.createarith::ConstantOp(

305 loc, elemType, rewriter.getZeroAttr(elemType));

306 Value res = rewriter.create(loc, dstType, zero);

307 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;

308 off += stride, ++idx) {

309 Value one = rewriter.create(loc, op.getVector(), off);

310 Value extracted = rewriter.create(

311 loc, one, getI64SubArray(op.getOffsets(), 1),

314 res = rewriter.create(loc, extracted, res, idx);

315 }

317 return success();

318 }

319 };

320

321

322

323 void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(

327 }

328

329 void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(

331 std::function<bool(ExtractStridedSliceOp)> controlFn,

334 patterns.getContext(), std::move(controlFn), benefit);

335 }

336

337

338 void vector::populateVectorInsertExtractStridedSliceTransforms(

340 populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns,

341 benefit);

344 benefit);

345

346

347 populateVectorExtractStridedSliceToExtractInsertChainPatterns(

349

350 [](ExtractStridedSliceOp op) {

351 return op.getType().isScalable() ||

352 op.getSourceVectorType().isScalable();

353 },

354 benefit);

355 }

RewritePattern for InsertStridedSliceOp where source and destination vectors have the same rank.

LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override

RewritePattern for InsertStridedSliceOp where source and destination vectors have different ranks.

LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override

TypedAttr getZeroAttr(Type type)

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Include the generated interface declarations.

SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)

Helper to return a subset of arrayAttr as a vector of int64_t.

const FrozenRewritePatternSet & patterns

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...