MLIR: lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

22

23namespace mlir {

25#define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATIONPASS

26#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"

27}

28}

29

30using namespace mlir;

32

33

34

35

36

37

38

40 while (auto viewLikeOp = value.getDefiningOp()) {

41 if (value != viewLikeOp.getViewDest()) {

42 break;

43 }

44 value = viewLikeOp.getViewSource();

45 }

46 return value;

47}

48

53 if (deallocOp.getMemrefs() == memrefs &&

54 deallocOp.getConditions() == conditions)

55 return failure();

56

58 deallocOp.getMemrefsMutable().assign(memrefs);

59 deallocOp.getConditionsMutable().assign(conditions);

60 });

62}

63

64

65

66

67

68

69

73 auto areDistinct = [](Value v1, Value v2) {

76 if (auto bbArg = dyn_cast(v2))

77 if (bbArg.getOwner()->findAncestorOpInBlock(*op))

78 return true;

79 return false;

80 };

81 return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base);

82}

83

84

85

86

89 for (auto other : otherList) {

91 continue;

92 std::optional analysisResult =

93 analysis.isSameAllocation(other, memref);

94 if (!analysisResult.has_value() || analysisResult == true)

95 return true;

96 }

97 return false;

98}

99

100

101

102

103

104namespace {

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135struct RemoveDeallocMemrefsContainedInRetained

137 RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,

138 BufferOriginAnalysis &analysis)

139 : OpRewritePattern(context), analysis(analysis) {}

140

141

142

143

144

145

146

147 LogicalResult handleOneMemref(DeallocOp deallocOp, Value memref, Value cond,

148 PatternRewriter &rewriter) const {

150

151

152

153

154 bool atLeastOneMustAlias = false;

155 for (Value retained : deallocOp.getRetained()) {

156 std::optional analysisResult =

157 analysis.isSameAllocation(retained, memref);

158 if (!analysisResult.has_value())

159 return failure();

160 if (analysisResult == true)

161 atLeastOneMustAlias = true;

162 }

163 if (!atLeastOneMustAlias)

164 return failure();

165

166

167

168

169 for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {

170 Value updatedCondition = deallocOp.getUpdatedConditions()[i];

171 std::optional analysisResult =

172 analysis.isSameAllocation(retained, memref);

173 if (analysisResult == true) {

174 auto disjunction = arith::OrIOp::create(rewriter, deallocOp.getLoc(),

175 updatedCondition, cond);

177 disjunction);

178 }

179 }

180

182 }

183

184 LogicalResult matchAndRewrite(DeallocOp deallocOp,

185 PatternRewriter &rewriter) const override {

186

187

188 DenseSet retained(deallocOp.getRetained().begin(),

189 deallocOp.getRetained().end());

190 if (retained.size() != deallocOp.getRetained().size())

191 return failure();

192

193 SmallVector newMemrefs, newConditions;

194 for (auto [memref, cond] :

195 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {

196

197 if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter)))

198 continue;

199

200 if (auto extractOp =

201 memref.getDefiningOpmemref::ExtractStridedMetadataOp())

202 if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond,

203 rewriter)))

204 continue;

205

206 newMemrefs.push_back(memref);

207 newConditions.push_back(cond);

208 }

209

210

211

213 rewriter);

214 }

215

216private:

217 BufferOriginAnalysis &analysis;

218};

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236struct RemoveRetainedMemrefsGuaranteedToNotAlias

238 RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,

239 BufferOriginAnalysis &analysis)

240 : OpRewritePattern(context), analysis(analysis) {}

241

242 LogicalResult matchAndRewrite(DeallocOp deallocOp,

243 PatternRewriter &rewriter) const override {

244 SmallVector newRetainedMemrefs, replacements;

245

246 for (auto retainedMemref : deallocOp.getRetained()) {

248 retainedMemref)) {

249 newRetainedMemrefs.push_back(retainedMemref);

250 replacements.push_back({});

251 continue;

252 }

253

254 replacements.push_back(arith::ConstantOp::create(

255 rewriter, deallocOp.getLoc(), rewriter.getBoolAttr(false)));

256 }

257

258 if (newRetainedMemrefs.size() == deallocOp.getRetained().size())

259 return failure();

260

261 auto newDeallocOp =

262 DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(),

263 deallocOp.getConditions(), newRetainedMemrefs);

264 int i = 0;

265 for (auto &repl : replacements) {

266 if (!repl)

267 repl = newDeallocOp.getUpdatedConditions()[i++];

268 }

269

270 rewriter.replaceOp(deallocOp, replacements);

272 }

273

274private:

275 BufferOriginAnalysis &analysis;

276};

277

278

279

280

281

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

304

305struct SplitDeallocWhenNotAliasingAnyOther

307 SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,

308 BufferOriginAnalysis &analysis)

309 : OpRewritePattern(context), analysis(analysis) {}

310

