MLIR: lib/Dialect/GPU/Transforms/EliminateBarriers.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

27 #include "llvm/ADT/TypeSwitch.h"

28 #include "llvm/Support/Debug.h"

29

30 namespace mlir {

31 #define GEN_PASS_DEF_GPUELIMINATEBARRIERS

32 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"

33 }

34

35 using namespace mlir;

37

38 #define DEBUG_TYPE "gpu-erase-barriers"

39 #define DEBUG_TYPE_ALIAS "gpu-erase-barries-alias"

40

41 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")

42 #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")

43

44

45

46

47

48

50 if (op->hasAttr("__parallel_region_boundary_for_test"))

51 return true;

52

53 return isa<GPUFuncOp, LaunchOp>(op);

54 }

55

56

57

59

60

61

62

64 return isa<FunctionOpInterface, scf::IfOp, memref::AllocaScopeOp>(op);

65 }

66

67

68

69

71 return isa_and_nonnull<memref::AllocOp, memref::AllocaOp>(op);

72 }

73

74

75

78 effects.emplace_back(MemoryEffects::Effect::getMemoryEffects::Read());

79 effects.emplace_back(MemoryEffects::Effect::getMemoryEffects::Write());

80 effects.emplace_back(MemoryEffects::Effect::getMemoryEffects::Allocate());

81 effects.emplace_back(MemoryEffects::Effect::getMemoryEffects::Free());

82 }

83

84

85

86

87

88 static bool

91 bool ignoreBarriers = true) {

92

93

94 if (ignoreBarriers && isa(op))

95 return true;

96

97

98

99

100

101 if (auto iface = dyn_cast(op)) {

103 iface.getEffects(localEffects);

104 llvm::append_range(effects, localEffects);

105 return true;

106 }

108 for (auto &region : op->getRegions()) {

109 for (auto &block : region) {

110 for (auto &innerOp : block)

111 if (collectEffects(&innerOp, effects, ignoreBarriers))

112 return false;

113 }

114 }

115 return true;

116 }

117

118

119

121 return false;

122 }

123

124

125

126 static bool

129 bool stopAtBarrier) {

131 return true;

132

133 for (Operation *it = op->getPrevNode(); it != nullptr;

134 it = it->getPrevNode()) {

135 if (isa(it)) {

136 if (stopAtBarrier)

137 return true;

138 continue;

139 }

140

142 return false;

143 }

144 return true;

145 }

146

147

148

149

150

151

152

153 static bool

156 bool stopAtBarrier) {

158 return true;

159

160

162 if (region && !llvm::hasSingleElement(region->getBlocks())) {

164 return false;

165 }

166

167

169

170

172 return true;

173

175

178 return false;

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

195

197 true);

198 }

199

200

201

202 bool conservative = false;

205 if (conservative)

208 conservative = true;

210 }

212 });

213

214 return !conservative;

215 }

216

217

218

219 static bool

222 bool stopAtBarrier) {

224 return true;

225

226 for (Operation *it = op->getNextNode(); it != nullptr;

227 it = it->getNextNode()) {

228 if (isa(it)) {

229 if (stopAtBarrier)

230 return true;

231 continue;

232 }

234 return false;

235 }

236 return true;

237 }

238

239

240

241

242

243

244

245 static bool

248 bool stopAtBarrier) {

250 return true;

251

252

254 if (region && !llvm::hasSingleElement(region->getBlocks())) {

256 return false;

257 }

258

259

261

263

265 return true;

266

267

268

271 return false;

272

273

274

275

276

277

278

279

280

281

282

283

284

285

286

289 return true;

290

293 true) &&

294 exact;

295 }

296

297

298

299 bool conservative = false;

302 if (conservative)

305 conservative = true;

307 }

309 });

310

311 return !conservative;

312 }

313

314

316 while (true) {

318 if (!definingOp)

319 break;

320

321 bool shouldContinue =

323 .Case<memref::CastOp, memref::SubViewOp, memref::ViewOp>(

324 [&](auto op) {

325 v = op.getSource();

326 return true;

327 })

328 .Casememref::TransposeOp([&](auto op) {

329 v = op.getIn();

330 return true;

331 })

332 .Case<memref::CollapseShapeOp, memref::ExpandShapeOp>([&](auto op) {

333 v = op.getSrc();

334 return true;

335 })

336 .Default([](Operation *) { return false; });

337 if (!shouldContinue)

338 break;

339 }

340 return v;

341 }

342

343

345 auto arg = dyn_cast(v);

346 return arg && isa(arg.getOwner()->getParentOp());

347 }

348

349

350

351

354 .Case(

355 [](ViewLikeOpInterface viewLike) { return viewLike.getViewSource(); })

