MLIR: lib/Dialect/Tensor/IR/TensorOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

33#include "llvm/ADT/DenseSet.h"

34#include "llvm/ADT/STLExtras.h"

35#include "llvm/ADT/SmallBitVector.h"

36#include "llvm/ADT/StringRef.h"

37#include "llvm/Support/Casting.h"

38#include "llvm/Support/MathExtras.h"

39#include

40

41using namespace mlir;

43

44

45

49 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))

50 return op;

51 if (complex::ConstantOp::isBuildableWith(value, type))

52 return complex::ConstantOp::create(builder, loc, type,

53 llvm::cast(value));

54 return nullptr;

55}

56

59 auto tensorType = llvm::cast(value.getType());

60 if (tensorType.isDynamicDim(dim))

61 return builder.createOrFoldtensor::DimOp(loc, value, dim);

62

63 return builder.getIndexAttr(tensorType.getDimSize(dim));

64}

65

68 auto tensorType = llvm::cast(value.getType());

70 for (int64_t i = 0; i < tensorType.getRank(); ++i)

73}

74

77 auto tensorType = llvm::dyn_cast(opResult.getType());

78 assert(tensorType && "expected tensor type");

79

80

81

82 auto destOp = opResult.getDefiningOp();

83 if (destOp)

84 return destOp.getTiedOpOperand(opResult)->get();

85

86

89

90

92 if (!tensorType.hasStaticShape()) {

93

96 return failure();

98 } else {

99

100 for (int64_t sz : tensorType.getShape())

101 mixedSizes.push_back(b.getIndexAttr(sz));

102 }

103

104

105 Value emptyTensor =

106 tensor::EmptyOp::create(b, loc, mixedSizes, tensorType.getElementType());

107 return emptyTensor;

108}

109

114 if (llvm::isa(opResult.getType())) {

116 if (failed(destination))

117 return failure();

118 result.push_back(*destination);

119 }

120 }

122}

123

125 if (auto rtp1 = llvm::dyn_cast(tp1)) {

126 if (auto rtp2 = llvm::dyn_cast(tp2))

127 return rtp1.getShape() == rtp2.getShape() &&

128 rtp1.getElementType() == rtp2.getElementType();

129 return false;

130 }

131 return tp1 == tp2;

132}

133

134

135

138 llvm::SmallBitVector droppedDims(mixedSizes.size());

139 int64_t shapePos = reducedShape.size() - 1;

140

141 for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {

142 size_t idx = mixedSizes.size() - size.index() - 1;

143

144 bool isStaticUnitSize =

145 isa(size.value()) &&

146 llvm::cast(cast(size.value())).getInt() == 1;

147

148 if (shapePos < 0) {

149

150

151 assert(isStaticUnitSize && "expected unit dim");

152 droppedDims.set(idx);

153 continue;

154 }

155

156

157 if (!isStaticUnitSize) {

158 --shapePos;

159 continue;

160 }

161

162

163 if (reducedShape[shapePos] == 1) {

164 --shapePos;

165 continue;

166 }

167

168

169 droppedDims.set(idx);

170 }

171

172 assert(shapePos < 0 && "dimension mismatch");

173 return droppedDims;

174}

175

176

177

178

179static RankedTensorType

183 assert(type.getNumDynamicDims() == dynamicSizes.size() &&

184 "incorrect number of dynamic sizes");

185

186

187 unsigned ctr = 0;

188 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {

189 if (type.isDynamicDim(i)) {

190 Value dynamicSize = dynamicSizes[ctr++];

192 if (cst.has_value()) {

193

194 if (cst.value() < 0) {

195 foldedDynamicSizes.push_back(dynamicSize);

196 continue;

197 }

198 staticShape[i] = *cst;

199 } else {

200 foldedDynamicSizes.push_back(dynamicSize);

201 }

202 }

203 }

204

205 return RankedTensorType::get(staticShape, type.getElementType(),

206 type.getEncoding());

207}

208

209

210

211

212

213bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

214 if (inputs.size() != 1 || outputs.size() != 1)

215 return false;

216 Type a = inputs.front(), b = outputs.front();

217 auto aT = dyn_cast(a);

218 auto bT = dyn_cast(b);

219 if (!aT || !bT)

220 return false;

221

222 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())

223 return false;

224

226}

227

228namespace {

229

230

231

232struct ChainedTensorBitcast : public OpRewritePattern {

233 using OpRewritePattern::OpRewritePattern;

234

235 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,

236 PatternRewriter &rewriter) const final {

237 auto tensorBitcastOperand =

238 tensorBitcast.getOperand().getDefiningOp();

239 if (!tensorBitcastOperand)

240 return failure();

241

242 auto resultType = cast(tensorBitcast.getType());

243 rewriter.replaceOpWithNewOp(tensorBitcast, resultType,

244 tensorBitcastOperand.getOperand());

246 }

247};

248

249}

250

251void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,

253 results.add(context);

254}

255

256

257

258

259

260void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {

261 setNameFn(getResult(), "cast");

262}

263

264

265

267 auto sourceType = llvm::dyn_cast(source);

268 auto targetType = llvm::dyn_cast(target);

269

270

271 if (!sourceType || !targetType)

272 return false;

273

274

275 if (sourceType.getElementType() != targetType.getElementType())

276 return false;

277

278

279 if (sourceType.getRank() != targetType.getRank())

280 return false;

281

282

283 if (sourceType.getEncoding() != targetType.getEncoding())

284 return false;

285

286

287 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {

288 if (ShapedType::isStatic(std::get<0>(t)) &&

289 ShapedType::isDynamic(std::get<1>(t)))

290 return false;

291 }

292

293 return true;

294}

295

296

297

298

299

300

301

302

303

304

305

306

307

308

309

310

311

312

313

314

315

316

317

319 if (!castOp)

320 return false;

321

322

323

325 castOp.getSource().getType());

326}

327

328

329

330

331

332

333

334

335

336

337

338

339

340

341

342

343

344

345

346

347

348

350 if (!castOp)

351 return false;

353 castOp.getType());

354}

355

358 if (llvm::isa(opOperand.get()))

359 return false;

360 auto castOp = opOperand.get().getDefiningOptensor::CastOp();

361 return castOp && canFoldIntoConsumerOp(castOp);

362 });

363}

364

368 newOperands.reserve(op->getNumOperands());

369

371

372

374 for (OpOperand &opOperand : op->getOpOperands()) {

375 auto tensorCastOp = opOperand.get().getDefiningOptensor::CastOp();

377 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());

378 if (op.isDpsInit(&opOperand) &&

379 !llvm::isa(newOperands.back().getType()))

380 newResTy[dpsInitIdx++] = newOperands.back().getType();

381 }

382 return newOperands;

383}

384

385

386

388 bool folded = false;

390 auto castOp = operand.get().getDefiningOptensor::CastOp();

392 operand.set(castOp.getOperand());

393 folded = true;

394 }

395 }

397}

398

400 if (inputs.size() != 1 || outputs.size() != 1)

401 return false;

402 Type a = inputs.front(), b = outputs.front();

403 auto aT = llvm::dyn_cast(a);

404 auto bT = llvm::dyn_cast(b);

405 if (!aT || !bT)

406 return false;

407

408 if (aT.getElementType() != bT.getElementType())

409 return false;

410

412}

413

414

415

418

420 return two;

422 return one;

423

424 int64_t rank = one.getRank();

425 if (rank != two.getRank())

426 return {};

427

429 join.reserve(rank);

430 for (int64_t i = 0; i < rank; ++i) {

431 if (one.isDynamicDim(i)) {

432 join.push_back(two.getDimSize(i));

433 continue;

434 }

435 if (two.isDynamicDim(i)) {

436 join.push_back(one.getDimSize(i));

437 continue;

438 }

439 if (one.getDimSize(i) != two.getDimSize(i))

440 return {};

441 join.push_back(one.getDimSize(i));

442 }

443 return RankedTensorType::get(join, one.getElementType());

444}

445

446namespace {

447

448

449

451 using OpRewritePattern::OpRewritePattern;

452

453 LogicalResult matchAndRewrite(CastOp tensorCast,

454 PatternRewriter &rewriter) const final {

455 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp();

456

457 if (!tensorCastOperand)

458 return failure();

459

460 auto sourceType =

461 llvm::cast(tensorCastOperand.getOperand().getType());

462 auto intermediateType = llvm::cast(tensorCastOperand.getType());

463 auto resultType = llvm::cast(tensorCast.getType());

464

465

466

467 auto firstJoin =

469

470

471 if (!firstJoin)

472 return failure();

473

474

475

476

477 auto newJoin = joinShapes(sourceType, resultType);

478 if (firstJoin != newJoin)

479 return failure();

480

481 rewriter.replaceOpWithNewOp(tensorCast, resultType,

482 tensorCastOperand.getOperand());

484 }

485};

486

487

488

489

490

491

492

493

494

495

496

497

498

499struct TensorCastExtractSlice : public OpRewritePattern {

500 using OpRewritePattern::OpRewritePattern;

501

502 LogicalResult matchAndRewrite(CastOp tensorCast,

503 PatternRewriter &rewriter) const final {

504 auto extractOperand =

505 tensorCast.getOperand().getDefiningOp();

506

507

508 auto rankedResultType =

509 llvm::dyn_cast(tensorCast.getType());

510 if (!rankedResultType)

511 return failure();

512

514 rankedResultType.getShape() ==

515 llvm::cast(tensorCast.getSource().getType())

516 .getShape())

517 return failure();

518

519 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();

521 extractOperand.getStaticSizes(), extractOperand.getType().getShape());

522 size_t dimIndex = 0;

523 for (size_t i = 0, e = sizes.size(); i < e; i++) {

524 if (dimMask && dimMask->count(i))

525 continue;

526 int64_t dim = rankedResultType.getShape()[dimIndex++];

527 if (ShapedType::isDynamic(dim))

528 continue;

529 sizes[i] = rewriter.getIndexAttr(dim);

530 }

531

532 rewriter.replaceOpWithNewOp(

533 tensorCast, rankedResultType, extractOperand.getSource(),

534 extractOperand.getMixedOffsets(), sizes,

535 extractOperand.getMixedStrides());

537 }

538};

539

540}

541

542void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,

544 results.add<ChainedTensorCast, TensorCastExtractSlice>(context);

545}

546

547

548

549

550

551RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {

552 assert(!inputTypes.empty() && "cannot concatenate 0 tensors");

553 auto tensorTypes =

554 llvm::map_to_vector<4>(inputTypes, llvm::CastTo);

555 int64_t concatRank = tensorTypes[0].getRank();

556

557

558 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");

559

561 for (int64_t i = 0, e = concatRank; i < e; ++i) {

562 if (i == dim)

563 continue;

565 for (auto tensorType : tensorTypes)

568 }

570 for (auto tensorType : tensorTypes)

571 concatSize =

573 sizes[dim] = concatSize.asInteger();

574 return RankedTensorType::get(sizes, tensorTypes[0].getElementType());

575}

576

579 FailureOr resultType =

580 inferResultType(dim, inputs.getTypes());

581 assert(succeeded(resultType) && "failed to infer concatenation result type");

582 build(builder, result, *resultType, dim, inputs);

583}

584

585LogicalResult ConcatOp::verify() {

586 if (getInputs().size() < 1)

587 return emitOpError("requires at least one input");

588

590 for (auto input : getInputs())

591 inputTypes.push_back(cast(input.getType()));

592

593 RankedTensorType resultType = getResultType();

594 int64_t resultRank = getRank();

595 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {

596 return type.getRank() != resultRank;

597 }))

598 return emitOpError("rank of concatenated inputs must match result rank");

599

600 Type resultElementType = resultType.getElementType();

601 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {

602 return type.getElementType() != resultElementType;

603 }))

604 return emitOpError("inputs and result element type must match");

605

607 if (dim >= resultRank)

608 return emitOpError("concatenation dim must be less than the tensor rank");

609

611 for (int64_t i = 0, e = resultRank; i < e; ++i) {

612 if (i == dim)

613 continue;

615 for (auto tensorType : inputTypes) {

616 FailureOr maybeSize =

618 if (failed(maybeSize))

619 return emitOpError("static concatenation size mismatch along ")

620 << "non-concatenated dimension " << i;

621 size = *maybeSize;

622 }

624 }

626 for (auto tensorType : inputTypes)

627 concatSize =

629 sizes[dim] = concatSize.asInteger();

630 auto inferredResultType =

631 RankedTensorType::get(sizes, inputTypes[0].getElementType());

632

633 for (auto [inferredSize, actualSize] :

634 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {

635 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||

636 ShapedType::isDynamic(actualSize);

637 if (!hasDynamic && inferredSize != actualSize)

639 << resultType << "does not match inferred shape "

640 << inferredResultType << " static sizes";

641 }

642

644}

645

646FailureOr<SmallVector> ConcatOp::decomposeOperation(OpBuilder &builder) {

647 size_t numInputs = getInputs().size();

648 uint64_t concatDim = getDim();

649

651 inputShapes.reserve(numInputs);

653 concatOffsets.reserve(numInputs);

655

660 for (auto [index, input] : llvm::enumerate(getInputs())) {

663 if (index == 0) {

664 outputShape = inputShape;

665 concatOffsets.push_back(zero);

666 } else {

667 concatOffsets.push_back(outputShape[concatDim]);

669 builder, loc, addExpr,

670 {outputShape[concatDim], inputShape[concatDim]});

671 }

672 inputShapes.emplace_back(std::move(inputShape));

673 }

674

675 Value replacement = tensor::EmptyOp::create(builder, loc, outputShape,

677

682 for (auto [index, input] : llvm::enumerate(getInputs())) {

683 offsets[concatDim] = concatOffsets[index];

684 auto insertSlice = tensor::InsertSliceOp::create(

685 builder, loc, input, replacement, offsets, inputShapes[index], strides);

687 }

690 }

