MLIR: lib/Dialect/Linalg/Transforms/Vectorization.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

13

36 #include "llvm/ADT/STLExtras.h"

37 #include "llvm/ADT/Sequence.h"

38 #include "llvm/ADT/SmallVector.h"

39 #include "llvm/ADT/TypeSwitch.h"

40 #include "llvm/ADT/iterator_range.h"

41 #include "llvm/Support/Debug.h"

42 #include "llvm/Support/MathExtras.h"

43 #include "llvm/Support/raw_ostream.h"

44 #include

45 #include <type_traits>

46

47 using namespace mlir;

49

50 #define DEBUG_TYPE "linalg-vectorization"

51

52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")

53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

54

55

56 static FailureOr<Operation *>

60 bool flatten1DDepthwiseConv = false);

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80 static LogicalResult

84

85

86

87

88

89

90

92

93

94

95 template

97 OpType res;

98 block.walk([&](OpType op) {

99 if (res) {

100 res = nullptr;

102 }

103 res = op;

105 });

106 return res;

107 }

108

109

110

113 int64_t nSize, int64_t wSize, int64_t cSize,

114 int64_t kwSize, int strideW, int dilationW,

115 int64_t wSizeStep, bool isSingleChanneled) {

117 if (isSingleChanneled) {

118

119

122 for (int64_t kw = 0; kw < kwSize; ++kw) {

123 for (int64_t w = 0; w < wSize; w += wSizeStep) {

124 result.push_back(rewriter.createvector::ExtractStridedSliceOp(

125 loc, input, ArrayRef<int64_t>{w + kw}, sizes, strides));

126 }

127 }

128 } else {

129

130

133 for (int64_t kw = 0; kw < kwSize; ++kw) {

134 for (int64_t w = 0; w < wSize; w += wSizeStep) {

135 result.push_back(rewriter.createvector::ExtractStridedSliceOp(

136 loc, input,

137 ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},

138 sizes, strides));

139 }

140 }

141 }

142 return result;

143 }

144

145

146

149 int64_t kwSize) {

151

152

153 for (int64_t kw = 0; kw < kwSize; ++kw) {

154 result.push_back(rewriter.createvector::ExtractOp(

156 }

157 return result;

158 }

159

160

161

164 int64_t nSize, int64_t wSize, int64_t fSize,

165 int64_t wSizeStep, bool isSingleChanneled) {

167 if (isSingleChanneled) {

168

171 for (int64_t w = 0; w < wSize; w += wSizeStep) {

172 result.push_back(rewriter.createvector::ExtractStridedSliceOp(

174 }

175 } else {

176

177

180 for (int64_t w = 0; w < wSize; w += wSizeStep) {

181 result.push_back(rewriter.createvector::ExtractStridedSliceOp(

182 loc, res, ArrayRef<int64_t>{0, w, 0}, sizes, strides));

183 }

184 }

185 return result;

186 }

187

188

190 Value res, int64_t wSize, int64_t wSizeStep,

192 bool isSingleChanneled) {

193

194 if (isSingleChanneled) {

195

196

198 for (int64_t w = 0; w < wSize; w += wSizeStep) {

199 res = rewriter.createvector::InsertStridedSliceOp(

200 loc, resVals[w], res, ArrayRef<int64_t>{w}, strides);

201 }

202 } else {

203

204

206 for (int64_t w = 0; w < wSize; w += wSizeStep) {

207 res = rewriter.createvector::InsertStridedSliceOp(

209 strides);

210 }

211 }

212 return res;

213 }

214

215

216

219

220

221

222 LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,

225

226

228

229

230

232

233

234

235

236

238 Type elementType,

239 std::optional dimPermutation = std::nullopt) const {

242 if (dimPermutation.has_value()) {

244 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);

245 scalableDims =

246 applyPermutationMap(*dimPermutation, scalableVecDims);

247 } else {

248 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());

249 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());

250 }

251

253 }

254

255

256

257

258

261 std::optional maybeIndexingMap = std::nullopt);

262

263 private:

264

265

266 void initIterSpaceStaticSizes(LinalgOp linalgOp) {

267 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());

268 }

269

270

271

272

273 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,

274 LinalgOp linalgOp);

275

276

277

278

279

281 LinalgOp linalgOp,

282 std::optional maybeMaskingMap);

283

284

285

286

287 bool isValidMaskingMap(AffineMap maskingMap) {

289 }

290

291

292

293

294

295

296

297

298

299

300

301

302

303

304

305

308 }

309

310

311

313

314

315

316

318

319

321

322

323

325

326

327

329

330

331

333 };

334

335 LogicalResult

336 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,

337 LinalgOp linalgOp) {

338

339 for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {

340 if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {

341

342 iterSpaceValueSizes.push_back(rewriter.createarith::ConstantIndexOp(

343 linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));

344 continue;

345 }

346

347

348

350 unsigned operandDimPos;

351 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,

352 operandDimPos)))

353 return failure();

354

355 Value dynamicDim = linalgOp.hasPureTensorSemantics()

357 linalgOp.getLoc(), operand, operandDimPos)

359 linalgOp.getLoc(), operand, operandDimPos);

360 iterSpaceValueSizes.push_back(dynamicDim);

361 }

362

363 return success();

364 }

365

366

367

368

369 LogicalResult

373

375

376 if (!inputVectorSizes.empty()) {

377

378

379

380 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());

381 scalableVecDims.append(inputScalableVecDims.begin(),

382 inputScalableVecDims.end());

383 } else {

384

385

386

387 canonicalVecShape = linalgOp.getStaticLoopRanges();

388 scalableVecDims.append(linalgOp.getNumLoops(), false);

389 }

390

391 LDBG("Canonical vector shape: ");

392 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));

393 LLVM_DEBUG(llvm::dbgs() << "\n");

394 LDBG("Scalable vector dims: ");

395 LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));

396 LLVM_DEBUG(llvm::dbgs() << "\n");

397

398 if (ShapedType::isDynamicShape(canonicalVecShape))

399 return failure();

400

401

402 initIterSpaceStaticSizes(linalgOp);

403

404

405

406

407 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))

408 return failure();

409

410 return success();

411 }

412

413

414

415

416

417 Value VectorizationState::getOrCreateMaskFor(

419 std::optional maybeMaskingMap) {

420

421 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&

422 "Ill-formed masking map.");

423

424

425 auto maskableOp = dyn_castvector::MaskableOpInterface(opToMask);

426 if (!maskableOp)

428

429 assert(!maskableOp.isMasked() &&

430 "Masking an operation that is already masked");

431

432

433 assert((!maybeMaskingMap || *maybeMaskingMap) &&

434 "Unexpected null mask permutation map");

436 maybeMaskingMap ? *maybeMaskingMap

438 linalgOp.getNumLoops(), rewriter.getContext());

439

440 LDBG("Masking map: " << maskingMap << "\n");

441

442

443

444 auto activeMaskIt = activeMaskCache.find(maskingMap);

445 if (activeMaskIt != activeMaskCache.end()) {

446 Value mask = activeMaskIt->second;

447 LDBG("Reusing mask: " << mask << "\n");

448 return mask;

449 }

450

451

452

453

454

455

456

458 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);

459 auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);

460 auto maskShape = maskType.getShape();

461

462 LDBG("Mask shape: ");

463 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));

464 LLVM_DEBUG(llvm::dbgs() << "\n");

465

466 if (permutedStaticSizes == maskShape) {

467 LDBG("Masking is not needed for masking map: " << maskingMap << "\n");

468 activeMaskCache[maskingMap] = Value();

470 }

471

472

475 assert(!maskShape.empty() && !upperBounds.empty() &&

476 "Masked 0-d vectors are not supported yet");

477

478

479 Value mask = rewriter.createvector::CreateMaskOp(linalgOp.getLoc(),

480 maskType, upperBounds);

481 LDBG("Creating new mask: " << mask << "\n");

482 activeMaskCache[maskingMap] = mask;

483 return mask;

484 }

485

488 LinalgOp linalgOp,

489 std::optional maybeIndexingMap) {

490 LDBG("Trying to mask: " << *opToMask << "\n");

491

492 std::optional maybeMaskingMap = std::nullopt;

493 if (maybeIndexingMap)

494 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);

495

496

498 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);

499

500 if (!mask) {

501 LDBG("No mask required\n");

502 return opToMask;

503 }

504

505

506 assert(opToMask && "Expected a valid operation to mask");

507 auto maskOp = castvector::MaskOp(

509 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();

510

513 maskOpTerminator);

514

515 LDBG("Masked operation: " << *maskOp << "\n");

516 return maskOp;

517 }

518

519

520

521

522

523

524

525

526

527

528

529

530

531

532

533

534

535

538 "expected projected permutation");

540 assert(res.getNumDims() ==

541 (res.getNumResults() - res.getNumOfZeroResults()) &&

542 "expected reindexed map with same number of dims and results");

543 return res;

544 }

545

546

548 W,

549 Ncw,

550 Nwc

551 };

552

553

554

556

558

560

561

563

564

567

569

570

572 };

573

574 std::optionalvector::CombiningKind

576 using ::mlir::vector::CombiningKind;

577

578 if (!combinerOp)

579 return std::nullopt;

581 .Case<arith::AddIOp, arith::AddFOp>(

582 [&](auto op) { return CombiningKind::ADD; })

583 .Casearith::AndIOp([&](auto op) { return CombiningKind::AND; })

584 .Casearith::MaxSIOp([&](auto op) { return CombiningKind::MAXSI; })

585 .Casearith::MaxUIOp([&](auto op) { return CombiningKind::MAXUI; })

586 .Casearith::MaximumFOp([&](auto op) { return CombiningKind::MAXIMUMF; })

587 .Casearith::MaxNumFOp([&](auto op) { return CombiningKind::MAXNUMF; })

588 .Casearith::MinSIOp([&](auto op) { return CombiningKind::MINSI; })

590 .Casearith::MinimumFOp([&](auto op) { return CombiningKind::MINIMUMF; })

591 .Casearith::MinNumFOp([&](auto op) { return CombiningKind::MINNUMF; })

592 .Case<arith::MulIOp, arith::MulFOp>(

593 [&](auto op) { return CombiningKind::MUL; })

594 .Casearith::OrIOp([&](auto op) { return CombiningKind::OR; })

595 .Casearith::XOrIOp([&](auto op) { return CombiningKind::XOR; })

596 .Default([&](auto op) { return std::nullopt; });

597 }

598

599

600

601

602

603

604

605

607 auto linalgOp = cast(outputOperand->getOwner());

608 unsigned outputPos =

609 outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();

610

612 if (matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||

613 combinerOps.size() != 1)

614 return nullptr;

615

616

617 return combinerOps[0];

618 }

619

620

621

623 auto dstVecType = dyn_cast(dstType);

624

625 if (dstVecType.getRank() == 0)

626 return value;

629 return value;

631 return b.createOrFoldvector::BroadcastOp(loc, dstVecType, value);

632 }

633

634

635

636

637

638

643 assert(maybeKind && "Failed precondition: could not get reduction kind");

644 return b.createvector::MultiDimReductionOp(

645 reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);

646 }

647

649 return llvm::to_vector(

651 }

652

653

654

656 return isalinalg::ReduceOp(op) ||

657 (isalinalg::GenericOp(op) &&

659 }

660

661

662

663

664

665

666

671 auto linalgOp = cast(outputOperand->getOwner());

672 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);

673

674

675

676

677

681 return llvm::is_contained(opOperandMap.getResults(), dimExpr);

682 });

