MLIR: lib/Dialect/Vector/Transforms/VectorDistribute.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

20 #include "llvm/ADT/SetVector.h"

21 #include "llvm/Support/FormatVariadic.h"

22 #include

23

24 using namespace mlir;

27

28

29

30

31

32

33

34

35

36

37

38

40 VectorType distributedType) {

42 perm.reserve(1);

43

44

45

46 for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {

47 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))

48 perm.push_back(getAffineDimExpr(i, distributedType.getContext()));

49 }

50 auto map = AffineMap::get(sequentialType.getRank(), 0, perm,

51 distributedType.getContext());

52 return map;

53 }

54

55 namespace {

56

57

58

59

60

61

62 struct DistributedLoadStoreHelper {

63 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,

65 : sequentialVal(sequentialVal), distributedVal(distributedVal),

66 laneId(laneId), zero(zero) {

67 sequentialVectorType = dyn_cast(sequentialVal.getType());

68 distributedVectorType = dyn_cast(distributedVal.getType());

69 if (sequentialVectorType && distributedVectorType)

70 distributionMap =

72 }

73

75 int64_t distributedSize = distributedVectorType.getDimSize(index);

77 return b.createOrFoldaffine::AffineApplyOp(loc, tid * distributedSize,

79 }

80

81

82

83

84

85

86

87

90 assert((val == distributedVal || val == sequentialVal) &&

91 "Must store either the preregistered distributed or the "

92 "preregistered sequential value.");

93

94 if (!isa(val.getType()))

95 return b.creatememref::StoreOp(loc, val, buffer, zero);

96

97

98

99 int64_t rank = sequentialVectorType.getRank();

101 if (val == distributedVal) {

102 for (auto dimExpr : distributionMap.getResults()) {

103 int64_t index = cast(dimExpr).getPosition();

104 indices[index] = buildDistributedOffset(b, loc, index);

105 }

106 }

108 return b.createvector::TransferWriteOp(

109 loc, val, buffer, indices,

111 }

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

134

135

136 if (!isa(type))

137 return b.creatememref::LoadOp(loc, buffer, zero);

138

139

140

141

142 assert((type == distributedVectorType || type == sequentialVectorType) &&

143 "Must store either the preregistered distributed or the "

144 "preregistered sequential type.");

146 if (type == distributedVectorType) {

147 for (auto dimExpr : distributionMap.getResults()) {

148 int64_t index = cast(dimExpr).getPosition();

149 indices[index] = buildDistributedOffset(b, loc, index);

150 }

151 }

153 return b.createvector::TransferReadOp(

154 loc, cast(type), buffer, indices,

156 }

157

158 Value sequentialVal, distributedVal, laneId, zero;

159 VectorType sequentialVectorType, distributedVectorType;

161 };

162

163 }

164

165

166

173 return rewriter.create(res);

174 }

175

176 namespace {

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

207 WarpOpToScfIfPattern(MLIRContext *context,

211

212 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

214 assert(warpOp.getBodyRegion().hasOneBlock() &&

215 "expected WarpOp with single block");

216 Block *warpOpBody = &warpOp.getBodyRegion().front();

217 Location loc = warpOp.getLoc();

218

219

222

223

224 Value c0 = rewriter.createarith::ConstantIndexOp(loc, 0);

225 Value isLane0 = rewriter.createarith::CmpIOp(

226 loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);

227 auto ifOp = rewriter.createscf::IfOp(loc, isLane0,

228 false);

229 rewriter.eraseOp(ifOp.thenBlock()->getTerminator());

230

231

232

234 for (const auto &it : llvm::enumerate(warpOp.getArgs())) {

236 Value distributedVal = it.value();

237 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,

238 warpOp.getLaneid(), c0);

239

240

242 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,

243 sequentialVal.getType());

244

245 helper.buildStore(rewriter, loc, distributedVal, buffer);

246

248 bbArgReplacements.push_back(

249 helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));

250 }

251

252

253 if (!warpOp.getArgs().empty()) {

255 options.warpSyncronizationFn(loc, rewriter, warpOp);

256 }

257

258

259 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);

260

261

262

263

264

266 auto yieldOp = castgpu::YieldOp(ifOp.thenBlock()->getTerminator());

267 Location yieldLoc = yieldOp.getLoc();

268 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {

269 Value sequentialVal = it.value();

270 Value distributedVal = warpOp->getResult(it.index());

271 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,

272 warpOp.getLaneid(), c0);

273

274

276 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,

277 sequentialVal.getType());

278

279

280

282 helper.buildStore(rewriter, loc, sequentialVal, buffer);

283

284

286

287

288

289

290

291

292

293 replacements.push_back(

294 helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));

295 }

296

297

298 if (!yieldOp.getOperands().empty()) {

300 options.warpSyncronizationFn(loc, rewriter, warpOp);

301 }

302

303

304 rewriter.eraseOp(yieldOp);

306 rewriter.createscf::YieldOp(yieldLoc);

307

308

309 rewriter.replaceOp(warpOp, replacements);

310

311 return success();

312 }

313

314 private:

316 };

317

318

319

320

321

322

323

324

325

