MLIR: lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
12
13 using namespace mlir;
15
16
19 return false;
20
23 if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
24 return true;
25
26
27
28
29
30
31
32
34 genericOp.getRegionOutputArgs()[result.getResultNumber()];
36 return false;
38
39
41 return false;
42
43
44 auto yieldOp = dyn_castlinalg::YieldOp(argUserOp);
45 if (!yieldOp)
46 return false;
47
48
49 if (yieldOp.getOperand(result.getResultNumber()) != outputArg)
50 return false;
51
52 return true;
53 }
54
55
56
57
58
59
60
61
62
67 llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
68 llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
69 for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
70 OpOperand *inputOpOperand = en.value();
71
72
73 if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
74
75
76 droppedOpOperands.push_back(inputOpOperand);
77 if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
78 continue;
79 droppedOpOperands.pop_back();
80 }
81
82
83 AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
84 auto it =
85 dedupedInputs.find(std::make_pair(inputOpOperand->get(), indexingMap));
86 if (it != dedupedInputs.end()) {
87 origToNewPos[en.index()] = it->second;
88 droppedOpOperands.push_back(inputOpOperand);
89 continue;
90 }
91
92
93 origToNewPos[en.index()] = newInputOperands.size();
94 dedupedInputs[{inputOpOperand->get(), indexingMap}] =
95 newInputOperands.size();
96 newInputOperands.push_back(inputOpOperand->get());
97 newIndexingMaps.push_back(indexingMap);
98 }
99 return origToNewPos;
100 }
101
102
103
104
105
110 llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
111 llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
112 dedupedOutpts;
113
114
115 if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
116 for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
117 origToNewPos[en.index()] = newOutputOperands.size();
118 newOutputOperands.push_back(en.value().get());
119 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&en.value()));
120 }
121 return origToNewPos;
122 }
123
124
125
126
127
128 auto yieldOp = cast(genericOp.getBody()->getTerminator());
129 for (const auto &outputOpOperand :
131 OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
133 genericOp.getMatchingIndexingMap(&outputOpOperand.value());
134 auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
135 yieldOp->getOperand(outputOpOperand.index()));
137
138
139
140
141 droppedOpOperands.push_back(&outputOpOperand.value());
142 if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
143 continue;
144 }
145 droppedOpOperands.pop_back();
146 }
147
148 if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
149
150
151
152
153
154 auto it = dedupedOutpts.find(key);
155 if (it != dedupedOutpts.end()) {
156 origToNewPos[outputOpOperand.index()] = it->second;
157 droppedOpOperands.push_back(&outputOpOperand.value());
158 continue;
159 }
160 }
161
162 origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
163 dedupedOutpts[key] = newOutputOperands.size();
164 newOutputOperands.push_back(outputOpOperand.value().get());
165 newIndexingMaps.push_back(
166 genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
167 }
168 return origToNewPos;
169 }
170
171
173 GenericOp genericOp, GenericOp newOp,
174 const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
175 const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
177
178 Block *newOpBlock = &newOp.getRegion().front();
179 assert(newOpBlock->empty() && "expected new op to have an empty payload");
180 Block *origOpBlock = &genericOp.getRegion().front();
182
183
184
185 auto updateReplacements =
188 const llvm::SmallDenseMap<unsigned, unsigned> &map) {
189 for (const auto &origOperand : llvm::enumerate(origOperands)) {
190 auto it = map.find(origOperand.index());
191 if (it == map.end())
192 continue;
193 OpOperand *newOperand = newOperands[it->second];
194 replacements[origOperand.value()->getOperandNumber()] =
196 }
197 };
198
201 updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
202
204 genericOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
206 newOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
207 updateReplacements(origOutputOperands, newOutputOperands,
208 origOutsToNewOutsPos);
209
210
211 if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
213 YieldOp origYieldOp = cast(origOpBlock->getTerminator());
215
217 for (const auto &yieldOpOperands :
219 auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
220 if (it == origOutsToNewOutsPos.end())
221 continue;
222 newYieldVals[it->second] = yieldOpOperands.value();
223 }
225 }
226
227 rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
228 }
229
230 FailureOrlinalg::GenericOp
232 RewriterBase &rewriter, linalg::GenericOp genericOp, bool removeOutputs) {
233
234
236
237
240
241
242 llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
244 newIndexingMaps);
245
246
247 llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
249 newIndexingMaps, removeOutputs);
250
251
252 if (newInputOperands.size() + newOutputOperands.size() ==
253 genericOp->getNumOperands())
254 return genericOp;
255
256
257 Location loc = genericOp.getLoc();
259 for (Value v : newOutputOperands)
260 if (isa(v.getType()))
261 newResultTypes.push_back(v.getType());
262 auto newOp = rewriter.create(
263 loc, newResultTypes, newInputOperands, newOutputOperands,
265 genericOp.getIteratorTypes(), genericOp.getDocAttr(),
266 genericOp.getLibraryCallAttr(),
268 return;
269 });
270
273 if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
274 newOp->setAttr(kv.getName(), kv.getValue());
275
276
277 populateOpPayload(genericOp, newOp, origInsToNewInsPos, origOutsToNewOutsPos,
278 rewriter);
279
280
281 SmallVector replacementsVals(genericOp->getNumResults(), nullptr);
282 for (const auto &result : llvm::enumerate(genericOp.getResults())) {
283 auto it = origOutsToNewOutsPos.find(result.index());
284 if (it == origOutsToNewOutsPos.end())
285 continue;
286 replacementsVals[result.index()] = newOp.getResult(it->second);
287 }
288 rewriter.replaceOp(genericOp, replacementsVals);
289 return newOp;
290 }
291
292 namespace {
293
294 struct DeduplicateAndRemoveDeadOperandsAndResults
296 DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx,
297 bool removeOutputs)
298 : OpRewritePattern(ctx), removeOutputs(removeOutputs) {}
299
300 LogicalResult matchAndRewrite(GenericOp genericOp,
303 rewriter, genericOp, removeOutputs);
304 if (failed(newOp) || newOp.value() == genericOp) {
306 genericOp, "failed to dedup operands/remove dead results");
307 }
308 return success();
309 }
310
311 private:
312
313 bool removeOutputs;
314 };
315
316
317
318
319
320
321
322
323 struct RemoveUnusedCycleInGenericOp : public OpRewritePattern {
325
326 LogicalResult matchAndRewrite(GenericOp genericOp,
328
329
330 if (!genericOp.hasPureTensorSemantics())
331 return failure();
332
333 bool hasRemovedCycles = false;
334
335 for (const auto &outputOpOperand :
337
338
339 Value result = genericOp.getResult(outputOpOperand.index());
341 continue;
342
343
345 genericOp.getRegionOutputArgs()[outputOpOperand.index()];
347 continue;
348
349
352 continue;
353
354
356 if (!isalinalg::YieldOp(cycleUserOp))
357 continue;
358
359
360 if (cycleUserOp->getOperand(outputOpOperand.index()) !=
362 continue;
363
364
365
366 rewriter.replaceOp(cycleOp, outputArg);
368 hasRemovedCycles = true;
369 }
370
371 if (hasRemovedCycles) {
372 return success();
373 }
374
375 return failure();
376 }
377 };
378
379
380
381
382
383
384
385
386
387
388
389
390 struct FoldDuplicateInputBbArgs : public OpRewritePattern {
392
393 LogicalResult matchAndRewrite(GenericOp genericOp,
395
397 for (int i = 0; i < genericOp.getNumDpsInputs(); ++i) {
398
399 if (genericOp.getBody()->getArgument(i).getUses().empty())
400 continue;
401
402 for (int j = genericOp->getNumOperands() - 1; j > i; --j) {
403 if (genericOp->getOperand(i) == genericOp->getOperand(j) &&
404 genericOp.getIndexingMapsArray()[i] ==
405 genericOp.getIndexingMapsArray()[j]) {
406 replacements[i] = j;
407 break;
408 }
409 }
410 }
411
412
413 if (replacements.empty())
414 return failure();
415
416
418 for (auto [before, after] : replacements) {
419 BlockArgument bbArg = genericOp.getBody()->getArgument(before);
420 BlockArgument replacement = genericOp.getBody()->getArgument(after);
422 }
423 });
424
425 return success();
426 }
427 };
428
429 }
430
433 patterns.insert(
434 patterns.getContext(), true);
435 patterns.insert(patterns.getContext());
436 }
437
440 patterns.insert(
441 patterns.getContext(), false);
442 patterns.insert(patterns.getContext());
443 }
static llvm::SmallDenseMap< unsigned, unsigned > deduplicateOutputOperands(GenericOp genericOp, SmallVector< OpOperand * > &droppedOpOperands, SmallVector< Value > &newOutputOperands, SmallVector< AffineMap > &newIndexingMaps, bool removeOutputs)
static llvm::SmallDenseMap< unsigned, unsigned > deduplicateInputOperands(GenericOp genericOp, SmallVector< OpOperand * > &droppedOpOperands, SmallVector< Value > &newInputOperands, SmallVector< AffineMap > &newIndexingMaps)
static void populateOpPayload(GenericOp genericOp, GenericOp newOp, const llvm::SmallDenseMap< unsigned, unsigned > &origInsToNewInsPos, const llvm::SmallDenseMap< unsigned, unsigned > &origOutsToNewOutsPos, RewriterBase &rewriter)
static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result)
Return true if the result of an operation genericOp is dead.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
Value getOperand(unsigned idx)
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
user_iterator user_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
user_iterator user_begin() const
bool hasOneUse() const
Returns true if this value has exactly one use.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< linalg::GenericOp > deduplicateOperandsAndRemoveDeadResults(RewriterBase &rewriter, linalg::GenericOp genericOp, bool removeOutputs)
Method to deduplicate operands and remove dead results of linalg.generic operations.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.