MLIR: lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
22
23namespace mlir {
25#define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATIONPASS
26#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
27}
28}
29
30using namespace mlir;
32
33
34
35
36
37
38
40 while (auto viewLikeOp = value.getDefiningOp()) {
41 if (value != viewLikeOp.getViewDest()) {
42 break;
43 }
44 value = viewLikeOp.getViewSource();
45 }
46 return value;
47}
48
53 if (deallocOp.getMemrefs() == memrefs &&
54 deallocOp.getConditions() == conditions)
55 return failure();
56
58 deallocOp.getMemrefsMutable().assign(memrefs);
59 deallocOp.getConditionsMutable().assign(conditions);
60 });
62}
63
64
65
66
67
68
69
73 auto areDistinct = [](Value v1, Value v2) {
76 if (auto bbArg = dyn_cast(v2))
77 if (bbArg.getOwner()->findAncestorOpInBlock(*op))
78 return true;
79 return false;
80 };
81 return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base);
82}
83
84
85
86
89 for (auto other : otherList) {
91 continue;
92 std::optional analysisResult =
93 analysis.isSameAllocation(other, memref);
94 if (!analysisResult.has_value() || analysisResult == true)
95 return true;
96 }
97 return false;
98}
99
100
101
102
103
104namespace {
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135struct RemoveDeallocMemrefsContainedInRetained
137 RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
138 BufferOriginAnalysis &analysis)
139 : OpRewritePattern(context), analysis(analysis) {}
140
141
142
143
144
145
146
147 LogicalResult handleOneMemref(DeallocOp deallocOp, Value memref, Value cond,
148 PatternRewriter &rewriter) const {
150
151
152
153
154 bool atLeastOneMustAlias = false;
155 for (Value retained : deallocOp.getRetained()) {
156 std::optional analysisResult =
157 analysis.isSameAllocation(retained, memref);
158 if (!analysisResult.has_value())
159 return failure();
160 if (analysisResult == true)
161 atLeastOneMustAlias = true;
162 }
163 if (!atLeastOneMustAlias)
164 return failure();
165
166
167
168
169 for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
170 Value updatedCondition = deallocOp.getUpdatedConditions()[i];
171 std::optional analysisResult =
172 analysis.isSameAllocation(retained, memref);
173 if (analysisResult == true) {
174 auto disjunction = arith::OrIOp::create(rewriter, deallocOp.getLoc(),
175 updatedCondition, cond);
177 disjunction);
178 }
179 }
180
182 }
183
184 LogicalResult matchAndRewrite(DeallocOp deallocOp,
185 PatternRewriter &rewriter) const override {
186
187
188 DenseSet retained(deallocOp.getRetained().begin(),
189 deallocOp.getRetained().end());
190 if (retained.size() != deallocOp.getRetained().size())
191 return failure();
192
193 SmallVector newMemrefs, newConditions;
194 for (auto [memref, cond] :
195 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
196
197 if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter)))
198 continue;
199
200 if (auto extractOp =
201 memref.getDefiningOpmemref::ExtractStridedMetadataOp())
202 if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond,
203 rewriter)))
204 continue;
205
206 newMemrefs.push_back(memref);
207 newConditions.push_back(cond);
208 }
209
210
211
213 rewriter);
214 }
215
216private:
217 BufferOriginAnalysis &analysis;
218};
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236struct RemoveRetainedMemrefsGuaranteedToNotAlias
238 RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
239 BufferOriginAnalysis &analysis)
240 : OpRewritePattern(context), analysis(analysis) {}
241
242 LogicalResult matchAndRewrite(DeallocOp deallocOp,
243 PatternRewriter &rewriter) const override {
244 SmallVector newRetainedMemrefs, replacements;
245
246 for (auto retainedMemref : deallocOp.getRetained()) {
248 retainedMemref)) {
249 newRetainedMemrefs.push_back(retainedMemref);
250 replacements.push_back({});
251 continue;
252 }
253
254 replacements.push_back(arith::ConstantOp::create(
255 rewriter, deallocOp.getLoc(), rewriter.getBoolAttr(false)));
256 }
257
258 if (newRetainedMemrefs.size() == deallocOp.getRetained().size())
259 return failure();
260
261 auto newDeallocOp =
262 DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(),
263 deallocOp.getConditions(), newRetainedMemrefs);
264 int i = 0;
265 for (auto &repl : replacements) {
266 if (!repl)
267 repl = newDeallocOp.getUpdatedConditions()[i++];
268 }
269
270 rewriter.replaceOp(deallocOp, replacements);
272 }
273
274private:
275 BufferOriginAnalysis &analysis;
276};
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305struct SplitDeallocWhenNotAliasingAnyOther
307 SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
308 BufferOriginAnalysis &analysis)
309 : OpRewritePattern(context), analysis(analysis) {}
310
311 LogicalResult matchAndRewrite(DeallocOp deallocOp,
312 PatternRewriter &rewriter) const override {
313 Location loc = deallocOp.getLoc();
314 if (deallocOp.getMemrefs().size() <= 1)
315 return failure();
316
317 SmallVector remainingMemrefs, remainingConditions;
318 SmallVector<SmallVector> updatedConditions;
319 for (int64_t i = 0, e = deallocOp.getMemrefs().size(); i < e; ++i) {
320 Value memref = deallocOp.getMemrefs()[i];
321 Value cond = deallocOp.getConditions()[i];
322 SmallVector otherMemrefs(deallocOp.getMemrefs());
323 otherMemrefs.erase(otherMemrefs.begin() + i);
324
326
327 remainingMemrefs.push_back(memref);
328 remainingConditions.push_back(cond);
329 continue;
330 }
331
332
333 auto newDeallocOp = DeallocOp::create(rewriter, loc, memref, cond,
334 deallocOp.getRetained());
335 updatedConditions.push_back(
336 llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions())));
337 }
338
339
340 if (remainingMemrefs.size() == deallocOp.getMemrefs().size())
341 return failure();
342
343
344 auto newDeallocOp =
345 DeallocOp::create(rewriter, loc, remainingMemrefs, remainingConditions,
346 deallocOp.getRetained());
347
348
349 SmallVector replacements =
350 llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()));
351 for (auto additionalConditions : updatedConditions) {
352 assert(replacements.size() == additionalConditions.size() &&
353 "expected same number of updated conditions");
354 for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
355 replacements[i] = arith::OrIOp::create(rewriter, loc, replacements[i],
356 additionalConditions[i]);
357 }
358 }
359 rewriter.replaceOp(deallocOp, replacements);
361 }
362
363private:
364 BufferOriginAnalysis &analysis;
365};
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390struct RetainedMemrefAliasingAlwaysDeallocatedMemref
392 RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
393 BufferOriginAnalysis &analysis)
394 : OpRewritePattern(context), analysis(analysis) {}
395
396 LogicalResult matchAndRewrite(DeallocOp deallocOp,
397 PatternRewriter &rewriter) const override {
398 BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size());
399 SmallVector newMemrefs, newConditions;
400 for (auto [memref, cond] :
401 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
402 bool canDropMemref = false;
403 for (auto [i, retained, res] : llvm::enumerate(
404 deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
406 continue;
407
408 std::optional analysisResult =
409 analysis.isSameAllocation(retained, memref);
410 if (analysisResult == true) {
412 aliasesWithConstTrueMemref[i] = true;
413 canDropMemref = true;
414 continue;
415 }
416
417
418
419 auto extractOp =
420 memref.getDefiningOpmemref::ExtractStridedMetadataOp();
421 if (!extractOp)
422 continue;
423
424 std::optional extractAnalysisResult =
425 analysis.isSameAllocation(retained, extractOp.getOperand());
426 if (extractAnalysisResult == true) {
428 aliasesWithConstTrueMemref[i] = true;
429 canDropMemref = true;
430 }
431 }
432
433 if (!canDropMemref) {
434 newMemrefs.push_back(memref);
435 newConditions.push_back(cond);
436 }
437 }
438 if (!aliasesWithConstTrueMemref.all())
439 return failure();
440
442 rewriter);
443 }
444
445private:
446 BufferOriginAnalysis &analysis;
447};
448
449}
450
451
452
453
454
455namespace {
456
457
458
459
460struct BufferDeallocationSimplificationPass
461 : public bufferization::impl::BufferDeallocationSimplificationPassBase<
462 BufferDeallocationSimplificationPass> {
463 void runOnOperation() override {
464 BufferOriginAnalysis analysis(getOperation());
466 patterns.add<RemoveDeallocMemrefsContainedInRetained,
467 RemoveRetainedMemrefsGuaranteedToNotAlias,
468 SplitDeallocWhenNotAliasingAnyOther,
469 RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
471
473
474
475
477 getOperation(), std::move(patterns),
478 GreedyRewriteConfig().setRegionSimplificationLevel(
479 GreedySimplifyRegionLevel::Normal))))
480 signalPassFailure();
481 }
482};
483
484}
static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis, ValueRange otherList, Value memref)
Checks if memref may potentially alias a MemRef in otherList.
Definition BufferDeallocationSimplification.cpp:87
static Value getViewBase(Value value)
Given a memref value, return the "base" value by skipping over all ViewLikeOpInterface ops (if any) i...
Definition BufferDeallocationSimplification.cpp:39
static bool distinctAllocAndBlockArgument(Value v1, Value v2)
Return "true" if the given values are guaranteed to be different (and non-aliasing) allocations based...
Definition BufferDeallocationSimplification.cpp:70
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
Definition BufferDeallocationSimplification.cpp:49
An is-same-buffer analysis that checks if two SSA values belong to the same buffer allocation or not.
BoolAttr getBoolAttr(bool value)
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...