MLIR: lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

24 #include

25

26 using namespace mlir;

28

29

30

31

32

33

34

40 for (auto result : indexingMap.getResults()) {

43 Value v = b.createaffine::AffineApplyOp(loc, m, ivs);

44 indices.push_back(v);

45 }

46 return indices;

47 }

48

49

50

53 Block *body = linalgOp.getBlock();

57 if (auto indexOp = dyn_cast(&op)) {

58 map.map(indexOp.getResult(), ivs[indexOp.getDim()]);

59 continue;

60 }

62 }

63

68 OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index());

70 b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs);

71 b.creatememref::StoreOp(

72 loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(),

73 indices);

74 }

75 return success();

76 }

77

78

79

80

81

82 namespace {

83

84

85

86

87

88 template

89 struct LinalgOpTilingInterface

90 : public TilingInterface::ExternalModel<LinalgOpTilingInterface,

91 LinalgOpTy> {

92

94 LinalgOpTy concreteOp = cast(op);

95 return concreteOp.getIteratorTypesArray();

96 }

97

98

103 LinalgOp linalgOp = cast(op);

105 linalgOp.createFlatListOfOperandDims(b, loc);

106 AffineMap map = linalgOp.getShapesToLoopsMap();

107

108 return llvm::to_vector(

110 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(

111 b, loc, loopExpr, allShapesSizes);

112 return Range{b.getIndexAttr(0), ofr, b.getIndexAttr(1)};

113 }));

114 }

115

116

117 FailureOr

121

122

124 LinalgOp linalgOp = cast(op);

127 b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);

129 llvm::make_filter_range(

130 tiledOperands,

131 [](Value v) -> bool {

132 return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>(

134 }),

136

139

140 Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);

141 offsetIndices(b, cast(tiledOp), offsets);

142

145 }

146

147

148

149

150 void

151 getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,

156 unsigned numLoops = linalgOp.getNumLoops();

157 auto tilingInterfaceOp = cast(linalgOp.getOperation());

158 mappedOffsets.resize(numLoops);

159 mappedSizes.resize(numLoops);

162 tilingInterfaceOp.getIterationDomain(b);

163 for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {

164 mappedOffsets[index] = value.offset;

165 mappedSizes[index] = value.size;

166 }

167 }

168 for (const auto &&[index, value] :

170 unsigned dimPosition = cast(value).getPosition();

171 mappedOffsets[dimPosition] = offsets[index];

172 mappedSizes[dimPosition] = sizes[index];

173 }

174 }

175

176

177

178 LogicalResult getIterationDomainTileFromOperandTile(

183 auto linalgOp = cast(op);

184

185

186

187

188

190 linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));

193 << "unhandled get iter domain position when operand is not "

194 "accessed using a permuted projection";

195 }

196

197 getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,

198 iterDomainOffsets, iterDomainSizes);

199 return success();

200 }

201

202

203

204 LogicalResult

211 LinalgOp linalgOp = cast(op);

212

216 llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) {

218 }));

219

220 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);

222 b, loc, outOperand->get(), sizes,

223 linalgOp.getMatchingIndexingMap(outOperand), offsets,

224 {}, subShapeSizes, true);

225 resultOffsets = sliceParams.offsets;

226 resultSizes = sliceParams.sizes;

227 return success();

228 }

229

230 LogicalResult getIterationDomainTileFromResultTile(

235 auto linalgOp = cast(op);

236

237

238

239

240

242 linalgOp.getIndexingMapMatchingResult(op->getResult(resultNumber));

245 "unhandled tiled implementation generation when result is not "

246 "accessed using a permuted projection");

247 }

248

249 getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,

250 iterDomainOffsets, iterDomainSizes);

251 return success();

252 }

253

254 FailureOr

255 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,

