MLIR: lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

34 #include "llvm/ADT/SmallVector.h"

35 #include "llvm/Support/Debug.h"

36 #include "llvm/Support/MathExtras.h"

37 #include "llvm/Support/raw_ostream.h"

38 #include

39 #include

40

41 using namespace mlir;

42

43 #define DEBUG_TYPE "vector-narrow-type-emulation"

44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")

45 #define DBGSNL() (llvm::dbgs() << "\n")

46 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

47

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

81 int numSrcElems,

82 int numSrcElemsPerDest,

83 int numFrontPadElems = 0) {

84

85 assert(numFrontPadElems < numSrcElemsPerDest &&

86 "numFrontPadElems must be less than numSrcElemsPerDest");

87

88 auto numDestElems =

89 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /

90 numSrcElemsPerDest;

91

94

95

96 while (maskOp &&

97 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(

98 maskOp)) {

99 if (auto extractOp = dyn_castvector::ExtractOp(maskOp)) {

100 maskOp = extractOp.getVector().getDefiningOp();

101 extractOps.push_back(extractOp);

102 }

103 }

104

105 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(

106 maskOp))

107 return failure();

108

109

110

112 cast(maskOp->getResultTypes()[0]).getShape());

113 maskShape.back() = numDestElems;

115 std::optional<Operation *> newMask =

117 .Casevector::CreateMaskOp(

118 [&](auto createMaskOp) -> std::optional<Operation *> {

119 OperandRange maskOperands = createMaskOp.getOperands();

120

121

122

123

124

125

128 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);

131 rewriter, loc, s0, origIndex);

133 newMaskOperands.push_back(

135 return rewriter.createvector::CreateMaskOp(loc, newMaskType,

136 newMaskOperands);

137 })

138 .Casevector::ConstantMaskOp(

139 [&](auto constantMaskOp) -> std::optional<Operation *> {

140

142 constantMaskOp.getMaskDimSizes());

143 int64_t &maskIndex = maskDimSizes.back();

145 numSrcElemsPerDest);

146 return rewriter.createvector::ConstantMaskOp(loc, newMaskType,

147 maskDimSizes);

148 })

149 .Casearith::ConstantOp([&](auto constantOp)

150 -> std::optional<Operation *> {

151

152 if (maskShape.size() != 1)

153 return std::nullopt;

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168 auto originalMask =

169 cast(constantOp.getValue());

171 paddedMaskValues.append(originalMask.template value_begin(),

172 originalMask.template value_end());

173 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false);

174

175

177 for (size_t i = 0; i < paddedMaskValues.size();

178 i += numSrcElemsPerDest) {

179 bool combinedValue = false;

180 for (int j = 0; j < numSrcElemsPerDest; ++j) {

181 combinedValue |= paddedMaskValues[i + j];

182 }

183 compressedMaskValues.push_back(combinedValue);

184 }

185 return rewriter.createarith::ConstantOp(

187 });

188

189 if (!newMask)

190 return failure();

191

192 while (!extractOps.empty()) {

193 newMask = rewriter.createvector::ExtractOp(

194 loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());

195 extractOps.pop_back();

196 }

197

198 return *newMask;

199 }

200

201

202

203

204

205

206

207

208

209

210

211

212

213

215 Value src, int64_t offset,

216 int64_t numElemsToExtract) {

217 auto vectorType = cast(src.getType());

218 assert(vectorType.getRank() == 1 && "expected source to be rank-1-D vector ");

219 assert(offset + numElemsToExtract <= vectorType.getNumElements() &&

220 "subvector out of bounds");

221

222

223

224 if (vectorType.getNumElements() == numElemsToExtract)

225 return src;

226

228 auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract});

230

231 auto resultVectorType =

232 VectorType::get({numElemsToExtract}, vectorType.getElementType());

233 return rewriter

234 .createvector::ExtractStridedSliceOp(loc, resultVectorType, src,

235 offsets, sizes, strides)

236 ->getResult(0);

237 }

238

239

240

241

242

243

244

245

246

247

249 Value src, Value dest, int64_t offset) {

250 [[maybe_unused]] auto srcVecTy = cast(src.getType());

251 [[maybe_unused]] auto destVecTy = cast(dest.getType());

252 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&

253 "expected source and dest to be rank-1 vector types");

254

255

256 if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)

257 return src;

258

261 return rewriter.createvector::InsertStridedSliceOp(loc, destVecTy, src,

262 dest, offsets, strides);

263 }

264

265

266

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

283

287 int64_t numElemsToExtract) {

288 auto srcVecTy = cast(src.getType());

289 assert(srcVecTy.getRank() == 1 && "expected source to be rank-1-D vector ");

290

291

292

293 assert(numElemsToExtract <= srcVecTy.getNumElements() &&

294 "subvector out of bounds");

295

296

297

298 if (srcVecTy.getNumElements() == numElemsToExtract)

299 return src;

300

301 for (int i = 0; i < numElemsToExtract; ++i) {

302 Value extractLoc =

303 (i == 0) ? dyn_cast(offset)

304 : rewriter.createarith::AddIOp(

305 loc, rewriter.getIndexType(), dyn_cast(offset),

306 rewriter.createarith::ConstantIndexOp(loc, i));

307 auto extractOp = rewriter.createvector::ExtractOp(loc, src, extractLoc);

308 dest = rewriter.createvector::InsertOp(loc, extractOp, dest, i);

309 }

310 return dest;

311 }

312

313

314

315

316

317

318

319

320

321

322

323

324

325

326

330 int64_t numElemsToInsert) {

331 auto srcVecTy = cast(src.getType());

332 auto destVecTy = cast(dest.getType());

333 assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&

334 "expected source and dest to be rank-1 vector types");

335 (void)srcVecTy;

336 (void)destVecTy;

337 assert(numElemsToInsert > 0 &&

338 "the number of elements to insert must be greater than 0");

339

340

341

342 assert(numElemsToInsert <= destVecTy.getNumElements() &&

343 "subvector out of bounds");

344

346 for (int64_t i = 0; i < numElemsToInsert; ++i) {

347 auto insertLoc = i == 0

348 ? destOffsetVal

349 : rewriter.createarith::AddIOp(

351 rewriter.createarith::ConstantIndexOp(loc, i));

352 auto extractOp = rewriter.createvector::ExtractOp(loc, src, i);

353 dest = rewriter.createvector::InsertOp(loc, extractOp, dest, insertLoc);

354 }

