MLIR: lib/Dialect/MemRef/IR/MemRefOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

24#include "llvm/ADT/STLExtras.h"

25#include "llvm/ADT/SmallBitVector.h"

26

27using namespace mlir;

29

30

31

35 return arith::ConstantOp::materialize(builder, value, type, loc);

36}

37

38

39

40

41

42

43

44

46 bool folded = false;

48 auto cast = operand.get().getDefiningOp();

49 if (cast && operand.get() != inner &&

50 !llvm::isa(cast.getOperand().getType())) {

51 operand.set(cast.getOperand());

52 folded = true;

53 }

54 }

56}

57

58

59

61 if (auto memref = llvm::dyn_cast(type))

62 return RankedTensorType::get(memref.getShape(), memref.getElementType());

63 if (auto memref = llvm::dyn_cast(type))

64 return UnrankedTensorType::get(memref.getElementType());

65 return NoneType::get(type.getContext());

66}

67

70 auto memrefType = llvm::cast(value.getType());

71 if (memrefType.isDynamicDim(dim))

72 return builder.createOrFoldmemref::DimOp(loc, value, dim);

73

74 return builder.getIndexAttr(memrefType.getDimSize(dim));

75}

76

79 auto memrefType = llvm::cast(value.getType());

81 for (int64_t i = 0; i < memrefType.getRank(); ++i)

84}

85

86

87

88

89

90

91

92

93

94

95

98 assert(constValues.size() == values.size() &&

99 "incorrect number of const values");

100 for (auto [i, cstVal] : llvm::enumerate(constValues)) {

102 if (ShapedType::isStatic(cstVal)) {

103

105 continue;

106 }

108

110 }

111 }

112}

113

114

115

116static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>

118 MemorySpaceCastOpInterface castOp =

119 MemorySpaceCastOpInterface::getIfPromotableCast(src);

120

121

122 if (!castOp)

123 return {};

124

125

126

127 FailureOr srcTy = resultTy.clonePtrWith(

128 castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);

129 if (failed(srcTy))

130 return {};

131

132 FailureOr tgtTy = resultTy.clonePtrWith(

133 castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);

134 if (failed(tgtTy))

135 return {};

136

137

138 if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))

139 return {};

140

141 return std::make_tuple(castOp, *tgtTy, *srcTy);

142}

143

144

145

146template

147static FailureOr<std::optional<SmallVector>>

151

152 if (!castOp)

153 return failure();

154

155

157 llvm::append_range(operands, op->getOperands());

159

160

161 auto newOp = ConcreteOpTy::create(

162 builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),

163 llvm::to_vector_of(op->getDiscardableAttrs()));

164

165

166 MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(

167 builder, tgtTy,

169 return std::optional<SmallVector>(

171}

172

173

174

175

176

177void AllocOp::getAsmResultNames(

179 setNameFn(getResult(), "alloc");

180}

181

182void AllocaOp::getAsmResultNames(

184 setNameFn(getResult(), "alloca");

185}

186

187template

189 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,

190 "applies to only alloc or alloca");

191 auto memRefType = llvm::dyn_cast(op.getResult().getType());

192 if (!memRefType)

193 return op.emitOpError("result must be a memref");

194

195 if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())

196 return op.emitOpError("dimension operand count does not equal memref "

197 "dynamic dimension count");

198

199 unsigned numSymbols = 0;

200 if (!memRefType.getLayout().isIdentity())

201 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();

202 if (op.getSymbolOperands().size() != numSymbols)

203 return op.emitOpError("symbol operand count does not equal memref symbol "

204 "count: expected ")

205 << numSymbols << ", got " << op.getSymbolOperands().size();

206

208}

209

210LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }

211

212LogicalResult AllocaOp::verify() {

213

216 "requires an ancestor op with AutomaticAllocationScope trait");

217

219}

220

221namespace {

222

223template

224struct SimplifyAllocConst : public OpRewritePattern {

225 using OpRewritePattern::OpRewritePattern;

226

227 LogicalResult matchAndRewrite(AllocLikeOp alloc,

228 PatternRewriter &rewriter) const override {

229

230

231 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {

232 APInt constSizeArg;

233 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))

234 return false;

235 return constSizeArg.isNonNegative();

236 }))

237 return failure();

238

239 auto memrefType = alloc.getType();

240

241

242

243 SmallVector<int64_t, 4> newShapeConstants;

244 newShapeConstants.reserve(memrefType.getRank());

245 SmallVector<Value, 4> dynamicSizes;

246

247 unsigned dynamicDimPos = 0;

248 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {

249 int64_t dimSize = memrefType.getDimSize(dim);

250

251 if (ShapedType::isStatic(dimSize)) {

252 newShapeConstants.push_back(dimSize);

253 continue;

254 }

255 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];

256 APInt constSizeArg;

258 constSizeArg.isNonNegative()) {

259

260 newShapeConstants.push_back(constSizeArg.getZExtValue());

261 } else {

262

263 newShapeConstants.push_back(ShapedType::kDynamic);

264 dynamicSizes.push_back(dynamicSize);

265 }

266 dynamicDimPos++;

267 }

268

269

270 MemRefType newMemRefType =

271 MemRefType::Builder(memrefType).setShape(newShapeConstants);

272 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());

273

274

275 auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,

276 dynamicSizes, alloc.getSymbolOperands(),

277 alloc.getAlignmentAttr());

278

279 rewriter.replaceOpWithNewOp(alloc, alloc.getType(), newAlloc);

281 }

282};

283

284

285template

287 using OpRewritePattern::OpRewritePattern;

288

289 LogicalResult matchAndRewrite(T alloc,

290 PatternRewriter &rewriter) const override {

291 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {

292 if (auto storeOp = dyn_cast(op))

293 return storeOp.getValue() == alloc;

294 return !isa(op);

295 }))

296 return failure();

297

298 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))

300

303 }

304};

305}

306

307void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,

309 results.add<SimplifyAllocConst, SimplifyDeadAlloc>(context);

310}

311

312void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,

314 results.add<SimplifyAllocConst, SimplifyDeadAlloc>(

315 context);

316}

317

318

319

320

321

322LogicalResult ReallocOp::verify() {

323 auto sourceType = llvm::cast(getOperand(0).getType());

324 MemRefType resultType = getType();

325

326

327 if (!sourceType.getLayout().isIdentity())

328 return emitError("unsupported layout for source memref type ")

329 << sourceType;

330

331

332 if (!resultType.getLayout().isIdentity())

333 return emitError("unsupported layout for result memref type ")

334 << resultType;

335

336

337 if (sourceType.getMemorySpace() != resultType.getMemorySpace())

338 return emitError("different memory spaces specified for source memref "

339 "type ")

340 << sourceType << " and result memref type " << resultType;

341

342

343 if (sourceType.getElementType() != resultType.getElementType())

344 return emitError("different element types specified for source memref "

345 "type ")

346 << sourceType << " and result memref type " << resultType;

347

348

349 if (resultType.getNumDynamicDims() && !getDynamicResultSize())

350 return emitError("missing dimension operand for result type ")

351 << resultType;

352 if (!resultType.getNumDynamicDims() && getDynamicResultSize())

353 return emitError("unnecessary dimension operand for result type ")

354 << resultType;

355

357}

358

359void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,

361 results.add<SimplifyDeadAlloc>(context);

362}

363

364

365

366

367

369 bool printBlockTerminators = false;

370

371 p << ' ';

372 if (!getResults().empty()) {

373 p << " -> (" << getResultTypes() << ")";

374 printBlockTerminators = true;

375 }

376 p << ' ';

378 false,

379 printBlockTerminators);

381}

382

384

385 result.regions.reserve(1);

387

388

390 return failure();

391

392

393 if (parser.parseRegion(*bodyRegion, {}))

394 return failure();

395 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),

397

398

400 return failure();

401

403}

404

405void AllocaScopeOp::getSuccessorRegions(

408 regions.push_back(RegionSuccessor(getOperation(), getResults()));

409 return;

410 }

411

413}

414

415

416

418 MemoryEffectOpInterface interface = dyn_cast(op);

419 if (!interface)

420 return false;

422 if (auto effect =

423 interface.getEffectOnValueMemoryEffects::Allocate(res)) {

424 if (isaSideEffects::AutomaticAllocationScopeResource(

425 effect->getResource()))

426 return true;

427 }

428 }

429 return false;

430}

431

432

433

434

435

437

438

440 return false;

441 MemoryEffectOpInterface interface = dyn_cast(op);

442 if (!interface)

443 return true;

445 if (auto effect =

446 interface.getEffectOnValueMemoryEffects::Allocate(res)) {

447 if (isaSideEffects::AutomaticAllocationScopeResource(

448 effect->getResource()))

449 return true;

450 }

451 }

452 return false;

453}

454

455

456

457

458

464

465

466

469

472 bool hasPotentialAlloca =

474 if (alloc == op)

481 }).wasInterrupted();

482

483

484

485 if (hasPotentialAlloca) {

486

487

489 return failure();

491 return failure();

492 }

493

494 Block *block = &op.getRegion().front();

499 rewriter.eraseOp(terminator);

501 }

502};

503

504

505

506

509

512

514 return failure();

515

517

518 if (!lastParentWithoutScope ||

520 return failure();

521

522

523

524

527 return failure();

528

529 while (!lastParentWithoutScope->getParentOp()

531 lastParentWithoutScope = lastParentWithoutScope->getParentOp();

532 if (!lastParentWithoutScope ||

534 return failure();

535 }

536 assert(lastParentWithoutScope->getParentOp()

538

539 Region *containingRegion = nullptr;

540 for (auto &r : lastParentWithoutScope->getRegions()) {

541 if (r.isAncestor(op->getParentRegion())) {

542 assert(containingRegion == nullptr &&

543 "only one region can contain the op");

544 containingRegion = &r;

545 }

546 }

547 assert(containingRegion && "op must be contained in a region");

548

553

554

555

557 return containingRegion->isAncestor(v.getParentRegion());

558 }))

560 toHoist.push_back(alloc);

562 });

563

564 if (toHoist.empty())

565 return failure();

567 for (auto *op : toHoist) {

568 auto *cloned = rewriter.clone(*op);

569 rewriter.replaceOp(op, cloned->getResults());

570 }

572 }

573};

574

575void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,

578}

579

580

581

582

583

584LogicalResult AssumeAlignmentOp::verify() {

585 if (!llvm::isPowerOf2_32(getAlignment()))

586 return emitOpError("alignment must be power of 2");

588}

589

590void AssumeAlignmentOp::getAsmResultNames(

592 setNameFn(getResult(), "assume_align");

593}

594

595OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {

596 auto source = getMemref().getDefiningOp();

597 if (!source)

598 return {};

599 if (source.getAlignment() != getAlignment())

600 return {};

601 return getMemref();

602}

603

604FailureOr<std::optional<SmallVector>>

605AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {

607}

608

609

610

611

612

613LogicalResult DistinctObjectsOp::verify() {

614 if (getOperandTypes() != getResultTypes())

615 return emitOpError("operand types and result types must match");

616

617 if (getOperandTypes().empty())

618 return emitOpError("expected at least one operand");

619

621}

622

623LogicalResult DistinctObjectsOp::inferReturnTypes(

624 MLIRContext * , std::optional ,

625 ValueRange operands, DictionaryAttr ,

628 llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));

630}

631

632

633

634

635

636void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {

637 setNameFn(getResult(), "cast");

638}

639

640

641

642

643

644

645

646

647

648

649

650

651

652

653

654

655

656

657

658

659

660

661

662

663

664

665

666

667

668

669

670

671

672

673

674

675

676

677bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {

678 MemRefType sourceType =

679 llvm::dyn_cast(castOp.getSource().getType());

680 MemRefType resultType = llvm::dyn_cast(castOp.getType());

681

682

683 if (!sourceType || !resultType)

684 return false;

685

686

687 if (sourceType.getElementType() != resultType.getElementType())

688 return false;

689

690

691 if (sourceType.getRank() != resultType.getRank())

692 return false;

693

694

695 int64_t sourceOffset, resultOffset;

697 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||

698 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))

699 return false;

700

701

702 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {

703 auto ss = std::get<0>(it), st = std::get<1>(it);

704 if (ss != st)

705 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))

706 return false;

707 }

708

709

710 if (sourceOffset != resultOffset)

