MLIR: lib/Dialect/Tensor/IR/TensorOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

33 #include "llvm/ADT/DenseSet.h"

34 #include "llvm/ADT/STLExtras.h"

35 #include "llvm/ADT/SmallBitVector.h"

36 #include "llvm/ADT/StringRef.h"

37 #include "llvm/Support/Casting.h"

38 #include "llvm/Support/LogicalResult.h"

39 #include "llvm/Support/MathExtras.h"

40 #include

41 #include

42 #include

43

44 using namespace mlir;

46

47 using llvm::divideCeilSigned;

48 using llvm::divideFloorSigned;

49 using llvm::mod;

50

51

52

56 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))

57 return op;

58 if (complex::ConstantOp::isBuildableWith(value, type))

59 return builder.createcomplex::ConstantOp(loc, type,

60 llvm::cast(value));

61 return nullptr;

62 }

63

65 int64_t dim) {

66 auto tensorType = llvm::cast(value.getType());

67 if (tensorType.isDynamicDim(dim))

68 return builder.createOrFoldtensor::DimOp(loc, value, dim);

69

70 return builder.getIndexAttr(tensorType.getDimSize(dim));

71 }

72

75 auto tensorType = llvm::cast(value.getType());

77 for (int64_t i = 0; i < tensorType.getRank(); ++i)

78 result.push_back(getMixedSize(builder, loc, value, i));

79 return result;

80 }

81

84 auto tensorType = llvm::dyn_cast(opResult.getType());

85 assert(tensorType && "expected tensor type");

86

87

88

89 auto destOp = opResult.getDefiningOp();

90 if (destOp)

91 return destOp.getTiedOpOperand(opResult)->get();

92

93

96

97

99 if (!tensorType.hasStaticShape()) {

100

103 return failure();

105 } else {

106

107 for (int64_t sz : tensorType.getShape())

109 }

110

111

112 Value emptyTensor =

113 b.createtensor::EmptyOp(loc, mixedSizes, tensorType.getElementType());

114 return emptyTensor;

115 }

116

121 if (llvm::isa(opResult.getType())) {

123 if (failed(destination))

124 return failure();

125 result.push_back(*destination);

126 }

127 }

128 return success();

129 }

130

132 if (auto rtp1 = llvm::dyn_cast(tp1)) {

133 if (auto rtp2 = llvm::dyn_cast(tp2))

134 return rtp1.getShape() == rtp2.getShape() &&

135 rtp1.getElementType() == rtp2.getElementType();

136 return false;

137 }

138 return tp1 == tp2;

139 }

140

141

142

145 llvm::SmallBitVector droppedDims(mixedSizes.size());

146 int64_t shapePos = reducedShape.size() - 1;

147

148 for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {

149 size_t idx = mixedSizes.size() - size.index() - 1;

150

151 bool isStaticUnitSize =

152 isa(size.value()) &&

153 llvm::cast(cast(size.value())).getInt() == 1;

154

155 if (shapePos < 0) {

156

157

158 assert(isStaticUnitSize && "expected unit dim");

159 droppedDims.set(idx);

160 continue;

161 }

162

163

164 if (!isStaticUnitSize) {

165 --shapePos;

166 continue;

167 }

168

169

170 if (reducedShape[shapePos] == 1) {

171 --shapePos;

172 continue;

173 }

174

175

176 droppedDims.set(idx);

177 }

178

179 assert(shapePos < 0 && "dimension mismatch");

180 return droppedDims;

181 }

182

183

184

185

186 static RankedTensorType

190 assert(type.getNumDynamicDims() == dynamicSizes.size() &&

191 "incorrect number of dynamic sizes");

192

193

194 unsigned ctr = 0;

195 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {

196 if (type.isDynamicDim(i)) {

197 Value dynamicSize = dynamicSizes[ctr++];

199 if (cst.has_value()) {

200

201 if (cst.value() < 0) {

202 foldedDynamicSizes.push_back(dynamicSize);

203 continue;

204 }

205 staticShape[i] = *cst;

206 } else {

207 foldedDynamicSizes.push_back(dynamicSize);

208 }

209 }

210 }

211

213 type.getEncoding());

214 }

215

216

217

218

219

220 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

221 if (inputs.size() != 1 || outputs.size() != 1)

222 return false;

223 Type a = inputs.front(), b = outputs.front();

224 auto aT = dyn_cast(a);

225 auto bT = dyn_cast(b);

226 if (!aT || !bT)

227 return false;

228

229 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())

230 return false;

231

233 }

234

235 namespace {

236

237

238

239 struct ChainedTensorBitcast : public OpRewritePattern {

241

242 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,

244 auto tensorBitcastOperand =

245 tensorBitcast.getOperand().getDefiningOp();

246 if (!tensorBitcastOperand)

247 return failure();

248

249 auto resultType = cast(tensorBitcast.getType());

250 rewriter.replaceOpWithNewOp(tensorBitcast, resultType,

251 tensorBitcastOperand.getOperand());

252 return success();

253 }

254 };

255

256 }

257

258 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,

260 results.add(context);

261 }

262

263

264

265

266

267 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {

268 setNameFn(getResult(), "cast");

269 }

270

271

272

274 auto sourceType = llvm::dyn_cast(source);

275 auto targetType = llvm::dyn_cast(target);

276

277

278 if (!sourceType || !targetType)

279 return false;

280

281

282 if (sourceType.getElementType() != targetType.getElementType())

283 return false;

284

285

286 if (sourceType.getRank() != targetType.getRank())

287 return false;

288

289

290 if (sourceType.getEncoding() != targetType.getEncoding())

291 return false;

292

293

294 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {

295 if (!ShapedType::isDynamic(std::get<0>(t)) &&

296 ShapedType::isDynamic(std::get<1>(t)))

297 return false;

298 }

299

300 return true;

301 }

302

303

304

305

306

307

308

309

310

311

312

313

314

315

316

317

318

319

320

321

322

323

324

326 if (!castOp)

327 return false;

328

329

330

332 castOp.getSource().getType());

333 }

334

335

336

337

338

339

340

341

342

343

344

345

346

347

348

349

350

351

352

353

354

355

357 if (!castOp)

358 return false;

360 castOp.getType());

361 }

362

365 if (llvm::isa(opOperand.get()))

366 return false;

367 auto castOp = opOperand.get().getDefiningOptensor::CastOp();

368 return castOp && canFoldIntoConsumerOp(castOp);

369 });

370 }

371

375 newOperands.reserve(op->getNumOperands());

376

378

379

380 int64_t dpsInitIdx = 0;

381 for (OpOperand &opOperand : op->getOpOperands()) {

382 auto tensorCastOp = opOperand.get().getDefiningOptensor::CastOp();

384 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());

385 if (op.isDpsInit(&opOperand) &&

386 !llvm::isa(newOperands.back().getType()))

387 newResTy[dpsInitIdx++] = newOperands.back().getType();

388 }

389 return newOperands;

390 }

391

392

393

395 bool folded = false;

397 auto castOp = operand.get().getDefiningOptensor::CastOp();

399 operand.set(castOp.getOperand());

400 folded = true;

401 }

402 }

403 return success(folded);

404 }

405

406 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

407 if (inputs.size() != 1 || outputs.size() != 1)

408 return false;

409 Type a = inputs.front(), b = outputs.front();

410 auto aT = llvm::dyn_cast(a);

411 auto bT = llvm::dyn_cast(b);

412 if (!aT || !bT)

413 return false;

414

415 if (aT.getElementType() != bT.getElementType())

416 return false;

417

419 }

420

421

422

425

427 return two;

429 return one;

430

431 int64_t rank = one.getRank();

432 if (rank != two.getRank())

433 return {};

434

436 join.reserve(rank);

437 for (int64_t i = 0; i < rank; ++i) {

438 if (one.isDynamicDim(i)) {

439 join.push_back(two.getDimSize(i));

440 continue;

441 }

442 if (two.isDynamicDim(i)) {

443 join.push_back(one.getDimSize(i));

444 continue;

445 }

446 if (one.getDimSize(i) != two.getDimSize(i))

447 return {};

448 join.push_back(one.getDimSize(i));

449 }

451 }

452

453 namespace {

454

455

456

459

460 LogicalResult matchAndRewrite(CastOp tensorCast,

462 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp();

463

464 if (!tensorCastOperand)

465 return failure();

466

467 auto sourceType =

468 llvm::cast(tensorCastOperand.getOperand().getType());

469 auto intermediateType = llvm::cast(tensorCastOperand.getType());

470 auto resultType = llvm::cast(tensorCast.getType());

471

472

473

474 auto firstJoin =

476

477

478 if (!firstJoin)

479 return failure();

480

481

482

483

484 auto newJoin = joinShapes(sourceType, resultType);

485 if (firstJoin != newJoin)

486 return failure();

487

488 rewriter.replaceOpWithNewOp(tensorCast, resultType,

489 tensorCastOperand.getOperand());

490 return success();

491 }

492 };

493

494

495

496

497

498

499

500

501

502

503

504

505

506 struct TensorCastExtractSlice : public OpRewritePattern {

508

509 LogicalResult matchAndRewrite(CastOp tensorCast,

511 auto extractOperand =

512 tensorCast.getOperand().getDefiningOp();

513

514

515 auto rankedResultType =

516 llvm::dyn_cast(tensorCast.getType());

517 if (!rankedResultType)

518 return failure();

519

521 rankedResultType.getShape() ==

522 llvm::cast(tensorCast.getSource().getType())

523 .getShape())

524 return failure();

525

528 extractOperand.getStaticSizes(), extractOperand.getType().getShape());

529 size_t dimIndex = 0;

530 for (size_t i = 0, e = sizes.size(); i < e; i++) {

531 if (dimMask && dimMask->count(i))

532 continue;

533 int64_t dim = rankedResultType.getShape()[dimIndex++];

534 if (ShapedType::isDynamic(dim))

535 continue;

536 sizes[i] = rewriter.getIndexAttr(dim);

537 }

538

539 rewriter.replaceOpWithNewOp(

540 tensorCast, rankedResultType, extractOperand.getSource(),

541 extractOperand.getMixedOffsets(), sizes,

542 extractOperand.getMixedStrides());

543 return success();

544 }

545 };

546

547 }

548

549 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,

551 results.add<ChainedTensorCast, TensorCastExtractSlice>(context);

552 }

553

554

555

556

557

558 RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {

559 assert(!inputTypes.empty() && "cannot concatenate 0 tensors");

560 auto tensorTypes =

561 llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {

562 return llvm::cast(type);

563 }));

564 int64_t concatRank = tensorTypes[0].getRank();

565

566

567 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");

568

570 for (int64_t i = 0, e = concatRank; i < e; ++i) {

571 if (i == dim)

572 continue;

574 for (auto tensorType : tensorTypes)

577 }

579 for (auto tensorType : tensorTypes)

580 concatSize =

582 sizes[dim] = concatSize.asInteger();

584 }

585

588 FailureOr resultType =

589 inferResultType(dim, inputs.getTypes());

590 assert(succeeded(resultType) && "failed to infer concatenation result type");

591 build(builder, result, *resultType, dim, inputs);

592 }

593

595 if (getInputs().size() < 1)

596 return emitOpError("requires at least one input");

597

599 for (auto input : getInputs())

600 inputTypes.push_back(cast(input.getType()));

601

602 RankedTensorType resultType = getResultType();

603 int64_t resultRank = getRank();

604 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {

605 return type.getRank() != resultRank;

606 }))

607 return emitOpError("rank of concatenated inputs must match result rank");

608

609 Type resultElementType = resultType.getElementType();

610 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {

611 return type.getElementType() != resultElementType;

612 }))

613 return emitOpError("inputs and result element type must match");

614

