MLIR: lib/Transforms/Mem2Reg.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

21#include "llvm/ADT/STLExtras.h"

22#include "llvm/Support/DebugLog.h"

23#include "llvm/Support/GenericIteratedDominanceFrontier.h"

24

25namespace mlir {

26#define GEN_PASS_DEF_MEM2REG

27#include "mlir/Transforms/Passes.h.inc"

28}

29

30#define DEBUG_TYPE "mem2reg"

31

32using namespace mlir;

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99namespace {

100

101using BlockingUsesMap =

102 llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;

103

104

105

106struct MemorySlotPromotionInfo {

107

109

110

111

112

113

114 BlockingUsesMap userToBlockingUses;

115};

116

117

118

119

120class MemorySlotPromotionAnalyzer {

121public:

124 : slot(slot), dominance(dominance), dataLayout(dataLayout) {}

125

126

127

128 std::optional computeInfo();

129

130private:

131

132

133

134

135

136

137

138 LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);

139

140

141

142

143

146

147

148

150

151

152

153

154

156

160};

161

163

164

165

166

167class MemorySlotPromoter {

168public:

169 MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,

171 const DataLayout &dataLayout, MemorySlotPromotionInfo info,

173 BlockIndexCache &blockIndexCache);

174

175

176

177

178

179

180 std::optional promoteSlot();

181

182private:

183

184

185

186

187 Value computeReachingDefInBlock(Block *block, Value reachingDef);

188

189

190

191

192

193 void computeReachingDefInRegion(Region *region, Value reachingDef);

194

195

196 void removeBlockingUses();

197

198

199

200 Value getOrCreateDefaultValue();

201

203 PromotableAllocationOpInterface allocator;

205

206

207 Value defaultValue;

208

209

214 MemorySlotPromotionInfo info;

216

217

218 BlockIndexCache &blockIndexCache;

219};

220

221}

222

223MemorySlotPromoter::MemorySlotPromoter(

224 MemorySlot slot, PromotableAllocationOpInterface allocator,

226 MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics,

227 BlockIndexCache &blockIndexCache)

228 : slot(slot), allocator(allocator), builder(builder), dominance(dominance),

229 dataLayout(dataLayout), info(std::move(info)), statistics(statistics),

230 blockIndexCache(blockIndexCache) {

231#ifndef NDEBUG

232 auto isResultOrNewBlockArgument = [&]() {

234 return arg.getOwner()->getParentOp() == allocator;

236 };

237

238 assert(isResultOrNewBlockArgument() &&

239 "a slot must be a result of the allocator or an argument of the child "

240 "regions of the allocator");

241#endif

242}

243

244Value MemorySlotPromoter::getOrCreateDefaultValue() {

245 if (defaultValue)

246 return defaultValue;

247

248 OpBuilder::InsertionGuard guard(builder);

250 return defaultValue = allocator.getDefaultValue(slot, builder);

251}

252

253LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(

254 BlockingUsesMap &userToBlockingUses) {

255

256

257

258

259

260

261

262

263

265 auto slotPtrRegionOp =

266 dyn_cast(slotPtrRegion->getParentOp());

267 if (slotPtrRegionOp &&

268 slotPtrRegionOp.getRegionKind(slotPtrRegion->getRegionNumber()) ==

269 RegionKind::Graph)

270 return failure();

271

272

273

274 for (OpOperand &use : slot.ptr.getUses()) {

275 SmallPtrSet<OpOperand *, 4> &blockingUses =

276 userToBlockingUses[use.getOwner()];

277 blockingUses.insert(&use);

278 }

279

280

281

282

283

284

287 for (Operation *user : forwardSlice) {

288

289 auto *it = userToBlockingUses.find(user);

290 if (it == userToBlockingUses.end())

291 continue;

292

293 SmallPtrSet<OpOperand *, 4> &blockingUses = it->second;

294

295 SmallVector<OpOperand *> newBlockingUses;

296

297

298 if (auto promotable = dyn_cast(user)) {

299 if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,

300 dataLayout))

301 return failure();

302 } else if (auto promotable = dyn_cast(user)) {

303 if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,

304 dataLayout))

305 return failure();

306 } else {

307

308

309 return failure();

310 }

311

312

313 for (OpOperand *blockingUse : newBlockingUses) {

314 assert(llvm::is_contained(user->getResults(), blockingUse->get()));

315

316 SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =

317 userToBlockingUses[blockingUse->getOwner()];

318 newUserBlockingUseSet.insert(blockingUse);

319 }

