MLIR: lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14 #include

15

17

30 #include "llvm/ADT/SetVector.h"

31 #include "llvm/Support/Debug.h"

32 #include

33

34 namespace mlir {

35 #define GEN_PASS_DEF_ASYNCTOASYNCRUNTIMEPASS

36 #define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIMEPASS

37 #include "mlir/Dialect/Async/Passes.h.inc"

38 }

39

40 using namespace mlir;

42

43 #define DEBUG_TYPE "async-to-async-runtime"

44

45 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";

46

47 namespace {

48

49 class AsyncToAsyncRuntimePass

50 : public impl::AsyncToAsyncRuntimePassBase {

51 public:

52 AsyncToAsyncRuntimePass() = default;

53 void runOnOperation() override;

54 };

55

56 }

57

58 namespace {

59

60 class AsyncFuncToAsyncRuntimePass

61 : public impl::AsyncFuncToAsyncRuntimePassBase<

62 AsyncFuncToAsyncRuntimePass> {

63 public:

64 AsyncFuncToAsyncRuntimePass() = default;

65 void runOnOperation() override;

66 };

67

68 }

69

70

71

72

73

74

75 namespace {

76 struct CoroMachinery {

77 func::FuncOp func;

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92 std::optional asyncToken;

94

95 Value coroHandle;

96 Block *entry;

97 std::optional<Block *> setError;

98 Block *cleanup;

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120 Block *cleanupForDestroy;

121 Block *suspend;

122 };

123 }

124

126 std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

173 assert(!func.getBlocks().empty() && "Function must have an entry block");

174

176 Block *entryBlock = &func.getBlocks().front();

177 Block *originalEntryBlock =

180

181

182

183

184

185

186

187 bool isStateful = isa(func.getResultTypes().front());

188

189 std::optional retToken;

190 if (isStateful)

191 retToken.emplace(builder.create(TokenType::get(ctx)));

192

195 isStateful ? func.getResultTypes().drop_front() : func.getResultTypes();

196 for (auto resType : resValueTypes)

197 retValues.emplace_back(

198 builder.create(resType).getResult());

199

200

201

202

203 auto coroIdOp = builder.create(CoroIdType::get(ctx));

204 auto coroHdlOp =

206 builder.createcf::BranchOp(originalEntryBlock);

207

208 Block *cleanupBlock = func.addBlock();

209 Block *cleanupBlockForDestroy = func.addBlock();

210 Block *suspendBlock = func.addBlock();

211

212

213

214

215 auto buildCleanupBlock = [&](Block *cb) {

216 builder.setInsertionPointToStart(cb);

217 builder.create(coroIdOp.getId(), coroHdlOp.getHandle());

218

219

220 builder.createcf::BranchOp(suspendBlock);

221 };

222 buildCleanupBlock(cleanupBlock);

223 buildCleanupBlock(cleanupBlockForDestroy);

224

225

226

227

228

229 builder.setInsertionPointToStart(suspendBlock);

230

231

232 builder.create(coroHdlOp.getHandle());

233

234

235

237 if (retToken)

238 ret.push_back(*retToken);

239 llvm::append_range(ret, retValues);

240 builder.createfunc::ReturnOp(ret);

241

242

243

244

245

246

247 func->setAttr("passthrough", builder.getArrayAttr(

249

250 CoroMachinery machinery;

251 machinery.func = func;

252 machinery.asyncToken = retToken;

253 machinery.returnValues = retValues;

254 machinery.coroHandle = coroHdlOp.getHandle();

255 machinery.entry = entryBlock;

256 machinery.setError = std::nullopt;

257 machinery.cleanup = cleanupBlock;

258 machinery.cleanupForDestroy = cleanupBlockForDestroy;

259 machinery.suspend = suspendBlock;

260 return machinery;

261 }

262

263

264

266 if (coro.setError)

267 return *coro.setError;

268

269 coro.setError = coro.func.addBlock();

270 (*coro.setError)->moveBefore(coro.cleanup);

271

272 auto builder =

274

275

276 if (coro.asyncToken)

277 builder.create(*coro.asyncToken);

278

279 for (Value retValue : coro.returnValues)

280 builder.create(retValue);

281

282

283 builder.createcf::BranchOp(coro.cleanup);

284

285 return *coro.setError;

286 }

287

288

289

290

291

292

293

294

295

296 static std::pair<func::FuncOp, CoroMachinery>

298 ModuleOp module = execute->getParentOfType();

299

301 Location loc = execute.getLoc();

302

303

