MLIR: lib/Dialect/Affine/Analysis/LoopAnalysis.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
14
21 #include "llvm/Support/MathExtras.h"
22
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/SmallPtrSet.h"
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/Support/Debug.h"
27 #include
28 #include
29 #include <type_traits>
30
31 #define DEBUG_TYPE "affine-loop-analysis"
32
33 using namespace mlir;
35
36 namespace {
37
38
39 class DirectedOpGraph {
40 public:
41
43 assert(!hasNode(op) && "node already added");
44 nodes.emplace_back(op);
45 edges[op] = {};
46 }
47
48
50
51 assert(hasNode(src) && "src node does not exist in graph");
52 assert(hasNode(dest) && "dest node does not exist in graph");
53 edges[src].push_back(getNode(dest));
54 }
55
56
57 bool hasCycle() { return dfs(true); }
58
59 void printEdges() {
60 for (auto &en : edges) {
61 llvm::dbgs() << *en.first << " (" << en.first << ")"
62 << " has " << en.second.size() << " edges:\n";
63 for (auto *node : en.second) {
64 llvm::dbgs() << '\t' << *node->op << '\n';
65 }
66 }
67 }
68
69 private:
70
71
72 struct DGNode {
73 DGNode(Operation *op) : op(op) {};
75
76
77
78
79
80
81
82 int vn = -1;
83
84
85 int fn = -1;
86 };
87
88
90 auto *value =
91 llvm::find_if(nodes, [&](const DGNode &node) { return node.op == op; });
92 assert(value != nodes.end() && "node doesn't exist in graph");
93 return &*value;
94 }
95
96
97 bool hasNode(Operation *key) const {
98 return llvm::find_if(nodes, [&](const DGNode &node) {
99 return node.op == key;
100 }) != nodes.end();
101 }
102
103
104
105
106
107
108 bool dfs(bool cycleCheck = false) {
109 for (DGNode &node : nodes) {
110 node.vn = 0;
111 node.fn = -1;
112 }
113
114 unsigned time = 0;
115 for (DGNode &node : nodes) {
116 if (node.vn == 0) {
117 bool ret = dfsNode(node, cycleCheck, time);
118
119 if (cycleCheck && ret)
120 return true;
121 } else if (cycleCheck && node.fn == -1) {
122
123
124 return true;
125 }
126 }
127 return false;
128 }
129
130
131
132 bool dfsNode(DGNode &node, bool cycleCheck, unsigned &time) const {
133 auto nodeEdges = edges.find(node.op);
134 assert(nodeEdges != edges.end() && "missing node in graph");
135 node.vn = ++time;
136
137 for (auto &neighbour : nodeEdges->second) {
138 if (neighbour->vn == 0) {
139 bool ret = dfsNode(*neighbour, cycleCheck, time);
140 if (cycleCheck && ret)
141 return true;
142 } else if (cycleCheck && neighbour->fn == -1) {
143
144
145 return true;
146 }
147 }
148
149
150 node.fn = ++time;
151
152 return false;
153 }
154
155
157
158
160 };
161
162 }
163
164
165
166
167
168
170 AffineForOp forOp, AffineMap *tripCountMap,
172 MLIRContext *context = forOp.getContext();
173 int64_t step = forOp.getStepAsInt();
174 int64_t loopSpan;
175 if (forOp.hasConstantBounds()) {
176 int64_t lb = forOp.getConstantLowerBound();
177 int64_t ub = forOp.getConstantUpperBound();
178 loopSpan = ub - lb;
179 if (loopSpan < 0)
180 loopSpan = 0;
182 llvm::divideCeilSigned(loopSpan, step), context);
183 tripCountOperands->clear();
184 return;
185 }
186 auto lbMap = forOp.getLowerBoundMap();
187 auto ubMap = forOp.getUpperBoundMap();
188 if (lbMap.getNumResults() != 1) {
190 return;
191 }
192
193
194
195
196 AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());
197
199 lbMap.getResult(0));
200 auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
201 lbSplatExpr, context);
202 AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());
203
206 for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
209
210 *tripCountMap = tripCountValueMap.getAffineMap();
211 tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
213 }
214
215
216
217
218
223
224 if (!map)
225 return std::nullopt;
226
227
228 std::optional<uint64_t> tripCount;
229 for (auto resultExpr : map.getResults()) {
230 if (auto constExpr = dyn_cast(resultExpr)) {
231 if (tripCount.has_value())
232 tripCount =
233 std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
234 else
235 tripCount = constExpr.getValue();
236 } else {
237 return std::nullopt;
238 }
239 }
240 return tripCount;
241 }
242
243
244
245
250
251 if (!map)
252 return 1;
253
254
255
256 assert(map.getNumResults() >= 1 && "expected one or more results");
257 std::optional<uint64_t> gcd;
258 for (auto resultExpr : map.getResults()) {
259 uint64_t thisGcd;
260 if (auto constExpr = dyn_cast(resultExpr)) {
261 uint64_t tripCount = constExpr.getValue();
262
263 if (tripCount == 0)
265 else
266
267 thisGcd = tripCount;
268 } else {
269
270 thisGcd = resultExpr.getLargestKnownDivisor();
271 }
272 if (gcd.has_value())
273 gcd = std::gcd(*gcd, thisGcd);
274 else
275 gcd = thisGcd;
276 }
277 assert(gcd.has_value() && "value expected per above logic");
278 return *gcd;
279 }
280
281
282
283
284
287 assert(isa(index.getType()) && "index must be of 'index' type");
293 }
294
295
296 template
298 AffineValueMap avm(memOp.getAffineMap(), memOp.getMapOperands());
300 return !llvm::is_contained(avm.getOperands(), forOp.getInductionVar());
301 }
302
303
305 AffineForOp);
307 AffineForOp);
310
314 for (Value index : indices) {
316 res.insert(index);
317 }
318 return res;
319 }
320
321
322 template
324 int *memRefDim) {
325 static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
326 AffineWriteOpInterface>::value,
327 "Must be called on either an affine read or write op");
328 assert(memRefDim && "memRefDim == nullptr");
329 auto memRefType = memoryOp.getMemRefType();
330
331 if (!memRefType.getLayout().isIdentity())
332 return memoryOp.emitError("NYI: non-trivial layout map"), false;
333
334 int uniqueVaryingIndexAlongIv = -1;
335 auto accessMap = memoryOp.getAffineMap();
337 unsigned numDims = accessMap.getNumDims();
338 for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) {
339
341 auto resultExpr = accessMap.getResult(i);
342 resultExpr.walk([&](AffineExpr expr) {
343 if (auto dimExpr = dyn_cast(expr))
344 exprOperands.push_back(mapOperands[dimExpr.getPosition()]);
345 else if (auto symExpr = dyn_cast(expr))
346 exprOperands.push_back(mapOperands[numDims + symExpr.getPosition()]);
347 });
348
349 for (Value exprOperand : exprOperands) {
351 if (uniqueVaryingIndexAlongIv != -1) {
352
353 return false;
354 }
355 uniqueVaryingIndexAlongIv = i;
356 }
357 }
358 }
359
360 if (uniqueVaryingIndexAlongIv == -1)
361 *memRefDim = -1;
362 else
363 *memRefDim = memRefType.getRank() - (uniqueVaryingIndexAlongIv + 1);
364 return true;
365 }
366
368 AffineReadOpInterface loadOp,
369 int *memRefDim);
371 AffineWriteOpInterface loadOp,
372 int *memRefDim);
373
374 template
376 auto memRefType = memoryOp.getMemRefType();
377 return isa(memRefType.getElementType());
378 }
379
381
382 static bool
386 auto *forOp = loop.getOperation();
387
388
391 conditionals.match(forOp, &conditionalsMatched);
392 if (!conditionalsMatched.empty()) {
393 return false;
394 }
395
396
397
400 if (MemRefType t = dyn_cast(type))
401 return !VectorType::isValidElementType(t.getElementType());
402 return !VectorType::isValidElementType(type);
403 }))
404 return true;
405 return !llvm::all_of(op.getResultTypes(), VectorType::isValidElementType);
406 });
408 types.match(forOp, &opsMatched);
409 if (!opsMatched.empty()) {
410 return false;
411 }
412
413
415 return op.getNumRegions() != 0 && !isa<AffineIfOp, AffineForOp>(op);
416 });
418 regions.match(forOp, ®ionsMatched);
419 if (!regionsMatched.empty()) {
420 return false;
421 }
422
424 vectorTransferMatcher.match(forOp, &vectorTransfersMatched);
425 if (!vectorTransfersMatched.empty()) {
426 return false;
427 }
428
431 loadAndStores.match(forOp, &loadAndStoresMatched);
432 for (auto ls : loadAndStoresMatched) {
433 auto *op = ls.getMatchedOperation();
434 auto load = dyn_cast(op);
435 auto store = dyn_cast(op);
436
437
438
440 if (vector) {
441 return false;
442 }
443 if (isVectorizableOp && !isVectorizableOp(loop, *op)) {
444 return false;
445 }
446 }
447 return true;
448 }
449
451 AffineForOp loop, int *memRefDim, NestedPattern &vectorTransferMatcher) {
452 *memRefDim = -1;
454 auto load = dyn_cast(op);
455 auto store = dyn_cast(op);
456 int thisOpMemRefDim = -1;
457 bool isContiguous =
459 cast(*load),
460 &thisOpMemRefDim)
462 cast(*store),
463 &thisOpMemRefDim);
464 if (thisOpMemRefDim != -1) {
465
466
467 if (*memRefDim != -1 && *memRefDim != thisOpMemRefDim)
468 return false;
469 *memRefDim = thisOpMemRefDim;
470 }
471 return isContiguous;
472 });
474 }
475
477 AffineForOp loop, NestedPattern &vectorTransferMatcher) {
479 }
480
481
482
483
484
485
488 auto *forBody = forOp.getBody();
489 assert(shifts.size() == forBody->getOperations().size());
490
491
492
494 for (const auto &it :
495 llvm::enumerate(llvm::reverse(forBody->getOperations()))) {
496 auto &op = it.value();
497
498
499
500 size_t index = shifts.size() - it.index() - 1;
501
502
503 uint64_t shift = shifts[index];
504 forBodyShift.try_emplace(&op, shift);
505
506
507 for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
508 Value result = op.getResult(i);
509 for (auto *user : result.getUsers()) {
510
511
512 if (auto *ancOp = forBody->findAncestorOpInBlock(*user)) {
513 assert(forBodyShift.count(ancOp) > 0 && "ancestor expected in map");
514 if (shift != forBodyShift[ancOp])
515 return false;
516 }
517 }
518 }
519 }
520 return true;
521 }
522
524 assert(!loops.empty() && "no original loops provided");
525
526
528 loops[0]->walk([&](Operation *op) {
529 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
530 loadAndStoreOps.push_back(op);
531 });
532
533 unsigned numOps = loadAndStoreOps.size();
534 unsigned numLoops = loops.size();
535 for (unsigned d = 1; d <= numLoops + 1; ++d) {
536 for (unsigned i = 0; i < numOps; ++i) {
537 Operation *srcOp = loadAndStoreOps[i];
539 for (unsigned j = 0; j < numOps; ++j) {
540 Operation *dstOp = loadAndStoreOps[j];
542
545 srcAccess, dstAccess, d, nullptr,
546 &depComps);
547
548
550 continue;
551
552
553
554
555 LLVM_DEBUG(llvm::dbgs() << "Checking whether tiling legality violated "
556 "for dependence at depth: "
557 << Twine(d) << " between:\n";);
558 LLVM_DEBUG(srcAccess.opInst->dump());
559 LLVM_DEBUG(dstAccess.opInst->dump());
561 if (depComp.lb.has_value() && depComp.ub.has_value() &&
562 *depComp.lb < *depComp.ub && *depComp.ub < 0) {
563 LLVM_DEBUG(llvm::dbgs()
564 << "Dependence component lb = " << Twine(*depComp.lb)
565 << " ub = " << Twine(*depComp.ub)
566 << " is negative at depth: " << Twine(d)
567 << " and thus violates the legality rule.\n");
568 return false;
569 }
570 }
571 }
572 }
573 }
574
575 return true;
576 }
577
579
580
581 DirectedOpGraph graph;
584 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
585 accesses.emplace_back(op);
586 graph.addNode(op);
587 }
588 });
589
590
592 for (const auto &accA : accesses) {
593 for (const auto &accB : accesses) {
594 if (accA.memref != accB.memref)
595 continue;
596
597 unsigned numCommonLoops =
599 for (unsigned d = rootDepth + 1; d <= numCommonLoops + 1; ++d) {
601 graph.addEdge(accA.opInst, accB.opInst);
602 }
603 }
604 }
605 return graph.hasCycle();
606 }
static bool isVectorizableLoopBodyWithOpCond(AffineForOp loop, const VectorizableOpFun &isVectorizableOp, NestedPattern &vectorTransferMatcher)
std::function< bool(AffineForOp, Operation &)> VectorizableOpFun
static bool isAccessIndexInvariant(Value iv, Value index)
Given an affine.for iv and an access index of type index, returns true if index is independent of iv ...
static bool isVectorElement(LoadOrStoreOp memoryOp)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
void composeSimplifyAndCanonicalize()
Composes all incoming affine.apply ops and then simplifies and canonicalizes the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
AffineMap getAffineMap() const
bool isFunctionOf(unsigned idx, Value value) const
Return true if the idx^th result depends on 'value', false otherwise.
void setResult(unsigned i, AffineExpr e)
unsigned getNumResults() const
static void difference(const AffineValueMap &a, const AffineValueMap &b, AffineValueMap *res)
Return the value map that is the difference of value maps 'a' and 'b', represented as an affine map a...
void match(Operation *op, SmallVectorImpl< NestedMatch > *matches)
Returns all the top-level matches in op.
NestedPattern If(const NestedPattern &child)
bool isLoadOrStore(Operation &op)
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
std::optional< uint64_t > getConstantTripCount(AffineForOp forOp)
Returns the trip count of the loop if it's a constant, std::nullopt otherwise.
bool isTilingValid(ArrayRef< AffineForOp > loops)
Checks whether hyper-rectangular loop tiling of the nest represented by loops is valid.
bool isVectorizableLoopBody(AffineForOp loop, NestedPattern &vectorTransferMatcher)
Checks whether the loop is structurally vectorizable; i.e.
unsigned getNumCommonSurroundingLoops(Operation &a, Operation &b)
Returns the number of surrounding loops common to both A and B.
DenseSet< Value, DenseMapInfo< Value > > getInvariantAccesses(Value iv, ArrayRef< Value > indices)
Given an induction variable iv of type AffineForOp and indices of type IndexType, returns the set of ...
void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *map, SmallVectorImpl< Value > *operands)
Returns the trip count of the loop as an affine map with its corresponding operands if the latter is ...
bool isInvariantAccess(LoadOrStoreOp memOp, AffineForOp forOp)
Checks if an affine read or write operation depends on forOp's IV, i.e., if the memory access is inva...
DependenceResult checkMemrefAccessDependence(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, unsigned loopDepth, FlatAffineValueConstraints *dependenceConstraints=nullptr, SmallVector< DependenceComponent, 2 > *dependenceComponents=nullptr, bool allowRAR=false)
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
uint64_t getLargestDivisorOfTripCount(AffineForOp forOp)
Returns the greatest known integral divisor of the trip count.
bool isContiguousAccess(Value iv, LoadOrStoreOp memoryOp, int *memRefDim)
Given:
bool hasDependence(DependenceResult result)
Utility function that returns true if the provided DependenceResult corresponds to a dependence resul...
unsigned getNestingDepth(Operation *op)
Returns the nesting depth of this operation, i.e., the number of loops surrounding this operation.
bool isOpwiseShiftValid(AffineForOp forOp, ArrayRef< uint64_t > shifts)
Checks where SSA dominance would be violated if a for op's body operations are shifted by the specifi...
bool hasCyclicDependence(AffineForOp root)
Returns true if the affine nest rooted at root has a cyclic dependence among its affine memory access...
bool noDependence(DependenceResult result)
Returns true if the provided DependenceResult corresponds to the absence of a dependence.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
Checks whether two accesses to the same memref access the same element.
Encapsulates a memref load or store access information.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.