259 if (failed(getIterationDomainTileFromResultTile(

260 op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {

261 return failure();

262 }

263 auto tilingInterfaceOp = cast(op);

264 FailureOr tilingResult =

265 tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);

266

267 if (failed(tilingResult))

268 return failure();

269

270 if (tilingResult->tiledOps.size() != 1)

271 return op->emitOpError("failed to generate tiled implementation");

272

276 tilingResult->generatedSlices};

277 }

278

279

280

281 FailureOr getTiledImplementationFromOperandTile(

285 if (failed(getIterationDomainTileFromOperandTile(

286 op, b, operandNumber, offsets, sizes, mappedOffsets,

287 mappedSizes))) {

288 return failure();

289 }

291 }

292

293 LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,

296 auto linalgOp = cast(op);

297 if (!linalgOp.hasPureBufferSemantics())

298 return op->emitOpError("expected operation to have buffer semantics");

299

301 indexedValues.reserve(linalgOp->getNumOperands());

303

304

305 for (OpOperand &operand : linalgOp->getOpOperands()) {

306 if (!linalgOp.payloadUsesValueFromOperand(&operand)) {

307 indexedValues.push_back(nullptr);

308 continue;

309 }

310 if (linalgOp.isScalar(&operand)) {

311 indexedValues.push_back(operand.get());

312 continue;

313 }

315 builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs);

317 builder.creatememref::LoadOp(linalgOpLoc, operand.get(), indices);

318 indexedValues.push_back(load);

319 }

320

321

322 return inlinePayload(builder, linalgOp, ivs, indexedValues);

323 }

324 };

325

326

327

328

329

330

331

332

333

334

335

336

337 static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,

339 unsigned resultNumber) {

341 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));

342 for (int redPos : reductionDims) {

345 }

346 return map;

347 }

348

349

350

351 template

352 struct LinalgOpPartialReductionInterface

353 : public PartialReductionOpInterface::ExternalModel<

354 LinalgOpPartialReductionInterface, LinalgOpTy> {

355 FailureOr<SmallVector> generateInitialTensorForPartialReduction(

358 auto linalgOp = cast(op);

360

361 if (linalgOp.hasPureBufferSemantics())

362 return op->emitOpError("expected operation to have tensor semantics");

363

364

365 auto tilingInterfaceOp = cast(linalgOp.getOperation());

367 llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b),

368 [](Range x) { return x.size; });

369

371 for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {

373 tiledShape.push_back(dimSize);

374 } else {

375 tiledShape.push_back(tileSize);

376 }

377 }

378

380 for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;

381 ++initIdx) {

383 if (matchReduction(linalgOp.getRegionOutputArgs(), initIdx,

384 combinerOps) ||

385 combinerOps.size() != 1)

386 return op->emitOpError("Failed to anaysis the reduction operation.");

387

388 Operation *reductionOp = combinerOps[0];

390 if (!identity.has_value())

392 "Failed to get an identity value for the reduction operation.");

393

394

396 getPartialResultAffineMap(linalgOp, reductionDims, initIdx);

399 auto dim = cast(dimExpr);

400 partialResultShape.push_back(tiledShape[dim.getPosition()]);

401 }

402

403 Type elType =

405 Value emptyTensor =

406 b.createtensor::EmptyOp(loc, partialResultShape, elType);

407 Value constantOp = b.createarith::ConstantOp(loc, *identity);

408 auto identityTensor =

409 b.createlinalg::FillOp(loc, constantOp, emptyTensor);

410 inits.push_back(identityTensor.getResult(0));

411 }

412

413 return inits;

414 }

415

416 FailureOr

422 auto linalgOp = cast(op);

423

424

425

427 newInitMaps.reserve(linalgOp.getNumDpsInits());

428 for (int idx : llvm::seq(0, linalgOp.getNumDpsInits())) {

429

430

432 getPartialResultAffineMap(linalgOp, reductionDims, idx);

433 newInitMaps.push_back(newMap);

434 }

435

436

438 b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);