304

306

307

309 execute.getDependencies());

310 functionInputs.insert_range(execute.getBodyOperands());

312

313

314 auto typesRange = llvm::map_range(

315 functionInputs, [](Value value) { return value.getType(); });

317 auto outputTypes = execute.getResultTypes();

318

321

322

323

324 func::FuncOp func =

325 func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);

326 symbolTable.insert(func);

327

330

331

332 {

333 size_t numDependencies = execute.getDependencies().size();

334 size_t numOperands = execute.getBodyOperands().size();

335

336

337 for (size_t i = 0; i < numDependencies; ++i)

338 builder.create(func.getArgument(i));

339

340

342 for (size_t i = 0; i < numOperands; ++i) {

343 Value operand = func.getArgument(numDependencies + i);

344 unwrappedOperands[i] = builder.create(loc, operand).getResult();

345 }

346

347

348

350 valueMapping.map(functionInputs, func.getArguments());

351 valueMapping.map(execute.getBodyRegion().getArguments(), unwrappedOperands);

352

353

354

355 for (Operation &op : execute.getBodyRegion().getOps())

356 builder.clone(op, valueMapping);

357 }

358

359

361

362

363

364

365 {

366 cf::BranchOp branch = castcf::BranchOp(coro.entry->getTerminator());

367 builder.setInsertionPointToEnd(coro.entry);

368

369

370 auto coroSaveOp =

371 builder.create(CoroStateType::get(ctx), coro.coroHandle);

372

373

374

375 builder.create(coro.coroHandle);

376

377

378 builder.create(coroSaveOp.getState(), coro.suspend,

379 branch.getDest(), coro.cleanupForDestroy);

380

381 branch.erase();

382 }

383

384

385 {

387 auto callOutlinedFunc = callBuilder.createfunc::CallOp(

388 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());

389 execute.replaceAllUsesWith(callOutlinedFunc.getResults());

390 execute.erase();

391 }

392

393 return {func, coro};

394 }

395

396

397

398

399

400 namespace {

401 class CreateGroupOpLowering : public OpConversionPattern {

402 public:

404

405 LogicalResult

406 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,

409 op, GroupType::get(op->getContext()), adaptor.getOperands());

410 return success();

411 }

412 };

413 }

414

415

416

417

418

419 namespace {

421 public:

423

424 LogicalResult

425 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,

428 op, rewriter.getIndexType(), adaptor.getOperands());

429 return success();

430 }

431 };

432 }

433

434

435

436

437

438

439 namespace {

440

441

442

443

444

446 public:

449

450 LogicalResult

451 matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,

454

455 auto newFuncOp =

456 rewriter.createfunc::FuncOp(loc, op.getName(), op.getFunctionType());

457

460

461 for (const auto &namedAttr : op->getAttrs()) {

463 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());

464 }

465

467 newFuncOp.end());

468

470 (*coros)[newFuncOp] = coro;

471

472

474 return success();

475 }

476

477 private:

479 };

480

481

482

483

484

486 public:

489

490 LogicalResult

491 matchAndRewrite(async::CallOp op, OpAdaptor adaptor,

494 op, op.getCallee(), op.getResultTypes(), op.getOperands());

495 return success();

496 }

497 };

498

499

500

501

502

503 class AsyncReturnOpLowering : public OpConversionPatternasync::ReturnOp {

504 public:

507

508 LogicalResult

509 matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,

511 auto func = op->template getParentOfTypefunc::FuncOp();

512 auto funcCoro = coros->find(func);

513 if (funcCoro == coros->end())

515 op, "operation is not inside the async coroutine function");

516

518 const CoroMachinery &coro = funcCoro->getSecond();

520

521

522

523 for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {

524 Value returnValue = std::get<0>(tuple);

525 Value asyncValue = std::get<1>(tuple);

526 rewriter.create(loc, returnValue, asyncValue);

527 rewriter.create(loc, asyncValue);

528 }

529

530 if (coro.asyncToken)

531

532 rewriter.create(loc, *coro.asyncToken);

533

535 rewriter.createcf::BranchOp(loc, coro.cleanup);

536 return success();

537 }

538

539 private:

541 };

542 }

543

544

545

546

547

548

549 namespace {

550 template <typename AwaitType, typename AwaitableType>

552 using AwaitAdaptor = typename AwaitType::Adaptor;

553

554 public:

556 bool shouldLowerBlockingWait)

558 shouldLowerBlockingWait(shouldLowerBlockingWait) {}

559

560 LogicalResult

