MLIR: lib/Dialect/Linalg/Transforms/Hoisting.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36
37 using llvm::dbgs;
38
39 #define DEBUG_TYPE "linalg-hoisting"
40
41 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
42
43 using namespace mlir;
45
46
47
48
49
50
51
53 scf::ForOp loop,
54 Value newInitOperand,
55 unsigned index,
56 Value newYieldValue) {
59 auto inits = llvm::to_vector(loop.getInits());
60
61
62 assert(index < inits.size());
63 inits[index] = newInitOperand;
64
65 scf::ForOp newLoop = rewriter.createscf::ForOp(
66 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
68
69
70 auto yieldOp = castscf::YieldOp(loop.getBody()->getTerminator());
71 yieldOp.setOperand(index, newYieldValue);
72
73
74 rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
75 newLoop.getBody()->getArguments());
76
77
78 rewriter.replaceOp(loop.getOperation(), newLoop->getResults());
79 return newLoop;
80 }
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
102
103
106
107 root->walk([&](vector::ExtractOp extractOp) {
108 LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
109 << *extractOp.getOperation() << "\n");
110
111 auto loop = dyn_castscf::ForOp(extractOp->getParentOp());
112 if (!loop)
114
115
116 auto blockArg = dyn_cast(extractOp.getVector());
117 if (!blockArg)
119
120
121 OpOperand *initArg = loop.getTiedLoopInit(blockArg);
122 if (!initArg)
124
125
126
127 if (!blockArg.hasOneUse())
129
130 unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
131
132
134 loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
135 auto broadcast = dyn_castvector::BroadcastOp(yieldedVal);
138
139 LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
140
141 Type broadcastInputType = broadcast.getSourceType();
142 if (broadcastInputType != extractOp.getType())
144
145
146
147 for (auto operand : extractOp.getDynamicPosition())
148 if (!loop.isDefinedOutsideOfLoop(operand))
150
152 extractOp.getVectorMutable().assign(initArg->get());
153 });
154 loop.moveOutOfLoop(extractOp);
156
158 rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
159
160 LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
161
165
168 });
169 }
170 }
171
173 LoopLikeOpInterface loop) {
174 Value source = transferRead.getBase();
175
176
177 while (auto srcOp =
178 dyn_cast_or_null(source.getDefiningOp()))
179 source = srcOp.getViewSource();
180
183 llvm::SmallDenseSet<Operation *, 32> processed;
184 while (!users.empty()) {
185 Operation *user = users.pop_back_val();
186
187 if (!processed.insert(user).second)
188 continue;
189 if (auto viewLike = dyn_cast(user)) {
190 users.append(viewLike->getUsers().begin(), viewLike->getUsers().end());
191 continue;
192 }
194 continue;
195 if (!loop->isAncestor(user))
196 continue;
197 return false;
198 }
199 return true;
200 }
201
203 bool verifyNonZeroTrip) {
207
208
211
212
213
214
215
217 if (verifyNonZeroTrip) {
218 root->walk([&](LoopLikeOpInterface loopLike) {
219 std::optional<SmallVector> lbs =
220 loopLike.getLoopLowerBounds();
221 std::optional<SmallVector> ubs =
222 loopLike.getLoopUpperBounds();
223
224 if (!lbs || !ubs)
225 return;
226
227
228
229
230 for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
231 FailureOr<int64_t> maxLb =
234 nullptr, true);
235 if (failed(maxLb))
236 return;
237 FailureOr<int64_t> minUb =
240 if (failed(minUb))
241 return;
242 if (minUb.value() <= maxLb.value())
243 return;
244 definiteNonZeroTripCountLoops.insert(loopLike);
245 }
246 });
247 }
248
249 root->walk([&](vector::TransferReadOp transferRead) {
250 if (!isa(transferRead.getShapedType()))
252
253 LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
254 << *transferRead.getOperation() << "\n");
255 auto loop = dyn_cast(transferRead->getParentOp());
256 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
257 << "\n");
258 if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
260
261 if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(loop)) {
262 LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop
263 << "\n");
265 }
266
267 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
268 << "\n");
269
271 getForwardSlice(transferRead.getOperation(), &forwardSlice);
272
273
274
275 vector::TransferWriteOp transferWrite;
276 for (auto *sliceOp : llvm::reverse(forwardSlice)) {
277 auto candidateWrite = dyn_castvector::TransferWriteOp(sliceOp);
278 if (!candidateWrite ||
279 candidateWrite.getBase() != transferRead.getBase())
280 continue;
281 transferWrite = candidateWrite;
282 }
283
284
285 for (auto operand : transferRead.getOperands())
286 if (!loop.isDefinedOutsideOfLoop(operand))
288
289
290
291 if (!transferWrite) {
292
293
295 loop.moveOutOfLoop(transferRead);
297 }
298
299 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
300 << "\n");
301
302
303
304
305
306
307
308
309
310 if (transferRead.getIndices() != transferWrite.getIndices() ||
311 transferRead.getVectorType() != transferWrite.getVectorType() ||
312 transferRead.getPermutationMap() != transferWrite.getPermutationMap())
314
315 auto *source = transferRead.getBase().getDefiningOp();
316 if (source && isa_and_nonnull(source))
318
319 source = transferWrite.getBase().getDefiningOp();
320 if (source && isa_and_nonnull(source))
322
323
324
326 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
328 for (auto &use : transferRead.getBase().getUses()) {
329 if (!loop->isAncestor(use.getOwner()))
330 continue;
331 if (use.getOwner() == transferRead.getOperation() ||
332 use.getOwner() == transferWrite.getOperation())
333 continue;
334 if (auto transferWriteUse =
335 dyn_castvector::TransferWriteOp(use.getOwner())) {
337 cast(*transferWrite),
338 cast(*transferWriteUse),
339 true))
341 } else if (auto transferReadUse =
342 dyn_castvector::TransferReadOp(use.getOwner())) {
344 cast(*transferWrite),
345 cast(*transferReadUse),
346 true))
348 } else {
349
350
352 }
353 }
354
355
356 loop.moveOutOfLoop(transferRead);
357
358
359 transferWrite->moveAfter(loop);
360
361
362 IRRewriter rewriter(transferRead.getContext());
366 };
367
368 auto maybeNewLoop = loop.replaceWithAdditionalYields(
369 rewriter, transferRead.getVector(),
370 true, yieldFn);
371 if (failed(maybeNewLoop))
373
374 transferWrite.getValueToStoreMutable().assign(
375 maybeNewLoop->getOperation()->getResults().back());
377
378
380 });
381 }
382 }
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, LoopLikeOpInterface loop)
static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop, Value newInitOperand, unsigned index, Value newYieldValue)
Replace loop with a new loop that has a different init operand at position index.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, StopConditionFn stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.