440 llvm::make_filter_range(

443

444

446 for (auto [valueMap, valueToTile] : llvm::zip_equal(newInitMaps, init)) {

447 int64_t initRank = valueMap.getNumResults();

451 for (AffineExpr dimExpr : valueMap.getResults()) {

452 auto dim = cast(dimExpr);

453 initSizes.push_back(sizes[dim.getPosition()]);

454 }

455

456 auto extractSlice = b.createtensor::ExtractSliceOp(

457 loc, valueToTile, initOffset, initSizes, initStride);

458 tiledInits.push_back(extractSlice);

459 generatedSlices.push_back(extractSlice);

460 }

461

462

464

465 for (int idx : llvm::seq(0, linalgOp.getNumDpsInits())) {

466

467

468 OpOperand *initOperand = linalgOp.getDpsInitOperand(idx);

469 int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand);

470 newMaps[mapIdx] = newInitMaps[idx];

471 }

472

473

475 linalgOp.getIteratorTypesArray();

476 for (int dim : reductionDims)

477 newIteratorTypes[dim] = utils::IteratorType::parallel;

478

479

480 auto genericOp =

482 tiledInits, newMaps, newIteratorTypes);

485 genericOp.getRegion().begin(), mapping);

487 {genericOp.getOperation()},

488 llvm::map_to_vector(genericOp->getResults(),

490 generatedSlices};

491 }

492

496 auto linalgOp = cast(op);

497

498

499

500 int64_t numInits = linalgOp.getNumDpsInits();

503 for (int idx : llvm::seq(numInits)) {

504

505

506

507

509 getPartialResultAffineMap(linalgOp, reductionDims, idx);

511 for (auto [resultNum, dimExpr] :

513 unsigned dim = cast(dimExpr).getPosition();

514 if (llvm::is_contained(reductionDims, dim)) {

515 partialReductionDims.push_back(resultNum);

516 }

517 }

518

519 Value partialResult = partialReduce[idx];

520 Value init = linalgOp.getDpsInits()[idx];

521

522 auto reduction = b.createlinalg::ReduceOp(

523 loc, partialResult, init, partialReductionDims,

525

527 matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);

528 Operation *clonedReductionOp = b.clone(*combinerOps[0]);

529

530 clonedReductionOp->setOperand(0, inputs[0]);

531 clonedReductionOp->setOperand(1, inputs[1]);

532 b.createlinalg::YieldOp(loc, clonedReductionOp->getResult(0));

533 });

534

535 mergeOperations.push_back(reduction);

536 replacements.push_back(reduction->getResult(0));

537 }

538

539 return MergeResult{mergeOperations, replacements};

540 }

541

542 LogicalResult getPartialResultTilePosition(

548 auto linalgOp = cast(op);

549

551 getPartialResultAffineMap(linalgOp, reductionDims, resultNumber);

553 unsigned dim = cast(dimExpr).getPosition();

554 resultSizes.push_back(sizes[dim]);

555

556 if (llvm::is_contained(reductionDims, dim)) {

557

558

560 } else {

561 resultOffsets.push_back(offsets[dim]);

562 }

563 }

564

565 return success();

566 }

567 };

568

569 template

572 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,

573 "applies to only pack or unpack operations");

575 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()

576 : op.getDestRank();

582 for (auto dim : llvm::seq<int64_t>(0, rank)) {

583 loopBounds[dim].offset = zero;

584 loopBounds[dim].stride = one;

585 loopBounds[dim].size = resultShape[0][dim];

586 }

587 return loopBounds;

588 }

589

593 if (permutation.empty())

594 return;

595 applyPermutationToVector(offsets, permutation);

596 applyPermutationToVector(sizes, permutation);

597 }

598

599 struct PackOpTiling

600 : public TilingInterface::ExternalModel<PackOpTiling, linalg::PackOp> {

601

603

604

605

606 auto packOp = cast(op);

608 packOp.getSourceRank(), utils::IteratorType::parallel);

