MLIR: lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

15

24

25 namespace mlir {

26 namespace bufferization {

27 #define GEN_PASS_DEF_BUFFERHOISTINGPASS

28 #define GEN_PASS_DEF_BUFFERLOOPHOISTINGPASS

29 #define GEN_PASS_DEF_PROMOTEBUFFERSTOSTACKPASS

30 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"

31 }

32 }

33

34 using namespace mlir;

36

37

38

40 return isa<LoopLikeOpInterface, RegionBranchOpInterface>(op);

41 }

42

43

44

45

46

48

49

50 if (isa(op))

51 return true;

52

53

54

55 auto regionInterface = dyn_cast(op);

56 if (!regionInterface)

57 return false;

58

59 return regionInterface.hasLoop();

60 }

61

62

63

66 }

67

68

69

71 auto allocOp = dyn_cast(op);

72 return allocOp &&

74 }

75

76

77

79 auto allocOp = dyn_cast(op);

80 return allocOp &&

82 }

83

84

85

86

88 unsigned maxRankOfAllocatedMemRef) {

89 auto type = dyn_cast(alloc.getType());

90 if (!type || !alloc.getDefiningOpmemref::AllocOp())

91 return false;

92 if (!type.hasStaticShape()) {

93

94

95

96

97

98 if (type.getRank() <= maxRankOfAllocatedMemRef) {

100 [&](Value operand) {

101 return operand.getDefiningOpmemref::RankOp();

102 });

103 }

104 return false;

105 }

108 return type.getNumElements() * bitwidth <= maximumSizeInBytes * 8;

109 }

110

111

112 static bool

115 for (Value alias : aliases) {

116 for (auto *use : alias.getUsers()) {

117

118

119

120 if (isa(use) &&

121 use->getParentRegion() == parentRegion)

122 return true;

123 }

124 }

125 return false;

126 }

127

128

132 do {

134

135

136

139 return true;

140

141

142

144 break;

145 }

147 return false;

148 }

149

150 namespace {

151

152

153

154

155

156

157 struct BufferAllocationHoistingStateBase {

158

160

161

162 Value allocValue;

163

164

165 Block *placementBlock;

166

167

168 BufferAllocationHoistingStateBase(DominanceInfo *dominators, Value allocValue,

169 Block *placementBlock)

170 : dominators(dominators), allocValue(allocValue),

171 placementBlock(placementBlock) {}

172 };

173

174

175 template

177 public:

178 BufferAllocationHoisting(Operation *op)

180 postDominators(op), scopeOp(op) {}

181

182

183 void hoist() {

186 allocsAndAllocas.push_back(std::get<0>(entry));

187 scopeOp->walk([&](memref::AllocaOp op) {

188 allocsAndAllocas.push_back(op.getMemref());

189 });

190

191 for (auto allocValue : allocsAndAllocas) {

192 if (!StateT::shouldHoistOpType(allocValue.getDefiningOp()))

193 continue;

194 Operation *definingOp = allocValue.getDefiningOp();

195 assert(definingOp && "No defining op");

196 auto operands = definingOp->getOperands();

197 auto resultAliases = aliases.resolve(allocValue);

198

199 Block *dominatorBlock =

201

202 StateT state(&dominators, allocValue, allocValue.getParentBlock());

203

204

205 Block *dependencyBlock = nullptr;

206

207

208

209 for (Value depValue : operands) {

210 Block *depBlock = depValue.getParentBlock();

211 if (!dependencyBlock || dominators.dominates(dependencyBlock, depBlock))

212 dependencyBlock = depBlock;

213 }

214

215

216

217

218 Block *placementBlock = findPlacementBlock(

219 state, state.computeUpperBound(dominatorBlock, dependencyBlock));

221 allocValue, placementBlock, liveness);

222

223

224 Operation *allocOperation = allocValue.getDefiningOp();

225 allocOperation->moveBefore(startOperation);

226 }

227 }

228

229 private:

230

231

232