355 return dest;

356 }

357

358

359

360

361

362

363

367 int64_t numContainerElemsToLoad,

368 Type emulatedElemTy,

369 Type containerElemTy) {

372 auto newLoad = rewriter.createvector::LoadOp(

373 loc, VectorType::get(numContainerElemsToLoad, containerElemTy), base,

375 return rewriter.createvector::BitCastOp(

376 loc,

377 VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,

378 emulatedElemTy),

379 newLoad);

380 }

381

382

383

385 VectorType downcastType,

386 VectorType upcastType, Value mask,

388 assert(

389 downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==

390 upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&

391 "expected input and output number of bits to match");

392 if (trueValue.getType() != downcastType) {

393 trueValue = builder.createvector::BitCastOp(loc, downcastType, trueValue);

394 }

395 if (falseValue.getType() != downcastType) {

396 falseValue =

397 builder.createvector::BitCastOp(loc, downcastType, falseValue);

398 }

399 Value selectedType =

400 builder.createarith::SelectOp(loc, mask, trueValue, falseValue);

401

402 return builder.createvector::BitCastOp(loc, upcastType, selectedType);

403 }

404

405

406

407

408

409

410

411

412

413

414

415

416

417

421 assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");

422

423

424

425 auto atomicOp = builder.creatememref::GenericAtomicRMWOp(

426 loc, linearizedMemref, ValueRange{storeIdx});

427 Value origValue = atomicOp.getCurrentValue();

428

431

432

433

434 auto oneElemVecType = VectorType::get({1}, origValue.getType());

435 Value origVecValue = builder.createvector::FromElementsOp(

436 loc, oneElemVecType, ValueRange{origValue});

437

438

439 Value maskedValue =

441 oneElemVecType, mask, valueToStore, origVecValue);

442 auto scalarMaskedValue =

443 builder.createvector::ExtractOp(loc, maskedValue, 0);

444 builder.creatememref::AtomicYieldOp(loc, scalarMaskedValue);

445 }

446

447

448

452 assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");

453

454 auto oneElemVecType =

456 Value origVecValue = builder.createvector::LoadOp(

457 loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});

458 origVecValue = builder.createvector::BitCastOp(loc, valueToStore.getType(),

459 origVecValue);

460

461 Value maskedValue =

463 oneElemVecType, mask, valueToStore, origVecValue);

464 builder.createvector::StoreOp(loc, maskedValue, linearizedMemref,

465 linearizedIndex);

466 }

467

468

469

470

471

472

473

474

475

476

479 int64_t extractOffset,

480 int64_t sliceNumElements,

481 int64_t insertOffset) {

482 assert(vector.getType().getRank() == 1 && "expected 1-D vector");

483 auto vectorElementType = vector.getType().getElementType();

484

485

486 assert(

487 sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&

488 "sliceNumElements * vector element size must be less than or equal to 8");

489 assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&

490 "vector element must be a valid sub-byte type");

491 auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();

492 auto emptyByteVector = rewriter.createarith::ConstantOp(

493 loc, VectorType::get({emulatedPerContainerElem}, vectorElementType),

495 VectorType::get({emulatedPerContainerElem}, vectorElementType)));

497 extractOffset, sliceNumElements);

499 insertOffset);

500 }

501

502 namespace {

503

504

505

506

507

508

509

510

511

512

513

514

515

516

517

518

519

520

521

522

523

524

525

526

527

528

529

530

531

532

533

534

535

536

537

538

539

540

541

542

543

544

545

546

549

550 ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)

552 disableAtomicRMW(disableAtomicRMW) {}

553

554 LogicalResult

555 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,

557

558

559 if (op.getValueToStore().getType().getRank() != 1)

561 "only 1-D vectors are supported ATM");

562

563 auto loc = op.getLoc();

564

565 auto valueToStore = cast(op.getValueToStore());

566 auto containerElemTy =

567 cast(adaptor.getBase().getType()).getElementType();

568 Type emulatedElemTy = op.getValueToStore().getType().getElementType();

570 int containerBits = containerElemTy.getIntOrFloatBitWidth();

571

572

573 if (containerBits % emulatedBits != 0) {

575 op, "impossible to pack emulated elements into container elements "

576 "(bit-wise misalignment)");

577 }

578 int emulatedPerContainerElem = containerBits / emulatedBits;

579

580

581

582

583

584

585

586

587

588

589

590

591

592

593 auto origElements = valueToStore.getType().getNumElements();

594

595 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;

596

597

598

599

600 auto trailingDim = op.getBase().getType().getShape().back();

601 bool trailingDimsMatch =

602 ShapedType::isDynamic(trailingDim) || trailingDim == origElements;

603

604 auto stridedMetadata =

605 rewriter.creatememref::ExtractStridedMetadataOp(loc, op.getBase());

606

607

608

611 std::tie(linearizedInfo, linearizedIndices) =

613 rewriter, loc, emulatedBits, containerBits,

614 stridedMetadata.getConstifiedMixedOffset(),

615 stridedMetadata.getConstifiedMixedSizes(),

616 stridedMetadata.getConstifiedMixedStrides(),

618

619 std::optional<int64_t> foldedNumFrontPadElems =

620 (isDivisibleInSize && trailingDimsMatch)

621 ? 0

623

624 if (!foldedNumFrontPadElems) {

626 op, "subbyte store emulation: dynamic front padding size is "

627 "not yet implemented");

628 }

629

630 auto memrefBase = cast(adaptor.getBase());

631

632

633

634

635

636

637

638

639

640

641

642

643

644

645

646

647

648

649

650

651

652

653

654

655

656

657

658

659

660

661

662 bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;

663

664 if (!emulationRequiresPartialStores) {

665

666 auto numElements = origElements / emulatedPerContainerElem;

667 auto bitCast = rewriter.createvector::BitCastOp(

669 op.getValueToStore());

671 op, bitCast.getResult(), memrefBase,

673 return success();

674 }

675

676

677

678

679

680

681

682

683

684

685

686

687

688

689

690

691

692

693

694

695

696

697

698

699

700

701

702

703

704

705

706

707 Value currentDestIndex =

709