711 if (ShapedType::isDynamic(sourceOffset) &&

712 ShapedType::isStatic(resultOffset))

713 return false;

714

715

716 for (auto it : llvm::zip(sourceStrides, resultStrides)) {

717 auto ss = std::get<0>(it), st = std::get<1>(it);

718 if (ss != st)

719 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))

720 return false;

721 }

722

723 return true;

724}

725

727 if (inputs.size() != 1 || outputs.size() != 1)

728 return false;

729 Type a = inputs.front(), b = outputs.front();

730 auto aT = llvm::dyn_cast(a);

731 auto bT = llvm::dyn_cast(b);

732

733 auto uaT = llvm::dyn_cast(a);

734 auto ubT = llvm::dyn_cast(b);

735

736 if (aT && bT) {

737 if (aT.getElementType() != bT.getElementType())

738 return false;

739 if (aT.getLayout() != bT.getLayout()) {

740 int64_t aOffset, bOffset;

742 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||

743 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||

744 aStrides.size() != bStrides.size())

745 return false;

746

747

748

749

750

752 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);

753 };

754 if (!checkCompatible(aOffset, bOffset))

755 return false;

756 for (const auto &aStride : enumerate(aStrides))

757 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))

758 return false;

759 }

760 if (aT.getMemorySpace() != bT.getMemorySpace())

761 return false;

762

763

764 if (aT.getRank() != bT.getRank())

765 return false;

766

767 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {

768 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);

769 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&

770 aDim != bDim)

771 return false;

772 }

773 return true;

774 } else {

775 if (!aT && !uaT)

776 return false;

777 if (!bT && !ubT)

778 return false;

779

780 if (uaT && ubT)

781 return false;

782

783 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();

784 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();

785 if (aEltType != bEltType)

786 return false;

787

788 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();

789 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();

790 return aMemSpace == bMemSpace;

791 }

792

793 return false;

794}

795

796OpFoldResult CastOp::fold(FoldAdaptor adaptor) {

798}

799

800FailureOr<std::optional<SmallVector>>

801CastOp::bubbleDownCasts(OpBuilder &builder) {

803}

804

805

806

807

808

809namespace {

810

811

813 using OpRewritePattern::OpRewritePattern;

814

815 LogicalResult matchAndRewrite(CopyOp copyOp,

816 PatternRewriter &rewriter) const override {

817 if (copyOp.getSource() != copyOp.getTarget())

818 return failure();

819

820 rewriter.eraseOp(copyOp);

822 }

823};

824

825struct FoldEmptyCopy final : public OpRewritePattern {

826 using OpRewritePattern::OpRewritePattern;

827

828 static bool isEmptyMemRef(BaseMemRefType type) {

829 return type.hasRank() && llvm::is_contained(type.getShape(), 0);

830 }

831

832 LogicalResult matchAndRewrite(CopyOp copyOp,

833 PatternRewriter &rewriter) const override {

834 if (isEmptyMemRef(copyOp.getSource().getType()) ||

835 isEmptyMemRef(copyOp.getTarget().getType())) {

836 rewriter.eraseOp(copyOp);

838 }

839

840 return failure();

841 }

842};

843}

844

845void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,

847 results.add<FoldEmptyCopy, FoldSelfCopy>(context);

848}

849

850

851

852

854 for (OpOperand &operand : op->getOpOperands()) {

855 auto castOp = operand.get().getDefiningOpmemref::CastOp();

856 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {

857 operand.set(castOp.getOperand());

859 }

860 }

861 return failure();

862}

863

864LogicalResult CopyOp::fold(FoldAdaptor adaptor,

866

867

869}

870

871

872

873

874

875LogicalResult DeallocOp::fold(FoldAdaptor adaptor,

877

879}

880

881

882

883

884

885void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {

886 setNameFn(getResult(), "dim");

887}

888

891 auto loc = result.location;

893 build(builder, result, source, indexValue);

894}

895

896std::optional<int64_t> DimOp::getConstantIndex() {

898}

899

902 if (!constantIndex)

904

905 auto rankedSourceType = dyn_cast(getSource().getType());

906 if (!rankedSourceType)

908

909 if (rankedSourceType.getRank() <= constantIndex)

911

913}

914

917 setResultRange(getResult(),

919}

920

921

922

923

924

926 std::map<int64_t, unsigned> numOccurences;

927 for (auto val : vals)

928 numOccurences[val]++;

929 return numOccurences;

930}

931

932

933

934

935

936

937

938

939static FailureOrllvm::SmallBitVector

942 llvm::SmallBitVector unusedDims(originalType.getRank());

943 if (originalType.getRank() == reducedType.getRank())

944 return unusedDims;

945

946 for (const auto &dim : llvm::enumerate(sizes))

947 if (auto attr = llvm::dyn_cast_if_present(dim.value()))

948 if (llvm::cast(attr).getInt() == 1)

949 unusedDims.set(dim.index());

950

951

952

953 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==

954 originalType.getRank())

955 return unusedDims;

956

958 int64_t originalOffset, candidateOffset;

959 if (failed(

960 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||

961 failed(

962 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))

963 return failure();

964

965

966

967

968

969

970

971

972

973

974 std::map<int64_t, unsigned> currUnaccountedStrides =

976 std::map<int64_t, unsigned> candidateStridesNumOccurences =

978 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {

979 if (!unusedDims.test(dim))

980 continue;

981 int64_t originalStride = originalStrides[dim];

982 if (currUnaccountedStrides[originalStride] >

983 candidateStridesNumOccurences[originalStride]) {

984

985 currUnaccountedStrides[originalStride]--;

986 continue;

987 }

988 if (currUnaccountedStrides[originalStride] ==

989 candidateStridesNumOccurences[originalStride]) {

990

991 unusedDims.reset(dim);

992 continue;

993 }

994 if (currUnaccountedStrides[originalStride] <

995 candidateStridesNumOccurences[originalStride]) {

996

997

998 return failure();

999 }

1000 }

1001

1002 if ((int64_t)unusedDims.count() + reducedType.getRank() !=

1003 originalType.getRank())

1004 return failure();

1005 return unusedDims;

1006}

1007

1008llvm::SmallBitVector SubViewOp::getDroppedDims() {

1009 MemRefType sourceType = getSourceType();

1010 MemRefType resultType = getType();

1011 FailureOrllvm::SmallBitVector unusedDims =

1013 assert(succeeded(unusedDims) && "unable to find unused dims of subview");

1014 return *unusedDims;

1015}

1016

1017OpFoldResult DimOp::fold(FoldAdaptor adaptor) {

1018

1019 auto index = llvm::dyn_cast_if_present(adaptor.getIndex());

1021 return {};

1022

1023

1024 auto memrefType = llvm::dyn_cast(getSource().getType());

1025 if (!memrefType)

1026 return {};

1027

1028

1029

1031 if (indexVal < 0 || indexVal >= memrefType.getRank())

1032 return {};

1033

1034

1035 if (!memrefType.isDynamicDim(index.getInt())) {

1037 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);

1038 }

1039

1040

1041 unsigned unsignedIndex = index.getValue().getZExtValue();

1042

1043

1044 Operation *definingOp = getSource().getDefiningOp();

1045

1046 if (auto alloc = dyn_cast_or_null(definingOp))

1047 return *(alloc.getDynamicSizes().begin() +

1048 memrefType.getDynamicDimIndex(unsignedIndex));

1049

1050 if (auto alloca = dyn_cast_or_null(definingOp))

1051 return *(alloca.getDynamicSizes().begin() +

1052 memrefType.getDynamicDimIndex(unsignedIndex));

1053

1054 if (auto view = dyn_cast_or_null(definingOp))

1055 return *(view.getDynamicSizes().begin() +

1056 memrefType.getDynamicDimIndex(unsignedIndex));

1057

1058 if (auto subview = dyn_cast_or_null(definingOp)) {

1059 llvm::SmallBitVector unusedDims = subview.getDroppedDims();

1060 unsigned resultIndex = 0;

1061 unsigned sourceRank = subview.getSourceType().getRank();

1062 unsigned sourceIndex = 0;

1063 for (auto i : llvm::seq(0, sourceRank)) {

1064 if (unusedDims.test(i))

1065 continue;

1066 if (resultIndex == unsignedIndex) {

1067 sourceIndex = i;

1068 break;

1069 }

1070 resultIndex++;

1071 }

1072 assert(subview.isDynamicSize(sourceIndex) &&

1073 "expected dynamic subview size");

1074 return subview.getDynamicSize(sourceIndex);

1075 }

1076

1077

1079 return getResult();

1080

1081 return {};

1082}

1083

1084namespace {

1085

1086

1088 using OpRewritePattern::OpRewritePattern;

1089

1090 LogicalResult matchAndRewrite(DimOp dim,

1091 PatternRewriter &rewriter) const override {

1092 auto reshape = dim.getSource().getDefiningOp();

1093

1094 if (!reshape)

1096 dim, "Dim op is not defined by a reshape op.");

1097

1098

1099

1100

1101

1102

1103

1104

1105

1106

1107 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {

1108 if (auto *definingOp = dim.getIndex().getDefiningOp()) {

1109 if (reshape->isBeforeInBlock(definingOp)) {

1111 dim,

1112 "dim.getIndex is not defined before reshape in the same block.");

1113 }

1114 }

1115

1116 }

1117 else if (dim->getBlock() != reshape->getBlock() &&

1118 !dim.getIndex().getParentRegion()->isProperAncestor(

1119 reshape->getParentRegion())) {

1120

1121

1122

1124 dim, "dim.getIndex does not dominate reshape.");

1125 }

1126

1127

1128

1130 Location loc = dim.getLoc();

1131 Value load =

1132 LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());

1133 if (load.getType() != dim.getType())

1134 load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);

1137 }

1138};

1139

1140}

1141

1142void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,

1144 results.add(context);

1145}

1146

1147

1148

1149

1150

1155 Value elementsPerStride) {

1156 result.addOperands(srcMemRef);

1157 result.addOperands(srcIndices);

1158 result.addOperands(destMemRef);

1159 result.addOperands(destIndices);

1160 result.addOperands({numElements, tagMemRef});

1161 result.addOperands(tagIndices);

1162 if (stride)

1163 result.addOperands({stride, elementsPerStride});

1164}

1165

1167 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "

1168 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()

1169 << ", " << getTagMemRef() << '[' << getTagIndices() << ']';

1170 if (isStrided())

1171 p << ", " << getStride() << ", " << getNumElementsPerStride();

1172

1174 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()

1175 << ", " << getTagMemRef().getType();

1176}

1177

1178

1179

1180

1181

1182

1183

1184

1185

1195

1198

1199

1200

1201

1202

1210 return failure();

1211

1212

1214 return failure();

1215

1216 bool isStrided = strideInfo.size() == 2;

1217 if (!strideInfo.empty() && !isStrided) {

1219 "expected two stride related operands");

1220 }

1221

1223 return failure();

1224 if (types.size() != 3)

1225 return parser.emitError(parser.getNameLoc(), "fewer/more types expected");

1226

1231

1234

1236 return failure();

1237

1238 if (isStrided) {

1240 return failure();

1241 }

1242

1244}

1245