326 static VectorType getDistributedType(VectorType originalType, AffineMap map,

327 int64_t warpSize) {

329 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {

331 if (targetShape[position] % warpSize != 0) {

332 if (warpSize % targetShape[position] != 0) {

333 return VectorType();

334 }

335 warpSize /= targetShape[position];

336 targetShape[position] = 1;

337 continue;

338 }

339 targetShape[position] = targetShape[position] / warpSize;

340 warpSize = 1;

341 break;

342 }

343 if (warpSize != 1) {

344 return VectorType();

345 }

346 VectorType targetType =

347 VectorType::get(targetShape, originalType.getElementType());

348 return targetType;

349 }

350

351

352

353

354

355

356

357

358

359

360

361

362

363

364

365

366

367

368

369

372 unsigned maxNumElementsToExtract, PatternBenefit b = 1)

374 maxNumElementsToExtract(maxNumElementsToExtract) {}

375

376

377

378 LogicalResult tryDistributeOp(RewriterBase &rewriter,

379 vector::TransferWriteOp writeOp,

380 WarpExecuteOnLane0Op warpOp) const {

381 VectorType writtenVectorType = writeOp.getVectorType();

382

383

384

385 if (writtenVectorType.getRank() == 0)

386 return failure();

387

388

389 AffineMap map = distributionMapFn(writeOp.getVector());

390 VectorType targetType =

391 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());

392 if (!targetType)

393 return failure();

394

395

396 VectorType maskType;

397 if (writeOp.getMask()) {

398

399

400

401

402

403

404 if (!writeOp.getPermutationMap().isMinorIdentity())

405 return failure();

406 maskType =

407 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());

408 }

409

410

411

412 vector::TransferWriteOp newWriteOp =

413 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);

414

415

416 auto newWarpOp =

417 newWriteOp.getVector().getDefiningOp();

418

419

420

421

424 for (auto [seqSize, distSize] :

425 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {

426 assert(seqSize % distSize == 0 && "Invalid distributed vector shape");

427 delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));

428 }

431 delinearized = rewriter

432 .createmlir::affine::AffineDelinearizeIndexOp(

433 newWarpOp.getLoc(), newWarpOp.getLaneid(),

434 delinearizedIdSizes)

435 .getResults();

436 } else {

437

438

439 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());

440 }

441

442 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());

443 Location loc = newWriteOp.getLoc();

445 newWriteOp.getIndices().end());

448 bindDims(newWarpOp.getContext(), d0, d1);

449 auto indexExpr = dyn_cast(std::get<0>(it));

450 if (!indexExpr)

451 continue;

452 unsigned indexPos = indexExpr.getPosition();

453 unsigned vectorPos = cast(std::get<1>(it)).getPosition();

454 Value laneId = delinearized[vectorPos];

455 auto scale =

458 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});

459 }

460 newWriteOp.getIndicesMutable().assign(indices);

461

462 return success();

463 }

464

465

466 LogicalResult tryExtractOp(RewriterBase &rewriter,

467 vector::TransferWriteOp writeOp,

468 WarpExecuteOnLane0Op warpOp) const {

469 Location loc = writeOp.getLoc();

470 VectorType vecType = writeOp.getVectorType();

471

472 if (vecType.getNumElements() > maxNumElementsToExtract) {

474 warpOp,

475 llvm::formatv(

476 "writes more elements ({0}) than allowed to extract ({1})",

477 vecType.getNumElements(), maxNumElementsToExtract));

478 }

479

480

481 if (llvm::all_of(warpOp.getOps(),

482 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))

483 return failure();

484

488 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(

489 rewriter, warpOp, yieldValues, retTypes, newRetIndices);

491

492

493 auto secondWarpOp = rewriter.create(

494 loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());

495 Block &body = secondWarpOp.getBodyRegion().front();

497 auto newWriteOp =

498 castvector::TransferWriteOp(rewriter.clone(*writeOp.getOperation()));

499 newWriteOp.getValueToStoreMutable().assign(

500 newWarpOp.getResult(newRetIndices[0]));

501 rewriter.eraseOp(writeOp);

502 rewriter.creategpu::YieldOp(newWarpOp.getLoc());

503 return success();

504 }

505

506 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

508 auto yield = castgpu::YieldOp(

509 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());

510 Operation *lastNode = yield->getPrevNode();

511 auto writeOp = dyn_cast_or_nullvector::TransferWriteOp(lastNode);

512 if (!writeOp)

513 return failure();

514

515 Value maybeMask = writeOp.getMask();

516 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {

517 return writeOp.getVector() == value ||

518 (maybeMask && maybeMask == value) ||

519 warpOp.isDefinedOutsideOfRegion(value);

520 }))

521 return failure();

522

523 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))

524 return success();

525

526

527 if (writeOp.getMask())

528 return failure();

529

530 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))

531 return success();

532

533 return failure();

534 }

535

536 private:

537

538

539

540 vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,

541 WarpExecuteOnLane0Op warpOp,

542 vector::TransferWriteOp writeOp,

543 VectorType targetType,

544 VectorType maybeMaskType) const {

545 assert(writeOp->getParentOp() == warpOp &&

546 "write must be nested immediately under warp");

549 WarpExecuteOnLane0Op newWarpOp;

550 if (maybeMaskType) {

551 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(

552 rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},

553 TypeRange{targetType, maybeMaskType}, newRetIndices);

554 } else {

555 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(

556 rewriter, warpOp, ValueRange{{writeOp.getVector()}},

557 TypeRange{targetType}, newRetIndices);

558 }

