MLIR: lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

12

13 using namespace mlir;

15

16

19 return false;

20

23 if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))

24 return true;

25

26

27

28

29

30

31

32

34 genericOp.getRegionOutputArgs()[result.getResultNumber()];

36 return false;

38

39

41 return false;

42

43

44 auto yieldOp = dyn_castlinalg::YieldOp(argUserOp);

45 if (!yieldOp)

46 return false;

47

48

49 if (yieldOp.getOperand(result.getResultNumber()) != outputArg)

50 return false;

51

52 return true;

53 }

54

55

56

57

58

59

60

61

62

67 llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;

68 llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;

69 for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {

70 OpOperand *inputOpOperand = en.value();

71

72

73 if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {

74

75

76 droppedOpOperands.push_back(inputOpOperand);

77 if (genericOp.canOpOperandsBeDropped(droppedOpOperands))

78 continue;

79 droppedOpOperands.pop_back();

80 }

81

82

83 AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);

84 auto it =

85 dedupedInputs.find(std::make_pair(inputOpOperand->get(), indexingMap));

86 if (it != dedupedInputs.end()) {

87 origToNewPos[en.index()] = it->second;

88 droppedOpOperands.push_back(inputOpOperand);

89 continue;

90 }

91

92

93 origToNewPos[en.index()] = newInputOperands.size();

94 dedupedInputs[{inputOpOperand->get(), indexingMap}] =

95 newInputOperands.size();

96 newInputOperands.push_back(inputOpOperand->get());

97 newIndexingMaps.push_back(indexingMap);

98 }

99 return origToNewPos;

100 }

101

102

103

104

105

110 llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;

111 llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>

112 dedupedOutpts;

113

114

115 if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {

116 for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {

117 origToNewPos[en.index()] = newOutputOperands.size();

118 newOutputOperands.push_back(en.value().get());

119 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&en.value()));

120 }

121 return origToNewPos;

122 }

123

124

125

126

127

128 auto yieldOp = cast(genericOp.getBody()->getTerminator());

129 for (const auto &outputOpOperand :

131 OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());

133 genericOp.getMatchingIndexingMap(&outputOpOperand.value());

134 auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,

135 yieldOp->getOperand(outputOpOperand.index()));

137

138

139

140

141 droppedOpOperands.push_back(&outputOpOperand.value());

142 if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {

143 continue;

144 }

145 droppedOpOperands.pop_back();

146 }

147

148 if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {

149

150

151

152

153

154 auto it = dedupedOutpts.find(key);

155 if (it != dedupedOutpts.end()) {

156 origToNewPos[outputOpOperand.index()] = it->second;

157 droppedOpOperands.push_back(&outputOpOperand.value());

158 continue;

159 }

160 }

161

162 origToNewPos[outputOpOperand.index()] = newOutputOperands.size();

163 dedupedOutpts[key] = newOutputOperands.size();

164 newOutputOperands.push_back(outputOpOperand.value().get());

165 newIndexingMaps.push_back(

166 genericOp.getMatchingIndexingMap(&outputOpOperand.value()));

167 }

168 return origToNewPos;

169 }

170

171

173 GenericOp genericOp, GenericOp newOp,

174 const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,

175 const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,

177

178 Block *newOpBlock = &newOp.getRegion().front();

179 assert(newOpBlock->empty() && "expected new op to have an empty payload");

180 Block *origOpBlock = &genericOp.getRegion().front();

182

183

184

185 auto updateReplacements =

188 const llvm::SmallDenseMap<unsigned, unsigned> &map) {

189 for (const auto &origOperand : llvm::enumerate(origOperands)) {

190 auto it = map.find(origOperand.index());

191 if (it == map.end())

192 continue;

193 OpOperand *newOperand = newOperands[it->second];

194 replacements[origOperand.value()->getOperandNumber()] =

196 }

197 };

198

201 updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);

202

204 genericOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));

206 newOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));

207 updateReplacements(origOutputOperands, newOutputOperands,

208 origOutsToNewOutsPos);

209

210

211 if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {

213 YieldOp origYieldOp = cast(origOpBlock->getTerminator());

215

217 for (const auto &yieldOpOperands :

219 auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());

220 if (it == origOutsToNewOutsPos.end())

221 continue;

222 newYieldVals[it->second] = yieldOpOperands.value();

223 }

225 }

226

227 rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);

228 }

229

230 FailureOrlinalg::GenericOp

232 RewriterBase &rewriter, linalg::GenericOp genericOp, bool removeOutputs) {

233

234

236

237

240

241

242 llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =

244 newIndexingMaps);

245

246

247 llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =

249 newIndexingMaps, removeOutputs);

250

251

252 if (newInputOperands.size() + newOutputOperands.size() ==

253 genericOp->getNumOperands())

254 return genericOp;

255

256

257 Location loc = genericOp.getLoc();

259 for (Value v : newOutputOperands)

260 if (isa(v.getType()))

261 newResultTypes.push_back(v.getType());

262 auto newOp = rewriter.create(

263 loc, newResultTypes, newInputOperands, newOutputOperands,

265 genericOp.getIteratorTypes(), genericOp.getDocAttr(),

266 genericOp.getLibraryCallAttr(),

268 return;

269 });

270

273 if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))

274 newOp->setAttr(kv.getName(), kv.getValue());

275

276

277 populateOpPayload(genericOp, newOp, origInsToNewInsPos, origOutsToNewOutsPos,

278 rewriter);

279

280

281 SmallVector replacementsVals(genericOp->getNumResults(), nullptr);