1246LogicalResult DmaStartOp::verify() {

1247 unsigned numOperands = getNumOperands();

1248

1249

1250

1251 if (numOperands < 4)

1252 return emitOpError("expected at least 4 operands");

1253

1254

1255

1256

1257 if (!llvm::isa(getSrcMemRef().getType()))

1258 return emitOpError("expected source to be of memref type");

1259 if (numOperands < getSrcMemRefRank() + 4)

1260 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4

1261 << " operands";

1262 if (!getSrcIndices().empty() &&

1263 !llvm::all_of(getSrcIndices().getTypes(),

1265 return emitOpError("expected source indices to be of index type");

1266

1267

1268 if (!llvm::isa(getDstMemRef().getType()))

1269 return emitOpError("expected destination to be of memref type");

1270 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;

1271 if (numOperands < numExpectedOperands)

1272 return emitOpError() << "expected at least " << numExpectedOperands

1273 << " operands";

1274 if (!getDstIndices().empty() &&

1275 !llvm::all_of(getDstIndices().getTypes(),

1277 return emitOpError("expected destination indices to be of index type");

1278

1279

1281 return emitOpError("expected num elements to be of index type");

1282

1283

1284 if (!llvm::isa(getTagMemRef().getType()))

1285 return emitOpError("expected tag to be of memref type");

1286 numExpectedOperands += getTagMemRefRank();

1287 if (numOperands < numExpectedOperands)

1288 return emitOpError() << "expected at least " << numExpectedOperands

1289 << " operands";

1290 if (!getTagIndices().empty() &&

1291 !llvm::all_of(getTagIndices().getTypes(),

1293 return emitOpError("expected tag indices to be of index type");

1294

1295

1296

1297 if (numOperands != numExpectedOperands &&

1298 numOperands != numExpectedOperands + 2)

1299 return emitOpError("incorrect number of operands");

1300

1301

1302 if (isStrided()) {

1303 if (!getStride().getType().isIndex() ||

1304 !getNumElementsPerStride().getType().isIndex())

1306 "expected stride and num elements per stride to be of type index");

1307 }

1308

1310}

1311

1312LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,

1314

1316}

1317

1318

1319

1320

1321

1322LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,

1324

1326}

1327

1328LogicalResult DmaWaitOp::verify() {

1329

1330 unsigned numTagIndices = getTagIndices().size();

1331 unsigned tagMemRefRank = getTagMemRefRank();

1332 if (numTagIndices != tagMemRefRank)

1333 return emitOpError() << "expected tagIndices to have the same number of "

1334 "elements as the tagMemRef rank, expected "

1335 << tagMemRefRank << ", but got " << numTagIndices;

1337}

1338

1339

1340

1341

1342

1343void ExtractAlignedPointerAsIndexOp::getAsmResultNames(

1345 setNameFn(getResult(), "intptr");

1346}

1347

1348

1349

1350

1351

1352

1353

1354LogicalResult ExtractStridedMetadataOp::inferReturnTypes(

1355 MLIRContext *context, std::optional location,

1356 ExtractStridedMetadataOp::Adaptor adaptor,

1358 auto sourceType = llvm::dyn_cast(adaptor.getSource().getType());

1359 if (!sourceType)

1360 return failure();

1361

1362 unsigned sourceRank = sourceType.getRank();

1363 IndexType indexType = IndexType::get(context);

1364 auto memrefType =

1365 MemRefType::get({}, sourceType.getElementType(),

1366 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());

1367

1368 inferredReturnTypes.push_back(memrefType);

1369

1370 inferredReturnTypes.push_back(indexType);

1371

1372 for (unsigned i = 0; i < sourceRank * 2; ++i)

1373 inferredReturnTypes.push_back(indexType);

1375}

1376

1377void ExtractStridedMetadataOp::getAsmResultNames(

1379 setNameFn(getBaseBuffer(), "base_buffer");

1380 setNameFn(getOffset(), "offset");

1381

1382

1383 if (!getSizes().empty()) {

1384 setNameFn(getSizes().front(), "sizes");

1385 setNameFn(getStrides().front(), "strides");

1386 }

1387}

1388

1389

1390

1391

1392template

1394 Container values,

1396 assert(values.size() == maybeConstants.size() &&

1397 " expected values and maybeConstants of the same size");

1398 bool atLeastOneReplacement = false;

1399 for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {

1400

1401

1403 continue;

1404 assert(isa(maybeConstant) &&

1405 "The constified value should be either unchanged (i.e., == result) "

1406 "or a constant");

1408 rewriter, loc,

1409 llvm::cast(cast(maybeConstant)).getInt());

1410 for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {

1411

1412

1414 atLeastOneReplacement = true;

1415 }

1416 }

1417 return atLeastOneReplacement;

1418}

1419

1420LogicalResult

1421ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,

1424

1427 getConstifiedMixedOffset());

1429 getConstifiedMixedSizes());

1431 builder, getLoc(), getStrides(), getConstifiedMixedStrides());

1432

1433

1434 if (auto prev = getSource().getDefiningOp())

1435 if (isa(prev.getSource().getType())) {

1436 getSourceMutable().assign(prev.getSource());

1437 atLeastOneReplacement = true;

1438 }

1439

1440 return success(atLeastOneReplacement);

1441}

1442

1446 return values;

1447}

1448

1450ExtractStridedMetadataOp::getConstifiedMixedStrides() {

1454 LogicalResult status =

1455 getSource().getType().getStridesAndOffset(staticValues, unused);

1456 (void)status;

1457 assert(succeeded(status) && "could not get strides from type");

1459 return values;

1460}

1461

1462OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {

1467 LogicalResult status =

1468 getSource().getType().getStridesAndOffset(unused, offset);

1469 (void)status;

1470 assert(succeeded(status) && "could not get offset from type");

1471 staticValues.push_back(offset);

1473 return values[0];

1474}

1475

1476

1477

1478

1479

1484 result.addOperands(ivs);

1485

1486 if (auto memrefType = llvm::dyn_cast(memref.getType())) {

1487 Type elementType = memrefType.getElementType();

1488 result.addTypes(elementType);

1489

1493 }

1494}

1495

1496LogicalResult GenericAtomicRMWOp::verify() {

1497 auto &body = getRegion();

1498 if (body.getNumArguments() != 1)

1499 return emitOpError("expected single number of entry block arguments");

1500

1501 if (getResult().getType() != body.getArgument(0).getType())

1502 return emitOpError("expected block argument of the same type result type");

1503

1505 body.walk([&](Operation *nestedOp) {

1509 "body of 'memref.generic_atomic_rmw' should contain "

1510 "only operations with no side effects");

1512 })

1513 .wasInterrupted();

1515}

1516

1517ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,

1520 Type memrefType;

1522

1529 return failure();

1530

1534 return failure();

1535 result.types.push_back(llvm::cast(memrefType).getElementType());

1537}

1538

1539void GenericAtomicRMWOp::print(OpAsmPrinter &p) {

1540 p << ' ' << getMemref() << "[" << getIndices()

1541 << "] : " << getMemref().getType() << ' ';

1544}

1545

1546

1547

1548

1549

1550LogicalResult AtomicYieldOp::verify() {

1551 Type parentType = (*this)->getParentOp()->getResultTypes().front();

1552 Type resultType = getResult().getType();

1553 if (parentType != resultType)

1554 return emitOpError() << "types mismatch between yield op: " << resultType

1555 << " and its parent: " << parentType;

1557}

1558

1559

1560

1561

1562

1564 TypeAttr type,

1566 p << type;

1567 if (!op.isExternal()) {

1568 p << " = ";

1569 if (op.isUninitialized())

1570 p << "uninitialized";

1571 else

1573 }

1574}

1575

1576static ParseResult

1581 return failure();

1582

1583 auto memrefType = llvm::dyn_cast(type);

1584 if (!memrefType || !memrefType.hasStaticShape())

1586 << "type should be static shaped memref, but got " << type;

1587 typeAttr = TypeAttr::get(type);

1588

1591

1593 initialValue = UnitAttr::get(parser.getContext());

1595 }

1596

1598 if (parser.parseAttribute(initialValue, tensorType))

1599 return failure();

1600 if (!llvm::isa(initialValue))

1602 << "initial value should be a unit or elements attribute";

1604}

1605

1606LogicalResult GlobalOp::verify() {

1607 auto memrefType = llvm::dyn_cast(getType());

1608 if (!memrefType || !memrefType.hasStaticShape())

1609 return emitOpError("type should be static shaped memref, but got ")

1611

1612

1613

1614 if (getInitialValue().has_value()) {

1615 Attribute initValue = getInitialValue().value();

1616 if (!llvm::isa(initValue) && !llvm::isa(initValue))

1617 return emitOpError("initial value should be a unit or elements "

1618 "attribute, but got ")

1619 << initValue;

1620

1621

1622

1623 if (auto elementsAttr = llvm::dyn_cast(initValue)) {

1624

1625 auto initElementType =

1626 cast(elementsAttr.getType()).getElementType();

1627 auto memrefElementType = memrefType.getElementType();

1628

1629 if (initElementType != memrefElementType)

1630 return emitOpError("initial value element expected to be of type ")

1631 << memrefElementType << ", but was of type " << initElementType;

1632

1633

1634

1635

1636 auto initShape = elementsAttr.getShapedType().getShape();

1637 auto memrefShape = memrefType.getShape();

1638 if (initShape != memrefShape)

1639 return emitOpError("initial value shape expected to be ")

1640 << memrefShape << " but was " << initShape;

1641 }

1642 }

1643

1644

1646}

1647

1648ElementsAttr GlobalOp::getConstantInitValue() {

1649 auto initVal = getInitialValue();

1650 if (getConstant() && initVal.has_value())

1651 return llvm::cast(initVal.value());

1652 return {};

1653}

1654

1655

1656

1657

1658

1659LogicalResult

1661

1662

1663 auto global =

1665 if (!global)

1667 << getName() << "' does not reference a valid global memref";

1668

1669 Type resultType = getResult().getType();

1670 if (global.getType() != resultType)

1672 << resultType << " does not match type " << global.getType()

1673 << " of the global memref @" << getName();

1675}

1676

1677

1678

1679

1680

1681LogicalResult LoadOp::verify() {

1683 return emitOpError("incorrect number of indices for load, expected ")

1685 }

1687}

1688

1689OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {

1690

1692 return getResult();

1694}

1695

1696FailureOr<std::optional<SmallVector>>

1697LoadOp::bubbleDownCasts(OpBuilder &builder) {

1699 getResult());

1700}

1701

1702

1703

1704

1705

1706void MemorySpaceCastOp::getAsmResultNames(

1708 setNameFn(getResult(), "memspacecast");

1709}

1710

1711bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

1712 if (inputs.size() != 1 || outputs.size() != 1)

1713 return false;

1714 Type a = inputs.front(), b = outputs.front();

1715 auto aT = llvm::dyn_cast(a);

1716 auto bT = llvm::dyn_cast(b);

1717

1718 auto uaT = llvm::dyn_cast(a);

1719 auto ubT = llvm::dyn_cast(b);

1720

1721 if (aT && bT) {

1722 if (aT.getElementType() != bT.getElementType())

1723 return false;

1724 if (aT.getLayout() != bT.getLayout())

1725 return false;

1726 if (aT.getShape() != bT.getShape())

1727 return false;

1728 return true;

1729 }

1730 if (uaT && ubT) {

1731 return uaT.getElementType() == ubT.getElementType();

1732 }

1733 return false;

1734}

1735

1736OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {

1737

1738

1739 if (auto parentCast = getSource().getDefiningOp()) {

1740 getSourceMutable().assign(parentCast.getSource());

1741 return getResult();

1742 }

1744}

1745

1747 return getSource();

1748}

1749

1751 return getDest();

1752}

1753

1754bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,

1755 PtrLikeTypeInterface src) {

1756 return isa(tgt) &&

1757 tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;

1758}

1759

1760MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(

1761 OpBuilder &b, PtrLikeTypeInterface tgt,

1763 assert(isValidMemorySpaceCast(tgt, src.getType()) && "invalid arguments");

1764 return MemorySpaceCastOp::create(b, getLoc(), tgt, src);

1765}

1766

1767

1768bool MemorySpaceCastOp::isSourcePromotable() {

1769 return getDest().getType().getMemorySpace() == nullptr;

1770}

1771

1772

1773

1774

1775

1777 p << " " << getMemref() << '[';

1779 p << ']' << ", " << (getIsWrite() ? "write" : "read");

1780 p << ", locality<" << getLocalityHint();

1781 p << ">, " << (getIsDataCache() ? "data" : "instr");

1783 (*this)->getAttrs(),

1784 {"localityHint", "isWrite", "isDataCache"});

1786}

1787

1791 IntegerAttr localityHint;

1792 MemRefType type;

1793 StringRef readOrWrite, cacheType;

1794

1802 parser.parseAttribute(localityHint, i32Type, "localityHint",

1803 result.attributes) ||

1808 return failure();

1809

1810 if (readOrWrite != "read" && readOrWrite != "write")

1812 "rw specifier has to be 'read' or 'write'");

1813 result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),

1815

1816 if (cacheType != "data" && cacheType != "instr")

1818 "cache type has to be 'data' or 'instr'");

1819

1820 result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),

1822

1824}

1825

1826LogicalResult PrefetchOp::verify() {

1827 if (getNumOperands() != 1 + getMemRefType().getRank())

1829

1831}

1832

1833LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,

1835

1837}

1838

1839

1840

1841

1842