560 auto newWriteOp =

561 castvector::TransferWriteOp(rewriter.clone(*writeOp.getOperation()));

562 rewriter.eraseOp(writeOp);

563 newWriteOp.getValueToStoreMutable().assign(

564 newWarpOp.getResult(newRetIndices[0]));

565 if (maybeMaskType)

566 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));

567 return newWriteOp;

568 }

569

571 unsigned maxNumElementsToExtract = 1;

572 };

573

574

575

576

577

578

579

580

581

582

583

584

585

586

587

588

589

590

591

593 using Base::Base;

594 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

598 });

599 if (!yieldOperand)

600 return failure();

601

604 Value distributedVal = warpOp.getResult(operandIndex);

607 Location loc = warpOp.getLoc();

609 Type targetType;

610 if (auto vecType = dyn_cast(distributedVal.getType())) {

611

612 auto operandType = cast(operand.get().getType());

613 targetType =

614 VectorType::get(vecType.getShape(), operandType.getElementType());

615 } else {

616 auto operandType = operand.get().getType();

617 assert(!isa(operandType) &&

618 "unexpected yield of vector from op with scalar result type");

619 targetType = operandType;

620 }

621 retTypes.push_back(targetType);

622 yieldValues.push_back(operand.get());

623 }

625 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(

626 rewriter, warpOp, yieldValues, retTypes, newRetIndices);

630 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {

631 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);

632 }

636 rewriter, loc, elementWise, newOperands,

637 {newWarpOp.getResult(operandIndex).getType()});

640 return success();

641 }

642 };

643

644

645

646

647

648

649

650

651

652

653

654

655

656

657

659 using Base::Base;

660 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

663 getWarpResult(warpOp, llvm::IsaPredarith::ConstantOp);

664 if (!yieldOperand)

665 return failure();

666 auto constantOp = yieldOperand->get().getDefiningOparith::ConstantOp();

667 auto dense = dyn_cast(constantOp.getValue());

668 if (!dense)

669 return failure();

670

671

676 cast(warpOp.getResult(operandIndex).getType()), scalarAttr);

677 Location loc = warpOp.getLoc();

679 Value distConstant = rewriter.createarith::ConstantOp(loc, newAttr);

680 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);

682 return success();

683 }

684 };

685

686

687

688

689

690

691

692

693

694

695

696

697

698

699

700

701

702

703

705 using Base::Base;

706 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

708

709

710

712

713 return isavector::TransferReadOp(op) && op->hasOneUse();

714 });

715 if (!operand)

717 warpOp, "warp result is not a vector.transfer_read op");

718 auto read = operand->get().getDefiningOpvector::TransferReadOp();

719

720

721 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))

723 read, "source must be defined outside of the region");

724

726 Value distributedVal = warpOp.getResult(operandIndex);

727

729 read.getIndices().end());

730 auto sequentialType = cast(read.getResult().getType());

731 auto distributedType = cast(distributedVal.getType());

734

735

736

738 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),

739 distributedType.getShape(), warpOp.getWarpSize(),

740 warpOp.getLaneid(), delinearizedIds)) {

742 read, "cannot delinearize lane ID for distribution");

743 }

744 assert(!delinearizedIds.empty() || map.getNumResults() == 0);

745

746

748 SmallVector additionalResults(indices.begin(), indices.end());

751 additionalResults.push_back(read.getPadding());

752 additionalResultTypes.push_back(read.getPadding().getType());

753

754 bool hasMask = false;

755 if (read.getMask()) {

756 hasMask = true;

757

758

759

760

761

762

765 read, "non-trivial permutation maps not supported");

766 VectorType maskType =

767 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());

768 additionalResults.push_back(read.getMask());

769 additionalResultTypes.push_back(maskType);

770 }

771

773 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(

774 rewriter, warpOp, additionalResults, additionalResultTypes,

775 newRetIndices);

776 distributedVal = newWarpOp.getResult(operandIndex);

777

778

780 for (int64_t i = 0, e = indices.size(); i < e; ++i)

781 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));

782

786 bindDims(read.getContext(), d0, d1);

787 auto indexExpr = dyn_cast(std::get<0>(it));

788 if (!indexExpr)

789 continue;

790 unsigned indexPos = indexExpr.getPosition();

791 unsigned vectorPos = cast(std::get<1>(it)).getPosition();

792 int64_t scale = distributedType.getDimSize(vectorPos);

794 rewriter, read.getLoc(), d0 + scale * d1,

795 {newIndices[indexPos], delinearizedIds[vectorPos]});

796 }

797

798

799 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);

800

802 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])

804 auto newRead = rewriter.createvector::TransferReadOp(

805 read.getLoc(), distributedVal.getType(), read.getBase(), newIndices,

806 read.getPermutationMapAttr(), newPadding, newMask,

807 read.getInBoundsAttr());

808

810 return success();

811 }

812 };

813

814

815

817 using Base::Base;

818 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

821 newResultTypes.reserve(warpOp->getNumResults());

823 newYieldValues.reserve(warpOp->getNumResults());

826 auto yield = castgpu::YieldOp(

827 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());

828

829

830

831

832

833

834

835

836

