MLIR: lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

13 #include

14

15 using namespace mlir;

17

18 namespace {

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84 struct DecomposeLinalgOp : public OpRewritePattern {

86

87 LogicalResult matchAndRewrite(GenericOp genericOp,

89

90 private:

91

92

93 GenericOp createPeeledGenericOp(GenericOp genericOp,

95

96

97

98 GenericOp createResidualGenericOp(GenericOp genericOp,

99 GenericOp peeledGenericOp,

101 };

102 }

103

104

106 GenericOp op) {

110 auto allShapesSizes =

111 cast(op.getOperation()).createFlatListOfOperandDims(b, loc);

112 AffineMap map = op.getShapesToLoopsMap();

115 allShapesSizes);

116 }

117

118

123 for (const auto &position :

125 return cast(expr).getPosition();

126 })))

127 permutedValues[position.value()] = values[position.index()];

128 return permutedValues;

129 }

130

131

134 "expected scalar type while computing zero value");

135 if (isa(elementType))

136 return b.createarith::ConstantIntOp(loc, 0, elementType);

137 if (elementType.isIndex())

138 return b.createarith::ConstantIndexOp(loc, 0);

139

140 auto floatType = cast(elementType);

141 return b.createarith::ConstantFloatOp(

142 loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);

143 }

144

145 GenericOp

146 DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,

148 Block *body = genericOp.getBody();

149 Operation *peeledScalarOperation = &(*body->begin());

151 genericOp.getIndexingMapsArray();

152

153

154

155 Location loc = genericOp.getLoc();

159

160

161 for (auto scalarOpResult : peeledScalarOperation->getResults()) {

162

163

164

165 std::optional resultNumber;

166 for (auto *user : scalarOpResult.getUsers()) {

167 if (auto yieldOp = dyn_cast(user)) {

168

169 for (OpOperand &yieldOperand : yieldOp->getOpOperands()) {

170 if (yieldOperand.get() == scalarOpResult) {

171 resultNumber = yieldOperand.getOperandNumber();

172 break;

173 }

174 }

175 assert(resultNumber && "unable to find use of a value in its user");

176 break;

177 }

178 }

179 if (resultNumber) {

180 newInitValues.push_back(

181 genericOp.getDpsInitOperand(*resultNumber)->get());

182 OpResult result = cast(genericOp.getResult(*resultNumber));

183 newResultTypes.push_back(result.getType());

184 peeledGenericOpIndexingMaps.push_back(

185 genericOp.getIndexingMapMatchingResult(result));

186 continue;

187 }

188

189

191 Value emptyTensor =

192 rewriter.createtensor::EmptyOp(loc, domain, scalarOpResult.getType());

193 newInitValues.push_back(emptyTensor);

194 newResultTypes.push_back(emptyTensor.getType());

195 peeledGenericOpIndexingMaps.push_back(indexingMap);

196 }

197

198

200 outsOperands.append(newInitValues.begin(), newInitValues.end());

201 SmallVector resultTypes = llvm::to_vector(genericOp.getResultTypes());

202 resultTypes.append(newResultTypes.begin(), newResultTypes.end());

203 auto indexingMapAttr =

205 return rewriter.create(

206 loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr,

207 genericOp.getIteratorTypes(), nullptr, nullptr,

209 }

210

211 GenericOp

212 DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,

213 GenericOp peeledGenericOp,

215

216

217 SmallVector residualGenericOpOperands = genericOp.getInputs();

218 unsigned origNumResults = genericOp.getNumResults();

219 unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();

221 for (auto resultNum :

222 llvm::seq(origNumResults, peeledGenericOpNumResults))

223 extraIns.push_back(peeledGenericOp->getResult(resultNum));

224 residualGenericOpOperands.append(extraIns);

225

226

227

228 auto indexingMaps = llvm::to_vector(

229 llvm::map_range(genericOp.getDpsInputOperands(), [&](OpOperand *operand) {

230 return genericOp.getMatchingIndexingMap(operand);

231 }));

232 for (auto resultNum :

233 llvm::seq(origNumResults, peeledGenericOpNumResults)) {

234 OpResult result = cast(peeledGenericOp.getResult(resultNum));

235 indexingMaps.push_back(

236 peeledGenericOp.getIndexingMapMatchingResult(result));

237 }

238 for (OpOperand &outOperand : genericOp.getDpsInitsMutable())

239 indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand));

240

242 return rewriter.create(

243 genericOp->getLoc(), genericOp->getResultTypes(),

244 residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr,

245 genericOp.getIteratorTypes(), nullptr, nullptr,

247 }

248

249 LogicalResult

250 DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,

252

253 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) {

255 "unhandled decomposition of operation "

256 "with non-parallel iterator types");

257 }

258

259

260

261 if (!genericOp.hasPureTensorSemantics()) {

263 genericOp, "only operations with tensor semantics are handled");