615 int64_t dim = getDim();

616 if (dim >= resultRank)

617 return emitOpError("concatenation dim must be less than the tensor rank");

618

620 for (int64_t i = 0, e = resultRank; i < e; ++i) {

621 if (i == dim)

622 continue;

624 for (auto tensorType : inputTypes) {

625 FailureOr maybeSize =

627 if (failed(maybeSize))

628 return emitOpError("static concatenation size mismatch along ")

629 << "non-concatenated dimension " << i;

630 size = *maybeSize;

631 }

633 }

635 for (auto tensorType : inputTypes)

636 concatSize =

638 sizes[dim] = concatSize.asInteger();

639 auto inferredResultType =

641

642 for (auto [inferredSize, actualSize] :

643 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {

644 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||

645 ShapedType::isDynamic(actualSize);

646 if (!hasDynamic && inferredSize != actualSize)

647 return emitOpError("result type ")

648 << resultType << "does not match inferred shape "

649 << inferredResultType << " static sizes";

650 }

651

652 return success();

653 }

654

655 FailureOr<SmallVector> ConcatOp::decomposeOperation(OpBuilder &builder) {

656 size_t numInputs = getInputs().size();

657 uint64_t concatDim = getDim();

658

660 inputShapes.reserve(numInputs);

662 concatOffsets.reserve(numInputs);

664

669 for (auto [index, input] : llvm::enumerate(getInputs())) {

672 if (index == 0) {

673 outputShape = inputShape;

674 concatOffsets.push_back(zero);

675 } else {

676 concatOffsets.push_back(outputShape[concatDim]);

678 builder, loc, addExpr,

679 {outputShape[concatDim], inputShape[concatDim]});

680 }

681 inputShapes.emplace_back(std::move(inputShape));

682 }

683

684 Value replacement = builder.createtensor::EmptyOp(

685 loc, outputShape, getType().getElementType());

686

687 int64_t rank = getType().getRank();

691 for (auto [index, input] : llvm::enumerate(getInputs())) {

692 offsets[concatDim] = concatOffsets[index];

693 auto insertSlice = builder.createtensor::InsertSliceOp(

694 loc, input, replacement, offsets, inputShapes[index], strides);

695 replacement = insertSlice.getResult();

696 }

697 if (replacement.getType() != getType()) {

698 replacement = builder.createtensor::CastOp(loc, getType(), replacement);

699 }

701 }

702

703 LogicalResult

707 int64_t dim = getDim();

708 RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());

709

710 Value init = inputs[0];

711 int64_t rank = getType().getRank();

712

714

715

716

717

718 for (int64_t i = 0; i < rank; ++i) {

719 if (i == dim)

720 continue;

721 if (getType().isDynamicDim(i)) {

722 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));

723 } else if (!inferredResultType.isDynamicDim(i)) {

725 builder, getLoc(),

726 builder.getIndexAttr(inferredResultType.getDimSize(i)));

727 } else {

728 reifiedReturnShapes[0][i] =

729 builder.createtensor::DimOp(init.getLoc(), init, i).getResult();

730 }

731 }

732

733 if (getType().isDynamicDim(dim)) {

734

738 for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {

740 sizes.push_back(

741 builder.createOrFoldtensor::DimOp(input.getLoc(), input, dim));

742 }

744 builder, getLoc(),

746 } else {

747

748

749 reifiedReturnShapes[0][dim] =

751 }

752 return success();

753 }

754

755 void ConcatOp::getAsmResultNames(

757 setNameFn(getResult(), "concat");

758 }

759

762 if (inputs.size() == 1 && inputs[0].getType() == getResultType())

763 return inputs[0];

764 return {};

765 }

766

767 namespace {

768

769 struct SingleInputConcatOp : public OpRewritePattern {

771

772 LogicalResult matchAndRewrite(ConcatOp concatOp,

774 if (concatOp.getInputs().size() != 1)

775 return failure();

776 rewriter.replaceOpWithNewOp(concatOp, concatOp.getResultType(),

777 concatOp.getInputs()[0]);

778 return success();

779 }

780 };

781

782

783

784

785

786

787

788

789

790

791

792

793

794

795

796

797

798

799

800

801 struct InferConcatOperandTypes : public OpRewritePattern {

803

804 LogicalResult matchAndRewrite(ConcatOp concatOp,

806 int64_t dim = concatOp.getDim();

807 RankedTensorType inferredResultType =

808 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());

809

810

811 LogicalResult matched = failure();

812

813

815 for (auto [operandIdx, operandType] :

817

818 inferredOperandShape[dim] =

819 cast(operandType).getDimSize(dim);

821 inferredOperandShape, inferredResultType.getElementType());

822

823

825 matched = success();

826

827

828 auto castOp =

829 rewriter.create(concatOp->getLoc(), inferredOperandType,

830 concatOp.getOperand(operandIdx));

831 rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {

832 concatOp->setOperand(operandIdx, castOp->getResult(0));

833 });

834 }

835 }

836

837 return matched;

838 }

839 };

840

841

842

843

844

845

846

847

848

849

850

851

852

853

854

855 struct InferConcatResultType : public OpRewritePattern {

857

858 LogicalResult matchAndRewrite(ConcatOp concatOp,

860 int64_t dim = concatOp.getDim();

861 RankedTensorType inferredResultType =

862 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());

863

864

866 concatOp.getResultType())) {

867 return failure();

868 }

869

870 auto newConcatOp = rewriter.create(

871 concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());

872 rewriter.replaceOpWithNewOp(concatOp, concatOp.getResultType(),

873 newConcatOp);

874

875 return success();

876 }

877 };

878 }

879

880 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,

882 results

883 .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(

884 context);

885 }

886

887

888

889

890

891 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {

892 setNameFn(getResult(), "dim");

893 }

894

896 int64_t index) {

898 Value indexValue = builder.createarith::ConstantIndexOp(loc, index);

899 build(builder, result, source, indexValue);

900 }

901

902 std::optional<int64_t> DimOp::getConstantIndex() {

904 }

905

910

911 auto rankedSourceType = dyn_cast(getSource().getType());

912 if (!rankedSourceType)

914

917

919 }

920

923 setResultRange(getResult(),

925 }

926

927 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {

928

929 auto index = llvm::dyn_cast_if_present(adaptor.getIndex());

930 if (!index)

931 return {};

932

933

934 auto tensorType = llvm::dyn_cast(getSource().getType());

935 if (!tensorType)

936 return {};

937

938

939

940 int64_t indexVal = index.getInt();

941 if (indexVal < 0 || indexVal >= tensorType.getRank())

942 return {};

943

944

945 if (!tensorType.isDynamicDim(index.getInt())) {

947 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);

948 }

949

950 Operation *definingOp = getSource().getDefiningOp();

951

952

953 if (auto fromElements = dyn_cast_or_nulltensor::GenerateOp(definingOp)) {

954 auto resultType =

955 llvm::cast(fromElements.getResult().getType());

956

957

958 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));

959

960

961 auto dynExtents = fromElements.getDynamicExtents().begin();

962 for (auto dim : resultType.getShape().take_front(index.getInt()))

963 if (ShapedType::isDynamic(dim))

964 dynExtents++;

965

966 return Value{*dynExtents};

967 }

968

969

970 unsigned unsignedIndex = index.getValue().getZExtValue();

971

972 if (auto sliceOp = dyn_cast_or_nulltensor::ExtractSliceOp(definingOp)) {

973

974

975 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&

976 sliceOp.isDynamicSize(unsignedIndex)) {

977 return {sliceOp.getDynamicSize(unsignedIndex)};

978 }

979 }

980

981

983 return getResult();

984

985 return {};

986 }

987

988 namespace {

989

992

993 LogicalResult matchAndRewrite(DimOp dimOp,

995 auto castOp = dimOp.getSource().getDefiningOp();

996 if (!castOp)

997 return failure();

998 Value newSource = castOp.getOperand();

999 rewriter.replaceOpWithNewOp(dimOp, newSource, dimOp.getIndex());

1000 return success();

1001 }

1002 };

1003

1004

1005

1008

1009 LogicalResult matchAndRewrite(DimOp dimOp,

1011 auto source = dimOp.getSource();

1012 auto destOp = source.getDefiningOp();

1013 if (!destOp)

1014 return failure();

1015

1016 auto resultIndex = cast(source).getResultNumber();

1017 auto *initOperand = destOp.getDpsInitOperand(resultIndex);

1018

1020 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });

1021 return success();

1022 }

1023 };

1024

1025

1026

1029

1030 LogicalResult matchAndRewrite(DimOp dim,

1032 auto reshape = dim.getSource().getDefiningOp();

1033

1034 if (!reshape)

1035 return failure();

1036

1037

1038

1040 Location loc = dim.getLoc();

1042 rewriter.create(loc, reshape.getShape(), dim.getIndex());

1043 if (extract.getType() != dim.getType())

1044 extract =

1045 rewriter.createarith::IndexCastOp(loc, dim.getType(), extract);

1046 rewriter.replaceOp(dim, extract);

1047 return success();

1048 }

1049 };

1050 }

1051

1052 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,

1054 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);

1055 }

1056

1057

1058

1059

1060

1064 assert(none_of(staticShape, ShapedType::isDynamic) &&

1065 "expected only static sizes");

1066 build(builder, result, staticShape, elementType, ValueRange{}, encoding);

1067 }

1068

1073 build(builder, result, tensorType, dynamicSizes);

1074 }

1075

1082 build(builder, result, staticShape, elementType, dynamicSizes, encoding);

1083 }

1084

1087 return emitOpError("incorrect number of dynamic sizes, has ")

1089 << getType().getNumDynamicDims();

1090 return success();

1091 }

1092

1093 LogicalResult

1097 unsigned ctr = 0;

1098 for (int64_t i = 0; i < getType().getRank(); ++i) {

1099 if (getType().isDynamicDim(i)) {

1101 } else {

1102 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));

1103 }

1104 }

1105 return success();

1106 }

1107

1108 Value EmptyOp::getDynamicSize(unsigned idx) {

1109 assert(getType().isDynamicDim(idx) && "expected dynamic dim");

1110 unsigned ctr = 0;

1111 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)

1112 if (getType().isDynamicDim(i))

1113 ++ctr;

1115 }

1116

1119 unsigned ctr = 0;

1121 for (int64_t i = 0; i < getType().getRank(); ++i) {

1122 if (getType().isDynamicDim(i)) {

1124 } else {

1125 result.push_back(b.getIndexAttr(getType().getShape()[i]));

1126 }

1127 }

1128 return result;

1129 }

1130

1131 namespace {

1132

1133

1134

1135

1136

1137

1138

1139

1140

1141

1142

1143 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern {

1145

1146 LogicalResult matchAndRewrite(EmptyOp op,

1150 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);

1151

1152

1153 if (foldedTensorType == op.getType())

1154 return failure();

1155

1156 auto newOp = rewriter.create(op.getLoc(), foldedTensorType,

1157 foldedDynamicSizes);

1158 rewriter.replaceOpWithNewOptensor::CastOp(op, op.getType(), newOp);

1159 return success();

1160 }

1161 };

1162

1163 struct FoldEmptyTensorWithDimOp : public OpRewritePattern {

1165

1166 LogicalResult matchAndRewrite(tensor::DimOp dimOp,

1168 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();

1169 auto emptyTensorOp = dimOp.getSource().getDefiningOp();

1170 if (!emptyTensorOp || !maybeConstantIndex)

1171 return failure();

1172 auto emptyTensorType = emptyTensorOp.getType();

1173 if (*maybeConstantIndex < 0 ||

1174 *maybeConstantIndex >= emptyTensorType.getRank() ||

1175 !emptyTensorType.isDynamicDim(*maybeConstantIndex))

1176 return failure();

1178 emptyTensorOp.getDynamicSize(*maybeConstantIndex));

1179 return success();

1180 }

