MLIR: lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

26 #include "llvm/ADT/STLExtras.h"

27 #include "llvm/Support/Debug.h"

28

29 using namespace mlir;

31

32

33

37 OperandRange::iterator &elementIt,

39 if (dim == static_cast<int>(shape.size()) - 1) {

40 for (int i = 0; i < shape.back(); ++i) {

41 indices.back() = constants[i];

42 destination = rewriter.createtensor::InsertOp(loc, *elementIt,

43 destination, indices);

44 ++elementIt;

45 }

46 return destination;

47 }

48 for (int i = 0; i < shape[dim]; ++i) {

49 indices[dim] = constants[i];

50 destination = createInserts(rewriter, loc, dim + 1, destination, shape,

51 constants, elementIt, indices);

52 }

53 return destination;

54 }

55

56

57

61 auto tensorType = dyn_cast(tensorSource.getType());

62 assert(tensorType && "expected ranked tensor");

63 assert(isa(memrefDest.getType()) && "expected ranked memref");

64

65 switch (options.memcpyOp) {

68

69

70 auto materializeOp = b.createbufferization::MaterializeInDestinationOp(

71 loc, tensorSource, memrefDest);

72 materializeOp.setWritable(true);

73 } break;

75

76

77

78 Value toBuffer = b.createbufferization::ToBufferOp(

80 tensorSource, true);

81 b.creatememref::CopyOp(loc, toBuffer, memrefDest);

82 } break;

84

85

86

87 Value toBuffer = b.createbufferization::ToBufferOp(

89 tensorSource, true);

90 b.createlinalg::CopyOp(loc, toBuffer, memrefDest);

91 } break;

92 };

93 }

94

99 RankedTensorType resultType = padOp.getResultType();

100

101

102

103 Value yieldedValue =

104 casttensor::YieldOp(padOp.getBody()->getTerminator()).getValue();

106

107 bool outsideBbArg =

108 isa(yieldedValue) &&

109 cast(yieldedValue).getOwner()->getParentOp() !=

110 padOp.getOperation();

111

112 bool outsideOpResult =

113 isa(yieldedValue) &&

115 bool invariantYieldedValue = outsideBbArg || outsideOpResult;

117

120 Value fillValue =

121 arithDialect

124 ->getResult(0);

125 auto fillOp = rewriter.createlinalg::FillOp(loc, ValueRange(fillValue),

127 return fillOp;

128 }

129

130 if (invariantYieldedValue) {

131

132 auto fillOp = rewriter.createlinalg::FillOp(loc, ValueRange(yieldedValue),

134 return fillOp;

135 }

136

137

139 utils::IteratorType::parallel);

142 auto genericOp = rewriter.createlinalg::GenericOp(

143 loc, resultType, ValueRange(),

144 ValueRange{dest},

145 indexingMaps, iteratorTypes);

146 Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},

147 resultType.getElementType(), loc);

150 for (int64_t i = 0; i < resultType.getRank(); ++i)

151 bbArgReplacements.push_back(rewriter.createlinalg::IndexOp(loc, i));

152 rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements);

153

154

155 auto yieldOp = casttensor::YieldOp(body->getTerminator());

156 rewriter.replaceOpWithNewOplinalg::YieldOp(yieldOp, yieldOp.getValue());

157 return genericOp;

158 }

159

162 auto tensorType = cast(value.getType());

163 if (tensorType.hasStaticShape())

164 return {};

165

166

168 if (isa(value) &&

171 for (int64_t i = 0; i < tensorType.getRank(); ++i) {

172 if (tensorType.isDynamicDim(i))

173 dynSizes.push_back(cast(

174 reifiedShape[cast(value).getResultNumber()][i]));

175 }

176 return dynSizes;

177 }

178

179

181 for (int64_t i = 0; i < tensorType.getRank(); ++i) {

182 if (tensorType.isDynamicDim(i))

183 dynSizes.push_back(

185 b.createarith::ConstantIndexOp(value.getLoc(), i)));

186 }

187 return dynSizes;

188 }

189

195 auto tensorType = cast(value.getType());

196

197

198 auto memrefType =

200 tensorType, memorySpace));

202

206 alloc = rewriter.creatememref::AllocOp(loc, memrefType, dynamicSizes);

207 if (options.emitDealloc) {

208

210 rewriter.creatememref::DeallocOp(loc, alloc);

211 }