320 }

321

322

323

324

325

326 for (auto &[toPromote, _] : userToBlockingUses)

327 if (isa(toPromote) &&

329 return failure();

330

332}

333

334SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(

335 SmallPtrSetImpl<Block *> &definingBlocks) {

336 SmallPtrSet<Block *, 16> liveIn;

337

338

339

340

341 SmallVector<Block *> liveInWorkList;

342

343

344

345

346 SmallPtrSet<Block *, 16> visited;

347 for (Operation *user : slot.ptr.getUsers()) {

348 if (!visited.insert(user->getBlock()).second)

349 continue;

350

351 for (Operation &op : user->getBlock()->getOperations()) {

352 if (auto memOp = dyn_cast(op)) {

353

354

355 if (memOp.loadsFrom(slot)) {

356 liveInWorkList.push_back(user->getBlock());

357 break;

358 }

359

360

361

362 if (memOp.storesTo(slot))

363 break;

364 }

365 }

366 }

367

368

369

370 while (!liveInWorkList.empty()) {

371 Block *liveInBlock = liveInWorkList.pop_back_val();

372

373 if (!liveIn.insert(liveInBlock).second)

374 continue;

375

376

377

378

379

380

381

383 if (!definingBlocks.contains(pred))

384 liveInWorkList.push_back(pred);

385 }

386

387 return liveIn;

388}

389

391void MemorySlotPromotionAnalyzer::computeMergePoints(

394 return;

395

397

400 if (auto storeOp = dyn_cast(user))

401 if (storeOp.storesTo(slot))

402 definingBlocks.insert(user->getBlock());

403

404 idfCalculator.setDefiningBlocks(definingBlocks);

405

407 idfCalculator.setLiveInBlocks(liveIn);

408

410 idfCalculator.calculate(mergePointsVec);

411

412 mergePoints.insert_range(mergePointsVec);

413}

414

415bool MemorySlotPromotionAnalyzer::areMergePointsUsable(

416 SmallPtrSetImpl<Block *> &mergePoints) {

417 for (Block *mergePoint : mergePoints)

418 for (Block *pred : mergePoint->getPredecessors())

419 if (!isa(pred->getTerminator()))

420 return false;

421

422 return true;

423}

424

425std::optional

426MemorySlotPromotionAnalyzer::computeInfo() {

427 MemorySlotPromotionInfo info;

428

429

430

431

432

433 if (failed(computeBlockingUses(info.userToBlockingUses)))

434 return {};

435

436

437

438

439 computeMergePoints(info.mergePoints);

440

441

442

443

444 if (!areMergePointsUsable(info.mergePoints))

445 return {};

446

447 return info;

448}

449

450Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,

451 Value reachingDef) {

452 SmallVector<Operation *> blockOps;

454 blockOps.push_back(&op);

455 for (Operation *op : blockOps) {

456 if (auto memOp = dyn_cast(op)) {

457 if (info.userToBlockingUses.contains(memOp))

458 reachingDefs.insert({memOp, reachingDef});

459

460 if (memOp.storesTo(slot)) {

462 Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout);

463 assert(stored && "a memory operation storing to a slot must provide a "

464 "new definition of the slot");

465 reachingDef = stored;

466 replacedValuesMap[memOp] = stored;

467 }

468 }

469 }

470

471 return reachingDef;

472}

473

474void MemorySlotPromoter::computeReachingDefInRegion(Region *region,

475 Value reachingDef) {

476 assert(reachingDef && "expected an initial reaching def to be provided");

478 computeReachingDefInBlock(&region->front(), reachingDef);

479 return;

480 }

481

482 struct DfsJob {

483 llvm::DomTreeNodeBase *block;

484 Value reachingDef;

485 };

486

487 SmallVector dfsStack;

488

490

491 dfsStack.emplace_back(

492 {domTree.getNode(&region->front()), reachingDef});

493

494 while (!dfsStack.empty()) {

495 DfsJob job = dfsStack.pop_back_val();

496 Block *block = job.block->getBlock();

497

498 if (info.mergePoints.contains(block)) {

499 BlockArgument blockArgument =

502 allocator.handleBlockArgument(slot, blockArgument, builder);

503 job.reachingDef = blockArgument;

504

507 }

508

509 job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);

