MLIR: lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14 #include
15
17
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/Support/Debug.h"
32 #include
33
34 namespace mlir {
35 #define GEN_PASS_DEF_ASYNCTOASYNCRUNTIMEPASS
36 #define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIMEPASS
37 #include "mlir/Dialect/Async/Passes.h.inc"
38 }
39
40 using namespace mlir;
42
43 #define DEBUG_TYPE "async-to-async-runtime"
44
45 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
46
47 namespace {
48
49 class AsyncToAsyncRuntimePass
50 : public impl::AsyncToAsyncRuntimePassBase {
51 public:
52 AsyncToAsyncRuntimePass() = default;
53 void runOnOperation() override;
54 };
55
56 }
57
58 namespace {
59
60 class AsyncFuncToAsyncRuntimePass
61 : public impl::AsyncFuncToAsyncRuntimePassBase<
62 AsyncFuncToAsyncRuntimePass> {
63 public:
64 AsyncFuncToAsyncRuntimePass() = default;
65 void runOnOperation() override;
66 };
67
68 }
69
70
71
72
73
74
75 namespace {
76 struct CoroMachinery {
77 func::FuncOp func;
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92 std::optional asyncToken;
94
95 Value coroHandle;
96 Block *entry;
97 std::optional<Block *> setError;
98 Block *cleanup;
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120 Block *cleanupForDestroy;
121 Block *suspend;
122 };
123 }
124
126 std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
173 assert(!func.getBlocks().empty() && "Function must have an entry block");
174
176 Block *entryBlock = &func.getBlocks().front();
177 Block *originalEntryBlock =
180
181
182
183
184
185
186
187 bool isStateful = isa(func.getResultTypes().front());
188
189 std::optional retToken;
190 if (isStateful)
191 retToken.emplace(builder.create(TokenType::get(ctx)));
192
195 isStateful ? func.getResultTypes().drop_front() : func.getResultTypes();
196 for (auto resType : resValueTypes)
197 retValues.emplace_back(
198 builder.create(resType).getResult());
199
200
201
202
203 auto coroIdOp = builder.create(CoroIdType::get(ctx));
204 auto coroHdlOp =
206 builder.createcf::BranchOp(originalEntryBlock);
207
208 Block *cleanupBlock = func.addBlock();
209 Block *cleanupBlockForDestroy = func.addBlock();
210 Block *suspendBlock = func.addBlock();
211
212
213
214
215 auto buildCleanupBlock = [&](Block *cb) {
216 builder.setInsertionPointToStart(cb);
217 builder.create(coroIdOp.getId(), coroHdlOp.getHandle());
218
219
220 builder.createcf::BranchOp(suspendBlock);
221 };
222 buildCleanupBlock(cleanupBlock);
223 buildCleanupBlock(cleanupBlockForDestroy);
224
225
226
227
228
229 builder.setInsertionPointToStart(suspendBlock);
230
231
232 builder.create(coroHdlOp.getHandle());
233
234
235
237 if (retToken)
238 ret.push_back(*retToken);
239 llvm::append_range(ret, retValues);
240 builder.createfunc::ReturnOp(ret);
241
242
243
244
245
246
247 func->setAttr("passthrough", builder.getArrayAttr(
249
250 CoroMachinery machinery;
251 machinery.func = func;
252 machinery.asyncToken = retToken;
253 machinery.returnValues = retValues;
254 machinery.coroHandle = coroHdlOp.getHandle();
255 machinery.entry = entryBlock;
256 machinery.setError = std::nullopt;
257 machinery.cleanup = cleanupBlock;
258 machinery.cleanupForDestroy = cleanupBlockForDestroy;
259 machinery.suspend = suspendBlock;
260 return machinery;
261 }
262
263
264
266 if (coro.setError)
267 return *coro.setError;
268
269 coro.setError = coro.func.addBlock();
270 (*coro.setError)->moveBefore(coro.cleanup);
271
272 auto builder =
274
275
276 if (coro.asyncToken)
277 builder.create(*coro.asyncToken);
278
279 for (Value retValue : coro.returnValues)
280 builder.create(retValue);
281
282
283 builder.createcf::BranchOp(coro.cleanup);
284
285 return *coro.setError;
286 }
287
288
289
290
291
292
293
294
295
296 static std::pair<func::FuncOp, CoroMachinery>
298 ModuleOp module = execute->getParentOfType();
299
301 Location loc = execute.getLoc();
302
303
304
306
307
309 execute.getDependencies());
310 functionInputs.insert_range(execute.getBodyOperands());
312
313
314 auto typesRange = llvm::map_range(
315 functionInputs, [](Value value) { return value.getType(); });
317 auto outputTypes = execute.getResultTypes();
318
321
322
323
324 func::FuncOp func =
325 func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
326 symbolTable.insert(func);
327
330
331
332 {
333 size_t numDependencies = execute.getDependencies().size();
334 size_t numOperands = execute.getBodyOperands().size();
335
336
337 for (size_t i = 0; i < numDependencies; ++i)
338 builder.create(func.getArgument(i));
339
340
342 for (size_t i = 0; i < numOperands; ++i) {
343 Value operand = func.getArgument(numDependencies + i);
344 unwrappedOperands[i] = builder.create(loc, operand).getResult();
345 }
346
347
348
350 valueMapping.map(functionInputs, func.getArguments());
351 valueMapping.map(execute.getBodyRegion().getArguments(), unwrappedOperands);
352
353
354
355 for (Operation &op : execute.getBodyRegion().getOps())
356 builder.clone(op, valueMapping);
357 }
358
359
361
362
363
364
365 {
366 cf::BranchOp branch = castcf::BranchOp(coro.entry->getTerminator());
367 builder.setInsertionPointToEnd(coro.entry);
368
369
370 auto coroSaveOp =
371 builder.create(CoroStateType::get(ctx), coro.coroHandle);
372
373
374
375 builder.create(coro.coroHandle);
376
377
378 builder.create(coroSaveOp.getState(), coro.suspend,
379 branch.getDest(), coro.cleanupForDestroy);
380
381 branch.erase();
382 }
383
384
385 {
387 auto callOutlinedFunc = callBuilder.createfunc::CallOp(
388 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
389 execute.replaceAllUsesWith(callOutlinedFunc.getResults());
390 execute.erase();
391 }
392
393 return {func, coro};
394 }
395
396
397
398
399
400 namespace {
401 class CreateGroupOpLowering : public OpConversionPattern {
402 public:
404
405 LogicalResult
406 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
409 op, GroupType::get(op->getContext()), adaptor.getOperands());
410 return success();
411 }
412 };
413 }
414
415
416
417
418
419 namespace {
421 public:
423
424 LogicalResult
425 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
428 op, rewriter.getIndexType(), adaptor.getOperands());
429 return success();
430 }
431 };
432 }
433
434
435
436
437
438
439 namespace {
440
441
442
443
444
446 public:
449
450 LogicalResult
451 matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
454
455 auto newFuncOp =
456 rewriter.createfunc::FuncOp(loc, op.getName(), op.getFunctionType());
457
460
461 for (const auto &namedAttr : op->getAttrs()) {
463 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
464 }
465
467 newFuncOp.end());
468
470 (*coros)[newFuncOp] = coro;
471
472
474 return success();
475 }
476
477 private:
479 };
480
481
482
483
484
486 public:
489
490 LogicalResult
491 matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
494 op, op.getCallee(), op.getResultTypes(), op.getOperands());
495 return success();
496 }
497 };
498
499
500
501
502
503 class AsyncReturnOpLowering : public OpConversionPatternasync::ReturnOp {
504 public:
507
508 LogicalResult
509 matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
511 auto func = op->template getParentOfTypefunc::FuncOp();
512 auto funcCoro = coros->find(func);
513 if (funcCoro == coros->end())
515 op, "operation is not inside the async coroutine function");
516
518 const CoroMachinery &coro = funcCoro->getSecond();
520
521
522
523 for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
524 Value returnValue = std::get<0>(tuple);
525 Value asyncValue = std::get<1>(tuple);
526 rewriter.create(loc, returnValue, asyncValue);
527 rewriter.create(loc, asyncValue);
528 }
529
530 if (coro.asyncToken)
531
532 rewriter.create(loc, *coro.asyncToken);
533
535 rewriter.createcf::BranchOp(loc, coro.cleanup);
536 return success();
537 }
538
539 private:
541 };
542 }
543
544
545
546
547
548
549 namespace {
550 template <typename AwaitType, typename AwaitableType>
552 using AwaitAdaptor = typename AwaitType::Adaptor;
553
554 public:
556 bool shouldLowerBlockingWait)
558 shouldLowerBlockingWait(shouldLowerBlockingWait) {}
559
560 LogicalResult
561 matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
563
564
565 if (!isa(op.getOperand().getType()))
566 return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
567
568
569 auto func = op->template getParentOfTypefunc::FuncOp();
570 auto funcCoro = coros->find(func);
571 const bool isInCoroutine = funcCoro != coros->end();
572
574 Value operand = adaptor.getOperand();
575
577
578
579 if (!isInCoroutine && !shouldLowerBlockingWait)
580 return failure();
581
582
583
584 if (!isInCoroutine) {
586 builder.create(loc, operand);
587
588
589 Value isError = builder.create(i1, operand);
590 Value notError = builder.createarith::XOrIOp(
591 isError, builder.createarith::ConstantOp(
592 loc, i1, builder.getIntegerAttr(i1, 1)));
593
594 builder.createcf::AssertOp(notError,
595 "Awaited async operand is in error state");
596 }
597
598
599
600 if (isInCoroutine) {
601 CoroMachinery &coro = funcCoro->getSecond();
602 Block *suspended = op->getBlock();
603
606
607
608
609 auto coroSaveOp =
610 builder.create(CoroStateType::get(ctx), coro.coroHandle);
611 builder.create(operand, coro.coroHandle);
612
613
615
616
617 builder.setInsertionPointToEnd(suspended);
618 builder.create(coroSaveOp.getState(), coro.suspend, resume,
619 coro.cleanupForDestroy);
620
621
623
624
625 builder.setInsertionPointToStart(resume);
626 auto isError = builder.create(loc, i1, operand);
627 builder.createcf::CondBranchOp(isError,
630 continuation,
632
633
634
636 }
637
638
639 if (Value replaceWith = getReplacementValue(op, operand, rewriter))
640 rewriter.replaceOp(op, replaceWith);
641 else
643
644 return success();
645 }
646
647 virtual Value getReplacementValue(AwaitType op, Value operand,
650 }
651
652 private:
654 bool shouldLowerBlockingWait;
655 };
656
657
658 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
659 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
660
661 public:
662 using Base::Base;
663 };
664
665
666 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
667 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
668
669 public:
670 using Base::Base;
671
673 getReplacementValue(AwaitOp op, Value operand,
675
676 auto valueType = cast(operand.getType()).getValueType();
677 return rewriter.create(op->getLoc(), valueType, operand);
678 }
679 };
680
681
682 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
683 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
684
685 public:
686 using Base::Base;
687 };
688
689 }
690
691
692
693
694
696 public:
699
700 LogicalResult
703
704 auto func = op->template getParentOfTypefunc::FuncOp();
705 auto funcCoro = coros->find(func);
706 if (funcCoro == coros->end())
708 op, "operation is not inside the async coroutine function");
709
711 const CoroMachinery &coro = funcCoro->getSecond();
712
713
714
715 for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
716 Value yieldValue = std::get<0>(tuple);
717 Value asyncValue = std::get<1>(tuple);
718 rewriter.create(loc, yieldValue, asyncValue);
719 rewriter.create(loc, asyncValue);
720 }
721
722 if (coro.asyncToken)
723
724 rewriter.create(loc, *coro.asyncToken);
725
727 rewriter.createcf::BranchOp(loc, coro.cleanup);
728
729 return success();
730 }
731
732 private:
734 };
735
736
737
738
739
741 public:
744
745 LogicalResult
748
749 auto func = op->template getParentOfTypefunc::FuncOp();
750 auto funcCoro = coros->find(func);
751 if (funcCoro == coros->end())
753 op, "operation is not inside the async coroutine function");
754
756 CoroMachinery &coro = funcCoro->getSecond();
757
760 rewriter.createcf::CondBranchOp(loc, adaptor.getArg(),
761 cont,
766
767 return success();
768 }
769
770 private:
772 };
773
774
775 void AsyncToAsyncRuntimePass::runOnOperation() {
776 ModuleOp module = getOperation();
778
779
780
782 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
783
784 module.walk([&](ExecuteOp execute) {
786 });
787
788 LLVM_DEBUG({
789 llvm::dbgs() << "Outlined " << coros->size()
790 << " functions built from async.execute operations\n";
791 });
792
793
794 auto isInCoroutine = [&](Operation *op) -> bool {
795 auto parentFunc = op->getParentOfTypefunc::FuncOp();
796 return coros->contains(parentFunc);
797 };
798
799
802
803
804
805
806
808
809
810
811 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
812
813 asyncPatterns
814 .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
815 ctx, coros, true);
816
817
819
820
822 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
823 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
824 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
825
826
827 runtimeTarget.addDynamicallyLegalDialectscf::SCFDialect([&](Operation *op) {
828 auto walkResult = op->walk([&](Operation *nested) {
829 bool isAsync = isaasync::AsyncDialect(nested->getDialect());
832 });
833 return !walkResult.wasInterrupted();
834 });
835 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
836 func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
837
838
839 runtimeTarget.addDynamicallyLegalOpcf::AssertOp(
840 [&](cf::AssertOp op) -> bool {
841 auto func = op->getParentOfTypefunc::FuncOp();
842 return !coros->contains(func);
843 });
844
846 std::move(asyncPatterns)))) {
847 signalPassFailure();
848 return;
849 }
850 }
851
852
855
856
858 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
860
861 patterns.add(ctx);
862 patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
863
864 patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
865 ctx, coros, false);
867
870 auto exec = op->getParentOfType();
871 auto func = op->getParentOfTypefunc::FuncOp();
872 return exec || !coros->contains(func);
873 });
874 }
875
876 void AsyncFuncToAsyncRuntimePass::runOnOperation() {
877 ModuleOp module = getOperation();
878
879
883
884
886 runtimeTarget);
887
888 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
889 runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
890
891 runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
892 cf::BranchOp, cf::CondBranchOp>();
893
895 std::move(asyncPatterns)))) {
896 signalPassFailure();
897 return;
898 }
899 }
static Block * setupSetErrorBlock(CoroMachinery &coro)
std::shared_ptr< llvm::DenseMap< func::FuncOp, CoroMachinery > > FuncCoroMapPtr
static constexpr const char kAsyncFnPrefix[]
static std::pair< func::FuncOp, CoroMachinery > outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute)
Outline the body region attached to the async.execute op into a standalone function.
static CoroMachinery setupCoroMachinery(func::FuncOp func)
Utility to partially update the regular function CFG to the coroutine CFG compatible with LLVM corout...
AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
Block represents an ordered list of Operations.
OpListType::iterator iterator
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
OpListType & getOperations()
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult interrupt()
void cloneConstantsIntoTheRegion(Region ®ion)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
void populateAsyncFuncToAsyncRuntimeConversionPatterns(RewritePatternSet &patterns, ConversionTarget &target)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.