MLIR: lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

19

20 using namespace mlir;

22

23

24

25 static ArrayAttr

29 size_t index = 0;

30 for (unsigned pos : permutation)

31 newInBoundsValues[pos] =

32 cast(attr.getValue()[index++]).getValue();

34 }

35

36

37

39 int64_t addedRank) {

40 auto originalVecType = cast(vec.getType());

42 newShape.append(originalVecType.getShape().begin(),

43 originalVecType.getShape().end());

44

46 newScalableDims.append(originalVecType.getScalableDims().begin(),

47 originalVecType.getScalableDims().end());

49 newShape, originalVecType.getElementType(), newScalableDims);

50 return builder.createvector::BroadcastOp(loc, newVecType, vec);

51 }

52

53

54

56 int64_t addedRank) {

59 for (int64_t i = addedRank,

60 e = cast(broadcasted.getType()).getRank();

61 i < e; ++i)

62 permutation.push_back(i);

63 for (int64_t i = 0; i < addedRank; ++i)

64 permutation.push_back(i);

65 return builder.createvector::TransposeOp(loc, broadcasted, permutation);

66 }

67

68

69

70

71

72 namespace {

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92 struct TransferReadPermutationLowering

94 using MaskableOpRewritePattern::MaskableOpRewritePattern;

95

96 FailureOrmlir::Value

97 matchAndRewriteMaskableOp(vector::TransferReadOp op,

98 MaskingOpInterface maskOp,

100

101 if (op.getTransferRank() == 0)

102 return rewriter.notifyMatchFailure(op, "0-d corner case not supported");

103

104 if (maskOp)

106

113 op, "map is not permutable to minor identity, apply another pattern");

114 }

119

120 permutationMap = map.getPermutationMap(permutation, op.getContext());

121

124

127 ArrayRef originalScalableDims = op.getVectorType().getScalableDims();

130 newVectorShape[pos.value()] = originalShape[pos.index()];

131 newScalableDims[pos.value()] = originalScalableDims[pos.index()];

132 }

133

134

135 ArrayAttr newInBoundsAttr =

137

138

140 newVectorShape, op.getVectorType().getElementType(), newScalableDims);

141 Value newRead = rewriter.createvector::TransferReadOp(

142 op.getLoc(), newReadType, op.getBase(), op.getIndices(),

144 newInBoundsAttr);

145

146

148 return rewriter

149 .createvector::TransposeOp(op.getLoc(), newRead, transposePerm)

150 .getResult();

151 }

152 };

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170 struct TransferWritePermutationLowering

172 using MaskableOpRewritePattern::MaskableOpRewritePattern;

173

174 FailureOrmlir::Value

175 matchAndRewriteMaskableOp(vector::TransferWriteOp op,

176 MaskingOpInterface maskOp,

178

179 if (op.getTransferRank() == 0)

180 return rewriter.notifyMatchFailure(op, "0-d corner case not supported");

181

182 if (maskOp)

184

188 return rewriter.notifyMatchFailure(op, "map is already minor identity");

189

192 op, "map is not permutable to minor identity, apply another pattern");

193 }

194

195

196

197

200

202 llvm::transform(permutationMap.getResults(), std::back_inserter(indices),

204 return dyn_cast(expr).getPosition();

205 });

206

207

208 ArrayAttr newInBoundsAttr =

210

211

212 Value newVec = rewriter.createvector::TransposeOp(

213 op.getLoc(), op.getVector(), indices);

216 auto newWrite = rewriter.createvector::TransferWriteOp(

217 op.getLoc(), newVec, op.getBase(), op.getIndices(),

219 if (newWrite.hasPureTensorSemantics())

221

222

224 }

225 };

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242 struct TransferWriteNonPermutationLowering

244 using MaskableOpRewritePattern::MaskableOpRewritePattern;

245

246 FailureOrmlir::Value

247 matchAndRewriteMaskableOp(vector::TransferWriteOp op,

248 MaskingOpInterface maskOp,

250

251 if (op.getTransferRank() == 0)

252 return rewriter.notifyMatchFailure(op, "0-d corner case not supported");

253

254 if (maskOp)

256

261 op,

262 "map is already permutable to minor identity, apply another pattern");

263 }

264

265

266

269 foundDim[cast(exp).getPosition()] = true;

271 bool foundFirstDim = false;

273 for (size_t i = 0; i < foundDim.size(); i++) {

274 if (foundDim[i]) {

275 foundFirstDim = true;

276 continue;

277 }

278 if (!foundFirstDim)

279 continue;

280

281

282 missingInnerDim.push_back(i);

284 }