311 LogicalResult matchAndRewrite(DeallocOp deallocOp,

312 PatternRewriter &rewriter) const override {

313 Location loc = deallocOp.getLoc();

314 if (deallocOp.getMemrefs().size() <= 1)

315 return failure();

316

317 SmallVector remainingMemrefs, remainingConditions;

318 SmallVector<SmallVector> updatedConditions;

319 for (int64_t i = 0, e = deallocOp.getMemrefs().size(); i < e; ++i) {

320 Value memref = deallocOp.getMemrefs()[i];

321 Value cond = deallocOp.getConditions()[i];

322 SmallVector otherMemrefs(deallocOp.getMemrefs());

323 otherMemrefs.erase(otherMemrefs.begin() + i);

324

326

327 remainingMemrefs.push_back(memref);

328 remainingConditions.push_back(cond);

329 continue;

330 }

331

332

333 auto newDeallocOp = DeallocOp::create(rewriter, loc, memref, cond,

334 deallocOp.getRetained());

335 updatedConditions.push_back(

336 llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions())));

337 }

338

339

340 if (remainingMemrefs.size() == deallocOp.getMemrefs().size())

341 return failure();

342

343

344 auto newDeallocOp =

345 DeallocOp::create(rewriter, loc, remainingMemrefs, remainingConditions,

346 deallocOp.getRetained());

347

348

349 SmallVector replacements =

350 llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()));

351 for (auto additionalConditions : updatedConditions) {

352 assert(replacements.size() == additionalConditions.size() &&

353 "expected same number of updated conditions");

354 for (int64_t i = 0, e = replacements.size(); i < e; ++i) {

355 replacements[i] = arith::OrIOp::create(rewriter, loc, replacements[i],

356 additionalConditions[i]);

357 }

358 }

359 rewriter.replaceOp(deallocOp, replacements);

361 }

362

363private:

364 BufferOriginAnalysis &analysis;

365};

366

367

368

369

370

371

372

373

374

375

376

377

378

379

380

381

382

383

384

385

386

387

388

389

390struct RetainedMemrefAliasingAlwaysDeallocatedMemref

392 RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,

393 BufferOriginAnalysis &analysis)

394 : OpRewritePattern(context), analysis(analysis) {}

395

396 LogicalResult matchAndRewrite(DeallocOp deallocOp,

397 PatternRewriter &rewriter) const override {

398 BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size());

399 SmallVector newMemrefs, newConditions;

400 for (auto [memref, cond] :

401 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {

402 bool canDropMemref = false;

403 for (auto [i, retained, res] : llvm::enumerate(

404 deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {

406 continue;

407

408 std::optional analysisResult =

409 analysis.isSameAllocation(retained, memref);

410 if (analysisResult == true) {

412 aliasesWithConstTrueMemref[i] = true;

413 canDropMemref = true;

414 continue;

415 }

416

417

418

419 auto extractOp =

420 memref.getDefiningOpmemref::ExtractStridedMetadataOp();

421 if (!extractOp)

422 continue;

423

424 std::optional extractAnalysisResult =

425 analysis.isSameAllocation(retained, extractOp.getOperand());

426 if (extractAnalysisResult == true) {

428 aliasesWithConstTrueMemref[i] = true;

429 canDropMemref = true;

430 }

431 }

432

433 if (!canDropMemref) {

434 newMemrefs.push_back(memref);

435 newConditions.push_back(cond);

436 }

437 }

438 if (!aliasesWithConstTrueMemref.all())

439 return failure();

440

442 rewriter);

443 }

444

445private:

446 BufferOriginAnalysis &analysis;

447};

448

449}

450

451

452

453

454

455namespace {

456

457

458

459

460struct BufferDeallocationSimplificationPass

461 : public bufferization::impl::BufferDeallocationSimplificationPassBase<

462 BufferDeallocationSimplificationPass> {

463 void runOnOperation() override {

464 BufferOriginAnalysis analysis(getOperation());

466 patterns.add<RemoveDeallocMemrefsContainedInRetained,

467 RemoveRetainedMemrefsGuaranteedToNotAlias,

468 SplitDeallocWhenNotAliasingAnyOther,

469 RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),

471

473

474

475

477 getOperation(), std::move(patterns),

478 GreedyRewriteConfig().setRegionSimplificationLevel(

479 GreedySimplifyRegionLevel::Normal))))

480 signalPassFailure();

481 }

482};

483

484}

static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis, ValueRange otherList, Value memref)

Checks if memref may potentially alias a MemRef in otherList.

Definition BufferDeallocationSimplification.cpp:87

static Value getViewBase(Value value)

Given a memref value, return the "base" value by skipping over all ViewLikeOpInterface ops (if any) i...

Definition BufferDeallocationSimplification.cpp:39

static bool distinctAllocAndBlockArgument(Value v1, Value v2)

Return "true" if the given values are guaranteed to be different (and non-aliasing) allocations based...

Definition BufferDeallocationSimplification.cpp:70

static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)

Definition BufferDeallocationSimplification.cpp:49

An is-same-buffer analysis that checks if two SSA values belong to the same buffer allocation or not.

BoolAttr getBoolAttr(bool value)

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

Operation is the basic unit of execution within MLIR.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)

Find uses of from and replace them with to except if the user is exceptedUser.

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

virtual void replaceAllUsesWith(Value from, Value to)

Find uses of from and replace them with to.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)

Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)

Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...

llvm::DenseSet< ValueT, ValueInfoT > DenseSet

const FrozenRewritePatternSet & patterns

detail::constant_int_predicate_matcher m_One()

Matches a constant scalar / vector splat / tensor splat integer one.

bool hasEffect(Operation *op)

Returns "true" if op has an effect of type EffectTy.

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...