MLIR: lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
18#include
19
20namespace mlir {
21
24 for (Block &b : funcOp.getBody())
25 if (auto returnOp = dyn_castfunc::ReturnOp(b.getTerminator()))
26 result.push_back(returnOp);
28}
29
31namespace func_ext {
32
36 auto createdAliasingResults =
40 (void)createdEquiv;
41 (void)createdAliasingResults;
42 (void)createdRead;
43 (void)createdWritten;
44#ifndef NDEBUG
45 assert(createdEquiv.second && "equivalence info exists already");
46 assert(createdAliasingResults.second && "aliasing info exists already");
47 assert(createdRead.second && "bbarg access info exists already");
48 assert(createdWritten.second && "bbarg access info exists already");
49#endif
50}
51
52
53
56 TensorLikeType type) {
57 if (auto tensorType = dyn_cast(type)) {
58 return *options.defaultMemorySpaceFn(tensorType);
59 }
60 return nullptr;
61}
62
63
64
65
66static BufferLikeType
69 auto type =
70 dyn_cast(funcOp.getFunctionType().getInput(index));
71 assert(type && "expected TensorLikeType");
72
73
74 if (auto tensorType = dyn_cast(type)) {
75 BufferLikeType memrefType = options.functionArgTypeConverterFn(
76 type, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
77
78 auto layoutAttr = funcOp.getArgAttrOfType(
79 index, BufferizationDialect::kBufferLayoutAttrName);
80 if (!layoutAttr)
81 return memrefType;
82
83 auto rankedMemrefType = dyn_cast(memrefType);
84 assert(rankedMemrefType &&
85 "buffer layout not supported on unranked tensors");
86 return cast(MemRefType::get(
87 rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
88 layoutAttr, rankedMemrefType.getMemorySpace()));
89 }
90
91 return options.functionArgTypeConverterFn(type, nullptr, funcOp,
93}
94
95
98 return dyn_cast_or_null(callOp.resolveCallableInTable(&symbolTables));
99}
100
101
104 auto &oneShotAnalysisState = static_cast<const OneShotAnalysisState &>(state);
105
106 if (auto *funcAnalysisState =
108
109 return getCalledFunction(callOp, funcAnalysisState->symbolTables);
110 }
111
114}
115
116
117static const FuncAnalysisState &
119 assert(isa(state) && "expected OneShotAnalysisState");
121 .getExtension();
122 assert(result && "FuncAnalysisState does not exist");
124}
125
126
128 FuncOp funcOp) {
129 if (!isa(state))
132 .getExtension();
133 if (!funcState)
135 const auto &analyzedFuncOps = funcState->analyzedFuncOps;
136 auto it = analyzedFuncOps.find(funcOp);
137 if (it == analyzedFuncOps.end())
139 return it->second;
140}
141
142
143
144static std::optional<int64_t>
149
150 return std::nullopt;
151
152 auto retValIt = funcOpIt->getSecond().find(returnValIdx);
153 if (retValIt == funcOpIt->getSecond().end())
154
155 return std::nullopt;
156
157 return retValIt->getSecond();
158}
159
161 : public BufferizableOpInterface::ExternalModel<CallOpInterface,
162 func::CallOp> {
165 func::CallOp callOp = castfunc::CallOp(op);
167 assert(funcOp && "expected CallOp to a FuncOp");
168
170
171 return true;
172
174 return funcState.readBbArgs.lookup(funcOp).contains(
176 }
177
180 func::CallOp callOp = castfunc::CallOp(op);
182 assert(funcOp && "expected CallOp to a FuncOp");
183
185
186 return true;
187
189 return funcState.writtenBbArgs.lookup(funcOp).contains(
191 }
192
195 func::CallOp callOp = castfunc::CallOp(op);
197 assert(funcOp && "expected CallOp to a FuncOp");
199
200 return detail::unknownGetAliasingValues(opOperand);
201
202
204 auto aliasingReturnVals =
207
208
209 std::optional<int64_t> equivalent = {};
210 if (aliasingReturnVals.size() == 1) {
212 aliasingReturnVals.front());
213 assert((!equivalent.has_value() ||
215 "inconsistent analysis state");
216 }
217 AliasingValueList result;
218 for (int64_t resultIdx : aliasingReturnVals)
219 result.addAlias({callOp->getOpResult(resultIdx),
220 equivalent.has_value() ? BufferRelation::Equivalent
221 : BufferRelation::Unknown,
222 equivalent.has_value()});
224 }
225
226 FailureOr
228 const BufferizationState &state,
230 auto callOp = castfunc::CallOp(op);
231
232
234
236 assert(funcOp && "expected CallOp to a FuncOp");
237
238
239
240 FunctionType funcType = funcOp.getFunctionType();
241 Type resultType =
242 funcType.getResult(cast(value).getResultNumber());
243 if (auto bufferizedType = dyn_cast(resultType))
244 return bufferizedType;
245
246
247 auto tensorType = cast(resultType);
248 return cast(options.functionArgTypeConverterFn(
251 }
252
253
254
257 BufferizationState &state) const {
258 func::CallOp callOp = castfunc::CallOp(op);
259
260
262 for (Value result : callOp.getResults()) {
264 if (!isa(returnType)) {
265
266 resultTypes.push_back(returnType);
267 continue;
268 }
269
270
271 FailureOr resultType =
272 bufferization::getBufferType(result, options, state);
273 if (failed(resultType))
274 return failure();
275 resultTypes.push_back(*resultType);
276 }
277
278
279
281
282 FuncOp funcOp = getCalledFunction(callOp, state.getSymbolTables());
283 assert(funcOp && "expected CallOp to a FuncOp");
284 FunctionType funcType = funcOp.getFunctionType();
285
286 for (OpOperand &opOperand : callOp->getOpOperands()) {
287
288 if (!isa(opOperand.get().getType())) {
289 newOperands.push_back(opOperand.get());
290 continue;
291 }
292
293
294 FailureOr maybeBuffer =
295 getBuffer(rewriter, opOperand.get(), options, state);
296 if (failed(maybeBuffer))
297 return failure();
298 Value buffer = *maybeBuffer;
299
300
301 auto bufferType = funcType.getInput(opOperand.getOperandNumber());
302 if (!isa(bufferType)) {
303
304
305
306 FailureOr maybeBufferType =
307 bufferization::getBufferType(
308 funcOp.getArgument(opOperand.getOperandNumber()), options,
309 state);
310 if (failed(maybeBufferType))
311 return failure();
312 bufferType = *maybeBufferType;
313 }
314
315
316
317
318
319
320
321 if (buffer.getType() != bufferType) {
322 auto memrefDstType = dyn_cast(bufferType);
323 assert(memrefDstType &&
324 "buffer layout not supported on unranked tensors");
326 rewriter, buffer, memrefDstType, options);
328 return failure();
330 }
331 newOperands.push_back(buffer);
332 }
333
334
336 func::CallOp::create(rewriter, callOp.getLoc(), funcOp.getSymName(),
337 resultTypes, newOperands);
338 newCallOp->setAttrs(callOp->getAttrs());
339
340
341 replaceOpWithBufferizedValues(rewriter, callOp, newCallOp->getResults());
342
344 }
345};
346
348 : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
349 func::ReturnOp> {
354
359
364
367 BufferizationState &state) const {
368#ifndef NDEBUG
369 auto returnOp = castfunc::ReturnOp(op);
370 assert(isa(returnOp->getParentOp()) &&
371 "only support FuncOp parent for ReturnOp");
372#endif
373
374
376 }
377};
378
381 FuncOpInterface, FuncOp> {
382
384
386 auto isaTensor = llvm::IsaPred;
387
388
389 auto funcOp = cast(op);
390 bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
391 bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
392 if (hasTensorArg || hasTensorResult)
393 return true;
394
395
396
397
398
399
400 for (Block &block : funcOp.getBody())
401 if (any_of(block.getArgumentTypes(), isaTensor))
402 return true;
403
404 return false;
405 }
406
407 AliasingOpOperandList
412
413 FailureOr
415 const BufferizationState &state,
417 auto funcOp = cast(op);
418 auto bbArg = cast(value);
419
420
421 if (bbArg.getOwner() == &funcOp.getBody().front())
424
427 }
428
429
430
431
432
433
434
435
438 BufferizationState &state) const {
439 auto funcOp = cast(op);
440 FunctionType funcType = funcOp.getFunctionType();
441
442
444 for (const auto &it : llvm::enumerate(funcType.getInputs())) {
445 Type argType = it.value();
446 if (isa(argType)) {
447 argTypes.push_back(
449 continue;
450 }
451 argTypes.push_back(argType);
452 }
453
454
456 for (Type resultType : funcType.getResults()) {
457 if (auto tensorType = dyn_cast(resultType)) {
458 BufferLikeType resultType = options.functionArgTypeConverterFn(
461 retTypes.push_back(resultType);
462 continue;
463 }
464 retTypes.push_back(resultType);
465 }
466
467
468 auto newFuncType = FunctionType::get(op->getContext(), argTypes, retTypes);
469
470
471 if (funcOp.isExternal()) {
472 funcOp.setType(newFuncType);
474 }
475
476
477 for (Block &block : funcOp.getBody())
480 return failure();
481
482
483 for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
484 assert(returnOp->getNumOperands() == retTypes.size() &&
485 "incorrect number of return values");
487 for (auto [returnVal, bufferizedType] :
488 llvm::zip_equal(returnOp->getOperands(), retTypes)) {
489 auto tensorType = dyn_cast(returnVal.getType());
491
492
493 if (!tensorType) {
494 returnValues.push_back(returnVal);
495 continue;
496 }
497
498
499
500 Value toBufferOp = bufferization::ToBufferOp::create(
501 rewriter, returnOp.getLoc(), bufferizedType, returnVal);
502 returnValues.push_back(toBufferOp);
503 }
504
505 returnOp.getOperandsMutable().assign(returnValues);
506 }
507
508
509 funcOp.setType(newFuncType);
511 }
512
513
516 auto funcOp = cast(op);
517 BlockArgument bbArg = dyn_cast(value);
518 assert(bbArg && "expected BlockArgument");
519
520
521
522 if (bbArg.getOwner() != &funcOp.getBody().front())
523 return true;
524
525
526
528 bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName))
529 return writable.getValue();
530
531
532 return true;
533 }
534};
535
536}
537}
538}
539
543 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
544 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
545 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
546 });
547}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static bool isaTensor(Type t)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
result_range getResults()
MLIRContext * getContext()
Return the context this operation is associated with.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
State for analysis-enabled bufferization.
static BufferLikeType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options)
Return the index-th bufferized function argument type.
Definition FuncBufferizableOpInterfaceImpl.cpp:67
FuncOpAnalysisState
The state of analysis of a FuncOp.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Definition FuncBufferizableOpInterfaceImpl.cpp:541
static FuncOp getCalledFunction(CallOpInterface callOp, SymbolTableCollection &symbolTables)
Return the FuncOp called by callOp.
Definition FuncBufferizableOpInterfaceImpl.cpp:96
static std::optional< int64_t > getEquivalentFuncArgIdx(FuncOp funcOp, const FuncAnalysisState &state, int64_t returnValIdx)
Return the index of the bbArg in the given FuncOp that is equivalent to the specified return value (i...
Definition FuncBufferizableOpInterfaceImpl.cpp:145
static mlir::Attribute getDefaultMemorySpace(const BufferizationOptions &options, TensorLikeType type)
Definition FuncBufferizableOpInterfaceImpl.cpp:55
static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, FuncOp funcOp)
Return the state (phase) of analysis of the FuncOp.
Definition FuncBufferizableOpInterfaceImpl.cpp:127
static const FuncAnalysisState & getFuncAnalysisState(const AnalysisState &state)
Get FuncAnalysisState.
Definition FuncBufferizableOpInterfaceImpl.cpp:118
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
Definition FuncBufferizableOpInterfaceImpl.cpp:22
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
Include the generated interface declarations.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...
AliasingOpOperandList getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg, const AnalysisState &state) const
FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const
FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const
Definition FuncBufferizableOpInterfaceImpl.cpp:227
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
Definition FuncBufferizableOpInterfaceImpl.cpp:193
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) const
All function arguments are writable.
Definition FuncBufferizableOpInterfaceImpl.cpp:255
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
Definition FuncBufferizableOpInterfaceImpl.cpp:178
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
Definition FuncBufferizableOpInterfaceImpl.cpp:163
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< int64_t, SmallVector< int64_t > > IndexToIndexListMapping
A mapping of indices to a list of indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseSet< int64_t > BbArgIndexSet
A set of block argument indices.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
Definition FuncBufferizableOpInterfaceImpl.cpp:33
DenseMap< int64_t, int64_t > IndexMapping
A mapping of indices to indices.
FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const
Definition FuncBufferizableOpInterfaceImpl.cpp:414
AliasingOpOperandList getAliasingOpOperands(Operation *op, Value value, const AnalysisState &state) const
Definition FuncBufferizableOpInterfaceImpl.cpp:408
bool isWritable(Operation *op, Value value, const AnalysisState &state) const
Return true if the given function argument is writable.
Definition FuncBufferizableOpInterfaceImpl.cpp:514
static bool supportsUnstructuredControlFlow()
Definition FuncBufferizableOpInterfaceImpl.cpp:383
bool hasTensorSemantics(Operation *op) const
Definition FuncBufferizableOpInterfaceImpl.cpp:385
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) const
Rewrite function bbArgs and return values into buffer form.
Definition FuncBufferizableOpInterfaceImpl.cpp:436
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
Definition FuncBufferizableOpInterfaceImpl.cpp:360
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
Definition FuncBufferizableOpInterfaceImpl.cpp:355
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
Definition FuncBufferizableOpInterfaceImpl.cpp:350
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) const
Definition FuncBufferizableOpInterfaceImpl.cpp:365