837 for (OpResult result : warpOp.getResults()) {

838 Value yieldOperand = yield.getOperand(result.getResultNumber());

839 auto it = dedupYieldOperandPositionMap.insert(

840 std::make_pair(yieldOperand, newResultTypes.size()));

841 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));

842 if (result.use_empty() || !it.second)

843 continue;

844 newResultTypes.push_back(result.getType());

845 newYieldValues.push_back(yieldOperand);

846 }

847

848 if (yield.getNumOperands() == newYieldValues.size())

849 return failure();

850

851 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(

852 rewriter, warpOp, newYieldValues, newResultTypes);

853

854

855 newWarpOp.getBody()->walk([&](Operation *op) {

858 });

859

860

862 newValues.reserve(warpOp->getNumResults());

863 for (OpResult result : warpOp.getResults()) {

864 if (result.use_empty())

865 newValues.push_back(Value());

866 else

867 newValues.push_back(

868 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));

869 }

870 rewriter.replaceOp(warpOp, newValues);

871 return success();

872 }

873 };

874

875

876

878 using Base::Base;

879 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

881 auto yield = castgpu::YieldOp(

882 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());

883 Value valForwarded;

884 unsigned resultIndex;

885 for (OpOperand &operand : yield->getOpOperands()) {

888 continue;

889

890

891 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {

893 continue;

894 valForwarded = operand.get();

896 break;

897 }

898 auto arg = dyn_cast(operand.get());

899 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())

900 continue;

901 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];

903 continue;

904 valForwarded = warpOperand;

906 break;

907 }

908 if (!valForwarded)

909 return failure();

910

911

913 rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);

915 return success();

916 }

917 };

918

920 using Base::Base;

921 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

924 getWarpResult(warpOp, llvm::IsaPredvector::BroadcastOp);

925 if (!operand)

926 return failure();

928 auto broadcastOp = operand->get().getDefiningOpvector::BroadcastOp();

929 Location loc = broadcastOp.getLoc();

930 auto destVecType =

931 cast(warpOp->getResultTypes()[operandNumber]);

932 Value broadcastSrc = broadcastOp.getSource();

933 Type broadcastSrcType = broadcastSrc.getType();

934

935

936

937

938

941 return failure();

943 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(

944 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);

946 Value broadcasted = rewriter.createvector::BroadcastOp(

947 loc, destVecType, newWarpOp->getResult(newRetIndices[0]));

949 broadcasted);

950 return success();

951 }

952 };

953

954

955

957 using Base::Base;

958 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

961 getWarpResult(warpOp, llvm::IsaPredvector::ShapeCastOp);

962 if (!operand)

963 return failure();

964

965 auto oldCastOp = operand->get().getDefiningOpvector::ShapeCastOp();

966

968 auto castDistributedType =

969 cast(warpOp->getResultTypes()[operandNumber]);

970 VectorType castOriginalType = oldCastOp.getSourceVectorType();

971 VectorType castResultType = castDistributedType;

972

973

974

975 unsigned castDistributedRank = castDistributedType.getRank();

976 unsigned castOriginalRank = castOriginalType.getRank();

977 if (castDistributedRank < castOriginalRank) {

979 llvm::append_range(shape, castDistributedType.getShape());

980 castDistributedType =

981 VectorType::get(shape, castDistributedType.getElementType());

982 }

983

985 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(

986 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},

987 newRetIndices);

989 Value newCast = rewriter.createvector::ShapeCastOp(

990 oldCastOp.getLoc(), castResultType,

991 newWarpOp->getResult(newRetIndices[0]));

992 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);

993 return success();

994 }

995 };

996

997

998

999

1000

1001

1002

1003

1004

1005

1006

1007

1008

1009

1010

1011

1012

1013

1014

1016 using Base::Base;

1017 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

1020 getWarpResult(warpOp, llvm::IsaPredvector::CreateMaskOp);

1021 if (!yieldOperand)

1022 return failure();

1023

1024 auto mask = yieldOperand->get().getDefiningOpvector::CreateMaskOp();

1025

1026

1027

1028 if (!llvm::all_of(mask->getOperands(), [&](Value value) {

1029 return warpOp.isDefinedOutsideOfRegion(value);

1030 }))

1031 return failure();

1032

1033 Location loc = mask.getLoc();

1035

1036 auto distType = cast(warpOp.getResult(operandIndex).getType());

1037 VectorType seqType = mask.getVectorType();

1040

1042

1043

1045 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,

1046 warpOp.getWarpSize(), warpOp.getLaneid(),

1047 delinearizedIds))

1049 mask, "cannot delinearize lane ID for distribution");

1050 assert(!delinearizedIds.empty());

1051

1052

1053

1055

1059 for (int i = 0, e = distShape.size(); i < e; ++i) {

1060

1061

1062

1063

1064

1066 rewriter, loc, s1 - s0 * distShape[i],

1067 {delinearizedIds[i], mask.getOperand(i)});

1068 newOperands.push_back(maskDimIdx);

1069 }

1070

1071 auto newMask =

1072 rewriter.createvector::CreateMaskOp(loc, distType, newOperands);

1073 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);

1075 return success();

1076 }

1077 };

1078

1079

1080

1082 using Base::Base;

1083 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

1086 getWarpResult(warpOp, llvm::IsaPredvector::ExtractOp);

1087 if (!operand)

1088 return failure();

1090 auto extractOp = operand->get().getDefiningOpvector::ExtractOp();