1181 };

1182

1183

1184

1185

1186

1187

1188

1189

1190

1191

1192

1193

1194

1195

1196

1197

1198 struct FoldEmptyTensorWithCastOp : public OpRewritePattern {

1200

1201 LogicalResult matchAndRewrite(CastOp castOp,

1204 return failure();

1205 auto producer = castOp.getSource().getDefiningOp();

1206 if (!producer)

1207 return failure();

1208

1209 auto resultType =

1210 llvm::cast(castOp->getResult(0).getType());

1214 newMixedSizes.reserve(currMixedSizes.size());

1215 assert(resultShape.size() == currMixedSizes.size() &&

1216 "mismatch in result shape and sizes of empty op");

1217 for (auto it : llvm::zip(resultShape, currMixedSizes)) {

1218 int64_t newDim = std::get<0>(it);

1220

1221

1222 if (auto attr = llvm::dyn_cast_if_present(currDim)) {

1223 if (ShapedType::isDynamic(newDim) ||

1224 newDim != llvm::cast(attr).getInt()) {

1225

1226

1227

1229 producer, "mismatch in static value of shape of empty tensor "

1230 "result and cast result");

1231 }

1232 newMixedSizes.push_back(attr);

1233 continue;

1234 }

1235

1236

1237

1238 if (!ShapedType::isDynamic(newDim)) {

1239 newMixedSizes.push_back(rewriter.getIndexAttr(newDim));

1240 continue;

1241 }

1242

1243

1244

1245 newMixedSizes.push_back(currDim);

1246 }

1247

1248

1250 resultType.getElementType());

1251 return success();

1252 }

1253 };

1254

1255 }

1256

1257 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,

1259 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,

1260 ReplaceEmptyTensorStaticShapeDims>(context);

1261 }

1262

1263

1264

1265

1266

1267 namespace {

1268

1269

1270

1271

1272

1273

1274

1275

1276

1277 struct ExtractFromTensorCast : public OpRewritePatterntensor::ExtractOp {

1279

1280 LogicalResult matchAndRewrite(tensor::ExtractOp extract,

1282 auto tensorCast = extract.getTensor().getDefiningOptensor::CastOp();

1283 if (!tensorCast)

1284 return failure();

1285 if (!llvm::isa(tensorCast.getSource().getType()))

1286 return failure();

1288 extract, tensorCast.getSource(), extract.getIndices());

1289 return success();

1290 }

1291 };

1292

1293

1294

1295

1296

1297

1298

1299

1300

1301

1302

1303 struct ExtractFromCollapseShape : public OpRewritePatterntensor::ExtractOp {

1305

1306 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,

1308 auto collapseOp =

1309 extractOp.getTensor().getDefiningOptensor::CollapseShapeOp();

1310 if (!collapseOp)

1311 return failure();

1312 if (!collapseOp.getSrcType().hasStaticShape())

1313 return failure();

1314

1315 auto sourceSizes = collapseOp.getSrcType().getShape();

1316

1318 extractOp.getIndices().end());

1320 for (auto [index, group] :

1321 llvm::zip(indices, collapseOp.getReassociationIndices())) {

1322 assert(!group.empty() && "association indices groups cannot be empty");

1323 auto groupSize = group.size();

1324

1325 if (groupSize == 1) {

1326 sourceIndices.push_back(index);

1327 continue;

1328 }

1329

1331 llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });

1332 auto delinearize = rewriter.createaffine::AffineDelinearizeIndexOp(

1333 extractOp.getLoc(), index, basis, true);

1334 llvm::append_range(sourceIndices, delinearize.getResults());

1335 }

1336 if (collapseOp.getReassociationIndices().empty()) {

1338 int64_t srcRank =

1339 cast(collapseOp.getSrcType()).getRank();

1341 rewriter, extractOp.getLoc(), zeroAffineMap,

1343 for (int64_t i = 0; i < srcRank; i++) {

1344 sourceIndices.push_back(

1346 }

1347 }

1348

1350 extractOp, collapseOp.getSrc(), sourceIndices);

1351 return success();

1352 }

1353 };

1354

1355 }

1356

1357 void ExtractOp::getAsmResultNames(

1359 setNameFn(getResult(), "extracted");

1360 }

1361

1363

1364 auto tensorType = llvm::cast(getTensor().getType());

1365 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))

1366 return emitOpError("incorrect number of indices for extract_element");

1367 return success();

1368 }

1369

1370

1371

1372

1373

1375 auto insertOp = extractOp.getTensor().getDefiningOp();

1376

1377 auto isSame = [](Value a, Value b) {

1379 };

1380 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&

1381 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))

1382 return insertOp.getScalar();

1383

1384 return {};

1385 }

1386

1387 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {

1388 if (Attribute tensor = adaptor.getTensor()) {

1389

1390

1391 if (auto splatTensor = llvm::dyn_cast(tensor))

1392 return splatTensor.getSplatValue<Attribute>();

1393

1394

1395 if (isa(tensor))

1396 return {};

1397 }

1398

1399

1401 for (Attribute indice : adaptor.getIndices()) {

1402 if (!indice || !llvm::isa(indice))

1403 return {};

1404 indices.push_back(llvm::cast(indice).getInt());

1405 }

1406

1407

1408 if (auto fromElementsOp = getTensor().getDefiningOp()) {

1409 auto tensorType = llvm::cast(fromElementsOp.getType());

1410 auto rank = tensorType.getRank();

1411 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&

1412 "rank mismatch");

1413 int flatIndex = 0;

1414 int stride = 1;

1415 for (int i = rank - 1; i >= 0; --i) {

1416 flatIndex += indices[i] * stride;

1417 stride *= tensorType.getDimSize(i);

1418 }

1419

1420

1421 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||

1422 flatIndex < 0)

1423 return {};

1424 return fromElementsOp.getElements()[flatIndex];

1425 }

1426

1427

1428 if (Attribute tensor = adaptor.getTensor()) {

1429 auto elementsAttr = llvm::dyn_cast(tensor);

1430 if (elementsAttr && elementsAttr.isValidIndex(indices))

1431 return elementsAttr.getValues<Attribute>()[indices];

1432 }

1433

1435 return result;

1436

1437 return {};

1438 }

1439

1440 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,

1442 results.add(context);

1443 }

1444

1447 patterns.add(patterns.getContext());

1448 }

1449

1450

1451

1452

1453

1454 void FromElementsOp::getAsmResultNames(

1456 setNameFn(getResult(), "from_elements");

1457 }

1458

1461 assert(!elements.empty() && "expected at least one element");

1463 {static_cast<int64_t>(elements.size())}, elements.front().getType());

1464 build(builder, result, resultType, elements);

1465 }

1466

1467 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {

1468 if (!llvm::is_contained(adaptor.getElements(), nullptr))

1470 return {};

1471 }

1472

1473 namespace {

1474

1475

1476

1477

1478

1479

1480

1481

1482

1483

1484

1485

1486

1487

1488

1489

1490

1491 struct ExtractElementFromIndexCast

1494

1495 LogicalResult matchAndRewrite(tensor::ExtractOp extract,

1497 Location loc = extract.getLoc();

1498 auto indexCast = extract.getTensor().getDefiningOparith::IndexCastOp();

1499 if (!indexCast)

1500 return failure();

1501

1503

1504 auto newExtract = rewriter.createtensor::ExtractOp(

1505 loc, elementTy, indexCast.getIn(), extract.getIndices());

1506

1507 rewriter.replaceOpWithNewOparith::IndexCastOp(extract, extract.getType(),

1508 newExtract);

1509

1510 return success();

1511 }

1512 };

1513

1514 }

1515

1516 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,

1518 results.add(context);

1519 }

1520

1521

1522

1523

1524

1525 void GatherOp::getAsmResultNames(

1527 setNameFn(getResult(), "gather");

1528 }

1529

1530

1531

1532

1533

1534

1535

1536

1537

1538

1539

1540

1541

1542 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,

1543 RankedTensorType indicesType,

1545 bool rankReduced) {

1547 resultShape.reserve(resultShape.size() + sourceType.getRank());

1548 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {

1549 if (llvm::binary_search(gatherDims, idx)) {

1550 if (!rankReduced)

1551 resultShape.push_back(1);

1552 continue;

1553 }

1554 resultShape.push_back(sourceType.getDimSize(idx));

1555 }

1557 }

1558

1559 static LogicalResult

1562 StringRef gatherOrScatter, StringRef sourceOrDest) {

1563 if (dims.empty())

1564 return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";

1565

1566 int64_t numGatherDims = dims.size();

1567 if (numGatherDims > rank)

1569 << "_dims overflow " << sourceOrDest << " rank";

1570 if (indices.empty() || indices.back() != numGatherDims)

1572 << "_dims length must match the size of last dimension of indices";

1573 for (int64_t val : dims) {

1574 if (val < 0)

1576 << "_dims value must be non-negative";

1577 if (val >= rank)

1579 << "_dims value must be smaller than " << sourceOrDest << " rank";

1580 }

1581 for (int64_t i = 1; i < numGatherDims; ++i) {

1582 if (dims[i - 1] >= dims[i])

1584 << "_dims values must be strictly increasing";

1585 }

1586 return success();

1587 }

1588

1590 int64_t sourceRank = getSourceType().getRank();

1593 getIndicesType().getShape(), sourceRank,

1594 "gather", "source")))

1595 return failure();

1596

1597 RankedTensorType expectedResultType = GatherOp::inferResultType(

1598 getSourceType(), getIndicesType(), gatherDims, false);

1599 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(

1600 getSourceType(), getIndicesType(), gatherDims, true);

1601 if (getResultType() != expectedResultType &&

1602 getResultType() != expectedRankReducedResultType) {

1603 return emitOpError("result type "

1604 "mismatch: "

1605 "expected ")

1606 << expectedResultType << " or its rank-reduced variant "

1607 << expectedRankReducedResultType << " (got: " << getResultType()

1608 << ")";

1609 }

1610

1611 return success();

1612 }

1613

1614 OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {

1615 if (OpFoldResult reshapedSource = reshapeConstantSource(

1616 llvm::dyn_cast_if_present(adaptor.getSource()),

1618 return reshapedSource;

1619 return {};

1620 }

1621

1622

1623

1624

1625

1626 void InsertOp::getAsmResultNames(

1628 setNameFn(getResult(), "inserted");

1629 }

1630

1632

1633 auto destType = llvm::cast(getDest().getType());

1634 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))

1635 return emitOpError("incorrect number of indices");

1636 return success();

1637 }

1638

1639 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {

1640 Attribute scalar = adaptor.getScalar();

1641 Attribute dest = adaptor.getDest();

1642 if (scalar && dest)

1643 if (auto splatDest = llvm::dyn_cast(dest))

1644 if (scalar == splatDest.getSplatValue<Attribute>())

1645 return dest;

1646 return {};

1647 }

1648

1649

1650

1651

1652

1653 void GenerateOp::getAsmResultNames(

1655 setNameFn(getResult(), "generated");

1656 }

1657

1661 int idx = 0;

1662 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {

1663 if (getType().isDynamicDim(dim)) {

1664 reifiedReturnShapes[0][dim] = getOperand(idx++);

1665 } else {

1666 reifiedReturnShapes[0][dim] =

1668 }

1669 }

1670 return success();

1671 }

1672

1674

1675

1676 RankedTensorType resultType = llvm::cast(getType());

1677 if (getNumOperands() != resultType.getNumDynamicDims())

1678 return emitError("must have as many index operands as dynamic extents "

1679 "in the result type");

1680 return success();

1681 }

1682