285

287 missingInnerDim.size());

288

290 if (op.getMask())

291 newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),

292 missingInnerDim.size());

296

297 SmallVector newInBoundsValues(missingInnerDim.size(), true);

298 for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {

299 newInBoundsValues.push_back(op.isDimInBounds(i));

300 }

301 ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);

302 auto newWrite = rewriter.createvector::TransferWriteOp(

303 op.getLoc(), newVec, op.getBase(), op.getIndices(),

305 if (newWrite.hasPureTensorSemantics())

307

308

310 }

311 };

312

313

314

315

316

317

318

319

320

321 struct TransferOpReduceRank

323 using MaskableOpRewritePattern::MaskableOpRewritePattern;

324

325 FailureOrmlir::Value

326 matchAndRewriteMaskableOp(vector::TransferReadOp op,

327 MaskingOpInterface maskOp,

329

330 if (op.getTransferRank() == 0)

331 return rewriter.notifyMatchFailure(op, "0-d corner case not supported");

332

333 if (maskOp)

335

337 unsigned numLeadingBroadcast = 0;

339 auto dimExpr = dyn_cast(expr);

340 if (!dimExpr || dimExpr.getValue() != 0)

341 break;

342 numLeadingBroadcast++;

343 }

344

345 if (numLeadingBroadcast == 0)

346 return rewriter.notifyMatchFailure(op, "no leading broadcasts in map");

347

348 VectorType originalVecType = op.getVectorType();

349 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;

350

353 op.getContext());

354

355

358 op, "map is not a minor identity with broadcasting");

359 }

360

362 originalVecType.getShape().take_back(reducedShapeRank));

364 originalVecType.getScalableDims().take_back(reducedShapeRank));

365

367 newShape, originalVecType.getElementType(), newScalableDims);

368 ArrayAttr newInBoundsAttr =

369 op.getInBounds()

371 op.getInBoundsAttr().getValue().take_back(reducedShapeRank))

372 : ArrayAttr();

373 Value newRead = rewriter.createvector::TransferReadOp(

374 op.getLoc(), newReadType, op.getBase(), op.getIndices(),

376 newInBoundsAttr);

377 return rewriter

378 .createvector::BroadcastOp(op.getLoc(), originalVecType, newRead)

379 .getVector();

380 }

381 };

382

383 }

384

388 .add<TransferReadPermutationLowering, TransferWritePermutationLowering,

389 TransferOpReduceRank, TransferWriteNonPermutationLowering>(

390 patterns.getContext(), benefit);

391 }

392

393

394

395

396

397 namespace {

398

399

400

401

402

403

404

405

406 struct TransferReadToVectorLoadLowering

408 TransferReadToVectorLoadLowering(MLIRContext *context,

409 std::optional maxRank,

412 maxTransferRank(maxRank) {}

413

414 FailureOrmlir::Value

415 matchAndRewriteMaskableOp(vector::TransferReadOp read,

416 MaskingOpInterface maskOp,

418 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {

420 read, "vector type is greater than max transfer rank");

421 }

422

423 if (maskOp)

424 return rewriter.notifyMatchFailure(read, "Masked case not supported");

426

427

428

429 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(

430 &broadcastedDims))

431 return rewriter.notifyMatchFailure(read, "not minor identity + bcast");

432

433 auto memRefType = dyn_cast(read.getShapedType());

434 if (!memRefType)

436

437

438 if (!memRefType.isLastDimUnitStride())

439 return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF");

440

441

442

445 for (unsigned i : broadcastedDims)

446 unbroadcastedVectorShape[i] = 1;

447 VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(

448 unbroadcastedVectorShape, read.getVectorType().getElementType());

449

450

451

452 auto memrefElTy = memRefType.getElementType();

453 if (isa(memrefElTy) && memrefElTy != unbroadcastedVectorType)

454 return rewriter.notifyMatchFailure(read, "incompatible element type");

455

456

457 if (!isa(memrefElTy) &&

458 memrefElTy != read.getVectorType().getElementType())

459 return rewriter.notifyMatchFailure(read, "non-matching element type");

460

461

462 if (read.hasOutOfBoundsDim())

463 return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");

464

465

467 if (read.getMask()) {

468 if (read.getVectorType().getRank() != 1)

469

471 read, "vector type is not rank 1, can't create masked load, needs "

472 "VectorToSCF");

473

474 Value fill = rewriter.createvector::SplatOp(

475 read.getLoc(), unbroadcastedVectorType, read.getPadding());

476 res = rewriter.createvector::MaskedLoadOp(

477 read.getLoc(), unbroadcastedVectorType, read.getBase(),

478 read.getIndices(), read.getMask(), fill);

479 } else {

480 res = rewriter.createvector::LoadOp(read.getLoc(),

481 unbroadcastedVectorType,

482 read.getBase(), read.getIndices());

483 }