609 return iteratorTypes;

610 }

611

613 return getPackUnPackIterationDomain(cast(op), b);

614 }

615

616 FailureOr

620 auto packOp = cast(op);

621 Location loc = packOp.getLoc();

622

623

624

625 int64_t inputRank = packOp.getSourceRank();

628 applyPermToRange(origOffsets, origSizes,

630

632 packOp.getDimAndTileMapping();

636 for (auto dim : llvm::seq<int64_t>(0, inputRank)) {

642 if (dimAndTileMapping.count(dim)) {

643

644

645

646 auto avOffset = AV(dim0).bind(origOffsets[dim]);

647 auto avSize = AV(dim0).bind(origSizes[dim]);

648 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);

649 inputIndices.push_back(ab.mul(avOffset, avTileSize));

650 inputSizes.push_back(ab.mul(avSize, avTileSize));

651 } else {

652 inputIndices.push_back(origOffsets[dim]);

653 inputSizes.push_back(origSizes[dim]);

654 }

655

656

657 if (packOp.getPaddingValue()) {

659 auto avDimSize = AV(dim0).bind(dimSize);

660 auto avInputIdx = AV(dim1).bind(inputIndices.back());

661 inputSizes.back() =

662 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});

663 }

664 }

665

668

670 auto sourceSlice = b.createtensor::ExtractSliceOp(

671 loc, packOp.getSource(), inputIndices, inputSizes, strides);

672 tiledOperands.push_back(sourceSlice);

673

676 outputSizes)))

677 return {};

678

679 strides.append(packOp.getDestRank() - inputRank, oneAttr);

680 auto outSlice = b.createtensor::ExtractSliceOp(

681 loc, packOp.getDest(), outputOffsets, outputSizes, strides);

682 tiledOperands.push_back(outSlice);

683

684 if (auto val = packOp.getPaddingValue())

685 tiledOperands.push_back(val);

686 for (auto tile : packOp.getInnerTiles())

687 tiledOperands.push_back(tile);

688

690 loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());

691

693 {tiledPackOp},

696 }

697

698 LogicalResult

704

705

706

707

708 auto packOp = cast(op);

709 int64_t inputRank = packOp.getSourceRank();

710 int64_t outputRank = packOp.getDestRank();

712 resultOffsets.assign(offsets.begin(), offsets.end());

713 resultOffsets.append(outputRank - inputRank, zeroAttr);

714

717 resultSizes.assign(sizes.begin(), sizes.end());

718 for (auto dataTileDim : llvm::seq(inputRank, outputRank))

719 resultSizes.push_back(outputShape[0][dataTileDim]);

720

721 return success();

722 }

723

724 FailureOr

725 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,

728 auto packOp = cast(op);

729 int64_t numTiles = packOp.getInnerDimsPos().size();

730

731

732

733

734 for (auto offset : offsets.take_back(numTiles))

736 return failure();

737

738 for (auto iter :

739 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))

741 return failure();

742

744 op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));

745 if (failed(tilingResult))

746 return failure();

747 return tilingResult.value();

748 }

749

750

751

752

753 LogicalResult getIterationDomainTileFromOperandTile(

758 if (operandNumber != 0)

759 return failure();

760

761 auto packOp = cast(op);

762

763

764 if (packOp.getPaddingValue())

765 return failure();

766

767 Location loc = packOp.getLoc();

768

771 packOp.getDimAndTileMapping();

772 for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {

773 if (dimAndTileMapping.count(dim)) {

774 FailureOr<int64_t> cstSize =

777 nullptr, true);

778 std::optional<int64_t> cstInnerSize =

780

781

782

783

784

785

786

787

788

789

790

791

792

793

794 if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {

795 return failure();

796 }

797

803 auto avOffset = AV(dim0).bind(offsets[dim]);

804 auto avSize = AV(dim0).bind(sizes[dim]);

805 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);

806 outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));