1683 LogicalResult GenerateOp::verifyRegions() {

1684 RankedTensorType resultTy = llvm::cast(getType());

1685

1686 if (!llvm::all_of(getBody().getArgumentTypes(),

1688 return emitError("all body arguments must be index");

1689 if (getBody().getNumArguments() != resultTy.getRank())

1690 return emitError("must have one body argument per input dimension");

1691

1692

1693 auto yieldOp = cast(getBody().getBlocks().front().getTerminator());

1694

1695 if (yieldOp.getValue().getType() != resultTy.getElementType())

1696 return emitOpError(

1697 "body must be terminated with a `yield` operation of the tensor "

1698 "element type");

1699

1700 return success();

1701 }

1702

1703 void GenerateOp::build(

1707 build(b, result, resultTy, dynamicExtents);

1708

1709

1711 Region *bodyRegion = result.regions.front().get();

1712 auto rank = llvm::cast(resultTy).getRank();

1715 Block *bodyBlock =

1716 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);

1718 }

1719

1720 namespace {

1721

1722

1723

1724

1725

1726 struct StaticTensorGenerate : public OpRewritePattern {

1728

1729 LogicalResult matchAndRewrite(GenerateOp generateOp,

1733 generateOp.getType(), generateOp.getDynamicExtents(),

1734 foldedDynamicSizes);

1735

1736

1737 if (foldedTensorType == generateOp.getType())

1738 return failure();

1739

1740 auto loc = generateOp.getLoc();

1741 auto newOp =

1742 rewriter.create(loc, foldedTensorType, foldedDynamicSizes);

1744 newOp.getBody().begin());

1746 generateOp.getType(), newOp);

1747 return success();

1748 }

1749 };

1750

1751

1752

1753

1754

1755

1756

1757

1758

1759

1760

1761

1762 struct ExtractFromTensorGenerate : public OpRewritePatterntensor::ExtractOp {

1764

1765 LogicalResult matchAndRewrite(tensor::ExtractOp extract,

1767 auto tensorFromElements = extract.getTensor().getDefiningOp();

1769 return failure();

1770

1772 Block *body = &tensorFromElements.getBody().front();

1773 mapping.map(body->getArguments(), extract.getIndices());

1775 rewriter.clone(op, mapping);

1776

1777 auto yield = cast(body->getTerminator());

1778

1780 return success();

1781 }

1782 };

1783

1784 }

1785

1786 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,

1788

1789 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);

1790 }

1791

1792

1793

1794

1795

1796 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {

1797 setNameFn(getResult(), "rank");

1798 }

1799

1800 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {

1801

1802 auto type = getOperand().getType();

1803 auto shapedType = llvm::dyn_cast(type);

1804 if (shapedType && shapedType.hasRank())

1806 return IntegerAttr();

1807 }

1808

1809

1810

1811

1812

1813 void ReshapeOp::getAsmResultNames(

1815 setNameFn(getResult(), "reshape");

1816 }

1817

1819 int64_t numElements = 1;

1820 for (auto dim : type.getShape())

1821 numElements *= dim;

1822 return numElements;

1823 }

1824

1826 TensorType operandType = llvm::cast(getSource().getType());

1827 TensorType resultType = llvm::cast(getResult().getType());

1828

1830 return emitOpError("element types of source and destination tensor "

1831 "types should be the same");

1832

1833 int64_t shapeSize =

1834 llvm::cast(getShape().getType()).getDimSize(0);

1835 auto resultRankedType = llvm::dyn_cast(resultType);

1836 auto operandRankedType = llvm::dyn_cast(operandType);

1837

1838 if (resultRankedType) {

1839 if (operandRankedType && resultRankedType.hasStaticShape() &&

1840 operandRankedType.hasStaticShape()) {

1842 return emitOpError("source and destination tensor should have the "

1843 "same number of elements");

1844 }

1845 if (ShapedType::isDynamic(shapeSize))

1846 return emitOpError("cannot use shape operand with dynamic length to "

1847 "reshape to statically-ranked tensor type");

1848 if (shapeSize != resultRankedType.getRank())

1849 return emitOpError(

1850 "length of shape operand differs from the result's tensor rank");

1851 }

1852 return success();

1853 }

1854

1855 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {

1856 if (OpFoldResult reshapedSource = reshapeConstantSource(

1857 llvm::dyn_cast_if_present(adaptor.getSource()),

1859 return reshapedSource;

1860

1861

1862

1863

1864 if (auto reshapeOpProducer = getSource().getDefiningOp()) {

1865 getSourceMutable().assign(reshapeOpProducer.getSource());

1866 return getResult();

1867 }

1868

1869 auto source = getSource();

1870 auto sourceTy = dyn_cast(source.getType());

1871 auto resultTy = dyn_cast(getType());

1872 if (!sourceTy || !resultTy || sourceTy != resultTy)

1873 return {};

1874

1875

1876

1877 if (sourceTy.getRank() == 1)

1878 return source;

1879

1880 if (auto fromElements = getShape().getDefiningOptensor::FromElementsOp()) {

1881 auto elements = fromElements.getElements();

1882 bool dynamicNoop =

1883 sourceTy.getRank() == static_cast<int64_t>(elements.size());

1884 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {

1885 auto element = elements[id];

1886

1888 dynamicNoop &= cst.value() == sourceTy.getDimSize(id);

1889 continue;

1890 }

1891

1892 if (auto dimOp = element.getDefiningOptensor::DimOp()) {

1893 dynamicNoop &= dimOp.getSource() == source;

1894

1896 dynamicNoop &=

1897 cst.has_value() && cst.value() == static_cast<int64_t>(id);

1898 continue;

1899 }

1900

1901 dynamicNoop = false;

1902 break;

1903 }

1904

1905 if (dynamicNoop)

1906 return source;

1907 }

1908

1909 return {};

1910 }

1911

1912

1913

1914

1915

1916 void CollapseShapeOp::getAsmResultNames(

1918 setNameFn(getResult(), "collapsed");

1919 }

1920

1921 void ExpandShapeOp::getAsmResultNames(

1923 setNameFn(getResult(), "expanded");

1924 }

1925

1926 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {

1927 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&

1928 "invalid resultDim");

1929 for (const auto &it : llvm::enumerate(getReassociationIndices()))

1930 if (llvm::is_contained(it.value(), resultDim))

1931 return it.index();

1932 llvm_unreachable("could not find reassociation group");

1933 }

1934

1935 FailureOr<SmallVector>

1937 RankedTensorType expandedType,

1940 std::optional<SmallVector> outputShape =

1942 inputShape);

1943 if (!outputShape)

1944 return failure();

1945 return *outputShape;

1946 }

1947

1950 }

1951

1956 auto [staticOutputShape, dynamicOutputShape] =

1958 build(builder, result, cast(resultType), src,

1960 dynamicOutputShape, staticOutputShape);

1961 }

1962

1968 auto tensorResultTy = cast(resultType);

1969 FailureOr<SmallVector> outputShape = inferOutputShape(

1970 builder, result.location, tensorResultTy, reassociation, inputShape);

1972 if (succeeded(outputShape)) {

1973 outputShapeOrEmpty = *outputShape;

1974 }

1975 build(builder, result, tensorResultTy, src, reassociation,

1976 outputShapeOrEmpty);

1977 }

1978

1981 }

1984 getReassociationIndices());

1985 }

1986

1989 }

1992 getReassociationIndices());

1993 }

1994

1995 RankedTensorType CollapseShapeOp::inferCollapsedType(

1997 return inferCollapsedType(

1999 type.getContext(), reassociation)));

2000 }

2001

2002

2003

2004 RankedTensorType

2005 CollapseShapeOp::inferCollapsedType(RankedTensorType type,

2007 auto shape = type.getShape();

2009 newShape.reserve(reassociation.size());

2010

2011

2012

2014 unsigned currentDim = 0;

2015 for (AffineMap m : reassociation) {

2016 unsigned dim = m.getNumResults();

2017 auto band = shape.slice(currentDim, dim);

2018 int64_t size = 1;

2019 if (llvm::is_contained(band, ShapedType::kDynamic))

2020 size = ShapedType::kDynamic;

2021 else

2022 for (unsigned d = 0; d < dim; ++d)

2023 size *= shape[currentDim + d];

2024 newShape.push_back(size);

2025 currentDim += dim;

2026 }

2027

2029 }

2030

2034 auto resultType = inferCollapsedType(

2035 llvm::cast(src.getType()),

2038 result.addAttribute(getReassociationAttrStrName(),

2040 build(b, result, resultType, src, attrs);

2041 }

2042

2043 template <typename TensorReshapeOp, bool isExpansion = std::is_same<

2044 TensorReshapeOp, ExpandShapeOp>::value>

2046 RankedTensorType expandedType,

2047 RankedTensorType collapsedType) {

2048 if (failed(

2050 return failure();

2051

2052 auto maps = op.getReassociationMaps();

2053 RankedTensorType expectedType =

2054 CollapseShapeOp::inferCollapsedType(expandedType, maps);

2056 return op.emitOpError("expected collapsed type to be ")

2057 << expectedType << ", but got " << collapsedType;

2058 return success();

2059 }

2060

2062 auto srcType = getSrcType();

2063 auto resultType = getResultType();

2064

2065 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())

2066 return emitOpError("expected number of static shape dims to be equal to "

2067 "the output rank (")

2068 << resultType.getRank() << ") but found "

2069 << getStaticOutputShape().size() << " inputs instead";

2070

2071 if ((int64_t)getOutputShape().size() !=

2072 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))

2073 return emitOpError("mismatch in dynamic dims in output_shape and "

2074 "static_output_shape: static_output_shape has ")

2075 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)

2076 << " dynamic dims while output_shape has " << getOutputShape().size()

2077 << " values";

2078

2080 }

2081

2084 }

2085

2086 namespace {

2087

2088

2089 template

2090 struct FoldReshapeWithConstant : OpRewritePattern {

2092 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,

2096 return failure();

2097 if (!attr || !attr.isSplat())

2098 return failure();

2100 reshapeOp.getResultType(), attr.getRawData());

2102 return success();

2103 }

2104 };

2105

2106

2107 template

2108 class FoldReshapeWithSplat : public OpRewritePattern {

2109 public:

2111

2112 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,

2114 auto splatOp = reshapeOp.getSrc().template getDefiningOptensor::SplatOp();

2115 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())

2116 return failure();

2117

2119 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());

2120 return success();

2121 }

2122 };

2123

2124

2125

2126 template

2127 struct FoldReshapeWithFromElements : OpRewritePattern {

2129 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,

2131 auto fromElements =

2132 reshapeOp.getSrc().template getDefiningOp();

2133 if (!fromElements)

2134 return failure();

2135

2136 auto shapedTy = llvm::cast(reshapeOp.getType());

2137

2138 if (!shapedTy.hasStaticShape())

2139 return failure();

2140

2141 rewriter.replaceOpWithNewOp(reshapeOp, reshapeOp.getType(),

2142 fromElements.getElements());

2143 return success();

2144 }

2145 };

2146

2147

2148 struct FoldCollapseOfCastOp : public OpRewritePattern {

2150

2151 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,

2153 auto castOp = collapseShapeOp.getSrc().getDefiningOptensor::CastOp();

2155 return failure();

2156

2157 RankedTensorType srcType =

2158 llvm::cast(castOp.getSource().getType());

2159 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(

2160 srcType, collapseShapeOp.getReassociationMaps());

2161

2162 if (newResultType == collapseShapeOp.getResultType()) {

2164 collapseShapeOp.getSrcMutable().assign(castOp.getSource());

2165 });

2166 } else {

2167 auto newOp = rewriter.create(

2168 collapseShapeOp.getLoc(), newResultType, castOp.getSource(),

2169 collapseShapeOp.getReassociation());

2171 collapseShapeOp, collapseShapeOp.getResultType(), newOp);