710 auto currentSourceIndex = 0;

711

712

713 auto subWidthStoreMaskType =

715

717

718

719

720

721

722 auto frontSubWidthStoreElem =

723 (emulatedPerContainerElem - *foldedNumFrontPadElems) %

724 emulatedPerContainerElem;

725 if (frontSubWidthStoreElem > 0) {

726 SmallVector frontMaskValues(emulatedPerContainerElem, false);

727 if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {

728 std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,

729 origElements, true);

730 frontSubWidthStoreElem = origElements;

731 } else {

732 std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,

733 *foldedNumFrontPadElems, true);

734 }

735 auto frontMask = rewriter.createarith::ConstantOp(

737

738 currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);

739 auto value =

741 frontSubWidthStoreElem, *foldedNumFrontPadElems);

742

743 storeFunc(rewriter, loc, memrefBase, currentDestIndex,

744 cast(value), frontMask.getResult());

745 }

746

747 if (currentSourceIndex >= origElements) {

749 return success();

750 }

751

752

753

754 auto constantOne = rewriter.createarith::ConstantIndexOp(loc, 1);

755 currentDestIndex = rewriter.createarith::AddIOp(

757

758

759

760

761 int64_t fullWidthStoreSize =

762 (origElements - currentSourceIndex) / emulatedPerContainerElem;

763 int64_t numNonFullWidthElements =

764 fullWidthStoreSize * emulatedPerContainerElem;

765 if (fullWidthStoreSize > 0) {

767 rewriter, loc, valueToStore, currentSourceIndex,

768 numNonFullWidthElements);

769

770 auto originType = cast(fullWidthStorePart.getType());

773 {originType.getNumElements() / emulatedPerContainerElem},

774 memrefElemType);

775 auto bitCast = rewriter.createvector::BitCastOp(loc, storeType,

776 fullWidthStorePart);

777 rewriter.createvector::StoreOp(loc, bitCast.getResult(), memrefBase,

778 currentDestIndex);

779

780 currentSourceIndex += numNonFullWidthElements;

781 currentDestIndex = rewriter.createarith::AddIOp(

782 loc, rewriter.getIndexType(), currentDestIndex,

783 rewriter.createarith::ConstantIndexOp(loc, fullWidthStoreSize));

784 }

785

786

787

788

789 auto remainingElements = origElements - currentSourceIndex;

790 if (remainingElements != 0) {

791 auto subWidthStorePart =

793 currentSourceIndex, remainingElements, 0);

794

795

796 auto maskValues = SmallVector(emulatedPerContainerElem, 0);

797 std::fill_n(maskValues.begin(), remainingElements, 1);

798 auto backMask = rewriter.createarith::ConstantOp(

800

801 storeFunc(rewriter, loc, memrefBase, currentDestIndex,

802 cast(subWidthStorePart), backMask.getResult());

803 }

804

806 return success();

807 }

808

809 private:

810 const bool disableAtomicRMW;

811 };

812

813

814

815

816

817

818 struct ConvertVectorMaskedStore final

821

822 LogicalResult

823 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,

825

826

827 if (op.getValueToStore().getType().getRank() != 1)

829 "only 1-D vectors are supported ATM");

830

831 auto loc = op.getLoc();

832 auto containerElemTy =

833 cast(adaptor.getBase().getType()).getElementType();

834 Type emulatedElemTy = op.getValueToStore().getType().getElementType();

836 int containerBits = containerElemTy.getIntOrFloatBitWidth();

837

838

839 if (containerBits % emulatedBits != 0) {

841 op, "impossible to pack emulated elements into container elements "

842 "(bit-wise misalignment)");

843 }

844

845 int emulatedPerContainerElem = containerBits / emulatedBits;

846 int origElements = op.getValueToStore().getType().getNumElements();

847 if (origElements % emulatedPerContainerElem != 0)

848 return failure();

849

850 auto stridedMetadata =

851 rewriter.creatememref::ExtractStridedMetadataOp(loc, op.getBase());

854 std::tie(linearizedInfo, linearizedIndicesOfr) =

856 rewriter, loc, emulatedBits, containerBits,

857 stridedMetadata.getConstifiedMixedOffset(),

858 stridedMetadata.getConstifiedMixedSizes(),

859 stridedMetadata.getConstifiedMixedStrides(),

861 Value linearizedIndices =

863

864

865

866

867

868

869

870

871

872

873

874

875

876

877

878

879

880

881

882

883

884

885

886

887

888

889

890

891

892

893

894

895

897 rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);

898 if (failed(newMask))

899 return failure();

900

901 auto numElements = (origElements + emulatedPerContainerElem - 1) /

902 emulatedPerContainerElem;

903 auto newType = VectorType::get(numElements, containerElemTy);

904 auto passThru = rewriter.createarith::ConstantOp(

905 loc, newType, rewriter.getZeroAttr(newType));

906

907 auto newLoad = rewriter.createvector::MaskedLoadOp(

908 loc, newType, adaptor.getBase(), linearizedIndices,

909 newMask.value()->getResult(0), passThru);

910

911 auto newBitCastType =

912 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);

913 Value valueToStore =

914 rewriter.createvector::BitCastOp(loc, newBitCastType, newLoad);

915 valueToStore = rewriter.createarith::SelectOp(

916 loc, op.getMask(), op.getValueToStore(), valueToStore);

917 valueToStore =

918 rewriter.createvector::BitCastOp(loc, newType, valueToStore);

919

921 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),

922 valueToStore);

923 return success();

924 }

925 };

926

927

928

929

930

931

934

935 LogicalResult

936 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,

938

939

940 if (op.getVectorType().getRank() != 1)

942 "only 1-D vectors are supported ATM");

943

944 auto loc = op.getLoc();

945 auto containerElemTy =

946 cast(adaptor.getBase().getType()).getElementType();

947 Type emulatedElemTy = op.getType().getElementType();

949 int containerBits = containerElemTy.getIntOrFloatBitWidth();

950

951

952 if (containerBits % emulatedBits != 0) {

954 op, "impossible to pack emulated elements into container elements "

955 "(bit-wise misalignment)");

956 }

957 int emulatedPerContainerElem = containerBits / emulatedBits;

958

959

960

961

962

963

964

965

966

967

968

969

970

971

972

973

974

975

976

977

978

979

980