692}

693

694LogicalResult

695ConcatOp::reifyResultShapes(OpBuilder &builder,

699 RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());

700

701 Value init = inputs[0];

703

705

706

707

708

709 for (int64_t i = 0; i < rank; ++i) {

710 if (i == dim)

711 continue;

712 if (getType().isDynamicDim(i)) {

713 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));

714 } else if (!inferredResultType.isDynamicDim(i)) {

716 builder, getLoc(),

717 builder.getIndexAttr(inferredResultType.getDimSize(i)));

718 } else {

719 reifiedReturnShapes[0][i] =

720 tensor::DimOp::create(builder, init.getLoc(), init, i).getResult();

721 }

722 }

723

724 if (getType().isDynamicDim(dim)) {

725

729 for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {

731 sizes.push_back(

732 builder.createOrFoldtensor::DimOp(input.getLoc(), input, dim));

733 }

735 builder, getLoc(),

737 } else {

738

739

740 reifiedReturnShapes[0][dim] =

742 }

744}

745

746void ConcatOp::getAsmResultNames(

748 setNameFn(getResult(), "concat");

749}

750

753 if (inputs.size() == 1 && inputs[0].getType() == getResultType())

754 return inputs[0];

755 return {};

756}

757

758namespace {

759

760struct SingleInputConcatOp : public OpRewritePattern {

761 using OpRewritePattern::OpRewritePattern;

762

763 LogicalResult matchAndRewrite(ConcatOp concatOp,

764 PatternRewriter &rewriter) const override {

765 if (concatOp.getInputs().size() != 1)

766 return failure();

767 rewriter.replaceOpWithNewOp(concatOp, concatOp.getResultType(),

768 concatOp.getInputs()[0]);

770 }

771};

772

773

774

775

776

777

778

779

780

781

782

783

784

785

786

787

788

789

790

791

792struct InferConcatOperandTypes : public OpRewritePattern {

793 using OpRewritePattern::OpRewritePattern;

794

795 LogicalResult matchAndRewrite(ConcatOp concatOp,

796 PatternRewriter &rewriter) const override {

797 int64_t dim = concatOp.getDim();

798 RankedTensorType inferredResultType =

799 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());

800

801

802 LogicalResult matched = failure();

803

804

805 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());

806 for (auto [operandIdx, operandType] :

807 llvm::enumerate(concatOp->getOperandTypes())) {

808

809 inferredOperandShape[dim] =

810 cast(operandType).getDimSize(dim);

811 auto inferredOperandType = RankedTensorType::get(

812 inferredOperandShape, inferredResultType.getElementType());

813

814

817

818

819 auto castOp =

820 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,

821 concatOp.getOperand(operandIdx));

822 rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {

823 concatOp->setOperand(operandIdx, castOp->getResult(0));

824 });

825 }

826 }

827

828 return matched;

829 }

830};

831

832

833

834

835

836

837

838

839

840

841

842

843

844

845

846struct InferConcatResultType : public OpRewritePattern {

847 using OpRewritePattern::OpRewritePattern;

848

849 LogicalResult matchAndRewrite(ConcatOp concatOp,

850 PatternRewriter &rewriter) const override {

851 int64_t dim = concatOp.getDim();

852 RankedTensorType inferredResultType =

853 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());

854

855

857 concatOp.getResultType())) {

858 return failure();

859 }

860

861 auto newConcatOp =

862 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,

863 concatOp->getOperands());

864 rewriter.replaceOpWithNewOp(concatOp, concatOp.getResultType(),

865 newConcatOp);

866

868 }

869};

870}

871

872void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,

874 results

875 .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(

876 context);

877}

878

879

880

881

882

883void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {

884 setNameFn(getResult(), "dim");

885}

886

889 auto loc = result.location;

891 build(builder, result, source, indexValue);

892}

893

894std::optional<int64_t> DimOp::getConstantIndex() {

896}

897

900 if (!constantIndex)

902

903 auto rankedSourceType = dyn_cast(getSource().getType());

904 if (!rankedSourceType)

906

907 if (rankedSourceType.getRank() <= constantIndex)

909

911}

912

915 setResultRange(getResult(),

917}

918

919OpFoldResult DimOp::fold(FoldAdaptor adaptor) {

920

921 auto index = llvm::dyn_cast_if_present(adaptor.getIndex());

923 return {};

924

925

926 auto tensorType = llvm::dyn_cast(getSource().getType());

927 if (!tensorType)

928 return {};

929

930

931

933 if (indexVal < 0 || indexVal >= tensorType.getRank())

934 return {};

935

936

937 if (!tensorType.isDynamicDim(index.getInt())) {

939 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);

940 }

941

942 Operation *definingOp = getSource().getDefiningOp();

943

944

945 if (auto fromElements = dyn_cast_or_nulltensor::GenerateOp(definingOp)) {

946 auto resultType =

947 llvm::cast(fromElements.getResult().getType());

948

949

950 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));

951

952

953 auto dynExtents = fromElements.getDynamicExtents().begin();

954 for (auto dim : resultType.getShape().take_front(index.getInt()))

955 if (ShapedType::isDynamic(dim))

956 dynExtents++;

957

958 return Value{*dynExtents};

959 }

960

961

962 unsigned unsignedIndex = index.getValue().getZExtValue();

963

964 if (auto sliceOp = dyn_cast_or_nulltensor::ExtractSliceOp(definingOp)) {

965

966

967 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&

968 sliceOp.isDynamicSize(unsignedIndex)) {

969 return {sliceOp.getDynamicSize(unsignedIndex)};

970 }

971 }

972

973

975 return getResult();

976

977 return {};

978}

979

980namespace {

981

983 using OpRewritePattern::OpRewritePattern;

984

985 LogicalResult matchAndRewrite(DimOp dimOp,

986 PatternRewriter &rewriter) const override {

987 auto castOp = dimOp.getSource().getDefiningOp();

988 if (!castOp)

989 return failure();

990 Value newSource = castOp.getOperand();

991 rewriter.replaceOpWithNewOp(dimOp, newSource, dimOp.getIndex());

993 }

994};

995

996

997

999 using OpRewritePattern::OpRewritePattern;

1000

1001 LogicalResult matchAndRewrite(DimOp dimOp,

1002 PatternRewriter &rewriter) const override {

1003 auto source = dimOp.getSource();

1004 auto destOp = source.getDefiningOp();

1005 if (!destOp)

1006 return failure();

1007

1008 auto resultIndex = cast(source).getResultNumber();

1009 auto *initOperand = destOp.getDpsInitOperand(resultIndex);

1010

1012 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });

1014 }

1015};

1016

1017

1018

1020 using OpRewritePattern::OpRewritePattern;

1021

1022 LogicalResult matchAndRewrite(DimOp dim,

1023 PatternRewriter &rewriter) const override {

1024 auto reshape = dim.getSource().getDefiningOp();

1025

1026 if (!reshape)

1027 return failure();

1028

1029

1030

1032 Location loc = dim.getLoc();

1033 Value extract =

1034 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());

1035 if (extract.getType() != dim.getType())

1036 extract =

1037 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);

1038 rewriter.replaceOp(dim, extract);

1040 }

1041};

1042}

1043

1044void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,

1046 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);

1047}

1048

1049

1050

1051

1052

1056 assert(none_of(staticShape, ShapedType::isDynamic) &&

1057 "expected only static sizes");

1058 build(builder, result, staticShape, elementType, ValueRange{}, encoding);

1059}

1060

1061void EmptyOp::build(OpBuilder &builder, OperationState &result,

1062 ArrayRef<int64_t> staticShape, Type elementType,

1063 ValueRange dynamicSizes, Attribute encoding) {

1064 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);

1065 build(builder, result, tensorType, dynamicSizes);

1066}

1067

1068void EmptyOp::build(OpBuilder &builder, OperationState &result,

1069 ArrayRef sizes, Type elementType,

1070 Attribute encoding) {

1071 SmallVector<int64_t> staticShape;

1072 SmallVector dynamicSizes;

1074 build(builder, result, staticShape, elementType, dynamicSizes, encoding);

1075}

1076

1077LogicalResult EmptyOp::verify() {

1079 return emitOpError("incorrect number of dynamic sizes, has ")

1081 << getType().getNumDynamicDims();

1083}

1084

1085LogicalResult

1086EmptyOp::reifyResultShapes(OpBuilder &builder,

1088 reifiedReturnShapes.resize(1, SmallVector(getType().getRank()));

1089 unsigned ctr = 0;

1090 for (int64_t i = 0; i < getType().getRank(); ++i) {

1091 if (getType().isDynamicDim(i)) {

1093 } else {

1094 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));

1095 }

1096 }

1098}

1099

1100Value EmptyOp::getDynamicSize(unsigned idx) {

1101 assert(getType().isDynamicDim(idx) && "expected dynamic dim");

1102 unsigned ctr = 0;

1103 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)

1104 if (getType().isDynamicDim(i))

1105 ++ctr;

1107}

1108

1109SmallVector EmptyOp::getMixedSizes() {

1110 SmallVector result;

1111 unsigned ctr = 0;

1113 for (int64_t i = 0; i < getType().getRank(); ++i) {

1114 if (getType().isDynamicDim(i)) {

1116 } else {

1118 }

1119 }

1121}

1122

1123namespace {

1124

1125

1126

1127

1128

1129

1130

1131

1132

1133

1134

1135struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern {

1136 using OpRewritePattern::OpRewritePattern;

1137

1138 LogicalResult matchAndRewrite(EmptyOp op,

1139 PatternRewriter &rewriter) const override {

1140 SmallVector foldedDynamicSizes;

1142 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);

1143

1144

1145 if (foldedTensorType == op.getType())

1146 return failure();

1147

1148 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,

1149 foldedDynamicSizes);

1150 rewriter.replaceOpWithNewOptensor::CastOp(op, op.getType(), newOp);

1152 }

1153};

1154

1155struct FoldEmptyTensorWithDimOp : public OpRewritePattern {

1156 using OpRewritePattern::OpRewritePattern;

1157

1158 LogicalResult matchAndRewrite(tensor::DimOp dimOp,

1159 PatternRewriter &rewriter) const override {

1160 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();

1161 auto emptyTensorOp = dimOp.getSource().getDefiningOp();

1162 if (!emptyTensorOp || !maybeConstantIndex)

1163 return failure();

1164 auto emptyTensorType = emptyTensorOp.getType();

1165 if (*maybeConstantIndex < 0 ||

1166 *maybeConstantIndex >= emptyTensorType.getRank() ||

1167 !emptyTensorType.isDynamicDim(*maybeConstantIndex))

1168 return failure();

1170 emptyTensorOp.getDynamicSize(*maybeConstantIndex));

1172 }

1173};

1174

1175

1176

1177

1178

1179

1180

1181

1182

1183

1184

1185

1186

1187

1188

1189

1190struct FoldEmptyTensorWithCastOp : public OpRewritePattern {

1191 using OpRewritePattern::OpRewritePattern;

1192

1193 LogicalResult matchAndRewrite(CastOp castOp,

1194 PatternRewriter &rewriter) const override {

1196 return failure();

1197 auto producer = castOp.getSource().getDefiningOp();

1198 if (!producer)

1199 return failure();

1200

1201 auto resultType =

1202 llvm::cast(castOp->getResult(0).getType());

1203 ArrayRef<int64_t> resultShape = resultType.getShape();

1204 SmallVector currMixedSizes = producer.getMixedSizes();

1205 SmallVector newMixedSizes;

1206 newMixedSizes.reserve(currMixedSizes.size());

1207 assert(resultShape.size() == currMixedSizes.size() &&

1208 "mismatch in result shape and sizes of empty op");

1209 for (auto it : llvm::zip(resultShape, currMixedSizes)) {

1210 int64_t newDim = std::get<0>(it);

1211 OpFoldResult currDim = std::get<1>(it);

1212

1213

1214 if (auto attr = llvm::dyn_cast_if_present(currDim)) {

1215 if (ShapedType::isDynamic(newDim) ||

1216 newDim != llvm::cast(attr).getInt()) {

1217

1218

1219

1221 producer, "mismatch in static value of shape of empty tensor "

1222 "result and cast result");

1223 }

1224 newMixedSizes.push_back(attr);

1225 continue;

1226 }

1227

1228

1229

1230 if (ShapedType::isStatic(newDim)) {

1231 newMixedSizes.push_back(rewriter.getIndexAttr(newDim));

1232 continue;

1233 }

1234

1235

1236

1237 newMixedSizes.push_back(currDim);

1238 }

1239

1240

1242 resultType.getElementType());

1244 }

1245};

1246

1247}

1248

1249void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,

1250 MLIRContext *context) {

1251 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,

1252 ReplaceEmptyTensorStaticShapeDims>(context);

1253}

1254

1255

1256

1257

1258

1259namespace {

1260

1261

1262

1263

1264

1265

1266

1267

1268

1269struct ExtractFromTensorCast : public OpRewritePatterntensor::ExtractOp {

1270 using OpRewritePatterntensor::ExtractOp::OpRewritePattern;

1271

1272 LogicalResult matchAndRewrite(tensor::ExtractOp extract,

1273 PatternRewriter &rewriter) const final {

1274 auto tensorCast = extract.getTensor().getDefiningOptensor::CastOp();

1275 if (!tensorCast)

1276 return failure();

1277 if (!llvm::isa(tensorCast.getSource().getType()))

1278 return failure();

1280 extract, tensorCast.getSource(), extract.getIndices());

1282 }

1283};

1284

1285

1286

1287

1288

1289

1290

1291

1292

1293

1294