356 .Case([](CastOpInterface castLike) { return castLike->getOperand(0); })

357 .Case([](memref::TransposeOp transpose) { return transpose.getIn(); })

358 .Case<memref::ExpandShapeOp, memref::CollapseShapeOp>(

359 [](auto op) { return op.getSrc(); })

361 }

362

363

364

365

368

369

370 .Case<memref::StoreOp, vector::TransferWriteOp>(

371 [&](auto op) { return op.getValue() == v; })

372 .Case<vector::StoreOp, vector::MaskedStoreOp>(

373 [&](auto op) { return op.getValueToStore() == v; })

374

375 .Case([](memref::DeallocOp) { return false; })

376

377 .Default([](Operation *) { return std::nullopt; });

378 }

379

380

381

382

383

386 while (!todo.empty()) {

387 Value v = todo.pop_back_val();

389

390 auto iface = dyn_cast(user);

391 if (iface) {

393 iface.getEffects(effects);

394 if (llvm::all_of(effects,

396 return isaMemoryEffects::Read(effect.getEffect());

397 })) {

398 continue;

399 }

400 }

401

402

403

405 todo.push_back(v);

406 continue;

407 }

408

410 if (!knownCaptureStatus || *knownCaptureStatus)

411 return true;

412 }

413 }

414

415 return false;

416 }

417

418

419

420

421

422

423

424

425

426

429 DBGS_ALIAS() << "checking aliasing between ";

433 });

434

436 second = getBase(second);

437

443 });

444

445

446

447 if (first == second) {

449 return true;

450 }

451

452

453 if (auto globFirst = first.getDefiningOpmemref::GetGlobalOp()) {

454 if (auto globSecond = second.getDefiningOpmemref::GetGlobalOp()) {

455 return globFirst.getNameAttr() == globSecond.getNameAttr();

456 }

457 }

458

459

460 auto isNoaliasFuncArgument = [](Value value) {

461 auto bbArg = dyn_cast(value);

462 if (!bbArg)

463 return false;

464 auto iface = dyn_cast(bbArg.getOwner()->getParentOp());

465 if (!iface)

466 return false;

467

468 return iface.getArgAttr(bbArg.getArgNumber(), "llvm.noalias") != nullptr;

469 };

470 if (isNoaliasFuncArgument(first) && isNoaliasFuncArgument(second))

471 return false;

472

475 bool isGlobal[] = {first.getDefiningOpmemref::GetGlobalOp() != nullptr,

476 second.getDefiningOpmemref::GetGlobalOp() != nullptr};

477

478

479

480

481 if ((isDistinct[0] || isGlobal[0]) && (isDistinct[1] || isGlobal[1]))

482 return false;

483

485

486

487 if ((isDistinct[0] && isArg[1]) || (isDistinct[1] && isArg[0]))

488 return false;

489

490

492 return false;

494 return false;

495

496

498 return true;

499 }

500

501

502

503

507 }

508 return true;

509 }

510

511

512

513

514

518 return false;

523 }

524 return true;

525 }

526

527

528

529

530

531

532

533

534

535 static bool

540

541 if (mayAlias(before, after))

542 continue;

543

544

545 if (isaMemoryEffects::Read(before.getEffect()) &&

546 isaMemoryEffects::Read(after.getEffect())) {

547 continue;

548 }

549

550

551

552

553

554 if (isaMemoryEffects::Allocate(before.getEffect()) ||

555 isaMemoryEffects::Allocate(after.getEffect())) {

556 continue;

557 }

558

559

560

561

562

563

564

565

566

567 if (isaMemoryEffects::Free(before.getEffect()))

568 continue;

569

570

571 LLVM_DEBUG(

572 DBGS() << "found a conflict between (before): " << before.getValue()

573 << " read:" << isaMemoryEffects::Read(before.getEffect())

574 << " write:" << isaMemoryEffects::Write(before.getEffect())

575 << " alloc:"

576 << isaMemoryEffects::Allocate(before.getEffect()) << " free:"

577 << isaMemoryEffects::Free(before.getEffect()) << "\n");

578 LLVM_DEBUG(

579 DBGS() << "and (after): " << after.getValue()

580 << " read:" << isaMemoryEffects::Read(after.getEffect())

581 << " write:" << isaMemoryEffects::Write(after.getEffect())

582 << " alloc:" << isaMemoryEffects::Allocate(after.getEffect())

583 << " free:" << isaMemoryEffects::Free(after.getEffect())

584 << "\n");

585 return true;

586 }

587 }

588

589 return false;

590 }

591