981

982

983

984

985

986

987

988 auto origElements = op.getVectorType().getNumElements();

989

990 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;

991

992 auto stridedMetadata =

993 rewriter.creatememref::ExtractStridedMetadataOp(loc, op.getBase());

994

997 std::tie(linearizedInfo, linearizedIndices) =

999 rewriter, loc, emulatedBits, containerBits,

1000 stridedMetadata.getConstifiedMixedOffset(),

1001 stridedMetadata.getConstifiedMixedSizes(),

1002 stridedMetadata.getConstifiedMixedStrides(),

1004

1005 std::optional<int64_t> foldedIntraVectorOffset =

1006 isDivisibleInSize ? 0

1008

1009

1010 int64_t maxintraDataOffset =

1011 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);

1012 auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,

1013 emulatedPerContainerElem);

1015 emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,

1016 numElements, emulatedElemTy, containerElemTy);

1017

1018 if (!foldedIntraVectorOffset) {

1019 auto resultVector = rewriter.createarith::ConstantOp(

1020 loc, op.getType(), rewriter.getZeroAttr(op.getType()));

1024 } else if (!isDivisibleInSize) {

1026 rewriter, loc, result, *foldedIntraVectorOffset, origElements);

1027 }

1029 return success();

1030 }

1031 };

1032

1033

1034

1035

1036

1037

1038 struct ConvertVectorMaskedLoad final

1041

1042 LogicalResult

1043 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,

1045

1046 if (op.getVectorType().getRank() != 1)

1048 "only 1-D vectors are supported ATM");

1049

1050 auto loc = op.getLoc();

1051

1052 auto containerElemTy =

1053 cast(adaptor.getBase().getType()).getElementType();

1054 Type emulatedElemTy = op.getType().getElementType();

1056 int containerBits = containerElemTy.getIntOrFloatBitWidth();

1057

1058

1059 if (containerBits % emulatedBits != 0) {

1061 op, "impossible to pack emulated elements into container elements "

1062 "(bit-wise misalignment)");

1063 }

1064 int emulatedPerContainerElem = containerBits / emulatedBits;

1065

1066

1067

1068

1069

1070

1071

1072

1073

1074

1075

1076

1077

1078

1079

1080

1081

1082

1083

1084

1085

1086

1087

1088

1089

1090

1091

1092

1093

1094

1095

1096

1097

1098

1099

1100

1101

1102

1103

1104

1105

1106

1107

1108 auto origType = op.getVectorType();

1109 auto origElements = origType.getNumElements();

1110

1111 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;

1112

1113 auto stridedMetadata =

1114 rewriter.creatememref::ExtractStridedMetadataOp(loc, op.getBase());

1117 std::tie(linearizedInfo, linearizedIndices) =

1119 rewriter, loc, emulatedBits, containerBits,

1120 stridedMetadata.getConstifiedMixedOffset(),

1121 stridedMetadata.getConstifiedMixedSizes(),

1122 stridedMetadata.getConstifiedMixedStrides(),

1124

1125 std::optional<int64_t> foldedIntraVectorOffset =

1126 isDivisibleInSize ? 0

1128

1129 int64_t maxIntraDataOffset =

1130 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);

1131 FailureOr<Operation *> newMask =

1133 emulatedPerContainerElem, maxIntraDataOffset);

1134 if (failed(newMask))

1135 return failure();

1136

1137 Value passthru = op.getPassThru();

1138

1139 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,

1140 emulatedPerContainerElem);

1141 auto loadType = VectorType::get(numElements, containerElemTy);

1142 auto newBitcastType =

1143 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);

1144

1145 auto emptyVector = rewriter.createarith::ConstantOp(

1146 loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));

1147 if (!foldedIntraVectorOffset) {

1149 rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,

1150 origElements);

1151 } else if (!isDivisibleInSize) {

1153 *foldedIntraVectorOffset);

1154 }

1155 auto newPassThru =

1156 rewriter.createvector::BitCastOp(loc, loadType, passthru);

1157

1158

1159 auto newLoad = rewriter.createvector::MaskedLoadOp(

1160 loc, loadType, adaptor.getBase(),

1162 newMask.value()->getResult(0), newPassThru);

1163

1164

1165

1166 auto bitCast =

1167 rewriter.createvector::BitCastOp(loc, newBitcastType, newLoad);

1168

1169 Value mask = op.getMask();

1171 numElements * emulatedPerContainerElem, rewriter.getI1Type());

1172

1173 auto emptyMask = rewriter.createarith::ConstantOp(

1174 loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));

1175 if (!foldedIntraVectorOffset) {

1178 origElements);

1179 } else if (!isDivisibleInSize) {

1181 *foldedIntraVectorOffset);

1182 }

1183

1185 rewriter.createarith::SelectOp(loc, mask, bitCast, passthru);

1186 if (!foldedIntraVectorOffset) {

1188 rewriter, loc, result, op.getPassThru(),

1190 } else if (!isDivisibleInSize) {

1192 rewriter, loc, result, *foldedIntraVectorOffset, origElements);

1193 }

1195

1196 return success();

1197 }

1198 };

1199

1200

1201

1202

1203

1204

1205

1206

1207

1208

1209

1210

1211

1212

1213

1214

1215 static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,

1216 Type multiByteScalarTy) {

1217 assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");

1218

1219 int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();

1221

1222 assert(subByteBits < 8 && "Not a sub-byte scalar type!");

1223 assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");

1224 assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");

1225

1226 int elemsPerMultiByte = multiByteBits / subByteBits;

1227

1228

1229 return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;

1230 }

1231

1232

1233

1234

1235

1236

1237 struct ConvertVectorTransferRead final

1240

1241 LogicalResult

1242 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,

1244

1245

1246 if (op.getVectorType().getRank() != 1)

1248 "only 1-D vectors are supported ATM");

1249

1250 auto loc = op.getLoc();

1251 auto containerElemTy =

1252 cast(adaptor.getBase().getType()).getElementType();

1253 Type emulatedElemTy = op.getType().getElementType();

1255 int containerBits = containerElemTy.getIntOrFloatBitWidth();

1256

1257

1258 if (containerBits % emulatedBits != 0) {

1260 op, "impossible to pack emulated elements into container elements "

1261 "(bit-wise misalignment)");

1262 }