1295struct ExtractFromCollapseShape : public OpRewritePatterntensor::ExtractOp {

1296 using OpRewritePatterntensor::ExtractOp::OpRewritePattern;

1297

1298 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,

1299 PatternRewriter &rewriter) const final {

1300 auto collapseOp =

1301 extractOp.getTensor().getDefiningOptensor::CollapseShapeOp();

1302 if (!collapseOp)

1303 return failure();

1304 if (!collapseOp.getSrcType().hasStaticShape())

1305 return failure();

1306

1307 auto sourceSizes = collapseOp.getSrcType().getShape();

1308

1309 SmallVector indices(extractOp.getIndices().begin(),

1310 extractOp.getIndices().end());

1311 SmallVector sourceIndices;

1312 for (auto [index, group] :

1313 llvm::zip(indices, collapseOp.getReassociationIndices())) {

1314 assert(!group.empty() && "association indices groups cannot be empty");

1315 auto groupSize = group.size();

1316

1317 if (groupSize == 1) {

1318 sourceIndices.push_back(index);

1319 continue;

1320 }

1321

1322 SmallVector<int64_t> basis =

1323 llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });

1324 auto delinearize = affine::AffineDelinearizeIndexOp::create(

1325 rewriter, extractOp.getLoc(), index, basis, true);

1326 llvm::append_range(sourceIndices, delinearize.getResults());

1327 }

1328 if (collapseOp.getReassociationIndices().empty()) {

1330 int64_t srcRank =

1331 cast(collapseOp.getSrcType()).getRank();

1333 rewriter, extractOp.getLoc(), zeroAffineMap,

1334 ArrayRef{});

1335 for (int64_t i = 0; i < srcRank; i++) {

1336 sourceIndices.push_back(

1338 }

1339 }

1340

1342 extractOp, collapseOp.getSrc(), sourceIndices);

1344 }

1345};

1346

1347}

1348

1349void ExtractOp::getAsmResultNames(

1350 function_ref<void(Value, StringRef)> setNameFn) {

1351 setNameFn(getResult(), "extracted");

1352}

1353

1354LogicalResult ExtractOp::verify() {

1355

1356 auto tensorType = llvm::cast(getTensor().getType());

1357 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))

1358 return emitOpError("incorrect number of indices for extract_element");

1360}

1361

1362

1363

1364

1365

1367 auto insertOp = extractOp.getTensor().getDefiningOp();

1368

1371 };

1372 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&

1373 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))

1374 return insertOp.getScalar();

1375

1376 return {};

1377}

1378

1379OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {

1380 if (Attribute tensor = adaptor.getTensor()) {

1381

1382

1383 if (auto splatTensor = llvm::dyn_cast(tensor))

1384 return splatTensor.getSplatValue();

1385

1386

1387 if (isa(tensor))

1388 return {};

1389 }

1390

1391

1392 SmallVector<uint64_t, 8> indices;

1393 for (Attribute indice : adaptor.getIndices()) {

1394 if (!indice || !llvm::isa(indice))

1395 return {};

1396 indices.push_back(llvm::cast(indice).getInt());

1397 }

1398

1399

1400 if (auto fromElementsOp = getTensor().getDefiningOp()) {

1401 auto tensorType = llvm::cast(fromElementsOp.getType());

1402 auto rank = tensorType.getRank();

1403 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&

1404 "rank mismatch");

1405 int flatIndex = 0;

1406 int stride = 1;

1407 for (int i = rank - 1; i >= 0; --i) {

1408 flatIndex += indices[i] * stride;

1409 stride *= tensorType.getDimSize(i);

1410 }

1411

1412

1413 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||

1414 flatIndex < 0)

1415 return {};

1416 return fromElementsOp.getElements()[flatIndex];

1417 }

1418

1419

1420 if (Attribute tensor = adaptor.getTensor()) {

1421 auto elementsAttr = llvm::dyn_cast(tensor);

1422 if (elementsAttr && elementsAttr.isValidIndex(indices))

1423 return elementsAttr.getValues()[indices];

1424 }

1425

1428

1429 return {};

1430}

1431

1432void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,

1433 MLIRContext *context) {

1434 results.add(context);

1435}

1436

1441

1442

1443

1444

1445

1446void FromElementsOp::getAsmResultNames(

1448 setNameFn(getResult(), "from_elements");

1449}

1450

1453 assert(!elements.empty() && "expected at least one element");

1454 Type resultType = RankedTensorType::get(

1455 {static_cast<int64_t>(elements.size())}, elements.front().getType());

1456 build(builder, result, resultType, elements);

1457}

1458

1459OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {

1460 if (!llvm::is_contained(adaptor.getElements(), nullptr))

1462 return {};

1463}

1464

1465namespace {

1466

1467

1468

1469

1470

1471

1472

1473

1474

1475

1476

1477

1478

1479

1480

1481

1482

1483struct ExtractElementFromIndexCast

1484 : public OpRewritePatterntensor::ExtractOp {

1485 using OpRewritePatterntensor::ExtractOp::OpRewritePattern;

1486

1487 LogicalResult matchAndRewrite(tensor::ExtractOp extract,

1488 PatternRewriter &rewriter) const final {

1489 Location loc = extract.getLoc();

1490 auto indexCast = extract.getTensor().getDefiningOparith::IndexCastOp();

1491 if (!indexCast)

1492 return failure();

1493

1495

1496 auto newExtract = tensor::ExtractOp::create(

1497 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());

1498

1499 rewriter.replaceOpWithNewOparith::IndexCastOp(extract, extract.getType(),

1500 newExtract);

1501

1503 }

1504};

1505

1506}

1507

1508void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,

1509 MLIRContext *context) {

1510 results.add(context);

1511}

1512

1513

1514

1515

1516

1517void GatherOp::getAsmResultNames(

1518 function_ref<void(Value, StringRef)> setNameFn) {

1519 setNameFn(getResult(), "gather");

1520}

1521

1522

1523

1524

1525

1526

1527

1528

1529

1530

1531

1532

1533

1534RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,

1535 RankedTensorType indicesType,

1536 ArrayRef<int64_t> gatherDims,

1537 bool rankReduced) {

1538 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());

1539 resultShape.reserve(resultShape.size() + sourceType.getRank());

1540 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {

1541 if (llvm::binary_search(gatherDims, idx)) {

1542 if (!rankReduced)

1543 resultShape.push_back(1);

1544 continue;

1545 }

1546 resultShape.push_back(sourceType.getDimSize(idx));

1547 }

1548 return RankedTensorType::Builder(sourceType).setShape(resultShape);

1549}

1550

1551static LogicalResult

1554 StringRef gatherOrScatter, StringRef sourceOrDest) {

1555 if (dims.empty())

1556 return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";

1557

1558 int64_t numGatherDims = dims.size();

1559 if (numGatherDims > rank)

1561 << "_dims overflow " << sourceOrDest << " rank";

1562 if (indices.empty() || indices.back() != numGatherDims)

1564 << "_dims length must match the size of last dimension of indices";

1565 for (int64_t val : dims) {

1566 if (val < 0)

1568 << "_dims value must be non-negative";

1569 if (val >= rank)

1571 << "_dims value must be smaller than " << sourceOrDest << " rank";

1572 }

1573 for (int64_t i = 1; i < numGatherDims; ++i) {

1574 if (dims[i - 1] >= dims[i])

1576 << "_dims values must be strictly increasing";

1577 }

1579}

1580

1581LogicalResult GatherOp::verify() {

1582 int64_t sourceRank = getSourceType().getRank();

1583 ArrayRef<int64_t> gatherDims = getGatherDims();

1585 getIndicesType().getShape(), sourceRank,

1586 "gather", "source")))

1587 return failure();

1588

1589 RankedTensorType expectedResultType = GatherOp::inferResultType(

1590 getSourceType(), getIndicesType(), gatherDims, false);

1591 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(

1592 getSourceType(), getIndicesType(), gatherDims, true);

1593 if (getResultType() != expectedResultType &&

1594 getResultType() != expectedRankReducedResultType) {

1596 "mismatch: "

1597 "expected ")

1598 << expectedResultType << " or its rank-reduced variant "

1599 << expectedRankReducedResultType << " (got: " << getResultType()

1600 << ")";

1601 }

1602

1604}

1605

1606OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {

1607 if (OpFoldResult reshapedSource = reshapeConstantSource(

1608 llvm::dyn_cast_if_present(adaptor.getSource()),

1610 return reshapedSource;

1611 return {};

1612}

1613

1614

1615

1616

1617

1618void InsertOp::getAsmResultNames(

1619 function_ref<void(Value, StringRef)> setNameFn) {

1620 setNameFn(getResult(), "inserted");

1621}

1622

1623LogicalResult InsertOp::verify() {

1624

1625 auto destType = llvm::cast(getDest().getType());

1626 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))

1627 return emitOpError("incorrect number of indices");

1629}

1630

1631OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {

1632 Attribute scalar = adaptor.getScalar();

1633 Attribute dest = adaptor.getDest();

1634 if (scalar && dest)

1635 if (auto splatDest = llvm::dyn_cast(dest))

1636 if (scalar == splatDest.getSplatValue())

1637 return dest;

1638 return {};

1639}

1640

1641

1642

1643

1644

1645void GenerateOp::getAsmResultNames(

1646 function_ref<void(Value, StringRef)> setNameFn) {

1647 setNameFn(getResult(), "generated");

1648}

1649

1650LogicalResult GenerateOp::reifyResultShapes(

1652 reifiedReturnShapes.resize(1, SmallVector(getType().getRank()));

1653 int idx = 0;

1654 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {

1655 if (getType().isDynamicDim(dim)) {

1656 reifiedReturnShapes[0][dim] = getOperand(idx++);

1657 } else {

1658 reifiedReturnShapes[0][dim] =

1660 }

1661 }

1663}

1664

1665LogicalResult GenerateOp::verify() {

1666

1667

1668 RankedTensorType resultType = llvm::cast(getType());

1669 if (getNumOperands() != resultType.getNumDynamicDims())

1670 return emitError("must have as many index operands as dynamic extents "

1671 "in the result type");

1673}

1674

1675LogicalResult GenerateOp::verifyRegions() {

1676 RankedTensorType resultTy = llvm::cast(getType());

1677

1678 if (!llvm::all_of(getBody().getArgumentTypes(),

1679 [](Type ty) { return ty.isIndex(); }))

1680 return emitError("all body arguments must be index");

1681 if (getBody().getNumArguments() != resultTy.getRank())

1682 return emitError("must have one body argument per input dimension");

1683

1684

1685 auto yieldOp = cast(getBody().getBlocks().front().getTerminator());

1686

1687 if (yieldOp.getValue().getType() != resultTy.getElementType())

1689 "body must be terminated with a `yield` operation of the tensor "

1690 "element type");

1691

1693}

1694

1695void GenerateOp::build(

1696 OpBuilder &b, OperationState &result, Type resultTy,

1699 build(b, result, resultTy, dynamicExtents);

1700

1701

1702 OpBuilder::InsertionGuard guard(b);

1703 Region *bodyRegion = result.regions.front().get();

1704 auto rank = llvm::cast(resultTy).getRank();

1705 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());

1706 SmallVector<Location, 2> argumentLocs(rank, result.location);

1707 Block *bodyBlock =

1708 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);

1710}

1711

1712namespace {

1713

1714

1715

1716

1717

1718struct StaticTensorGenerate : public OpRewritePattern {

1719 using OpRewritePattern::OpRewritePattern;

1720

1721 LogicalResult matchAndRewrite(GenerateOp generateOp,

1722 PatternRewriter &rewriter) const final {

1723 SmallVector foldedDynamicSizes;

1725 generateOp.getType(), generateOp.getDynamicExtents(),

1726 foldedDynamicSizes);

1727

1728

1729 if (foldedTensorType == generateOp.getType())

1730 return failure();

1731

1732 auto loc = generateOp.getLoc();

1733 auto newOp =

1734 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);

1736 newOp.getBody().begin());

1738 generateOp.getType(), newOp);

1740 }

1741};

1742

1743

1744

1745

1746

1747

1748

1749

1750

1751

1752

1753

1754struct ExtractFromTensorGenerate : public OpRewritePatterntensor::ExtractOp {

1755 using OpRewritePatterntensor::ExtractOp::OpRewritePattern;

1756

1757 LogicalResult matchAndRewrite(tensor::ExtractOp extract,

1758 PatternRewriter &rewriter) const final {

1759 auto tensorFromElements = extract.getTensor().getDefiningOp();

1761 return failure();

1762

1763 IRMapping mapping;

1764 Block *body = &tensorFromElements.getBody().front();

1765 mapping.map(body->getArguments(), extract.getIndices());

1767 rewriter.clone(op, mapping);

1768

1769 auto yield = cast(body->getTerminator());

1770

1773 }

1774};

1775

1776}

1777

1778void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,

1779 MLIRContext *context) {

1780

1781 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);

1782}

1783

1784

1785

1786

1787

1788void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {

1789 setNameFn(getResult(), "rank");

1790}

1791

1792OpFoldResult RankOp::fold(FoldAdaptor adaptor) {

1793

1794 auto type = getOperand().getType();

1795 auto shapedType = llvm::dyn_cast(type);

1796 if (shapedType && shapedType.hasRank())

1797 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());

1798 return IntegerAttr();

1799}

1800

1801

1802

1803

1804

1805void ReshapeOp::getAsmResultNames(

1806 function_ref<void(Value, StringRef)> setNameFn) {

1807 setNameFn(getResult(), "reshape");

1808}

1809

1811 int64_t numElements = 1;

1812 for (auto dim : type.getShape())

1813 numElements *= dim;

1814 return numElements;

1815}

1816

