MLIR: lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
14
15using namespace mlir;
17
18
22 if (isa(ofr))
23 return ofr;
28 true)))
29 return failure();
31}
32
34 memref::AllocaOp allocaOp,
37 b.setInsertionPoint(allocaOp);
38 Location loc = allocaOp.getLoc();
39
41 for (OpFoldResult ofr : allocaOp.getMixedSizes()) {
44 return failure();
45 newSizes.push_back(*ub);
46 }
47
48
49 if (llvm::equal(allocaOp.getMixedSizes(), newSizes))
50 return allocaOp.getResult();
51
52
53 Value newAllocaOp =
54 AllocaOp::create(b, loc, newSizes, allocaOp.getType().getElementType());
55
56
59 return SubViewOp::create(b, loc, newAllocaOp, offsets,
60 allocaOp.getMixedSizes(), strides)
61 .getResult();
62}
63
64
65static UnrealizedConversionCastOp
67 UnrealizedConversionCastOp conversionOp, SubViewOp op) {
70 MemRefType newResultType = SubViewOp::inferRankReducedResultType(
71 op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
72 op.getMixedSizes(), op.getMixedStrides());
73 Value newSubview = SubViewOp::create(
74 rewriter, op.getLoc(), newResultType, conversionOp.getOperand(0),
75 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
76 auto newConversionOp = UnrealizedConversionCastOp::create(
77 rewriter, op.getLoc(), op.getType(), newSubview);
78 rewriter.replaceAllUsesWith(op.getResult(), newConversionOp->getResult(0));
79 return newConversionOp;
80}
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
99 "expected same number of results");
102
103
104
106 for (const auto &it :
108 unrealizedConversions.push_back(UnrealizedConversionCastOp::create(
109 rewriter, to->getLoc(), std::get<0>(it.value()).getType(),
110 std::get<1>(it.value())));
112 unrealizedConversions.back()->getResult(0));
113 }
114
115
116
117 for (int i = 0; i < static_cast(unrealizedConversions.size()); ++i) {
118 UnrealizedConversionCastOp conversion = unrealizedConversions[i];
119 assert(conversion->getNumOperands() == 1 &&
120 conversion->getNumResults() == 1 &&
121 "expected single operand and single result");
124
125
126 if (auto subviewOp = dyn_cast(user)) {
127 unrealizedConversions.push_back(
129 continue;
130 }
131
132
133
134
135
136
137 if (llvm::any_of(user->getResultTypes(),
138 [](Type t) { return isa(t); }))
139 continue;
140 if (llvm::any_of(user->getRegions(), [](Region &r) {
141 return llvm::any_of(r.getArguments(), [](BlockArgument bbArg) {
142 return isa(bbArg.getType());
143 });
144 }))
145 continue;
146
147
148
149
150 for (OpOperand &operand : user->getOpOperands()) {
151 if ([[maybe_unused]] auto castOp =
152 operand.get().getDefiningOp()) {
154 user, [&]() { operand.set(conversion->getOperand(0)); });
155 }
156 }
157 }
158 }
159
160
161 for (auto op : unrealizedConversions)
162 if (op->getUses().empty())
163 rewriter.eraseOp(op);
164}
165
167 memref::AllocaOp allocaOp,
172 return failure();
176}
177
179 RewriterBase &rewriter, memref::AllocOp alloc,
180 function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter) {
181 memref::DeallocOp dealloc = nullptr;
183 llvm::make_range(alloc->getIterator(), alloc->getBlock()->end())) {
184 dealloc = dyn_castmemref::DeallocOp(candidate);
185 if (dealloc && dealloc.getMemref() == alloc.getMemref() &&
186 (!filter || filter(alloc, dealloc))) {
187 break;
188 }
189 }
190
191 if (!dealloc)
192 return nullptr;
193
197 alloc, alloc.getMemref().getType(), alloc.getOperands());
198 rewriter.eraseOp(dealloc);
199 return alloca;
200}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static FailureOr< OpFoldResult > makeIndependent(OpBuilder &b, Location loc, OpFoldResult ofr, ValueRange independencies)
Make the given OpFoldResult independent of all independencies.
Definition IndependenceTransforms.cpp:19
static void replaceAndPropagateMemRefType(RewriterBase &rewriter, Operation *from, Operation *to)
Given an original op and a new, modified op with the same number of results, whose memref return type...
Definition IndependenceTransforms.cpp:96
static UnrealizedConversionCastOp propagateSubViewOp(RewriterBase &rewriter, UnrealizedConversionCastOp conversionOp, SubViewOp op)
Push down an UnrealizedConversionCastOp past a SubViewOp.
Definition IndependenceTransforms.cpp:66
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static LogicalResult computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, const Variable &var, ValueRange independencies, bool closedUB=false)
Compute a bound in that is independent of all values in independencies.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult materializeComputedBound(OpBuilder &b, Location loc, AffineMap boundMap, ArrayRef< std::pair< Value, std::optional< int64_t > > > mapOperands)
Materialize an already computed bound with Affine dialect ops.
FailureOr< Value > replaceWithIndependentOp(RewriterBase &rewriter, memref::AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
Definition IndependenceTransforms.cpp:166
FailureOr< Value > buildIndependentOp(OpBuilder &b, AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
memref::AllocaOp allocToAlloca(RewriterBase &rewriter, memref::AllocOp alloc, function_ref< bool(memref::AllocOp, memref::DeallocOp)> filter=nullptr)
Replaces the given alloc with the corresponding alloca and returns it if the following conditions are...
Definition IndependenceTransforms.cpp:178
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< std::pair< Value, std::optional< int64_t > > > ValueDimList
llvm::function_ref< Fn > function_ref