233 Block *findPlacementBlock(StateT &state, Block *upperBound) {

234 Block *currentBlock = state.placementBlock;

235

236

237

238

239

240

241

243 Block *parentBlock;

244 while ((parentOp = currentBlock->getParentOp()) &&

245 (parentBlock = parentOp->getBlock()) &&

246 (!upperBound ||

247 dominators.properlyDominates(upperBound, currentBlock))) {

248

249

251

252

254 idom = dominators.getNode(currentBlock)->getIDom();

255

256 if (idom && dominators.properlyDominates(parentBlock, idom->getBlock())) {

257

258

259 currentBlock = idom->getBlock();

260 state.recordMoveToDominator(currentBlock);

261 } else {

262

263

264

265

266

268 !state.isLegalPlacement(parentOp))

269 break;

270

271

272 currentBlock = parentBlock;

273 state.recordMoveToParent(currentBlock);

274 }

275 }

276

277 return state.placementBlock;

278 }

279

280

281

283

284

285

287

288

290

291

292

294 };

295

296

297

298

299 struct BufferAllocationHoistingState : BufferAllocationHoistingStateBase {

300 using BufferAllocationHoistingStateBase::BufferAllocationHoistingStateBase;

301

302

303 Block *computeUpperBound(Block *dominatorBlock, Block *dependencyBlock) {

304

305

306 if (!dependencyBlock)

307 return dominatorBlock;

308

309

310

311 return dominators->properlyDominates(dominatorBlock, dependencyBlock)

312 ? dependencyBlock

313 : dominatorBlock;

314 }

315

316

317 bool isLegalPlacement(Operation *op) { return isLoop(op); }

318

319

320 static bool shouldHoistOpType(Operation *op) {

322 }

323

324

325 void recordMoveToDominator(Block *block) { placementBlock = block; }

326

327

328 void recordMoveToParent(Block *block) { recordMoveToDominator(block); }

329 };

330

331

332

333 struct BufferAllocationLoopHoistingState : BufferAllocationHoistingStateBase {

334 using BufferAllocationHoistingStateBase::BufferAllocationHoistingStateBase;

335

336

337 Block *aliasDominatorBlock = nullptr;

338

339

340 Block *computeUpperBound(Block *dominatorBlock, Block *dependencyBlock) {

341 aliasDominatorBlock = dominatorBlock;

342

343

344 return dependencyBlock ? dependencyBlock : nullptr;

345 }

346

347

348

349

350

351

352 bool isLegalPlacement(Operation *op) {

354 !dominators->dominates(aliasDominatorBlock, op->getBlock());

355 }

356

357

358 static bool shouldHoistOpType(Operation *op) {

360 }

361

362

363

364 void recordMoveToDominator(Block *block) {}

365

366

367 void recordMoveToParent(Block *block) { placementBlock = block; }

368 };

369

370

371

372

373

374

376 public:

377 BufferPlacementPromotion(Operation *op)

379

380

383 Value alloc = std::get<0>(entry);

384 Operation *dealloc = std::get<1>(entry);

385

386

387

388

389 if (!isSmallAlloc(alloc) || dealloc ||

391 continue;

392

395

396

397 OpBuilder builder(startOperation);

399 if (auto allocInterface = dyn_cast(allocOp)) {

400 std::optional<Operation *> alloca =

401 allocInterface.buildPromotedAlloc(builder, alloc);

402 if (!alloca)

403 continue;

404

406 allocOp->erase();

407 }

408 }

409 }

410 };

411

412

413

414

415

416

417

418 struct BufferHoistingPass

419 : public bufferization::impl::BufferHoistingPassBase {

420

421 void runOnOperation() override {

422

423 BufferAllocationHoisting optimizer(

424 getOperation());

425 optimizer.hoist();

426 }

427 };

428

429

430 struct BufferLoopHoistingPass

431 : public bufferization::impl::BufferLoopHoistingPassBase<

432 BufferLoopHoistingPass> {

433

434 void runOnOperation() override {

435

437 }

438 };

439

440

441

442 class PromoteBuffersToStackPass

443 : public bufferization::impl::PromoteBuffersToStackPassBase<

444 PromoteBuffersToStackPass> {

445 using Base::Base;

446

447 public:

448 explicit PromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc)

449 : isSmallAlloc(std::move(isSmallAlloc)) {}

450

451 LogicalResult initialize(MLIRContext *context) override {

452 if (isSmallAlloc == nullptr) {

453 isSmallAlloc = [=](Value alloc) {

455 maxRankOfAllocatedMemRef);

456 };

457 }