2172 }

2173 return success();

2174 }

2175 };

2176

2177

2178

2179

2180

2181 struct ConvertToStaticExpandShape : public OpRewritePattern {

2183

2184 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,

2186 auto castOp = expandOp.getSrc().getDefiningOp();

2188 return failure();

2189

2190 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();

2192 expandOp.getReassociationIndices();

2193

2196 auto outputIt = expandOp.getOutputShape().begin();

2197

2198 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {

2199 for (uint64_t outDim : innerReassoc) {

2200 if (!ShapedType::isDynamic(newOutputShape[outDim]))

2201 continue;

2202

2203

2204

2205

2206

2207 Value val = *outputIt;

2208 ++outputIt;

2209 if (ShapedType::isDynamic(castSrcShape[inputDim])) {

2210 dynamicOutputShape.push_back(val);

2211 continue;

2212 }

2213

2214 APInt cst;

2216 newOutputShape[outDim] = cst.getSExtValue();

2217 } else {

2218 dynamicOutputShape.push_back(val);

2219 }

2220 }

2221 }

2222

2223

2224 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())

2225 return failure();

2226

2227

2229 for (auto inDim : llvm::seq(0, newInputShape.size())) {

2230 for (auto outDim : reassoc[inDim]) {

2231 auto ofr = newOutputShape[outDim];

2232 if (ShapedType::isDynamic(ofr)) {

2233 newInputShape[inDim] = ShapedType::kDynamic;

2234 break;

2235 }

2236 newInputShape[inDim] *= ofr;

2237 }

2238 }

2239

2241 getMixedValues(newOutputShape, dynamicOutputShape, rewriter);

2243 newInputShape, expandOp.getSrcType().getElementType());

2245 newOutputShape, expandOp.getSrcType().getElementType());

2246 auto inputCast = rewriter.create(expandOp.getLoc(), inputType,

2247 expandOp.getSrc());

2248 auto newExpand = rewriter.create(

2249 expandOp.getLoc(), outputType, inputCast.getResult(),

2250 expandOp.getReassociationIndices(), outputOfr);

2252 newExpand.getResult());

2253 return success();

2254 }

2255 };

2256 }

2257

2258 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,

2260 results.add<

2263 ConvertToStaticExpandShape, FoldReshapeWithConstant,

2264 FoldReshapeWithSplat,

2265 FoldReshapeWithFromElements>(context);

2266 }

2267

2268 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,

2270 results.add<

2273 tensor::DimOp, RankedTensorType>,

2274 FoldReshapeWithConstant,

2275 FoldReshapeWithSplat,

2276 FoldReshapeWithFromElements, FoldCollapseOfCastOp>(

2277 context);

2278 }

2279

2280 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {

2281 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,

2282 adaptor.getOperands());

2283 }

2284

2285 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {

2286 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,

2287 adaptor.getOperands());

2288 }

2289

2290

2291

2292

2293

2294 void ExtractSliceOp::getAsmResultNames(

2296 setNameFn(getResult(), "extracted_slice");

2297 }

2298

2299

2300

2301

2302 RankedTensorType ExtractSliceOp::inferResultType(

2303 RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,

2305

2306

2307

2308 assert(static_cast<int64_t>(staticSizes.size()) ==

2309 sourceTensorType.getRank() &&

2310 "unexpected staticSizes not equal to rank of source");

2312 sourceTensorType.getEncoding());

2313 }

2314

2315 RankedTensorType ExtractSliceOp::inferResultType(

2323 return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,

2324 staticSizes, staticStrides);

2325 }

2326

2327

2328

2329

2330

2331

2332

2333

2334

2335 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(

2336 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,

2339

2340 auto inferredType = llvm::cast(

2341 inferResultType(sourceRankedTensorType, offsets, sizes, strides));

2342 int rankDiff = inferredType.getRank() - desiredResultRank;

2343 if (rankDiff > 0) {

2344 auto shape = inferredType.getShape();

2345 llvm::SmallBitVector dimsToProject =

2348

2349 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)

2350 if (!dimsToProject.test(pos))

2351 projectedShape.push_back(shape[pos]);

2352 inferredType =

2354 }

2355 return inferredType;

2356 }

2357

2358 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(

2359 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,

2367 return ExtractSliceOp::inferCanonicalRankReducedResultType(

2368 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,

2369 staticStrides);

2370 }

2371

2372

2373

2375 RankedTensorType resultType, Value source,

2385 auto sourceRankedTensorType = llvm::cast(source.getType());

2386

2387 if (!resultType) {

2388 resultType = llvm::cast(ExtractSliceOp::inferResultType(

2389 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));

2390 }

2392 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,

2396 }

2397

2398

2399

2405 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);

2406 }

2407

2408

2409

2414 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);

2415 }

2416

2417

2418

2420 RankedTensorType resultType, Value source,

2424 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));

2426 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));

2428 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));

2429 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);

2430 }

2431

2432

2436 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);

2437 }

2438

2441 RankedTensorType expectedType) {

2442 switch (result) {

2444 return success();

2446 return op->emitError("expected rank to be smaller or equal to ")

2447 << "the other rank. ";

2449 return op->emitError("expected type to be ")

2450 << expectedType << " or a rank-reduced version. (size mismatch) ";

2452 return op->emitError("expected element type to be ")

2453 << expectedType.getElementType();

2454 default:

2455 llvm_unreachable("unexpected extract_slice op verification result");

2456 }

2457 }

2458

2459

2461 RankedTensorType sourceType = getSourceType();

2462

2463

2464 RankedTensorType expectedType = ExtractSliceOp::inferResultType(

2465 sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());

2469

2470

2471

2473 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),

2474 getStaticStrides(), true);

2475 if (!boundsResult.isValid)

2476 return getOperation()->emitError(boundsResult.errorMessage);

2477

2478 return success();

2479 }

2480

2483 }

2484

2485 FailureOr

2488 auto sourceTensorType = llvm::dyn_cast(value.getType());

2489 assert(sourceTensorType && "not a ranked tensor type");

2490 auto sourceShape = sourceTensorType.getShape();

2491 if (sourceShape.equals(desiredShape))

2492 return value;

2493 auto maybeRankReductionMask =

2495 if (!maybeRankReductionMask)

2496 return failure();

2498 b, loc, value,

2500 }

2501

2504 reifiedReturnShapes.resize(1);

2505 reifiedReturnShapes[0].reserve(getType().getRank());

2507 llvm::SmallBitVector droppedDims = getDroppedDims();

2508 for (const auto &size : enumerate(mixedSizes)) {

2509 if (droppedDims.test(size.index()))

2510 continue;

2511 reifiedReturnShapes[0].push_back(size.value());

2512 }

2513 return success();

2514 }

2515

2516 namespace {

2517

2518

2519

2520

2521

2522

2523

2524

2525

2526

2527

2528

2529

2530

2531

2532 class ExtractSliceOpCastFolder final : public OpRewritePattern {

2533 public:

2535

2536 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,

2538

2539 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {

2540 return matchPattern(operand, matchConstantIndex());

2541 }))

2542 return failure();

2543

2544 auto castOp = sliceOp.getSource().getDefiningOp();

2545 if (!castOp)

2546 return failure();

2547

2549 return failure();

2550

2551

2553 cast(castOp.getSource().getType()).getShape(),

2554 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),

2555 sliceOp.getStaticStrides());

2556 if (!sliceResult.isValid)

2557 return failure();

2558

2559

2560 Location loc = sliceOp.getLoc();

2561 Value newResult = rewriter.create(

2562 loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),

2563 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),

2564 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());

2565 rewriter.replaceOp(sliceOp, newResult);

2566 return success();

2567 }

2568 };

2569

2570

2571

2572

2573 template <typename IterTy, typename ElemTy>

2574 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,

2578 assert(offsets.size() == sizes.size());

2579 assert(offsets.size() == strides.size());

2580 if (offsets.empty())

2581 return;

2582

2583 int64_t offset = offsets.front();

2584 int64_t size = sizes.front();

2585 int64_t stride = strides.front();

2586 if (offsets.size() == 1) {

2587 for (int64_t i = 0; i < size; ++i, offset += stride)

2588 outValues->push_back(*(values + offset));

2589

2590 return;

2591 }

2592

2593 for (int64_t i = 0; i < size; ++i, offset += stride) {

2594 auto begin = values + offset * counts.front();

2595 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),

2596 offsets.drop_front(), sizes.drop_front(),

2597 strides.drop_front(), outValues);

2598 }

2599 }

2600

2601

2602

2603

2604 class ConstantOpExtractSliceFolder final

2606 public:

2608

2609 ConstantOpExtractSliceFolder(MLIRContext *context,

2612 controlFn(std::move(controlFn)) {}

2613

2614 LogicalResult matchAndRewrite(ExtractSliceOp op,

2618 return failure();

2619

2620

2622 return failure();

2623

2624

2625 auto sourceType = llvm::cast(op.getSource().getType());

2626 auto resultType = llvm::cast(op.getResult().getType());

2627 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())

2628 return failure();

2629

2630

2631 if (!controlFn(op))

2632 return failure();

2633

2634 int64_t count = sourceType.getNumElements();

2635 if (count == 0)

2636 return failure();

2637

2638

2639 auto offsets = op.getStaticOffsets();

2640 if (llvm::is_contained(offsets, ShapedType::kDynamic))

2641 return failure();

2642 auto sizes = op.getStaticSizes();

2643 if (llvm::is_contained(sizes, ShapedType::kDynamic))

2644 return failure();

2645 auto strides = op.getStaticStrides();

2646 if (llvm::is_contained(strides, ShapedType::kDynamic))

2647 return failure();

2648

2649

2652 counts.reserve(shape.size());

2653 for (int64_t v : shape) {

2654 count = count / v;

2655 counts.push_back(count);

2656 }

2657

2658

2660

2661 if (auto elems = llvm::dyn_cast(attr)) {

2663 outValues.reserve(sourceType.getNumElements());

2664 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(

2665 elems.begin(), counts, offsets, sizes, strides, &outValues);

2667 } else if (auto elems = llvm::dyn_cast(attr)) {

2669 outValues.reserve(sourceType.getNumElements());

2670 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(

2671 elems.begin(), counts, offsets, sizes, strides, &outValues);

2673 }

2674

2675 if (newAttr) {

2676 rewriter.replaceOpWithNewOparith::ConstantOp(op, resultType, newAttr);

2677 return success();

2678 }

2679

2680 return failure();

2681 }

2682

2683 private:

2684

2685

2687 };

2688

2689 }

2690

2694 patterns.add(patterns.getContext(), controlFn);

2695 }

2696

2697

2703 return ExtractSliceOp::inferCanonicalRankReducedResultType(

2704 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,

2705 mixedStrides);

2706 }

2707 };

2708

2709

2712 ExtractSliceOp newOp) {

2713 Value replacement = newOp.getResult();

2714 if (replacement.getType() != op.getType())

2715 replacement = rewriter.createtensor::CastOp(op.getLoc(), op.getType(),

2716 replacement);

2717 rewriter.replaceOp(op, replacement);

2718 }

2719 };

2720

2721 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,

2723 results.add<

2726 ExtractSliceOpCastFolder>(context);

2727 }

2728

2729

2730 static LogicalResult

2732 ShapedType shapedType) {

2734 for (OpFoldResult ofr : op.getMixedOffsets())

2736 return failure();

2737

2738

2739 auto shape = shapedType.getShape();

2740 for (auto it : llvm::zip(op.getMixedSizes(), shape))

2742 return failure();

2743 for (OpFoldResult ofr : op.getMixedStrides())

2745 return failure();

2746 return success();

2747 }