510 assert(job.reachingDef);

511

512 if (auto terminator = dyn_cast(block->getTerminator())) {

513 for (BlockOperand &blockOperand : terminator->getBlockOperands()) {

514 if (info.mergePoints.contains(blockOperand.get())) {

515 terminator.getSuccessorOperands(blockOperand.getOperandNumber())

516 .append(job.reachingDef);

517 }

518 }

519 }

520

521 for (auto *child : job.block->children())

522 dfsStack.emplace_back({child, job.reachingDef});

523 }

524}

525

526

529 auto [it, inserted] = blockIndexCache.try_emplace(region);

531 return it->second;

532

535 for (auto [index, block] : llvm::enumerate(topologicalOrder))

536 blockIndices[block] = index;

537 return blockIndices;

538}

539

540

541

542

544 BlockIndexCache &blockIndexCache) {

545

546

549

550

551

552

554 size_t lhsBlockIndex = topoBlockIndices.at(lhs->getBlock());

555 size_t rhsBlockIndex = topoBlockIndices.at(rhs->getBlock());

556 if (lhsBlockIndex == rhsBlockIndex)

557 return lhs->isBeforeInBlock(rhs);

558 return lhsBlockIndex < rhsBlockIndex;

559 });

560}

561

562void MemorySlotPromoter::removeBlockingUses() {

563 llvm::SmallVector<Operation *> usersToRemoveUses(

564 llvm::make_first_range(info.userToBlockingUses));

565

566

568 blockIndexCache);

569

570 llvm::SmallVector<Operation *> toErase;

571

572 llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList;

573

574 llvm::SmallVector toVisit;

575 for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {

576 if (auto toPromoteMemOp = dyn_cast(toPromote)) {

577 Value reachingDef = reachingDefs.lookup(toPromoteMemOp);

578

579

580 if (!reachingDef)

581 reachingDef = getOrCreateDefaultValue();

582

584 if (toPromoteMemOp.removeBlockingUses(

585 slot, info.userToBlockingUses[toPromote], builder, reachingDef,

586 dataLayout) == DeletionKind::Delete)

587 toErase.push_back(toPromote);

588 if (toPromoteMemOp.storesTo(slot))

589 if (Value replacedValue = replacedValuesMap[toPromoteMemOp])

590 replacedValuesList.push_back({toPromoteMemOp, replacedValue});

591 continue;

592 }

593

594 auto toPromoteBasic = cast(toPromote);

596 if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],

597 builder) == DeletionKind::Delete)

598 toErase.push_back(toPromote);

599 if (toPromoteBasic.requiresReplacedValues())

600 toVisit.push_back(toPromoteBasic);

601 }

602 for (PromotableOpInterface op : toVisit) {

604 op.visitReplacedValues(replacedValuesList, builder);

605 }

606

607 for (Operation *toEraseOp : toErase)

608 toEraseOp->erase();

609

611 "after promotion, the slot pointer should not be used anymore");

612}

613

614std::optional

615MemorySlotPromoter::promoteSlot() {

617 getOrCreateDefaultValue());

618

619

620 removeBlockingUses();

621

622

623

624 for (Block *mergePoint : info.mergePoints) {

625 for (BlockOperand &use : mergePoint->getUses()) {

626 auto user = cast(use.getOwner());

627 SuccessorOperands succOperands =

628 user.getSuccessorOperands(use.getOperandNumber());

632 succOperands.append(getOrCreateDefaultValue());

633 }

634 }

635

636 LDBG() << "Promoted memory slot: " << slot.ptr;

637

640

641 return allocator.handlePromotionComplete(slot, defaultValue, builder);

642}

643

648 bool promotedAny = false;

649

650

651

652

653 BlockIndexCache blockIndexCache;

654

656

658 newWorkList.reserve(workList.size());

659 while (true) {

660 bool changesInThisRound = false;

661 for (PromotableAllocationOpInterface allocator : workList) {

662 bool changedAllocator = false;

663 for (MemorySlot slot : allocator.getPromotableSlots()) {

665 continue;

666

667 MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);

668 std::optional info = analyzer.computeInfo();

669 if (info) {

670 std::optional newAllocator =

671 MemorySlotPromoter(slot, allocator, builder, dominance,

672 dataLayout, std::move(*info), statistics,

673 blockIndexCache)

674 .promoteSlot();

675 changedAllocator = true;

676

677

678 if (newAllocator)

679 newWorkList.push_back(*newAllocator);

680

681

682

683 break;

684 }

685 }

686 if (!changedAllocator)

687 newWorkList.push_back(allocator);

688 changesInThisRound |= changedAllocator;

689 }