1817LogicalResult ReshapeOp::verify() {

1818 TensorType operandType = llvm::cast(getSource().getType());

1819 TensorType resultType = llvm::cast(getResult().getType());

1820

1822 return emitOpError("element types of source and destination tensor "

1823 "types should be the same");

1824

1825 int64_t shapeSize =

1826 llvm::cast(getShape().getType()).getDimSize(0);

1827 auto resultRankedType = llvm::dyn_cast(resultType);

1828 auto operandRankedType = llvm::dyn_cast(operandType);

1829

1830 if (resultRankedType) {

1831 if (operandRankedType && resultRankedType.hasStaticShape() &&

1832 operandRankedType.hasStaticShape()) {

1834 return emitOpError("source and destination tensor should have the "

1835 "same number of elements");

1836 }

1837 if (ShapedType::isDynamic(shapeSize))

1838 return emitOpError("cannot use shape operand with dynamic length to "

1839 "reshape to statically-ranked tensor type");

1840 if (shapeSize != resultRankedType.getRank())

1842 "length of shape operand differs from the result's tensor rank");

1843 }

1845}

1846

1847OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {

1848 if (OpFoldResult reshapedSource = reshapeConstantSource(

1849 llvm::dyn_cast_if_present(adaptor.getSource()),

1851 return reshapedSource;

1852

1853

1854

1855

1856 if (auto reshapeOpProducer = getSource().getDefiningOp()) {

1857 getSourceMutable().assign(reshapeOpProducer.getSource());

1858 return getResult();

1859 }

1860

1861 auto source = getSource();

1862 auto sourceTy = dyn_cast(source.getType());

1863 auto resultTy = dyn_cast(getType());

1864 if (!sourceTy || !resultTy || sourceTy != resultTy)

1865 return {};

1866

1867

1868

1869 if (sourceTy.getRank() <= 1)

1870 return source;

1871

1872 if (auto fromElements = getShape().getDefiningOptensor::FromElementsOp()) {

1873 auto elements = fromElements.getElements();

1874 bool dynamicNoop =

1875 sourceTy.getRank() == static_cast<int64_t>(elements.size());

1876 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {

1877 auto element = elements[id];

1878

1880 dynamicNoop &= cst.value() == sourceTy.getDimSize(id);

1881 continue;

1882 }

1883

1884 if (auto dimOp = element.getDefiningOptensor::DimOp()) {

1885 dynamicNoop &= dimOp.getSource() == source;

1886

1888 dynamicNoop &=

1889 cst.has_value() && cst.value() == static_cast<int64_t>(id);

1890 continue;

1891 }

1892

1893 dynamicNoop = false;

1894 break;

1895 }

1896

1897 if (dynamicNoop)

1898 return source;

1899 }

1900

1901 return {};

1902}

1903

1904

1905

1906

1907

1908void CollapseShapeOp::getAsmResultNames(

1909 function_ref<void(Value, StringRef)> setNameFn) {

1910 setNameFn(getResult(), "collapsed");

1911}

1912

1913void ExpandShapeOp::getAsmResultNames(

1914 function_ref<void(Value, StringRef)> setNameFn) {

1915 setNameFn(getResult(), "expanded");

1916}

1917

1918int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {

1919 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&

1920 "invalid resultDim");

1921 for (const auto &it : llvm::enumerate(getReassociationIndices()))

1922 if (llvm::is_contained(it.value(), resultDim))

1923 return it.index();

1924 llvm_unreachable("could not find reassociation group");

1925}

1926

1927FailureOr<SmallVector>

1928ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,

1929 RankedTensorType expandedType,

1930 ArrayRef reassociation,

1931 ArrayRef inputShape) {

1932 std::optional<SmallVector> outputShape =

1934 inputShape);

1935 if (!outputShape)

1936 return failure();

1937 return *outputShape;

1938}

1939

1940SmallVector ExpandShapeOp::getMixedOutputShape() {

1942}

1943

1944void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,

1945 Type resultType, Value src,

1946 ArrayRef reassociation,

1947 ArrayRef outputShape) {

1948 auto [staticOutputShape, dynamicOutputShape] =

1950 build(builder, result, cast(resultType), src,

1952 dynamicOutputShape, staticOutputShape);

1953}

1954

1955void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,

1956 Type resultType, Value src,

1957 ArrayRef reassociation) {

1958 SmallVector inputShape =

1960 auto tensorResultTy = cast(resultType);

1961 FailureOr<SmallVector> outputShape = inferOutputShape(

1962 builder, result.location, tensorResultTy, reassociation, inputShape);

1963 SmallVector outputShapeOrEmpty;

1964 if (succeeded(outputShape)) {

1965 outputShapeOrEmpty = *outputShape;

1966 }

1967 build(builder, result, tensorResultTy, src, reassociation,

1968 outputShapeOrEmpty);

1969}

1970

1971SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {

1973}

1974SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {

1976 getReassociationIndices());

1977}

1978

1979SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {

1981}

1982SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {

1984 getReassociationIndices());

1985}

1986

1987RankedTensorType CollapseShapeOp::inferCollapsedType(

1988 RankedTensorType type, SmallVector reassociation) {

1989 return inferCollapsedType(

1991 type.getContext(), reassociation)));

1992}

1993

1994

1995

1996RankedTensorType

1997CollapseShapeOp::inferCollapsedType(RankedTensorType type,

1998 ArrayRef reassociation) {

1999 auto shape = type.getShape();

2000 SmallVector<int64_t, 4> newShape;

2001 newShape.reserve(reassociation.size());

2002

2003

2004

2006 unsigned currentDim = 0;

2007 for (AffineMap m : reassociation) {

2008 unsigned dim = m.getNumResults();

2009 auto band = shape.slice(currentDim, dim);

2010 int64_t size = 1;

2011 if (llvm::is_contained(band, ShapedType::kDynamic))

2012 size = ShapedType::kDynamic;

2013 else

2014 for (unsigned d = 0; d < dim; ++d)

2015 size *= shape[currentDim + d];

2016 newShape.push_back(size);

2017 currentDim += dim;

2018 }

2019

2020 return RankedTensorType::get(newShape, type.getElementType());

2021}

2022

2023void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,

2024 ArrayRef reassociation,

2025 ArrayRef attrs) {

2026 auto resultType = inferCollapsedType(

2027 llvm::cast(src.getType()),

2030 result.addAttribute(getReassociationAttrStrName(),

2032 build(b, result, resultType, src, attrs);

2033}

2034

2035template <typename TensorReshapeOp, bool isExpansion = std::is_same<

2036 TensorReshapeOp, ExpandShapeOp>::value>

2038 RankedTensorType expandedType,

2039 RankedTensorType collapsedType) {

2040 if (failed(

2041 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))

2042 return failure();

2043

2044 auto maps = op.getReassociationMaps();

2045 RankedTensorType expectedType =

2046 CollapseShapeOp::inferCollapsedType(expandedType, maps);

2048 return op.emitOpError("expected collapsed type to be ")

2049 << expectedType << ", but got " << collapsedType;

2051}

2052

2053LogicalResult ExpandShapeOp::verify() {

2054 auto srcType = getSrcType();

2055 auto resultType = getResultType();

2056

2057 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())

2058 return emitOpError("expected number of static shape dims to be equal to "

2059 "the output rank (")

2060 << resultType.getRank() << ") but found "

2061 << getStaticOutputShape().size() << " inputs instead";

2062

2063 if ((int64_t)getOutputShape().size() !=

2064 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))

2065 return emitOpError("mismatch in dynamic dims in output_shape and "

2066 "static_output_shape: static_output_shape has ")

2067 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)

2068 << " dynamic dims while output_shape has " << getOutputShape().size()

2069 << " values";

2070

2072}

2073

2074LogicalResult CollapseShapeOp::verify() {

2076}

2077

2078namespace {

2079

2080

2081template

2082struct FoldReshapeWithConstant : OpRewritePattern {

2083 using OpRewritePattern::OpRewritePattern;

2084 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,

2085 PatternRewriter &rewriter) const override {

2086 DenseElementsAttr attr;

2088 return failure();

2089 if (!attr || !attr.isSplat())

2090 return failure();

2092 reshapeOp.getResultType(), attr.getRawData());

2095 }

2096};

2097

2098

2099template

2100class FoldReshapeWithSplat : public OpRewritePattern {

2101public:

2102 using OpRewritePattern::OpRewritePattern;

2103

2104 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,

2105 PatternRewriter &rewriter) const override {

2106 auto splatOp = reshapeOp.getSrc().template getDefiningOptensor::SplatOp();

2107 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())

2108 return failure();

2109

2111 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());

2113 }

2114};

2115

2116

2117

2118template

2119struct FoldReshapeWithFromElements : OpRewritePattern {

2120 using OpRewritePattern::OpRewritePattern;

2121 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,

2122 PatternRewriter &rewriter) const override {

2123 auto fromElements =

2124 reshapeOp.getSrc().template getDefiningOp();

2125 if (!fromElements)

2126 return failure();

2127

2128 auto shapedTy = llvm::cast(reshapeOp.getType());

2129

2130 if (!shapedTy.hasStaticShape())

2131 return failure();

2132

2133 rewriter.replaceOpWithNewOp(reshapeOp, reshapeOp.getType(),

2134 fromElements.getElements());

2136 }

2137};

2138

2139

2140struct FoldCollapseOfCastOp : public OpRewritePattern {

2141 using OpRewritePattern::OpRewritePattern;

2142

2143 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,

2144 PatternRewriter &rewriter) const override {

2145 auto castOp = collapseShapeOp.getSrc().getDefiningOptensor::CastOp();

2147 return failure();

2148

2149 RankedTensorType srcType =

2150 llvm::cast(castOp.getSource().getType());

2151 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(

2152 srcType, collapseShapeOp.getReassociationMaps());

2153

2154 if (newResultType == collapseShapeOp.getResultType()) {

2156 collapseShapeOp.getSrcMutable().assign(castOp.getSource());

2157 });

2158 } else {

2159 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),

2160 newResultType, castOp.getSource(),

2161 collapseShapeOp.getReassociation());

2163 collapseShapeOp, collapseShapeOp.getResultType(), newOp);

2164 }

2166 }

2167};

2168

2169

2170

2171

2172

2173struct ConvertToStaticExpandShape : public OpRewritePattern {

2174 using OpRewritePattern::OpRewritePattern;

2175

2176 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,

2177 PatternRewriter &rewriter) const override {

2178 auto castOp = expandOp.getSrc().getDefiningOp();

2180 return failure();

2181

2182 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();

2183 SmallVector<ReassociationIndices, 4> reassoc =

2184 expandOp.getReassociationIndices();

2185

2186 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());

2187 SmallVector dynamicOutputShape;

2188 auto outputIt = expandOp.getOutputShape().begin();

2189

2190 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {

2191 for (uint64_t outDim : innerReassoc) {

2192 if (ShapedType::isStatic(newOutputShape[outDim]))

2193 continue;

2194

2195

2196

2197

2198

2199 Value val = *outputIt;

2200 ++outputIt;

2201 if (ShapedType::isDynamic(castSrcShape[inputDim])) {

2202 dynamicOutputShape.push_back(val);

2203 continue;

2204 }

2205

2206 APInt cst;

2208 newOutputShape[outDim] = cst.getSExtValue();

2209 } else {

2210 dynamicOutputShape.push_back(val);

2211 }

2212 }

2213 }

2214

2215

2216 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())

2217 return failure();

2218

2219

2220 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);

2221 for (auto inDim : llvm::seq(0, newInputShape.size())) {

2222 for (auto outDim : reassoc[inDim]) {

2223 auto ofr = newOutputShape[outDim];

2224 if (ShapedType::isDynamic(ofr)) {

2225 newInputShape[inDim] = ShapedType::kDynamic;

2226 break;

2227 }

2228 newInputShape[inDim] *= ofr;

2229 }

2230 }

2231

2232 SmallVector outputOfr =

2233 getMixedValues(newOutputShape, dynamicOutputShape, rewriter);

2234 auto inputType = RankedTensorType::get(

2235 newInputShape, expandOp.getSrcType().getElementType());

2236 auto outputType = RankedTensorType::get(

2237 newOutputShape, expandOp.getSrcType().getElementType());

2238 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,

2239 expandOp.getSrc());

2240 auto newExpand = ExpandShapeOp::create(

2241 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),

2242 expandOp.getReassociationIndices(), outputOfr);

2244 newExpand.getResult());

2246 }

2247};

2248}

2249

2250void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,

2251 MLIRContext *context) {

2252 results.add<

2253 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,

2254 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,

2255 ConvertToStaticExpandShape, FoldReshapeWithConstant,

2256 FoldReshapeWithSplat,

2257 FoldReshapeWithFromElements>(context);

2258}

2259

2260void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,

2261 MLIRContext *context) {

2262 results.add<

2263 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,

2264 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,

2265 tensor::DimOp, RankedTensorType>,

2266 FoldReshapeWithConstant,

2267 FoldReshapeWithSplat,

2268 FoldReshapeWithFromElements, FoldCollapseOfCastOp>(

2269 context);

2270}

2271

2272OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {

2274 adaptor.getOperands());

2275}

2276

2277OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {

2279 adaptor.getOperands());

2280}

2281

2282

2283

2284

2285

2286void ExtractSliceOp::getAsmResultNames(

2287 function_ref<void(Value, StringRef)> setNameFn) {

2288 setNameFn(getResult(), "extracted_slice");

2289}

2290

2291

2292

2293

2294RankedTensorType

2295ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,

2296 ArrayRef<int64_t> staticSizes) {

2297

2298

2299

2300 assert(static_cast<int64_t>(staticSizes.size()) ==

2301 sourceTensorType.getRank() &&

2302 "unexpected staticSizes not equal to rank of source");

2303 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),

2304 sourceTensorType.getEncoding());

2305}

2306

2307

2308RankedTensorType

2309ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,