1263 int emulatedPerContainerElem = containerBits / emulatedBits;

1264

1265 auto origElements = op.getVectorType().getNumElements();

1266

1267

1268 bool isDivisibleInSize =

1269 fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);

1270

1271 auto newPadding = rewriter.createarith::ExtUIOp(loc, containerElemTy,

1272 adaptor.getPadding());

1273

1274 auto stridedMetadata =

1275 rewriter.creatememref::ExtractStridedMetadataOp(loc, op.getBase());

1276

1279 std::tie(linearizedInfo, linearizedIndices) =

1281 rewriter, loc, emulatedBits, containerBits,

1282 stridedMetadata.getConstifiedMixedOffset(),

1283 stridedMetadata.getConstifiedMixedSizes(),

1284 stridedMetadata.getConstifiedMixedStrides(),

1286

1287 std::optional<int64_t> foldedIntraVectorOffset =

1288 isDivisibleInSize ? 0

1290

1291 int64_t maxIntraDataOffset =

1292 foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);

1293 auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,

1294 emulatedPerContainerElem);

1295

1296 auto newRead = rewriter.createvector::TransferReadOp(

1297 loc, VectorType::get(numElements, containerElemTy), adaptor.getBase(),

1299 newPadding);

1300

1301 auto bitCast = rewriter.createvector::BitCastOp(

1302 loc,

1303 VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),

1304 newRead);

1305

1306 Value result = bitCast->getResult(0);

1307 if (!foldedIntraVectorOffset) {

1308 auto zeros = rewriter.createarith::ConstantOp(

1309 loc, op.getType(), rewriter.getZeroAttr(op.getType()));

1312 origElements);

1313 } else if (!isDivisibleInSize) {

1315 rewriter, loc, result, *foldedIntraVectorOffset, origElements);

1316 }

1318

1319 return success();

1320 }

1321 };

1322 }

1323

1324

1325

1326

1327

1328 namespace {

1329

1330

1331

1332 struct SourceElementRange {

1333

1334 int64_t sourceElementIdx;

1335

1336 int64_t sourceBitBegin;

1337 int64_t sourceBitEnd;

1338 };

1339

1340 struct SourceElementRangeList : public SmallVector {

1341

1342

1343

1344

1345

1346 int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {

1347 int64_t res = 0;

1348 for (int64_t i = 0; i < shuffleIdx; ++i)

1349 res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;

1350 return res;

1351 }

1352 };

1353

1354

1355

1356

1357

1358

1359

1360

1361

1362

1363

1364

1365

1366

1367

1368 struct BitCastBitsEnumerator {

1369 BitCastBitsEnumerator(VectorType sourceVectorType,

1370 VectorType targetVectorType);

1371

1372 int64_t getMaxNumberOfEntries() {

1373 int64_t numVectors = 0;

1374 for (const auto &l : sourceElementRanges)

1375 numVectors = std::max(numVectors, (int64_t)l.size());

1376 return numVectors;

1377 }

1378

1379 VectorType sourceVectorType;

1380 VectorType targetVectorType;

1382 };

1383

1384

1385

1386

1387

1388

1389

1390

1391

1392

1393

1394

1395

1396

1397

1398

1399

1400

1401

1402

1403

1404

1405

1406

1407

1408

1409

1410

1411

1412

1413

1414

1415

1416

1417

1418

1419

1420

1421

1422

1423

1424

1425

1426

1427

1428

1429

1430

1431

1432

1433

1434

1435

1436

1437

1438

1439

1440

1441

1442

1443

1444

1445

1446

1447

1448

1449

1450

1451

1452

1453

1454

1455 struct BitCastRewriter {

1456

1457 struct Metadata {

1460 };

1461

1462 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);

1463

1464

1465 LogicalResult commonPrecondition(PatternRewriter &rewriter,

1466 VectorType preconditionType, Operation *op);

1467

1468

1470 precomputeMetadata(IntegerType shuffledElementType);

1471

1472

1473

1475 Value initialValue, Value runningResult,

1476 const BitCastRewriter::Metadata &metadata);

1477

1478 private:

1479

1480

1481 BitCastBitsEnumerator enumerator;

1482 };

1483

1484 }

1485

1486 [[maybe_unused]] static raw_ostream &

1488 for (const auto &l : vec) {

1490 os << "{ " << it.value().sourceElementIdx << ": b@["

1491 << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd

1492 << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";

1493 }

1494 os << "\n";

1495 }

1496 return os;

1497 }

1498

1499 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,

1500 VectorType targetVectorType)

1501 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {

1502

1503 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&

1504 "requires -D non-scalable vector type");

1505 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&

1506 "requires -D non-scalable vector type");

1507 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();

1508 int64_t mostMinorSourceDim = sourceVectorType.getShape().back();

1509 LDBG("sourceVectorType: " << sourceVectorType);

1510

1511 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();

1512 int64_t mostMinorTargetDim = targetVectorType.getShape().back();

1513 LDBG("targetVectorType: " << targetVectorType);

1514

1515 int64_t bitwidth = targetBitWidth * mostMinorTargetDim;

1516 (void)mostMinorSourceDim;

1517 assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&

1518 "source and target bitwidths must match");

1519

1520

1522 for (int64_t resultBit = 0; resultBit < bitwidth;) {

1523 int64_t resultElement = resultBit / targetBitWidth;

1524 int64_t resultBitInElement = resultBit % targetBitWidth;

1525 int64_t sourceElementIdx = resultBit / sourceBitWidth;

1526 int64_t sourceBitInElement = resultBit % sourceBitWidth;

1527 int64_t step = std::min(sourceBitWidth - sourceBitInElement,

1528 targetBitWidth - resultBitInElement);

1529 sourceElementRanges[resultElement].push_back(

1530 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});

1531 resultBit += step;

1532 }

1533 }

1534

1535 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,

1536 VectorType targetVectorType)

1537 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {

1538 LDBG("\n" << enumerator.sourceElementRanges);

1539 }

1540

1541

1542

1544 VectorType preconditionType,

1546 if (!preconditionType || preconditionType.isScalable())

1548

1549

1550

1551 unsigned bitwidth = preconditionType.getElementTypeBitWidth();

1552 if (bitwidth % 8 != 0)

1554

1555 return success();

1556 }

1557

1558 LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,