683 auto vectorType = state.getCanonicalVecType(

685

687 if (vectorType.getRank() > 0) {

690 rewriter.createarith::ConstantIndexOp(loc, 0));

692 assert(value.getType() == vectorType && "Incorrect type");

693 write = rewriter.createvector::TransferWriteOp(

694 loc, value, outputOperand->get(), indices, writeMap);

695 } else {

696

697 if (!isa(value.getType()))

698 value = rewriter.createvector::BroadcastOp(loc, vectorType, value);

699 assert(value.getType() == vectorType && "Incorrect type");

700 write = rewriter.createvector::TransferWriteOp(

702 }

703

704 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);

705

706

707

708 if (auto maskOp = dyn_castvector::MaskingOpInterface(write)) {

709 auto maskedWriteOp = castvector::TransferWriteOp(maskOp.getMaskableOp());

710 SmallVector inBounds(maskedWriteOp.getVectorType().getRank(), true);

711 maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));

712 }

713

714 LDBG("vectorized op: " << *write << "\n");

718 }

719

720

721

722

724 std::function<LogicalResult(Operation *, bool)>;

725

726

727

728

731

732

733

734

735

736

737

738

743 auto yieldOp = dyn_castlinalg::YieldOp(op);

744 if (!yieldOp)

746 for (const auto &output : llvm::enumerate(yieldOp.getValues())) {

747

748

749 Value vectorValue = bvm.lookup(output.value());

750 Value newResult =

752 linalgOp.getDpsInitOperand(output.index()), state);

753 if (newResult)

754 newResults.push_back(newResult);

755 }

756

758 }

759

760

761

762

763

767 LinalgOp linalgOp) {

768 IndexOp indexOp = dyn_castlinalg::IndexOp(op);

769 if (!indexOp)

771 auto loc = indexOp.getLoc();

772

774 auto dim = indexOp.getDim();

775

776 auto indexVectorType =

778 state.getScalableVecDims()[dim]);

779 auto indexSteps = rewriter.createvector::StepOp(loc, indexVectorType);

780

781

782

783 if (dim == targetShape.size() - 1)

785

786

787

788 auto permPattern =

789 llvm::to_vector(llvm::seq(0, targetShape.size()));

790 std::swap(permPattern[dim], permPattern.back());

791 auto permMap =

793

794 auto broadCastOp = rewriter.createvector::BroadcastOp(

795 loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),

796 indexSteps);

798 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));

799 std::swap(transposition.back(), transposition[dim]);

800 auto transposeOp =

801 rewriter.createvector::TransposeOp(loc, broadCastOp, transposition);

803 }

804

805

806

807 static LogicalResult

809 tensor::ExtractOp extractOp = dyn_casttensor::ExtractOp(op);

810 if (!extractOp)

811 return failure();

812

813 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)

814 return failure();

815

816

817

818 if (not extractOp.getIndices().empty()) {

819 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))

820 return failure();

821 }

822

823 if (!llvm::all_of(extractOp->getResultTypes(),

824 VectorType::isValidElementType)) {

825 return failure();

826 }

827

828 return success();

829 }

830

831

832

833

834

835

836

837

838

839

840

843 tensor::ExtractOp extractOp,

845

846 auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());

847 auto loc = extractOp.getLoc();

848

850 rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);

851

852 const size_t numIndices = extractOp.getIndices().size();

853 for (size_t i = 1; i < numIndices; i++) {

854 Value dimIdx = rewriter.createarith::ConstantIndexOp(loc, i);

855

857 rewriter,

858 rewriter.createtensor::DimOp(loc, extractOp.getTensor(), dimIdx),

859 indexVecType);

860

861 offset = rewriter.createarith::MulIOp(loc, offset, dimSize);

862

864 rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);

865

866 offset = rewriter.createarith::AddIOp(loc, extractOpIndex, offset);

867 }

868

869 return offset;

870 }

871

873

874

875

876

877

878

879

880

881

882

883

884

885

886

887

888

891 assert(

892 (linalgOp.hasDynamicShape() ||

893 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&

894 "For statically shaped Linalg Ops, only one "

895 "non-unit loop dim is expected");

896 assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");

897

898 size_t idx = loopRanges.size() - 1;

899 for (; idx != 0; idx--)

900 if (loopRanges[idx] != 1)

901 break;

902

903 return idx;

904 }

905

906

908 VectorType resType) {

909

910 assert(((llvm::count_if(resType.getShape(),

911 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&

912 "n-D vectors are not yet supported");

913

914

915

916

917

918 auto *block = linalgOp.getBlock();

919 if (isa(val))

920 return llvm::all_of(block->getArguments(),

921 [&val](Value v) { return (v != val); });

922

924 assert(defOp && "This is neither a block argument nor an operation result");

925

926

927

928

929 if (auto indexOp = dyn_castlinalg::IndexOp(defOp)) {

930 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;

931 }

932

933 auto *ancestor = block->findAncestorOpInBlock(*defOp);

934

935

936 if (!ancestor)

937 return true;

938

939

940 if (isaarith::ConstantOp(ancestor))

941 return true;

942

943 bool result = true;

944 for (auto op : ancestor->getOperands())

946

947 return result;

948 }

949

950

951

952

953

954

955

956

957

958

959

960

961

962

963

964

965

966

968 bool &foundIndexOp, VectorType resType) {

969

970 assert(((llvm::count_if(resType.getShape(),

971 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&

972 "n-D vectors are not yet supported");

973

974

975

976

977

978 auto *block = linalgOp.getBlock();

979 if (isa(val))

980 return llvm::all_of(block->getArguments(),

981 [&val](Value v) { return (v != val); });

982

984 assert(defOp && "This is neither a block argument nor an operation result");

985

986 if (auto indexOp = dyn_castlinalg::IndexOp(defOp)) {

988

989 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);

990 return true;

991 }

992

993 auto *ancestor = block->findAncestorOpInBlock(*defOp);

994

995 if (!ancestor)

996 return false;

997

998

999

1000 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))

1001 return false;

1002

1003 bool result = false;

1004 for (auto op : ancestor->getOperands())

1006

1007 return result;

1008 }

1009

1010

1011

1012

1013

1014

1015

1016

1017

1018

1019

1020

1021

1024 LinalgOp &linalgOp, VectorType resType) {

1025

1026 auto inputShape = cast(extractOp.getTensor().getType());

1027

1028

1029 if (inputShape.getShape().empty())

1031

1032

1033

1034 bool isOutput1DVector =

1035 (llvm::count_if(resType.getShape(),

1036 [](int64_t dimSize) { return dimSize > 1; }) == 1);

1037

1038 if (!isOutput1DVector)

1040

1041 bool leadingIdxsLoopInvariant = true;

1042

1043

1044

1045

1046

1047 auto indices = extractOp.getIndices();

1048 auto leadIndices = indices.drop_back(1);

1049

1050 for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {

1051 if (inputShape.getShape()[i] == 1)

1052 continue;

1053

1054 leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);

1055 }

1056

1057 if (!leadingIdxsLoopInvariant) {

1058 LDBG("Found gather load: " << extractOp);

1060 }

1061

1062

1063

1064

1065

1066 auto extractOpTrailingIdx = indices.back();

1067

1068

1069

1070 if (leadingIdxsLoopInvariant &&

1072 LDBG("Found scalar broadcast load: " << extractOp);

1073

1075 }

1076

1077

1078

1079

1080

1081 bool foundIndexOp = false;

1082 bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,

1083 foundIndexOp, resType);

1084

1085

1086 bool isRowVector = resType.getShape().back() != 1;

1087 isContiguousLoad &= (foundIndexOp && isRowVector);

1088

1089 if (isContiguousLoad) {

1090 LDBG("Found contigous load: " << extractOp);

1092 }

1093

1094

1095 LDBG("Found gather load: " << extractOp);

1097 }

1098

1099

1100

1101

1102

1106 tensor::ExtractOp extractOp = dyn_casttensor::ExtractOp(op);

1107 if (!extractOp)

1109 auto loc = extractOp.getLoc();

1110

1111

1112 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());

1113 auto maskConstantOp = rewriter.createarith::ConstantOp(

1114 loc,

1116 true));

1117 auto passThruConstantOp =

1118 rewriter.createarith::ConstantOp(loc, rewriter.getZeroAttr(resultType));

1119

1120

1121

1123 extractOp.getIndices().size(),

1124 rewriter.createarith::ConstantIndexOp(loc, 0));

1125

1128

1129

1132

1133

1134 Operation *gatherOp = rewriter.createvector::GatherOp(

1135 loc, resultType, extractOp.getTensor(), baseIndices, offset,

1136 maskConstantOp, passThruConstantOp);

1137 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);

1138

1139 LDBG("Vectorised as gather load: " << extractOp << "\n");

1141 }

1142

1143

1144

1145

1146

1147

1148

1149

1150

1151

1152

1153

1154

1155

1156

1157

1158

1159

1160

1162 for (size_t i = 0; i < extractOp.getIndices().size(); i++) {

1163 Value idx = bvm.lookup(extractOp.getIndices()[i]);

1165 transferReadIdxs.push_back(idx);

1166 continue;

1167 }

1168

1169 auto indexAs1dVector = rewriter.createvector::ShapeCastOp(

1170 loc,

1172 resultType.getScalableDims().back()),

1173 idx);

1174 transferReadIdxs.push_back(

1175 rewriter.createvector::ExtractOp(loc, indexAs1dVector, 0));

1176 }

1177

1178

1179 auto dstRank = resultType.getRank();

1180 auto srcRank = extractOp.getTensor().getType().getRank();

1182

1183

1187 auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);

1188

1189 auto transferReadOp = rewriter.createvector::TransferReadOp(

1190 loc, resultType, extractOp.getTensor(), transferReadIdxs,

1191 permutationMap, inBounds);

1192

1193

1194

1195

1198 auto allTrue = rewriter.createvector::ConstantMaskOp(

1200 auto *maskedReadOp =

1202

1203 LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");

1205 }

1206

1207

1210

1211 int32_t rankDiff = dstRank - srcRank;

1212

1213

1214

1215

1216

1217

1218

1219 while (rankDiff > 0) {

1220 permutationMap = permutationMap.insertResult(

1222 rankDiff--;

1223 }

1224

1225 auto transferReadOp = rewriter.createvector::TransferReadOp(

1226 loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,

1227 inBounds);

1228

1229 LDBG("Vectorised as contiguous load: " << extractOp);

1231 }

1232

1233

1234

1235

1236

1238 Value reduceValue, Value initialValue,

1240 Value reduceVec = bvm.lookup(reduceValue);

1241 Value outputVec = bvm.lookup(initialValue);

1242 auto reduceType = dyn_cast(reduceVec.getType());

1243 auto outputType = dyn_cast(outputVec.getType());

1244

1245

1246 if (!reduceType ||

1247 (outputType && reduceType.getShape() == outputType.getShape()))

1248 return nullptr;

1251 }

1252

1253

1254

1255

1256

1257

1258

1259

1260

1261

1262

1263

1264

1265

1266

1267

1268

1269

1270

1271

1276 LDBG("vectorize op " << *op << "\n");

1277

1278

1279 if (!customVectorizationHooks.empty()) {

1280 for (auto &customFunc : customVectorizationHooks) {

1283 continue;

1284 return result;

1285 }

1286 }

1287

1288

1289

1290 if (isa<arith::ConstantOp, func::ConstantOp>(op))

1292

1293

1296

1297

1300 auto blockArg = dyn_cast(operand);

1301 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||

1302 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())

1303 continue;

1306 linalgOp.getRegionOutputArgs(),

1307 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);

1308 if (!reduceValue)

1309 continue;

1310 reductionOperands.push_back(std::make_pair(reduceValue, operand));

1311 }

1312 if (!reductionOperands.empty()) {

1313 assert(reductionOperands.size() == 1);

1315 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,

1316 reductionOperands[0].second, bvm);

1317 if (reduceOp)