212 } else if (options.allocOp ==

214 alloc = rewriter.creatememref::AllocaOp(loc, memrefType, dynamicSizes);

215

216 }

217

218 return alloc;

219 }

220

224

225 assert(options.bufferizeDestinationOnly && "invalid options");

226

228 rewriter.setInsertionPoint(insertionPoint ? insertionPoint : padOp);

229 Location loc = padOp.getLoc();

230

231

235

236 if (!padOp.hasZeroLowPad() || !padOp.hasZeroHighPad()) {

237

241 }

242

243

248 Value subview = rewriter.creatememref::SubViewOp(

249 loc, alloc, padOp.getMixedLowPad(), sizes, strides);

251

252

253

254 Value toTensorOp = rewriter.createbufferization::ToTensorOp(

255 loc, alloc, true, true);

256 rewriter.replaceOp(padOp, toTensorOp);

257 return alloc;

258 }

259

262 vector::MaskOp maskOp, Attribute memorySpace, Operation *insertionPoint) {

263 assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 &&

264 "expected single masked op");

266

267

270

271 Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator();

272 assert(isavector::YieldOp(yieldOp) && "expected yield op terminator");

273

274

275

277 rewriter, options, maskOp.getMaskableOp(), memorySpace,

278 insertionPoint ? insertionPoint : maskOp);

279

280 if (options.bufferizeDestinationOnly)

281 return alloc;

282

283

285 if (failed(castbufferization::BufferizableOpInterface(yieldOp).bufferize(

286 rewriter, bufferizationOptions, bufferizationState)))

287 return nullptr;

288

289

290

291

293 maskOp.walk([&](bufferization::ToTensorOp toTensorOp) {

294 if (toTensorOp->getUses().empty())

295 toTensorOps.push_back(toTensorOp.getOperation());

296 });

297 for (Operation *op : toTensorOps)

299

300

302 for (Value result : maskOp.getResults())

303 if (isa(result.getType()))

304 for (OpOperand &use : result.getUses())

305 resultUses.push_back(&use);

307 if (failed(

308 castbufferization::BufferizableOpInterface(maskOp.getOperation())

309 .bufferize(rewriter, bufferizationOptions, bufferizationState)))

310 return nullptr;

311

312

313

314 for (OpOperand *resultUse : resultUses) {

315 auto toTensorOp =

316 resultUse->get().getDefiningOpbufferization::ToTensorOp();

317 assert(toTensorOp && "expected to_tensor op");

319 toTensorOp.setRestrict(true);

320 toTensorOp.setWritable(true);

321 });

322 }

323

324 return alloc;

325 }

326

329 bufferization::AllocTensorOp allocTensorOp, Attribute memorySpace,

331 Location loc = allocTensorOp.getLoc();

333 rewriter.setInsertionPoint(insertionPoint ? insertionPoint : allocTensorOp);

335

336

338 rewriter, loc, allocTensorOp.getResult(), options, memorySpace);

339

340

341

342 Value toTensorOp = rewriter.createbufferization::ToTensorOp(

343 loc, alloc, true, true);

344 rewriter.replaceOp(allocTensorOp, toTensorOp);

345 return alloc;

346 }

347

348

350 RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {

351 Location loc = fromElementsOp.getLoc();

352 RankedTensorType tensorType =

353 cast(fromElementsOp.getType());

354 auto shape = tensorType.getShape();

355

356

357 auto emptyOp = rewriter.create(loc, tensorType, ValueRange());

358

359

360 if (shape.empty()) {

362 fromElementsOp, fromElementsOp.getElements().front(),

364 return res;

365 }

366

367

368 auto maxDim = *llvm::max_element(shape);

370 constants.reserve(maxDim);

371 for (int i = 0; i < maxDim; ++i)

373

374

375 auto elementIt = fromElementsOp.getElements().begin();

377 Value result = createInserts(rewriter, loc, 0, emptyOp.getResult(),

378 shape, constants, elementIt, indices);

379

380

381 rewriter.replaceOp(fromElementsOp, result);

383 }

384

385

386 FailureOr<Operation *>

388 tensor::GenerateOp generateOp) {

389

390 if (!generateOp.getBody().hasOneBlock())

391 return failure();

392

393 Location loc = generateOp.getLoc();

394 RankedTensorType tensorType = cast(generateOp.getType());

395

396

397 auto emptyOp =

398 rewriter.create(loc, tensorType, generateOp.getDynamicExtents());

399

400

402 utils::IteratorType::parallel);