2310 ArrayRef sizes) {

2311 SmallVector<int64_t> staticSizes;

2313

2314 assert(static_cast<int64_t>(staticSizes.size()) ==

2315 sourceTensorType.getRank() &&

2316 "unexpected staticSizes not equal to rank of source");

2317 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),

2318 sourceTensorType.getEncoding());

2319}

2320

2321

2322

2323

2324

2325

2326

2327

2328

2329RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(

2330 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,

2331 ArrayRef<int64_t> sizes) {

2332

2333 auto inferredType = llvm::cast(

2334 inferResultType(sourceRankedTensorType, sizes));

2335 int rankDiff = inferredType.getRank() - desiredResultRank;

2336 if (rankDiff > 0) {

2337 auto shape = inferredType.getShape();

2338 llvm::SmallBitVector dimsToProject =

2340 SmallVector<int64_t> projectedShape;

2341

2342 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)

2343 if (!dimsToProject.test(pos))

2344 projectedShape.push_back(shape[pos]);

2345 inferredType =

2346 RankedTensorType::get(projectedShape, inferredType.getElementType());

2347 }

2348 return inferredType;

2349}

2350

2351RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(

2352 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,

2353 ArrayRef sizes) {

2354 SmallVector<int64_t> staticSizes;

2355 SmallVector dynamicSizes;

2357 return ExtractSliceOp::inferCanonicalRankReducedResultType(

2358 desiredResultRank, sourceRankedTensorType, staticSizes);

2359}

2360

2361

2362

2363void ExtractSliceOp::build(OpBuilder &b, OperationState &result,

2364 RankedTensorType resultType, Value source,

2365 ArrayRef offsets,

2366 ArrayRef sizes,

2367 ArrayRef strides,

2368 ArrayRef attrs) {

2369 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;

2370 SmallVector dynamicOffsets, dynamicSizes, dynamicStrides;

2374 auto sourceRankedTensorType = llvm::cast(source.getType());

2375

2376 if (!resultType) {

2377 resultType = llvm::cast(

2378 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));

2379 }

2380 result.addAttributes(attrs);

2381 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,

2382 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),

2383 b.getDenseI64ArrayAttr(staticSizes),

2384 b.getDenseI64ArrayAttr(staticStrides));

2385}

2386

2387

2388

2389void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,

2390 ArrayRef offsets,

2391 ArrayRef sizes,

2392 ArrayRef strides,

2393 ArrayRef attrs) {

2394 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);

2395}

2396

2397

2398

2399void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,

2400 ArrayRef ranges,

2401 ArrayRef attrs) {

2403 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);

2404}

2405

2406

2407

2408void ExtractSliceOp::build(OpBuilder &b, OperationState &result,

2409 RankedTensorType resultType, Value source,

2411 ValueRange strides, ArrayRef attrs) {

2412 SmallVector offsetValues = llvm::to_vector<4>(

2413 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));

2414 SmallVector sizeValues = llvm::to_vector<4>(

2415 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));

2416 SmallVector strideValues = llvm::to_vector<4>(

2417 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));

2418 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);

2419}

2420

2421

2422void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,

2424 ValueRange strides, ArrayRef attrs) {

2425 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);

2426}

2427

2430 RankedTensorType expectedType) {

2435 return op->emitError("expected rank to be smaller or equal to ")

2436 << "the other rank. ";

2438 return op->emitError("expected type to be ")

2439 << expectedType << " or a rank-reduced version. (size mismatch) ";

2441 return op->emitError("expected element type to be ")

2442 << expectedType.getElementType();

2443 default:

2444 llvm_unreachable("unexpected extract_slice op verification result");

2445 }

2446}

2447

2448

2449

2450void ExtractSliceOp::build(OpBuilder &b, OperationState &result,

2451 RankedTensorType resultType, Value source,

2452 ArrayRef sizes,

2453 ArrayRef attrs) {

2454 Attribute zeroIdxAttr = b.getIndexAttr(0);

2455 Attribute oneIdxAttr = b.getIndexAttr(1);

2456 SmallVector readStrides(sizes.size(), oneIdxAttr);

2457 SmallVector readOffsets(sizes.size(), zeroIdxAttr);

2458 build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs);

2459}

2460

2461

2462LogicalResult ExtractSliceOp::verify() {

2463 RankedTensorType sourceType = getSourceType();

2464

2465

2466 RankedTensorType expectedType =

2467 ExtractSliceOp::inferResultType(sourceType, getMixedSizes());

2471

2472

2473

2475 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),

2476 getStaticStrides(), true);

2477 if (!boundsResult.isValid)

2478 return getOperation()->emitError(boundsResult.errorMessage);

2479

2481}

2482

2483llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {

2485}

2486

2487FailureOr

2488ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,

2489 ArrayRef<int64_t> desiredShape) {

2490 auto sourceTensorType = llvm::dyn_cast(value.getType());

2491 assert(sourceTensorType && "not a ranked tensor type");

2492 auto sourceShape = sourceTensorType.getShape();

2493 if (sourceShape.equals(desiredShape))

2494 return value;

2495 auto maybeRankReductionMask =

2497 if (!maybeRankReductionMask)

2498 return failure();

2500 b, loc, value,

2501 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));

2502}

2503

2504LogicalResult ExtractSliceOp::reifyResultShapes(

2506 reifiedReturnShapes.resize(1);

2507 reifiedReturnShapes[0].reserve(getType().getRank());

2508 SmallVector mixedSizes = getMixedSizes();

2509 llvm::SmallBitVector droppedDims = getDroppedDims();

2510 for (const auto &size : enumerate(mixedSizes)) {

2511 if (droppedDims.test(size.index()))

2512 continue;

2513 reifiedReturnShapes[0].push_back(size.value());

2514 }

2516}

2517

2518namespace {

2519

2520

2521

2522

2523

2524

2525

2526

2527

2528

2529

2530

2531

2532

2533

2534class ExtractSliceOpCastFolder final : public OpRewritePattern {

2535public:

2536 using OpRewritePattern::OpRewritePattern;

2537

2538 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,

2539 PatternRewriter &rewriter) const override {

2540

2541 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {

2542 return matchPattern(operand, matchConstantIndex());

2543 }))

2544 return failure();

2545

2546 auto castOp = sliceOp.getSource().getDefiningOp();

2547 if (!castOp)

2548 return failure();

2549

2551 return failure();

2552

2553

2555 cast(castOp.getSource().getType()).getShape(),

2556 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),

2557 sliceOp.getStaticStrides());

2558 if (!sliceResult.isValid)

2559 return failure();

2560

2561

2562 Location loc = sliceOp.getLoc();

2563 Value newResult = ExtractSliceOp::create(

2564 rewriter, loc, sliceOp.getType(), castOp.getSource(),

2565 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),

2566 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),

2567 sliceOp.getStaticStrides());

2568 rewriter.replaceOp(sliceOp, newResult);

2570 }

2571};

2572

2573

2574

2575

2576template <typename IterTy, typename ElemTy>

2577static void sliceElements(IterTy values, ArrayRef<int64_t> counts,

2578 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,

2579 ArrayRef<int64_t> strides,

2580 llvm::SmallVectorImpl *outValues) {

2581 assert(offsets.size() == sizes.size());

2582 assert(offsets.size() == strides.size());

2583 if (offsets.empty())

2584 return;

2585

2586 int64_t offset = offsets.front();

2587 int64_t size = sizes.front();

2588 int64_t stride = strides.front();

2589 if (offsets.size() == 1) {

2590 for (int64_t i = 0; i < size; ++i, offset += stride)

2591 outValues->push_back(*(values + offset));

2592

2593 return;

2594 }

2595

2596 for (int64_t i = 0; i < size; ++i, offset += stride) {

2597 auto begin = values + offset * counts.front();

2598 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),

2599 offsets.drop_front(), sizes.drop_front(),

2600 strides.drop_front(), outValues);

2601 }

2602}

2603

2604

2605

2606

2607class ConstantOpExtractSliceFolder final

2608 : public OpRewritePattern {

2609public:

2610 using OpRewritePattern::OpRewritePattern;

2611

2612 ConstantOpExtractSliceFolder(MLIRContext *context,

2614 : OpRewritePattern(context),

2615 controlFn(std::move(controlFn)) {}

2616

2617 LogicalResult matchAndRewrite(ExtractSliceOp op,

2618 PatternRewriter &rewriter) const override {

2619 DenseElementsAttr attr;

2621 return failure();

2622

2623

2625 return failure();

2626

2627

2628 auto sourceType = llvm::cast(op.getSource().getType());

2629 auto resultType = llvm::cast(op.getResult().getType());

2630 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())

2631 return failure();

2632

2633

2634 if (!controlFn(op))

2635 return failure();

2636

2637 int64_t count = sourceType.getNumElements();

2638 if (count == 0)

2639 return failure();

2640

2641

2642 auto offsets = op.getStaticOffsets();

2643 if (llvm::is_contained(offsets, ShapedType::kDynamic))

2644 return failure();

2645 auto sizes = op.getStaticSizes();

2646 if (llvm::is_contained(sizes, ShapedType::kDynamic))

2647 return failure();

2648 auto strides = op.getStaticStrides();

2649 if (llvm::is_contained(strides, ShapedType::kDynamic))

2650 return failure();

2651

2652

2653 SmallVector<int64_t> counts;

2654 ArrayRef<int64_t> shape = sourceType.getShape();

2655 counts.reserve(shape.size());

2656 for (int64_t v : shape) {

2657 count = count / v;

2658 counts.push_back(count);

2659 }

2660

2661

2662 DenseElementsAttr newAttr;

2663

2664 if (auto elems = llvm::dyn_cast(attr)) {

2665 SmallVector outValues;

2666 outValues.reserve(sourceType.getNumElements());

2667 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(

2668 elems.begin(), counts, offsets, sizes, strides, &outValues);

2670 } else if (auto elems = llvm::dyn_cast(attr)) {

2671 SmallVector outValues;

2672 outValues.reserve(sourceType.getNumElements());

2673 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(

2674 elems.begin(), counts, offsets, sizes, strides, &outValues);

2676 }

2677

2678 if (newAttr) {

2679 rewriter.replaceOpWithNewOparith::ConstantOp(op, resultType, newAttr);

2681 }

2682

2683 return failure();

2684 }

2685

2686private:

2687

2688

2690};

2691

2692}

2693

2697 patterns.add(patterns.getContext(), controlFn);

2698}

2699

2700

2706 return ExtractSliceOp::inferCanonicalRankReducedResultType(

2707 op.getType().getRank(), op.getSourceType(), mixedSizes);

2708 }

2709};

2710

2711

2714 ExtractSliceOp newOp) {

2716 if (replacement.getType() != op.getType())

2717 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),

2720 }

2721};

2722

2723void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,

2724 MLIRContext *context) {

2725 results.add<

2726 OpWithOffsetSizesAndStridesConstantArgumentFolder<

2727 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,

2728 ExtractSliceOpCastFolder>(context);

2729}

2730

2731

2732static LogicalResult

2734 ShapedType shapedType) {

2736 for (OpFoldResult ofr : op.getMixedOffsets())

2738 return failure();

2739

2740

2741 auto shape = shapedType.getShape();

2742 for (auto it : llvm::zip(op.getMixedSizes(), shape))

2744 return failure();

2745 for (OpFoldResult ofr : op.getMixedStrides())

2747 return failure();

2749}

2750

2751

2752

2753

2754

2756 auto insertOp = extractOp.getSource().getDefiningOp();

2757

2759 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&

2760 insertOp.isSameAs(extractOp, isSame))

2761 return insertOp.getSource();

2762

2763 return {};

2764}

2765

2766OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {

2767 if (OpFoldResult reshapedSource = reshapeConstantSource(

2768 llvm::dyn_cast_if_present(adaptor.getSource()),

2770 return reshapedSource;

2771 if (getSourceType() == getType() &&

2773 return this->getSource();

2775 return slice;

2776

2777 return OpFoldResult();

2778}

2779

2782 auto rankedTensorType = llvm::cast(tensor.getType());

2783 unsigned rank = rankedTensorType.getRank();

2787 return b.createOrFoldtensor::ExtractSliceOp(loc, targetType, tensor,

2788 offsets, sizes, strides);

2789}

2790

2791

2792

2793

2794

2795void InsertSliceOp::getAsmResultNames(

2797 setNameFn(getResult(), "inserted_slice");

2798}

2799

2800

2811 result.addAttributes(attrs);

2812 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,

2813 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),

2814 b.getDenseI64ArrayAttr(staticSizes),

2815 b.getDenseI64ArrayAttr(staticStrides));

2816}

2817

2818

2819

2820void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,

2821 Value dest, ArrayRef ranges,

2822 ArrayRef attrs) {

2824 build(b, result, source, dest, offsets, sizes, strides, attrs);

2825}

2826

2827

2828void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,

2830 ValueRange strides, ArrayRef attrs) {

2831 SmallVector offsetValues = llvm::to_vector<4>(

2832 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));

2833 SmallVector sizeValues = llvm::to_vector<4>(

2834 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));

2835 SmallVector strideValues = llvm::to_vector<4>(

2836 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));

2837 build(b, result, source, dest, offsetValues, sizeValues, strideValues);

2838}

2839

2840

2841

2843 RankedTensorType srcType, RankedTensorType dstType,

2845 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {

2846

2847

2848 RankedTensorType expected =

2849 ExtractSliceOp::inferResultType(dstType, staticSizes);

2850 if (expectedType)

2851 *expectedType = expected;

2853}

2854

2855

2856LogicalResult InsertSliceOp::verify() {

2857

2858 RankedTensorType expectedType;

2861 getStaticSizes(), getStaticStrides(), &expectedType);

2864

2865

2866

2868 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),

2869 getStaticStrides(), true);

2870 if (!boundsResult.isValid)

