MLIR: lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33
34 #include
35
36 namespace mlir {
37 namespace memref {
38 #define GEN_PASS_DEF_FLATTENMEMREFSPASS
39 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
40 }
41 }
42
43 using namespace mlir;
44
47 if (Attribute offsetAttr = dyn_cast(in)) {
48 return rewriter.createarith::ConstantIndexOp(
49 loc, cast(offsetAttr).getInt());
50 }
51 return cast(in);
52 }
53
54
55
60 int64_t sourceOffset;
62 auto sourceType = cast(source.getType());
63 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) {
64 assert(false);
65 }
66
67 memref::ExtractStridedMetadataOp stridedMetadata =
68 rewriter.creatememref::ExtractStridedMetadataOp(loc, source);
69
70 auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
73 std::tie(linearizedInfo, linearizedIndices) =
75 rewriter, loc, typeBit, typeBit,
76 stridedMetadata.getConstifiedMixedOffset(),
77 stridedMetadata.getConstifiedMixedSizes(),
78 stridedMetadata.getConstifiedMixedStrides(),
80
81 return std::make_pair(
82 rewriter.creatememref::ReinterpretCastOp(
83 loc, source,
85
87
90 }
91
93 auto type = cast(val.getType());
94 return type.getRank() > 1;
95 }
96
98 auto type = cast(val.getType());
99 return type.getLayout().isIdentity() ||
100 isa(type.getLayout());
101 }
102
103 namespace {
106 .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
107 memref::AllocOp>([](auto op) { return op.getMemref(); })
108 .template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
109 vector::MaskedStoreOp, vector::TransferReadOp,
110 vector::TransferWriteOp>(
111 [](auto op) { return op.getBase(); })
112 .Default([](auto) { return Value{}; });
113 }
114
115 template
116 static void castAllocResult(T oper, T newOper, Location loc,
118 memref::ExtractStridedMetadataOp stridedMetadata =
119 rewriter.creatememref::ExtractStridedMetadataOp(loc, oper);
121 oper, cast(oper.getType()), newOper,
123 stridedMetadata.getConstifiedMixedSizes(),
124 stridedMetadata.getConstifiedMixedStrides());
125 }
126
127 template
132 .template Casememref::AllocOp([&](auto oper) {
133 auto newAlloc = rewriter.creatememref::AllocOp(
134 loc, cast(flatMemref.getType()),
135 oper.getAlignmentAttr());
136 castAllocResult(oper, newAlloc, loc, rewriter);
137 })
138 .template Casememref::AllocaOp([&](auto oper) {
139 auto newAlloca = rewriter.creatememref::AllocaOp(
140 loc, cast(flatMemref.getType()),
141 oper.getAlignmentAttr());
142 castAllocResult(oper, newAlloca, loc, rewriter);
143 })
144 .template Casememref::LoadOp([&](auto op) {
145 auto newLoad = rewriter.creatememref::LoadOp(
146 loc, op->getResultTypes(), flatMemref, ValueRange{offset});
147 newLoad->setAttrs(op->getAttrs());
148 rewriter.replaceOp(op, newLoad.getResult());
149 })
150 .template Casememref::StoreOp([&](auto op) {
151 auto newStore = rewriter.creatememref::StoreOp(
152 loc, op->getOperands().front(), flatMemref, ValueRange{offset});
153 newStore->setAttrs(op->getAttrs());
154 rewriter.replaceOp(op, newStore);
155 })
156 .template Casevector::LoadOp([&](auto op) {
157 auto newLoad = rewriter.createvector::LoadOp(
158 loc, op->getResultTypes(), flatMemref, ValueRange{offset});
159 newLoad->setAttrs(op->getAttrs());
160 rewriter.replaceOp(op, newLoad.getResult());
161 })
162 .template Casevector::StoreOp([&](auto op) {
163 auto newStore = rewriter.createvector::StoreOp(
164 loc, op->getOperands().front(), flatMemref, ValueRange{offset});
165 newStore->setAttrs(op->getAttrs());
166 rewriter.replaceOp(op, newStore);
167 })
168 .template Casevector::MaskedLoadOp([&](auto op) {
169 auto newMaskedLoad = rewriter.createvector::MaskedLoadOp(
170 loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(),
171 op.getPassThru());
172 newMaskedLoad->setAttrs(op->getAttrs());
173 rewriter.replaceOp(op, newMaskedLoad.getResult());
174 })
175 .template Casevector::MaskedStoreOp([&](auto op) {
176 auto newMaskedStore = rewriter.createvector::MaskedStoreOp(
177 loc, flatMemref, ValueRange{offset}, op.getMask(),
178 op.getValueToStore());
179 newMaskedStore->setAttrs(op->getAttrs());
180 rewriter.replaceOp(op, newMaskedStore);
181 })
182 .template Casevector::TransferReadOp([&](auto op) {
183 auto newTransferRead = rewriter.createvector::TransferReadOp(
184 loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding());
185 rewriter.replaceOp(op, newTransferRead.getResult());
186 })
187 .template Casevector::TransferWriteOp([&](auto op) {
188 auto newTransferWrite = rewriter.createvector::TransferWriteOp(
189 loc, op.getVector(), flatMemref, ValueRange{offset});
190 rewriter.replaceOp(op, newTransferWrite);
191 })
192 .Default([&](auto op) {
193 op->emitOpError("unimplemented: do not know how to replace op.");
194 });
195 }
196
197 template
199 if constexpr (std::is_same_v<T, memref::AllocaOp> ||
200 std::is_same_v<T, memref::AllocOp>) {
202 } else {
203 return op.getIndices();
204 }
205 }
206
207 template
208 static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) {
210 .template Case<vector::TransferReadOp, vector::TransferWriteOp>(
211 [&](auto oper) {
212
213
214
215 auto permutationMap = oper.getPermutationMap();
216 if (!permutationMap.isIdentity() &&
217 !permutationMap.isMinorIdentity()) {
218 return rewriter.notifyMatchFailure(
219 oper, "only identity permutation map is supported");
220 }
221 mlir::ArrayAttr inbounds = oper.getInBounds();
222 if (llvm::any_of(inbounds, [](Attribute attr) {
223 return !cast(attr).getValue();
224 })) {
226 "only inbounds are supported");
227 }
228 return success();
229 })
230 .Default([&](auto op) { return success(); });
231 }
232
233 template
236 LogicalResult matchAndRewrite(T op,
238 LogicalResult canFlatten = canBeFlattened(op, rewriter);
239 if (failed(canFlatten)) {
240 return canFlatten;
241 }
242
243 Value memref = getTargetMemref(op);
245 return failure();
247 rewriter, op->getLoc(), memref, getIndices(op));
248 replaceOp(op, rewriter, flatMemref, offset);
249 return success();
250 }
251 };
252
253 struct FlattenMemrefsPass
254 : public mlir::memref::impl::FlattenMemrefsPassBase {
255 using Base::Base;
256
257 void getDependentDialects(DialectRegistry ®istry) const override {
258 registry.insert<affine::AffineDialect, arith::ArithDialect,
259 memref::MemRefDialect, vector::VectorDialect>();
260 }
261
262 void runOnOperation() override {
264
266
268 return signalPassFailure();
269 }
270 };
271
272 }
273
275 patterns.insert<MemRefRewritePatternmemref::LoadOp,
276 MemRefRewritePatternmemref::StoreOp,
277 MemRefRewritePatternmemref::AllocOp,
278 MemRefRewritePatternmemref::AllocaOp,
279 MemRefRewritePatternvector::LoadOp,
280 MemRefRewritePatternvector::StoreOp,
281 MemRefRewritePatternvector::TransferReadOp,
282 MemRefRewritePatternvector::TransferWriteOp,
283 MemRefRewritePatternvector::MaskedLoadOp,
284 MemRefRewritePatternvector::MaskedStoreOp>(
286 }
static std::pair< Value, Value > getFlattenMemrefAndOffset(OpBuilder &rewriter, Location loc, Value source, ValueRange indices)
Returns a collapsed memref and the linearized index to access the element at the specified indices.
static bool checkLayout(Value val)
static bool needFlattening(Value val)
static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, OpFoldResult in)
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult linearizedOffset