282 for (const auto &result : llvm::enumerate(genericOp.getResults())) {

283 auto it = origOutsToNewOutsPos.find(result.index());

284 if (it == origOutsToNewOutsPos.end())

285 continue;

286 replacementsVals[result.index()] = newOp.getResult(it->second);

287 }

288 rewriter.replaceOp(genericOp, replacementsVals);

289 return newOp;

290 }

291

292 namespace {

293

294 struct DeduplicateAndRemoveDeadOperandsAndResults

296 DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx,

297 bool removeOutputs)

298 : OpRewritePattern(ctx), removeOutputs(removeOutputs) {}

299

300 LogicalResult matchAndRewrite(GenericOp genericOp,

303 rewriter, genericOp, removeOutputs);

304 if (failed(newOp) || newOp.value() == genericOp) {

306 genericOp, "failed to dedup operands/remove dead results");

307 }

308 return success();

309 }

310

311 private:

312

313 bool removeOutputs;

314 };

315

316

317

318

319

320

321

322

323 struct RemoveUnusedCycleInGenericOp : public OpRewritePattern {

325

326 LogicalResult matchAndRewrite(GenericOp genericOp,

328

329

330 if (!genericOp.hasPureTensorSemantics())

331 return failure();

332

333 bool hasRemovedCycles = false;

334

335 for (const auto &outputOpOperand :

337

338

339 Value result = genericOp.getResult(outputOpOperand.index());

341 continue;

342

343

345 genericOp.getRegionOutputArgs()[outputOpOperand.index()];

347 continue;

348

349

352 continue;

353

354

356 if (!isalinalg::YieldOp(cycleUserOp))

357 continue;

358

359

360 if (cycleUserOp->getOperand(outputOpOperand.index()) !=

362 continue;

363

364

365

366 rewriter.replaceOp(cycleOp, outputArg);

368 hasRemovedCycles = true;

369 }

370

371 if (hasRemovedCycles) {

372 return success();

373 }

374

375 return failure();

376 }

377 };

378

379

380

381

382

383

384

385

386

387

388

389

390 struct FoldDuplicateInputBbArgs : public OpRewritePattern {

392

393 LogicalResult matchAndRewrite(GenericOp genericOp,

395

397 for (int i = 0; i < genericOp.getNumDpsInputs(); ++i) {

398

399 if (genericOp.getBody()->getArgument(i).getUses().empty())

400 continue;

401

402 for (int j = genericOp->getNumOperands() - 1; j > i; --j) {

403 if (genericOp->getOperand(i) == genericOp->getOperand(j) &&

404 genericOp.getIndexingMapsArray()[i] ==

405 genericOp.getIndexingMapsArray()[j]) {

406 replacements[i] = j;

407 break;

408 }

409 }

410 }

411

412

413 if (replacements.empty())

414 return failure();

415

416

418 for (auto [before, after] : replacements) {

419 BlockArgument bbArg = genericOp.getBody()->getArgument(before);

420 BlockArgument replacement = genericOp.getBody()->getArgument(after);

422 }

423 });

424

425 return success();

426 }

427 };

428

429 }

430

433 patterns.insert(

434 patterns.getContext(), true);

435 patterns.insert(patterns.getContext());

436 }

437

440 patterns.insert(

441 patterns.getContext(), false);

442 patterns.insert(patterns.getContext());

443 }

static llvm::SmallDenseMap< unsigned, unsigned > deduplicateOutputOperands(GenericOp genericOp, SmallVector< OpOperand * > &droppedOpOperands, SmallVector< Value > &newOutputOperands, SmallVector< AffineMap > &newIndexingMaps, bool removeOutputs)

static llvm::SmallDenseMap< unsigned, unsigned > deduplicateInputOperands(GenericOp genericOp, SmallVector< OpOperand * > &droppedOpOperands, SmallVector< Value > &newInputOperands, SmallVector< AffineMap > &newIndexingMaps)

static void populateOpPayload(GenericOp genericOp, GenericOp newOp, const llvm::SmallDenseMap< unsigned, unsigned > &origInsToNewInsPos, const llvm::SmallDenseMap< unsigned, unsigned > &origOutsToNewOutsPos, RewriterBase &rewriter)

static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result)

Return true if the result of an operation genericOp is dead.

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

This class represents an argument of a Block.

Block represents an ordered list of Operations.

BlockArgument getArgument(unsigned i)

unsigned getNumArguments()

Operation * getTerminator()

Get the terminator operation of this block.

ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)

IRValueT get() const

Return the current value being used by this operand.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

NamedAttribute represents a combination of a name and an Attribute value.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents an operand of an operation.

unsigned getOperandNumber()

Return which operand this is in the OpOperand list of the Operation.

This is a value defined by a result of an operation.

unsigned getResultNumber() const

Returns the number of this result.

Operation is the basic unit of execution within MLIR.

bool use_empty()

Returns true if this operation has no uses.

Value getOperand(unsigned idx)

bool hasOneUse()

Returns true if this operation has exactly one use.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

user_iterator user_begin()

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

void replaceAllUsesWith(Value from, Value to)

Find uses of from and replace them with to.

void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)

Inline the operations of block 'source' into the end of block 'dest'.

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

bool use_empty() const

Returns true if this value has no uses.

user_iterator user_begin() const

bool hasOneUse() const

Returns true if this value has exactly one use.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

FailureOr< linalg::GenericOp > deduplicateOperandsAndRemoveDeadResults(RewriterBase &rewriter, linalg::GenericOp genericOp, bool removeOutputs)

Method to deduplicate operands and remove dead results of linalg.generic operations.

void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)

Pattern to remove dead operands and results of linalg.generic operations.

void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)

Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.

Include the generated interface declarations.

const FrozenRewritePatternSet & patterns

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

Eliminates variable at the specified position using Fourier-Motzkin variable elimination.