2748

2749

2750

2751

2752

2754 auto insertOp = extractOp.getSource().getDefiningOp();

2755

2757 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&

2758 insertOp.isSameAs(extractOp, isSame))

2759 return insertOp.getSource();

2760

2761 return {};

2762 }

2763

2764 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {

2765 if (OpFoldResult reshapedSource = reshapeConstantSource(

2766 llvm::dyn_cast_if_present(adaptor.getSource()),

2768 return reshapedSource;

2769 if (getSourceType() == getType() &&

2771 return this->getSource();

2773 return slice;

2774

2776 }

2777

2780 auto rankedTensorType = llvm::cast(tensor.getType());

2781 unsigned rank = rankedTensorType.getRank();

2785 return b.createOrFoldtensor::ExtractSliceOp(loc, targetType, tensor,

2786 offsets, sizes, strides);

2787 }

2788

2789

2790

2791

2792

2793 void InsertSliceOp::getAsmResultNames(

2795 setNameFn(getResult(), "inserted_slice");

2796 }

2797

2798

2810 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,

2814 }

2815

2816

2817

2822 build(b, result, source, dest, offsets, sizes, strides, attrs);

2823 }

2824

2825

2830 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));

2832 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));

2834 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));

2835 build(b, result, source, dest, offsetValues, sizeValues, strideValues);

2836 }

2837

2838

2839

2841 RankedTensorType srcType, RankedTensorType dstType,

2843 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {

2844

2845

2846 RankedTensorType expected = ExtractSliceOp::inferResultType(

2847 dstType, staticOffsets, staticSizes, staticStrides);

2848 if (expectedType)

2849 *expectedType = expected;

2851 }

2852

2853

2855

2856 RankedTensorType expectedType;

2859 getStaticSizes(), getStaticStrides(), &expectedType);

2862

2863

2864

2866 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),

2867 getStaticStrides(), true);

2868 if (!boundsResult.isValid)

2869 return getOperation()->emitError(boundsResult.errorMessage);

2870

2871 return success();

2872 }

2873

2874

2875

2876

2877

2878

2879

2880

2881

2882

2883

2884

2885

2886

2887

2888

2889

2890

2892 auto prevInsertOp = insertOp.getDest().getDefiningOp();

2893

2895 if (!prevInsertOp ||

2896 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||

2897 !prevInsertOp.isSameAs(insertOp, isSame))

2898 return failure();

2899

2900 insertOp.getDestMutable().assign(prevInsertOp.getDest());

2901 return success();

2902 }

2903

2904

2905

2906

2907

2908

2909

2910

2912 auto extractOp = insertOp.getSource().getDefiningOp();

2913

2915 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||

2916 !extractOp.isSameAs(insertOp, isSame))

2917 return nullptr;

2918

2919 return extractOp.getSource();

2920 }

2921

2922 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {

2923 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&

2924 getSourceType() == getType() &&

2926 return this->getSource();

2928 return getResult();

2930 return result;

2932 return getDest();

2934 }

2935

2940 return success();

2941 }

2942

2943 namespace {

2944

2945

2946

2947 template

2948 class InsertSliceOpConstantArgumentFolder final

2950 public:

2952

2953 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,

2958

2959

2963 return failure();

2964

2965

2968 mixedOffsets, mixedSizes, mixedStrides);

2969 if (!sliceResult.isValid)

2970 return failure();

2971

2972

2973 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(

2974 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),

2975 mixedOffsets, mixedSizes, mixedStrides);

2976 Value toInsert = insertSliceOp.getSource();

2977 if (sourceType != insertSliceOp.getSourceType()) {

2979

2980

2981

2982 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)

2984 toInsert = rewriter.createtensor::CastOp(insertSliceOp.getLoc(),

2985 sourceType, toInsert);

2986 }

2988 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,

2989 mixedSizes, mixedStrides);

2990 return success();

2991 }

2992 };

2993

2994

2995

2996

2997

2998

2999

3000

3001

3002

3003

3004

3005

3006

3007

3008

3009

3010

3011

3012

3013

3014 template

3015 struct InsertSliceOpCastFolder final : public OpRewritePattern {

3017

3018 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,

3020 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {

3021 return matchPattern(operand, matchConstantIndex());

3022 }))

3023 return failure();

3024

3025 auto getSourceOfCastOp = [](Value v) -> std::optional {

3026 auto castOp = v.getDefiningOptensor::CastOp();

3028 return std::nullopt;

3029 return castOp.getSource();

3030 };

3031 std::optional sourceCastSource =

3032 getSourceOfCastOp(insertSliceOp.getSource());

3033 std::optional destCastSource =

3034 getSourceOfCastOp(insertSliceOp.getDest());

3035 if (!sourceCastSource && !destCastSource)

3036 return failure();

3037

3038 auto src =

3039 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());

3040 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());

3041 auto srcType = llvm::dyn_cast(src.getType());

3042 auto dstType = llvm::dyn_cast(dst.getType());

3043 if (!srcType || !dstType)

3044 return failure();

3045

3046

3047

3048

3051 staticSizes, srcType.getShape(), true);

3052 if (!rankReductionMask.has_value())

3053 return failure();

3054

3055

3056

3057

3058

3060 int64_t rankReducedIdx = 0;

3061 for (auto [idx, size] : enumerate(staticSizes)) {

3062 if (!rankReductionMask.value().contains(idx) &&

3063 !srcType.isDynamicDim(rankReducedIdx)) {

3065 rewriter.getContext(), srcType.getDimSize(rankReducedIdx));

3066 size = srcType.getDimSize(rankReducedIdx++);

3067 }

3068 }

3069

3070

3071 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),

3072 staticSizes, insertSliceOp.getStaticStrides()) !=

3074 return failure();

3076 verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),

3077 mixedSizes, insertSliceOp.getMixedStrides());

3078 if (!sliceResult.isValid)

3079 return failure();

3080

3082 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),

3083 mixedSizes, insertSliceOp.getMixedStrides());

3084

3085

3086 bool isParallelInsert =

3087 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;

3088 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {

3089 replacement = rewriter.createtensor::CastOp(insertSliceOp.getLoc(),

3090 insertSliceOp.getDestType(),

3092 }

3094 return success();

3095 }

3096 };

3097

3098

3099

3100

3101

3102

3103

3104

3105

3106

3107

3108

3109

3110

3111

3112

3113

3114

3115

3116

3117

3118

3119 template

3120 struct InsertSliceOpSourceCastInserter final

3123

3124 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,

3126 RankedTensorType srcType = insertSliceOp.getSourceType();

3127 if (srcType.getRank() != insertSliceOp.getDestType().getRank())

3128 return failure();

3130 for (int64_t i = 0; i < srcType.getRank(); ++i) {

3131 if (std::optional<int64_t> constInt =

3133

3134 if (*constInt < 0)

3135 return failure();

3136 newSrcShape[i] = *constInt;

3137 }

3138 }

3140 return failure();

3141

3143 newSrcShape, srcType.getElementType(), srcType.getEncoding());

3144 if (srcType == newSrcType ||

3146 !tensor::CastOp::areCastCompatible(srcType, newSrcType))

3147 return failure();

3148

3149

3150

3151

3152

3153

3155

3156

3157

3158 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)

3160 Value cast = rewriter.createtensor::CastOp(

3161 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());

3163 insertSliceOp, cast, insertSliceOp.getDest(),

3164 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),

3165 insertSliceOp.getMixedStrides());

3166 return success();

3167 }

3168 };

3169 }

3170

3173 }

3174

3175 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,

3177 results.add<InsertSliceOpConstantArgumentFolder,

3178 InsertSliceOpCastFolder,

3179 InsertSliceOpSourceCastInserter>(context);

3180 }

3181

3186 auto rankedTensorType = llvm::cast(dest.getType());

3187 unsigned rank = rankedTensorType.getRank();

3191 return b.createOrFoldtensor::InsertSliceOp(loc, tensor, dest, offsets,

3192 sizes, strides);

3193 }

3194

3195

3196

3197

3198

3199 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {

3200 setNameFn(getResult(), "padded");

3201 }

3202

3203

3204

3206 Type typeToInfer, Type typeToInferFrom) {}

3207

3208 ParseResult

3210 std::optionalOpAsmParser::UnresolvedOperand optOperand,

3211 Type &typeToInfer, Type typeToInferFrom) {

3212 if (optOperand)

3213 typeToInfer = typeToInferFrom;

3214 return success();

3215 }

3216

3218 auto sourceType = llvm::cast(getSource().getType());

3219 auto resultType = llvm::cast(getResult().getType());

3220 auto expectedType =

3221 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());

3222 if (!expectedType) {

3223 return emitError("failed to infer expectedType from sourceType ")

3224 << sourceType << ", specified resultType is " << resultType;

3225 }

3226 if (resultType.getRank() != expectedType.getRank()) {

3227 return emitError("specified type ")

3228 << resultType << " does not match the inferred type "

3229 << expectedType;

3230 }

3231 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {

3232 if (resultType.getDimSize(i) == expectedType.getDimSize(i))

3233 continue;

3234 if (expectedType.isDynamicDim(i))

3235 continue;

3236 return emitError("specified type ")

3237 << resultType << " does not match the inferred type "

3238 << expectedType;

3239 }

3240

3241 return success();

3242 }

3243

3244 LogicalResult PadOp::verifyRegions() {

3245 auto &region = getRegion();

3246 unsigned rank = llvm::cast(getResult().getType()).getRank();

3249 return emitError("expected the block to have ") << rank << " arguments";

3250

3251

3253 if (!en.value().isIndex())

3254 return emitOpError("expected block argument ")

3255 << (en.index() + 1) << " to be an index";

3256 }

3257

3258

3259 auto yieldOp = llvm::cast(block.getTerminator());

3260 if (yieldOp.getValue().getType() !=

3262 return emitOpError("expected yield type to match shape element type");

3263

3264 return success();

3265 }

3266

3267 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,

3271 unsigned rank = sourceType.getRank();

3272 if (staticLow.size() != rank)

3273 return RankedTensorType();

3274 if (staticHigh.size() != rank)

3275 return RankedTensorType();

3276 if (!resultShape.empty() && resultShape.size() != rank)

3277 return RankedTensorType();

3278

3280 for (auto i : llvm::seq(0, rank)) {

3281 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||

3282 staticHigh[i] == ShapedType::kDynamic) {

3283 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic

3284 : resultShape[i]);

3285 } else {

3286 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];

3287 assert((resultShape.empty() || size == resultShape[i] ||

3288 resultShape[i] == ShapedType::kDynamic) &&

3289 "mismatch between inferred shape and result shape");

3290 inferredShape.push_back(size);

3291 }

3292 }

3293

3295 }

3296

3301 auto sourceType = llvm::cast(source.getType());

3302 if (!resultType)

3303 resultType = inferResultType(sourceType, staticLow, staticHigh);

3305 build(b, result, resultType, source, low, high,

3308 }

3309

3313 auto sourceType = llvm::cast(source.getType());

3314 unsigned rank = sourceType.getRank();

3316 build(b, result, resultType, source, staticVector, staticVector, low, high,

3317 nofold, attrs);

3318 }

3319

3324 auto sourceType = llvm::cast(source.getType());

3327

3328

3329

3330

3333 if (!resultType) {

3334 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);

3335 }

3336 assert(llvm::isa(resultType));

3338 build(b, result, resultType, source, dynamicLow, dynamicHigh,

3341 }

3342

3347 build(b, result, resultType, source, low, high, nofold, attrs);

3348

3349

3351 int sourceRank = llvm::cast(source.getType()).getRank();

3354

3355

3356

3358 b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);

3359 b.createtensor::YieldOp(result.location, constantPadValue);