1091 VectorType extractSrcType = extractOp.getSourceVectorType();

1092 Location loc = extractOp.getLoc();

1093

1094

1095 if (extractSrcType.getRank() <= 1) {

1096 return failure();

1097 }

1098

1099

1100

1101 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {

1102

1103

1104

1105

1106

1108 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(

1109 rewriter, warpOp, {extractOp.getVector()},

1110 {extractOp.getSourceVectorType()}, newRetIndices);

1112 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);

1113

1114 Value newExtract = rewriter.createvector::ExtractOp(

1115 loc, distributedVec, extractOp.getMixedPosition());

1117 newExtract);

1118 return success();

1119 }

1120

1121

1122 auto distributedType =

1123 cast(warpOp.getResult(operandNumber).getType());

1124 auto yieldedType = cast(operand->get().getType());

1125 int64_t distributedDim = -1;

1126 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {

1127 if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {

1128

1129

1130 assert(distributedDim == -1 && "found multiple distributed dims");

1131 distributedDim = i;

1132 }

1133 }

1134 assert(distributedDim != -1 && "could not find distributed dimension");

1135 (void)distributedDim;

1136

1137

1139 for (int i = 0; i < distributedType.getRank(); ++i)

1140 newDistributedShape[i + extractOp.getNumIndices()] =

1141 distributedType.getDimSize(i);

1142 auto newDistributedType =

1143 VectorType::get(newDistributedShape, distributedType.getElementType());

1145 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(

1146 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},

1147 newRetIndices);

1149 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);

1150

1151 Value newExtract = rewriter.createvector::ExtractOp(

1152 loc, distributedVec, extractOp.getMixedPosition());

1154 newExtract);

1155 return success();

1156 }

1157 };

1158

1159

1160

1162 WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,

1165 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

1168 getWarpResult(warpOp, llvm::IsaPredvector::ExtractOp);

1169 if (!operand)

1170 return failure();

1172 auto extractOp = operand->get().getDefiningOpvector::ExtractOp();

1173 VectorType extractSrcType = extractOp.getSourceVectorType();

1174

1175 if (extractSrcType.getRank() > 1) {

1177 extractOp, "only 0-D or 1-D source supported for now");

1178 }

1179

1180

1181 if (!extractSrcType.getElementType().isF32() &&

1182 !extractSrcType.getElementType().isInteger(32))

1184 extractOp, "only f32/i32 element types are supported");

1185 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;

1186 Type elType = extractSrcType.getElementType();

1187 VectorType distributedVecType;

1188 if (!is0dOrVec1Extract) {

1189 assert(extractSrcType.getRank() == 1 &&

1190 "expected that extract src rank is 0 or 1");

1191 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)

1192 return failure();

1193 int64_t elementsPerLane =

1194 extractSrcType.getShape()[0] / warpOp.getWarpSize();

1195 distributedVecType = VectorType::get({elementsPerLane}, elType);

1196 } else {

1197 distributedVecType = extractSrcType;

1198 }

1199

1202 additionalResults.append(

1204 additionalResultTypes.append(

1206

1207 Location loc = extractOp.getLoc();

1209 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(

1210 rewriter, warpOp, additionalResults, additionalResultTypes,

1211 newRetIndices);

1213 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);

1214

1215

1216

1217 if (is0dOrVec1Extract) {

1218 Value newExtract;

1220 newExtract =

1221 rewriter.createvector::ExtractOp(loc, distributedVec, indices);

1223 newExtract);

1224 return success();

1225 }

1226

1227 int64_t staticPos = extractOp.getStaticPosition()[0];

1228 OpFoldResult pos = ShapedType::isDynamic(staticPos)

1229 ? (newWarpOp->getResult(newRetIndices[1]))

1231

1232

1233 int64_t elementsPerLane = distributedVecType.getShape()[0];

1235

1237 rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);

1238

1240 elementsPerLane == 1

1241 ? rewriter.createarith::ConstantIndexOp(loc, 0).getResult()

1243 sym0 % elementsPerLane, pos);

1244 Value extracted =

1245 rewriter.createvector::ExtractOp(loc, distributedVec, newPos);

1246

1247

1248 Value shuffled = warpShuffleFromIdxFn(

1249 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());

1250 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);

1251 return success();

1252 }

1253

1254 private:

1255 WarpShuffleFromIdxFn warpShuffleFromIdxFn;

1256 };

1257

1258

1260 using Base::Base;

1261 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

1264 getWarpResult(warpOp, llvm::IsaPredvector::ExtractElementOp);

1265 if (!operand)

1266 return failure();

1267 auto extractOp = operand->get().getDefiningOpvector::ExtractElementOp();

1269 if (auto pos = extractOp.getPosition()) {

1270 indices.push_back(pos);

1271 }

1274 extractOp, extractOp.getVector(), indices);

1275 return success();

1276 }

1277 };

1278

1279

1280

1282 using Base::Base;

1283 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

1285 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPredvector::InsertOp);

1286 if (!operand)

1287 return failure();

1289 auto insertOp = operand->get().getDefiningOpvector::InsertOp();

1290 VectorType vecType = insertOp.getDestVectorType();

1291 VectorType distrType =

1292 cast(warpOp.getResult(operandNumber).getType());

1293

1294

1295 if (vecType.getRank() > 1) {

1297 insertOp, "only 0-D or 1-D source supported for now");