1319 }

1320

1321

1322

1323 VectorType firstMaxRankedType;

1325 auto vecOperand = bvm.lookup(operand);

1326 assert(vecOperand && "Vector operand couldn't be found");

1327

1328 auto vecType = dyn_cast(vecOperand.getType());

1329 if (vecType && (!firstMaxRankedType ||

1330 firstMaxRankedType.getRank() < vecType.getRank()))

1331 firstMaxRankedType = vecType;

1332 }

1333

1336 Value vecOperand = bvm.lookup(scalarOperand);

1337 assert(vecOperand && "Vector operand couldn't be found");

1338

1339 if (firstMaxRankedType) {

1340 auto vecType = VectorType::get(firstMaxRankedType.getShape(),

1342 firstMaxRankedType.getScalableDims());

1343 vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));

1344 } else {

1345 vecOperands.push_back(vecOperand);

1346 }

1347 }

1348

1351 resultTypes.push_back(

1352 firstMaxRankedType

1353 ? VectorType::get(firstMaxRankedType.getShape(), resultType,

1354 firstMaxRankedType.getScalableDims())

1355 : resultType);

1356 }

1357

1361 resultTypes, op->getAttrs())};

1362 }

1363

1364

1365

1366

1367

1368

1369

1370

1371

1372

1373

1374

1375

1376

1377

1378

1379

1380

1381

1382

1383

1384

1385

1386 static LogicalResult

1388 LinalgOp linalgOp,

1390 LDBG("Vectorizing operation as linalg generic\n");

1391 Block *block = linalgOp.getBlock();

1392

1393

1394

1398 bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());

1399

1400 if (linalgOp.getNumDpsInits() == 0)

1401 return failure();

1402

1403

1404 Location loc = linalgOp.getLoc();

1405 Value zero = rewriter.createarith::ConstantIndexOp(loc, 0);

1406 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {

1407 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);

1408 if (linalgOp.isScalar(opOperand)) {

1409 bvm.map(bbarg, opOperand->get());

1410 continue;

1411 }

1412

1413

1414

1415 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);

1416

1418 VectorType readType;

1420 if (linalgOp.isDpsInput(opOperand)) {

1421

1423 readType = state.getCanonicalVecType(elemType);

1424 } else {

1425

1426

1427

1429 readType =

1430 state.getCanonicalVecType(elemType, readMap.compose(indexingMap));

1431 }

1432

1433 SmallVector indices(linalgOp.getShape(opOperand).size(), zero);

1434

1435 Operation *read = rewriter.createvector::TransferReadOp(

1436 loc, readType, opOperand->get(), indices, readMap);

1437 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);

1439

1440

1441

1442 if (auto maskOp = dyn_castvector::MaskingOpInterface(read)) {

1444 castvector::TransferReadOp(maskOp.getMaskableOp())

1446 }

1447

1448

1449

1450 if (readType.getRank() == 0)

1453

1455 << "\n");

1458 }

1459

1461

1464 return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);

1465 };

1466 hooks.push_back(vectorizeYield);

1467

1468

1472 };

1473 hooks.push_back(vectorizeIndex);

1474

1475

1479 };

1480 hooks.push_back(vectorizeExtract);

1481

1482

1485 vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);

1487 LDBG("failed to vectorize: " << op << "\n");

1488 return failure();

1489 }

1492 state.maskOperation(rewriter, result.newOp, linalgOp);

1493 LDBG("New vector op: " << *maybeMaskedOp << "\n");

1495 }

1496 }

1497

1498 return success();

1499 }

1500

1501

1502

1506 }

1507

1508

1509

1510

1511

1512

1513

1514

1515

1516

1517

1518

1519

1520

1521

1522

1523

1524

1525

1526

1527

1528

1529

1530

1531

1532

1533

1534

1535

1536

1537

1538

1539

1540

1541

1542

1543

1544

1545

1546

1547

1548

1549

1550

1551

1552

1553

1558

1559 if (ShapedType::isDynamicShape(destShape))

1560 return false;

1561

1562

1566 cstMaskSizes.push_back(*intSize);

1567 }

1568 }

1569

1570

1571 if (cstMaskSizes.size() != maskShape.size())

1572 return false;

1573

1574

1577 APSInt intVal;

1579 cstWriteIdxs.push_back(intVal.getSExtValue());

1580 }

1581 }

1582

1583

1584 if (cstWriteIdxs.size() != destShape.size())

1585 return false;

1586

1587

1588

1589

1590

1591

1592

1593 int64_t rankDiff = destShape.size() - cstMaskSizes.size();

1595 if ( maskShape[i] > destShape[rankDiff + i] ||

1596 destShape[rankDiff + i] <

1597 (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +

1598 cstWriteIdxs[i]))

1599 return false;

1600 }

1601

1602 return true;

1603 }

1604

1605

1606

1607

1608

1609

1610

1611

1612

1613

1614

1615

1616

1617

1618

1619

1620

1621

1622

1623

1624

1625

1626

1627

1628

1629

1633 bool useInBoundsInsteadOfMasking = false) {

1634

1635 ShapedType destType = cast(dest.getType());

1636 int64_t destRank = destType.getRank();

1637 auto destShape = destType.getShape();

1638

1639 VectorType vecToStoreType = cast(vecToStore.getType());

1640 int64_t vecToStoreRank = vecToStoreType.getRank();

1641 auto vecToStoreShape = vecToStoreType.getShape();

1642

1643

1645 if (useInBoundsInsteadOfMasking) {

1646

1647

1648 for (unsigned i = 0; i < vecToStoreRank; i++)

1649 inBoundsVal[i] =

1650 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&

1651 !ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]);

1652 }

1653

1654

1655 assert(writeIndices.empty() ||

1656 writeIndices.size() == static_cast<size_t>(destRank) &&

1657 "Invalid number of write indices!");

1658 if (writeIndices.empty()) {

1659 auto zero = builder.createarith::ConstantIndexOp(loc, 0);

1660 writeIndices.assign(destRank, zero);

1661 }

1662

1663

1665 builder.createvector::TransferWriteOp(loc,

1666 vecToStore,

1667 dest,

1668 writeIndices,

1669 inBoundsVal);

1670

1671

1672 if (useInBoundsInsteadOfMasking)

1673 return write;

1674

1675

1676 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))

1677 return write;

1678

1679

1681

1685 destSizes.end());

1686

1688 vecToStoreShape))

1689 return write;

1690

1691 Value maskForWrite =

1692 builder.createOrFoldvector::CreateMaskOp(loc, writeMaskType, maskSizes);

1694 }

1695

1696

1697

1698

1699

1700

1701

1702

1703

1704

1705

1706

1707

1708

1709

1710

1711

1712

1713

1714

1715

1716

1717

1718

1719

1720

1721

1722

1723

1724

1725

1726

1727

1728

1729

1730 static LogicalResult

1734

1737

1738 Location loc = packOp.getLoc();

1739 auto padValue = packOp.getPaddingValue();

1740 if (!padValue) {

1741 padValue = rewriter.createarith::ConstantOp(

1742 loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));

1743 }

1745 LogicalResult status =

1746 cast(packOp.getOperation())

1747 .reifyResultShapes(rewriter, reifiedReturnShapes);

1748 (void)status;

1749 assert(succeeded(status) && "failed to reify result shapes");

1750

1751

1752

1753

1754 bool useInBoundsInsteadOfMasking = false;

1755 if (inputVectorSizes.empty()) {

1756 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();

1757 inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());

1758 useInBoundsInsteadOfMasking = true;

1759 }

1760

1761

1763 auto innerTiles = packOp.getStaticInnerTiles();

1764 auto innerDimsPos = packOp.getInnerDimsPos();

1772 rewriter, loc, packOp.getSource(), inputShape, padValue,

1773 useInBoundsInsteadOfMasking);

1774

1775

1779 packOp.getDestType().getElementType());

1780 auto shapeCastOp =

1781 rewriter.createvector::ShapeCastOp(loc, tiledPackType, maskedRead);

1782

1783

1784 auto destPermutation =

1786 auto transposeOp = rewriter.createvector::TransposeOp(

1787 loc, shapeCastOp.getResult(), destPermutation);

1788

1789

1790 Value dest = rewriter.createtensor::EmptyOp(

1791 loc, reifiedReturnShapes[0],

1792 transposeOp.getResult().getType().getElementType());

1795 newResults.push_back(write->getResult(0));

1796 return success();

1797 }

1798

1799

1800

1801

1802

1803

1804

1805

1806

1807

1808 static LogicalResult

1812

1813

1816

1817 RankedTensorType unpackTensorType = unpackOp.getSourceType();

1818

1822 bool useInBoundsInsteadOfMasking = false;

1824

1825 auto destSize = unpackOp.getDestRank();

1826

1827 if (!inputVectorSizes.empty())

1828 assert(inputVectorSizes.size() == destSize &&

1829 "Incorrect number of input vector sizes");

1830

1831

1832

1833

1834

1835

1836

1837

1838

1840 if (vectorSizes.empty()) {

1841 llvm::append_range(vectorSizes, sourceShape.take_front(destSize));

1846

1847 useInBoundsInsteadOfMasking = true;

1848 }

1849

1850

1851

1852

1853

1854

1855

1856

1857

1858

1859

1860

1861

1862

1863

1864

1865

1866

1867

1868

1869 SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());

1870

1872 readVectorSizes[innerDimPos[index]] =

1873 llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);

1874 }

1877 }

1878 readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),

1879 sourceShape.end());

1880

1882 LogicalResult status =

1883 cast(unpackOp.getOperation())

1884 .reifyResultShapes(rewriter, reifiedRetShapes);

1885 if (status.failed()) {

1886 LDBG("Unable to reify result shapes of " << unpackOp);

1887 return failure();

1888 }

1889 Location loc = unpackOp->getLoc();

1890

1891 auto padValue = rewriter.createarith::ConstantOp(

1892 loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));

1893

1894

1895

1897 rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,

1898 false);

1899

1900 PackingMetadata packMetadata;

1903 ShapedType maskedOpShapedType = cast(readResult.getType());

1905 mlir::Type stripMineElemType = maskedOpShapedType.getElementType();

1907 RankedTensorType stripMineTensorType =

1909

1910 vector::TransposeOp transposeOp = rewriter.createvector::TransposeOp(

1911 loc, readResult, lastDimToInsertPosPerm);

1912

1913

1914 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(

1915 stripMineTensorType, packMetadata.reassociations);

1916 mlir::VectorType vecCollapsedType =

1917 VectorType::get(collapsedType.getShape(), collapsedType.getElementType());

1918 vector::ShapeCastOp shapeCastOp = rewriter.createvector::ShapeCastOp(

1919 loc, vecCollapsedType, transposeOp->getResult(0));

1920

1921

1922

1924 unpackOp.getDestType().hasStaticShape()

1925 ? vectorSizes

1926 : shapeCastOp.getResultVectorType().getShape());

1927 Value dest = rewriter.createtensor::EmptyOp(

1928 loc, reifiedRetShapes[0],

1929 shapeCastOp.getResult().getType().getElementType());

1931 rewriter, loc, shapeCastOp.getResult(), dest,

1932 {}, useInBoundsInsteadOfMasking);

1933 newResults.push_back(write->getResult(0));

1934 return success();

1935 }

1936

1937

1938

1939

1940 static LogicalResult

1944 auto padValue = padOp.getConstantPaddingValue();

1945 Location loc = padOp.getLoc();

1946

1947

1950

1952 LogicalResult status =

1953 cast(padOp.getOperation())

1954 .reifyResultShapes(rewriter, reifiedReturnShapes);

1955 (void)status;

1956 assert(succeeded(status) && "failed to reify result shapes");

1958 rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,

