MLIR: lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

15

20 #include "llvm/ADT/SmallVector.h"

21 #include "llvm/ADT/TypeSwitch.h"

22

23 namespace mlir {

24 #define GEN_PASS_DEF_LINALGNAMEDOPCONVERSIONPASS

25 #include "mlir/Dialect/Linalg/Passes.h.inc"

26 }

27

28 using namespace mlir;

30

32 return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));

33 }

34

35 static LogicalResult

40 auto linalgOp = dyn_cast(operation);

41

42 if (!linalgOp || !linalgOp.hasPureTensorSemantics())

43 return failure();

44

45 auto result = operation->getResult(0);

46

47 auto kernelTy = dyn_cast(kernel.getType());

48 auto initTy = dyn_cast(init.getType());

49 auto resultTy = dyn_cast(result.getType());

50 if (!kernelTy || !initTy || !resultTy)

51 return failure();

52

53 if (kernelTy.getDimSize(3) != 1)

54 return failure();

55

56

60 {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},

61 kernelTy.getElementType());

62 auto collapsedKernel = rewriter.createtensor::CollapseShapeOp(

63 loc, newKernelTy, kernel, collapsedKernelDims);

64

65

69 auto newInitTy =

71 initTy.getDimSize(2), initTy.getDimSize(3)},

72 initTy.getElementType());

73 auto collapsedInit = rewriter.createtensor::CollapseShapeOp(

74 loc, newInitTy, init, collapsedInitDims);

75

79 .Case([&](auto op) {

81 return rewriter.create(

82 loc, newInitTy, ValueRange{input, collapsedKernel},

83 ValueRange{collapsedInit}, stride, dilation);

84 })

85 .Case([&](auto op) {

87 return rewriter.create(

88 loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},

89 ValueRange{collapsedInit}, stride, dilation);

90 })

91 .Default([](Operation *op) { return nullptr; });

92 if (!newConv)

93 return failure();

94 for (auto attr : preservedAttrs)

95 newConv->setAttr(attr.getName(), attr.getValue());

96

97

99 operation, resultTy, newConv->getResult(0), collapsedInitDims);

100 return success();

101 }

102

103 namespace {

104 struct SimplifyDepthwiseConvOp

107

108 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,

110 Operation *operation = op.getOperation();

111 Value input = op.getDpsInputOperand(0)->get();

112 Value kernel = op.getDpsInputOperand(1)->get();

113 Value init = op.getDpsInitOperand(0)->get();

114

115 auto stride = op.getStrides();

116 auto dilation = op.getDilations();

117

119 nullptr, init, stride, dilation,

120 rewriter);

121 }

122 };

123

124 struct SimplifyDepthwiseConvQOp

127

128 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,

130 Operation *operation = op.getOperation();

131 Value input = op.getDpsInputOperand(0)->get();

132 Value kernel = op.getDpsInputOperand(1)->get();

133 Value iZp = op.getDpsInputOperand(2)->get();

134 Value kZp = op.getDpsInputOperand(3)->get();

135 Value init = op.getDpsInitOperand(0)->get();

136

137 auto stride = op.getStrides();

138 auto dilation = op.getDilations();

139

141 init, stride, dilation, rewriter);

142 }

143 };

144

145 struct LinalgNamedOpConversionPass

146 : public impl::LinalgNamedOpConversionPassBase<

147 LinalgNamedOpConversionPass> {

148 using impl::LinalgNamedOpConversionPassBase<

149 LinalgNamedOpConversionPass>::LinalgNamedOpConversionPassBase;

150

151 void runOnOperation() override {

156 return signalPassFailure();

157 }

158 };

159 }

160

163 patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(

165 }

static llvm::SmallVector< int64_t > getIndicesVector(int start, int end)

static LogicalResult matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, Value iZp, Value kZp, Value init, Attribute stride, Attribute dilation, PatternRewriter &rewriter)

Attributes are known-constant values of operations.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

Operation is the basic unit of execution within MLIR.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

MLIRContext * getContext()

Return the context this operation is associated with.

Location getLoc()

The source location the operation was defined or derived from.

void setAttr(StringAttr name, Attribute value)

If the an attribute exists with the specified name, change it to the new value.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns)

Patterns to convert from one named op to another.

SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)

Returns an attribute list that excludes pre-defined attributes.

Include the generated interface declarations.

LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)

Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...

const FrozenRewritePatternSet & patterns

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...