1559 VectorType preconditionType,

1561 if (!enumerator.sourceVectorType || !enumerator.targetVectorType)

1563

1564 if (!preconditionType || preconditionType.getRank() != 1)

1566

1568 }

1569

1570

1571

1572

1573

1574

1575

1576

1577

1578

1579

1580

1581

1582

1583

1584

1585

1586

1587

1588

1589

1590

1591

1592

1593

1594

1595

1596

1597

1598

1599

1600

1602 VectorType subByteVecTy,

1603 Type containerTy,

1606 "container element type is not a scalar");

1607

1608

1609

1610 if (!subByteVecTy)

1612

1613 unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();

1615

1616

1617 assert(containerBits % 8 == 0 && "Not a multi-byte scalar type!");

1618

1619

1620 if (subByteBits != 2 && subByteBits != 4)

1622 op, "only 2-bit and 4-bit sub-byte type is supported at this moment");

1623

1624

1625 if (containerBits % subByteBits != 0)

1627

1628

1629 if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))

1631 op, "not possible to fit this sub-byte vector type into a vector of "

1632 "the given multi-byte type");

1633

1634 return success();

1635 }

1636

1638 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {

1640 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();

1641 shuffleIdx < e; ++shuffleIdx) {

1644

1645

1646 for (auto &srcEltRangeList : enumerator.sourceElementRanges) {

1647 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())

1648 ? srcEltRangeList[shuffleIdx].sourceElementIdx

1649 : 0;

1650 shuffles.push_back(sourceElement);

1651

1652 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())

1653 ? srcEltRangeList[shuffleIdx].sourceBitBegin

1654 : 0;

1655 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())

1656 ? srcEltRangeList[shuffleIdx].sourceBitEnd

1657 : 0;

1659 shuffledElementType,

1660 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),

1661 bitLo, bitHi));

1662 masks.push_back(mask);

1663

1664 int64_t shiftRight = bitLo;

1665 shiftRightAmounts.push_back(

1667

1668 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);

1669 shiftLeftAmounts.push_back(

1671 }

1672

1673 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});

1674 }

1675 return result;

1676 }

1677

1678 Value BitCastRewriter::genericRewriteStep(

1680 Value runningResult, const BitCastRewriter::Metadata &metadata) {

1681

1682 auto shuffleOp = rewriter.createvector::ShuffleOp(

1683 loc, initialValue, initialValue, metadata.shuffles);

1684

1685

1686 VectorType shuffledVectorType = shuffleOp.getResultVectorType();

1687 auto constOp = rewriter.createarith::ConstantOp(

1689 Value andValue = rewriter.createarith::AndIOp(loc, shuffleOp, constOp);

1690

1691

1692 auto shiftRightConstantOp = rewriter.createarith::ConstantOp(

1693 loc,

1695 Value shiftedRight =

1696 rewriter.createarith::ShRUIOp(loc, andValue, shiftRightConstantOp);

1697

1698

1699 auto shiftLeftConstantOp = rewriter.createarith::ConstantOp(

1700 loc,

1702 Value shiftedLeft =

1703 rewriter.createarith::ShLIOp(loc, shiftedRight, shiftLeftConstantOp);

1704

1705 runningResult =

1706 runningResult

1707 ? rewriter.createarith::OrIOp(loc, runningResult, shiftedLeft)

1708 : shiftedLeft;

1709

1710 return runningResult;

1711 }

1712

1713

1714

1715

1716

1717

1718

1720 Value subByteVec) {

1721 auto srcVecType = cast(subByteVec.getType());

1722 int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();

1723 assert(8 % srcBitwidth == 0 &&

1724 "Unsupported sub-byte type (not a divisor of i8)");

1725 int64_t numSrcElemsPerByte = 8 / srcBitwidth;

1727

1728 vecShape.back() = vecShape.back() / numSrcElemsPerByte;

1730 return rewriter.createvector::BitCastOp(loc, i8VecType, subByteVec);

1731 }

1732

1733

1734

1735

1736

1737

1738

1739

1740

1741

1742

1743

1744

1745

1746

1747

1748

1751 int bitIdx, int numBits) {

1752 auto srcType = cast(src.getType());

1753 Value shl = src;

1754 int8_t bitsToShiftLeft = 8 - numBits - bitIdx;

1755 assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&

1756 "Invalid bitIdx range");

1757 if (bitsToShiftLeft != 0) {

1758 Value shiftLeftValues = rewriter.createarith::ConstantOp(

1760 shl = rewriter.createarith::ShLIOp(loc, src, shiftLeftValues);

1761 }

1762

1763 int8_t bitsToShiftRight = 8 - numBits;

1764 Value shiftRightValues = rewriter.createarith::ConstantOp(

1766 Value shr = rewriter.createarith::ShRSIOp(loc, shl, shiftRightValues);

1767 return shr;

1768 }

1769

1770

1771

1772

1773

1774

1775

1776

1777

1778

1779

1780

1781

1782

1783

1784

1785

1786

1787

1788

1789

1790

1793 int bitIdx, int numBits) {

1794 assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&

1795 "Invalid bitIdx range");

1796 auto srcType = cast(src.getType());

1797 int8_t bitsToShiftRight = bitIdx;

1798 Value shr = src;

1799 if (bitsToShiftRight != 0) {

1800 Value shiftRightValues = rewriter.createarith::ConstantOp(

1802 shr = rewriter.createarith::ShRUIOp(loc, src, shiftRightValues);

1803 }

1804 if (bitIdx + numBits == 8) {

1805 return shr;

1806 }

1807 uint8_t lowBitsMask = (1 << numBits) - 1;

1808 Value lowBitsMaskValues = rewriter.createarith::ConstantOp(

1810 return rewriter.createarith::AndIOp(loc, shr, lowBitsMaskValues);

1811 }

1812

1815

1816

1817

1820 [[maybe_unused]] auto srcVecType = cast(srcValue.getType());

1821 assert(srcVecType.getElementType().isSignlessInteger(4) &&

1822 "Expected i4 type");

1823

1824

1826

1827

1828

1829 Value low = extFn(rewriter, loc, i8Vector, 0, 4);

1830 Value high = extFn(rewriter, loc, i8Vector, 4, 4);

1831

1832

1833 return rewriter.createvector::InterleaveOp(loc, low, high);

1834 }

1835

1836

1837

1840 [[maybe_unused]] VectorType srcVecType = cast(srcValue.getType());

1841 assert(srcVecType.getElementType().isSignlessInteger(2) &&

1842 "Expected i2 type");

1843

1844

1846

1847

1848

1849 Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);

