MLIR: lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14 #include

15 #include <type_traits>

16

23

28

29 #include "llvm/ADT/DenseSet.h"

30 #include "llvm/ADT/MapVector.h"

31 #include "llvm/ADT/STLExtras.h"

32 #include "llvm/Support/CommandLine.h"

33 #include "llvm/Support/Debug.h"

34 #include "llvm/Support/raw_ostream.h"

35

36 #define DEBUG_TYPE "vector-transfer-split"

37

38 using namespace mlir;

40

41

42

44 VectorTransferOpInterface xferOp) {

45 assert(xferOp.getPermutationMap().isMinorIdentity() &&

46 "Expected minor identity map");

47 Value inBoundsCond;

48 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {

49

50

51

52 if (xferOp.isDimInBounds(resultIdx))

53 return;

54

55 Location loc = xferOp.getLoc();

56 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);

59 {xferOp.getIndices()[indicesIdx]});

64 if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)

65 return;

67 b.createarith::CmpIOp(loc, arith::CmpIPredicate::sle,

70

71 if (inBoundsCond)

72 inBoundsCond = b.createarith::AndIOp(loc, inBoundsCond, cond);

73 else

74 inBoundsCond = cond;

75 });

76 return inBoundsCond;

77 }

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112 static LogicalResult

114

115 if (xferOp.getTransferRank() == 0)

116 return failure();

117

118

119 if (!xferOp.getPermutationMap().isMinorIdentity())

120 return failure();

121

122 if (!xferOp.hasOutOfBoundsDim())

123 return failure();

124

125

126

127 if (isascf::IfOp(xferOp->getParentOp()))

128 return failure();

129 return success();

130 }

131

132

133

134

135

136

137

138

139

140

142 if (memref::CastOp::areCastCompatible(aT, bT))

143 return aT;

144 if (aT.getRank() != bT.getRank())

145 return MemRefType();

146 int64_t aOffset, bOffset;

148 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||

149 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||

150 aStrides.size() != bStrides.size())

151 return MemRefType();

152

153 ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();

154 int64_t resOffset;

156 resStrides(bT.getRank(), 0);

157 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {

158 resShape[idx] =

159 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;

160 resStrides[idx] =

161 (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;

162 }

163 resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;

165 resShape, aT.getElementType(),

167 }

168

169

170

171

173 MemRefType compatibleMemRefType) {

174 MemRefType sourceType = cast(memref.getType());

175 Value res = memref;

176 if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {

178 sourceType.getShape(), sourceType.getElementType(),

179 sourceType.getLayout(), compatibleMemRefType.getMemorySpace());

180 res = b.creatememref::MemorySpaceCastOp(memref.getLoc(), sourceType, res);

181 }

182 if (sourceType == compatibleMemRefType)

183 return res;

184 return b.creatememref::CastOp(memref.getLoc(), compatibleMemRefType, res);

185 }

186

187

188

189

190 static std::pair<Value, Value>

193 Location loc = xferOp.getLoc();

194 int64_t memrefRank = xferOp.getShapedType().getRank();

195

196 assert(memrefRank == cast(alloc.getType()).getRank() &&

197 "Expected memref rank to match the alloc rank");

199 xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());

201 sizes.append(leadingIndices.begin(), leadingIndices.end());

202 auto isaWrite = isavector::TransferWriteOp(xferOp);

203 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {

205 Value dimMemRef =

206 b.creatememref::DimOp(xferOp.getLoc(), xferOp.getBase(), indicesIdx);

207 Value dimAlloc = b.creatememref::DimOp(loc, alloc, resultIdx);

208 Value index = xferOp.getIndices()[indicesIdx];

210 bindDims(xferOp.getContext(), i, j, k);

213

214 Value affineMin = b.createaffine::AffineMinOp(

215 loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});

216 sizes.push_back(affineMin);

217 });

218

220 xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; }));

223 auto copySrc = b.creatememref::SubViewOp(

224 loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides);

225 auto copyDest = b.creatememref::SubViewOp(

226 loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides);

227 return std::make_pair(copySrc, copyDest);

228 }

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249 static scf::IfOp

252 MemRefType compatibleMemRefType, Value alloc) {

253 Location loc = xferOp.getLoc();

254 Value zero = b.createarith::ConstantIndexOp(loc, 0);

255 Value memref = xferOp.getBase();

256 return b.createscf::IfOp(

257 loc, inBoundsCond,

261 llvm::append_range(viewAndIndices, xferOp.getIndices());

262 b.createscf::YieldOp(loc, viewAndIndices);

263 },

265 b.createlinalg::FillOp(loc, ValueRange{xferOp.getPadding()},

267

268

271 rewriter, cast(xferOp.getOperation()),

272 alloc);

273 b.creatememref::CopyOp(loc, copyArgs.first, copyArgs.second);

277 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),