2871 return getOperation()->emitError(boundsResult.errorMessage);

2872

2874}

2875

2876

2877

2878

2879

2880

2881

2882

2883

2884

2885

2886

2887

2888

2889

2890

2891

2892

2894 auto prevInsertOp = insertOp.getDest().getDefiningOp();

2895

2897 if (!prevInsertOp ||

2898 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||

2899 !prevInsertOp.isSameAs(insertOp, isSame))

2900 return failure();

2901

2902 insertOp.getDestMutable().assign(prevInsertOp.getDest());

2904}

2905

2906

2907

2908

2909

2910

2911

2912

2914 auto extractOp = insertOp.getSource().getDefiningOp();

2915

2917 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||

2918 !extractOp.isSameAs(insertOp, isSame))

2919 return nullptr;

2920

2921 return extractOp.getSource();

2922}

2923

2924OpFoldResult InsertSliceOp::fold(FoldAdaptor) {

2925 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&

2926 getSourceType() == getType() &&

2928 return this->getSource();

2930 return getResult();

2934 return getDest();

2935 return OpFoldResult();

2936}

2937

2938LogicalResult InsertSliceOp::reifyResultShapes(

2940 reifiedReturnShapes.resize(1, SmallVector(getType().getRank()));

2943}

2944

2945namespace {

2946

2947

2948

2949template

2950class InsertSliceOpConstantArgumentFolder final

2951 : public OpRewritePattern {

2952public:

2953 using OpRewritePattern::OpRewritePattern;

2954

2955 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,

2956 PatternRewriter &rewriter) const override {

2957 SmallVector mixedOffsets(insertSliceOp.getMixedOffsets());

2958 SmallVector mixedSizes(insertSliceOp.getMixedSizes());

2959 SmallVector mixedStrides(insertSliceOp.getMixedStrides());

2960

2961

2965 return failure();

2966

2967

2968 SliceBoundsVerificationResult sliceResult =

2970 mixedOffsets, mixedSizes, mixedStrides);

2971 if (!sliceResult.isValid)

2972 return failure();

2973

2974

2975 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(

2976 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),

2977 mixedSizes);

2978 Value toInsert = insertSliceOp.getSource();

2979 if (sourceType != insertSliceOp.getSourceType()) {

2980 OpBuilder::InsertionGuard g(rewriter);

2981

2982

2983

2984 if (isa(insertSliceOp->getParentOp()))

2986 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),

2987 sourceType, toInsert);

2988 }

2990 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,

2991 mixedSizes, mixedStrides);

2993 }

2994};

2995

2996

2997

2998

2999

3000

3001

3002

3003

3004

3005

3006

3007

3008

3009

3010

3011

3012

3013

3014

3015

3016template

3017struct InsertSliceOpCastFolder final : public OpRewritePattern {

3018 using OpRewritePattern::OpRewritePattern;

3019

3020 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,

3021 PatternRewriter &rewriter) const override {

3022 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {

3023 return matchPattern(operand, matchConstantIndex());

3024 }))

3025 return failure();

3026

3027 auto getSourceOfCastOp = [](Value v) -> std::optional {

3028 auto castOp = v.getDefiningOptensor::CastOp();

3030 return std::nullopt;

3031 return castOp.getSource();

3032 };

3033 std::optional sourceCastSource =

3034 getSourceOfCastOp(insertSliceOp.getSource());

3035 std::optional destCastSource =

3036 getSourceOfCastOp(insertSliceOp.getDest());

3037 if (!sourceCastSource && !destCastSource)

3038 return failure();

3039

3040 auto src =

3041 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());

3042 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());

3043 auto srcType = llvm::dyn_cast(src.getType());

3044 auto dstType = llvm::dyn_cast(dst.getType());

3045 if (!srcType || !dstType)

3046 return failure();

3047

3048

3049

3050

3051 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());

3053 staticSizes, srcType.getShape(), true);

3054 if (!rankReductionMask.has_value())

3055 return failure();

3056

3057

3058

3059

3060

3061 SmallVector mixedSizes(insertSliceOp.getMixedSizes());

3062 int64_t rankReducedIdx = 0;

3063 for (auto [idx, size] : enumerate(staticSizes)) {

3064 if (!rankReductionMask.value().contains(idx) &&

3065 !srcType.isDynamicDim(rankReducedIdx)) {

3067 rewriter.getContext(), srcType.getDimSize(rankReducedIdx));

3068 size = srcType.getDimSize(rankReducedIdx++);

3069 }

3070 }

3071

3072

3073 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),

3074 staticSizes, insertSliceOp.getStaticStrides()) !=

3075 SliceVerificationResult::Success)

3076 return failure();

3077 SliceBoundsVerificationResult sliceResult =

3078 verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),

3079 mixedSizes, insertSliceOp.getMixedStrides());

3080 if (!sliceResult.isValid)

3081 return failure();

3082

3084 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,

3085 insertSliceOp.getMixedOffsets(), mixedSizes,

3086 insertSliceOp.getMixedStrides());

3087

3088

3089 bool isParallelInsert =

3090 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;

3091 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {

3092 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),

3093 insertSliceOp.getDestType(),

3095 }

3098 }

3099};

3100

3101

3102

3103

3104

3105

3106

3107

3108

3109

3110

3111

3112

3113

3114

3115

3116

3117

3118

3119

3120

3121

3122template

3123struct InsertSliceOpSourceCastInserter final

3124 : public OpRewritePattern {

3125 using OpRewritePattern::OpRewritePattern;

3126

3127 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,

3128 PatternRewriter &rewriter) const override {

3129 RankedTensorType srcType = insertSliceOp.getSourceType();

3130 if (srcType.getRank() != insertSliceOp.getDestType().getRank())

3131 return failure();

3132 SmallVector<int64_t> newSrcShape(srcType.getShape());

3133 for (int64_t i = 0; i < srcType.getRank(); ++i) {

3134 if (std::optional<int64_t> constInt =

3136

3137 if (*constInt < 0)

3138 return failure();

3139 newSrcShape[i] = *constInt;

3140 }

3141 }

3143 return failure();

3144

3145 RankedTensorType newSrcType = RankedTensorType::get(

3146 newSrcShape, srcType.getElementType(), srcType.getEncoding());

3147 if (srcType == newSrcType ||

3149 !tensor::CastOp::areCastCompatible(srcType, newSrcType))

3150 return failure();

3151

3152

3153

3154

3155

3156

3157 OpBuilder::InsertionGuard g(rewriter);

3158

3159

3160

3161 if (isa(insertSliceOp->getParentOp()))

3163 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),

3164 newSrcType, insertSliceOp.getSource());

3166 insertSliceOp, cast, insertSliceOp.getDest(),

3167 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),

3168 insertSliceOp.getMixedStrides());

3170 }

3171};

3172}

3173

3174llvm::SmallBitVector InsertSliceOp::getDroppedDims() {

3176}

3177

3178void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,

3179 MLIRContext *context) {

3180 results.add<InsertSliceOpConstantArgumentFolder,

3181 InsertSliceOpCastFolder,

3182 InsertSliceOpSourceCastInserter>(context);

3183}

3184

3189 auto rankedTensorType = llvm::cast(dest.getType());

3190 unsigned rank = rankedTensorType.getRank();

3194 return b.createOrFoldtensor::InsertSliceOp(loc, tensor, dest, offsets,

3195 sizes, strides);

3196}

3197

3198

3199

3200

3201

3202void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {

3203 setNameFn(getResult(), "padded");

3204}

3205

3206LogicalResult PadOp::verify() {

3207 auto sourceType = llvm::cast(getSource().getType());

3208 auto resultType = llvm::cast(getResult().getType());

3209 auto expectedType =

3210 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());

3211 if (!expectedType) {

3212 return emitError("failed to infer expectedType from sourceType ")

3213 << sourceType << ", specified resultType is " << resultType;

3214 }

3215 if (resultType.getRank() != expectedType.getRank()) {

3216 return emitError("specified type ")

3217 << resultType << " does not match the inferred type "

3218 << expectedType;

3219 }

3220 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {

3221 if (resultType.getDimSize(i) == expectedType.getDimSize(i))

3222 continue;

3223 if (expectedType.isDynamicDim(i))

3224 continue;

3225 return emitError("specified type ")

3226 << resultType << " does not match the inferred type "

3227 << expectedType;

3228 }

3229

3231}

3232

3233LogicalResult PadOp::verifyRegions() {

3234 auto &region = getRegion();

3235 unsigned rank = llvm::cast(getResult().getType()).getRank();

3236 Block &block = region.front();

3238 return emitError("expected the block to have ") << rank << " arguments";

3239

3240

3241 for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {

3242 if (!en.value().isIndex())

3243 return emitOpError("expected block argument ")

3244 << (en.index() + 1) << " to be an index";

3245 }

3246

3247

3248 auto yieldOp = llvm::cast(block.getTerminator());

3249 if (yieldOp.getValue().getType() !=

3251 return emitOpError("expected yield type to match shape element type");

3252

3254}

3255

3256RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,

3257 ArrayRef<int64_t> staticLow,

3258 ArrayRef<int64_t> staticHigh,

3259 ArrayRef<int64_t> resultShape) {

3260 unsigned rank = sourceType.getRank();

3261 if (staticLow.size() != rank)

3262 return RankedTensorType();

3263 if (staticHigh.size() != rank)

3264 return RankedTensorType();

3265 if (!resultShape.empty() && resultShape.size() != rank)

3266 return RankedTensorType();

3267

3268 SmallVector<int64_t, 4> inferredShape;

3269 for (auto i : llvm::seq(0, rank)) {

3270 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||

3271 staticHigh[i] == ShapedType::kDynamic) {

3272 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic

3273 : resultShape[i]);

3274 } else {

3275 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];

3276 assert((resultShape.empty() || size == resultShape[i] ||

3277 resultShape[i] == ShapedType::kDynamic) &&

3278 "mismatch between inferred shape and result shape");

3279 inferredShape.push_back(size);

3280 }

3281 }

3282

3283 return RankedTensorType::get(inferredShape, sourceType.getElementType());

3284}

3285

3286void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,

3287 Value source, ArrayRef<int64_t> staticLow,

3289 bool nofold, ArrayRef attrs) {

3290 auto sourceType = llvm::cast(source.getType());

3291 if (!resultType)

3292 resultType = inferResultType(sourceType, staticLow, staticHigh);

3293 result.addAttributes(attrs);

3294 build(b, result, resultType, source, low, high,

3295 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),

3296 nofold ? b.getUnitAttr() : UnitAttr());

3297}

3298

3299void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,

3301 ArrayRef attrs) {

3302 auto sourceType = llvm::cast(source.getType());

3303 unsigned rank = sourceType.getRank();

3304 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);

3305 build(b, result, resultType, source, staticVector, staticVector, low, high,

3306 nofold, attrs);

3307}

3308

3309void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,

3310 Value source, ArrayRef low,

3311 ArrayRef high, bool nofold,

3312 ArrayRef attrs) {

3313 auto sourceType = llvm::cast(source.getType());

3314 SmallVector<Value, 4> dynamicLow, dynamicHigh;

3315 SmallVector<int64_t, 4> staticLow, staticHigh;

3316

3317

3318

3319

3322 if (!resultType) {

3323 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);

3324 }

3325 assert(llvm::isa(resultType));

3326 result.addAttributes(attrs);

3327 build(b, result, resultType, source, dynamicLow, dynamicHigh,

3328 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),

3329 nofold ? b.getUnitAttr() : UnitAttr());

3330}

3331

3332void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,

3333 Value source, ArrayRef low,

3334 ArrayRef high, Value constantPadValue,

3335 bool nofold, ArrayRef attrs) {

3336 build(b, result, resultType, source, low, high, nofold, attrs);

3337

3338

3339 Region *region = result.regions[0].get();

3340 int sourceRank = llvm::cast(source.getType()).getRank();

3341 SmallVector blockArgTypes(sourceRank, b.getIndexType());

3342 SmallVector blockArgLocs(sourceRank, result.location);

3343

3344

3345

3346 OpBuilder::InsertionGuard guard(b);

3347 b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);

3348 tensor::YieldOp::create(b, result.location, constantPadValue);

3349}

3350

3351llvm::SmallBitVector PadOp::getPaddedDims() {

3352 llvm::SmallBitVector paddedDims(getSourceType().getRank());

3353 auto extractPaddedDims = [&](ArrayRef paddingWidths) {

3354 for (const auto &en : enumerate(paddingWidths))

3356 paddedDims.set(en.index());

3357 };

3358 extractPaddedDims(getMixedLowPad());

3359 extractPaddedDims(getMixedHighPad());

3360 return paddedDims;

3361}

3362

3363namespace {

3364

3365

3366struct FoldStaticZeroPadding : public OpRewritePattern {

3367 using OpRewritePattern::OpRewritePattern;

3368

3369 LogicalResult matchAndRewrite(PadOp padTensorOp,

3370 PatternRewriter &rewriter) const override {

3371 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())

3372 return failure();

3373 if (padTensorOp.getNofold())

3374 return failure();

3376 padTensorOp, padTensorOp.getResult().getType(),

3377 padTensorOp.getSource());

3379 }

3380};

3381

3382

3383struct FoldSourceTensorCast : public OpRewritePattern {

3384 using OpRewritePattern::OpRewritePattern;

3385

3386 LogicalResult matchAndRewrite(PadOp padTensorOp,

3387 PatternRewriter &rewriter) const override {

3388 auto castOp = padTensorOp.getSource().getDefiningOptensor::CastOp();

3390 return failure();

3391

3392 auto newResultType = PadOp::inferResultType(

3393 llvm::cast(castOp.getSource().getType()),

3394 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),

3395 padTensorOp.getResultType().getShape());

3396

3397 if (newResultType == padTensorOp.getResultType()) {

3399 padTensorOp.getSourceMutable().assign(castOp.getSource());

3400 });