1850

1851 Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);

1852

1853 Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);

1854

1855 Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);

1856

1857

1858

1859

1860

1861

1862

1863

1864

1865

1866 Value interleave02 = rewriter.createvector::InterleaveOp(loc, vec0, vec2);

1867 Value interleave13 = rewriter.createvector::InterleaveOp(loc, vec1, vec3);

1868 return rewriter.createvector::InterleaveOp(loc, interleave02, interleave13);

1869 }

1870

1871

1872

1874 Value srcValue) {

1875 VectorType srcVecType = cast(srcValue.getType());

1876 assert(srcVecType.getElementType().isSignlessInteger(8) &&

1877 "Expected i8 type");

1878

1879

1880 auto deinterleaveOp = rewriter.createvector::DeinterleaveOp(loc, srcValue);

1881

1882

1883 constexpr int8_t i8LowBitMask = 0x0F;

1884 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();

1885 Value zeroOutMask = rewriter.createarith::ConstantOp(

1887 Value zeroOutLow = rewriter.createarith::AndIOp(

1888 loc, deinterleaveOp.getRes1(), zeroOutMask);

1889

1890

1891 constexpr int8_t bitsToShift = 4;

1892 auto shiftValues = rewriter.createarith::ConstantOp(

1894 Value shlHigh = rewriter.createarith::ShLIOp(loc, deinterleaveOp.getRes2(),

1895 shiftValues);

1896

1897

1898 auto mergedHiLowOp = rewriter.createarith::OrIOp(loc, zeroOutLow, shlHigh);

1899

1900

1901 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());

1902 return rewriter.createvector::BitCastOp(loc, i4VecType, mergedHiLowOp);

1903 }

1904

1905 namespace {

1906

1907

1908

1909 struct RewriteBitCastOfTruncI : OpRewritePatternvector::BitCastOp {

1911

1912 LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,

1914

1915 auto truncOp =

1916 bitCastOp.getSource().template getDefiningOparith::TruncIOp();

1917 if (!truncOp)

1918 return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");

1919

1920

1921 VectorType sourceVectorType = bitCastOp.getSourceVectorType();

1922 VectorType targetVectorType = bitCastOp.getResultVectorType();

1923 BitCastRewriter bcr(sourceVectorType, targetVectorType);

1924 if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))

1925 return failure();

1926

1927

1928 Value truncValue = truncOp.getIn();

1929 auto shuffledElementType =

1931 Value runningResult;

1932 for (const BitCastRewriter ::Metadata &metadata :

1933 bcr.precomputeMetadata(shuffledElementType)) {

1934 runningResult = bcr.genericRewriteStep(

1935 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);

1936 }

1937

1938

1939 bool narrowing = targetVectorType.getElementTypeBitWidth() <=

1940 shuffledElementType.getIntOrFloatBitWidth();

1941 if (narrowing) {

1942 if (runningResult.getType() == bitCastOp.getResultVectorType()) {

1943 rewriter.replaceOp(bitCastOp, runningResult);

1944 } else {

1946 bitCastOp, bitCastOp.getResultVectorType(), runningResult);

1947 }

1948 } else {

1949 if (runningResult.getType() == bitCastOp.getResultVectorType()) {

1950 rewriter.replaceOp(bitCastOp, runningResult);

1951 } else {

1953 bitCastOp, bitCastOp.getResultVectorType(), runningResult);

1954 }

1955 }

1956

1957 return success();

1958 }

1959 };

1960 }

1961

1962

1963

1964

1965

1966 namespace {

1967

1968

1969

1970 template

1973

1976

1979

1980 auto bitCastOp = extOp.getIn().template getDefiningOpvector::BitCastOp();

1981 if (!bitCastOp)

1983

1984

1985 VectorType sourceVectorType = bitCastOp.getSourceVectorType();

1986 VectorType targetVectorType = bitCastOp.getResultVectorType();

1987 BitCastRewriter bcr(sourceVectorType, targetVectorType);

1988 if (failed(bcr.commonPrecondition(

1989 rewriter, cast(extOp.getOut().getType()), bitCastOp)))

1990 return failure();

1991

1992

1993 Value runningResult;

1994 Value sourceValue = bitCastOp.getSource();

1995 auto shuffledElementType =

1997 for (const BitCastRewriter::Metadata &metadata :

1998 bcr.precomputeMetadata(shuffledElementType)) {

1999 runningResult = bcr.genericRewriteStep(

2000 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);

2001 }

2002

2003

2004 bool narrowing =

2005 cast(extOp.getOut().getType()).getElementTypeBitWidth() <=

2006 shuffledElementType.getIntOrFloatBitWidth();

2007 if (narrowing) {

2009 extOp, cast(extOp.getOut().getType()), runningResult);

2010 } else {

2012 extOp, cast(extOp.getOut().getType()), runningResult);

2013 }

2014

2015 return success();

2016 }

2017 };

2018

2019

2020

2021

2022

2023

2024

2025

2026

2027

2028

2029

2030

2031

2032

2033

2034

2035

2036

2037

2038

2039

2040

2041

2042

2043

2044

2045

2046

2047

2048

2049

2050

2051

2052

2053 template <typename ConversionOpType, bool isSigned>

2054 struct RewriteAlignedSubByteIntExt : OpRewritePattern {

2056

2057 LogicalResult matchAndRewrite(ConversionOpType conversionOp,

2059

2060 Value srcValue = conversionOp.getIn();

2061 VectorType srcVecType = dyn_cast(srcValue.getType());

2062 VectorType dstVecType = dyn_cast(conversionOp.getType());

2063

2064 if (failed(

2066 return failure();

2067

2068

2070 rewriter, srcVecType,

2071 rewriter.getI8Type(), conversionOp)))

2072 return failure();

2073

2074

2075 Location loc = conversionOp.getLoc();

2078 Value subByteExt;

2079 switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {

2080 case 2:

2081 subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);

2082 break;

2083 case 4:

2084 subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);

2085 break;

2086 default:

2087 return failure();