1298 }

1299

1300

1302 insertOp.getValueToStore()};

1304 distrType, insertOp.getValueToStore().getType()};

1305 additionalResults.append(SmallVector(insertOp.getDynamicPosition()));

1306 additionalResultTypes.append(

1308

1309 Location loc = insertOp.getLoc();

1311 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(

1312 rewriter, warpOp, additionalResults, additionalResultTypes,

1313 newRetIndices);

1315 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);

1316 Value newSource = newWarpOp->getResult(newRetIndices[1]);

1318

1320 if (vecType.getRank() != 0) {

1321 int64_t staticPos = insertOp.getStaticPosition()[0];

1322 pos = ShapedType::isDynamic(staticPos)

1323 ? (newWarpOp->getResult(newRetIndices[2]))

1325 }

1326

1327

1328 if (vecType == distrType) {

1329 Value newInsert;

1331 if (pos) {

1332 indices.push_back(pos);

1333 }

1334 newInsert = rewriter.createvector::InsertOp(loc, newSource,

1335 distributedVec, indices);

1336

1338 newInsert);

1339 return success();

1340 }

1341

1342

1343 int64_t elementsPerLane = distrType.getShape()[0];

1345

1347 rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);

1348

1350 rewriter, loc, sym0 % elementsPerLane, pos);

1351 Value isInsertingLane = rewriter.createarith::CmpIOp(

1352 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);

1353 Value newResult =

1354 rewriter

1356 loc, isInsertingLane,

1357

1359 Value newInsert = builder.createvector::InsertOp(

1360 loc, newSource, distributedVec, newPos);

1361 builder.createscf::YieldOp(loc, newInsert);

1362 },

1363

1365 builder.createscf::YieldOp(loc, distributedVec);

1366 })

1367 .getResult(0);

1368 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);

1369 return success();

1370 }

1371 };

1372

1374 using Base::Base;

1375 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

1377 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPredvector::InsertOp);

1378 if (!operand)

1379 return failure();

1381 auto insertOp = operand->get().getDefiningOpvector::InsertOp();

1382 Location loc = insertOp.getLoc();

1383

1384

1385 if (insertOp.getDestVectorType().getRank() <= 1) {

1386 return failure();

1387 }

1388

1389

1390

1391 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {

1392

1393

1395 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(

1396 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},

1397 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},

1398 newRetIndices);

1400 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);

1401 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);

1402 Value newResult = rewriter.createvector::InsertOp(

1403 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());

1405 newResult);

1406 return success();

1407 }

1408

1409

1410 auto distrDestType =

1411 cast(warpOp.getResult(operandNumber).getType());

1412 auto yieldedType = cast(operand->get().getType());

1413 int64_t distrDestDim = -1;

1414 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {

1415 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {

1416

1417

1418 assert(distrDestDim == -1 && "found multiple distributed dims");

1419 distrDestDim = i;

1420 }

1421 }

1422 assert(distrDestDim != -1 && "could not find distributed dimension");

1423

1424

1425 VectorType srcVecType = cast(insertOp.getValueToStoreType());

1427

1428

1429

1430

1431

1432

1433 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();

1434 if (distrSrcDim >= 0)

1435 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);

1436 auto distrSrcType =

1437 VectorType::get(distrSrcShape, distrDestType.getElementType());

1438

1439

1441 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(

1442 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},

1443 {distrSrcType, distrDestType}, newRetIndices);

1445 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);

1446 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);

1447

1448

1449 Value newResult;

1450 if (distrSrcDim >= 0) {

1451

1452 newResult = rewriter.createvector::InsertOp(

1453 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());

1454 } else {

1455

1456 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);

1459

1460 Value insertingLane = rewriter.createarith::ConstantIndexOp(

1461 loc, newPos[distrDestDim] / elementsPerLane);

1462 Value isInsertingLane = rewriter.createarith::CmpIOp(

1463 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);

1464

1465 newPos[distrDestDim] %= elementsPerLane;

1467 Value newInsert = builder.createvector::InsertOp(

1468 loc, distributedSrc, distributedDest, newPos);

1469 builder.createscf::YieldOp(loc, newInsert);

1470 };

1471 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {

1472 builder.createscf::YieldOp(loc, distributedDest);

1473 };

1474 newResult = rewriter

1475 .createscf::IfOp(loc, isInsertingLane,

1476 insertingBuilder,

1477 nonInsertingBuilder)

1478 .getResult(0);

1479 }

1480

1481 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);

1482 return success();

1483 }

1484 };

1485

1487 using Base::Base;

1488 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

1491 getWarpResult(warpOp, llvm::IsaPredvector::InsertElementOp);

1492 if (!operand)

1493 return failure();

1494 auto insertOp = operand->get().getDefiningOpvector::InsertElementOp();

1496 if (auto pos = insertOp.getPosition()) {

1497 indices.push_back(pos);

1498 }

1501 insertOp, insertOp.getSource(), insertOp.getDest(), indices);

1502 return success();

1503 }

1504 };

1505

1506

1507

1508

1509

1510

1511

1512

1513

1514

1515

1516

1517

1518

1519

1520

1521

1522

1523

1524

1525

1526

1527

1528

1529

1530

1531

1532

1533

1534

1535

1536

1537

1539

1542 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

