[mlir][IntRangeAnalysis] Handle unstructured loop arguments correctly by krzysz00 · Pull Request #119459 · llvm/llvm-project (original) (raw)
@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir
Author: Krzysztof Drewniak (krzysz00)
Changes
The integer range analysis currently has a bug where, because of how it interacts with dead code analysis, it will sometimes declare code dead that isn't dead, becaues it hasn't seen the edge that loops an incremented value back to itself yet.
This commit fixes the issue by overriding the join method on lattice values in order to detect these back-edges on non-entry blocks and then snapping the passed-around value to its maximum possible range, just like we do for loop-varying values in region control flow.
Fixes #119045
Full diff: https://github.com/llvm/llvm-project/pull/119459.diff
3 Files Affected:
- (modified) mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h (+10)
- (modified) mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (+46)
- (modified) mlir/test/Dialect/Arith/int-range-opts.mlir (+60)
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h index f99eae379596b6..464a47355b4207 100644 --- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h @@ -31,6 +31,16 @@ class IntegerValueRangeLattice : public Lattice { public: using Lattice::Lattice;
- /// Override the join logic so that arguments to non-entry blocks
- /// whose arguments come from later in the program get set to
- /// a maximal value so that we don't prematurely declare code to be
- /// deade.
- ChangeResult join(const AbstractSparseLattice &rhs) override;
- ChangeResult join(const IntegerValueRange &range) {
- return Lattice::join(range);
- }
- /// If the range can be narrowed to an integer constant, update the constant /// value of the SSA value. void onUpdate(DataFlowSolver *solver) const override; diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index a97e43708d9a37..a45fcee345e91d 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -37,6 +37,50 @@ using namespace mlir; using namespace mlir::dataflow;
+/// Return true if block
is a non-entry block with a predecessor that's
+/// defined after the block. This allows us to detect loop-varying values
+/// in unstructured control flow.
+static bool isLoopLikeBlock(Block *block) {
- if (!block || block->isEntryBlock())
- return false;
- Region *parent = block->getParent();
- if (!parent)
- return false;
- SmallPtrSet<Block *, 4> preds;
- for (Block *pred : block->getPredecessors())
- preds.insert(pred);
- if (preds.size() <= 1)
- return false;
- for (Block ®ionBlock : parent->getBlocks()) {
- if (®ionBlock == block)
break;
- preds.erase(®ionBlock);
- }
- // The block loops back on itself or has an edge from further in the program.
- return !preds.empty(); +}
- +ChangeResult IntegerValueRangeLattice::join(const AbstractSparseLattice &rhs) {
- Value lhsAnchor = getAnchor();
- Block *lhsBlock = lhsAnchor.getParentBlock();
- unsigned width = ConstantIntRanges::getStorageBitwidth(lhsAnchor.getType());
- /// Special-case: we're in unstructured control flow and one of the
- /// predecessors of this block argument is defined in a block that comes after
- /// the argument. So we conservatively conclude that the value could be
- /// anything.
- if (width > 0 && isa(lhsAnchor) && isLoopLikeBlock(lhsBlock)) {
- LLVM_DEBUG(llvm::dbgs() << "Found loop-varying block argument " << lhsAnchor
<< " from " << rhs.getAnchor() << "\n");
- LLVM_DEBUG(llvm::dbgs() << "Inferring maximum range\n");
- IntegerValueRange maxRange = IntegerValueRange::getMaxRange(lhsAnchor);
- return join(maxRange);
- }
- return Lattice::join(rhs); +}
- void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { Lattice::onUpdate(solver);
@@ -206,6 +250,8 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( if (max.sge(min)) { IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); auto ivRange = ConstantIntRanges::fromSigned(min, max);
LLVM_DEBUG(llvm::dbgs()
} return; diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir index ea5969a1002580..e312cf175f5b56 100644 --- a/mlir/test/Dialect/Arith/int-range-opts.mlir +++ b/mlir/test/Dialect/Arith/int-range-opts.mlir @@ -132,3 +132,63 @@ func.func @wraps() -> i8 { %mod = arith.remsi %val, %c64 : i8 return %mod : i8 }<< "Inferred loop bound range: " << ivRange << "\n"); propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
- +// -----
- +// Note: I wish I had a simpler example than this, but getting rid of a +// bunch of the arithmetic made the issue go away. +// CHECK-LABEL: @blocks_prematurely_declared_dead_bug +// CHECK-NOT: arith.constant true +func.func @blocks_prematurely_declared_dead_bug(%mem: memref<?xf16>) {
- %cst = arith.constant dense : vector<1xi1>
- %c1 = arith.constant 1 : index
- %cst_0 = arith.constant dense<0.000000e+00> : vector<1xf16>
- %cst_1 = arith.constant 0.000000e+00 : f16
- %c16 = arith.constant 16 : index
- %c0 = arith.constant 0 : index
- %c64 = arith.constant 64 : index
- %thread_id_x = gpu.thread_id x upper_bound 64
- %6 = test.with_bounds { smin = 16 : index, smax = 112 : index, umin = 16 : index, umax = 112 : index } : index
- %8 = arith.divui %6, %c16 : index
- %9 = arith.muli %8, %c16 : index
- cf.br ^bb1(%c0 : index) +^bb1(%12: index): // 2 preds: ^bb0, ^bb7
- %13 = arith.cmpi slt, %12, %9 : index
- cf.cond_br %13, ^bb2, ^bb8 +^bb2: // pred: ^bb1
- %14 = arith.subi %9, %12 : index
- %15 = arith.minsi %14, %c64 : index
- %16 = arith.subi %15, %thread_id_x : index
- %17 = vector.constant_mask [1] : vector<1xi1>
- %18 = arith.cmpi sgt, %16, %c0 : index
- %19 = arith.select %18, %17, %cst : vector<1xi1>
- %20 = vector.extract %19[0] : i1 from vector<1xi1>
- %21 = vector.insert %20, %cst [0] : i1 into vector<1xi1>
- %22 = arith.addi %12, %thread_id_x : index
- cf.br ^bb3(%c0, %cst_0 : index, vector<1xf16>) +^bb3(%23: index, %24: vector<1xf16>): // 2 preds: ^bb2, ^bb6
- %25 = arith.cmpi slt, %23, %c1 : index
- cf.cond_br %25, ^bb4, ^bb7 +^bb4: // pred: ^bb3
- %26 = vector.extractelement %21[%23 : index] : vector<1xi1>
- cf.cond_br %26, ^bb5, ^bb6(%24 : vector<1xf16>) +^bb5: // pred: ^bb4
- %27 = arith.addi %22, %23 : index
- %28 = memref.load %mem[%27] : memref<?xf16>
- %29 = vector.insertelement %28, %24[%23 : index] : vector<1xf16>
- cf.br ^bb6(%29 : vector<1xf16>) +^bb6(%30: vector<1xf16>): // 2 preds: ^bb4, ^bb5
- %31 = arith.addi %23, %c1 : index
- cf.br ^bb3(%31, %30 : index, vector<1xf16>) +^bb7: // pred: ^bb3
- %37 = arith.addi %12, %c64 : index
- cf.br ^bb1(%37 : index) +^bb8: // pred: ^bb1
- %70 = arith.cmpi eq, %thread_id_x, %c0 : index
- cf.cond_br %70, ^bb9, ^bb10 +^bb9: // pred: ^bb8
- memref.store %cst_1, %mem[%c0] : memref<?xf16>
- cf.br ^bb10 +^bb10: // 2 preds: ^bb8, ^bb9
- return +}