3360 }

3361

3362 llvm::SmallBitVector PadOp::getPaddedDims() {

3363 llvm::SmallBitVector paddedDims(getSourceType().getRank());

3365 for (const auto &en : enumerate(paddingWidths))

3367 paddedDims.set(en.index());

3368 };

3369 extractPaddedDims(getMixedLowPad());

3370 extractPaddedDims(getMixedHighPad());

3371 return paddedDims;

3372 }

3373

3374 namespace {

3375

3376

3377 struct FoldStaticZeroPadding : public OpRewritePattern {

3379

3380 LogicalResult matchAndRewrite(PadOp padTensorOp,

3382 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())

3383 return failure();

3384 if (padTensorOp.getNofold())

3385 return failure();

3387 padTensorOp, padTensorOp.getResult().getType(),

3388 padTensorOp.getSource());

3389 return success();

3390 }

3391 };

3392

3393

3394 struct FoldSourceTensorCast : public OpRewritePattern {

3396

3397 LogicalResult matchAndRewrite(PadOp padTensorOp,

3399 auto castOp = padTensorOp.getSource().getDefiningOptensor::CastOp();

3401 return failure();

3402

3403 auto newResultType = PadOp::inferResultType(

3404 llvm::cast(castOp.getSource().getType()),

3405 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),

3406 padTensorOp.getResultType().getShape());

3407

3408 if (newResultType == padTensorOp.getResultType()) {

3410 padTensorOp.getSourceMutable().assign(castOp.getSource());

3411 });

3412 } else {

3413 auto newOp = rewriter.create(

3414 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),

3415 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),

3416 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),

3419 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);

3420

3422 padTensorOp, padTensorOp.getResultType(), newOp);

3423 }

3424 return success();

3425 }

3426 };

3427

3428

3429

3430 struct FoldTargetTensorCast : public OpRewritePattern {

3432

3433 LogicalResult matchAndRewrite(PadOp padTensorOp,

3435 if (!padTensorOp.getResult().hasOneUse())

3436 return failure();

3437 auto tensorCastOp =

3438 dyn_casttensor::CastOp(*padTensorOp->getUsers().begin());

3439 if (!tensorCastOp)

3440 return failure();

3442 tensorCastOp.getDest().getType()))

3443 return failure();

3444

3445 auto replacementOp = rewriter.create(

3446 padTensorOp.getLoc(), tensorCastOp.getDest().getType(),

3447 padTensorOp.getSource(), padTensorOp.getStaticLow(),

3448 padTensorOp.getStaticHigh(), padTensorOp.getLow(),

3449 padTensorOp.getHigh(), padTensorOp.getNofold(),

3452

3453 rewriter.replaceOp(padTensorOp, replacementOp.getResult());

3454 rewriter.replaceOp(tensorCastOp, replacementOp.getResult());

3455 return success();

3456 }

3457 };

3458

3459

3460

3461

3462

3463

3464

3465

3466

3467

3468

3469

3470

3471

3472

3473

3474

3475

3476

3477

3478

3479

3480

3481

3482

3483

3484

3485

3486

3487

3488

3489

3490

3491

3492

3493

3494 struct FoldOrthogonalPaddings : public OpRewritePattern {

3496

3497 LogicalResult matchAndRewrite(PadOp padOp,

3499 auto innerSliceOp = padOp.getSource().getDefiningOp();

3500 if (!innerSliceOp)

3501 return failure();

3502 auto outerPadOp = innerSliceOp.getSource().getDefiningOp();

3503 if (!outerPadOp || outerPadOp.getNofold())

3504 return failure();

3505 auto outerSliceOp = outerPadOp.getSource().getDefiningOp();

3506 if (!outerSliceOp)

3507 return failure();

3508

3509

3510 int64_t rank = padOp.getSourceType().getRank();

3511 if (outerSliceOp.getSourceType().getRank() != rank) {

3513 "cannot fold rank-reducing chain");

3514 }

3515

3516

3517 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {

3519 padOp, "cannot fold non-unit stride ExtractSliceOps");

3520 }

3521

3522

3523 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {

3525 "cannot fold PadOps with low padding");

3526 }

3527

3528

3530 Value innerValue = padOp.getConstantPaddingValue();

3531 Value outerValue = outerPadOp.getConstantPaddingValue();

3532 if (!innerValue || !outerValue ||

3535 innerAttr != outerAttr) {

3537 padOp, "cannot fold PadOps with different padding values");

3538 }

3539

3540

3541 llvm::SmallBitVector innerDims = padOp.getPaddedDims();

3542 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();

3543 if (innerDims.anyCommon(outerDims)) {

3545 padOp, "cannot fold PadOps with common padding dimensions");

3546 }

3547

3548

3549

3550

3551

3552

3554 for (auto en : enumerate(newOffsets)) {

3555 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];

3556 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];

3557 if (!innerDims.test(en.index()) &&

3559 en.value() = outerOffset;

3560 continue;

3561 }

3562 if (!outerDims.test(en.index()) &&

3564 en.value() = innerOffset;

3565 continue;

3566 }

3568 padOp, "cannot find zero-offset and zero-padding pair");

3569 }

3570

3571

3572

3573

3574

3575

3577 for (auto en : enumerate(newSizes)) {

3578 if (!outerDims.test(en.index()))

3579 continue;

3580 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];

3581 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];

3582 assert(!ShapedType::isDynamic(sourceSize) &&

3583 "expected padded dimension to have a static size");

3586 padOp, "cannot fold since the inner ExtractSliceOp size does not "

3587 "match the size of the outer padding");

3588 }

3589 en.value() = outerSliceOp.getMixedSizes()[en.index()];

3590 }

3591

3592

3594 for (auto en : enumerate(newHighPad)) {

3595 if (innerDims.test(en.index()))

3596 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];

3597 if (outerDims.test(en.index()))

3598 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];

3599 }

3600

3601

3602

3603 auto newSliceOp = rewriter.create(

3604 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,

3605 innerSliceOp.getMixedStrides());

3606 auto newPadOp = rewriter.create(

3607 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),

3608 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),

3610 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),

3611 newPadOp.getRegion().begin());

3612 rewriter.replaceOp(padOp, newPadOp.getResult());

3613 return success();

3614 }

3615 };

3616

3619

3620 LogicalResult matchAndRewrite(PadOp padTensorOp,

3622 Value input = padTensorOp.getSource();

3623 if (!llvm::isa(input.getType()))

3624 return failure();

3625 auto inputDims = llvm::cast(input.getType()).getShape();

3626 auto inputRank = inputDims.size();

3627

3628 auto oldResultType =

3629 dyn_cast(padTensorOp.getResult().getType());

3630 if (!oldResultType)

3631 return failure();

3632

3633 auto outputDims = oldResultType.getShape();

3634

3635

3638 for (auto operand : padTensorOp.getLow()) {

3639 APSInt intOp;

3641 constOperandsLow.push_back(ShapedType::kDynamic);

3642 newLows.push_back(operand);

3643 continue;

3644 }

3645 constOperandsLow.push_back(intOp.getExtValue());

3646 }

3649 for (auto operand : padTensorOp.getHigh()) {

3650 APSInt intOp;

3652 constOperandsHigh.push_back(ShapedType::kDynamic);

3653 newHighs.push_back(operand);

3654 continue;

3655 }

3656 constOperandsHigh.push_back(intOp.getExtValue());

3657 }

3658

3661

3662

3663 if (inputDims.size() != outputDims.size() ||

3664 inputDims.size() != constLow.size() ||

3665 inputDims.size() != constHigh.size())

3666 return failure();

3667

3668 auto lowCount = 0;

3669 auto highCount = 0;

3670 for (size_t i = 0; i < inputRank; i++) {

3671 if (constLow[i] == ShapedType::kDynamic)

3672 constLow[i] = constOperandsLow[lowCount++];

3673 if (constHigh[i] == ShapedType::kDynamic)

3674 constHigh[i] = constOperandsHigh[highCount++];

3675 }

3676

3679

3680

3682 for (size_t i = 0; i < inputRank; i++) {

3683 if (outputDims[i] == ShapedType::kDynamic) {

3684 newOutDims.push_back(

3685 (staticLow[i] == ShapedType::kDynamic ||

3686 staticHigh[i] == ShapedType::kDynamic ||

3687 inputDims[i] == ShapedType::kDynamic

3688 ? ShapedType::kDynamic

3689 : inputDims[i] + staticLow[i] + staticHigh[i]));

3690 } else {

3691 newOutDims.push_back(outputDims[i]);

3692 }

3693 }

3694

3696 llvm::all_of(newOutDims,

3697 [&](int64_t x) { return x == ShapedType::kDynamic; }))

3698 return failure();

3699

3700

3702 newOutDims, padTensorOp.getType().getElementType());

3703 auto newOp = rewriter.create(

3704 padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,

3705 newLows, newHighs, padTensorOp.getNofold(),

3707

3709 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);

3710 rewriter.replaceOpWithNewOptensor::CastOp(padTensorOp, oldResultType,

3711 newOp);

3712

3713 return success();

3714 }

3715 };

3716

3717

3718

3719

3720

3721

3722

3723

3724

3725

3726

3727

3728

3729

3730

3731

3732

3733

3734

3735

3736

3737 struct FoldConsecutiveConstantPadding : public OpRewritePatterntensor::PadOp {

3739

3740 LogicalResult matchAndRewrite(tensor::PadOp padOp,

3742 if (padOp.getNofold()) {

3743 return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");

3744 }

3745

3746 auto producerPad = padOp.getSource().getDefiningOptensor::PadOp();

3747 if (!producerPad || producerPad.getNofold()) {

3749 padOp, "producer is not a foldable tensor.pad op");

3750 }

3751

3752

3753 Value consumerPadValue = padOp.getConstantPaddingValue();

3754 Value producerPadValue = producerPad.getConstantPaddingValue();

3755 if (!consumerPadValue || !producerPadValue ||

3756 consumerPadValue != producerPadValue) {

3758 padOp,

3759 "cannot fold PadOps with different or non-constant padding values");

3760 }

3761

3762 Location loc = padOp.getLoc();

3765

3766

3770 for (auto [consumerIndex, producerIndex] :

3771 llvm::zip_equal(consumerPaddings, producerPaddings)) {

3773 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));

3774 }

3775 return sumPaddings;

3776 };

3777

3779 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());

3781 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());

3782

3783 auto newPadOp = rewriter.createtensor::PadOp(

3784 padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),

3785 newLowPad, newHighPad, padOp.getNofold(),

3787 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),

3788 newPadOp.getRegion().begin());

3789 rewriter.replaceOp(padOp, newPadOp.getResult());

3790 return success();

3791 }

3792 };

3793

3794 }

3795

3796 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,

3798 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,

3799 FoldOrthogonalPaddings, FoldStaticPadding,

3800 FoldConsecutiveConstantPadding>(context);

3801 }

3802

3803

3804

3805

3806

3807

3808

3809

3810

3811

3812 Value PadOp::getConstantPaddingValue() {

3813 auto yieldOp = dyn_cast(getRegion().front().getTerminator());

3814 if (!yieldOp)

3815 return {};

3816 Value padValue = yieldOp.getValue();

3817

3819 return padValue;

3820

3821 if (padValue.getParentBlock() == &getRegion().front())

3822 return {};

3823

3824 return padValue;

3825 }

3826

3828 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&

3829 !getNofold())

3830 return getSource();

3831 return {};

3832 }

3833

3834

3835

3836

3837

3838 OpResult ParallelInsertSliceOp::getTiedOpResult() {

3839 ParallelCombiningOpInterface parallelCombiningParent =

3840 getParallelCombiningParent();

3841 for (const auto &it :

3842 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {

3844 if (&nextOp == getOperation())

3845 return parallelCombiningParent.getParentResult(it.index());

3846 }

3847 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");

3848 }

3849

3850

3863 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,

3867 }