2088 }

2089

2090

2092 conversionOp, conversionOp.getType(), subByteExt);

2093 return success();

2094 }

2095 };

2096

2097

2098

2099

2100

2101

2102

2103

2104

2105

2106

2107

2108

2109

2110

2111

2112

2113

2114 struct RewriteAlignedSubByteIntTrunc : OpRewritePatternarith::TruncIOp {

2116

2119

2120 Value srcValue = truncOp.getIn();

2121 auto srcVecType = dyn_cast(srcValue.getType());

2122 auto dstVecType = dyn_cast(truncOp.getType());

2123 if (!srcVecType || !dstVecType)

2124 return failure();

2125

2127 return failure();

2128

2129

2130 if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)

2131 return failure();

2132

2133

2134

2136 rewriter, dstVecType,

2137 rewriter.getI8Type(), truncOp)))

2138 return failure();

2139

2140

2141 Location loc = truncOp.getLoc();

2142 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());

2143 Value i8TruncVal =

2144 rewriter.createarith::TruncIOp(loc, i8VecType, srcValue);

2145

2146

2148

2149

2150 rewriter.replaceOp(truncOp, subByteTrunc);

2151 return success();

2152 }

2153 };

2154

2155

2156

2157

2158

2159

2160

2161

2162

2163

2164

2165

2166

2167 struct RewriteVectorTranspose : OpRewritePatternvector::TransposeOp {

2169

2171 : OpRewritePatternvector::TransposeOp(context, benefit) {}

2172

2173 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,

2175

2176 constexpr unsigned minNativeBitwidth = 8;

2177 VectorType srcSubByteVecType = transposeOp.getSourceVectorType();

2178 if (!srcSubByteVecType.getElementType().isSignlessInteger() ||

2179 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {

2181 "not a sub-byte transpose");

2182 }

2183

2184

2185 Location loc = transposeOp.getLoc();

2186

2187

2188

2189

2190 auto srcNativeVecType = srcSubByteVecType.cloneWith(

2191 std::nullopt, rewriter.getIntegerType(minNativeBitwidth));

2192 Value extOp = rewriter.createarith::ExtSIOp(loc, srcNativeVecType,

2193 transposeOp.getVector());

2194 Value newTranspose = rewriter.createvector::TransposeOp(

2195 loc, extOp, transposeOp.getPermutation());

2196 VectorType dstSubByteVecType = transposeOp.getResultVectorType();

2197 rewriter.replaceOpWithNewOparith::TruncIOp(transposeOp, dstSubByteVecType,

2198 newTranspose);

2199 return success();

2200 }

2201 };

2202

2203 }

2204

2205

2206

2207

2208

2209

2210 void vector::populateVectorNarrowTypeEmulationPatterns(

2211 const arith::NarrowTypeEmulationConverter &typeConverter,

2213

2214

2215

2216 patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,

2217 ConvertVectorMaskedStore, ConvertVectorTransferRead>(

2218 typeConverter, patterns.getContext());

2219

2220

2221

2222

2223 patterns.insert(patterns.getContext(), disableAtomicRMW);

2224 }

2225

2226 void vector::populateVectorNarrowTypeRewritePatterns(

2228

2229 patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCastarith::ExtUIOp,

2230 RewriteExtOfBitCastarith::ExtSIOp>(patterns.getContext(),

2231 benefit);

2232

2233

2234

2235

2236 patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, true>,

2237 RewriteAlignedSubByteIntExt<arith::SIToFPOp, true>,

2238 RewriteAlignedSubByteIntTrunc>(patterns.getContext(),

2240

2242 .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, false>,

2243 RewriteAlignedSubByteIntExt<arith::UIToFPOp, false>>(

2245 }

2246

2247

2248 void vector::populateVectorTransposeNarrowTypeRewritePatterns(

2250 patterns.add(patterns.getContext(), benefit);

2251 }

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)

static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)

Walks the given type hierarchy with the given indices, potentially down to component granularity,...

static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)

Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...

std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn

static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)

Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...

TypedValue< MemRefType > MemRefValue

static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)

Emulate a vector load for emulatedElemTy using containerElemTy

TypedValue< VectorType > VectorValue

static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)

Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...

static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)

Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...

static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)

Inserts 1-D subvector into a 1-D vector.

static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)

Inserts 1-D subvector into a 1-D vector.

static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)

Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...

static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)

Extracts 1-D subvector from a 1-D vector.

static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)

Verify that the precondition type meets the common preconditions for any conversion.

static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)

Extracts 1-D subvector from a 1-D vector.

static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)

Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.

static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)

Generate a non-atomic read-modify-write sequence for storing to the emulated type.

static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)

Returns a compressed mask for the emulated vector.

static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)

Bitcasts the aligned subByteVec vector to a vector of i8.

static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)

Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...

static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)

Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...

static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)

Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...

Base type for affine expression.

IntegerType getIntegerType(unsigned width)

TypedAttr getZeroAttr(Type type)

MLIRContext * getContext() const

ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)

This class implements a pattern rewriter for use with ConversionPatterns.

void replaceOp(Operation *op, ValueRange newValues) override

Replace the given operation with the new values.

void eraseOp(Operation *op) override

PatternRewriter hook for erasing a dead operation.

static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)

Constructs a dense elements attribute from an array of element values.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...

OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)

This class represents a single result from folding an operation.

This class implements the operand iterators for the Operation class.

Operation is the basic unit of execution within MLIR.

result_type_range getResultTypes()

This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...

unsigned short getBenefit() const

If the corresponding pattern can match, return its benefit. If the.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isIntOrFloat() const

Return true if this is an integer (of any signedness) or a float type.

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)

Divides the known min value of the numerator by the denominator and rounds the result up to the next ...

std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})

Value constantOne(OpBuilder &builder, Location loc, Type tp)

Generates a 1-valued constant of the given type.

Include the generated interface declarations.

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue

If Ty is mlir::Type this will select Value instead of having a wrapper around it.

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to SymbolExpr at positions: [0 .

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

OpFoldResult getAsOpFoldResult(Value val)

Given a value, try to extract a constant Attribute.

raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...

LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final

Wrapper around the RewritePattern method that passes the derived op type.

For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...

OpFoldResult intraDataOffset

Eliminates variable at the specified position using Fourier-Motzkin variable elimination.