807 outerDimSizes.push_back(ab.ceil(avSize, avTileSize));

808 } else {

809 outerDimOffsets.push_back(offsets[dim]);

810 outerDimSizes.push_back(sizes[dim]);

811 }

812 }

813 applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm());

814 resultOffsets = outerDimOffsets;

815 resultSizes = outerDimSizes;

816 return success();

817 }

818

819

820 FailureOr getTiledImplementationFromOperandTile(

823 if (operandNumber != 0)

824 return failure();

825

826 auto packOp = cast(op);

827 Location loc = packOp.getLoc();

828

829 int64_t inputRank = packOp.getSourceRank();

832

834 auto sourceSlice = b.createtensor::ExtractSliceOp(

835 loc, packOp.getSource(), offsets, sizes, strides);

836 tiledOperands.push_back(sourceSlice);

837

839 if (failed(getIterationDomainTileFromOperandTile(

840 op, b, 0, offsets, sizes, outerDimOffsets,

841 outerDimSizes)))

842 return failure();

843

846 outputOffsets, outputSizes)))

847 return failure();

848

849 strides.append(packOp.getDestRank() - inputRank, oneAttr);

850 auto outSlice = b.createtensor::ExtractSliceOp(

851 loc, packOp.getDest(), outputOffsets, outputSizes, strides);

852 tiledOperands.push_back(outSlice);

853

854 assert(!packOp.getPaddingValue() && "Expect no padding semantic");

855 for (auto tile : packOp.getInnerTiles())

856 tiledOperands.push_back(tile);

857

859 loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());

860

862 {tiledPackOp},

865 }

866 };

867

868 struct UnpackTileDimInfo {

869 bool isAlignedToInnerTileSize;

874 };

875

876

877

878

879 static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,

880 int64_t tileDim,

883 UnpackTileDimInfo info;

887 unpackOp.getDimAndTileMapping();

888

889 if (!dimAndTileMapping.count(tileDim)) {

890 info.isAlignedToInnerTileSize = true;

891 info.sourceOffset = tileOffset;

892 info.sourceSize = tileSize;

893 info.resultOffset = zeroAttr;

894 info.destExpandedSize = tileSize;

895 return info;

896 }

897

898 Location loc = unpackOp.getLoc();

904

905 OpFoldResult innerTileSize = dimAndTileMapping[tileDim];

906

907 info.isAlignedToInnerTileSize = false;

910 nullptr, true);

911 std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);

912 if (!failed(cstSize) && cstInnerSize) {

913 if (*cstSize % *cstInnerSize == 0)

914 info.isAlignedToInnerTileSize = true;

915

916

917

918 if (*cstInnerSize == *cstSize) {

919 auto lhs = AV(dim0).bind(tileOffset);

920 auto rhs = AV(dim1).bind(innerTileSize);

921 info.sourceOffset = ab.floor(lhs, rhs);

922 info.sourceSize = oneAttr;

923 info.resultOffset = zeroAttr;

924 info.destExpandedSize = tileSize;

925 return info;

926 }

927 }

928

929 if (info.isAlignedToInnerTileSize) {

930 info.sourceOffset =

931 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));

932 info.resultOffset = zeroAttr;

933 info.destExpandedSize = tileSize;

934

935

936

937

938

939

940

941 info.sourceSize =

942 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));

943 return info;

944 }

945

950 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));

952 b, loc,

954 b, loc,

955 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),

957

959 AV(dim1).bind(firstCoord.quotient));

960 info.sourceSize =

961 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));

962 info.sourceOffset = firstCoord.quotient;

963 info.resultOffset = firstCoord.remainder;

964

965

966 info.destExpandedSize = b.createOrFoldarith::MulIOp(

969 return info;

970 }

971

972 struct UnPackOpTiling