561 matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,

563

564

565 if (!isa(op.getOperand().getType()))

566 return rewriter.notifyMatchFailure(op, "unsupported awaitable type");

567

568

569 auto func = op->template getParentOfTypefunc::FuncOp();

570 auto funcCoro = coros->find(func);

571 const bool isInCoroutine = funcCoro != coros->end();

572

574 Value operand = adaptor.getOperand();

575

577

578

579 if (!isInCoroutine && !shouldLowerBlockingWait)

580 return failure();

581

582

583

584 if (!isInCoroutine) {

586 builder.create(loc, operand);

587

588

589 Value isError = builder.create(i1, operand);

590 Value notError = builder.createarith::XOrIOp(

591 isError, builder.createarith::ConstantOp(

592 loc, i1, builder.getIntegerAttr(i1, 1)));

593

594 builder.createcf::AssertOp(notError,

595 "Awaited async operand is in error state");

596 }

597

598

599

600 if (isInCoroutine) {

601 CoroMachinery &coro = funcCoro->getSecond();

602 Block *suspended = op->getBlock();

603

606

607

608

609 auto coroSaveOp =

610 builder.create(CoroStateType::get(ctx), coro.coroHandle);

611 builder.create(operand, coro.coroHandle);

612

613

615

616

617 builder.setInsertionPointToEnd(suspended);

618 builder.create(coroSaveOp.getState(), coro.suspend, resume,

619 coro.cleanupForDestroy);

620

621

623

624

625 builder.setInsertionPointToStart(resume);

626 auto isError = builder.create(loc, i1, operand);

627 builder.createcf::CondBranchOp(isError,

630 continuation,

632

633

634

636 }

637

638

639 if (Value replaceWith = getReplacementValue(op, operand, rewriter))

640 rewriter.replaceOp(op, replaceWith);

641 else

643

644 return success();

645 }

646

647 virtual Value getReplacementValue(AwaitType op, Value operand,

650 }

651

652 private:

654 bool shouldLowerBlockingWait;

655 };

656

657

658 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {

659 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;

660

661 public:

662 using Base::Base;

663 };

664

665

666 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {

667 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;

668

669 public:

670 using Base::Base;

671

673 getReplacementValue(AwaitOp op, Value operand,

675

676 auto valueType = cast(operand.getType()).getValueType();

677 return rewriter.create(op->getLoc(), valueType, operand);

678 }

679 };

680

681

682 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {

683 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;

684

685 public:

686 using Base::Base;

687 };

688

689 }

690

691

692

693

694

696 public:

699

700 LogicalResult

703

704 auto func = op->template getParentOfTypefunc::FuncOp();

705 auto funcCoro = coros->find(func);

706 if (funcCoro == coros->end())

708 op, "operation is not inside the async coroutine function");

709

711 const CoroMachinery &coro = funcCoro->getSecond();

712

713

714

715 for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {

716 Value yieldValue = std::get<0>(tuple);

717 Value asyncValue = std::get<1>(tuple);

718 rewriter.create(loc, yieldValue, asyncValue);

719 rewriter.create(loc, asyncValue);

720 }

721

722 if (coro.asyncToken)

723

724 rewriter.create(loc, *coro.asyncToken);

725

727 rewriter.createcf::BranchOp(loc, coro.cleanup);

728

729 return success();

730 }

731

732 private:

734 };

735

736

737

738

739

741 public:

744

745 LogicalResult

748

749 auto func = op->template getParentOfTypefunc::FuncOp();

750 auto funcCoro = coros->find(func);

751 if (funcCoro == coros->end())

753 op, "operation is not inside the async coroutine function");

754

756 CoroMachinery &coro = funcCoro->getSecond();

757

760 rewriter.createcf::CondBranchOp(loc, adaptor.getArg(),

761 cont,

766

767 return success();

768 }

769

770 private:

772 };

773

774

775 void AsyncToAsyncRuntimePass::runOnOperation() {

776 ModuleOp module = getOperation();

778

779

780

782 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();

783

784 module.walk([&](ExecuteOp execute) {

786 });

787

788 LLVM_DEBUG({

789 llvm::dbgs() << "Outlined " << coros->size()

790 << " functions built from async.execute operations\n";

791 });

792

793

794 auto isInCoroutine = [&](Operation *op) -> bool {

795 auto parentFunc = op->getParentOfTypefunc::FuncOp();

796 return coros->contains(parentFunc);

797 };

798

799

802

803

804

805

806

808

809

810

811 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);

812

813 asyncPatterns