405 auto genericOp = rewriter.createlinalg::GenericOp(

406 loc, tensorType, ValueRange(),

407 ValueRange{emptyOp.getResult()},

408 indexingMaps, iteratorTypes);

409 Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},

410 tensorType.getElementType(), loc);

413 for (int64_t i = 0; i < tensorType.getRank(); ++i)

414 bbArgReplacements.push_back(rewriter.createlinalg::IndexOp(loc, i));

415 rewriter.mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements);

416

417

418 auto yieldOp = casttensor::YieldOp(body->getTerminator());

419 rewriter.replaceOpWithNewOplinalg::YieldOp(yieldOp, yieldOp.getValue());

420

421

422 rewriter.replaceOp(generateOp, genericOp->getResult(0));

423 return genericOp.getOperation();

424 }

425

426

427 FailureOr<Operation *>

429 tensor::PadOp padOp) {

430

431 if (!padOp.getBodyRegion().hasOneBlock())

432 return failure();

433

434

435 Location loc = padOp.getLoc();

436 RankedTensorType resultType = padOp.getResultType();

440 padOp, "failed to reify tensor.pad op result shape");

442 for (int64_t i = 0; i < resultType.getRank(); ++i)

443 if (resultType.isDynamicDim(i))

444 dynamicSizes.push_back(cast(reifiedShape[0][i]));

445

446

447

448 if (padOp.getNofoldAttr() &&

449 llvm::all_of(padOp.getMixedLowPad(), isZeroInteger) &&

450 llvm::all_of(padOp.getMixedHighPad(), isZeroInteger)) {

451 using bufferization::AllocTensorOp;

452 Value allocated =

453 rewriter.create(loc, resultType, dynamicSizes);

455 padOp, padOp.getSource(), allocated);

456 return copyOp.getOperation();

457 }

458

459 Value empty = rewriter.create(loc, resultType, dynamicSizes);

460

463

464

469 auto insertSliceOp = rewriter.replaceOpWithNewOptensor::InsertSliceOp(

470 padOp, padOp.getSource(), fillOp->getResult(0),

471 padOp.getMixedLowPad(), sliceSizes, sliceStrides);

472 return insertSliceOp.getOperation();

473 }

474

478 using namespace bufferization;

479

480

481 if (auto padOp = dyn_casttensor::PadOp(op))

483 if (auto maskOp = dyn_castvector::MaskOp(op))

485 if (auto allocTensorOp = dyn_castbufferization::AllocTensorOp(op))

487

488

489 auto bufferizableOp = dyn_cast(op);

490 if (!bufferizableOp)

491 return nullptr;

492

493

494 BufferizationOptions bufferizationOptions;

495 AnalysisState analysisState(bufferizationOptions);

496 BufferizationState bufferizationState;

497

498 #ifndef NDEBUG

499 if (options.bufferizeDestinationOnly) {

500

501

503 if (op == nestedOp)

504 return;

505 if (llvm::any_of(nestedOp->getOperands(),

506 [](Value v) { return isa(v.getType()); }))

507 llvm_unreachable("ops with nested tensor ops are not supported yet");

508 if (llvm::any_of(nestedOp->getResults(),

509 [](Value v) { return isa(v.getType()); }))

510 llvm_unreachable("ops with nested tensor ops are not supported yet");

511 });

512 }

513 #endif

514

515

518 if (!isa(result.getType()))

519 continue;

520

521 if (!isa(result.getType()))

522 return nullptr;

523

524 if (bufferizableOp.bufferizesToAllocation(result))

525 return nullptr;

526 tensorResults.push_back(result);

527 }

528

529

530

532 auto addOutOfPlaceOperand = [&](OpOperand *operand) {

533 if (!llvm::is_contained(outOfPlaceOperands, operand))

534 outOfPlaceOperands.push_back(operand);

535 };

536 for (OpResult result : tensorResults) {

538 analysisState.getAliasingOpOperands(result);

539 for (const AliasingOpOperand &operand : aliasingOperands) {

540 addOutOfPlaceOperand(operand.opOperand);

541 for (OpOperand &resultUse : result.getUses())

542 resultUses.push_back(&resultUse);

543 }

544 }

546 if (!analysisState.bufferizesToMemoryWrite(operand))

547 continue;

548 if (!isa(operand.get().getType()))

549 continue;

550 addOutOfPlaceOperand(&operand);