278 zero);

279 b.createscf::YieldOp(loc, viewAndIndices);

280 });

281 }

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

304 Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {

305 Location loc = xferOp.getLoc();

306 scf::IfOp fullPartialIfOp;

307 Value zero = b.createarith::ConstantIndexOp(loc, 0);

308 Value memref = xferOp.getBase();

309 return b.createscf::IfOp(

310 loc, inBoundsCond,

314 llvm::append_range(viewAndIndices, xferOp.getIndices());

315 b.createscf::YieldOp(loc, viewAndIndices);

316 },

318 Operation *newXfer = b.clone(*xferOp.getOperation());

319 Value vector = cast(newXfer).getVector();

320 b.creatememref::StoreOp(

321 loc, vector,

322 b.createvector::TypeCastOp(

324

328 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),

329 zero);

330 b.createscf::YieldOp(loc, viewAndIndices);

331 });

332 }

333

334

335

336

337

338

339

340

341

342

343

344

345

346

347

348

352 MemRefType compatibleMemRefType, Value alloc) {

353 Location loc = xferOp.getLoc();

354 Value zero = b.createarith::ConstantIndexOp(loc, 0);

355 Value memref = xferOp.getBase();

356 return b

358 loc, inBoundsCond,

363 llvm::append_range(viewAndIndices, xferOp.getIndices());

364 b.createscf::YieldOp(loc, viewAndIndices);

365 },

370 viewAndIndices.insert(viewAndIndices.end(),

371 xferOp.getTransferRank(), zero);

372 b.createscf::YieldOp(loc, viewAndIndices);

373 })

374 ->getResults();

375 }

376

377

378

379

380

381

382

383

384

385

386

387

388

389

391 vector::TransferWriteOp xferOp,

393 Location loc = xferOp.getLoc();

394 auto notInBounds = b.createarith::XOrIOp(

395 loc, inBoundsCond, b.createarith::ConstantIntOp(loc, true, 1));

399 rewriter, cast(xferOp.getOperation()),

400 alloc);

401 b.creatememref::CopyOp(loc, copyArgs.first, copyArgs.second);

403 });

404 }

405

406

407

408

409

410

411

412

413

414

415

416

417

419 vector::TransferWriteOp xferOp,

420 Value inBoundsCond,

422 Location loc = xferOp.getLoc();

423 auto notInBounds = b.createarith::XOrIOp(

424 loc, inBoundsCond, b.createarith::ConstantIntOp(loc, true, 1));

428 loc,

429 b.createvector::TypeCastOp(

432 mapping.map(xferOp.getVector(), load);

433 b.clone(*xferOp.getOperation(), mapping);

435 });

436 }

437

438

440

441

442

447 scope = parent;

448 if (!isa<scf::ForOp, affine::AffineForOp>(parent))

449 break;

450 }

451 assert(scope && "Expected op to be inside automatic allocation scope");

452 return scope;

453 }

454

455

456

457

458

459

460

461

462

463

464

465

466

467

468

469

470

471

472

473

474

475

476

477

478

479

480

481

482

483

484

485

486

487

488

489

490

491

492

493

494

495

496

497

498

499

500

501

502

503

504

505

506

507

508

509

510

511

512

513

514

516 RewriterBase &b, VectorTransferOpInterface xferOp,

519 return failure();

520

523 if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {

525 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);

526 });

527 return success();

528 }

529

530

531

532 {

534 "Expected splitFullAndPartialTransferPrecondition to hold");

535

536 auto xferReadOp = dyn_castvector::TransferReadOp(xferOp.getOperation());

537 auto xferWriteOp = dyn_castvector::TransferWriteOp(xferOp.getOperation());

538

539 if (!(xferReadOp || xferWriteOp))

540 return failure();

541 if (xferWriteOp && xferWriteOp.getMask())

542 return failure();

543 if (xferReadOp && xferReadOp.getMask())

544 return failure();

545 }

546

550 b, cast(xferOp.getOperation()));

551 if (!inBoundsCond)

552 return failure();

553

554

556 {

560 "AutomaticAllocationScope with >1 regions");

562 auto shape = xferOp.getVectorType().getShape();

563 Type elementType = xferOp.getVectorType().getElementType();

564 alloc = b.creatememref::AllocaOp(scope->getLoc(),

567 }

