MLIR: lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

16#include "llvm/ADT/TypeSwitch.h"

17

19#define GEN_PASS_DEF_AMDGPUFOLDMEMREFOPSPASS

20#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"

21

30

34 Value &memrefBase, StringRef role) {

36 if (!defOp) {

37 return failure();

38 }

40 .Casememref::SubViewOp([&](memref::SubViewOp subviewOp) {

42 rewriter, loc, subviewOp.getMixedOffsets(),

43 subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices,

44 resolvedIndices);

45 memrefBase = subviewOp.getSource();

47 })

48 .Casememref::ExpandShapeOp([&](memref::ExpandShapeOp expandShapeOp) {

50 loc, rewriter, expandShapeOp, indices, resolvedIndices,

51 false))) {

52 return failure();

53 }

54 memrefBase = expandShapeOp.getViewSource();

56 })

57 .Casememref::CollapseShapeOp(

58 [&](memref::CollapseShapeOp collapseShapeOp) {

60 loc, rewriter, collapseShapeOp, indices,

61 resolvedIndices))) {

62 return failure();

63 }

64 memrefBase = collapseShapeOp.getViewSource();

66 })

69 op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or "

70 "CollapseShapeOp")

71 .str());

72 });

73}

74

80

82 Value memrefSource, memrefDest;

83

84 auto foldSrcResult =

85 foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(),

86 sourceIndices, memrefSource, "source");

87

88 if (failed(foldSrcResult)) {

89 memrefSource = op.getSrc();

90 sourceIndices = op.getSrcIndices();

91 }

92

93 auto foldDstResult =

94 foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(),

95 destIndices, memrefDest, "destination");

96

97 if (failed(foldDstResult)) {

98 memrefDest = op.getDst();

99 destIndices = op.getDstIndices();

100 }

101

102 rewriter.replaceOpWithNewOp(op, memrefSource, sourceIndices,

103 memrefDest, destIndices,

104 op.getTransferType());

105

107 }

108};

109

114}

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

This class implements the operand iterators for the Operation class.

Operation is the basic unit of execution within MLIR.

This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)

Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...

void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)

Definition FoldMemRefsOps.cpp:110

static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc, Value view, mlir::OperandRange indices, SmallVectorImpl< Value > &resolvedIndices, Value &memrefBase, StringRef role)

Definition FoldMemRefsOps.cpp:31

LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)

Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...

LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)

Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...

const FrozenRewritePatternSet & patterns

void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)

A fast walk-based pattern rewrite driver.

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...

Definition FoldMemRefsOps.cpp:23

void runOnOperation() override

Definition FoldMemRefsOps.cpp:24

Definition FoldMemRefsOps.cpp:75

LogicalResult matchAndRewrite(GatherToLDSOp op, PatternRewriter &rewriter) const override

Definition FoldMemRefsOps.cpp:77

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...