1843OpFoldResult RankOp::fold(FoldAdaptor adaptor) {

1844

1845 auto type = getOperand().getType();

1846 auto shapedType = llvm::dyn_cast(type);

1847 if (shapedType && shapedType.hasRank())

1848 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());

1849 return IntegerAttr();

1850}

1851

1852

1853

1854

1855

1856void ReinterpretCastOp::getAsmResultNames(

1858 setNameFn(getResult(), "reinterpret_cast");

1859}

1860

1861

1862

1863

1865 MemRefType resultType, Value source,

1874 result.addAttributes(attrs);

1875 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,

1876 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),

1877 b.getDenseI64ArrayAttr(staticSizes),

1878 b.getDenseI64ArrayAttr(staticStrides));

1879}

1880

1886 auto sourceType = cast(source.getType());

1892 auto stridedLayout = StridedLayoutAttr::get(

1893 b.getContext(), staticOffsets.front(), staticStrides);

1894 auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),

1895 stridedLayout, sourceType.getMemorySpace());

1896 build(b, result, resultType, source, offset, sizes, strides, attrs);

1897}

1898

1900 MemRefType resultType, Value source,

1905 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {

1906 return b.getI64IntegerAttr(v);

1907 }));

1910 return b.getI64IntegerAttr(v);

1911 }));

1912 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,

1913 strideValues, attrs);

1914}

1915

1917 MemRefType resultType, Value source, Value offset,

1921 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));

1923 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));

1924 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);

1925}

1926

1927

1928

1929LogicalResult ReinterpretCastOp::verify() {

1930

1931 auto srcType = llvm::cast(getSource().getType());

1932 auto resultType = llvm::cast(getType());

1933 if (srcType.getMemorySpace() != resultType.getMemorySpace())

1934 return emitError("different memory spaces specified for source type ")

1935 << srcType << " and result memref type " << resultType;

1936 if (srcType.getElementType() != resultType.getElementType())

1937 return emitError("different element types specified for source type ")

1938 << srcType << " and result memref type " << resultType;

1939

1940

1941 for (auto [idx, resultSize, expectedSize] :

1942 llvm::enumerate(resultType.getShape(), getStaticSizes())) {

1943 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)

1944 return emitError("expected result type with size = ")

1945 << (ShapedType::isDynamic(expectedSize)

1946 ? std::string("dynamic")

1947 : std::to_string(expectedSize))

1948 << " instead of " << resultSize << " in dim = " << idx;

1949 }

1950

1951

1952

1953

1956 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))

1957 return emitError("expected result type to have strided layout but found ")

1958 << resultType;

1959

1960

1961 int64_t expectedOffset = getStaticOffsets().front();

1962 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)

1963 return emitError("expected result type with offset = ")

1964 << (ShapedType::isDynamic(expectedOffset)

1965 ? std::string("dynamic")

1966 : std::to_string(expectedOffset))

1967 << " instead of " << resultOffset;

1968

1969

1970 for (auto [idx, resultStride, expectedStride] :

1971 llvm::enumerate(resultStrides, getStaticStrides())) {

1972 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)

1973 return emitError("expected result type with stride = ")

1974 << (ShapedType::isDynamic(expectedStride)

1975 ? std::string("dynamic")

1976 : std::to_string(expectedStride))

1977 << " instead of " << resultStride << " in dim = " << idx;

1978 }

1979

1981}

1982

1983OpFoldResult ReinterpretCastOp::fold(FoldAdaptor ) {

1984 Value src = getSource();

1985 auto getPrevSrc = [&]() -> Value {

1986

1987 if (auto prev = src.getDefiningOp())

1988 return prev.getSource();

1989

1990

1992 return prev.getSource();

1993

1994

1995

1997 if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))

1998 return prev.getSource();

1999

2000 return nullptr;

2001 };

2002

2003 if (auto prevSrc = getPrevSrc()) {

2004 getSourceMutable().assign(prevSrc);

2005 return getResult();

2006 }

2007

2008

2010 src.getType() == getType() && getStaticOffsets().front() == 0) {

2011 return src;

2012 }

2013

2014 return nullptr;

2015}

2016

2020 return values;

2021}

2022

2027 LogicalResult status = getType().getStridesAndOffset(staticValues, unused);

2028 (void)status;

2029 assert(succeeded(status) && "could not get strides from type");

2031 return values;

2032}

2033

2034OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {

2036 assert(values.size() == 1 &&

2037 "reinterpret_cast must have one and only one offset");

2040 LogicalResult status = getType().getStridesAndOffset(unused, offset);

2041 (void)status;

2042 assert(succeeded(status) && "could not get offset from type");

2043 staticValues.push_back(offset);

2045 return values[0];

2046}

2047

2048namespace {

2049

2050

2051

2052

2053

2054

2055

2056

2057

2058

2059

2060

2061

2062

2063

2064

2065

2066

2067

2068

2069

2070

2071

2072

2073

2074

2075

2076

2077

2078

2079

2080

2081

2082

2083

2084

2085

2086

2087

2088

2089

2090

2091struct ReinterpretCastOpExtractStridedMetadataFolder

2093public:

2094 using OpRewritePattern::OpRewritePattern;

2095

2096 LogicalResult matchAndRewrite(ReinterpretCastOp op,

2097 PatternRewriter &rewriter) const override {

2098 auto extractStridedMetadata =

2099 op.getSource().getDefiningOp();

2100 if (!extractStridedMetadata)

2101 return failure();

2102

2103

2104

2105 auto isReinterpretCastNoop = [&]() -> bool {

2106

2107 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),

2108 op.getConstifiedMixedStrides()))

2109 return false;

2110

2111

2112 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),

2113 op.getConstifiedMixedSizes()))

2114 return false;

2115

2116

2117 assert(op.getMixedOffsets().size() == 1 &&

2118 "reinterpret_cast with more than one offset should have been "

2119 "rejected by the verifier");

2120 return extractStridedMetadata.getConstifiedMixedOffset() ==

2121 op.getConstifiedMixedOffset();

2122 };

2123

2124 if (!isReinterpretCastNoop()) {

2125

2126

2127

2128

2129

2130

2131

2132

2133

2134

2135

2136

2137

2138

2139

2141 op.getSourceMutable().assign(extractStridedMetadata.getSource());

2142 });

2144 }

2145

2146

2147

2148

2149

2150

2151 Type srcTy = extractStridedMetadata.getSource().getType();

2152 if (srcTy == op.getResult().getType())

2153 rewriter.replaceOp(op, extractStridedMetadata.getSource());

2154 else

2156 extractStridedMetadata.getSource());

2157

2159 }

2160};

2161

2162struct ReinterpretCastOpConstantFolder

2164public:

2165 using OpRewritePattern::OpRewritePattern;

2166

2167 LogicalResult matchAndRewrite(ReinterpretCastOp op,

2168 PatternRewriter &rewriter) const override {

2169 unsigned srcStaticCount = llvm::count_if(

2170 llvm::concat(op.getMixedOffsets(), op.getMixedSizes(),

2171 op.getMixedStrides()),

2172 [](OpFoldResult ofr) { return isa(ofr); });

2173

2174 SmallVector offsets = {op.getConstifiedMixedOffset()};

2175 SmallVector sizes = op.getConstifiedMixedSizes();

2176 SmallVector strides = op.getConstifiedMixedStrides();

2177

2178

2179

2180

2181

2182 if (srcStaticCount ==

2183 llvm::count_if(llvm::concat(offsets, sizes, strides),

2184 [](OpFoldResult ofr) { return isa(ofr); }))

2185 return failure();

2186

2187 auto newReinterpretCast = ReinterpretCastOp::create(

2188 rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);

2189

2190 rewriter.replaceOpWithNewOp(op, op.getType(), newReinterpretCast);

2192 }

2193};

2194}

2195

2196void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,

2198 results.add<ReinterpretCastOpExtractStridedMetadataFolder,

2199 ReinterpretCastOpConstantFolder>(context);

2200}

2201

2202FailureOr<std::optional<SmallVector>>

2203ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {

2205}

2206

2207

2208

2209

2210

2211void CollapseShapeOp::getAsmResultNames(

2213 setNameFn(getResult(), "collapse_shape");

2214}

2215

2216void ExpandShapeOp::getAsmResultNames(

2217 function_ref<void(Value, StringRef)> setNameFn) {

2218 setNameFn(getResult(), "expand_shape");

2219}

2220

2221LogicalResult ExpandShapeOp::reifyResultShapes(

2223 reifiedResultShapes = {

2224 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};

2226}

2227

2228

2229

2230

2231

2232

2233static LogicalResult

2237 bool allowMultipleDynamicDimsPerGroup) {

2238

2239 if (collapsedShape.size() != reassociation.size())

2240 return op->emitOpError("invalid number of reassociation groups: found ")

2241 << reassociation.size() << ", expected " << collapsedShape.size();

2242

2243

2244

2246 for (const auto &it : llvm::enumerate(reassociation)) {

2248 int64_t collapsedDim = it.index();

2249

2250 bool foundDynamic = false;

2251 for (int64_t expandedDim : group) {

2252 if (expandedDim != nextDim++)

2253 return op->emitOpError("reassociation indices must be contiguous");

2254

2255 if (expandedDim >= static_cast<int64_t>(expandedShape.size()))

2256 return op->emitOpError("reassociation index ")

2257 << expandedDim << " is out of bounds";

2258

2259

2260 if (ShapedType::isDynamic(expandedShape[expandedDim])) {

2261 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)

2263 "at most one dimension in a reassociation group may be dynamic");

2264 foundDynamic = true;

2265 }

2266 }

2267

2268

2269 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)

2270 return op->emitOpError("collapsed dim (")

2271 << collapsedDim

2272 << ") must be dynamic if and only if reassociation group is "

2273 "dynamic";

2274

2275

2276

2277 if (!foundDynamic) {

2279 for (int64_t expandedDim : group)

2280 groupSize *= expandedShape[expandedDim];

2281 if (groupSize != collapsedShape[collapsedDim])

2282 return op->emitOpError("collapsed dim size (")

2283 << collapsedShape[collapsedDim]

2284 << ") must equal reassociation group size (" << groupSize << ")";

2285 }

2286 }

2287

2288 if (collapsedShape.empty()) {

2289

2290 for (int64_t d : expandedShape)

2291 if (d != 1)

2293 "rank 0 memrefs can only be extended/collapsed with/from ones");

2294 } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {

2295

2296

2297 return op->emitOpError("expanded rank (")

2298 << expandedShape.size()

2299 << ") inconsistent with number of reassociation indices (" << nextDim

2300 << ")";

2301 }

2302

2304}

2305

2306SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {

2308}

2309

2310SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {

2312 getReassociationIndices());

2313}

2314

2315SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {

2317}

2318

2319SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {

2321 getReassociationIndices());

2322}

2323

2324

2325

2326static FailureOr

2331 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))

2332 return failure();

2333 assert(srcStrides.size() == reassociation.size() && "invalid reassociation");

2334

2335

2336

2337

2338

2339

2340

2341

2342

2343

2344

2345

2346

2348 reverseResultStrides.reserve(resultShape.size());

2349 unsigned shapeIndex = resultShape.size() - 1;

2350 for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {

2352 int64_t currentStrideToExpand = std::get<1>(it);

2353 for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {

2354 reverseResultStrides.push_back(currentStrideToExpand);

2355 currentStrideToExpand =

2358 .asInteger();

2359 }

2360 }

2361 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));

2362 resultStrides.resize(resultShape.size(), 1);

2363 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);

2364}

2365

2366FailureOr ExpandShapeOp::computeExpandedType(

2367 MemRefType srcType, ArrayRef<int64_t> resultShape,

2368 ArrayRef reassociation) {

2369 if (srcType.getLayout().isIdentity()) {

2370

2371

2372 MemRefLayoutAttrInterface layout;

2373 return MemRefType::get(resultShape, srcType.getElementType(), layout,

2374 srcType.getMemorySpace());

2375 }

2376

2377

2378 FailureOr computedLayout =

2380 if (failed(computedLayout))

2381 return failure();

2382 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,

2383 srcType.getMemorySpace());

2384}

2385

2386FailureOr<SmallVector>

2387ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,

2388 MemRefType expandedType,

2389 ArrayRef reassociation,

2390 ArrayRef inputShape) {

2391 std::optional<SmallVector> outputShape =

2393 inputShape);