1544 auto yield = castgpu::YieldOp(

1545 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());

1546

1547 Operation *lastNode = yield->getPrevNode();

1548 auto forOp = dyn_cast_or_nullscf::ForOp(lastNode);

1549 if (!forOp)

1550 return failure();

1551

1552

1553

1554 llvm::SmallSetVector<Value, 32> escapingValues;

1558 forOp.getBodyRegion(), [&](OpOperand *operand) {

1559 Operation *parent = operand->get().getParentRegion()->getParentOp();

1560 if (warpOp->isAncestor(parent)) {

1561 if (!escapingValues.insert(operand->get()))

1562 return;

1563 Type distType = operand->get().getType();

1564 if (auto vecType = dyn_cast(distType)) {

1565 AffineMap map = distributionMapFn(operand->get());

1566 distType = getDistributedType(vecType, map, warpOp.getWarpSize());

1567 }

1568 inputTypes.push_back(operand->get().getType());

1569 distTypes.push_back(distType);

1570 }

1571 });

1572

1573 if (llvm::is_contained(distTypes, Type{}))

1574 return failure();

1575

1577 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(

1578 rewriter, warpOp, escapingValues.getArrayRef(), distTypes,

1579 newRetIndices);

1580 yield = castgpu::YieldOp(

1581 newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());

1582

1585

1586 for (OpOperand &yieldOperand : yield->getOpOperands()) {

1587 if (yieldOperand.get().getDefiningOp() != forOp.getOperation())

1588 continue;

1589 auto forResult = cast(yieldOperand.get());

1590 newOperands.push_back(

1592 yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);

1594 }

1595

1598

1599

1600

1601 auto newForOp = rewriter.createscf::ForOp(

1602 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),

1603 forOp.getStep(), newOperands);

1605

1607 newForOp.getRegionIterArgs().end());

1609 forOp.getResultTypes().end());

1610 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;

1611 for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {

1612 warpInput.push_back(newWarpOp.getResult(retIdx));

1613 argIndexMapping[escapingValues[i]] = warpInputType.size();

1614 warpInputType.push_back(inputTypes[i]);

1615 }

1616 auto innerWarp = rewriter.create(

1617 newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),

1618 newWarpOp.getWarpSize(), warpInput, warpInputType);

1619

1621 argMapping.push_back(newForOp.getInductionVar());

1622 for (Value args : innerWarp.getBody()->getArguments()) {

1623 argMapping.push_back(args);

1624 }

1625 argMapping.resize(forOp.getBody()->getNumArguments());

1627 for (Value operand : forOp.getBody()->getTerminator()->getOperands())

1628 yieldOperands.push_back(operand);

1629 rewriter.eraseOp(forOp.getBody()->getTerminator());

1630 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);

1632 rewriter.creategpu::YieldOp(innerWarp.getLoc(), yieldOperands);

1634 if (!innerWarp.getResults().empty())

1635 rewriter.createscf::YieldOp(forOp.getLoc(), innerWarp.getResults());

1636 rewriter.eraseOp(forOp);

1637

1640 newForOp.getResult(res.index()));

1641 newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));

1642 }

1643 newForOp.walk([&](Operation *op) {

1645 auto it = argIndexMapping.find(operand.get());

1646 if (it == argIndexMapping.end())

1647 continue;

1648 operand.set(innerWarp.getBodyRegion().getArgument(it->second));

1649 }

1650 });

1651

1652

1653 mlir::vector::moveScalarUniformCode(innerWarp);

1654 return success();

1655 }

1656

1657 private:

1659 };

1660

1661

1662

1663

1664

1665

1666

1667

1668

1669

1670

1671

1672

1673

1674

1675

1676

1677

1678

1679

1682 DistributedReductionFn distributedReductionFn,

1685 distributedReductionFn(std::move(distributedReductionFn)) {}

1686

1687 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

1690 getWarpResult(warpOp, llvm::IsaPredvector::ReductionOp);

1691 if (!yieldOperand)

1692 return failure();

1693

1694 auto reductionOp =

1695 castvector::ReductionOp(yieldOperand->get().getDefiningOp());

1696 auto vectorType = cast(reductionOp.getVector().getType());

1697

1698 if (vectorType.getRank() != 1)

1700 warpOp, "Only rank 1 reductions can be distributed.");

1701

1702 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)

1704 warpOp, "Reduction vector dimension must match was size.");

1705 if (!reductionOp.getType().isIntOrFloat())

1707 warpOp, "Reduction distribution currently only supports floats and "

1708 "integer types.");

1709

1710 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();

1711

1716 if (reductionOp.getAcc()) {

1717 yieldValues.push_back(reductionOp.getAcc());

1718 retTypes.push_back(reductionOp.getAcc().getType());

1719 }

1721 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(

1722 rewriter, warpOp, yieldValues, retTypes, newRetIndices);

1724

1725

1726 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);

1727

1728 Value fullReduce =

1729 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,

1730 reductionOp.getKind(), newWarpOp.getWarpSize());

1731 if (reductionOp.getAcc()) {

1733 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,

1734 newWarpOp.getResult(newRetIndices[1]));

1735 }

1736 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);

1737 return success();

1738 }

1739

1740 private:

1741 DistributedReductionFn distributedReductionFn;

1742 };

1743

1744 }

1745

1750 }

1751