551 }

552

553 if (outOfPlaceOperands.size() != 1)

554 return nullptr;

555

556

558 rewriter.setInsertionPoint(insertionPoint ? insertionPoint : op);

560 for (OpOperand *operand : outOfPlaceOperands) {

562 rewriter, op->getLoc(), operand->get(), options, memorySpace);

563 allocs.push_back(alloc);

564 if (!analysisState.findDefinitions(operand).empty()) {

565

566

568 }

570 auto toTensorOp = rewriter.create(op->getLoc(), alloc);

571 operand->set(toTensorOp);

572 if (options.bufferizeDestinationOnly) {

574 toTensorOp.setRestrict(true);

575 toTensorOp.setWritable(true);

576 });

577 }

578 });

579 }

580

581 if (options.bufferizeDestinationOnly)

582 return allocs.front();

583

584

586 if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions,

587 bufferizationState)))

588 return nullptr;

589

590

591

592 for (OpOperand *resultUse : resultUses) {

593 auto toTensorOp = resultUse->get().getDefiningOp();

594 assert(toTensorOp && "expected to_tensor op");

596 toTensorOp.setRestrict(true);

597 toTensorOp.setWritable(true);

598 });

599 }

600 return allocs.front();

601 }

602

603 namespace {

604

605 template

606 LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,

609 }

610

611 }

612

615 patterns.add(rewriteOpInDestinationPassingStyletensor::FromElementsOp);

616 patterns.add(rewriteOpInDestinationPassingStyletensor::GenerateOp);

617 patterns.add(rewriteOpInDestinationPassingStyletensor::PadOp);

618 }

static Operation * movePaddingToFillOrGenericOp(RewriterBase &rewriter, Location loc, PadOp padOp, Value dest)

static Value createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value, const linalg::BufferizeToAllocationOptions &options, Attribute memorySpace={})

static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, Value memrefDest, const linalg::BufferizeToAllocationOptions &options)

Create a memcpy from the given source tensor to the given destination memref.

static SmallVector< Value > reifyOrComputeDynamicSizes(OpBuilder &b, Value value)

static Value createInserts(RewriterBase &rewriter, Location loc, int dim, Value destination, ArrayRef< int64_t > shape, ArrayRef< Value > constants, OperandRange::iterator &elementIt, SmallVectorImpl< Value > &indices)

static llvm::ManagedStatic< PassManagerOptions > options

Base class for generic analysis states.

Attributes are known-constant values of operations.

Block represents an ordered list of Operations.

Operation * getTerminator()

Get the terminator operation of this block.

IntegerAttr getIndexAttr(int64_t value)

AffineMap getMultiDimIdentityMap(unsigned rank)

MLIRContext * getContext() const

Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...

virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)

Registered hook to materialize a single constant operation from a given attribute value with the desi...

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

Dialect * getLoadedDialect(StringRef name)

Get a registered IR dialect with the given namespace.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

Block * getInsertionBlock() const

Return the block the current insertion point belongs to.

This class represents an operand of an operation.

This is a value defined by a result of an operation.

Operation is the basic unit of execution within MLIR.

std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)

Walk the operation by calling the callback for each nested operation (including this one),...

Location getLoc()

The source location the operation was defined or derived from.

Operation * getParentOp()

Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...

MutableArrayRef< OpOperand > getOpOperands()

operand_range getOperands()

Returns an iterator on the underlying Value's.

result_range getResults()

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)

Inline the operations of block 'source' into the end of block 'dest'.

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Location getLoc() const

Return the location of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

Specialization of arith.constant op that returns an integer of index type.

BufferizationState provides information about the state of the IR during the bufferization process.

BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)

Return a MemRef type with a static identity layout (i.e., no layout map).

AliasList< AliasingOpOperand > AliasingOpOperandList

A list of possible aliasing OpOperands.

BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)

Return a MemRef type with fully dynamic layout.

Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)

Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....

FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)

Rewrite tensor.from_elements to linalg.generic.

void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns)

Populate patterns that convert non-destination-style ops to destination style ops.

SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)

Return the dimensions of the given tensor value.

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)

Reify the shape of the result of an operation (typically in terms of the shape of its operands).

const FrozenRewritePatternSet & patterns

bool isZeroInteger(OpFoldResult v)

Return true if v is an IntegerAttr with value 0.

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

Options for BufferizableOpInterface-based bufferization.

@ MaterializeInDestination