2394 if (!outputShape)

2395 return failure();

2396 return *outputShape;

2397}

2398

2399void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,

2400 Type resultType, Value src,

2401 ArrayRef reassociation,

2402 ArrayRef outputShape) {

2403 auto [staticOutputShape, dynamicOutputShape] =

2405 build(builder, result, llvm::cast(resultType), src,

2407 dynamicOutputShape, staticOutputShape);

2408}

2409

2410void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,

2411 Type resultType, Value src,

2412 ArrayRef reassociation) {

2413 SmallVector inputShape =

2415 MemRefType memrefResultTy = llvm::cast(resultType);

2416 FailureOr<SmallVector> outputShape = inferOutputShape(

2417 builder, result.location, memrefResultTy, reassociation, inputShape);

2418

2419

2420 assert(succeeded(outputShape) && "unable to infer output shape");

2421 build(builder, result, memrefResultTy, src, reassociation, *outputShape);

2422}

2423

2424void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,

2425 ArrayRef<int64_t> resultShape, Value src,

2426 ArrayRef reassociation) {

2427

2428 auto srcType = llvm::cast(src.getType());

2429 FailureOr resultType =

2430 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);

2431

2432

2433 assert(succeeded(resultType) && "could not compute layout");

2434 build(builder, result, *resultType, src, reassociation);

2435}

2436

2437void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,

2438 ArrayRef<int64_t> resultShape, Value src,

2439 ArrayRef reassociation,

2440 ArrayRef outputShape) {

2441

2442 auto srcType = llvm::cast(src.getType());

2443 FailureOr resultType =

2444 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);

2445

2446

2447 assert(succeeded(resultType) && "could not compute layout");

2448 build(builder, result, *resultType, src, reassociation, outputShape);

2449}

2450

2451LogicalResult ExpandShapeOp::verify() {

2452 MemRefType srcType = getSrcType();

2453 MemRefType resultType = getResultType();

2454

2455 if (srcType.getRank() > resultType.getRank()) {

2456 auto r0 = srcType.getRank();

2457 auto r1 = resultType.getRank();

2459 << r0 << " and result rank " << r1 << ". This is not an expansion ("

2460 << r0 << " > " << r1 << ").";

2461 }

2462

2463

2465 resultType.getShape(),

2466 getReassociationIndices(),

2467 true)))

2468 return failure();

2469

2470

2471 FailureOr expectedResultType = ExpandShapeOp::computeExpandedType(

2472 srcType, resultType.getShape(), getReassociationIndices());

2473 if (failed(expectedResultType))

2474 return emitOpError("invalid source layout map");

2475

2476

2477 if (*expectedResultType != resultType)

2478 return emitOpError("expected expanded type to be ")

2479 << *expectedResultType << " but found " << resultType;

2480

2481 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())

2482 return emitOpError("expected number of static shape bounds to be equal to "

2483 "the output rank (")

2484 << resultType.getRank() << ") but found "

2485 << getStaticOutputShape().size() << " inputs instead";

2486

2487 if ((int64_t)getOutputShape().size() !=

2488 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))

2489 return emitOpError("mismatch in dynamic dims in output_shape and "

2490 "static_output_shape: static_output_shape has ")

2491 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)

2492 << " dynamic dims while output_shape has " << getOutputShape().size()

2493 << " values";

2494

2495

2496 DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();

2497 ArrayRef<int64_t> resShape = getResult().getType().getShape();

2498 for (auto [pos, shape] : llvm::enumerate(resShape)) {

2499 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {

2500 return emitOpError("invalid output shape provided at pos ") << pos;

2501 }

2502 }

2503

2505}

2506

2507void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,

2508 MLIRContext *context) {

2509 results.add<

2510 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,

2511 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);

2512}

2513

2514FailureOr<std::optional<SmallVector>>

2515ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {

2517}

2518

2519

2520

2521

2522

2523

2524

2525

2526static FailureOr

2529 bool strict = false) {

2532 auto srcShape = srcType.getShape();

2533 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))

2534 return failure();

2535

2536

2537

2538

2539

2540

2542 resultStrides.reserve(reassociation.size());

2545 while (srcShape[ref.back()] == 1 && ref.size() > 1)

2546 ref = ref.drop_back();

2547 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {

2548 resultStrides.push_back(srcStrides[ref.back()]);

2549 } else {

2550

2551

2552

2553

2554 resultStrides.push_back(ShapedType::kDynamic);

2555 }

2556 }

2557

2558

2559 unsigned resultStrideIndex = resultStrides.size() - 1;

2561 auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();

2563 for (int64_t idx : llvm::reverse(trailingReassocs)) {

2565

2566

2567

2568

2569

2570

2571

2572

2573

2575 if (strict && (stride.saturated || srcStride.saturated))

2576 return failure();

2577

2578

2579

2580 if (srcShape[idx - 1] == 1)

2581 continue;

2582

2583 if (!stride.saturated && !srcStride.saturated && stride != srcStride)

2584 return failure();

2585 }

2586 }

2587 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);

2588}

2589

2590bool CollapseShapeOp::isGuaranteedCollapsible(

2591 MemRefType srcType, ArrayRef reassociation) {

2592

2593 if (srcType.getLayout().isIdentity())

2594 return true;

2595

2597 true));

2598}

2599

2600MemRefType CollapseShapeOp::computeCollapsedType(

2601 MemRefType srcType, ArrayRef reassociation) {

2602 SmallVector<int64_t> resultShape;

2603 resultShape.reserve(reassociation.size());

2606 for (int64_t srcDim : group)

2607 groupSize =

2609 resultShape.push_back(groupSize.asInteger());

2610 }

2611

2612 if (srcType.getLayout().isIdentity()) {

2613

2614

2615 MemRefLayoutAttrInterface layout;

2616 return MemRefType::get(resultShape, srcType.getElementType(), layout,

2617 srcType.getMemorySpace());

2618 }

2619

2620

2621

2622

2623 FailureOr computedLayout =

2625 assert(succeeded(computedLayout) &&

2626 "invalid source layout map or collapsing non-contiguous dims");

2627 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,

2628 srcType.getMemorySpace());

2629}

2630

2631void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,

2632 ArrayRef reassociation,

2633 ArrayRef attrs) {

2634 auto srcType = llvm::cast(src.getType());

2635 MemRefType resultType =

2636 CollapseShapeOp::computeCollapsedType(srcType, reassociation);

2639 build(b, result, resultType, src, attrs);

2640}

2641

2642LogicalResult CollapseShapeOp::verify() {

2643 MemRefType srcType = getSrcType();

2644 MemRefType resultType = getResultType();

2645

2646 if (srcType.getRank() < resultType.getRank()) {

2647 auto r0 = srcType.getRank();

2648 auto r1 = resultType.getRank();

2650 << r0 << " and result rank " << r1 << ". This is not a collapse ("

2651 << r0 << " < " << r1 << ").";

2652 }

2653

2654

2656 srcType.getShape(), getReassociationIndices(),

2657 true)))

2658 return failure();

2659

2660

2661 MemRefType expectedResultType;

2662 if (srcType.getLayout().isIdentity()) {

2663

2664

2665 MemRefLayoutAttrInterface layout;

2666 expectedResultType =

2667 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,

2668 srcType.getMemorySpace());

2669 } else {

2670

2671

2672

2673 FailureOr computedLayout =

2675 if (failed(computedLayout))

2677 "invalid source layout map or collapsing non-contiguous dims");

2678 expectedResultType =

2679 MemRefType::get(resultType.getShape(), srcType.getElementType(),

2680 *computedLayout, srcType.getMemorySpace());

2681 }

2682

2683 if (expectedResultType != resultType)

2684 return emitOpError("expected collapsed type to be ")

2685 << expectedResultType << " but found " << resultType;

2686

2688}

2689

2692public:

2694

2697 auto cast = op.getOperand().getDefiningOp();

2698 if (!cast)

2699 return failure();

2700

2701 if (!CastOp::canFoldIntoConsumerOp(cast))

2702 return failure();

2703

2704 Type newResultType = CollapseShapeOp::computeCollapsedType(

2705 llvm::cast(cast.getOperand().getType()),

2706 op.getReassociationIndices());

2707

2708 if (newResultType == op.getResultType()) {

2710 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });

2711 } else {

2713 CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),

2714 op.getReassociationIndices());

2716 }

2718 }

2719};

2720

2721void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,

2722 MLIRContext *context) {

2723 results.add<

2724 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,

2725 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,

2726 memref::DimOp, MemRefType>,

2727 CollapseShapeOpMemRefCastFolder>(context);

2728}

2729

2730OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {

2732 adaptor.getOperands());

2733}

2734

2735OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {

2737 adaptor.getOperands());

2738}

2739

2740FailureOr<std::optional<SmallVector>>

2741CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {

2743}

2744

2745

2746

2747

2748

2749void ReshapeOp::getAsmResultNames(

2750 function_ref<void(Value, StringRef)> setNameFn) {

2751 setNameFn(getResult(), "reshape");

2752}

2753

2754LogicalResult ReshapeOp::verify() {

2755 Type operandType = getSource().getType();

2756 Type resultType = getResult().getType();

2757

2758 Type operandElementType =

2759 llvm::cast(operandType).getElementType();

2760 Type resultElementType = llvm::cast(resultType).getElementType();

2761 if (operandElementType != resultElementType)

2762 return emitOpError("element types of source and destination memref "

2763 "types should be the same");

2764

2765 if (auto operandMemRefType = llvm::dyn_cast(operandType))

2766 if (!operandMemRefType.getLayout().isIdentity())

2767 return emitOpError("source memref type should have identity affine map");

2768

2769 int64_t shapeSize =

2770 llvm::cast(getShape().getType()).getDimSize(0);

2771 auto resultMemRefType = llvm::dyn_cast(resultType);

2772 if (resultMemRefType) {

2773 if (!resultMemRefType.getLayout().isIdentity())

2774 return emitOpError("result memref type should have identity affine map");

2775 if (shapeSize == ShapedType::kDynamic)

2776 return emitOpError("cannot use shape operand with dynamic length to "

2777 "reshape to statically-ranked memref type");

2778 if (shapeSize != resultMemRefType.getRank())

2780 "length of shape operand differs from the result's memref rank");

2781 }

2783}

2784

2785FailureOr<std::optional<SmallVector>>

2786ReshapeOp::bubbleDownCasts(OpBuilder &builder) {

2788}

2789

2790

2791

2792

2793

2794LogicalResult StoreOp::verify() {

2795 if (getNumOperands() != 2 + getMemRefType().getRank())

2796 return emitOpError("store index operand count not equal to memref rank");

2797

2799}

2800

2801LogicalResult StoreOp::fold(FoldAdaptor adaptor,

2802 SmallVectorImpl &results) {

2803

2805}

2806

2807FailureOr<std::optional<SmallVector>>

2808StoreOp::bubbleDownCasts(OpBuilder &builder) {

2811}

2812

2813

2814

2815

2816

2817void SubViewOp::getAsmResultNames(

2818 function_ref<void(Value, StringRef)> setNameFn) {

2819 setNameFn(getResult(), "subview");

2820}

2821

2822

2823

2824

2825MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,

2826 ArrayRef<int64_t> staticOffsets,

2827 ArrayRef<int64_t> staticSizes,

2828 ArrayRef<int64_t> staticStrides) {

2829 unsigned rank = sourceMemRefType.getRank();

2830 (void)rank;

2831 assert(staticOffsets.size() == rank && "staticOffsets length mismatch");

2832 assert(staticSizes.size() == rank && "staticSizes length mismatch");

2833 assert(staticStrides.size() == rank && "staticStrides length mismatch");

2834

2835

2836 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();

2837

2838

2839

2840 int64_t targetOffset = sourceOffset;

2841 for (auto it : llvm::zip(staticOffsets, sourceStrides)) {

2842 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);

2846 .asInteger();

2847 }

2848

2849

2850

2851 SmallVector<int64_t, 4> targetStrides;

2852 targetStrides.reserve(staticOffsets.size());

2853 for (auto it : llvm::zip(sourceStrides, staticStrides)) {

2854 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);

2857 .asInteger());

2858 }

2859

2860

2861 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),

2862 StridedLayoutAttr::get(sourceMemRefType.getContext(),

2863 targetOffset, targetStrides),