973 : public TilingInterface::ExternalModel<UnPackOpTiling, linalg::UnPackOp> {

974

976 auto unpackOp = cast(op);

978 unpackOp.getDestRank(), utils::IteratorType::parallel);

979 return iteratorTypes;

980 }

981

983 return getPackUnPackIterationDomain(cast(op), b);

984 }

985

986

987

988

989

990

991

992

993

994

995

996

997

998

999

1000 FailureOr

1004 auto unpackOp = cast(op);

1005 int64_t srcRank = unpackOp.getSourceRank();

1006 int64_t destRank = unpackOp.getDestRank();

1007 int64_t numInnerTiles = srcRank - destRank;

1008 Location loc = unpackOp.getLoc();

1009

1010

1011

1012

1013 bool isPerfectTilingCase = true;

1018 for (auto dim : llvm::seq<int64_t>(0, destRank)) {

1019 UnpackTileDimInfo info =

1020 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);

1021 if (!info.isAlignedToInnerTileSize)

1022 isPerfectTilingCase = false;

1023 sliceSrcIndices.push_back(info.sourceOffset);

1024 sliceSrcSizes.push_back(info.sourceSize);

1025 destExpandedSizes.push_back(info.destExpandedSize);

1026 resultOffsetsFromDest.push_back(info.resultOffset);

1027 }

1028

1029

1030

1031 applyPermToRange(sliceSrcIndices, sliceSrcSizes,

1032 unpackOp.getOuterDimsPerm());

1034 sliceSrcIndices.append(numInnerTiles, zeroAttr);

1035 sliceSrcSizes.append(unpackOp.getMixedTiles());

1036 sliceSrcStrides.append(numInnerTiles, oneAttr);

1038 tensor::ExtractSliceOp sliceSource = b.createtensor::ExtractSliceOp(

1039 loc, unpackOp.getSource(), sliceSrcIndices, sliceSrcSizes,

1040 sliceSrcStrides);

1041 generatedSlices.push_back(sliceSource);

1042

1044 Value sliceDest;

1045 if (isPerfectTilingCase) {

1046 auto destSliceOp = b.createtensor::ExtractSliceOp(

1047 loc, unpackOp.getDest(), offsets, sizes, destStrides);

1048 sliceDest = destSliceOp;

1049 generatedSlices.push_back(destSliceOp);

1050 } else {

1051 sliceDest = b.createtensor::EmptyOp(

1052 loc, destExpandedSizes, unpackOp.getDestType().getElementType());

1053 }

1054

1055 SmallVector tiledOperands = {sliceSource.getResult(), sliceDest};

1056 for (auto tile : unpackOp.getInnerTiles())

1057 tiledOperands.push_back(tile);

1058

1060 loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());

1061

1062 if (isPerfectTilingCase)

1065 generatedSlices};

1066

1067 auto extractSlice = b.createtensor::ExtractSliceOp(

1068 loc, tiledUnpackOp->getResult(0), resultOffsetsFromDest, sizes,

1069 destStrides);

1071 {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};

1072 }

1073

1074 LogicalResult

1080 resultOffsets = llvm::to_vector(offsets);

1081 resultSizes = llvm::to_vector(sizes);

1082 return success();

1083 }

1084

1085 FailureOr

1086 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,

1089 FailureOr tilingResult =

1091 if (failed(tilingResult))

1092 return failure();

1093 return tilingResult.value();

1094 }

1095

1096

1097

1098 LogicalResult getIterationDomainTileFromOperandTile(

1103 auto unPackOp = cast(op);

1104

1105 if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {

1106 resultOffsets = llvm::to_vector(offsets);

1107 resultSizes = llvm::to_vector(sizes);

1108 return success();

1109 }

1110 Location loc = unPackOp.getLoc();

1111

1112 int64_t numTiles = unPackOp.getInnerDimsPos().size();

1113 auto destOffsets = offsets.drop_back(numTiles);

1114 auto destSizes = sizes.drop_back(numTiles);

1115

1116

1117 int64_t outputRank = unPackOp.getDestRank();

1119 if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes)))