3401 } else {

3402 auto newOp = PadOp::create(

3403 rewriter, padTensorOp->getLoc(), newResultType,

3404 padTensorOp.getSource(), padTensorOp.getStaticLow(),

3405 padTensorOp.getStaticHigh(), padTensorOp.getLow(),

3406 padTensorOp.getHigh(), padTensorOp.getNofold(),

3408 IRMapping mapper;

3409 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);

3410

3412 padTensorOp, padTensorOp.getResultType(), newOp);

3413 }

3415 }

3416};

3417

3418

3419

3420struct FoldTargetTensorCast : public OpRewritePattern {

3421 using OpRewritePattern::OpRewritePattern;

3422

3423 LogicalResult matchAndRewrite(PadOp padTensorOp,

3424 PatternRewriter &rewriter) const override {

3425 if (!padTensorOp.getResult().hasOneUse())

3426 return failure();

3427 auto tensorCastOp =

3428 dyn_casttensor::CastOp(*padTensorOp->getUsers().begin());

3429 if (!tensorCastOp)

3430 return failure();

3432 tensorCastOp.getDest().getType()))

3433 return failure();

3434

3435 auto replacementOp = PadOp::create(

3436 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),

3437 padTensorOp.getSource(), padTensorOp.getStaticLow(),

3438 padTensorOp.getStaticHigh(), padTensorOp.getLow(),

3439 padTensorOp.getHigh(), padTensorOp.getNofold(),

3441 replacementOp.getRegion().takeBody(padTensorOp.getRegion());

3442

3443 rewriter.replaceOp(padTensorOp, replacementOp.getResult());

3444 rewriter.replaceOp(tensorCastOp, replacementOp.getResult());

3446 }

3447};

3448

3449

3450

3451

3452

3453

3454

3455

3456

3457

3458

3459

3460

3461

3462

3463

3464

3465

3466

3467

3468

3469

3470

3471

3472

3473

3474

3475

3476

3477

3478

3479

3480

3481

3482

3483

3484struct FoldOrthogonalPaddings : public OpRewritePattern {

3485 using OpRewritePattern::OpRewritePattern;

3486

3487 LogicalResult matchAndRewrite(PadOp padOp,

3488 PatternRewriter &rewriter) const override {

3489 auto innerSliceOp = padOp.getSource().getDefiningOp();

3490 if (!innerSliceOp)

3491 return failure();

3492 auto outerPadOp = innerSliceOp.getSource().getDefiningOp();

3493 if (!outerPadOp || outerPadOp.getNofold())

3494 return failure();

3495 auto outerSliceOp = outerPadOp.getSource().getDefiningOp();

3496 if (!outerSliceOp)

3497 return failure();

3498

3499

3500 int64_t rank = padOp.getSourceType().getRank();

3501 if (outerSliceOp.getSourceType().getRank() != rank) {

3503 "cannot fold rank-reducing chain");

3504 }

3505

3506

3507 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {

3509 padOp, "cannot fold non-unit stride ExtractSliceOps");

3510 }

3511

3512

3513 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {

3515 "cannot fold PadOps with low padding");

3516 }

3517

3518

3519 Attribute innerAttr, outerAttr;

3520 Value innerValue = padOp.getConstantPaddingValue();

3521 Value outerValue = outerPadOp.getConstantPaddingValue();

3522 if (!innerValue || !outerValue ||

3525 innerAttr != outerAttr) {

3527 padOp, "cannot fold PadOps with different padding values");

3528 }

3529

3530

3531 llvm::SmallBitVector innerDims = padOp.getPaddedDims();

3532 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();

3533 if (innerDims.anyCommon(outerDims)) {

3535 padOp, "cannot fold PadOps with common padding dimensions");

3536 }

3537

3538

3539

3540

3541

3542

3543 SmallVector newOffsets(rank, rewriter.getIndexAttr(0));

3544 for (auto en : enumerate(newOffsets)) {

3545 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];

3546 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];

3547 if (!innerDims.test(en.index()) &&

3549 en.value() = outerOffset;

3550 continue;

3551 }

3552 if (!outerDims.test(en.index()) &&

3554 en.value() = innerOffset;

3555 continue;

3556 }

3558 padOp, "cannot find zero-offset and zero-padding pair");

3559 }

3560

3561

3562

3563

3564

3565

3566 SmallVector newSizes = innerSliceOp.getMixedSizes();

3567 for (auto en : enumerate(newSizes)) {

3568 if (!outerDims.test(en.index()))

3569 continue;

3570 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];

3571 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];

3572 assert(ShapedType::isStatic(sourceSize) &&

3573 "expected padded dimension to have a static size");

3576 padOp, "cannot fold since the inner ExtractSliceOp size does not "

3577 "match the size of the outer padding");

3578 }

3579 en.value() = outerSliceOp.getMixedSizes()[en.index()];

3580 }

3581

3582

3583 SmallVector newHighPad(rank, rewriter.getIndexAttr(0));

3584 for (auto en : enumerate(newHighPad)) {

3585 if (innerDims.test(en.index()))

3586 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];

3587 if (outerDims.test(en.index()))

3588 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];

3589 }

3590

3591

3592

3593 auto newSliceOp = ExtractSliceOp::create(

3594 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,

3595 newSizes, innerSliceOp.getMixedStrides());

3596 auto newPadOp = PadOp::create(

3597 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),

3598 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),

3600 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),

3601 newPadOp.getRegion().begin());

3602 rewriter.replaceOp(padOp, newPadOp.getResult());

3604 }

3605};

3606

3607struct FoldStaticPadding : public OpRewritePattern {

3608 using OpRewritePattern::OpRewritePattern;

3609

3610 LogicalResult matchAndRewrite(PadOp padTensorOp,

3611 PatternRewriter &rewriter) const override {

3612 Value input = padTensorOp.getSource();

3613 if (!llvm::isa(input.getType()))

3614 return failure();

3615 auto inputDims = llvm::cast(input.getType()).getShape();

3616 auto inputRank = inputDims.size();

3617

3618 auto oldResultType =

3619 dyn_cast(padTensorOp.getResult().getType());

3620 if (!oldResultType)

3621 return failure();

3622

3623 auto outputDims = oldResultType.getShape();

3624

3625

3626 SmallVector<int64_t> constOperandsLow;

3627 SmallVector newLows;

3628 for (auto operand : padTensorOp.getLow()) {

3629 APSInt intOp;

3631 constOperandsLow.push_back(ShapedType::kDynamic);

3632 newLows.push_back(operand);

3633 continue;

3634 }

3635 constOperandsLow.push_back(intOp.getExtValue());

3636 }

3637 SmallVector<int64_t> constOperandsHigh;

3638 SmallVector newHighs;

3639 for (auto operand : padTensorOp.getHigh()) {

3640 APSInt intOp;

3642 constOperandsHigh.push_back(ShapedType::kDynamic);

3643 newHighs.push_back(operand);

3644 continue;

3645 }

3646 constOperandsHigh.push_back(intOp.getExtValue());

3647 }

3648

3649 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());

3650 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());

3651

3652

3653 if (inputDims.size() != outputDims.size() ||

3654 inputDims.size() != constLow.size() ||

3655 inputDims.size() != constHigh.size())

3656 return failure();

3657

3658 auto lowCount = 0;

3659 auto highCount = 0;

3660 for (size_t i = 0; i < inputRank; i++) {

3661 if (constLow[i] == ShapedType::kDynamic)

3662 constLow[i] = constOperandsLow[lowCount++];

3663 if (constHigh[i] == ShapedType::kDynamic)

3664 constHigh[i] = constOperandsHigh[highCount++];

3665 }

3666

3667 auto staticLow = ArrayRef<int64_t>(constLow);

3668 auto staticHigh = ArrayRef<int64_t>(constHigh);

3669

3670

3671 SmallVector<int64_t> newOutDims;

3672 for (size_t i = 0; i < inputRank; i++) {

3673 if (outputDims[i] == ShapedType::kDynamic) {

3674 newOutDims.push_back(

3675 (staticLow[i] == ShapedType::kDynamic ||

3676 staticHigh[i] == ShapedType::kDynamic ||

3677 inputDims[i] == ShapedType::kDynamic

3678 ? ShapedType::kDynamic

3679 : inputDims[i] + staticLow[i] + staticHigh[i]));

3680 } else {

3681 newOutDims.push_back(outputDims[i]);

3682 }

3683 }

3684

3685 if (SmallVector<int64_t>(outputDims) == newOutDims ||

3686 llvm::all_of(newOutDims,

3687 [&](int64_t x) { return x == ShapedType::kDynamic; }))

3688 return failure();

3689

3690

3691 auto newResultType = RankedTensorType::get(

3692 newOutDims, padTensorOp.getType().getElementType());

3693 auto newOp = PadOp::create(

3694 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,

3695 staticHigh, newLows, newHighs, padTensorOp.getNofold(),

3697

3698 IRMapping mapper;

3699 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);

3700 rewriter.replaceOpWithNewOptensor::CastOp(padTensorOp, oldResultType,

3701 newOp);

3702

3704 }

3705};

3706

3707

3708

3709

3710

3711

3712

3713

3714

3715

3716

3717

3718

3719

3720

3721

3722

3723

3724

3725

3726

3727struct FoldConsecutiveConstantPadding : public OpRewritePatterntensor::PadOp {

3728 using OpRewritePatterntensor::PadOp::OpRewritePattern;

3729

3730 LogicalResult matchAndRewrite(tensor::PadOp padOp,

3731 PatternRewriter &rewriter) const override {

3732 if (padOp.getNofold()) {

3733 return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");

3734 }

3735

3736 auto producerPad = padOp.getSource().getDefiningOptensor::PadOp();

3737 if (!producerPad || producerPad.getNofold()) {

3739 padOp, "producer is not a foldable tensor.pad op");

3740 }

3741

3742

3743 Value consumerPadValue = padOp.getConstantPaddingValue();

3744 Value producerPadValue = producerPad.getConstantPaddingValue();

3745 if (!consumerPadValue || !producerPadValue ||

3746 consumerPadValue != producerPadValue) {

3748 padOp,

3749 "cannot fold PadOps with different or non-constant padding values");

3750 }

3751

3752 Location loc = padOp.getLoc();

3753 AffineExpr d0, d1;

3755

3756

3757 auto addPaddings = [&](ArrayRef consumerPaddings,

3758 ArrayRef producerPaddings) {

3759 SmallVector sumPaddings;

3760 for (auto [consumerIndex, producerIndex] :

3761 llvm::zip_equal(consumerPaddings, producerPaddings)) {

3763 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));

3764 }

3765 return sumPaddings;

3766 };

3767

3768 SmallVector newHighPad =

3769 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());

3770 SmallVector newLowPad =

3771 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());

3772

3773 auto newPadOp = tensor::PadOp::create(

3774 rewriter, padOp.getLoc(), padOp.getResultType(),

3775 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),

3777 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),

3778 newPadOp.getRegion().begin());

3779 rewriter.replaceOp(padOp, newPadOp.getResult());

3781 }

3782};

3783

3784}

3785

3786LogicalResult

3787PadOp::reifyResultShapes(OpBuilder &b,

3789 reifiedReturnShapes.resize(1, SmallVector(getType().getRank()));

3790 SmallVector lp = getMixedLowPad();

3791 SmallVector hp = getMixedHighPad();

3792 for (int64_t i = 0; i < getResultType().getRank(); ++i) {

3793 if (getType().isDynamicDim(i)) {

3794 reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));

3795 continue;

3796 }

3797 Location loc = getLoc();

3798 Value dim = b.createOrFoldtensor::DimOp(

3800

3801 AffineExpr d0, d1, d2;

3802 bindDims(b.getContext(), d0, d1, d2);

3804 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});

3805 }

3807}

3808

3809void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,

3810 MLIRContext *context) {

3811 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,

3812 FoldOrthogonalPaddings, FoldStaticPadding,

3813 FoldConsecutiveConstantPadding>(context);

3814}

3815

3816

3817

3818

3819

3820

3821

3822

3823

3824

3825Value PadOp::getConstantPaddingValue() {

3826 auto yieldOp = dyn_cast(getRegion().front().getTerminator());

3827 if (!yieldOp)

3828 return {};

3829 Value padValue = yieldOp.getValue();

3830

3832 return padValue;

3833

3834 if (padValue.getParentBlock() == &getRegion().front())

3835 return {};

3836

3837 return padValue;

3838}

3839

3840OpFoldResult PadOp::fold(FoldAdaptor) {

3841 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&

3842 !getNofold())

3843 return getSource();

3844 return {};

3845}

3846

3847

3848

3849

3850

3851OpResult ParallelInsertSliceOp::getTiedOpResult() {

3852 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();

3853 for (const auto &it :

3854 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {

3855 Operation &nextOp = it.value();

3856 if (&nextOp == getOperation())

3857 return parallelCombiningParent.getParentResult(it.index());

3858 }

3859 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");

3860}

3861

3862

3863void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,

3864 Value source, Value dest,

3865 ArrayRef offsets,

3866 ArrayRef sizes,

3867 ArrayRef strides,

3868 ArrayRef attrs) {

3869 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;

3870 SmallVector dynamicOffsets, dynamicSizes, dynamicStrides;

3874 result.addAttributes(attrs);

3875 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,

3876 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),

3877 b.getDenseI64ArrayAttr(staticSizes),

3878 b.getDenseI64ArrayAttr(staticStrides));

3879}

3880

3881

3882

3883void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,

3884 Value source, Value dest,

3885 ArrayRef ranges,

3886 ArrayRef attrs) {

3888 build(b, result, source, dest, offsets, sizes, strides, attrs);

3889}