3868

3869

3870

3876 build(b, result, source, dest, offsets, sizes, strides, attrs);

3877 }

3878

3879

3885 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));

3887 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));

3889 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));

3890 build(b, result, source, dest, offsetValues, sizeValues, strideValues);

3891 }

3892

3894 if (!isa(getOperation()->getParentOp()))

3895 return this->emitError("expected ParallelCombiningOpInterface parent, got:")

3896 << *(getOperation()->getParentOp());

3897

3898

3899 RankedTensorType expectedType;

3902 getStaticSizes(), getStaticStrides(), &expectedType);

3905

3906

3907

3909 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),

3910 getStaticStrides(), true);

3911 if (!boundsResult.isValid)

3912 return getOperation()->emitError(boundsResult.errorMessage);

3913

3914 return success();

3915 }

3916

3917 void ParallelInsertSliceOp::getCanonicalizationPatterns(

3919 results.add<InsertSliceOpConstantArgumentFolder,

3920 InsertSliceOpCastFolder,

3921 InsertSliceOpSourceCastInserter>(context);

3922 }

3923

3926 }

3927

3928

3929

3930

3931

3932 void ScatterOp::getAsmResultNames(

3934 setNameFn(getResult(), "scatter");

3935 }

3936

3938 int64_t destRank = getDestType().getRank();

3941 getIndicesType().getShape(), destRank,

3942 "scatter", "dest")))

3943 return failure();

3944

3945 if (!getUnique())

3946 return emitOpError("requires 'unique' attribute to be set");

3947

3948

3949

3950

3951

3952

3953 RankedTensorType expectedSourceType = GatherOp::inferResultType(

3954 getDestType(), getIndicesType(), scatterDims, false);

3955 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(

3956 getDestType(), getIndicesType(), scatterDims, true);

3957 if (getSourceType() != expectedSourceType &&

3958 getSourceType() != expectedRankReducedSourceType) {

3959 return emitOpError("source type "

3960 "mismatch: "

3961 "expected ")

3962 << expectedSourceType << " or its rank-reduced variant "

3963 << expectedRankReducedSourceType << " (got: " << getSourceType()

3964 << ")";

3965 }

3966

3967 return success();

3968 }

3969

3970

3971

3972

3973

3976 build(builder, result, aggregateType, element, dynamicSizes);

3977 }

3978

3982 build(builder, result, aggregateType, element, dynamicSizes);

3983 }

3984

3990 build(builder, result, element, staticShape, dynamicSizes);

3991 }

3992

3993 void SplatOp::getAsmResultNames(

3995 setNameFn(getResult(), "splat");

3996 }

3997

4000 return emitOpError("incorrect number of dynamic sizes, has ")

4002 << getType().getNumDynamicDims();

4003 return success();

4004 }

4005

4006 LogicalResult

4010 unsigned ctr = 0;

4011 for (int64_t i = 0; i < getType().getRank(); ++i) {

4012 if (getType().isDynamicDim(i)) {

4014 } else {

4015 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));

4016 }

4017 }

4018 return success();

4019 }

4020

4021 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {

4022 auto constOperand = adaptor.getInput();

4023 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))

4024 return {};

4025

4026

4027 if (getType().hasStaticShape())

4028 return {};

4029

4030

4031

4033 }

4034

4035

4036

4037

4039

4040

4041

4042 if (isa(op.getOperation()) ||

4043 isa(op.getOperation()))

4044 return false;

4045

4047 }

4048

4049

4050

4051

4052

4053

4054

4055

4056

4057

4058

4059

4060

4061

4062

4063

4064

4069

4072

4073

4074

4076 isalinalg::RelayoutOpInterface(*op))

4077 return failure();

4078

4082

4083

4084 auto newOp = clone(rewriter, op, newResultTypes, newOperands);

4085

4087 replacements.reserve(newOp->getNumResults());

4088 for (auto [oldResult, newResult] :

4089 llvm::zip(op->getResults(), newOp->getResults())) {

4090 if (newResult.getType() != oldResult.getType()) {

4091 replacements.push_back(rewriter.createtensor::CastOp(

4092 op->getLoc(), oldResult.getType(), newResult));

4093 } else {

4094 replacements.push_back(newResult);

4095 }

4096 }

4097 rewriter.replaceOp(op, replacements);

4098

4099 return success();

4100 }

4101 };

4102

4103

4104

4105

4106

4107 void TensorDialect::getCanonicalizationPatterns(

4110 }

4111

4112

4113

4114

4115

4116 #define GET_OP_CLASSES

4117 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"

static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)

A utility function used to materialize a constant for a given attribute and type.

static MLIRContext * getContext(OpFoldResult val)

static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)

Walks the given type hierarchy with the given indices, potentially down to component granularity,...

static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)

Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...

static TensorType joinShapes(TensorType one, TensorType two)

Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.

static Value foldExtractAfterInsert(ExtractOp extractOp)

If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...

static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)

static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)

ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)

static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)

If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...

static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)

static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)

If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...

static int64_t getNumElements(ShapedType type)

static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)

Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.

static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)

Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...

void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)

static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)

Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....

static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)

Folds round-trip extract/insert slice op pairs.

static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)

bool foldTensorCastPrecondition(DestinationStyleOpInterface op)

static ArrayRef< int64_t > getShape(Type type)

Returns the shape of the given type.

Base type for affine expression.

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

Attributes are known-constant values of operations.

Block represents an ordered list of Operations.

ValueTypeRange< BlockArgListType > getArgumentTypes()

Return a range containing the types of the arguments for this block.

unsigned getNumArguments()

Operation * getTerminator()

Get the terminator operation of this block.

BlockArgListType getArguments()

iterator_range< iterator > without_terminator()

Return an iterator range over the operation within this block excluding the terminator operation at t...

This class is a general helper class for creating context-global objects like types,...

IntegerAttr getIndexAttr(int64_t value)

DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)

AffineExpr getAffineSymbolExpr(unsigned position)

AffineExpr getAffineDimExpr(unsigned position)

AffineMap getConstantAffineMap(int64_t val)

Returns a single constant result affine map with 0 dimensions and 0 symbols.

MLIRContext * getContext() const

An attribute that represents a reference to a dense vector or tensor object.

static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)

Construct a dense elements attribute from a raw buffer representing the data for this attribute.

bool isSplat() const

Returns true if this attribute corresponds to a splat, i.e.

ArrayRef< char > getRawData() const

Return the raw storage data held by this attribute.

static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)

Constructs a dense elements attribute from an array of element values.

This is a utility class for mapping one set of IR entities to another.

auto lookupOrDefault(T from) const

Lookup a mapped value within the map.

void map(Value from, Value to)

Inserts a new mapping for 'from' to 'to'.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

This class represents a single result from folding an operation.

This class represents an operand of an operation.

This is a value defined by a result of an operation.

unsigned getResultNumber() const

Returns the number of this result.

Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.

Operation is the basic unit of execution within MLIR.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

Region & getRegion(unsigned index)

Returns the region held by this operation at position 'index'.

MutableArrayRef< OpOperand > getOpOperands()

result_range getResults()

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This is a builder type that keeps local references to arguments.

Builder & setShape(ArrayRef< int64_t > newShape)

This class contains a list of basic blocks and a link to the parent operation it is attached to.

void takeBody(Region &other)

Takes body of another region (that region will have no body after this operation completes).

RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)

Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)

Move the blocks that belong to "region" before the given position in another region "parent".

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...

bool hasRank() const

Returns if this type is ranked, i.e. it has a known number of dimensions.

Type getElementType() const

Returns the element type of this tensor type.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

type_range getType() const

type_range getTypes() const

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Block * getParentBlock()

Return the Block in which this Value is defined.

Location getLoc() const

Return the location of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

Speculatability

This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...

constexpr auto Speculatable

constexpr auto NotSpeculatable

OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)

Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...

SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)

Return the dimensions of the given memref value.

Operation::operand_range getIndices(Operation *op)

Get the indices that the given load/store operation is operating on.

DynamicAPInt getIndex(const ConeV &cone)

Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...

Value constantIndex(OpBuilder &builder, Location loc, int64_t i)

Generates a constant of index type.

LogicalResult foldTensorCast(Operation *op)

Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.

bool hasFoldableTensorCastOperand(Operation *op)

Return true if any of the operands of op is a CastOp that can be folded into its consumer,...

void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})

Patterns to fold the extract slice op with its constant operand.

bool canFoldIntoProducerOp(CastOp castOp)

Determines whether the tensor::CastOp casts to a more static version of the source tensor.

SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)

Assuming that op contains at least one operand that is a foldable CastOp (i.e.

bool canFoldIntoConsumerOp(CastOp castOp)

Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.

Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)

Create a rank-reducing InsertSliceOp @[0 .

Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)

Create a rank-reducing ExtractSliceOp @[0 .

bool isSameTypeWithoutEncoding(Type tp1, Type tp2)

Tests if types are the same when ignoring encoding on ranked tensors.

OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)

Return the dimension of the given tensor value.

void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)

Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.

FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)

This is a helper function for DestinationStyleOpInterface.

std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn

Function to control the folding of constant and extract slice.

bool preservesStaticInformation(Type source, Type target)

Returns true if target is a ranked tensor type that preserves static information available in the sou...

SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)

Return the dimensions of the given tensor value.

LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)

This is a helper function for DestinationStyleOpInterface.

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)

Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...

OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)

Convert int64_t to integer attributes of index type and return them as OpFoldResult.

std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)

Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...

SliceVerificationResult

Enum that captures information related to verifier error conditions on slice insert/extract type of o...

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)

Reify the shape of the result of an operation (typically in terms of the shape of its operands).

LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)

Returns "success" when any of the elements in strides is a constant value.

SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)

Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to DimExpr at positions: [0 .

SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)

Given the strides together with a linear index in the dimension space, return the vector-space offset...

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)

Constructs affine maps out of Array<Array>.

static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)

Common verifier for reshape-like types.

bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)

Helper function to check whether the passed in sizes or offsets are valid.

bool wouldOpBeTriviallyDead(Operation *op)

Return true if the given operation would be dead if unused, and has no side effects on memory that wo...

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

const FrozenRewritePatternSet & patterns

bool isZeroInteger(OpFoldResult v)

Return true if v is an IntegerAttr with value 0.

void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)

Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...

SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)

Convert reassociation indices to affine expressions.

std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)

Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.

bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)

Return true if the reassociation specification is valid, false otherwise.

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)

Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.

Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

OpFoldResult getAsOpFoldResult(Value val)

Given a value, try to extract a constant Attribute.

llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)

SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)

Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...

std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)

Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...

LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)

Returns success if the given two shapes are compatible.

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)

Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)

Wraps a list of reassociations in an ArrayAttr.

SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)

LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)

Returns "success" when any of the elements in offsetsOrSizes is a constant value.

Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....

LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override

A canonicalizer wrapper to replace ExtractSliceOps.

void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)

Return the canonical type of the result of an extract_slice op.

RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)

Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).

Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...

OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

This represents an operation in an abstracted form, suitable for use with the builder APIs.

void addAttributes(ArrayRef< NamedAttribute > newAttributes)

Add an array of named attributes.

void addAttribute(StringRef name, Attribute attr)

Add an attribute with the specified name.

SmallVector< std::unique_ptr< Region >, 1 > regions

Regions that the op will hold.

Idiomatic saturated operations on values like offsets, sizes, and strides.

static SaturatedInteger wrap(int64_t v)

FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)

Result for slice bounds verification;.

bool isValid

If set to "true", the slice bounds verification was successful.

std::string errorMessage

An error message that can be printed during op verification.