2864 sourceMemRefType.getMemorySpace());

2865}

2866

2867MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,

2868 ArrayRef offsets,

2869 ArrayRef sizes,

2870 ArrayRef strides) {

2871 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;

2872 SmallVector dynamicOffsets, dynamicSizes, dynamicStrides;

2877 return {};

2879 return {};

2881 return {};

2882 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,

2883 staticSizes, staticStrides);

2884}

2885

2886MemRefType SubViewOp::inferRankReducedResultType(

2887 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,

2888 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,

2889 ArrayRef<int64_t> strides) {

2890 MemRefType inferredType =

2891 inferResultType(sourceRankedTensorType, offsets, sizes, strides);

2892 assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&

2893 "expected ");

2894 if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))

2895 return inferredType;

2896

2897

2898 std::optional<llvm::SmallDenseSet> dimsToProject =

2900 assert(dimsToProject.has_value() && "invalid rank reduction");

2901

2902

2903 auto inferredLayout = llvm::cast(inferredType.getLayout());

2904 SmallVector<int64_t> rankReducedStrides;

2905 rankReducedStrides.reserve(resultShape.size());

2906 for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {

2907 if (!dimsToProject->contains(idx))

2908 rankReducedStrides.push_back(value);

2909 }

2910 return MemRefType::get(resultShape, inferredType.getElementType(),

2911 StridedLayoutAttr::get(inferredLayout.getContext(),

2912 inferredLayout.getOffset(),

2913 rankReducedStrides),

2914 inferredType.getMemorySpace());

2915}

2916

2917MemRefType SubViewOp::inferRankReducedResultType(

2918 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,

2919 ArrayRef offsets, ArrayRef sizes,

2920 ArrayRef strides) {

2921 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;

2922 SmallVector dynamicOffsets, dynamicSizes, dynamicStrides;

2926 return SubViewOp::inferRankReducedResultType(

2927 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,

2928 staticStrides);

2929}

2930

2931

2932

2933void SubViewOp::build(OpBuilder &b, OperationState &result,

2934 MemRefType resultType, Value source,

2935 ArrayRef offsets,

2936 ArrayRef sizes,

2937 ArrayRef strides,

2938 ArrayRef attrs) {

2939 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;

2940 SmallVector dynamicOffsets, dynamicSizes, dynamicStrides;

2944 auto sourceMemRefType = llvm::cast(source.getType());

2945

2946 if (!resultType) {

2947 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,

2948 staticSizes, staticStrides);

2949 }

2950 result.addAttributes(attrs);

2951 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,

2952 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),

2953 b.getDenseI64ArrayAttr(staticSizes),

2954 b.getDenseI64ArrayAttr(staticStrides));

2955}

2956

2957

2958

2959void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,

2960 ArrayRef offsets,

2961 ArrayRef sizes,

2962 ArrayRef strides,

2963 ArrayRef attrs) {

2964 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);

2965}

2966

2967

2968void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,

2969 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,

2970 ArrayRef<int64_t> strides,

2971 ArrayRef attrs) {

2972 SmallVector offsetValues = llvm::to_vector<4>(

2973 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {

2974 return b.getI64IntegerAttr(v);

2975 }));

2976 SmallVector sizeValues =

2977 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {

2978 return b.getI64IntegerAttr(v);

2979 }));

2980 SmallVector strideValues = llvm::to_vector<4>(

2981 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {

2982 return b.getI64IntegerAttr(v);

2983 }));

2984 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);

2985}

2986

2987

2988

2989void SubViewOp::build(OpBuilder &b, OperationState &result,

2990 MemRefType resultType, Value source,

2991 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,

2992 ArrayRef<int64_t> strides,

2993 ArrayRef attrs) {

2994 SmallVector offsetValues = llvm::to_vector<4>(

2995 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {

2996 return b.getI64IntegerAttr(v);

2997 }));

2998 SmallVector sizeValues =

2999 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {

3000 return b.getI64IntegerAttr(v);

3001 }));

3002 SmallVector strideValues = llvm::to_vector<4>(

3003 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {

3004 return b.getI64IntegerAttr(v);

3005 }));

3006 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,

3007 attrs);

3008}

3009

3010

3011

3012void SubViewOp::build(OpBuilder &b, OperationState &result,

3013 MemRefType resultType, Value source, ValueRange offsets,

3015 ArrayRef attrs) {

3016 SmallVector offsetValues = llvm::to_vector<4>(

3017 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));

3018 SmallVector sizeValues = llvm::to_vector<4>(

3019 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));

3020 SmallVector strideValues = llvm::to_vector<4>(

3021 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));

3022 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);

3023}

3024

3025

3026void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,

3028 ArrayRef attrs) {

3029 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);

3030}

3031

3032

3033Value SubViewOp::getViewSource() { return getSource(); }

3034

3035

3036

3038 int64_t t1Offset, t2Offset;

3040 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);

3041 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);

3042 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;

3043}

3044

3045

3046

3047

3049 const llvm::SmallBitVector &droppedDims) {

3050 assert(size_t(t1.getRank()) == droppedDims.size() &&

3051 "incorrect number of bits");

3052 assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&

3053 "incorrect number of dropped dims");

3054 int64_t t1Offset, t2Offset;

3056 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);

3057 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);

3058 if (failed(res1) || failed(res2))

3059 return false;

3060 for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {

3061 if (droppedDims[i])

3062 continue;

3063 if (t1Strides[i] != t2Strides[j])

3064 return false;

3065 ++j;

3066 }

3067 return true;

3068}

3069

3071 SubViewOp op, Type expectedType) {

3072 auto memrefType = llvm::cast(expectedType);

3077 return op->emitError("expected result rank to be smaller or equal to ")

3078 << "the source rank, but got " << op.getType();

3080 return op->emitError("expected result type to be ")

3081 << expectedType

3082 << " or a rank-reduced version. (mismatch of result sizes), but got "

3083 << op.getType();

3085 return op->emitError("expected result element type to be ")

3086 << memrefType.getElementType() << ", but got " << op.getType();

3088 return op->emitError(

3089 "expected result and source memory spaces to match, but got ")

3090 << op.getType();

3092 return op->emitError("expected result type to be ")

3093 << expectedType

3094 << " or a rank-reduced version. (mismatch of result layout), but "

3095 "got "

3096 << op.getType();

3097 }

3098 llvm_unreachable("unexpected subview verification result");

3099}

3100

3101

3102LogicalResult SubViewOp::verify() {

3103 MemRefType baseType = getSourceType();

3104 MemRefType subViewType = getType();

3105 ArrayRef<int64_t> staticOffsets = getStaticOffsets();

3106 ArrayRef<int64_t> staticSizes = getStaticSizes();

3107 ArrayRef<int64_t> staticStrides = getStaticStrides();

3108

3109

3110 if (baseType.getMemorySpace() != subViewType.getMemorySpace())

3111 return emitError("different memory spaces specified for base memref "

3112 "type ")

3113 << baseType << " and subview memref type " << subViewType;

3114

3115

3116 if (!baseType.isStrided())

3117 return emitError("base type ") << baseType << " is not strided";

3118

3119

3120

3121 MemRefType expectedType = SubViewOp::inferResultType(

3122 baseType, staticOffsets, staticSizes, staticStrides);

3123

3124

3125

3127 expectedType, subViewType);

3130

3131

3132 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())

3134 *this, expectedType);

3135

3136

3139 *this, expectedType);

3140

3141

3142

3143

3144

3147 if (failed(unusedDims))

3149 *this, expectedType);

3150

3151

3154 *this, expectedType);

3155

3156

3157

3158 SliceBoundsVerificationResult boundsResult =

3160 staticStrides, true);

3161 if (!boundsResult.isValid)

3162 return getOperation()->emitError(boundsResult.errorMessage);

3163

3165}

3166

3168 return os << "range " << range.offset << ":" << range.size << ":"

3170}

3171

3172

3173

3174

3177 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();

3178 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");

3179 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");

3181 unsigned rank = ranks[0];

3182 res.reserve(rank);

3183 for (unsigned idx = 0; idx < rank; ++idx) {

3185 op.isDynamicOffset(idx)

3186 ? op.getDynamicOffset(idx)

3189 op.isDynamicSize(idx)

3190 ? op.getDynamicSize(idx)

3193 op.isDynamicStride(idx)

3194 ? op.getDynamicStride(idx)

3196 res.emplace_back(Range{offset, size, stride});

3197 }

3198 return res;

3199}

3200

3201

3202

3203

3204

3205

3206

3207

3209 MemRefType currentResultType, MemRefType currentSourceType,

3212 MemRefType nonRankReducedType = SubViewOp::inferResultType(

3213 sourceType, mixedOffsets, mixedSizes, mixedStrides);

3215 currentSourceType, currentResultType, mixedSizes);

3216 if (failed(unusedDims))

3217 return nullptr;

3218

3219 auto layout = llvm::cast(nonRankReducedType.getLayout());

3221 unsigned numDimsAfterReduction =

3222 nonRankReducedType.getRank() - unusedDims->count();

3223 shape.reserve(numDimsAfterReduction);

3224 strides.reserve(numDimsAfterReduction);

3225 for (const auto &[idx, size, stride] :

3226 llvm::zip(llvm::seq(0, nonRankReducedType.getRank()),

3227 nonRankReducedType.getShape(), layout.getStrides())) {

3228 if (unusedDims->test(idx))

3229 continue;

3230 shape.push_back(size);

3231 strides.push_back(stride);

3232 }

3233

3234 return MemRefType::get(shape, nonRankReducedType.getElementType(),

3235 StridedLayoutAttr::get(sourceType.getContext(),

3236 layout.getOffset(), strides),

3237 nonRankReducedType.getMemorySpace());

3238}

3239

3242 auto memrefType = llvm::cast(memref.getType());

3243 unsigned rank = memrefType.getRank();

3247 MemRefType targetType = SubViewOp::inferRankReducedResultType(

3248 targetShape, memrefType, offsets, sizes, strides);

3249 return b.createOrFoldmemref::SubViewOp(loc, targetType, memref, offsets,

3250 sizes, strides);

3251}

3252

3253FailureOr SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,

3256 auto sourceMemrefType = llvm::dyn_cast(value.getType());

3257 assert(sourceMemrefType && "not a ranked memref type");

3258 auto sourceShape = sourceMemrefType.getShape();

3259 if (sourceShape.equals(desiredShape))

3260 return value;

3261 auto maybeRankReductionMask =

3263 if (!maybeRankReductionMask)

3264 return failure();

3266}

3267

3268

3269

3270

3271

3273 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())

3274 return false;

3275

3276 auto mixedOffsets = subViewOp.getMixedOffsets();

3277 auto mixedSizes = subViewOp.getMixedSizes();

3278 auto mixedStrides = subViewOp.getMixedStrides();

3279

3280

3281 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {

3283 return !intValue || intValue.value() != 0;

3284 }))

3285 return false;

3286

3287

3288 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {

3290 return !intValue || intValue.value() != 1;

3291 }))

3292 return false;

3293

3294

3295 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();

3296 for (const auto &size : llvm::enumerate(mixedSizes)) {

3298 if (!intValue || *intValue != sourceShape[size.index()])

3299 return false;

3300 }

3301

3302 return true;

3303}

3304

3305namespace {

3306

3307

3308

3309

3310

3311

3312

3313

3314

3315

3316

3317

3318

3319

3320

3321

3322class SubViewOpMemRefCastFolder final : public OpRewritePattern {

3323public:

3324 using OpRewritePattern::OpRewritePattern;

3325

3326 LogicalResult matchAndRewrite(SubViewOp subViewOp,

3327 PatternRewriter &rewriter) const override {

3328

3329

3330 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {

3331 return matchPattern(operand, matchConstantIndex());

3332 }))

3333 return failure();

3334

3335 auto castOp = subViewOp.getSource().getDefiningOp();

3336 if (!castOp)

3337 return failure();

3338

3339 if (!CastOp::canFoldIntoConsumerOp(castOp))

3340 return failure();

3341

3342

3343

3344

3345

3347 subViewOp.getType(), subViewOp.getSourceType(),

3348 llvm::cast(castOp.getSource().getType()),

3349 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),

3350 subViewOp.getMixedStrides());

3351 if (!resultType)

3352 return failure();

3353

3354 Value newSubView = SubViewOp::create(

3355 rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),

