MLIR: lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
17
18 using namespace mlir;
20
21
22
23
24
25
26
27
28
29
30
31
32
33
36 public:
38
41 auto srcType = op.getSourceVectorType();
42 auto dstType = op.getDestVectorType();
43
44 if (op.getOffsets().getValue().empty())
45 return failure();
46
47 auto loc = op.getLoc();
48 int64_t rankDiff = dstType.getRank() - srcType.getRank();
49 assert(rankDiff >= 0);
50 if (rankDiff == 0)
51 return failure();
52
53 int64_t rankRest = dstType.getRank() - rankDiff;
54
55
56 Value extracted = rewriter.create(
57 loc, op.getDest(),
59 rankRest));
60
61
62
63 auto stridedSliceInnerOp = rewriter.create(
64 loc, op.getValueToStore(), extracted,
65 getI64SubArray(op.getOffsets(), rankDiff),
67
69 op, stridedSliceInnerOp.getResult(), op.getDest(),
71 rankRest));
72 return success();
73 }
74 };
75
76
77
78
79
80
81
82
83
86 public:
88
90
91
92 setHasBoundedRewriteRecursion();
93 }
94
97 auto srcType = op.getSourceVectorType();
98 auto dstType = op.getDestVectorType();
99 int64_t srcRank = srcType.getRank();
100
101
102 if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1)
103 return failure();
104
105 if (op.getOffsets().getValue().empty())
106 return failure();
107
108 int64_t dstRank = dstType.getRank();
109 assert(dstRank >= srcRank);
110 if (dstRank != srcRank)
111 return failure();
112
113 if (srcType == dstType) {
114 rewriter.replaceOp(op, op.getValueToStore());
115 return success();
116 }
117
118 int64_t offset =
119 cast(op.getOffsets().getValue().front()).getInt();
120 int64_t size = srcType.getShape().front();
121 int64_t stride =
122 cast(op.getStrides().getValue().front()).getInt();
123
124 auto loc = op.getLoc();
125 Value res = op.getDest();
126
127 if (srcRank == 1) {
128 int nSrc = srcType.getShape().front();
129 int nDest = dstType.getShape().front();
130
132 for (int64_t i = 0; i < nSrc; ++i)
133 offsets[i] = i;
134 Value scaledSource = rewriter.create(
135 loc, op.getValueToStore(), op.getValueToStore(), offsets);
136
137
138
139 offsets.clear();
140 for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
141 if (i < offset || i >= e || (i - offset) % stride != 0)
142 offsets.push_back(nDest + i);
143 else
144 offsets.push_back((i - offset) / stride);
145 }
146
147
148 rewriter.replaceOpWithNewOp(op, scaledSource, op.getDest(),
149 offsets);
150
151 return success();
152 }
153
154
155 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
156 off += stride, ++idx) {
157
158 Value extractedSource =
159 rewriter.create(loc, op.getValueToStore(), idx);
160 if (isa(extractedSource.getType())) {
161
162
163 Value extractedDest =
164 rewriter.create(loc, op.getDest(), off);
165
166
167 extractedSource = rewriter.create(
168 loc, extractedSource, extractedDest,
171 }
172
173 res = rewriter.create(loc, extractedSource, res, off);
174 }
175
177 return success();
178 }
179 };
180
181
182
185 public:
187
190 auto dstType = op.getType();
191 auto srcType = op.getSourceVectorType();
192
193
194 if (dstType.isScalable() || srcType.isScalable())
195 return failure();
196
197 assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
198
199 int64_t offset =
200 cast(op.getOffsets().getValue().front()).getInt();
201 int64_t size = cast(op.getSizes().getValue().front()).getInt();
202 int64_t stride =
203 cast(op.getStrides().getValue().front()).getInt();
204
205 assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
206
207
208 if (op.getOffsets().getValue().size() != 1)
209 return failure();
210
212 offsets.reserve(size);
213 for (int64_t off = offset, e = offset + size * stride; off < e;
214 off += stride)
215 offsets.push_back(off);
217 op.getVector(), offsets);
218 return success();
219 }
220 };
221
222
223
224
227 public:
230 std::function<bool(ExtractStridedSliceOp)> controlFn,
232 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
233
236 if (controlFn && !controlFn(op))
237 return failure();
238
239
240 if (op.getOffsets().getValue().size() != 1)
241 return failure();
242
243 int64_t offset =
244 cast(op.getOffsets().getValue().front()).getInt();
245 int64_t size = cast(op.getSizes().getValue().front()).getInt();
246 int64_t stride =
247 cast(op.getStrides().getValue().front()).getInt();
248
251 elements.reserve(size);
252 for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
253 elements.push_back(rewriter.create(loc, op.getVector(), i));
254
255 Value result = rewriter.createarith::ConstantOp(
256 loc, rewriter.getZeroAttr(op.getType()));
257 for (int64_t i = 0; i < size; ++i)
258 result = rewriter.create(loc, elements[i], result, i);
259
261 return success();
262 }
263
264 private:
265 std::function<bool(ExtractStridedSliceOp)> controlFn;
266 };
267
268
269
270
273 public:
275
277
278
279 setHasBoundedRewriteRecursion();
280 }
281
284 auto dstType = op.getType();
285
286 assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
287
288 int64_t offset =
289 cast(op.getOffsets().getValue().front()).getInt();
290 int64_t size = cast(op.getSizes().getValue().front()).getInt();
291 int64_t stride =
292 cast(op.getStrides().getValue().front()).getInt();
293
294 auto loc = op.getLoc();
295 auto elemType = dstType.getElementType();
296 assert(elemType.isSignlessIntOrIndexOrFloat());
297
298
299
300 if (op.getOffsets().getValue().size() == 1)
301 return failure();
302
303
304 Value zero = rewriter.createarith::ConstantOp(
305 loc, elemType, rewriter.getZeroAttr(elemType));
306 Value res = rewriter.create(loc, dstType, zero);
307 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
308 off += stride, ++idx) {
309 Value one = rewriter.create(loc, op.getVector(), off);
310 Value extracted = rewriter.create(
311 loc, one, getI64SubArray(op.getOffsets(), 1),
314 res = rewriter.create(loc, extracted, res, idx);
315 }
317 return success();
318 }
319 };
320
321
322
323 void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
327 }
328
329 void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
331 std::function<bool(ExtractStridedSliceOp)> controlFn,
334 patterns.getContext(), std::move(controlFn), benefit);
335 }
336
337
338 void vector::populateVectorInsertExtractStridedSliceTransforms(
340 populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns,
341 benefit);
344 benefit);
345
346
347 populateVectorExtractStridedSliceToExtractInsertChainPatterns(
349
350 [](ExtractStridedSliceOp op) {
351 return op.getType().isScalable() ||
352 op.getSourceVectorType().isScalable();
353 },
354 benefit);
355 }
RewritePattern for InsertStridedSliceOp where source and destination vectors have the same rank.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have different ranks.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Include the generated interface declarations.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...