484

485

486 if (!broadcastedDims.empty())

487 res = rewriter.createvector::BroadcastOp(

488 read.getLoc(), read.getVectorType(), res->getResult(0));

490 }

491

492 std::optional maxTransferRank;

493 };

494

495

496

497

498

499

500

501

502

503 struct TransferWriteToVectorStoreLowering

505 TransferWriteToVectorStoreLowering(MLIRContext *context,

506 std::optional maxRank,

509 maxTransferRank(maxRank) {}

510

511 FailureOrmlir::Value

512 matchAndRewriteMaskableOp(vector::TransferWriteOp write,

513 MaskingOpInterface maskOp,

515 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {

517 write, "vector type is greater than max transfer rank");

518 }

519 if (maskOp)

520 return rewriter.notifyMatchFailure(write, "Masked case not supported");

521

522

523

524 if (

525 !write.getPermutationMap().isMinorIdentity())

527 diag << "permutation map is not minor identity: " << write;

528 });

529

530 auto memRefType = dyn_cast(write.getShapedType());

531 if (!memRefType)

533 diag << "not a memref type: " << write;

534 });

535

536

537 if (!memRefType.isLastDimUnitStride())

539 diag << "most minor stride is not 1: " << write;

540 });

541

542

543

544 auto memrefElTy = memRefType.getElementType();

545 if (isa(memrefElTy) && memrefElTy != write.getVectorType())

547 diag << "elemental type mismatch: " << write;

548 });

549

550

551 if (!isa(memrefElTy) &&

552 memrefElTy != write.getVectorType().getElementType())

554 diag << "elemental type mismatch: " << write;

555 });

556

557

558 if (write.hasOutOfBoundsDim())

560 diag << "out of bounds dim: " << write;

561 });

562 if (write.getMask()) {

563 if (write.getVectorType().getRank() != 1)

564

567 diag << "vector type is not rank 1, can't create masked store, "

568 "needs VectorToSCF: "

569 << write;

570 });

571

572 rewriter.createvector::MaskedStoreOp(

573 write.getLoc(), write.getBase(), write.getIndices(), write.getMask(),

574 write.getVector());

575 } else {

576 rewriter.createvector::StoreOp(write.getLoc(), write.getVector(),

577 write.getBase(), write.getIndices());

578 }

579

580

582 }

583

584 std::optional maxTransferRank;

585 };

586 }

587

591 patterns.add<TransferReadToVectorLoadLowering,

592 TransferWriteToVectorStoreLowering>(patterns.getContext(),

593 maxTransferRank, benefit);

594 }

static ArrayAttr inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, const SmallVector< unsigned > &permutation)

Transpose a vector transfer op's in_bounds attribute by applying reverse permutation based on the giv...

static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)

Extend the rank of a vector Value by addedRanks by adding inner unit dimensions.

static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)

Extend the rank of a vector Value by addedRanks by adding outer unit dimensions.

static std::string diag(const llvm::Value &value)

static std::optional< VectorShape > vectorShape(Type type)

Base type for affine expression.

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)

Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.

bool isMinorIdentity() const

Returns true if this affine map is a minor identity, i.e.

static AffineMap get(MLIRContext *context)

Returns a zero result affine map with no dimensions or symbols: () -> ().

bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const

Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...

unsigned getNumDims() const

ArrayRef< AffineExpr > getResults() const

bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const

Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...

unsigned getNumResults() const

static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)

Returns an AffineMap representing a permutation.

AffineMap compose(AffineMap map) const

Returns the AffineMap resulting from composing this with map.

bool isIdentity() const

Returns true if this affine map is an identity affine map.

AffineExpr getAffineDimExpr(unsigned position)

MLIRContext * getContext() const

ArrayAttr getArrayAttr(ArrayRef< Attribute > value)

ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)

This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

This class helps build Operations.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

Operation is the basic unit of execution within MLIR.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)

Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....

void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)

Populate the pattern set with the following patterns:

Include the generated interface declarations.

AffineMap inversePermutation(AffineMap map)

Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...

const FrozenRewritePatternSet & patterns

AffineMap compressUnusedDims(AffineMap map)

Drop the dims that are not used.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.