1120 return failure();

1124 applyPermToRange(origOffsets, origSizes,

1126

1128 unPackOp.getDimAndTileMapping();

1129

1130 for (auto dim : llvm::seq<int64_t>(0, outputRank)) {

1136 if (dimAndTileMapping.count(dim)) {

1137

1138

1139

1140 auto avOffset = AV(dim0).bind(origOffsets[dim]);

1141 auto avSize = AV(dim0).bind(origSizes[dim]);

1142 auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]);

1143 auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]);

1144 resultOffsets.push_back(ab.mul(avOffset, avTileSize));

1145 auto avResultOffset = AV(dim1).bind(resultOffsets.back());

1146 resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize),

1147 ab.sub(avResultSize, avResultOffset)}));

1148 } else {

1149 resultOffsets.push_back(origOffsets[dim]);

1150 resultSizes.push_back(origSizes[dim]);

1151 }

1152 }

1153 return success();

1154 }

1155

1156

1157 FailureOr getTiledImplementationFromOperandTile(

1160 auto unPackOp = cast(op);

1161

1162

1163 int64_t numTiles = unPackOp.getInnerDimsPos().size();

1164 for (auto iter :

1165 llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {

1167 return failure();

1168 }

1169

1170 Location loc = unPackOp.getLoc();

1171

1172

1173

1175 if (failed(getIterationDomainTileFromOperandTile(

1176 op, b, 0, offsets, sizes, outputOffsets,

1177 outputSizes)))

1178 return failure();

1179

1181 int64_t outputRank = unPackOp.getDestRank();

1183

1185

1186 auto extractDestSlice = b.createtensor::ExtractSliceOp(

1187 loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);

1188 tiledOperands.push_back(extractDestSlice);

1189

1190 strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);

1191

1192 auto extractSourceSlice = b.createtensor::ExtractSliceOp(

1193 loc, unPackOp.getSource(), offsets, sizes, strides);

1194 tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);

1195 for (auto tile : unPackOp.getInnerTiles())

1196 tiledOperands.push_back(tile);

1197

1198

1200 b.create(loc, TypeRange{extractDestSlice.getType()},

1201 tiledOperands, op->getAttrs());

1202

1206 extractSourceSlice, extractDestSlice})};

1207 }

1208 };

1209

1210 }

1211

1212 template

1214 OpType::template attachInterface<LinalgOpTilingInterface>(*ctx);

1215 OpType::template attachInterface<LinalgOpPartialReductionInterface>(

1216 *ctx);

1217 }

1218

1219

1220 template <typename... OpTypes>

1222 (registerOne(ctx), ...);

1223 }

1224

1225 #define GET_OP_LIST

1226

1230 registerOnelinalg::GenericOp(ctx);

1231 linalg::PackOp::attachInterface(*ctx);

1232 linalg::UnPackOp::attachInterface(*ctx);

1234 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"

1235 >(ctx);

1236 });

1237 }

1238

1242 linalg::PackOp::attachInterface(*ctx);

1243 linalg::UnPackOp::attachInterface(*ctx);

1244 });

1245 }

static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)

static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)

static SmallVector< Value > getIndicesForAccess(OpBuilder &b, Location loc, AffineMap indexingMap, ValueRange ivs)

Return the SSA values that represent the data point accessed using a given indexingMap for a given po...

static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, ValueRange ivs, ValueRange argValues)

Method to inline the payload of a linalgOp given the iteration space point and values for the argumen...

static void registerAll(MLIRContext *ctx)

Variadic helper function.

static void registerOne(MLIRContext *ctx)

Base type for affine expression.

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

static AffineMap get(MLIRContext *context)

Returns a zero result affine map with no dimensions or symbols: () -> ().