3356 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),

3357 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),

3358 subViewOp.getStaticStrides());

3360 newSubView);

3362 }

3363};

3364

3365

3366

3367class TrivialSubViewOpFolder final : public OpRewritePattern {

3368public:

3369 using OpRewritePattern::OpRewritePattern;

3370

3371 LogicalResult matchAndRewrite(SubViewOp subViewOp,

3372 PatternRewriter &rewriter) const override {

3374 return failure();

3375 if (subViewOp.getSourceType() == subViewOp.getType()) {

3376 rewriter.replaceOp(subViewOp, subViewOp.getSource());

3378 }

3380 subViewOp.getSource());

3382 }

3383};

3384}

3385

3386

3391

3392 MemRefType resTy = SubViewOp::inferResultType(

3393 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);

3394 if (!resTy)

3395 return {};

3396 MemRefType nonReducedType = resTy;

3397

3398

3399 llvm::SmallBitVector droppedDims = op.getDroppedDims();

3400 if (droppedDims.none())

3401 return nonReducedType;

3402

3403

3404 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();

3405

3406

3409 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {

3410 if (droppedDims.test(i))

3411 continue;

3412 targetStrides.push_back(nonReducedStrides[i]);

3413 targetShape.push_back(nonReducedType.getDimSize(i));

3414 }

3415

3416 return MemRefType::get(targetShape, nonReducedType.getElementType(),

3417 StridedLayoutAttr::get(nonReducedType.getContext(),

3418 offset, targetStrides),

3419 nonReducedType.getMemorySpace());

3420 }

3421};

3422

3423

3429

3430void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,

3431 MLIRContext *context) {

3432 results

3433 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<

3434 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,

3435 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);

3436}

3437

3438OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {

3439 MemRefType sourceMemrefType = getSource().getType();

3440 MemRefType resultMemrefType = getResult().getType();

3441 auto resultLayout =

3442 dyn_cast_if_present(resultMemrefType.getLayout());

3443

3444 if (resultMemrefType == sourceMemrefType &&

3445 resultMemrefType.hasStaticShape() &&

3446 (!resultLayout || resultLayout.hasStaticLayout())) {

3447 return getViewSource();

3448 }

3449

3450

3451

3452

3453 if (auto srcSubview = getViewSource().getDefiningOp()) {

3454 auto srcSizes = srcSubview.getMixedSizes();

3456 auto offsets = getMixedOffsets();

3457 bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);

3458 auto strides = getMixedStrides();

3459 bool allStridesOne = llvm::all_of(strides, isOneInteger);

3460 bool allSizesSame = llvm::equal(sizes, srcSizes);

3461 if (allOffsetsZero && allStridesOne && allSizesSame &&

3462 resultMemrefType == sourceMemrefType)

3463 return getViewSource();

3464 }

3465

3466 return {};

3467}

3468

3469FailureOr<std::optional<SmallVector>>

3470SubViewOp::bubbleDownCasts(OpBuilder &builder) {

3472}

3473

3474void SubViewOp::inferStridedMetadataRanges(

3475 ArrayRef ranges, GetIntRangeFn getIntRange,

3477 auto isUninitialized =

3478 +[](IntegerValueRange range) { return range.isUninitialized(); };

3479

3480

3481 SmallVector offsetOperands =

3482 getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth);

3483 if (llvm::any_of(offsetOperands, isUninitialized))

3484 return;

3485

3486 SmallVector sizeOperands =

3488 if (llvm::any_of(sizeOperands, isUninitialized))

3489 return;

3490

3491 SmallVector stridesOperands =

3492 getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth);

3493 if (llvm::any_of(stridesOperands, isUninitialized))

3494 return;

3495

3496 StridedMetadataRange sourceRange =

3497 ranges[getSourceMutable().getOperandNumber()];

3499 return;

3500

3501 ArrayRef srcStrides = sourceRange.getStrides();

3502

3503

3504 llvm::SmallBitVector droppedDims = getDroppedDims();

3505

3506

3507 ConstantIntRanges offset = sourceRange.getOffsets()[0];

3508 SmallVector strides, sizes;

3509

3510 for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {

3511 bool dropped = droppedDims.test(i);

3512

3513 ConstantIntRanges off =

3516

3517

3518 if (dropped)

3519 continue;

3520

3521 strides.push_back(

3522 intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]}));

3523

3524 sizes.push_back(sizeOperands[i].getValue());

3525 }

3526

3527 setMetadata(getResult(),

3529 SmallVector({std::move(offset)}),

3530 std::move(sizes), std::move(strides)));

3531}

3532

3533

3534

3535

3536

3537void TransposeOp::getAsmResultNames(

3538 function_ref<void(Value, StringRef)> setNameFn) {

3539 setNameFn(getResult(), "transpose");

3540}

3541

3542

3545 auto originalSizes = memRefType.getShape();

3546 auto [originalStrides, offset] = memRefType.getStridesAndOffset();

3547 assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));

3548

3549

3552

3556 StridedLayoutAttr::get(memRefType.getContext(), offset, strides));

3557}

3558

3559void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,

3560 AffineMapAttr permutation,

3561 ArrayRef attrs) {

3562 auto permutationMap = permutation.getValue();

3563 assert(permutationMap);

3564

3565 auto memRefType = llvm::cast(in.getType());

3566

3568

3569 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);

3570 build(b, result, resultType, in, attrs);

3571}

3572

3573

3574void TransposeOp::print(OpAsmPrinter &p) {

3575 p << " " << getIn() << " " << getPermutation();

3577 p << " : " << getIn().getType() << " to " << getType();

3578}

3579

3580ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {

3581 OpAsmParser::UnresolvedOperand in;

3582 AffineMap permutation;

3583 MemRefType srcType, dstType;

3590 return failure();

3591

3592 result.addAttribute(TransposeOp::getPermutationAttrStrName(),

3593 AffineMapAttr::get(permutation));

3595}

3596

3597LogicalResult TransposeOp::verify() {

3599 return emitOpError("expected a permutation map");

3600 if (getPermutation().getNumDims() != getIn().getType().getRank())

3601 return emitOpError("expected a permutation map of same rank as the input");

3602

3603 auto srcType = llvm::cast(getIn().getType());

3604 auto resultType = llvm::cast(getType());

3606 .canonicalizeStridedLayout();

3607

3608 if (resultType.canonicalizeStridedLayout() != canonicalResultType)

3610 << resultType

3611 << " is not equivalent to the canonical transposed input type "

3612 << canonicalResultType;

3614}

3615

3616OpFoldResult TransposeOp::fold(FoldAdaptor) {

3617

3618

3619 if (getPermutation().isIdentity() && getType() == getIn().getType())

3620 return getIn();

3621

3622

3623 if (auto otherTransposeOp = getIn().getDefiningOpmemref::TransposeOp()) {

3624 AffineMap composedPermutation =

3625 getPermutation().compose(otherTransposeOp.getPermutation());

3626 getInMutable().assign(otherTransposeOp.getIn());

3627 setPermutation(composedPermutation);

3628 return getResult();

3629 }

3630 return {};

3631}

3632

3633FailureOr<std::optional<SmallVector>>

3634TransposeOp::bubbleDownCasts(OpBuilder &builder) {

3636}

3637

3638

3639

3640

3641

3642void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {

3643 setNameFn(getResult(), "view");

3644}

3645

3646LogicalResult ViewOp::verify() {

3647 auto baseType = llvm::cast(getOperand(0).getType());

3648 auto viewType = getType();

3649

3650

3651 if (!baseType.getLayout().isIdentity())

3652 return emitError("unsupported map for base memref type ") << baseType;

3653

3654

3655 if (!viewType.getLayout().isIdentity())

3656 return emitError("unsupported map for result memref type ") << viewType;

3657

3658

3659 if (baseType.getMemorySpace() != viewType.getMemorySpace())

3660 return emitError("different memory spaces specified for base memref "

3661 "type ")

3662 << baseType << " and view memref type " << viewType;

3663

3664

3665 unsigned numDynamicDims = viewType.getNumDynamicDims();

3666 if (getSizes().size() != numDynamicDims)

3667 return emitError("incorrect number of size operands for type ") << viewType;

3668

3670}

3671

3672Value ViewOp::getViewSource() { return getSource(); }

3673

3674OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {

3675 MemRefType sourceMemrefType = getSource().getType();

3676 MemRefType resultMemrefType = getResult().getType();

3677

3678 if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape())

3679 return getViewSource();

3680

3681 return {};

3682}

3683

3684namespace {

3685

3686struct ViewOpShapeFolder : public OpRewritePattern {

3687 using OpRewritePattern::OpRewritePattern;

3688

3689 LogicalResult matchAndRewrite(ViewOp viewOp,

3690 PatternRewriter &rewriter) const override {

3691

3692 if (llvm::none_of(viewOp.getOperands(), [](Value operand) {

3693 return matchPattern(operand, matchConstantIndex());

3694 }))

3695 return failure();

3696

3697

3698 auto memrefType = viewOp.getType();

3699

3700

3701 int64_t oldOffset;

3702 SmallVector<int64_t, 4> oldStrides;

3703 if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))

3704 return failure();

3705 assert(oldOffset == 0 && "Expected 0 offset");

3706

3707 SmallVector<Value, 4> newOperands;

3708

3709

3710

3711

3712 SmallVector<int64_t, 4> newShapeConstants;

3713 newShapeConstants.reserve(memrefType.getRank());

3714

3715 unsigned dynamicDimPos = 0;

3716 unsigned rank = memrefType.getRank();

3717 for (unsigned dim = 0, e = rank; dim < e; ++dim) {

3718 int64_t dimSize = memrefType.getDimSize(dim);

3719

3720 if (ShapedType::isStatic(dimSize)) {

3721 newShapeConstants.push_back(dimSize);

3722 continue;

3723 }

3724 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();

3725 if (auto constantIndexOp =

3726 dyn_cast_or_nullarith::ConstantIndexOp(defOp)) {

3727

3728 newShapeConstants.push_back(constantIndexOp.value());

3729 } else {

3730

3731 newShapeConstants.push_back(dimSize);

3732 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);

3733 }

3734 dynamicDimPos++;

3735 }

3736

3737

3738 MemRefType newMemRefType =

3739 MemRefType::Builder(memrefType).setShape(newShapeConstants);

3740

3741 if (newMemRefType == memrefType)

3742 return failure();

3743

3744

3745 auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,

3746 viewOp.getOperand(0), viewOp.getByteShift(),

3747 newOperands);

3748

3749 rewriter.replaceOpWithNewOp(viewOp, viewOp.getType(), newViewOp);

3751 }

3752};

3753

3754struct ViewOpMemrefCastFolder : public OpRewritePattern {

3755 using OpRewritePattern::OpRewritePattern;

3756

3757 LogicalResult matchAndRewrite(ViewOp viewOp,

3758 PatternRewriter &rewriter) const override {

3759 Value memrefOperand = viewOp.getOperand(0);

3760 CastOp memrefCastOp = memrefOperand.getDefiningOp();

3761 if (!memrefCastOp)

3762 return failure();

3763 Value allocOperand = memrefCastOp.getOperand();

3764 AllocOp allocOp = allocOperand.getDefiningOp();

3765 if (!allocOp)

3766 return failure();

3767 rewriter.replaceOpWithNewOp(viewOp, viewOp.getType(), allocOperand,

3768 viewOp.getByteShift(),

3769 viewOp.getSizes());

3771 }

3772};

3773

3774}

3775

3776void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,

3777 MLIRContext *context) {

3778 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);

3779}

3780

3781FailureOr<std::optional<SmallVector>>

3782ViewOp::bubbleDownCasts(OpBuilder &builder) {

3784}

3785

3786

3787

3788

3789

3790LogicalResult AtomicRMWOp::verify() {

3791 if (getMemRefType().getRank() != getNumOperands() - 2)

3793 "expects the number of subscripts to be equal to memref rank");

