MLIR: lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
10
16#include "llvm/ADT/TypeSwitch.h"
17
19#define GEN_PASS_DEF_AMDGPUFOLDMEMREFOPSPASS
20#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
21
30
34 Value &memrefBase, StringRef role) {
36 if (!defOp) {
37 return failure();
38 }
40 .Casememref::SubViewOp([&](memref::SubViewOp subviewOp) {
42 rewriter, loc, subviewOp.getMixedOffsets(),
43 subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices,
44 resolvedIndices);
45 memrefBase = subviewOp.getSource();
47 })
48 .Casememref::ExpandShapeOp([&](memref::ExpandShapeOp expandShapeOp) {
50 loc, rewriter, expandShapeOp, indices, resolvedIndices,
51 false))) {
52 return failure();
53 }
54 memrefBase = expandShapeOp.getViewSource();
56 })
57 .Casememref::CollapseShapeOp(
58 [&](memref::CollapseShapeOp collapseShapeOp) {
60 loc, rewriter, collapseShapeOp, indices,
61 resolvedIndices))) {
62 return failure();
63 }
64 memrefBase = collapseShapeOp.getViewSource();
66 })
69 op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or "
70 "CollapseShapeOp")
71 .str());
72 });
73}
74
80
82 Value memrefSource, memrefDest;
83
84 auto foldSrcResult =
85 foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(),
86 sourceIndices, memrefSource, "source");
87
88 if (failed(foldSrcResult)) {
89 memrefSource = op.getSrc();
90 sourceIndices = op.getSrcIndices();
91 }
92
93 auto foldDstResult =
94 foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(),
95 destIndices, memrefDest, "destination");
96
97 if (failed(foldDstResult)) {
98 memrefDest = op.getDst();
99 destIndices = op.getDstIndices();
100 }
101
102 rewriter.replaceOpWithNewOp(op, memrefSource, sourceIndices,
103 memrefDest, destIndices,
104 op.getTransferType());
105
107 }
108};
109
114}
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Definition FoldMemRefsOps.cpp:110
static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc, Value view, mlir::OperandRange indices, SmallVectorImpl< Value > &resolvedIndices, Value &memrefBase, StringRef role)
Definition FoldMemRefsOps.cpp:31
LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition FoldMemRefsOps.cpp:23
void runOnOperation() override
Definition FoldMemRefsOps.cpp:24
Definition FoldMemRefsOps.cpp:75
LogicalResult matchAndRewrite(GatherToLDSOp op, PatternRewriter &rewriter) const override
Definition FoldMemRefsOps.cpp:77
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...