1959 false);

1960

1961

1962 Value dest = rewriter.createtensor::EmptyOp(

1963 loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());

1965 newResults.push_back(write->getResult(0));

1966 return success();

1967 }

1968

1969

1970

1973 LDBG("reduction precondition failed: no reduction iterator\n");

1974 return failure();

1975 }

1976 for (OpOperand &opOperand : op.getDpsInitsMutable()) {

1977 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);

1979 continue;

1980

1983 LDBG("reduction precondition failed: reduction detection failed\n");

1984 return failure();

1985 }

1986 }

1987 return success();

1988 }

1989

1990 static LogicalResult

1992 bool flatten1DDepthwiseConv) {

1993 if (flatten1DDepthwiseConv) {

1994 LDBG("Vectorization of flattened convs with dynamic shapes is not "

1995 "supported\n");

1996 return failure();

1997 }

1998

1999 if (!isalinalg::DepthwiseConv1DNwcWcOp(conv)) {

2000 LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");

2001 return failure();

2002 }

2003

2004

2005

2006 Value lhs = conv.getDpsInputOperand(0)->get();

2008 auto shapeWithoutCh = lhsShape.drop_back(1);

2009 if (ShapedType::isDynamicShape(shapeWithoutCh)) {

2010 LDBG("Dynamically-shaped op vectorization precondition failed: only "

2011 "channel dim can be dynamic\n");

2012 return failure();

2013 }

2014

2015 return success();

2016 }

2017

2018 static LogicalResult

2020 bool flatten1DDepthwiseConv) {

2021 if (isa(op.getOperation()))

2023

2026

2027

2028

2030 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(

2031 op.getOperation()))

2032 return failure();

2033

2034 LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");

2035 return success();

2036 }

2037

2038

2039 static LogicalResult

2042

2043 if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {

2044 return !getConstantIntValue(res).has_value();

2045 })) {

2046 LDBG("Inner-tiles must be constant: " << unpackOp << "\n");

2047 return failure();

2048 }

2049 ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();

2050 bool satisfyEmptyCond = inputVectorSizes.empty() &&

2051 unpackOp.getDestType().hasStaticShape() &&

2052 unpackOp.getSourceType().hasStaticShape();

2053 if (!satisfyEmptyCond &&

2055 return failure();

2056

2057 return success();

2058 }

2059

2060 static LogicalResult

2063

2065 auto sourceType = source.getType();

2066 if (!VectorType::isValidElementType(sourceType.getElementType()))

2067 return failure();

2068

2069

2070

2071

2072

2073

2074

2075

2076

2077

2078

2079

2080

2082 bool isOutOfBoundsRead =

2083 !sourceType.hasStaticShape() && inputVectorSizes.empty();

2084

2085 if (!padValue && isOutOfBoundsRead) {

2086 LDBG("Failed to get a pad value for out-of-bounds read access\n");

2087 return failure();

2088 }

2089 return success();

2090 }

2091

2092 namespace {

2093 enum class ConvOperationKind { Conv, Pool };

2094 }

2095

2097 return isa(op) && op->getNumOperands() == 1 &&

2098 isa(op->getOperand(0));

2099 }

2100

2101

2102

2103

2104

2105

2106

2107

2108

2109

2110

2111 static std::optional

2113 int numBlockArguments =

2114 llvm::count_if(reduceOp->getOperands(), llvm::IsaPred);

2115

2116 switch (numBlockArguments) {

2117 case 1: {

2118

2119

2120

2121

2122 auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),

2123 llvm::IsaPred);

2124 assert(feedValIt != reduceOp->operand_end() &&

2125 "Expected a non-block argument operand");

2126 Operation *feedOp = (*feedValIt).getDefiningOp();

2128 return ConvOperationKind::Pool;

2129 }

2130

2131 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||

2132 (isaarith::AndIOp(feedOp) &&

2135 if (isa(v))

2136 return true;

2137 if (Operation *op = v.getDefiningOp())

2138 return isCastOfBlockArgument(op);

2139 return false;

2140 }))) {

2141 return std::nullopt;

2142 }

2143

2144 return ConvOperationKind::Conv;

2145 }

2146 case 2:

2147

2148 return ConvOperationKind::Pool;

2149 default:

2150 return std::nullopt;

2151 }

2152 }

2153

2155 switch (kind) {

2156 case vector::CombiningKind::ADD:

2157 case vector::CombiningKind::MAXNUMF:

2158 case vector::CombiningKind::MAXIMUMF:

2159 case vector::CombiningKind::MAXSI:

2160 case vector::CombiningKind::MAXUI:

2161 case vector::CombiningKind::MINNUMF:

2162 case vector::CombiningKind::MINIMUMF:

2163 case vector::CombiningKind::MINSI:

2165 return true;

2166 default:

2167 return false;

2168 }

2169 }

2170

2172 auto getOperandType = [&](auto operand) {

2173 return dyn_cast((operand->get()).getType());

2174 };

2175 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));

2176 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));

2177 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));

2178

2179

2180

2181 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&

2182 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))

2183 return failure();

2184

2186 if (!reduceOp)

2187 return failure();

2188

2190 if (!maybeOper.has_value())

2191 return failure();

2192

2194

2195

2196

2197 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&

2198 *maybeKind != vector::CombiningKind::OR) &&

2199 (*maybeOper != ConvOperationKind::Pool ||

2201 return failure();

2202 }

2203

2204 auto rhsRank = rhsShapedType.getRank();

2205 if (*maybeOper == ConvOperationKind::Pool) {

2206 if (rhsRank != 1)

2207 return failure();

2208 } else {

2209 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)

2210 return failure();

2211 }

2212

2213 return success();

2214 }

2215

2218 bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {

2219

2220 if (llvm::is_contained(linalgOp.getStaticShape(), 0))

2221 return failure();

2222

2223 if (!inputVectorSizes.empty() &&

2225 inputVectorSizes)))

2226 return failure();

2227

2229 linalgOp, flatten1DDepthwiseConv))) {

2230 LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");

2231 return failure();

2232 }

2233

2235

2236

2238

2239

2241

2242 if (llvm::any_of(

2243 customPreconditions,

2245 return succeeded(

2246 customPrecondition(&innerOp, vectorizeNDExtract));

2247 })) {

2248 continue;

2249 }

2250 if (!llvm::all_of(innerOp.getOperandTypes(),

2251 VectorType::isValidElementType)) {

2252 return failure();

2253 }

2254 if (!llvm::all_of(innerOp.getResultTypes(),

2255 VectorType::isValidElementType)) {

2256 return failure();

2257 }

2258 }

2260 return success();

2261

2262

2263

2264

2265 if (isa(linalgOp.getOperation()))

2267

2268

2269

2270

2272 LDBG("precondition failed: not projected permutations\n");

2273 return failure();

2274 }

2276 LDBG("precondition failed: reduction preconditions\n");

2277 return failure();

2278 }

2279 return success();

2280 }

2281

2282 static LogicalResult

2285 auto padValue = packOp.getPaddingValue();

2288 LDBG("pad value is not constant: " << packOp << "\n");

2289 return failure();

2290 }

2291 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();

2292 bool satisfyEmptyCond = true;

2293 if (inputVectorSizes.empty()) {

2294 if (!packOp.getDestType().hasStaticShape() ||

2295 !packOp.getSourceType().hasStaticShape())

2296 satisfyEmptyCond = false;

2297 }

2298

2299 if (!satisfyEmptyCond &&

2301 resultTensorShape.take_front(packOp.getSourceRank()),

2302 inputVectorSizes)))

2303 return failure();

2304

2305 if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {

2306 return !getConstantIntValue(v).has_value();

2307 })) {

2308 LDBG("inner_tiles must be constant: " << packOp << "\n");

2309 return failure();

2310 }

2311

2312 return success();

2313 }

2314

2315 static LogicalResult

2318 auto padValue = padOp.getConstantPaddingValue();

2319 if (!padValue) {

2320 LDBG("pad value is not constant: " << padOp << "\n");

2321 return failure();

2322 }

2323

2324 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();

2326 inputVectorSizes)))

2327 return failure();

2328

2329

2330

2331

2332

2333

2334

2335

2336

2337

2338

2339

2340 if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) {

2341 Value padValue = en.value();

2342 unsigned pos = en.index();

2343 std::optional<int64_t> pad = getConstantIntValue(padValue);

2344 return (!pad.has_value() || pad.value() != 0) &&

2345 resultTensorShape[pos] != 1;

2346 })) {

2347 LDBG("low pad must all be zero for all non unit dims: " << padOp << "\n");

2348 return failure();

2349 }

2350

2351 return success();

2352 }

2353

2354

2355

2356 static LogicalResult

2360 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&

2361 "Number of input vector sizes and scalable dims doesn't match");

2362

2363 size_t numOfScalableDims =

2364 llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });

2365

2366 if (numOfScalableDims == 0)

2367 return success();

2368

2369 auto linalgOp = dyn_cast(op);

2370

2371

2372

2373 if (!linalgOp)

2374 return failure();

2375

2376

2377 if (numOfScalableDims > 2)

2378 return failure();

2379

2380

2381

2382

2383

2384

2385

2386

2387

2388

2389

2390

2391

2392

2393

2394

2395

2396

2397 bool seenNonUnitParallel = false;

2398 auto iterators = linalgOp.getIteratorTypesArray();

2400 int64_t idx = scalableFlags.size() - 1;

2401 while (!scalableFlags[idx]) {

2402 bool isNonUnitDim = (inputVectorSizes[idx] != 1);

2403 seenNonUnitParallel |=

2404 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);

2405

2406 iterators.pop_back();

2407 scalableFlags.pop_back();

2408 --idx;

2409 }

2410

2411

2412 switch (iterators.back()) {

2413 case utils::IteratorType::reduction: {

2414

2415 if (iterators.size() != inputVectorSizes.size()) {

2416 LDBG("Non-trailing reduction dim requested for scalable "

2417 "vectorization\n");

2418 return failure();

2419 }

2420 if (isalinalg::MatmulOp(op) || isalinalg::MatmulTransposeAOp(op)) {

2421 LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "

2422 "is not supported\n");

2423 return failure();

2424 }

2425 break;

2426 }

2427 case utils::IteratorType::parallel: {

2428

2429 if (seenNonUnitParallel) {

2430 LDBG("Inner parallel dim not requested for scalable "

2431 "vectorization\n");

2432 return failure();

2433 }

2434 break;

2435 }

2436 }

2437

2438

2439

2440

2441

2442 if (numOfScalableDims == 2) {

2443

2444

2445

2446 if (iterators.back() == utils::IteratorType::reduction) {

2447 LDBG("Higher dim than the trailing reduction dim requested for scalable "

2448 "vectorization\n");

2449 return failure();

2450 }

2451 scalableFlags.pop_back();

2452 iterators.pop_back();

2453

2454 if (!scalableFlags.back() ||

2455 (iterators.back() != utils::IteratorType::parallel))

2456 return failure();

2457 }

2458

2459

2460

2461 if (linalgOp.hasUserDefinedMaps())

2462 return failure();

2463

2464

2465

2466 return success(isElementwise(linalgOp) || isalinalg::MatmulOp(op) ||

2467 isalinalg::MatmulTransposeAOp(op) ||

2468 isalinalg::DepthwiseConv1DNwcWcOp(op) ||

2470 }

2471

2474 ArrayRef inputScalableVecDims, bool vectorizeNDExtract,

2475 bool flatten1DDepthwiseConv) {

2476

2478 return failure();

2479

2481 inputScalableVecDims)))

2482 return failure();

2483

2485 .Caselinalg::LinalgOp([&](auto linalgOp) {

2487 vectorizeNDExtract,

2488 flatten1DDepthwiseConv);

2489 })