690 if (!changesInThisRound)

691 break;

692 promotedAny = true;

693

694

695

696 workList.swap(newWorkList);

697 newWorkList.clear();

698 }

699

700 return success(promotedAny);

701}

702

703namespace {

704

707

708 void runOnOperation() override {

709 Operation *scopeOp = getOperation();

710

711 Mem2RegStatistics statistics{&promotedAmount, &newBlockArgumentAmount};

712

714

715 auto &dataLayoutAnalysis = getAnalysis();

716 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);

717 auto &dominance = getAnalysis();

718

719 for (Region &region : scopeOp->getRegions()) {

721 continue;

722

723 OpBuilder builder(&region.front(), region.front().begin());

724

725 SmallVector allocators;

726

727 region.walk([&](PromotableAllocationOpInterface allocator) {

728 allocators.emplace_back(allocator);

729 });

730

731

733 dominance, statistics)))

735 }

737 markAllAnalysesPreserved();

738 }

739};

740

741}

*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method

static void dominanceSort(SmallVector< Operation * > &ops, Region &region, BlockIndexCache &blockIndexCache)

Sorts ops according to dominance.

Definition Mem2Reg.cpp:543

static const DenseMap< Block *, size_t > & getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region)

Gets or creates a block index mapping for region.

Definition Mem2Reg.cpp:528

llvm::IDFCalculatorBase< Block, false > IDFCalculator

Definition Mem2Reg.cpp:390

This class represents an argument of a Block.

Block represents an ordered list of Operations.

unsigned getNumArguments()

iterator_range< pred_iterator > getPredecessors()

Region * getParent() const

Provide a 'getParent' method for ilist_node_with_parent methods.

OpListType & getOperations()

Operation * getTerminator()

Get the terminator operation of this block.

BlockArgument addArgument(Type type, Location loc)

Add one value to the argument list.

The main mechanism for performing data layout queries.

A class for computing basic dominance information.

use_range getUses() const

Returns a range of all uses, which is useful for iterating over all uses.

This class helps build Operations.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

Operation is the basic unit of execution within MLIR.

MutableArrayRef< Region > getRegions()

Returns the regions held by this operation.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

unsigned getRegionNumber()

Return the number of this region in the parent operation.

Operation * getParentOp()

Return the parent operation this region is attached to.

BlockListType & getBlocks()

bool hasOneBlock()

Return true if this region has exactly one block.

RetT walk(FnT &&callback)

Walk all nested operations, blocks or regions (including this region), depending on the type of callb...

void append(ValueRange valueRange)

Add new operands that are forwarded to the successor.

unsigned size() const

Returns the amount of operands passed to the successor.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

bool use_empty() const

Returns true if this value has no uses.

use_range getUses() const

Returns a range of all uses, which is useful for iterating over all uses.

Block * getParentBlock()

Return the Block in which this Value is defined.

user_range getUsers() const

Location getLoc() const

Return the location of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

Region * getParentRegion()

Return the Region in which this Value is defined.

DomTree & getDomTree(Region *region) const

Definition Mem2Reg.cpp:831

Include the generated interface declarations.

const FrozenRewritePatternSet GreedyRewriteConfig bool * changed

SetVector< Block * > getBlocksSortedByDominance(Region &region)

Gets a list of blocks that is sorted according to dominance.

llvm::SetVector< T, Vector, Set, N > SetVector

llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap

LogicalResult tryToPromoteMemorySlots(ArrayRef< PromotableAllocationOpInterface > allocators, OpBuilder &builder, const DataLayout &dataLayout, DominanceInfo &dominance, Mem2RegStatistics statistics={})

Attempts to promote the memory slots of the provided allocators.

Definition Mem2Reg.cpp:644

void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})

Fills forwardSlice with the computed forward slice (i.e.

Statistics collected while applying mem2reg.

llvm::Statistic * promotedAmount

Total amount of memory slots promoted.

llvm::Statistic * newBlockArgumentAmount

Total amount of new block arguments inserted in blocks.

Represents a slot in memory.

Value ptr

Pointer to the memory slot, used by operations to refer to it.

Type elemType

Type of the value contained in the slot.