568

569 MemRefType compatibleMemRefType =

571 cast(alloc.getType()));

572 if (!compatibleMemRefType)

573 return failure();

574

577 returnTypes[0] = compatibleMemRefType;

578

579 if (auto xferReadOp =

580 dyn_castvector::TransferReadOp(xferOp.getOperation())) {

581

582 scf::IfOp fullPartialIfOp =

583 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer

585 inBoundsCond,

586 compatibleMemRefType, alloc)

588 inBoundsCond, compatibleMemRefType,

589 alloc);

590 if (ifOp)

591 *ifOp = fullPartialIfOp;

592

593

594 for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)

595 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));

596

598 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);

599 });

600

601 return success();

602 }

603

604 auto xferWriteOp = castvector::TransferWriteOp(xferOp.getOperation());

605

606

608 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);

609

610

611

612

614 mapping.map(xferWriteOp.getBase(), memrefAndIndices.front());

615 mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());

616 auto *clone = b.clone(*xferWriteOp, mapping);

617 clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);

618

619

620

621 if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)

623 else

625

627

628 return success();

629 }

630

631 namespace {

632

633

634 struct VectorTransferFullPartialRewriter : public RewritePattern {

635 using FilterConstraintType =

636 std::function<LogicalResult(VectorTransferOpInterface op)>;

637

638 explicit VectorTransferFullPartialRewriter(

641 FilterConstraintType filter =

642 [](VectorTransferOpInterface op) { return success(); },

645 filter(std::move(filter)) {}

646

647

648 LogicalResult matchAndRewrite(Operation *op,

650

651 private:

653 FilterConstraintType filter;

654 };

655

656 }

657

658 LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(

660 auto xferOp = dyn_cast(op);

662 failed(filter(xferOp)))

663 return failure();

665 }

666

669 patterns.add(patterns.getContext(),

671 }

static llvm::ManagedStatic< PassManagerOptions > options

static scf::IfOp createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)

Given an xferOp for which:

static void createFullPartialVectorTransferWrite(RewriterBase &b, vector::TransferWriteOp xferOp, Value inBoundsCond, Value alloc)

Given an xferOp for which:

static ValueRange getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)

Given an xferOp for which:

static scf::IfOp createFullPartialVectorTransferRead(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)

Given an xferOp for which:

static LogicalResult splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)

Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fast path and a ...

static std::pair< Value, Value > createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, Value alloc)

Operates under a scoped context to build the intersection between the view xferOp....

static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT)

Given two MemRefTypes aT and bT, return a MemRefType to which both can be cast.

static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)

Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.

static Operation * getAutomaticAllocationScope(Operation *op)

static Value castToCompatibleMemRefType(OpBuilder &b, Value memref, MemRefType compatibleMemRefType)

Casts the given memref to a compatible memref type.

Base type for affine expression.

static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)

Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...

IntegerAttr getIndexAttr(int64_t value)

AffineExpr getAffineConstantExpr(int64_t constant)

IntegerAttr getI64IntegerAttr(int64_t value)

AffineExpr getAffineDimExpr(unsigned position)

MLIRContext * getContext() const

ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)

This is a utility class for mapping one set of IR entities to another.

void map(Value from, Value to)

Inserts a new mapping for 'from' to 'to'.

This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents a single result from folding an operation.

A trait of region holding operations that define a new scope for automatic allocations,...

Operation is the basic unit of execution within MLIR.

unsigned getNumRegions()

Returns the number of regions held by this operation.

Location getLoc()

The source location the operation was defined or derived from.

Operation * getParentOp()

Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...

Region & getRegion(unsigned index)

Returns the region held by this operation at position 'index'.

void setAttr(StringAttr name, Attribute value)

If the an attribute exists with the specified name, change it to the new value.

This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

RewritePattern is the common base class for all DAG to DAG replacements.

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Location getLoc() const

Return the location of this value.

OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...

OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)

Return the dimension of the given memref value.

SmallVector< Value > ValueVector

An owning vector of values, handy to return from functions.

LogicalResult splitFullAndPartialTransfer(RewriterBase &b, VectorTransferOpInterface xferOp, VectorTransformsOptions options=VectorTransformsOptions(), scf::IfOp *ifOp=nullptr)

Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fastpath and a s...

void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)

Populate patterns with the following patterns.

Include the generated interface declarations.

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to DimExpr at positions: [0 .

const FrozenRewritePatternSet & patterns

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

Structure to control the behavior of vector transform patterns.

Eliminates variable at the specified position using Fourier-Motzkin variable elimination.