2490 .Casetensor::PadOp([&](auto padOp) {

2492 })

2493 .Caselinalg::PackOp([&](auto packOp) {

2495 })

2496 .Caselinalg::UnPackOp([&](auto unpackOp) {

2498 })

2499 .Casetensor::InsertSliceOp([&](auto sliceOp) {

2501 })

2502 .Default([](auto) { return failure(); });

2503 }

2504

2505

2508 auto toReplace = linalgOp.getBlock()->getOpsaffine::AffineApplyOp();

2509

2510 for (auto op : make_early_inc_range(toReplace)) {

2513 rewriter, op->getLoc(), op.getAffineMap().getResult(0),

2514 op.getOperands().take_front(op.getAffineMap().getNumDims()),

2515 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));

2516 rewriter.replaceOp(op, expanded);

2517 }

2518 }

2519

2521 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,

2522 tensor::InsertSliceOp>(op);

2523 }

2524

2525

2526

2527

2528

2529

2530

2534 bool vectorizeNDExtract,

2535 bool flatten1DDepthwiseConv) {

2536 LDBG("Attempting to vectorize:\n" << *op << "\n");

2537 LDBG("Input vector sizes: ");

2538 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));

2539 LLVM_DEBUG(llvm::dbgs() << "\n");

2540 LDBG("Input scalable vector dims: ");

2541 LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));

2542 LLVM_DEBUG(llvm::dbgs() << "\n");

2543

2545 vectorizeNDExtract,

2546 flatten1DDepthwiseConv))) {

2547 LDBG("Vectorization pre-conditions failed\n");

2548 return failure();

2549 }

2550

2551

2553 if (auto linalgOp = dyn_castlinalg::LinalgOp(op)) {

2554 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,

2555 inputScalableVecDims))) {

2556 LDBG("Vectorization state couldn't be initialized\n");

2557 return failure();

2558 }

2559 }

2560

2562 auto vectorizeResult =

2564 .Caselinalg::LinalgOp([&](auto linalgOp) {

2565

2566

2567

2568 if (isa(linalgOp.getOperation())) {

2570 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,

2571 flatten1DDepthwiseConv);

2572 if (succeeded(convOr)) {

2573 llvm::append_range(results, (*convOr)->getResults());

2574 return success();

2575 }

2576

2577 LDBG("Unsupported convolution can't be vectorized.\n");

2578 return failure();

2579 }

2580

2581 LDBG("Vectorize generic by broadcasting to the canonical vector "

2582 "shape\n");

2583

2584

2586

2587

2588

2589

2590

2591

2593 })

2594 .Casetensor::PadOp([&](auto padOp) {

2596 results);

2597 })

2598 .Caselinalg::PackOp([&](auto packOp) {

2600 results);

2601 })

2602 .Caselinalg::UnPackOp([&](auto unpackOp) {

2604 inputVectorSizes, results);

2605 })

2606 .Casetensor::InsertSliceOp([&](auto sliceOp) {

2608 results);

2609 })

2610 .Default([](auto) { return failure(); });

2611

2612 if (failed(vectorizeResult)) {

2613 LDBG("Vectorization failed\n");

2614 return failure();

2615 }

2616

2617 if (!results.empty())

2618 rewriter.replaceOp(op, results);

2619 else

2621

2622 return success();

2623 }

2624

2626 memref::CopyOp copyOp) {

2627 auto srcType = cast(copyOp.getSource().getType());

2628 auto dstType = cast(copyOp.getTarget().getType());

2629 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())

2630 return failure();

2631

2634 if (!VectorType::isValidElementType(srcElementType) ||

2635 !VectorType::isValidElementType(dstElementType))

2636 return failure();

2637

2638 auto readType = VectorType::get(srcType.getShape(), srcElementType);

2639 auto writeType = VectorType::get(dstType.getShape(), dstElementType);

2640

2641 Location loc = copyOp->getLoc();

2644

2646 loc, readType, copyOp.getSource(), indices,

2648 if (cast(readValue.getType()).getRank() == 0) {

2652 }

2653 Operation *writeValue = rewriter.createvector::TransferWriteOp(

2654 loc, readValue, copyOp.getTarget(), indices,

2657 return success();

2658 }

2659

2660

2661

2662

2663

2664

2665 template

2668

2672

2673 for (auto *user : llvm::to_vector<4>(padOp->getUsers()))

2674 if (auto op = dyn_cast(user))

2675 changed |= rewriteUser(rewriter, padOp, op).succeeded();

2676 return success(changed);

2677 }

2678

2679 protected:

2681 tensor::PadOp padOp, OpTy op) const = 0;

2682 };

2683

2684

2685

2686

2687

2688

2689

2690

2691

2692

2693

2694

2695

2696

2697

2698

2699

2700

2701

2702

2707

2709 vector::TransferReadOp xferOp) const override {

2710

2711 if (!padOp.hasZeroLowPad())

2712 return failure();

2713

2714 auto padValue = padOp.getConstantPaddingValue();

2715 if (!padValue)

2716 return failure();

2717

2718 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())

2719 return failure();

2720

2722 SmallVector inBounds(xferOp.getVectorType().getRank(), false);

2723 xferOp->setAttr(xferOp.getInBoundsAttrName(),

2725 xferOp.getBaseMutable().assign(padOp.getSource());

2726 xferOp.getPaddingMutable().assign(padValue);

2727 });

2728

2729 return success();

2730 }

2731 };

2732

2733

2734

2735

2736

2737

2738

2739

2740

2741

2742

2743

2744

2745

2746

2747

2748

2749

2750

2751

2752

2753

2754

2755

2756

2757

2758

2759

2760

2761

2762

2763

2764

2769

2771 vector::TransferWriteOp xferOp) const override {

2772

2773 if (xferOp.getTransferRank() == 0)

2774 return failure();

2775

2776

2777 if (!padOp.hasZeroLowPad())

2778 return failure();

2779

2780 auto padValue = padOp.getConstantPaddingValue();

2781 if (!padValue)

2782 return failure();

2783

2784 if (!xferOp->hasOneUse())

2785 return failure();

2786 auto trimPadding = dyn_casttensor::ExtractSliceOp(*xferOp->user_begin());

2787 if (!trimPadding)

2788 return failure();

2789

2790 if (!trimPadding.hasZeroOffset())

2791 return failure();

2792

2793 if (!hasSameTensorSize(padOp.getSource(), trimPadding))

2794 return failure();

2795

2796

2798

2799 SmallVector inBounds(xferOp.getVectorType().getRank(), false);

2800 auto newXferOp = rewriter.replaceOpWithNewOpvector::TransferWriteOp(

2801 xferOp, padOp.getSource().getType(), xferOp.getVector(),

2802 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),

2804 rewriter.replaceOp(trimPadding, newXferOp->getResult(0));

2805

2806 return success();

2807 }

2808

2809

2810

2811

2812

2813

2814

2815

2816

2817

2818

2820 tensor::ExtractSliceOp afterTrimming) const {

2821

2822

2823 if (auto castOp = beforePadding.getDefiningOptensor::CastOp())

2824 if (hasSameTensorSize(castOp.getSource(), afterTrimming))

2825 return true;

2826

2827 auto t1 = dyn_cast(beforePadding.getType());

2828 auto t2 = dyn_cast(afterTrimming.getType());

2829

2830 if (!t1 || !t2)

2831 return false;

2832

2833 if (t1.getRank() != t2.getRank())

2834 return false;

2835

2836

2837

2838 for (unsigned i = 0; i < t1.getRank(); ++i) {

2839 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))

2840 return false;

2841 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))

2842 return false;

2843 }

2844

2845

2846 if (t1.getNumDynamicDims() == 0)

2847 return true;

2848

2849

2850

2851

2852

2853

2854 auto beforeSlice = beforePadding.getDefiningOptensor::ExtractSliceOp();

2855 if (!beforeSlice)

2856 return false;

2857

2858 assert(static_cast<size_t>(t1.getRank()) ==

2859 beforeSlice.getMixedSizes().size());

2860 assert(static_cast<size_t>(t2.getRank()) ==

2861 afterTrimming.getMixedSizes().size());

2862

2863 for (unsigned i = 0; i < t1.getRank(); ++i) {

2864

2865 if (!t1.isDynamicDim(i))

2866 continue;

2867 auto size1 = beforeSlice.getMixedSizes()[i];

2868 auto size2 = afterTrimming.getMixedSizes()[i];

2869

2870

2872 continue;

2873

2874

2875 auto v1 = llvm::dyn_cast_if_present(size1);

2876 auto v2 = llvm::dyn_cast_if_present(size2);

2877 if (!v1 || !v2)

2878 return false;

2879

2880

2881

2882 auto minOp1 = v1.getDefiningOpaffine::AffineMinOp();

2883 auto minOp2 = v2.getDefiningOpaffine::AffineMinOp();

2884 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&

2885 minOp1.getOperands() == minOp2.getOperands())

2886 continue;

2887

2888

2889 }

2890

2891

2892 return true;

2893 }

2894 };

2895

2896

2897

2898

2899

2900

2901

2902

2903

2904

2906 if (!op)

2907 return {};

2908

2909

2910

2911 if (auto bcast = llvm::dyn_castvector::BroadcastOp(op)) {

2912 auto source = bcast.getSource();

2913 if (llvm::dyn_cast(source.getType()))

2914 return {};

2915

2916 return source;

2917 }

2918

2919

2920

2921 if (auto fill = llvm::dyn_castlinalg::FillOp(op)) {

2922 return fill.getInputs()[0];

2923 }

2924

2925

2926

2927 if (auto generate = llvm::dyn_casttensor::GenerateOp(op)) {

2928 return {};

2929 }

2930

2931

2932

2933

2934 if (auto xferWrite = llvm::dyn_castvector::TransferWriteOp(op))

2935 return getStaticPadVal(xferWrite.getVector().getDefiningOp());

2936

2937

2938

2939

2940

2941

2942 if (auto slice = llvm::dyn_casttensor::InsertSliceOp(op))

2943 return getStaticPadVal(slice.getDest().getDefiningOp());

2944

2945 return {};

2946 }

2947

2948 static LogicalResult

2952

2955

2957 auto sourceType = source.getType();

2958 auto resultType = sliceOp.getResultType();

2959

2961

2962 if (!padValue) {

2963 auto elemType = sourceType.getElementType();

2964 padValue = rewriter.createarith::ConstantOp(

2965 sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));

2966 }

2967

2968

2970 size_t rankDiff = resultType.getRank() - sourceType.getRank();

2971 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {

2972 if (!inputVectorSizes.empty()) {

2973 vecShape.push_back(inputVectorSizes[i]);

2974 } else if (!sourceType.isDynamicDim(i)) {

2975 vecShape.push_back(sourceType.getDimSize(i));

2976 } else if (!resultType.isDynamicDim(i)) {

2977

2978

2979

2980

2981

2982 vecShape.push_back(resultType.getDimSize(rankDiff + i));

2983 } else {

2984

2985

2986 return failure();

2987 }

2988 }

2989 auto vecType = VectorType::get(vecShape, sourceType.getElementType());

2990

2991

2992 auto loc = sliceOp.getLoc();

2993

2994

2996 vecType.getRank(), rewriter.createarith::ConstantIndexOp(loc, 0));

2998 rewriter, loc, source, vecType.getShape(), padValue,

2999 inputVectorSizes.empty());

3000

3001

3002 auto writeIndices =

3006 writeIndices, inputVectorSizes.empty());

3007

3008

3009 newResults.push_back(write->getResult(0));

3010

3011 return success();

3012 }

3013

3014

3015

3016

3017

3018

3019

3020

3021

3022

3023

3024

3025

3026

3027

3028

3029

3030

3031

3032