3794 switch (getKind()) {

3795 case arith::AtomicRMWKind::addf:

3796 case arith::AtomicRMWKind::maximumf:

3797 case arith::AtomicRMWKind::minimumf:

3798 case arith::AtomicRMWKind::mulf:

3799 if (!llvm::isa(getValue().getType()))

3801 << arith::stringifyAtomicRMWKind(getKind())

3802 << "' expects a floating-point type";

3803 break;

3804 case arith::AtomicRMWKind::addi:

3805 case arith::AtomicRMWKind::maxs:

3806 case arith::AtomicRMWKind::maxu:

3807 case arith::AtomicRMWKind::mins:

3808 case arith::AtomicRMWKind::minu:

3809 case arith::AtomicRMWKind::muli:

3810 case arith::AtomicRMWKind::ori:

3811 case arith::AtomicRMWKind::xori:

3812 case arith::AtomicRMWKind::andi:

3813 if (!llvm::isa(getValue().getType()))

3815 << arith::stringifyAtomicRMWKind(getKind())

3816 << "' expects an integer type";

3817 break;

3818 default:

3819 break;

3820 }

3822}

3823

3824OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {

3825

3827 return getResult();

3828 return OpFoldResult();

3829}

3830

3831FailureOr<std::optional<SmallVector>>

3832AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {

3834 getResult());

3835}

3836

3837

3838

3839

3840

3841#define GET_OP_CLASSES

3842#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"

p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")

Given a list of lists of parsed operands, populates uniqueOperands with unique operands.

static bool hasSideEffects(Operation *op)

static bool isPermutation(const std::vector< PermutationTy > &permutation)

static Type getElementType(Type type)

Determine the element type of type.

static int64_t getNumElements(Type t)

Compute the total number of elements in the given type, also taking into account nested types.

b

Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...

static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)

Helper function that sets values[i] to constValues[i] if the latter is a static value,...

Definition MemRefOps.cpp:96

static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)

Definition MemRefOps.cpp:1563

static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)

Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.

Definition MemRefOps.cpp:2234

static bool isOpItselfPotentialAutomaticAllocation(Operation *op)

Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.

Definition MemRefOps.cpp:436

static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)

Build a strided memref type by applying permutationMap to memRefType.

Definition MemRefOps.cpp:3543

static bool isGuaranteedAutomaticAllocation(Operation *op)

Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...

Definition MemRefOps.cpp:417

static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)

Compute the layout map after expanding a given source MemRef type with the specified reassociation in...

Definition MemRefOps.cpp:2327

static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)

Return true if t1 and t2 have equal offsets (both dynamic or of same static value).

Definition MemRefOps.cpp:3037

static LogicalResult FoldCopyOfCast(CopyOp op)

If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...

Definition MemRefOps.cpp:853

static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)

Helper function to perform the replacement of all constant uses of values by a materialized constant ...

Definition MemRefOps.cpp:1393

static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)

Definition MemRefOps.cpp:3070

static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)

Compute the canonical result type of a SubViewOp.

Definition MemRefOps.cpp:3208

static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)

Definition MemRefOps.cpp:1577

static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)

Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...

Definition MemRefOps.cpp:117

static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)

Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...

Definition MemRefOps.cpp:940

static bool isTrivialSubViewOp(SubViewOp subViewOp)

Helper method to check if a subview operation is trivially a no-op.

Definition MemRefOps.cpp:3272

static bool lastNonTerminatorInRegion(Operation *op)

Return whether this op is the last non terminating op in a region.

Definition MemRefOps.cpp:459

static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)

Return a map with key being elements in vals and data being number of occurences of it.

Definition MemRefOps.cpp:925

static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)

Return true if t1 and t2 have equal strides (both dynamic or of same static value).

Definition MemRefOps.cpp:3048

static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)

Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...

Definition MemRefOps.cpp:2527

static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)

Implementation of bubbleDownCasts method for memref operations that return a single memref result.

Definition MemRefOps.cpp:148

static LogicalResult verifyAllocLikeOp(AllocLikeOp op)

Definition MemRefOps.cpp:188

static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)

Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....

static ArrayRef< int64_t > getShape(Type type)

Returns the shape of the given type.

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

@ Square

Square brackets surrounding zero or more operands.

virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0

Parse a colon followed by a type list, which must have at least one type.

virtual Builder & getBuilder() const =0

Return a builder which provides useful access to MLIRContext, global objects like types and attribute...

virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0

Parse a named dictionary into 'result' if it is present.

virtual ParseResult parseOptionalEqual()=0

Parse a = token if present.

virtual ParseResult parseOptionalKeyword(StringRef keyword)=0

Parse the given keyword if present.

MLIRContext * getContext() const

virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0

Emit a diagnostic at the specified location and return failure.

virtual ParseResult parseAffineMap(AffineMap &map)=0

Parse an affine map instance into 'map'.

ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)

Add the specified type to the end of the specified type list and return success.

virtual ParseResult parseLess()=0

Parse a '<' token.

virtual ParseResult parseColonType(Type &result)=0

Parse a colon followed by a type.

virtual SMLoc getNameLoc() const =0

Return the location of the original name token.

virtual ParseResult parseGreater()=0

Parse a '>' token.

virtual ParseResult parseType(Type &result)=0

Parse a type.

virtual ParseResult parseComma()=0

Parse a , token.

virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0

Parse an optional arrow followed by a type list.

ParseResult parseKeywordType(const char *keyword, Type &result)

Parse a keyword followed by a type.

ParseResult parseKeyword(StringRef keyword)

Parse a given keyword.

virtual ParseResult parseAttribute(Attribute &result, Type type={})=0

Parse an arbitrary attribute of a given type and return it in result.

virtual void printAttributeWithoutType(Attribute attr)

Print the given attribute without its type.

Attributes are known-constant values of operations.

This class provides a shared interface for ranked and unranked memref types.

ArrayRef< int64_t > getShape() const

Returns the shape of this memref type.

FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const

Clone this type with the given memory space and element type.

bool hasRank() const

Returns if this type is ranked, i.e. it has a known number of dimensions.

Block represents an ordered list of Operations.

Operation * getTerminator()

Get the terminator operation of this block.

bool mightHaveTerminator()

Return "true" if this block might have a terminator.

This class is a general helper class for creating context-global objects like types,...

IntegerAttr getIndexAttr(int64_t value)

IntegerType getIntegerType(unsigned width)

BoolAttr getBoolAttr(bool value)

IRValueT get() const

Return the current value being used by this operand.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

This is a builder type that keeps local references to arguments.

Builder & setShape(ArrayRef< int64_t > newShape)

Builder & setLayout(MemRefLayoutAttrInterface newLayout)

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0

Parses a region.

ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)

Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...

virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0

Resolve an operand to an SSA value, emitting an error on failure.

ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)

Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...

virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0

Parse a single SSA value operand name along with a result number if allowResultNumber is true.

virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0

Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

void printOperands(const ContainerType &container)

Print a comma separated list of operands.

virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0

If the specified operation has attributes, print out an attribute dictionary with their values.

virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0

Prints a region.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

This class represents a single result from folding an operation.

This class represents an operand of an operation.

unsigned getOperandNumber()

Return which operand this is in the OpOperand list of the Operation.

A trait of region holding operations that define a new scope for automatic allocations,...

This trait indicates that the memory effects of an operation includes the effects of operations neste...

Simple wrapper around a void* in order to express generically how to pass in op properties through AP...

type_range getType() const

Operation is the basic unit of execution within MLIR.

void replaceUsesOfWith(Value from, Value to)

Replace any uses of 'from' with 'to' within this operation.

bool hasTrait()

Returns true if the operation was registered with a particular trait, e.g.

Block * getBlock()

Returns the operation block that contains this operation.

Operation * getParentOp()

Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...

MutableArrayRef< OpOperand > getOpOperands()

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

MutableArrayRef< Region > getRegions()

Returns the regions held by this operation.

operand_range getOperands()

Returns an iterator on the underlying Value's.

result_range getResults()

Region * getParentRegion()

Returns the region to which the instruction belongs.

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class represents a point being branched from in the methods of the RegionBranchOpInterface.

bool isParent() const

Returns true if branching from the parent op.

This class provides an abstraction over the different types of ranges over Regions.

This class represents a successor of a region.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

BlockArgument addArgument(Type type, Location loc)

Add one value to the argument list.

bool hasOneBlock()

Return true if this region has exactly one block.

RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)

Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})

Inline the operations of block 'source' into block 'dest' before the given position.

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class represents a collection of SymbolTables.

virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)

Returns the operation registered with the given symbol name within the closest parent operation of,...

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

MLIRContext * getContext() const

Return the MLIRContext in which this type was uniqued.

This class provides an abstraction over the different types of ranges over Values.

type_range getTypes() const

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

static WalkResult advance()

static WalkResult interrupt()

static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)

Speculatability

This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...

constexpr auto Speculatable

constexpr auto NotSpeculatable

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)

Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.

ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)

ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)

ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)

Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...

Type getTensorTypeFromMemRefType(Type type)

Return an unranked/ranked tensor type for the given unranked/ranked memref type.

Definition MemRefOps.cpp:60

OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)

Return the dimension of the given memref value.

Definition MemRefOps.cpp:68

LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)

This is a common utility used for patterns of the form "someop(memref.cast) -> someop".

Definition MemRefOps.cpp:45

SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)

Return the dimensions of the given memref value.

Definition MemRefOps.cpp:77

Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)

Create a rank-reducing SubViewOp @[0 .

Definition MemRefOps.cpp:3240

Operation::operand_range getIndices(Operation *op)

Get the indices that the given load/store operation is operating on.

DynamicAPInt getIndex(const ConeV &cone)

Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...

Value constantIndex(OpBuilder &builder, Location loc, int64_t i)

Generates a constant of index type.

MemRefType getMemRefType(T &&t)

Convenience method to abbreviate casting getType().

SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)

Return the dimensions of the given tensor value.

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)

Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...

detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)

Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...

SliceVerificationResult

Enum that captures information related to verifier error conditions on slice insert/extract type of o...

constexpr StringRef getReassociationAttrName()

Attribute name for the ArrayAttr which encodes reassociation indices.

detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)

llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn

Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.

static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)

SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)

Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)

Return the list of Range (i.e.

Definition MemRefOps.cpp:3175

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)

Constructs affine maps out of Array<Array>.

bool isMemoryEffectFree(Operation *op)

Returns true if the given operation is free of memory effects.

bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)

Helper function to check whether the passed in sizes or offsets are valid.

SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims

SmallVector< IntegerValueRange > getIntValueRanges(ArrayRef< OpFoldResult > values, GetIntRangeFn getIntRange, int32_t indexBitwidth)

Helper function to collect the integer range values of an array of op fold results.

std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue

If Ty is mlir::Type this will select Value instead of having a wrapper around it.

bool isZeroInteger(OpFoldResult v)

Return true if v is an IntegerAttr with value 0.

bool hasValidStrides(SmallVector< int64_t > strides)

Helper function to check whether the passed in strides are valid.

void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)

Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...

SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)

Convert reassociation indices to affine expressions.

std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)

Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.

SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)

Apply a permutation from map to source and return the result.

OpFoldResult getAsOpFoldResult(Value val)

Given a value, try to extract a constant Attribute.

function_ref< void(Value, const StridedMetadataRange &)> SetStridedMetadataRangeFn

Callback function type for setting the strided metadata of a value.

std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)

Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...

SmallVector< int64_t, 2 > ReassociationIndices

SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)

Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...

ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)

Wraps a list of reassociations in an ArrayAttr.

llvm::function_ref< Fn > function_ref

bool isOneInteger(OpFoldResult v)

Return true if v is an IntegerAttr with value 1.

std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)

Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.

function_ref< IntegerValueRange(Value)> GetIntRangeFn

Helper callback type to get the integer range of a value.

Move allocations into an allocation scope, if it is legal to move them (e.g.

Definition MemRefOps.cpp:507

LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override

Definition MemRefOps.cpp:510

Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...

Definition MemRefOps.cpp:467

LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override

Definition MemRefOps.cpp:470

Definition MemRefOps.cpp:2691

LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override

Definition MemRefOps.cpp:2695

A canonicalizer wrapper to replace SubViewOps.

Definition MemRefOps.cpp:3424

void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)

Definition MemRefOps.cpp:3425

Return the canonical type of the result of a subview.

Definition MemRefOps.cpp:3387

MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)

Definition MemRefOps.cpp:3388

This is the representation of an operand reference.

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

This represents an operation in an abstracted form, suitable for use with the builder APIs.

Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...

static SaturatedInteger wrap(int64_t v)

bool isValid

If set to "true", the slice bounds verification was successful.

std::string errorMessage

An error message that can be printed during op verification.

Eliminates variable at the specified position using Fourier-Motzkin variable elimination.