1752 void mlir::vector::populateDistributeTransferWriteOpPatterns(

1754 unsigned maxNumElementsToExtract, PatternBenefit benefit) {

1755 patterns.add(patterns.getContext(), distributionMapFn,

1756 maxNumElementsToExtract, benefit);

1757 }

1758

1759 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(

1761 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,

1763 patterns.add(patterns.getContext(), readBenefit);

1764 patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,

1765 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,

1766 WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,

1767 WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(

1768 patterns.getContext(), benefit);

1769 patterns.add(patterns.getContext(), warpShuffleFromIdxFn,

1770 benefit);

1771 patterns.add(patterns.getContext(), distributionMapFn,

1772 benefit);

1773 }

1774

1775 void mlir::vector::populateDistributeReduction(

1777 const DistributedReductionFn &distributedReductionFn,

1779 patterns.add(patterns.getContext(), distributedReductionFn,

1780 benefit);

1781 }

1782

1783

1786 return llvm::all_of(op->getOperands(), definedOutside) &&

1788 }

1789

1790 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {

1791 Block *body = warpOp.getBody();

1792

1793

1794 llvm::SmallSetVector<Operation *, 8> opsToMove;

1795

1796

1797 auto isDefinedOutsideOfBody = [&](Value value) {

1799 return (definingOp && opsToMove.count(definingOp)) ||

1800 warpOp.isDefinedOutsideOfRegion(value);

1801 };

1802

1803

1804

1806 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {

1807 return isa(result.getType());

1808 });

1809 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))

1810 opsToMove.insert(&op);

1811 }

1812

1813

1816 }

static llvm::ManagedStatic< PassManagerOptions > options

static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)

static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)

Currently the distribution map is implicit based on the vector shape.

static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)

Helper to know if an op can be hoisted out of the region.

Base type for affine expression.

AffineExpr ceilDiv(uint64_t v) const

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

unsigned getDimPosition(unsigned idx) const

Extracts the position of the dimensional expression at the given result, when the caller knows it is ...

static AffineMap get(MLIRContext *context)

Returns a zero result affine map with no dimensions or symbols: () -> ().

ArrayRef< AffineExpr > getResults() const

unsigned getNumResults() const

AffineMap compose(AffineMap map) const

Returns the AffineMap resulting from composing this with map.

bool isIdentity() const

Returns true if this affine map is an identity affine map.

Attributes are known-constant values of operations.

Block represents an ordered list of Operations.

BlockArgument getArgument(unsigned i)

iterator_range< iterator > without_terminator()

Return an iterator range over the operation within this block excluding the terminator operation at t...

IntegerAttr getIndexAttr(int64_t value)

AffineExpr getAffineConstantExpr(int64_t constant)

MLIRContext * getContext() const

static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)

Constructs a dense elements attribute from an array of element values.

IRValueT get() const

Return the current value being used by this operand.

void set(IRValueT newValue)

Set the current value being used by this operand.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

This class represents a single result from folding an operation.

This class represents an operand of an operation.

unsigned getOperandNumber()

Return which operand this is in the OpOperand list of the Operation.

This is a value defined by a result of an operation.

StringRef getStringRef() const

Return the name of this operation. This always succeeds.

Operation is the basic unit of execution within MLIR.

bool hasOneUse()

Returns true if this operation has exactly one use.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

unsigned getNumRegions()

Returns the number of regions held by this operation.

unsigned getNumOperands()

ArrayRef< NamedAttribute > getAttrs()

Return all of the attributes on this operation.

OperationName getName()

The name of an operation is the key identifier for it.

MutableArrayRef< OpOperand > getOpOperands()

operand_range getOperands()

Returns an iterator on the underlying Value's.

void moveBefore(Operation *existingOp)

Unlink this operation from its current block and insert it right before existingOp which may be in th...

result_range getResults()

This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

void replaceAllUsesWith(Value from, Value to)

Find uses of from and replace them with to.

void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)

Inline the operations of block 'source' into the end of block 'dest'.

virtual void finalizeOpModification(Operation *op)

This method is used to signal the end of an in-place modification of the given operation.

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

virtual void startOpModification(Operation *op)

This method is used to notify the rewriter that an in-place operation modification is about to happen...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

bool use_empty() const

Returns true if this value has no uses.

Type getType() const

Return the type of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

Region * getParentRegion()

Return the Region in which this Value is defined.

bool hasElementwiseMappableTraits(Operation *op)

Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...

AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...

OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)

Returns the result value of reducing two scalar/vector values with the corresponding arith operation.

BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)

void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)

SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)

Returns the integer numbers in values.

std::function< AffineMap(Value)> DistributionMapFn

Include the generated interface declarations.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to DimExpr at positions: [0 .

bool isMemoryEffectFree(Operation *op)

Returns true if the given operation is free of memory effects.

bool isOpTriviallyDead(Operation *op)

Return true if the given operation is unused, and has no side effects on memory that prevent erasing.

const FrozenRewritePatternSet & patterns

void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to SymbolExpr at positions: [0 .

AffineMap compressUnusedDims(AffineMap map)

Drop the dims that are not used.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)

These free functions allow clients of the API to not use classes in detail.

void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)

Calls callback for each use of a value within region or its descendants that was defined at the ances...

AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)

This represents an operation in an abstracted form, suitable for use with the builder APIs.