3033

3034

3035

3036

3041

3043 tensor::InsertSliceOp insertOp) const override {

3044

3045 if (!padOp.hasZeroLowPad())

3046 return failure();

3047

3048 if (!insertOp.hasUnitStride())

3049 return failure();

3050

3051 auto padValue = padOp.getConstantPaddingValue();

3052 if (!padValue)

3053 return failure();

3054

3055 if (!cast(padOp.getResult().getType()).hasStaticShape())

3056 return failure();

3057

3058 if (insertOp.getDest() == padOp.getResult())

3059 return failure();

3060

3061 auto vecType = VectorType::get(padOp.getType().getShape(),

3062 padOp.getType().getElementType());

3063 unsigned vecRank = vecType.getRank();

3064 unsigned tensorRank = insertOp.getType().getRank();

3065

3066

3067

3069 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());

3070 if (!llvm::all_of(

3071 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {

3072 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);

3073 }))

3074 return failure();

3075

3076

3077

3079

3080

3081

3083 vecRank, rewriter.createarith::ConstantIndexOp(padOp.getLoc(), 0));

3084 auto read = rewriter.createvector::TransferReadOp(

3085 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);

3086

3087

3088

3089

3091 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());

3094 insertOp, read, insertOp.getDest(), writeIndices,

3096

3097 return success();

3098 }

3099 };

3100

3107 }

3108

3109

3110

3111

3112

3113

3114

3115

3120 LDBG("interleavedUses precondition failed, firstOp: "

3121 << *firstOp << ", second op: " << *secondOp << "\n");

3122 return true;

3123 }

3124 for (auto v : values) {

3125 for (auto &u : v.getUses()) {

3126 Operation *owner = u.getOwner();

3127 if (owner == firstOp || owner == secondOp)

3128 continue;

3129

3132 continue;

3133 LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp

3134 << ", second op: " << *secondOp << "\n");

3135 return true;

3136 }

3137 }

3138 return false;

3139 }

3140

3141

3142

3144 memref::SubViewOp subViewOp;

3145 for (auto &u : v.getUses()) {

3146 if (auto newSubViewOp = dyn_castmemref::SubViewOp(u.getOwner())) {

3147 if (subViewOp)

3148 return memref::SubViewOp();

3149 subViewOp = newSubViewOp;

3150 }

3151 }

3152 return subViewOp;

3153 }

3154

3155

3156

3158 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {

3159

3160

3161 if (xferOp.getMask())

3163

3164

3165 Value viewOrAlloc = xferOp.getBase();

3166 if (!viewOrAlloc.getDefiningOpmemref::ViewOp() &&

3168 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");

3169

3170

3172 if (!subViewOp)

3174 Value subView = subViewOp.getResult();

3175

3176

3177 memref::CopyOp copyOp;

3178 for (auto &u : subView.getUses()) {

3179 if (auto newCopyOp = dyn_castmemref::CopyOp(u.getOwner())) {

3180 assert(isa(newCopyOp.getTarget().getType()));

3181 if (newCopyOp.getTarget() != subView)

3182 continue;

3184 continue;

3185 copyOp = newCopyOp;

3186 break;

3187 }

3188 }

3189 if (!copyOp)

3191

3192

3193

3194 FillOp maybeFillOp;

3195 for (auto &u : viewOrAlloc.getUses()) {

3196 if (auto newFillOp = dyn_cast(u.getOwner())) {

3197 assert(isa(newFillOp.output().getType()));

3198 if (newFillOp.output() != viewOrAlloc)

3199 continue;

3201 continue;

3202 maybeFillOp = newFillOp;

3203 break;

3204 }

3205 }

3206

3207 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())

3209 "padding value does not match fill");

3210

3211

3212 Value in = copyOp.getSource();

3213

3214

3215

3216

3217

3218 auto vectorType = xferOp.getVectorType();

3219 Value res = rewriter.createvector::TransferReadOp(

3220 xferOp.getLoc(), vectorType, in, xferOp.getIndices(),

3221 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),

3224

3225 if (maybeFillOp)

3226 rewriter.eraseOp(maybeFillOp);

3227 rewriter.eraseOp(copyOp);

3228 rewriter.replaceOp(xferOp, res);

3229

3230 return success();

3231 }

3232

3233

3234

3236 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {

3237

3238 if (xferOp.getMask())

3240

3241

3242 Value viewOrAlloc = xferOp.getBase();

3243 if (!viewOrAlloc.getDefiningOpmemref::ViewOp() &&

3245 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");

3246

3247

3249 if (!subViewOp)

3251 Value subView = subViewOp.getResult();

3252

3253

3254 memref::CopyOp copyOp;

3255 for (auto &u : subViewOp.getResult().getUses()) {

3256 if (auto newCopyOp = dyn_castmemref::CopyOp(u.getOwner())) {

3257 if (newCopyOp.getSource() != subView)

3258 continue;

3260 continue;

3261 copyOp = newCopyOp;

3262 break;

3263 }

3264 }

3265 if (!copyOp)

3267

3268

3269 assert(isa(copyOp.getTarget().getType()));

3270 Value out = copyOp.getTarget();

3271

3272

3273

3274

3275

3276

3277 auto vector = xferOp.getVector();

3278 rewriter.createvector::TransferWriteOp(

3279 xferOp.getLoc(), vector, out, xferOp.getIndices(),

3280 xferOp.getPermutationMapAttr(), xferOp.getMask(),

3282 dyn_cast(vector.getType()).getRank(), false)));

3283

3284 rewriter.eraseOp(copyOp);

3285 rewriter.eraseOp(xferOp);

3286

3287 return success();

3288 }

3289

3290

3291

3292

3293

3294 template

3296

3297 template <int N, typename IntTy, typename... IntTy2>

3298 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {

3299 val = shapedType.getShape()[N];

3300 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);

3301 }

3302

3303

3304 template <typename... IntTy>

3305 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {

3306 bindShapeDims<0>(shapedType, vals...);

3307 }

3308

3309 namespace {

3310

3311

3312

3313

3314

3315

3316

3317

3318

3319

3320

3321

3322

3323

3324

3325

3326

3327

3328

3329

3330

3331

3332

3333

3334

3335

3336

3337

3338

3339

3340

3341

3342

3343

3344 struct Conv1DGenerator

3346 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)

3347 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {

3348

3349 lhsShaped = linalgOp.getDpsInputOperand(0)->get();

3350 rhsShaped = linalgOp.getDpsInputOperand(1)->get();

3351 resShaped = linalgOp.getDpsInitOperand(0)->get();

3352 lhsShapedType = dyn_cast(lhsShaped.getType());

3353 rhsShapedType = dyn_cast(rhsShaped.getType());

3354 resShapedType = dyn_cast(resShaped.getType());

3355

3358

3359 setConvOperationKind(reduceOp);

3360

3362 reductionKind = maybeKind.value();

3363

3364

3365

3366

3367

3369 auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");

3370 strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;

3371 dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;

3372 }

3373

3374

3375

3376

3377

3378

3379

3380

3381

3382

3383

3384

3385

3386

3387

3388

3389

3390

3391

3392 FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {

3393 int64_t nSize, wSize, cSize, kwSize, fSize;

3395 bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);

3396 switch (conv1DOpOrder) {

3398

3399 nSize = fSize = cSize = 0;

3400

3402

3404 lhsShape = {

3405

3406 (wSize + kwSize - 1)};

3407 rhsShape = {kwSize};

3408 resShape = {wSize};

3409 break;

3411

3412 bindShapeDims(resShapedType, nSize, wSize, fSize);

3413 switch (oper) {

3414 case ConvOperationKind::Conv:

3415

3417 break;

3418 case ConvOperationKind::Pool:

3419

3421 cSize = fSize;

3422 break;

3423 }

3424 lhsShape = {nSize,

3425

3426

3427

3428 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -

3429 1,

3430 cSize};

3431 switch (oper) {

3432 case ConvOperationKind::Conv:

3433 rhsShape = {kwSize, cSize, fSize};

3434 break;

3435 case ConvOperationKind::Pool:

3436 rhsShape = {kwSize};

3437 break;

3438 }

3439 resShape = {nSize, wSize, fSize};

3440 break;

3442

3443 bindShapeDims(resShapedType, nSize, fSize, wSize);

3444 switch (oper) {

3445 case ConvOperationKind::Conv:

3446

3447 bindShapeDims(rhsShapedType, fSize, cSize, kwSize);

3448 break;

3449 case ConvOperationKind::Pool:

3450

3452 cSize = fSize;

3453 break;

3454 }

3455 lhsShape = {nSize, cSize,

3456

3457

3458

3459 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -

3460 1};

3461 switch (oper) {

3462 case ConvOperationKind::Conv:

3463 rhsShape = {fSize, cSize, kwSize};

3464 break;

3465 case ConvOperationKind::Pool:

3466 rhsShape = {kwSize};

3467 break;

3468 }

3469 resShape = {nSize, fSize, wSize};

3470 break;

3471 }

3472

3473 vector::TransferWriteOp write;

3474 Value zero = rewriter.createarith::ConstantIndexOp(loc, 0);

3475

3476

3477

3478

3479 int64_t wSizeStep = strideW == 1 ? wSize : 1;

3480

3481 Type lhsEltType = lhsShapedType.getElementType();

3482 Type rhsEltType = rhsShapedType.getElementType();

3483 Type resEltType = resShapedType.getElementType();

3487

3491

3492

3493 Value lhs = rewriter.createvector::TransferReadOp(loc, lhsType, lhsShaped,

3494 lhsPadding);

3495

3496 Value rhs = nullptr;

3497 if (oper == ConvOperationKind::Conv)

3498 rhs = rewriter.createvector::TransferReadOp(loc, rhsType, rhsShaped,

3499 rhsPadding);

3500 Value res = rewriter.createvector::TransferReadOp(loc, resType, resShaped,

3501 resPadding);

3502

3503

3504

3505

3506 switch (conv1DOpOrder) {

3509

3510 break;

3512

3513

3514 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};

3515 lhs = rewriter.createvector::TransposeOp(loc, lhs, permLhs);

3516

3517 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};

3518

3519

3520 if (oper == ConvOperationKind::Conv)

3521 rhs = rewriter.createvector::TransposeOp(loc, rhs, permRhs);

3522

3523 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};

3524 res = rewriter.createvector::TransposeOp(loc, res, permRes);

3525 break;

3526 }

3527 }

3528

3529

3530

3531

3532

3535 kwSize, strideW, dilationW, wSizeStep,

3536 isSingleChanneled);

3537

3538 if (oper == ConvOperationKind::Conv)

3541 wSizeStep, isSingleChanneled);

3542

3543 auto linearIndex = [&](int64_t kw, int64_t w) {

3544 return kw * (wSize / wSizeStep) + w;

3545 };

3546

3547

3548

3549

3550 for (int64_t kw = 0; kw < kwSize; ++kw) {

3551 for (int64_t w = 0; w < wSize; w += wSizeStep) {

3552 switch (oper) {

3553 case ConvOperationKind::Conv:

3554 if (isSingleChanneled) {

3555 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,

3556 lhsVals[linearIndex(kw, w)],

3557 rhsVals[kw], resVals[w]);

3558 } else {

3559 resVals[w] = conv1dSliceAsContraction(rewriter, loc,

3560 lhsVals[linearIndex(kw, w)],

3561 rhsVals[kw], resVals[w]);

3562 }

3563 break;

3564 case ConvOperationKind::Pool:

3565 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],

3566 resVals[w]);

3567 break;

3568 }

3569 }

3570 }

3571

3573 isSingleChanneled);

3574

3575

3576

3577

3578

3579

3580

3581 switch (conv1DOpOrder) {

3584

3585 break;

3587

3588 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};