AffineMap insertResult(AffineExpr expr, unsigned pos) const

Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.

bool isProjectedPermutation(bool allowZeroInResults=false) const

Returns true if the AffineMap represents a subset (i.e.

unsigned getNumSymbols() const

unsigned getNumDims() const

ArrayRef< AffineExpr > getResults() const

unsigned getNumResults() const

bool isPermutation() const

Returns true if the AffineMap represents a symbol-less permutation map.

Attributes are known-constant values of operations.

Block represents an ordered list of Operations.

Operation * getTerminator()

Get the terminator operation of this block.

BlockArgListType getArguments()

iterator_range< iterator > without_terminator()

Return an iterator range over the operation within this block excluding the terminator operation at t...

IntegerAttr getIndexAttr(int64_t value)

IntegerAttr getI64IntegerAttr(int64_t value)

MLIRContext * getContext() const

The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.

bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)

Add the given extension to the registry.

This is a utility class for mapping one set of IR entities to another.

auto lookupOrDefault(T from) const

Lookup a mapped value within the map.

void map(Value from, Value to)

Inserts a new mapping for 'from' to 'to'.

IRValueT get() const

Return the current value being used by this operand.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents a single result from folding an operation.

This class represents an operand of an operation.

This is a value defined by a result of an operation.

Operation is the basic unit of execution within MLIR.

OpOperand & getOpOperand(unsigned idx)

void setOperand(unsigned idx, Value value)

Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())

Create a deep copy of this operation, remapping any operands that use values outside of the operation...

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

Location getLoc()

The source location the operation was defined or derived from.

ArrayRef< NamedAttribute > getAttrs()

Return all of the attributes on this operation.

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

Region & getRegion(unsigned index)

Returns the region held by this operation at position 'index'.

operand_range getOperands()

Returns an iterator on the underlying Value's.

result_range getResults()

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

void cloneInto(Region *dest, IRMapping &mapper)

Clone the internal blocks from this region into dest.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, StopConditionFn stopCondition=nullptr, bool closedUB=false)

Compute a constant bound for the given variable.

This class provides an abstraction over the different types of ranges over Values.

type_range getTypes() const

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...

DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs)

Create IR to calculate (div lhs, rhs) and (mod lhs, rhs).

std::optional< TypedAttr > getNeutralElement(Operation *op)

Return the identity numeric value associated to the give op.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)

Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...

void registerTilingInterfaceExternalModelsForPackUnPackOps(DialectRegistry &registry)

Similar to the above registeration, but it is only for tensor.pack and tensor.unpack ops.

void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)

Add the specified offsets to any linalg.index ops contained in the given linalgOp.

void registerTilingInterfaceExternalModels(DialectRegistry &registry)

SmallVector< Type > getTensorOutputTypes(LinalgOp op, ValueRange operands)

Returns the list of tensor output types produced when the given structured operation op is applied to...

SliceParameters computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)

Computes SliceParameters for a single valueToTile assuming that its user is being tiled with the give...

SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)

Return the dimensions of the given tensor value.

Include the generated interface declarations.

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)

Reify the shape of the result of an operation (typically in terms of the shape of its operands).

bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)

Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.

void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to DimExpr at positions: [0 .

Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)

Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

bool isZeroInteger(OpFoldResult v)

Return true if v is an IntegerAttr with value 0.

void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to SymbolExpr at positions: [0 .

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)

SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)

Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...

AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)

These free functions allow clients of the API to not use classes in detail.

SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)

Helper method to apply to inverse a permutation.

Container for the result of merge operation of tiling.

Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...

Container for result values of tiling.

SmallVector< Operation * > tiledOps

Helper struct to build simple AffineValueExprs with minimal type inference support.

Holds the result of (div a, b) and (mod a, b).

A struct containg offsets-sizes-strides arguments of the tiled shape.

SmallVector< OpFoldResult > sizes

SmallVector< OpFoldResult > offsets