592 namespace {

593 class BarrierElimination final : public OpRewritePattern {

594 public:

596

597 LogicalResult matchAndRewrite(BarrierOp barrier,

599 LLVM_DEBUG(DBGS() << "checking the necessity of: " << barrier << " "

600 << barrier.getLoc() << "\n");

601

603 getEffectsBefore(barrier, beforeEffects, true);

604

606 getEffectsAfter(barrier, afterEffects, true);

607

609 LLVM_DEBUG(DBGS() << "the surrounding barriers are sufficient, removing "

610 << barrier << "\n");

611 rewriter.eraseOp(barrier);

612 return success();

613 }

614

615 LLVM_DEBUG(DBGS() << "barrier is necessary: " << barrier << " "

616 << barrier.getLoc() << "\n");

617 return failure();

618 }

619 };

620

621 class GpuEliminateBarriersPass

622 : public impl::GpuEliminateBarriersBase {

623 void runOnOperation() override {

624 auto funcOp = getOperation();

628 return signalPassFailure();

629 }

630 }

631 };

632

633 }

634

637 }

static bool isSequentialLoopLike(Operation *op)

Returns true if the op behaves like a sequential loop, e.g., the control flow "wraps around" from the...

static bool isFunctionArgument(Value v)

Returns true if the value is defined as a function argument.

static Value getBase(Value v)

Looks through known "view-like" ops to find the base memref.

static Value propagatesCapture(Operation *op)

Returns the operand that the operation "propagates" through it for capture purposes.

static bool hasSingleExecutionBody(Operation *op)

Returns true if the regions of the op are guaranteed to be executed at most once.

static bool producesDistinctBase(Operation *op)

Returns true if the operation is known to produce a pointer-like object distinct from any other objec...

static bool mayAlias(Value first, Value second)

Returns true if two values may be referencing aliasing memory.

static bool getEffectsBeforeInBlock(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects, bool stopAtBarrier)

Get all effects before the given operation caused by other operations in the same block.

static bool isParallelRegionBoundary(Operation *op)

Returns true if the op is defines the parallel region that is subject to barrier synchronization.

static bool getEffectsAfter(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects, bool stopAtBarrier)

Collects memory effects from operations that may be executed after op in a trivial structured control...

static std::optional< bool > getKnownCapturingStatus(Operation *op, Value v)

Returns true if the given operation is known to capture the given value, false if it is known not to ...

static bool collectEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects, bool ignoreBarriers=true)

Collect the memory effects of the given op in 'effects'.

static bool haveConflictingEffects(ArrayRef< MemoryEffects::EffectInstance > beforeEffects, ArrayRef< MemoryEffects::EffectInstance > afterEffects)

Returns true if any of the "before" effect instances has a conflict with any "after" instance for the...

static bool getEffectsAfterInBlock(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects, bool stopAtBarrier)

Get all effects after the given operation caused by other operations in the same block.

static void addAllValuelessEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

Populates effects with all memory effects without associating them to a specific value.

static bool maybeCaptured(Value v)

Returns true if the value may be captured by any of its users, i.e., if the user may be storing this ...

static bool getEffectsBefore(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects, bool stopAtBarrier)

Collects memory effects from operations that may be executed before op in a trivial structured contro...

static MLIRContext * getContext(OpFoldResult val)

Region * getParent() const

Provide a 'getParent' method for ilist_node_with_parent methods.

Operation * getTerminator()

Get the terminator operation of this block.

This trait indicates that the memory effects of an operation includes the effects of operations neste...

This class provides the API for ops that are known to be isolated from above.

Operation is the basic unit of execution within MLIR.

bool hasTrait()

Returns true if the operation was registered with a particular trait, e.g.

bool hasAttr(StringAttr name)

Return true if the operation has an attribute with the provided name, false otherwise.

std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)

Walk the operation by calling the callback for each nested operation (including this one),...

Operation * getParentOp()

Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...

Block * getBlock()

Returns the operation block that contains this operation.

MutableArrayRef< Region > getRegions()

Returns the regions held by this operation.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class contains a list of basic blocks and a link to the parent operation it is attached to.

BlockListType & getBlocks()

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

This class represents a specific instance of an effect.

Resource * getResource() const

Return the resource that the effect applies to.

EffectT * getEffect() const

Return the effect being applied.

Value getValue() const

Return the value the effect is applied on, or nullptr if there isn't a known value being affected.

TypeID getResourceID() const

Return the unique identifier for the base resource class.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

user_range getUsers() const

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

static WalkResult advance()

static WalkResult interrupt()

static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)

Include the generated interface declarations.

LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)

Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...

const FrozenRewritePatternSet & patterns

void populateGpuEliminateBarriersPatterns(RewritePatternSet &patterns)

Erase barriers that do not enforce conflicting memory side effects.

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...