3589 res = rewriter.createvector::TransposeOp(loc, res, perm);

3590 break;

3591 }

3592 }

3593

3594 return rewriter

3595 .createvector::TransferWriteOp(loc, res, resShaped, resPadding)

3596 .getOperation();

3597 }

3598

3599

3603 assert(isa(dstElementType) || isa(dstElementType));

3604 if (srcElementType == dstElementType)

3605 return val;

3606

3609 const Type dstType =

3610 cast(val.getType()).cloneWith(std::nullopt, dstElementType);

3611

3612 if (isa(srcElementType) && isa(dstElementType)) {

3613 return rewriter.createarith::SIToFPOp(loc, dstType, val);

3614 }

3615

3616 if (isa(srcElementType) && isa(dstElementType) &&

3617 srcWidth < dstWidth)

3618 return rewriter.createarith::ExtFOp(loc, dstType, val);

3619

3620 if (isa(srcElementType) && isa(dstElementType) &&

3621 srcWidth < dstWidth)

3622 return rewriter.createarith::ExtSIOp(loc, dstType, val);

3623

3624 assert(false && "unhandled promotion case");

3625 return nullptr;

3626 }

3627

3628

3631 vector::IteratorType par = vector::IteratorType::parallel;

3632 vector::IteratorType red = vector::IteratorType::reduction;

3637 auto contrationOp = rewriter.createvector::ContractionOp(

3638 loc, lhs, rhs, res,

3639 MapList{{n, w, c}, {c, f}, {n, w, f}},

3641 contrationOp.setKind(reductionKind);

3642 return contrationOp;

3643 }

3644

3645

3646

3649 return rewriter.createvector::OuterProductOp(

3650 loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);

3651 }

3652

3653

3656 if (isPoolExt)

3658 return rewriter

3661 }

3662

3663

3664

3665

3666

3667

3668

3669

3670

3671

3672 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,

3673 bool channelDimScalableFlag,

3674 bool flatten) {

3675 bool scalableChDim = false;

3676 bool useMasking = false;

3677 int64_t nSize, wSize, cSize, kwSize;

3678

3680 if (ShapedType::isDynamic(cSize)) {

3681 assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");

3682 cSize = channelDimVecSize;

3683

3684

3685

3686 scalableChDim = channelDimScalableFlag;

3687 useMasking = true;

3688 }

3689

3690 assert(!(useMasking && flatten) &&

3691 "Unsupported flattened conv with dynamic shapes");

3692

3693

3695

3696 vector::TransferWriteOp write;

3697 Value zero = rewriter.createarith::ConstantIndexOp(loc, 0);

3698

3699

3700

3701

3702 int64_t wSizeStep = strideW == 1 ? wSize : 1;

3703

3704 Type lhsEltType = lhsShapedType.getElementType();

3705 Type rhsEltType = rhsShapedType.getElementType();

3706 Type resEltType = resShapedType.getElementType();

3708 {nSize,

3709

3710

3711 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,

3712 cSize},

3713 lhsEltType, {false, false, scalableChDim});

3714 VectorType rhsType =

3716 {false, scalableChDim});

3717 VectorType resType =

3719 {false, false, scalableChDim});

3720

3721

3722

3726 if (!useMasking)

3727 return opToMask;

3728 auto maskType =

3730

3732 auto xferOp = cast(opToMask);

3733 xferOp->setAttr(xferOp.getInBoundsAttrName(),

3735

3737 cast(op).hasPureTensorSemantics(), opToMask, rewriter);

3738

3740 rewriter.createvector::CreateMaskOp(loc, maskType, mixedDims);

3741

3743 };

3744

3745

3746

3747 Value lhs = rewriter.createvector::TransferReadOp(

3748 loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});

3749 auto maybeMaskedLhs = maybeMaskXferOp(

3750 lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());

3751

3752

3753 Value rhs = rewriter.createvector::TransferReadOp(loc, rhsType, rhsShaped,

3755 auto maybeMaskedRhs = maybeMaskXferOp(

3756 rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());

3757

3758

3759 Value res = rewriter.createvector::TransferReadOp(

3760 loc, resType, resShaped, ValueRange{zero, zero, zero});

3761 auto maybeMaskedRes = maybeMaskXferOp(

3762 resType.getShape(), resType.getScalableDims(), res.getDefiningOp());

3763

3764

3765

3766

3767

3771

3772

3773

3774 for (int64_t kw = 0; kw < kwSize; ++kw) {

3775 for (int64_t w = 0; w < wSize; w += wSizeStep) {

3776 lhsVals.push_back(rewriter.createvector::ExtractStridedSliceOp(

3777 loc, maybeMaskedLhs->getResult(0),

3778 ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},

3779 inOutSliceSizes, inOutStrides));

3780 }

3781 }

3782

3783 for (int64_t kw = 0; kw < kwSize; ++kw) {

3784 rhsVals.push_back(rewriter.createvector::ExtractOp(

3785 loc, maybeMaskedRhs->getResult(0),

3787 }

3788

3789 for (int64_t w = 0; w < wSize; w += wSizeStep) {

3790 resVals.push_back(rewriter.createvector::ExtractStridedSliceOp(

3791 loc, maybeMaskedRes->getResult(0),

3793 inOutStrides));

3794 }

3795

3796 auto linearIndex = [&](int64_t kw, int64_t w) {

3797 return kw * (wSize / wSizeStep) + w;

3798 };

3799

3800

3801

3803 auto lhsTypeAfterFlattening =

3805 auto resTypeAfterFlattening =

3807

3808

3809 for (int64_t kw = 0; kw < kwSize; ++kw) {

3810 for (int64_t w = 0; w < wSize; w += wSizeStep) {

3811 Value lhsVal = lhsVals[linearIndex(kw, w)];

3812 Value resVal = resVals[w];

3813 if (flatten) {

3814

3815

3816 lhsVal = rewriter.createvector::ShapeCastOp(

3817 loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);

3818 resVal = rewriter.createvector::ShapeCastOp(

3819 loc, resTypeAfterFlattening, resVals[w]);

3820 }

3821 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,

3822 rhsVals[kw], resVal, flatten);

3823 if (flatten) {

3824

3825 resVals[w] = rewriter.createvector::ShapeCastOp(

3826 loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);

3827 }

3828 }

3829 }

3830

3831

3832 if (!llvm::all_of(resVals, [](Value v) { return v; })) {

3833

3834 for (auto &collection :

3835 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})

3836 for (Value v : collection)

3839 }

3840

3841

3842

3843 for (int64_t w = 0; w < wSize; w += wSizeStep) {

3844 maybeMaskedRes = rewriter.createvector::InsertStridedSliceOp(

3845 loc, resVals[w], maybeMaskedRes->getResult(0),

3848 }

3849

3850

3851

3852

3853

3854 Operation *resOut = rewriter.createvector::TransferWriteOp(

3855 loc, maybeMaskedRes->getResult(0), resShaped,

3857 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),

3858 resOut);

3859 }

3860

3861

3862

3863

3864

3867 bool flatten) {

3868 auto rhsTy = cast(rhs.getType());

3869 auto resTy = cast(res.getType());

3870

3871

3872 lhs = promote(rewriter, loc, lhs, resTy);

3873

3874 if (flatten) {

3875

3876

3877

3878

3879

3880

3881

3882

3883 auto rhsSize = cast(rhs.getType()).getShape()[0];

3884 auto resSize = cast(res.getType()).getShape()[1];

3885

3887 for (int i = 0; i < resSize / rhsSize; ++i) {

3888 for (int j = 0; j < rhsSize; ++j)

3889 indices.push_back(j);

3890 }

3891

3892 rhs = rewriter.createvector::ShuffleOp(loc, rhs, rhs, indices);

3893 }

3894

3895 rhs = rewriter.createvector::BroadcastOp(

3896 loc, resTy.clone(rhsTy.getElementType()), rhs);

3897

3898 rhs = promote(rewriter, loc, rhs, resTy);

3899

3900 if (!lhs || !rhs)

3901 return nullptr;

3902

3903 if (isa(resTy.getElementType()))

3904 return rewriter.createvector::FMAOp(loc, lhs, rhs, res);

3905

3906 auto mul = rewriter.createarith::MulIOp(loc, lhs, rhs);

3907 return rewriter.createarith::AddIOp(loc, mul, res);

3908 }

3909

3910

3911

3912 FailureOr<Operation *> generateNonChanneledConv() {

3915 if (!iters({Par(), Red()}))

3917 "failed to match conv::W 1-par 1-red");

3918

3919

3920 if (layout({ {w + kw},

3921 {kw},

3922 {w}}))

3924

3926 }

3927

3928

3929

3930 FailureOr<Operation *> generateNwcConv() {

3932 bindDims(ctx, n, w, f, kw, c);

3933 if (!iters({Par(), Par(), Par(), Red(), Red()}))

3935 op, "failed to match conv::Nwc 3-par 2-red");

3936

3937

3938 if (layout({ {n, strideW * w + dilationW * kw, c},

3939 {kw, c, f},

3940 {n, w, f}}))

3942

3944 }

3945

3946

3947

3948 FailureOr<Operation *> generateNcwConv() {

3950 bindDims(ctx, n, f, w, c, kw);

3951 if (!iters({Par(), Par(), Par(), Red(), Red()}))

3953 op, "failed to match conv::Ncw 3-par 2-red");

3954

3955 if (layout({ {n, c, strideW * w + dilationW * kw},

3956 {f, c, kw},

3957 {n, f, w}}))

3959

3961 }

3962

3963

3964

3965 FailureOr<Operation *> generateNwcPooling() {

3968 if (!iters({Par(), Par(), Par(), Red()}))

3970 "failed to match pooling 3-par 1-red");

3971

3972

3973 if (layout({ {n, strideW * w + dilationW * kw, c},

3974 {kw},

3975 {n, w, c}}))

3977

3978 return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");

3979 }

3980

3981

3982

3983 FailureOr<Operation *> generateNcwPooling() {

3986 if (!iters({Par(), Par(), Par(), Red()}))

3988 "failed to match pooling 3-par 1-red");

3989

3990 if (layout({ {n, c, strideW * w + dilationW * kw},

3991 {kw},

3992 {n, c, w}}))

3994

3995 return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");

3996 }

3997

3998

3999

4000 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,

4001 bool vecChDimScalableFlag = false,

4002 bool flatten = false) {

4005 if (!iters({Par(), Par(), Par(), Red()}))

4007 op, "failed to match depthwise::Nwc conv 3-par 1-red");

4008

4009

4010 if (layout({ {n, strideW * w + dilationW * kw, c},

4011 {kw, c},

4012 {n, w, c}}))

4013 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);

4014

4015 return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");

4016 }

4017

4018 private:

4019 ConvOperationKind oper = ConvOperationKind::Conv;

4020 StringAttr redOp;

4021 StringAttr poolExtOp;

4022 bool isPoolExt = false;

4023 int strideW, dilationW;

4024 Value lhsShaped, rhsShaped, resShaped;

4025 ShapedType lhsShapedType, rhsShapedType, resShapedType;

4026 vector::CombiningKind reductionKind;

4027

4028

4029 void setConvOperationKind(Operation *reduceOp) {

4030 int numBlockArguments =

4031 llvm::count_if(reduceOp->getOperands(), llvm::IsaPred);

4032 if (numBlockArguments == 1) {

4033

4034

4035

4036

4037 auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),

4038 llvm::IsaPred);

4039 Operation *feedOp = (*feedValIt).getDefiningOp();

4041 oper = ConvOperationKind::Pool;

4042 isPoolExt = true;

4044 return;

4045 }

4046 oper = ConvOperationKind::Conv;

4047 return;

4048 }

4049

4050 oper = ConvOperationKind::Pool;

