MLIR: lib/Conversion/TosaToSCF/TosaToSCF.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

20

21 using namespace mlir;

22 using namespace tosa;

23

28

29 Block *headBlock = &dstRegion.front();

30 for (auto it : llvm::zip(headBlock->getArguments(), operands))

32

33 auto yield = cast(headBlock->getTerminator());

35 rewriter.createscf::YieldOp(yield.getLoc(), yield.getInputs());

37

39 }

40

45

46 Block *headBlock = &dstRegion.front();

47

48 auto yield = cast(headBlock->getTerminator());

50 if (isCond) {

51 auto condition =

52 rewriter.createtensor::ExtractOp(yield.getLoc(), yield.getOperand(0));

53 rewriter.createscf::ConditionOp(yield.getLoc(), condition,

55 } else {

57 rewriter.createscf::YieldOp(yield.getLoc(), yield.getInputs());

58 }

60 }

61

62 namespace {

63

65 public:

67

68 LogicalResult matchAndRewrite(tosa::IfOp op,

70 auto condition =

71 rewriter.createtensor::ExtractOp(op.getLoc(), op.getCondition());

72 auto newIf = rewriter.createscf::IfOp(op.getLoc(), op.getResultTypes(),

73 condition, true);

74

75 inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputList(),

76 rewriter);

77 inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputList(),

78 rewriter);

79

80 rewriter.replaceOp(op, newIf.getResults());

81 return success();

82 }

83 };

84

85 class ScatterOpConverter : public OpRewritePatterntosa::ScatterOp {

87 int64_t dim) {

88 return builder.createOrFoldtensor::DimOp(loc, tensor, dim);

89 }

90

92 int64_t value) {

93 return builder.createarith::ConstantIndexOp(loc, value);

94 }

95

96 public:

98

99 LogicalResult matchAndRewrite(tosa::ScatterOp scatter,

101 auto valuesIn = scatter.getValuesIn();

102 auto indices = scatter.getIndices();

103 auto input = scatter.getInput();

104 auto loc = scatter.getLoc();

105

106

107 auto dimN = createTensorDim(rewriter, loc, input, 0);

108 auto dimW = createTensorDim(rewriter, loc, input, 1);

109 auto dimC = createTensorDim(rewriter, loc, input, 2);

110

111 auto zero = createIndexConst(rewriter, loc, 0);

112 auto one = createIndexConst(rewriter, loc, 1);

113

114

118

121 auto n = ivs[0];

122

123

124 auto index = builder.createtensor::ExtractOp(loc, indices, ivs);

125 auto castIndex = builder.createarith::IndexCastOp(

127

128

129 auto inputOffset = llvm::to_vector(ivs);

130 inputOffset.push_back(zero);

131

134

135 auto slice = builder.createtensor::ExtractSliceOp(

136 loc, input, inputOffset, sizes, strides);

137

138

140 auto updated = builder.createtensor::InsertSliceOp(

141 loc, slice, args[0], outputOffset, sizes, strides);

142

143 return {updated};

144 };

145

148 rewriter.replaceOp(scatter, loops.results);

149

150 return success();

151 }

152 };

153

154 class WhileOpConverter : public OpRewritePatterntosa::WhileOp {

155 public:

157

158 LogicalResult matchAndRewrite(tosa::WhileOp op,

160 auto newWhile = rewriter.createscf::WhileOp(

161 op.getLoc(), op.getResultTypes(), op.getInputList());

162 rewriter.createBlock(&newWhile.getBefore());

163 rewriter.createBlock(&newWhile.getAfter());

164

165 inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter, true);

166 inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter, false);

167

168 rewriter.replaceOp(op, newWhile.getResults());

169

170 return success();

171 }

172 };

173

174 }

175

178 patterns->add<IfOpConverter, ScatterOpConverter, WhileOpConverter>(

180 }

static void inlineIfCase(Region &srcRegion, Region &dstRegion, OperandRange operands, PatternRewriter &rewriter)

static void inlineWhileCase(Region &srcRegion, Region &dstRegion, PatternRewriter &rewriter, bool isCond)

Block represents an ordered list of Operations.

unsigned getNumArguments()

Operation * getTerminator()

Get the terminator operation of this block.

void eraseArguments(unsigned start, unsigned num)

Erases 'num' arguments from the index 'start'.

BlockArgListType getArguments()

void replaceAllUsesWith(ValueT &&newValue)

Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

This class helps build Operations.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)

Clone the blocks that belong to "region" before the given position in another region "parent".

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class implements the operand iterators for the Operation class.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class contains a list of basic blocks and a link to the parent operation it is attached to.

static std::unique_ptr< T > create(Args &&...args)

This method provides a convenient interface for creating and initializing derived rewrite patterns of...

virtual void eraseBlock(Block *block)

This method erases all operations in a block.

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)

Creates a perfect nest of "for" loops, i.e.

SmallVector< Value > ValueVector

An owning vector of values, handy to return from functions.

void populateTosaToSCFConversionPatterns(RewritePatternSet *patterns)

Include the generated interface declarations.

const FrozenRewritePatternSet & patterns

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...