3890

3891

3892void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,

3893 Value source, Value dest, ValueRange offsets,

3895 ArrayRef attrs) {

3896 SmallVector offsetValues = llvm::to_vector<4>(

3897 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));

3898 SmallVector sizeValues = llvm::to_vector<4>(

3899 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));

3900 SmallVector strideValues = llvm::to_vector<4>(

3901 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));

3902 build(b, result, source, dest, offsetValues, sizeValues, strideValues);

3903}

3904

3905

3906

3907void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,

3908 Value dest, ArrayRef sizes,

3909 ArrayRef attrs) {

3910 Attribute zeroIdxAttr = b.getIndexAttr(0);

3911 Attribute oneIdxAttr = b.getIndexAttr(1);

3912 SmallVector writeStrides(sizes.size(), oneIdxAttr);

3913 SmallVector writeOffsets(sizes.size(), zeroIdxAttr);

3914 build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs);

3915}

3916

3917LogicalResult ParallelInsertSliceOp::verify() {

3918 if (!isa(getOperation()->getParentOp()))

3919 return this->emitError("expected InParallelOpInterface parent, got:")

3920 << *(getOperation()->getParentOp());

3921

3922

3923 RankedTensorType expectedType;

3926 getStaticSizes(), getStaticStrides(), &expectedType);

3929

3930

3931

3933 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),

3934 getStaticStrides(), true);

3935 if (!boundsResult.isValid)

3936 return getOperation()->emitError(boundsResult.errorMessage);

3937

3939}

3940

3941void ParallelInsertSliceOp::getCanonicalizationPatterns(

3942 RewritePatternSet &results, MLIRContext *context) {

3943 results.add<InsertSliceOpConstantArgumentFolder,

3944 InsertSliceOpCastFolder,

3945 InsertSliceOpSourceCastInserter>(context);

3946}

3947

3948llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {

3950}

3951

3952

3953MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {

3954 return getDestMutable();

3955}

3956

3957Operation *ParallelInsertSliceOp::getIteratingParent() {

3958

3959 if (auto combiningOp =

3960 dyn_cast(getOperation()->getParentOp()))

3961 return combiningOp->getParentOp();

3962 return nullptr;

3963}

3964

3965

3966

3967

3968

3969void ScatterOp::getAsmResultNames(

3970 function_ref<void(Value, StringRef)> setNameFn) {

3971 setNameFn(getResult(), "scatter");

3972}

3973

3974LogicalResult ScatterOp::verify() {

3975 int64_t destRank = getDestType().getRank();

3976 ArrayRef<int64_t> scatterDims = getScatterDims();

3978 getIndicesType().getShape(), destRank,

3979 "scatter", "dest")))

3980 return failure();

3981

3982 if (!getUnique())

3983 return emitOpError("requires 'unique' attribute to be set");

3984

3985

3986

3987

3988

3989

3990 RankedTensorType expectedSourceType = GatherOp::inferResultType(

3991 getDestType(), getIndicesType(), scatterDims, false);

3992 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(

3993 getDestType(), getIndicesType(), scatterDims, true);

3994 if (getSourceType() != expectedSourceType &&

3995 getSourceType() != expectedRankReducedSourceType) {

3997 "mismatch: "

3998 "expected ")

3999 << expectedSourceType << " or its rank-reduced variant "

4000 << expectedRankReducedSourceType << " (got: " << getSourceType()

4001 << ")";

4002 }

4003

4005}

4006

4007

4008

4009

4010

4011void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,

4012 Type aggregateType, ValueRange dynamicSizes) {

4013 build(builder, result, aggregateType, element, dynamicSizes);

4014}

4015

4016void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,

4017 ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {

4018 auto aggregateType = RankedTensorType::get(staticShape, element.getType());

4019 build(builder, result, aggregateType, element, dynamicSizes);

4020}

4021

4022void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,

4023 ArrayRef sizes) {

4024 SmallVector<int64_t> staticShape;

4025 SmallVector dynamicSizes;

4027 build(builder, result, element, staticShape, dynamicSizes);

4028}

4029

4030void SplatOp::getAsmResultNames(

4031 function_ref<void(Value, StringRef)> setNameFn) {

4032 setNameFn(getResult(), "splat");

4033}

4034

4035LogicalResult SplatOp::verify() {

4037 return emitOpError("incorrect number of dynamic sizes, has ")

4039 << getType().getNumDynamicDims();

4041}

4042

4043LogicalResult

4044SplatOp::reifyResultShapes(OpBuilder &builder,

4046 reifiedReturnShapes.resize(1, SmallVector(getType().getRank()));

4047 unsigned ctr = 0;

4048 for (int64_t i = 0; i < getType().getRank(); ++i) {

4049 if (getType().isDynamicDim(i)) {

4051 } else {

4052 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));

4053 }

4054 }

4056}

4057

4058OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {

4059 auto constOperand = adaptor.getInput();

4060 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))

4061 return {};

4062

4063

4064 if (getType().hasStaticShape())

4065 return {};

4066

4067

4068

4070}

4071

4072

4073

4074

4076

4077

4078

4079 if (isa(op.getOperation()) ||

4080 isa(op.getOperation()))

4081 return false;

4082

4084}

4085

4086

4087

4088

4089

4090

4091

4092

4093

4094

4095

4096

4097

4098

4099

4100

4101

4106

4109

4110

4111

4113 isalinalg::RelayoutOpInterface(*op))

4114 return failure();

4115

4119

4120

4121 auto newOp = clone(rewriter, op, newResultTypes, newOperands);

4122

4124 replacements.reserve(newOp->getNumResults());

4125 for (auto [oldResult, newResult] :

4126 llvm::zip(op->getResults(), newOp->getResults())) {

4127 if (newResult.getType() != oldResult.getType()) {

4128 replacements.push_back(tensor::CastOp::create(

4129 rewriter, op->getLoc(), oldResult.getType(), newResult));

4130 } else {

4131 replacements.push_back(newResult);

4132 }

4133 }

4134 rewriter.replaceOp(op, replacements);

4135

4137 }

4138};

4139

4140

4141

4142

4143

4144void TensorDialect::getCanonicalizationPatterns(

4145 RewritePatternSet &results) const {

4146 results.add(getContext());

4147}

4148

4149

4150

4151

4152

4153#define GET_OP_CLASSES

4154#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"

p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")

Given a list of lists of parsed operands, populates uniqueOperands with unique operands.

static Type getElementType(Type type)

Determine the element type of type.

static int64_t getNumElements(Type t)

Compute the total number of elements in the given type, also taking into account nested types.

b

Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...

*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`

static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)

Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...

static TensorType joinShapes(TensorType one, TensorType two)

Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.

Definition TensorOps.cpp:416

static Value foldExtractAfterInsert(ExtractOp extractOp)

If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...

Definition TensorOps.cpp:1366

static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)

Definition TensorOps.cpp:1552

static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)

Definition TensorOps.cpp:2428

static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)

Definition TensorOps.cpp:4075

static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)

If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...

Definition TensorOps.cpp:2893

static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)

Definition TensorOps.cpp:2733

static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)

If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...

Definition TensorOps.cpp:2755

static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)

Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.

Definition TensorOps.cpp:2842

static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)

Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...

Definition TensorOps.cpp:180

static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)

Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....

Definition TensorOps.cpp:136

static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)

Folds round-trip extract/insert slice op pairs.

Definition TensorOps.cpp:2913

static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)

Definition TensorOps.cpp:2037

static ArrayRef< int64_t > getShape(Type type)

Returns the shape of the given type.

Base type for affine expression.

Attributes are known-constant values of operations.

ValueTypeRange< BlockArgListType > getArgumentTypes()

Return a range containing the types of the arguments for this block.

unsigned getNumArguments()

Operation * getTerminator()

Get the terminator operation of this block.

BlockArgListType getArguments()

iterator_range< iterator > without_terminator()

Return an iterator range over the operation within this block excluding the terminator operation at t...

This class is a general helper class for creating context-global objects like types,...

IntegerAttr getIndexAttr(int64_t value)

AffineExpr getAffineSymbolExpr(unsigned position)

Ty getType(Args &&...args)

Get or construct an instance of the type Ty with provided arguments.

AffineExpr getAffineDimExpr(unsigned position)

AffineMap getConstantAffineMap(int64_t val)

Returns a single constant result affine map with 0 dimensions and 0 symbols.

MLIRContext * getContext() const

static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)

Construct a dense elements attribute from a raw buffer representing the data for this attribute.

bool isSplat() const

Returns true if this attribute corresponds to a splat, i.e.

ArrayRef< char > getRawData() const

Return the raw storage data held by this attribute.

static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)

Constructs a dense elements attribute from an array of element values.

auto lookupOrDefault(T from) const

Lookup a mapped value within the map.

void map(Value from, Value to)

Inserts a new mapping for 'from' to 'to'.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

This class represents a single result from folding an operation.

This class represents an operand of an operation.

This is a value defined by a result of an operation.

unsigned getResultNumber() const

Returns the number of this result.

Operation is the basic unit of execution within MLIR.

MutableArrayRef< OpOperand > getOpOperands()

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

result_range getResults()

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)

Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)

Move the blocks that belong to "region" before the given position in another region "parent".

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...

bool hasRank() const

Returns if this type is ranked, i.e. it has a known number of dimensions.

Type getElementType() const

Returns the element type of this tensor type.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

type_range getType() const

type_range getTypes() const

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Block * getParentBlock()

Return the Block in which this Value is defined.

Location getLoc() const

Return the location of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)

Speculatability

This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...

constexpr auto Speculatable

constexpr auto NotSpeculatable

OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)

Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)

Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...

Operation::operand_range getIndices(Operation *op)

Get the indices that the given load/store operation is operating on.

DynamicAPInt getIndex(const ConeV &cone)

Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...

Value constantIndex(OpBuilder &builder, Location loc, int64_t i)

Generates a constant of index type.

LogicalResult foldTensorCast(Operation *op)

Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.

Definition TensorOps.cpp:387

bool hasFoldableTensorCastOperand(Operation *op)

Return true if any of the operands of op is a CastOp that can be folded into its consumer,...

Definition TensorOps.cpp:356

void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})

Patterns to fold the extract slice op with its constant operand.

Definition TensorOps.cpp:2694

bool canFoldIntoProducerOp(CastOp castOp)

Determines whether the tensor::CastOp casts to a more static version of the source tensor.

Definition TensorOps.cpp:349

SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)

Assuming that op contains at least one operand that is a foldable CastOp (i.e.

Definition TensorOps.cpp:365

bool canFoldIntoConsumerOp(CastOp castOp)

Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.

Definition TensorOps.cpp:318

Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)

Create a rank-reducing InsertSliceOp @[0 .

Definition TensorOps.cpp:3185

Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)

Create a rank-reducing ExtractSliceOp @[0 .

Definition TensorOps.cpp:2780

bool isSameTypeWithoutEncoding(Type tp1, Type tp2)

Tests if types are the same when ignoring encoding on ranked tensors.

Definition TensorOps.cpp:124

OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)

Return the dimension of the given tensor value.

Definition TensorOps.cpp:57

void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)

Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.

Definition TensorOps.cpp:1437

FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)

This is a helper function for DestinationStyleOpInterface.

Definition TensorOps.cpp:75

bool preservesStaticInformation(Type source, Type target)

Returns true if target is a ranked tensor type that preserves static information available in the sou...

Definition TensorOps.cpp:266

SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)

Return the dimensions of the given tensor value.

Definition TensorOps.cpp:66

LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)

This is a helper function for DestinationStyleOpInterface.

Definition TensorOps.cpp:110

std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn

Function to control the folding of constant and extract slice.

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)

Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...

detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)

Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...

OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)

Convert int64_t to integer attributes of index type and return them as OpFoldResult.

std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)

Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...

SliceVerificationResult

Enum that captures information related to verifier error conditions on slice insert/extract type of o...

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)

Reify the shape of the result of an operation (typically in terms of the shape of its operands).

LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)

Returns "success" when any of the elements in strides is a constant value.

llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn

Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.

static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)

SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)

Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to DimExpr at positions: [0 .

SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)

Given the strides together with a linear index in the dimension space, return the vector-space offset...

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)

Constructs affine maps out of Array<Array>.

bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)

Helper function to check whether the passed in sizes or offsets are valid.

bool wouldOpBeTriviallyDead(Operation *op)

Return true if the given operation would be dead if unused, and has no side effects on memory that wo...

SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

bool isZeroInteger(OpFoldResult v)

Return true if v is an IntegerAttr with value 0.

void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)

Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...

SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)

Convert reassociation indices to affine expressions.

bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)

Return true if the reassociation specification is valid, false otherwise.

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)

Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.

Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)

OpFoldResult getAsOpFoldResult(Value val)

Given a value, try to extract a constant Attribute.

llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)

std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)

Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...

LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)

Returns success if the given two shapes are compatible.

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)

Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...

ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)

Wraps a list of reassociations in an ArrayAttr.

llvm::function_ref< Fn > function_ref

SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)

LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)

Returns "success" when any of the elements in offsetsOrSizes is a constant value.

std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)

Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.

Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....

Definition TensorOps.cpp:4103

LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override

Definition TensorOps.cpp:4107

A canonicalizer wrapper to replace ExtractSliceOps.

Definition TensorOps.cpp:2712

void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)

Definition TensorOps.cpp:2713

Return the canonical type of the result of an extract_slice op.

Definition TensorOps.cpp:2701

RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)

Definition TensorOps.cpp:2702

OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

This represents an operation in an abstracted form, suitable for use with the builder APIs.

Idiomatic saturated operations on values like offsets, sizes, and strides.

static SaturatedInteger wrap(int64_t v)

FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)

bool isValid

If set to "true", the slice bounds verification was successful.

std::string errorMessage

An error message that can be printed during op verification.