4051 isPoolExt = false;

4052 }

4053 };

4054 }

4055

4056

4057

4060 ArrayRef inputScalableVecDims, bool flatten1DDepthwiseConv) {

4061 Conv1DGenerator conv1dGen(rewriter, op);

4062 auto res = conv1dGen.generateNonChanneledConv();

4063 if (succeeded(res))

4064 return res;

4065 res = conv1dGen.generateNwcConv();

4066 if (succeeded(res))

4067 return res;

4068 res = conv1dGen.generateNcwConv();

4069 if (succeeded(res))

4070 return res;

4071 res = conv1dGen.generateNwcPooling();

4072 if (succeeded(res))

4073 return res;

4074 res = conv1dGen.generateNcwPooling();

4075 if (succeeded(res))

4076 return res;

4077

4078

4079

4080

4081 uint64_t vecChDimSize = ShapedType::kDynamic;

4082 bool vecChDimScalableFlag = false;

4083 if (!inputVecSizes.empty()) {

4084

4085

4086 assert((isalinalg::DepthwiseConv1DNwcWcOp(*op) ||

4087 isalinalg::DepthwiseConv1DNcwCwOp(*op)) &&

4088 "Not a 1D depthwise conv!");

4089 size_t chDimIdx =

4091 .Caselinalg::DepthwiseConv1DNwcWcOp([](auto conv) { return 2; })

4092 .Caselinalg::DepthwiseConv1DNcwCwOp([](auto conv) { return 1; });

4093

4094 vecChDimSize = inputVecSizes[chDimIdx];

4095 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];

4096 }

4097 return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,

4098 flatten1DDepthwiseConv);

4099 }

4100

4103

4107 if (failed(resultOrFail))

4108 return failure();

4109 Operation *newOp = *resultOrFail;

4111 rewriter.eraseOp(op.getOperation());

4112 return success();

4113 }

4114 assert(newOp->getNumResults() == 1 && "expected single result");

4116 return success();

4117 }

4118 };

4119

4123 }

SmallVector< int64_t > outerDimsPerm

SmallVector< OpFoldResult > innerTiles

union mlir::linalg::@1203::ArityGroupAndKind::Kind kind

SmallVector< int64_t > innerDimsPos

static std::optional< VectorShape > vectorShape(Type type)

static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)

static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)

static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)

static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)

Vectorize a linalg::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...

static memref::SubViewOp getSubViewUseIfUnique(Value v)

Return the unique subview use of v if it is indeed unique, null otherwise.

static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)

Vectorize linalg::PackOp with (1) static inner_tiles (2) constant padding value and (3) input vector ...

static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)

Checks whether val can be used for calculating a loop invariant index.

static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)

Helper function to insert the computed result slices.

static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)

Infer the memory access pattern for the input ExtractOp.

static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)

Determines whether a mask for xfer_write is trivially "all true".

static LogicalResult reductionPreconditions(LinalgOp op)

static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)

Generic vectorization for a single operation op, given already vectorized operands carried by bvm.

static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)

Vectorize tensor::InsertSliceOp with:

static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)

static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)

Helper function to extract the filter slices after filter is unrolled along kw.

static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)

Try to vectorize convOp as a convolution.

static bool isCastOfBlockArgument(Operation *op)

static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)

Generic vectorization function that rewrites the body of a linalgOp into vector form.

static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)

Helper function to vectorize the terminator of a linalgOp.

static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)

Helper function to extract the input slices after filter is unrolled along kw.

static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)

Creates an optionally masked TransferWriteOp.

static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)

Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...

static Value getStaticPadVal(Operation *op)

Returns the effective Pad value for the input op, provided it's a scalar.

static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)

Emit reduction operations if the shapes of the value to reduce is different that the result shape.

static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)

static void bindShapeDims(ShapedType shapedType)

static bool hasReductionIterator(LinalgOp &op)

Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.

static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)

Check whether there is any interleaved use of any values between firstOp and secondOp.

static Operation * matchLinalgReduction(OpOperand *outputOperand)

Check whether outputOperand is a reduction with a single combiner operation.

static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)

Helper function to vectorize the tensor.extract operations.

static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)

Create MultiDimReductionOp to compute the reduction for reductionOp.

static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)

Find the index of the trailing non-unit dim in linalgOp.

std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook

Conv1DOpOrder

Helper enum to represent conv1d input traversal order.

static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)

static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)

static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)

Helper function to vectorize the index operations of a linalgOp.

std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition

static bool isSupportedPoolKind(vector::CombiningKind kind)

static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)

Converts affine.apply Ops to arithmetic operations.

static OpType getSingleOpOfType(Block &block)

Return the unique instance of OpType in block if it is indeed unique.

static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)

static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)

Helper function to extract the result slices after filter is unrolled along kw.

static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)

Given a linalg::PackOp, return the dest shape before any packing permutations.

static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)

static AffineMap reindexIndexingMap(AffineMap map)

Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...

static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)

Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...

VectorizationStatus

Helper data structure to represent the result of vectorization.

@ Failure

Op failed to vectorize.

@ NewOp

Op vectorized into a new Op whose results will replace original Op's results.

@ NoReplace

Op vectorized and custom function took care of replacement logic.

static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)

static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)

static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)

Broadcast value to a vector of shape if possible.

static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)

Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....

static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)

Preconditions for scalable vectors.

static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)

Check whether val could be used for calculating the trailing index for a contiguous load operation.

static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)

Need to check if the inner-tiles are static/constant.

static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)

Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...

A dimensional identifier appearing in an affine expression.

Base type for affine expression.

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)

Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.

MLIRContext * getContext() const

static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)

Returns an AffineMap with 'numDims' identity result dim exprs.

static AffineMap get(MLIRContext *context)

Returns a zero result affine map with no dimensions or symbols: () -> ().

bool isProjectedPermutation(bool allowZeroInResults=false) const

Returns true if the AffineMap represents a subset (i.e.

unsigned getNumInputs() const

static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)

Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...

AffineMap dropZeroResults()

Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.

static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)

Returns an AffineMap representing a permutation.

SmallVector< unsigned > getBroadcastDims() const

Returns the list of broadcast dimensions (i.e.

AffineMap compose(AffineMap map) const

Returns the AffineMap resulting from composing this with map.

bool isPermutation() const

Returns true if the AffineMap represents a symbol-less permutation map.

Attributes are known-constant values of operations.

This class represents an argument of a Block.

unsigned getArgNumber() const

Returns the number of this argument.

Block represents an ordered list of Operations.

RetT walk(FnT &&callback)

Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...

OpListType & getOperations()

AffineMap getMultiDimIdentityMap(unsigned rank)

TypedAttr getZeroAttr(Type type)

MLIRContext * getContext() const

ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)

An attribute that represents a reference to a dense integer vector or tensor object.

static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)

Get an instance of a DenseIntElementsAttr with the given arguments.

This is a utility class for mapping one set of IR entities to another.

auto lookup(T from) const

Lookup a mapped value within the map.

void map(Value from, Value to)

Inserts a new mapping for 'from' to 'to'.

IRValueT get() const

Return the current value being used by this operand.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Block::iterator getInsertionPoint() const

Returns the current insertion point of the builder.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents a single result from folding an operation.

This class represents an operand of an operation.

unsigned getOperandNumber()

Return which operand this is in the OpOperand list of the Operation.

StringAttr getIdentifier() const

Return the name of this operation as a StringAttr.

Operation is the basic unit of execution within MLIR.

Value getOperand(unsigned idx)

bool isBeforeInBlock(Operation *other)

Given an operation 'other' that is within the same parent block, return whether the current operation...

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

Location getLoc()

The source location the operation was defined or derived from.

unsigned getNumOperands()

ArrayRef< NamedAttribute > getAttrs()

Return all of the attributes on this operation.

Block * getBlock()

Returns the operation block that contains this operation.

operand_iterator operand_end()

Region & getRegion(unsigned index)

Returns the region held by this operation at position 'index'.

OperationName getName()

The name of an operation is the key identifier for it.

result_type_range getResultTypes()

operand_range getOperands()

Returns an iterator on the underlying Value's.

result_range getResults()

unsigned getNumResults()

Return the number of results held by this operation.

This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...

unsigned short getBenefit() const

If the corresponding pattern can match, return its benefit. If the.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)

Find uses of from and replace them with to except if the user is exceptedUser.

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

use_range getUses() const

Returns a range of all uses, which is useful for iterating over all uses.

Location getLoc() const

Return the location of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

static WalkResult advance()

static WalkResult interrupt()

Specialization of arith.constant op that returns an integer of index type.

Operation * getOwner() const

Return the owner of this operand.

bool hasElementwiseMappableTraits(Operation *op)

Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...

Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)

Emit code that computes the given affine expression using standard arithmetic operations applied to t...

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)

Divides the known min value of the numerator by the denominator and rounds the result up to the next ...

bool hasVectorizationImpl(Operation *)

Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...

bool allIndexingsAreProjectedPermutation(LinalgOp op)

Check if all indexing maps are projected permutations.

void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)

Populates patterns with patterns that vectorize tensor.pad.

bool isReductionIterator(utils::IteratorType iteratorType)

Check if iterator type has "reduction" semantics.

SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)

Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...

LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)

Emit a suitable vector form for an operation.

void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)

Populate patterns for vectorizing low-D convolution ops.

bool isElementwise(LinalgOp op)

Check if a LinalgOp is an element-wise operation.

LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)

Emit a suitable vector form for a Copy op with fully static shape.

LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)

Return success if the operation can be vectorized.

SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)

Shell function to compute the Source Permutation of unPackOp.

std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)

Return vector::CombiningKind for the given op.

void promote(RewriterBase &rewriter, scf::ForallOp forallOp)

Promotes the loop body of a scf::ForallOp to its containing block.

std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)

Returns an element-value of non-complex type.

SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)

Return the dimensions of the given tensor value.

LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)

Returns success if inputVectorSizes is a valid masking configuraion for given shape,...

Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())

Creates a vector.mask operation around a maskable operation.

BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)

Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking=false)

Creates a TransferReadOp from source.

SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)

A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)

Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...

const FrozenRewritePatternSet GreedyRewriteConfig bool * changed

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)

Apply a permutation from map to source and return the result.

bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)

Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.

AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)

Return the reverse map of a projected permutation where the projected dimensions are transformed into...

std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue

If Ty is mlir::Type this will select Value instead of having a wrapper around it.

void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to DimExpr at positions: [0 .

SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)

AffineMap inversePermutation(AffineMap map)

Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...

Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)

Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)

Fill values with a list of values defined at the ancestors of the limit region and used within region...

AffineMap compressUnusedDims(AffineMap map)

Drop the dims that are not used.

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)

Apply the permutation defined by permutation to inVec.

SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)

Helper method to apply to inverse a permutation.

Rewrite use of tensor::PadOp result in InsertSliceOp.

LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override

Rewrite use of tensor::PadOp result in TransferReadOp.

LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override

Rewrite use of tensor::PadOp result in TransferWriteOp.

bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const

Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.

LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override

Operation * newOp

New vectorized operation to replace the current op.

enum VectorizationStatus status

Return status from vectorizing the current op.

Contains the vectorization state and related methods used across the vectorization process of a given...

ArrayRef< bool > getScalableVecDims() const

Returns the vector dimensions that are scalable in the canonical vector shape.

VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const

Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...

ArrayRef< int64_t > getCanonicalVecShape() const

Returns the canonical vector shape used to vectorize the iteration space.

LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)

Initializes the vectorization state, including the computation of the canonical vector shape for vect...

VectorizationState(RewriterBase &rewriter)

LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override

Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.

virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0

LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final

OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...

OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override

TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.

LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override

TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.

Eliminates variable at the specified position using Fourier-Motzkin variable elimination.