264 }

265

266 if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) {

267 return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation();

268 })) {

270 genericOp, "unhandled decomposition of generic op with out operand not "

271 "accessed using a permutation");

272 }

273

274

275 Block *body = genericOp.getBody();

278 "operation has less than 3 statements");

279 }

280

281

282 if (llvm::any_of(body->getOperations().begin()->getResultTypes(),

283 [](Type t) { return !t.isIntOrIndexOrFloat(); })) {

286 "expected return type to be only int, index or float");

287 }

288

289 GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter);

290 GenericOp residualGenericOp =

291 createResidualGenericOp(genericOp, peeledGenericOp, rewriter);

292

293

294

295 Block *peeledGenericOpBody = peeledGenericOp.getBody();

296 Block *residualGenericOpBody = residualGenericOp.getBody();

297 assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() &&

298 "expected split generic ops to have empty region");

301 residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(),

303

304 Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin());

305 auto *yieldOp = residualGenericOpBody->getTerminator();

306 {

307

311 for (auto origYield : yieldOp->getOperands()) {

312 if (origYield.getDefiningOp() == peeledScalarOperation) {

313 yieldedVals.push_back(origYield);

314 } else {

315

316

317

320 yieldedVals.push_back(

321 getZero(rewriter, genericOp.getLoc(), origYield.getType()));

322 }

323 }

324 yieldedVals.append(llvm::to_vector(

325 llvm::map_range(peeledScalarOperation->getResults(),

327 rewriter.create(genericOp.getLoc(), yieldedVals);

328 }

329

330

331

332 unsigned origNumInputs = genericOp.getNumDpsInputs();

333 for (const auto &inputBlockArg :

335 Value residualOpReplacementArg =

336 residualGenericOpBody->getArgument(inputBlockArg.index());

338 inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) {

339 return use.getOwner()->getBlock() == residualGenericOpBody;

340 });

341

342 Value peeledOpReplacementArg =

343 peeledGenericOpBody->getArgument(inputBlockArg.index());

345 inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) {

346 return use.getOwner()->getBlock() == peeledGenericOpBody;

347 });

348 }

349

350

351

352

353

355 for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {

356 OpResult opr = dyn_cast(yieldValue.value());

357 if (!opr || opr.getOwner() != peeledScalarOperation)

358 replacements.push_back(residualGenericOp.getResult(yieldValue.index()));

359 else

360 replacements.push_back(peeledGenericOp->getResult(yieldValue.index()));

361 }

362

363

364

365 {

367 unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults();

368 scalarReplacements.reserve(peeledScalarOpNumResults);

369 for (auto num : llvm::seq(0, peeledScalarOpNumResults))

370 scalarReplacements.push_back(

371 residualGenericOpBody->getArgument(num + origNumInputs));

372 bool allUsesReplaced = false;

374 residualGenericOpBody, &allUsesReplaced);

375 assert(!allUsesReplaced &&

376 "peeled scalar operation is erased when it wasnt expected to be");

377 }

378

379

380 rewriter.replaceOp(genericOp, replacements);

381 return success();

382 }

383

387

388 if (removeDeadArgsAndResults)

390 }

static Value getZero(OpBuilder &b, Location loc, Type elementType)

Get zero value for an element type.

SmallVector< OpFoldResult > permuteValues(ArrayRef< OpFoldResult > values, AffineMap map)

Helper method to permute the list of values based on the map.

static SmallVector< OpFoldResult > getGenericOpLoopRange(OpBuilder &b, GenericOp op)

Helper method to compute the range of a generic op.

Base type for affine expression.

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

ArrayRef< AffineExpr > getResults() const

bool isPermutation() const

Returns true if the AffineMap represents a symbol-less permutation map.

Block represents an ordered list of Operations.

BlockArgument getArgument(unsigned i)

Operation * getTerminator()

Get the terminator operation of this block.

OpListType & getOperations()

AffineMap getMultiDimIdentityMap(unsigned rank)

ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)

This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents an operand of an operation.

This is a value defined by a result of an operation.

Operation * getOwner() const

Returns the operation that owns this result.

Operation is the basic unit of execution within MLIR.

result_range getResults()

unsigned getNumResults()

Return the number of results held by this operation.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)

Find uses of from and replace them with to if the functor returns true.

void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)

Find uses of from within block and replace them with to.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isIntOrIndexOrFloat() const

Return true if this is an integer (of any signedness), index, or float type.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Variant of makeComposedFoldedAffineApply suitable for multi-result maps.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)

Pattern to remove dead operands and results of linalg.generic operations.

void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, bool removeDeadArgsAndResults=true)

Populate patterns for splitting a LinalgOp with multiple statements within its payload into multiple ...

Include the generated interface declarations.

const FrozenRewritePatternSet & patterns

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...