458 return success();

459 }

460

461 void runOnOperation() override {

462

463 BufferPlacementPromotion optimizer(getOperation());

464 optimizer.promote(isSmallAlloc);

465 }

466

467 private:

468 std::function<bool(Value)> isSmallAlloc;

469 };

470

471 }

472

474 BufferAllocationHoisting optimizer(op);

475 optimizer.hoist();

476 }

477

479 std::function<bool(Value)> isSmallAlloc) {

480 return std::make_unique(std::move(isSmallAlloc));

481 }

static bool leavesAllocationScope(Region *parentRegion, const BufferViewFlowAnalysis::ValueSetT &aliases)

Checks whether the given aliases leave the allocation scope.

static bool isKnownControlFlowInterface(Operation *op)

Returns true if the given operation implements a known high-level region- based control-flow interfac...

static bool hasAllocationScope(Value alloc, const BufferViewFlowAnalysis &aliasAnalysis)

Checks, if an automated allocation scope for a given alloc value exists.

static bool isSequentialLoop(Operation *op)

Return whether the given operation is a loop with sequential execution semantics.

static bool isLoop(Operation *op)

Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...

static bool allowAllocDominateBlockHoisting(Operation *op)

Returns true if the given operation implements the AllocationOpInterface and it supports the dominate...

static bool allowAllocLoopHoisting(Operation *op)

Returns true if the given operation implements the AllocationOpInterface and it supports the loop hoi...

static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes, unsigned maxRankOfAllocatedMemRef)

Check if the size of the allocation is less than the given size.

Block represents an ordered list of Operations.

bool isEntryBlock()

Return if this block is the entry block in the parent region.

Operation * getParentOp()

Returns the closest surrounding operation that contains this block.

A straight-forward alias analysis which ensures that all dependencies of all values will be determine...

ValueSetT resolve(Value value) const

Find all immediate and indirect views upon this value.

static DataLayout closest(Operation *op)

Returns the layout of the closest parent operation carrying layout info.

llvm::TypeSize getTypeSizeInBits(Type t) const

Returns the size in bits of the given type in the current scope.

A class for computing basic dominance information.

MLIRContext is the top-level object for a collection of MLIR operations.

This class helps build Operations.

A trait of region holding operations that define a new scope for automatic allocations,...

Operation is the basic unit of execution within MLIR.

bool hasTrait()

Returns true if the operation was registered with a particular trait, e.g.

Block * getBlock()

Returns the operation block that contains this operation.

operand_range getOperands()

Returns an iterator on the underlying Value's.

void moveBefore(Operation *existingOp)

Unlink this operation from its current block and insert it right before existingOp which may be in th...

void replaceAllUsesWith(ValuesT &&values)

Replace all uses of results of this operation with the provided 'values'.

void erase()

Remove this operation from its parent block and delete it.

A class for computing basic postdominance information.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

Region * getParentRegion()

Return the region containing this region or nullptr if the region is attached to a top-level operatio...

Operation * getParentOp()

Return the parent operation this region is attached to.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Block * getParentBlock()

Return the Block in which this Value is defined.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

Region * getParentRegion()

Return the Region in which this Value is defined.

static Operation * getStartOperation(Value allocValue, Block *placementBlock, const Liveness &liveness)

Get the start operation to place the given alloc value within the specified placement block.

std::tuple< Value, Operation * > AllocEntry

Represents a tuple of allocValue and deallocOperation.

The base class for all BufferPlacement transformations.

void hoistBuffersFromLoops(Operation *op)

Within the given operation, hoist buffers from loops where possible.

std::unique_ptr< Pass > createPromoteBuffersToStackPass(std::function< bool(Value)> isSmallAlloc)

Creates a pass that promotes heap-based allocations to stack-based ones.

Block * findCommonDominator(Value value, const BufferViewFlowAnalysis::ValueSetT &values, const DominatorT &doms)

Finds a common dominator for the given value while taking the positions of the values in the value se...

void promote(RewriterBase &rewriter, scf::ForallOp forallOp)

Promotes the loop body of a scf::ForallOp to its containing block.

Include the generated interface declarations.

llvm::DomTreeNodeBase< Block > DominanceInfoNode