814 .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(

815 ctx, coros, true);

816

817

819

820

822 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();

823 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();

824 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();

825

826

827 runtimeTarget.addDynamicallyLegalDialectscf::SCFDialect([&](Operation *op) {

828 auto walkResult = op->walk([&](Operation *nested) {

829 bool isAsync = isaasync::AsyncDialect(nested->getDialect());

832 });

833 return !walkResult.wasInterrupted();

834 });

835 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,

836 func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();

837

838

839 runtimeTarget.addDynamicallyLegalOpcf::AssertOp(

840 [&](cf::AssertOp op) -> bool {

841 auto func = op->getParentOfTypefunc::FuncOp();

842 return !coros->contains(func);

843 });

844

846 std::move(asyncPatterns)))) {

847 signalPassFailure();

848 return;

849 }

850 }

851

852

855

856

858 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();

860

861 patterns.add(ctx);

862 patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);

863

864 patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(

865 ctx, coros, false);

867

870 auto exec = op->getParentOfType();

871 auto func = op->getParentOfTypefunc::FuncOp();

872 return exec || !coros->contains(func);

873 });

874 }

875

876 void AsyncFuncToAsyncRuntimePass::runOnOperation() {

877 ModuleOp module = getOperation();

878

879

883

884

886 runtimeTarget);

887

888 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();

889 runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();

890

891 runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,

892 cf::BranchOp, cf::CondBranchOp>();

893

895 std::move(asyncPatterns)))) {

896 signalPassFailure();

897 return;

898 }

899 }

static Block * setupSetErrorBlock(CoroMachinery &coro)

std::shared_ptr< llvm::DenseMap< func::FuncOp, CoroMachinery > > FuncCoroMapPtr

static constexpr const char kAsyncFnPrefix[]

static std::pair< func::FuncOp, CoroMachinery > outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute)

Outline the body region attached to the async.execute op into a standalone function.

static CoroMachinery setupCoroMachinery(func::FuncOp func)

Utility to partially update the regular function CFG to the coroutine CFG compatible with LLVM corout...

AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)

LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override

Methods that operate on the SourceOp type.

LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override

Methods that operate on the SourceOp type.

YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)

Block represents an ordered list of Operations.

OpListType::iterator iterator

Block * splitBlock(iterator splitBefore)

Split the block into two blocks before the specified operation or iterator.

OpListType & getOperations()

This class implements a pattern rewriter for use with ConversionPatterns.

void replaceOp(Operation *op, ValueRange newValues) override

Replace the given operation with the new values.

void eraseOp(Operation *op) override

PatternRewriter hook for erasing a dead operation.

This class describes a specific conversion target.

void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)

Register the given operation as dynamically legal and set the dynamic legalization callback to the on...

This is a utility class for mapping one set of IR entities to another.

void map(Value from, Value to)

Inserts a new mapping for 'from' to 'to'.

ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...

static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, Listener *listener=nullptr)

Create a builder and set the insertion point to before the first operation in the block but still ins...

OpTy create(Args &&...args)

Create an operation of specific op type at the current insertion point and location.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...

OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)

Operation is the basic unit of execution within MLIR.

Dialect * getDialect()

Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

Block * splitBlock(Block *block, Block::iterator before)

Split the operations starting at "before" (inclusive) out of the given block into a new block,...

void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)

Move the blocks that belong to "region" before the given position in another region "parent".

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class allows for representing and managing the symbol table used by operations with the 'SymbolT...

static Visibility getSymbolVisibility(Operation *symbol)

Returns the visibility of the given symbol operation.

static StringRef getSymbolAttrName()

Return the name of the attribute used for symbol names.

static void setSymbolVisibility(Operation *symbol, Visibility vis)

Sets the visibility of the given symbol operation.

@ Private

The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...

StringAttr insert(Operation *symbol, Block::iterator insertPt={})

Insert a new symbol into the table, and rename it as necessary to avoid collisions.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

A utility result that is used to signal how to proceed with an ongoing walk:

static WalkResult interrupt()

void cloneConstantsIntoTheRegion(Region &region)

Clone ConstantLike operations that are defined above the given region and have users in the region in...

Include the generated interface declarations.

const FrozenRewritePatternSet & patterns

void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)

Fill values with a list of values defined at the ancestors of the limit region and used within region...

void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)

Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...

void populateAsyncFuncToAsyncRuntimeConversionPatterns(RewritePatternSet &patterns, ConversionTarget &target)

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())

Below we define several entry points for operation conversion.