MLIR: lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
13 #include
14
15 using namespace mlir;
17
18 namespace {
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84 struct DecomposeLinalgOp : public OpRewritePattern {
86
87 LogicalResult matchAndRewrite(GenericOp genericOp,
89
90 private:
91
92
93 GenericOp createPeeledGenericOp(GenericOp genericOp,
95
96
97
98 GenericOp createResidualGenericOp(GenericOp genericOp,
99 GenericOp peeledGenericOp,
101 };
102 }
103
104
106 GenericOp op) {
110 auto allShapesSizes =
111 cast(op.getOperation()).createFlatListOfOperandDims(b, loc);
112 AffineMap map = op.getShapesToLoopsMap();
115 allShapesSizes);
116 }
117
118
123 for (const auto &position :
125 return cast(expr).getPosition();
126 })))
127 permutedValues[position.value()] = values[position.index()];
128 return permutedValues;
129 }
130
131
134 "expected scalar type while computing zero value");
135 if (isa(elementType))
136 return b.createarith::ConstantIntOp(loc, 0, elementType);
137 if (elementType.isIndex())
138 return b.createarith::ConstantIndexOp(loc, 0);
139
140 auto floatType = cast(elementType);
141 return b.createarith::ConstantFloatOp(
142 loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
143 }
144
145 GenericOp
146 DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
148 Block *body = genericOp.getBody();
149 Operation *peeledScalarOperation = &(*body->begin());
151 genericOp.getIndexingMapsArray();
152
153
154
155 Location loc = genericOp.getLoc();
159
160
161 for (auto scalarOpResult : peeledScalarOperation->getResults()) {
162
163
164
165 std::optional resultNumber;
166 for (auto *user : scalarOpResult.getUsers()) {
167 if (auto yieldOp = dyn_cast(user)) {
168
169 for (OpOperand &yieldOperand : yieldOp->getOpOperands()) {
170 if (yieldOperand.get() == scalarOpResult) {
171 resultNumber = yieldOperand.getOperandNumber();
172 break;
173 }
174 }
175 assert(resultNumber && "unable to find use of a value in its user");
176 break;
177 }
178 }
179 if (resultNumber) {
180 newInitValues.push_back(
181 genericOp.getDpsInitOperand(*resultNumber)->get());
182 OpResult result = cast(genericOp.getResult(*resultNumber));
183 newResultTypes.push_back(result.getType());
184 peeledGenericOpIndexingMaps.push_back(
185 genericOp.getIndexingMapMatchingResult(result));
186 continue;
187 }
188
189
191 Value emptyTensor =
192 rewriter.createtensor::EmptyOp(loc, domain, scalarOpResult.getType());
193 newInitValues.push_back(emptyTensor);
194 newResultTypes.push_back(emptyTensor.getType());
195 peeledGenericOpIndexingMaps.push_back(indexingMap);
196 }
197
198
200 outsOperands.append(newInitValues.begin(), newInitValues.end());
201 SmallVector resultTypes = llvm::to_vector(genericOp.getResultTypes());
202 resultTypes.append(newResultTypes.begin(), newResultTypes.end());
203 auto indexingMapAttr =
205 return rewriter.create(
206 loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr,
207 genericOp.getIteratorTypes(), nullptr, nullptr,
209 }
210
211 GenericOp
212 DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
213 GenericOp peeledGenericOp,
215
216
217 SmallVector residualGenericOpOperands = genericOp.getInputs();
218 unsigned origNumResults = genericOp.getNumResults();
219 unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
221 for (auto resultNum :
222 llvm::seq(origNumResults, peeledGenericOpNumResults))
223 extraIns.push_back(peeledGenericOp->getResult(resultNum));
224 residualGenericOpOperands.append(extraIns);
225
226
227
228 auto indexingMaps = llvm::to_vector(
229 llvm::map_range(genericOp.getDpsInputOperands(), [&](OpOperand *operand) {
230 return genericOp.getMatchingIndexingMap(operand);
231 }));
232 for (auto resultNum :
233 llvm::seq(origNumResults, peeledGenericOpNumResults)) {
234 OpResult result = cast(peeledGenericOp.getResult(resultNum));
235 indexingMaps.push_back(
236 peeledGenericOp.getIndexingMapMatchingResult(result));
237 }
238 for (OpOperand &outOperand : genericOp.getDpsInitsMutable())
239 indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand));
240
242 return rewriter.create(
243 genericOp->getLoc(), genericOp->getResultTypes(),
244 residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,
245 genericOp.getIteratorTypes(), nullptr, nullptr,
247 }
248
249 LogicalResult
250 DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
252
253 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {
255 "unhandled decomposition of operation "
256 "with non-parallel iterator types");
257 }
258
259
260
261 if (!genericOp.hasPureTensorSemantics()) {
263 genericOp, "only operations with tensor semantics are handled");
264 }
265
266 if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) {
267 return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation();
268 })) {
270 genericOp, "unhandled decomposition of generic op with out operand not "
271 "accessed using a permutation");
272 }
273
274
275 Block *body = genericOp.getBody();
278 "operation has less than 3 statements");
279 }
280
281
282 if (llvm::any_of(body->getOperations().begin()->getResultTypes(),
283 [](Type t) { return !t.isIntOrIndexOrFloat(); })) {
286 "expected return type to be only int, index or float");
287 }
288
289 GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);
290 GenericOp residualGenericOp =
291 createResidualGenericOp(genericOp, peeledGenericOp, rewriter);
292
293
294
295 Block *peeledGenericOpBody = peeledGenericOp.getBody();
296 Block *residualGenericOpBody = residualGenericOp.getBody();
297 assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&
298 "expected split generic ops to have empty region");
301 residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(),
303
304 Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());
305 auto *yieldOp = residualGenericOpBody->getTerminator();
306 {
307
311 for (auto origYield : yieldOp->getOperands()) {
312 if (origYield.getDefiningOp() == peeledScalarOperation) {
313 yieldedVals.push_back(origYield);
314 } else {
315
316
317
320 yieldedVals.push_back(
321 getZero(rewriter, genericOp.getLoc(), origYield.getType()));
322 }
323 }
324 yieldedVals.append(llvm::to_vector(
325 llvm::map_range(peeledScalarOperation->getResults(),
327 rewriter.create(genericOp.getLoc(), yieldedVals);
328 }
329
330
331
332 unsigned origNumInputs = genericOp.getNumDpsInputs();
333 for (const auto &inputBlockArg :
335 Value residualOpReplacementArg =
336 residualGenericOpBody->getArgument(inputBlockArg.index());
338 inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) {
339 return use.getOwner()->getBlock() == residualGenericOpBody;
340 });
341
342 Value peeledOpReplacementArg =
343 peeledGenericOpBody->getArgument(inputBlockArg.index());
345 inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) {
346 return use.getOwner()->getBlock() == peeledGenericOpBody;
347 });
348 }
349
350
351
352
353
355 for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
356 OpResult opr = dyn_cast(yieldValue.value());
357 if (!opr || opr.getOwner() != peeledScalarOperation)
358 replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
359 else
360 replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));
361 }
362
363
364
365 {
367 unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();
368 scalarReplacements.reserve(peeledScalarOpNumResults);
369 for (auto num : llvm::seq(0, peeledScalarOpNumResults))
370 scalarReplacements.push_back(
371 residualGenericOpBody->getArgument(num + origNumInputs));
372 bool allUsesReplaced = false;
374 residualGenericOpBody, &allUsesReplaced);
375 assert(!allUsesReplaced &&
376 "peeled scalar operation is erased when it wasnt expected to be");
377 }
378
379
380 rewriter.replaceOp(genericOp, replacements);
381 return success();
382 }
383
387
388 if (removeDeadArgsAndResults)
390 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
SmallVector< OpFoldResult > permuteValues(ArrayRef< OpFoldResult > values, AffineMap map)
Helper method to permute the list of values based on the map.
static SmallVector< OpFoldResult > getGenericOpLoopRange(OpBuilder &b, GenericOp op)
Helper method to compute the range of a generic op.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
ArrayRef< AffineExpr > getResults() const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
Operation is the basic unit of execution within MLIR.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
Find uses of from within block and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, bool removeDeadArgsAndResults=true)
Populate patterns for splitting a LinalgOp with multiple statements within its payload into multiple ...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...