MLIR: lib/Dialect/Shape/IR/Shape.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9 #include

10

12

27 #include "llvm/ADT/SetOperations.h"

28 #include "llvm/ADT/SmallString.h"

29 #include "llvm/ADT/TypeSwitch.h"

30 #include "llvm/Support/raw_ostream.h"

31

32 using namespace mlir;

34

35 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"

36

37 namespace {

38 #include "ShapeCanonicalization.inc"

39 }

40

43 }

44

46 auto ranked = llvm::dyn_cast(type);

47 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();

48 }

49

52 if (auto inputOp = input.getDefiningOp()) {

53 auto type = llvm::cast(inputOp.getArg().getType());

54 if (!type.hasRank())

55 return failure();

56 llvm::append_range(shapeValues, type.getShape());

57 return success();

58 }

61 llvm::append_range(shapeValues, attr.getValues<int64_t>());

62 return success();

63 }

64 return failure();

65 }

66

68 return llvm::any_of(operandTypes,

69 llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);

70 }

71

73 assert(op != nullptr && op->getNumResults() == 1);

76 if (!llvm::isa(resultTy))

78 << "if at least one of the operands can hold error values then "

79 "the result must be of type `size` to propagate them";

80 }

81 return success();

82 }

83

85 assert(op != nullptr && op->getNumResults() == 1);

88 if (!llvm::isa(resultTy))

90 << "if at least one of the operands can hold error values then "

91 "the result must be of type `shape` to propagate them";

92 }

93 return success();

94 }

95

96 template <typename... Ty>

98 return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());

99 }

100

101 template <typename... Ty, typename... ranges>

104 }

105

106

107

108

109

110 namespace {

111

114

115

116

119 return true;

120 }

121

122

123

124

127 return true;

128 }

129 };

130 }

131

132 void ShapeDialect::initialize() {

133 addOperations<

134 #define GET_OP_LIST

135 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"

136 >();

137 addTypes<

138 #define GET_TYPEDEF_LIST

139 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"

140 >();

141 addInterfaces();

142

143

144

145 allowUnknownOperations();

146 declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,

147 AssumingYieldOp>();

148 }

149

153 if (auto poison = dyn_castub::PoisonAttr(value))

154 return builder.createub::PoisonOp(loc, type, poison);

155

157 return builder.create(

158 loc, type, llvm::cast(value));

159 if (llvm::isa(type))

160 return builder.create(loc, type,

161 llvm::cast(value));

162 if (llvm::isa(type))

163 return builder.create(loc, type,

164 llvm::cast(value));

165

166 return arith::ConstantOp::materialize(builder, value, type, loc);

167 }

168

169 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,

171

172 if (attribute.getName() == "shape.lib") {

175 "shape.lib attribute may only be on op implementing SymbolTable");

176

177 if (auto symbolRef = llvm::dyn_cast(attribute.getValue())) {

179 if (!symbol)

180 return op->emitError("shape function library ")

181 << symbolRef << " not found";

182 return isashape::FunctionLibraryOp(symbol)

183 ? success()

185 << symbolRef << " required to be shape function library";

186 }

187

188 if (auto arr = llvm::dyn_cast(attribute.getValue())) {

189

190

192 for (auto it : arr) {

193 if (!llvm::isa(it))

195 "only SymbolRefAttr allowed in shape.lib attribute array");

196

197 auto shapeFnLib = dyn_castshape::FunctionLibraryOp(

199 if (!shapeFnLib)

201 << it << " does not refer to FunctionLibraryOp";

202 for (auto mapping : shapeFnLib.getMapping()) {

203 if (!key.insert(mapping.getName()).second) {

204 return op->emitError("only one op to shape mapping allowed, found "

205 "multiple for `")

206 << mapping.getName() << "`";

207 }

208 }

209 }

210 return success();

211 }

212

213 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "

214 "allowed as shape.lib attribute");

215 }

216 return success();

217 }

218

219

220

221

222

223

224

225 OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {

226

227 if (adaptor.getInputs().back())

228 return adaptor.getInputs().back();

229

230 return nullptr;

231 }

232

233

234

235

236

238 result.regions.reserve(1);

240

246 return failure();

247

248

250 return failure();

251

252

253 if (parser.parseRegion(*doRegion, {}, {}))

254 return failure();

255 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);

256

257

259 return failure();

260 return success();

261 }

262

264 bool yieldsResults = !getResults().empty();

265

266 p << " " << getWitness();

267 if (yieldsResults)

268 p << " -> (" << getResultTypes() << ")";

269 p << ' ';

271 false,

272 yieldsResults);

274 }

275

276 namespace {

277

278 struct AssumingWithTrue : public OpRewritePattern {

280

281 LogicalResult matchAndRewrite(AssumingOp op,

283 auto witness = op.getWitness().getDefiningOp();

284 if (!witness || !witness.getPassingAttr())

285 return failure();

286

287 AssumingOp::inlineRegionIntoParent(op, rewriter);

288 return success();

289 }

290 };

291

292 struct AssumingOpRemoveUnusedResults : public OpRewritePattern {

294

295 LogicalResult matchAndRewrite(AssumingOp op,

297 Block *body = op.getBody();

298 auto yieldOp = llvm::cast(body->getTerminator());

299

300

302 for (auto [opResult, yieldOperand] :

303 llvm::zip(op.getResults(), yieldOp.getOperands())) {

304 if (!opResult.getUses().empty()) {

305 newYieldOperands.push_back(yieldOperand);

306 }

307 }

308

309

310 if (newYieldOperands.size() == yieldOp->getNumOperands())

311 return failure();

312

313

314

316 auto newYieldOp =

317 rewriter.replaceOpWithNewOp(yieldOp, newYieldOperands);

319 auto newOp = rewriter.create(

320 op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());

321 newOp.getDoRegion().takeBody(op.getDoRegion());

322

323

325 auto src = newOp.getResults().begin();

326 for (auto it : op.getResults()) {

327 if (it.getUses().empty())

328 replacementValues.push_back(nullptr);

329 else

330 replacementValues.push_back(*src++);

331 }

332 rewriter.replaceOp(op, replacementValues);

333 return success();

334 }

335 };

336 }

337

340 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);

341 }

342

343

344 void AssumingOp::getSuccessorRegions(

346

347

348

351 return;

352 }

353

355 }

356

357 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,

360 auto *assumingBlock = op.getBody();

362 auto *blockAfterAssuming =

363 rewriter.splitBlock(blockBeforeAssuming, initPosition);

364

365

366 auto &yieldOp = assumingBlock->back();

368 rewriter.replaceOp(op, yieldOp.getOperands());

369 rewriter.eraseOp(&yieldOp);

370

371

372

373 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);

374 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);

375 }

376

377 void AssumingOp::build(

381

385

386

388 builder.create(result.location, yieldValues);

389

391 for (Value v : yieldValues)

392 assumingTypes.push_back(v.getType());

393 result.addTypes(assumingTypes);

394 }

395

396

397

398

399

400 LogicalResult mlir::shape::AddOp::inferReturnTypes(

401 MLIRContext *context, std::optional location,

403 if (llvm::isa(adaptor.getLhs().getType()) ||

404 llvm::isa(adaptor.getRhs().getType()))

405 inferredReturnTypes.assign({SizeType::get(context)});

406 else

407 inferredReturnTypes.assign({IndexType::get(context)});

408 return success();

409 }

410

411 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {

412

413 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);

414 }

415

416 OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {

417

419 return getLhs();

420

421 return constFoldBinaryOp(

422 adaptor.getOperands(),

423 [](APInt a, const APInt &b) { return std::move(a) + b; });

424 }

425

427

428

429

430

431

432 namespace {

433

434

435

436

437

438

439

440

441

442 struct MergeAssumingAllOps : public OpRewritePattern {

444

445 LogicalResult matchAndRewrite(AssumingAllOp op,

448

449 for (Value operand : op.getInputs()) {

450 if (auto assumeAll = operand.getDefiningOp())

451 operands.append(assumeAll.operand_begin(), assumeAll->operand_end());

452 else

453 operands.push_back(operand);

454 }

455

456

457 if (operands.size() == op.getNumOperands())

458 return failure();

459

460

462 return success();

463 }

464 };

465

466

467

468

469

470

471

472

473

474

475

476

477

478

479

480

481

482

483

484

485

486

487 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern {

489

490 LogicalResult matchAndRewrite(AssumingAllOp op,

492

494 for (Value operand : op.getInputs()) {

495

496

497 auto broadcastable = operand.getDefiningOp();

498 if (!broadcastable)

499 return failure();

500

501 operands.insert(broadcastable);

502 }

503

504

505 if (operands.size() <= 1)

506 return failure();

507

508

510 for (auto cstr : operands) {

511 DenseSet shapesSet(cstr->operand_begin(), cstr->operand_end());

512 shapes.emplace_back(cstr, std::move(shapesSet));

513 }

514

515

516 llvm::sort(shapes, [](auto a, auto b) {

517 return a.first.getNumOperands() > b.first.getNumOperands();

518 });

519

520

521

522

523

525

526 for (unsigned i = 0; i < shapes.size(); ++i) {

527 auto isSubset = [&](auto pair) {

528 return llvm::set_is_subset(pair.second, shapes[i].second);

529 };

530

531

532 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);

533 for (auto *it0 = it; it0 < shapes.end(); ++it0)

534 markedForErase.push_back(it0->first);

535 shapes.erase(it, shapes.end());

536 }

537

538

539 if (markedForErase.empty())

540 return failure();

541

542

544 for (auto &shape : shapes)

545 uniqueConstraints.push_back(shape.first.getResult());

546

547

549

550

551 for (auto &op : markedForErase)

552 if (op->use_empty())

554

555 return success();

556 }

557 };

558

559 struct AssumingAllToCstrEqCanonicalization

562

563 LogicalResult matchAndRewrite(AssumingAllOp op,

566 for (Value w : op.getInputs()) {

567 auto cstrEqOp = w.getDefiningOp();

568 if (!cstrEqOp)

569 return failure();

570 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {

571 return llvm::is_contained(shapes, s);

572 });

573 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)

574 return failure();

575 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());

576 }

578 return success();

579 }

580 };

581

582 template

583 struct RemoveDuplicateOperandsPattern : public OpRewritePattern {

585

586 LogicalResult matchAndRewrite(OpTy op,

588

589 SetVector unique(op.operand_begin(), op.operand_end());

590

591

592 if (unique.size() < op.getNumOperands()) {

594 unique.takeVector(), op->getAttrs());

595 return success();

596 }

597

598 return failure();

599 }

600 };

601 }

602

606 .add<MergeAssumingAllOps, AssumingAllOneOp,

607 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,

608 RemoveDuplicateOperandsPattern>(context);

609 }

610

611 OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {

612

613

614 for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {

615 Attribute a = adaptor.getInputs()[idx];

616

617 if (!a)

618 return nullptr;

619

620

621

622 getOperation()->eraseOperand(idx);

623

624

625 if (!llvm::cast(a).getValue())

626 return a;

627 }

628

630 }

631

633

634 if (getNumOperands() == 0)

635 return emitOpError("no operands specified");

636

637 return success();

638 }

639

640

641

642

643

644 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {

645 if (getShapes().size() == 1) {

646

648 return nullptr;

649 return getShapes().front();

650 }

651

652 if (!adaptor.getShapes().front())

653 return nullptr;

654

656 llvm::cast(adaptor.getShapes().front())

657 .getValues<int64_t>());

658

659 for (auto next : adaptor.getShapes().drop_front()) {

660 if (!next)

661 return nullptr;

662 auto nextShape = llvm::to_vector<6>(

663 llvm::cast(next).getValues<int64_t>());

664

666

667

669 return nullptr;

670

671 resultShape.clear();

672 std::copy(tmpShape.begin(), tmpShape.end(),

673 std::back_inserter(resultShape));

674 }

675

678 }

679

682 }

683

684 namespace {

685 template

686 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern {

688

689 LogicalResult matchAndRewrite(OpTy op,

691 auto isPotentiallyNonEmptyShape = [](Value shape) {

692 if (auto extentTensorTy =

693 llvm::dyn_cast(shape.getType())) {

694 if (extentTensorTy.getDimSize(0) == 0)

695 return false;

696 }

697 if (auto constShape = shape.getDefiningOp()) {

698 if (constShape.getShape().empty())

699 return false;

700 }

701 return true;

702 };

703 auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),

704 isPotentiallyNonEmptyShape);

705

706

707

708 if (newOperands.empty()) {

711 return success();

712 }

713

714

715 if (newOperands.size() < op.getNumOperands()) {

716 rewriter.replaceOpWithNewOp(op, op->getResultTypes(), newOperands,

717 op->getAttrs());

718 return success();

719 }

720

721 return failure();

722 }

723 };

724

725 struct BroadcastForwardSingleOperandPattern

728

729 LogicalResult matchAndRewrite(BroadcastOp op,

731 if (op.getNumOperands() != 1)

732 return failure();

733 Value replacement = op.getShapes().front();

734

735

736 if (replacement.getType() != op.getType()) {

737 auto loc = op.getLoc();

738 if (llvm::isa(op.getType())) {

739 replacement = rewriter.create(loc, replacement);

740 } else {

741 assert(!llvm::isa(op.getType()) &&

742 !llvm::isa(replacement.getType()) &&

743 "expect extent tensor cast");

744 replacement =

745 rewriter.createtensor::CastOp(loc, op.getType(), replacement);

746 }

747 }

748

749 rewriter.replaceOp(op, replacement);

750 return success();

751 }

752 };

753

754 struct BroadcastFoldConstantOperandsPattern

757

758 LogicalResult matchAndRewrite(BroadcastOp op,

762 for (Value shape : op.getShapes()) {

763 if (auto constShape = shape.getDefiningOp()) {

766 foldedConstantShape,

767 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),

768 newFoldedConstantShape)) {

769 foldedConstantShape = newFoldedConstantShape;

770 continue;

771 }

772 }

773 newShapeOperands.push_back(shape);

774 }

775

776

777 if (op.getNumOperands() - newShapeOperands.size() < 2)

778 return failure();

779

781 {static_cast<int64_t>(foldedConstantShape.size())},

783 newShapeOperands.push_back(rewriter.create(

784 op.getLoc(), foldedConstantOperandsTy,

787 newShapeOperands);

788 return success();

789 }

790 };

791

792 template

793 struct CanonicalizeCastExtentTensorOperandsPattern

796

797 LogicalResult matchAndRewrite(OpTy op,

799

800 bool anyChange = false;

801 auto canonicalizeOperand = [&](Value operand) -> Value {

802 if (auto castOp = operand.getDefiningOptensor::CastOp()) {

803

804 bool isInformationLoosingCast =

805 llvm::cast(castOp.getType()).isDynamicDim(0);

806 if (isInformationLoosingCast) {

807 anyChange = true;

808 return castOp.getSource();

809 }

810 }

811 return operand;

812 };

813 auto newOperands = llvm::to_vector<8>(

814 llvm::map_range(op.getOperands(), canonicalizeOperand));

815

816

817 if (!anyChange)

818 return failure();

819 rewriter.replaceOpWithNewOp(op, op->getResultTypes(), newOperands);

820 return success();

821 }

822 };

823

824 struct BroadcastConcretizeResultTypePattern

827

828 LogicalResult matchAndRewrite(BroadcastOp op,

830

831 auto resultTy = llvm::dyn_cast(op.getType());

832 if (!resultTy || !resultTy.isDynamicDim(0))

833 return failure();

834

835

836 int64_t maxRank = 0;

837 for (Value shape : op.getShapes()) {

838 if (auto extentTensorTy =

839 llvm::dyn_cast(shape.getType())) {

840

841

842 if (extentTensorTy.isDynamicDim(0))

843 return failure();

844 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));

845 }

846 }

847

848 auto newOp = rewriter.create(

850 op.getShapes());

851 rewriter.replaceOpWithNewOptensor::CastOp(op, op.getType(), newOp);

852 return success();

853 }

854 };

855 }

856

859 patterns.add<BroadcastConcretizeResultTypePattern,

860 BroadcastFoldConstantOperandsPattern,

861 BroadcastForwardSingleOperandPattern,

862 CanonicalizeCastExtentTensorOperandsPattern,

863 RemoveDuplicateOperandsPattern,

864 RemoveEmptyShapeOperandsPattern>(context);

865 }

866

867

868

869

870

871 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {

872 if (!adaptor.getLhs() || !adaptor.getRhs())

873 return nullptr;

874 auto lhsShape = llvm::to_vector<6>(

875 llvm::cast(adaptor.getLhs()).getValues<int64_t>());

876 auto rhsShape = llvm::to_vector<6>(

877 llvm::cast(adaptor.getRhs()).getValues<int64_t>());

879 resultShape.append(lhsShape.begin(), lhsShape.end());

880 resultShape.append(rhsShape.begin(), rhsShape.end());

883 }

884

885

886

887

888

890 p << " ";

892 p << "[";

893 interleaveComma(getShape().getValues<int64_t>(), p);

894 p << "] : ";

896 }

897

900 return failure();

901

902

903

906 if (parser.parseAttribute(extentsRaw, "dummy", dummy))

907 return failure();

908 auto extentsArray = llvm::dyn_cast(extentsRaw);

909 if (!extentsArray)

910 return failure();

912 for (Attribute extent : extentsArray) {

913 IntegerAttr attr = llvm::dyn_cast(extent);

914 if (!attr)

915 return failure();

916 ints.push_back(attr.getInt());

917 }

920 Type resultTy;

922 return failure();

923 result.types.push_back(resultTy);

924 return success();

925 }

926

927 OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }

928

931 patterns.add(context);

932 }

933

934 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(

935 MLIRContext *context, std::optional location,

938 const Properties prop = adaptor.getProperties();

940 {static_cast<int64_t>(prop.shape.size())}, b.getIndexType())});

941 return success();

942 }

943

944 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,

946 if (l.size() != 1 || r.size() != 1)

947 return false;

948

949 Type lhs = l.front();

950 Type rhs = r.front();

951

952 if (llvm::isa(lhs) || llvm::isa(rhs))

953

954 return true;

955 return lhs == rhs;

956 }

957

958

959

960

961

962 void CstrBroadcastableOp::getCanonicalizationPatterns(

964

965

966

967 patterns.add<CanonicalizeCastExtentTensorOperandsPattern,

968 CstrBroadcastableEqOps,

969 RemoveDuplicateOperandsPattern,

970 RemoveEmptyShapeOperandsPattern>(context);

971 }

972

973

974

976 bool nonScalarSeen = false;

978 if (!a || llvm::cast(a).getNumElements() != 0) {

979 if (nonScalarSeen)

980 return false;

981 nonScalarSeen = true;

982 }

983 }

984 return true;

985 }

986

987 OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {

988

991

992 if ([&] {

994 for (const auto &operand : adaptor.getShapes()) {

995 if (!operand)

996 return false;

997 extents.push_back(llvm::to_vector<6>(

998 llvm::cast(operand).getValues<int64_t>()));

999 }

1001 }())

1003

1004

1005

1006 if ([&] {

1008 for (auto shapeValue : getShapes()) {

1009 extents.emplace_back();

1010 if (failed(getShapeVec(shapeValue, extents.back())))

1011 return false;

1012 }

1014 }())

1016

1017

1018

1019 return nullptr;

1020 }

1021

1023

1024 if (getNumOperands() < 2)

1025 return emitOpError("required at least 2 input shapes");

1026 return success();

1027 }

1028

1029

1030

1031

1032

1035

1036 patterns.add(context);

1037 }

1038

1039 OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {

1040 if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) {

1041 return a && a == adaptor.getShapes().front();

1042 }))

1044

1045

1046

1047

1048 return nullptr;

1049 }

1050

1051

1052

1053

1054

1056 int64_t value) {

1057 build(builder, result, builder.getIndexAttr(value));

1058 }

1059

1060 OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }

1061

1062 void ConstSizeOp::getAsmResultNames(

1065 llvm::raw_svector_ostream os(buffer);

1066 os << "c" << getValue();

1067 setNameFn(getResult(), os.str());

1068 }

1069

1070

1071

1072

1073

1074 OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }

1075

1076

1077

1078

1079

1080 OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {

1081 return adaptor.getPred();

1082 }

1083

1084

1085

1086

1087

1088 std::optional<int64_t> DimOp::getConstantIndex() {

1089 if (auto constSizeOp = getIndex().getDefiningOp())

1090 return constSizeOp.getValue().getLimitedValue();

1091 if (auto constantOp = getIndex().getDefiningOparith::ConstantOp())

1092 return llvm::cast(constantOp.getValue()).getInt();

1093 return std::nullopt;

1094 }

1095

1096 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {

1097 Type valType = getValue().getType();

1098 auto valShapedType = llvm::dyn_cast(valType);

1099 if (!valShapedType || !valShapedType.hasRank())

1100 return nullptr;

1101 std::optional<int64_t> index = getConstantIndex();

1102 if (!index.has_value())

1103 return nullptr;

1104 if (index.value() < 0 || index.value() >= valShapedType.getRank())

1105 return nullptr;

1106 auto extent = valShapedType.getDimSize(*index);

1107 if (ShapedType::isDynamic(extent))

1108 return nullptr;

1110 }

1111

1112 LogicalResult mlir::shape::DimOp::inferReturnTypes(

1113 MLIRContext *context, std::optional location,

1115 inferredReturnTypes.assign({adaptor.getIndex().getType()});

1116 return success();

1117 }

1118

1119 bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {

1120 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);

1121 }

1122

1123

1124

1125

1126

1127 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {

1128 auto lhs = llvm::dyn_cast_if_present(adaptor.getLhs());

1129 if (!lhs)

1130 return nullptr;

1131 auto rhs = llvm::dyn_cast_if_present(adaptor.getRhs());

1132 if (!rhs || rhs.getValue().isZero())

1133 return nullptr;

1134

1135

1136

1137 APInt quotient, remainder;

1138 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);

1139 if (quotient.isNegative() && !remainder.isZero()) {

1140 quotient -= 1;

1141 }

1142

1145 }

1146

1147 LogicalResult mlir::shape::DivOp::inferReturnTypes(

1148 MLIRContext *context, std::optional location,

1150 if (llvm::isa(adaptor.getLhs().getType()) ||

1151 llvm::isa(adaptor.getRhs().getType()))

1152 inferredReturnTypes.assign({SizeType::get(context)});

1153 else

1154 inferredReturnTypes.assign({IndexType::get(context)});

1155 return success();

1156 }

1157

1158 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {

1159

1160 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);

1161 }

1162

1164

1165

1166

1167

1168

1169 OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {

1170 bool allSame = true;

1171 if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())

1172 return {};

1173 for (Attribute operand : adaptor.getShapes().drop_front()) {

1174 if (!operand)

1175 return {};

1176 allSame = allSame && operand == adaptor.getShapes().front();

1177 }

1179 }

1180

1181

1182

1183

1184

1185 OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {

1186

1187

1188 if (Attribute arg = adaptor.getArg())

1189 return arg;

1190 return {};

1191 }

1192

1195 patterns.add(context);

1196 }

1197

1198

1199

1200

1201

1202 OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {

1203 if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; }))

1204 return nullptr;

1206 for (auto attr : adaptor.getExtents())

1207 extents.push_back(llvm::cast(attr).getInt());

1210 }

1211

1212

1213

1214

1215

1217 StringRef name) {

1220 }

1221

1222 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {

1223 auto attr = llvm::dyn_cast_or_null(

1225 if (!attr)

1226 return nullptr;

1227 return lookupSymbol(attr);

1228 }

1229

1232

1233 StringAttr nameAttr;

1236 return failure();

1237

1239 return failure();

1240

1241 auto *bodyRegion = result.addRegion();

1243 return failure();

1244

1246 return failure();

1247

1248 DictionaryAttr mappingAttr;

1252 return failure();

1253 return success();

1254 }

1255

1257 p << ' ';

1260 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});

1261 p << ' ';

1262 p.printRegion(getRegion(), false,

1263 false);

1264 p << " mapping ";

1266 }

1267

1268

1269

1270

1271

1272 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,

1276 FuncOp::build(builder, state, name, type, attrs);

1278 }

1279 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,

1282 return create(location, name, type, llvm::ArrayRef(attrRef));

1283 }

1284 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,

1287 FuncOp func = create(location, name, type, attrs);

1288 func.setAllArgAttrs(argAttrs);

1289 return func;

1290 }

1291

1295 state.addAttribute(FuncOp::getSymNameAttrName(state.name),

1297 state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),

1299 state.attributes.append(attrs.begin(), attrs.end());

1300 state.addRegion();

1301

1302 if (argAttrs.empty())

1303 return;

1304 assert(type.getNumInputs() == argAttrs.size());

1306 builder, state, argAttrs, std::nullopt,

1307 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));

1308 }

1309

1311 auto buildFuncType =

1314 std::string &) { return builder.getFunctionType(argTypes, results); };

1315

1317 parser, result, false,

1318 getFunctionTypeAttrName(result.name), buildFuncType,

1319 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));

1320 }

1321

1324 p, *this, false, getFunctionTypeAttrName(),

1325 getArgAttrsAttrName(), getResAttrsAttrName());

1326 }

1327

1328

1329

1330

1331

1332 std::optional<int64_t> GetExtentOp::getConstantDim() {

1333 if (auto constSizeOp = getDim().getDefiningOp())

1334 return constSizeOp.getValue().getLimitedValue();

1335 if (auto constantOp = getDim().getDefiningOparith::ConstantOp())

1336 return llvm::cast(constantOp.getValue()).getInt();

1337 return std::nullopt;

1338 }

1339

1340 OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {

1341 auto elements = llvm::dyn_cast_if_present(adaptor.getShape());

1342 if (!elements)

1343 return nullptr;

1344 std::optional<int64_t> dim = getConstantDim();

1345 if (!dim.has_value())

1346 return nullptr;

1347 if (dim.value() >= elements.getNumElements())

1348 return nullptr;

1349 return elements.getValues<Attribute>()[(uint64_t)dim.value()];

1350 }

1351

1353 int64_t dim) {

1356 if (llvm::isa(shape.getType())) {

1357 Value dim = builder.create(loc, dimAttr);

1358 build(builder, result, builder.getType(), shape, dim);

1359 } else {

1361 builder.createarith::ConstantOp(loc, builder.getIndexType(), dimAttr);

1362 build(builder, result, builder.getIndexType(), shape, dim);

1363 }

1364 }

1365

1366 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(

1367 MLIRContext *context, std::optional location,

1369 inferredReturnTypes.assign({IndexType::get(context)});

1370 return success();

1371 }

1372

1373 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,

1375

1376 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);

1377 }

1378

1380

1381

1382

1383

1384

1387 patterns.add<RemoveDuplicateOperandsPattern>(context);

1388 }

1389

1390 OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {

1391

1392 if (adaptor.getShapes().size() < 2) {

1394 }

1395

1396 return nullptr;

1397 }

1398

1399

1400

1401

1402

1403 LogicalResult mlir::shape::MeetOp::inferReturnTypes(

1404 MLIRContext *context, std::optional location,

1406 if (adaptor.getOperands().empty())

1407 return failure();

1408

1409 auto isShapeType = [](Type arg) {

1410 if (llvm::isa(arg))

1411 return true;

1413 };

1414

1416 Type acc = types.front();

1417 for (auto t : drop_begin(types)) {

1418 Type l = acc, r = t;

1419 if (!llvm::isa<ShapeType, SizeType>(l))

1420 std::swap(l, r);

1421

1422

1423 if (llvm::isa(l)) {

1424 if (llvm::isa<SizeType, IndexType>(r))

1425 acc = l;

1426 else

1427 return emitOptionalError(location, "requires all sizes or shapes");

1428 } else if (llvm::isa(l)) {

1429 if (llvm::isa(r))

1430 acc = r;

1431 else

1432 return emitOptionalError(location, "requires all sizes or shapes");

1433 } else if (llvm::isa(l)) {

1434

1435 if (isShapeType(r))

1436 acc = l;

1437 else

1438 return emitOptionalError(location, "requires all sizes or shapes");

1440 auto rank1 = llvm::cast(l).getShape()[0];

1441 auto rank2 = llvm::cast(r).getShape()[0];

1442 if (ShapedType::isDynamic(rank1))

1443 acc = l;

1444 else if (ShapedType::isDynamic(rank2))

1445 acc = r;

1446 else if (rank1 != rank2)

1447 return emitOptionalError(location, "unequal shape cardinality");

1448 else

1449 acc = l;

1450 }

1451 }

1452 inferredReturnTypes.assign({acc});

1453 return success();

1454 }

1455

1456 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {

1457 if (l.size() != 1 || r.size() != 1)

1458 return false;

1459 if (l == r)

1460 return true;

1461

1462 Type lhs = l.front();

1463 Type rhs = r.front();

1464

1465 if (!llvm::isa<ShapeType, SizeType>(lhs))

1466 std::swap(lhs, rhs);

1467

1468 if (llvm::isa(lhs))

1469 return llvm::isa<SizeType, IndexType>(rhs);

1470 if (llvm::isa(lhs))

1471 return llvm::isa<ShapeType, TensorType>(rhs);

1472

1474 return true;

1475 return false;

1476 }

1477

1478

1479

1480

1481

1482 OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {

1483 auto shape = llvm::dyn_cast_if_present(adaptor.getShape());

1484 if (!shape)

1485 return {};

1486 int64_t rank = shape.getNumElements();

1489 }

1490

1491

1492

1493

1494

1495

1496

1497

1498

1499

1500

1501

1502

1503

1504

1505 namespace {

1506 struct RankShapeOfCanonicalizationPattern

1509

1510 LogicalResult matchAndRewrite(shape::RankOp op,

1512 auto shapeOfOp = op.getShape().getDefiningOp();

1513 if (!shapeOfOp)

1514 return failure();

1515 auto rankedTensorType =

1516 llvm::dyn_cast(shapeOfOp.getArg().getType());

1517 if (!rankedTensorType)

1518 return failure();

1519 int64_t rank = rankedTensorType.getRank();

1520 if (llvm::isa(op.getType())) {

1521 rewriter.replaceOpWithNewOparith::ConstantIndexOp(op.getOperation(),

1522 rank);

1523 } else if (llvm::isashape::SizeType(op.getType())) {

1524 rewriter.replaceOpWithNewOpshape::ConstSizeOp(op.getOperation(), rank);

1525 } else {

1526 return failure();

1527 }

1528 return success();

1529 }

1530 };

1531 }

1532

1535 patterns.add(context);

1536 }

1537

1538 LogicalResult mlir::shape::RankOp::inferReturnTypes(

1539 MLIRContext *context, std::optional location,

1541 if (llvm::isa(adaptor.getShape().getType()))

1542 inferredReturnTypes.assign({SizeType::get(context)});

1543 else

1544 inferredReturnTypes.assign({IndexType::get(context)});

1545 return success();

1546 }

1547

1548 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {

1549

1550 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);

1551 }

1552

1554

1555

1556

1557

1558

1559 OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {

1560

1561

1562 Attribute shape = adaptor.getShape();

1563 if (!shape)

1564 return {};

1565

1567 for (auto value : llvm::cast(shape))

1571 }

1572

1573 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(

1574 MLIRContext *context, std::optional location,

1575 NumElementsOp::Adaptor adaptor,

1577 if (llvm::isa(adaptor.getShape().getType()))

1578 inferredReturnTypes.assign({SizeType::get(context)});

1579 else

1580 inferredReturnTypes.assign({IndexType::get(context)});

1581 return success();

1582 }

1583

1584 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,

1586

1587 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);

1588 }

1589

1592 }

1593

1594

1595

1596

1597

1598 OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {

1599

1600 if (getLhs() == getRhs())

1601 return getLhs();

1602 return nullptr;

1603 }

1604

1605 LogicalResult mlir::shape::MaxOp::inferReturnTypes(

1606 MLIRContext *context, std::optional location,

1608 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())

1609 inferredReturnTypes.assign({adaptor.getLhs().getType()});

1610 else

1611 inferredReturnTypes.assign({SizeType::get(context)});

1612 return success();

1613 }

1614

1615 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {

1616 if (l.size() != 1 || r.size() != 1)

1617 return false;

1618 if (llvm::isa(l.front()) && llvm::isa(r.front()))

1619 return true;

1620 if (llvm::isa(l.front()) && llvm::isa(r.front()))

1621 return true;

1622 return false;

1623 }

1624

1625

1626

1627

1628

1629 OpFoldResult MinOp::fold(FoldAdaptor adaptor) {

1630

1631 if (getLhs() == getRhs())

1632 return getLhs();

1633 return nullptr;

1634 }

1635

1636 LogicalResult mlir::shape::MinOp::inferReturnTypes(

1637 MLIRContext *context, std::optional location,

1639 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())

1640 inferredReturnTypes.assign({adaptor.getLhs().getType()});

1641 else

1642 inferredReturnTypes.assign({SizeType::get(context)});

1643 return success();

1644 }

1645

1646 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {

1647 if (l.size() != 1 || r.size() != 1)

1648 return false;

1649 if (llvm::isa(l.front()) && llvm::isa(r.front()))

1650 return true;

1651 if (llvm::isa(l.front()) && llvm::isa(r.front()))

1652 return true;

1653 return false;

1654 }

1655

1656

1657

1658

1659

1660 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {

1661 auto lhs = llvm::dyn_cast_if_present(adaptor.getLhs());

1662 if (!lhs)

1663 return nullptr;

1664 auto rhs = llvm::dyn_cast_if_present(adaptor.getRhs());

1665 if (!rhs)

1666 return nullptr;

1667 APInt folded = lhs.getValue() * rhs.getValue();

1670 }

1671

1672 LogicalResult mlir::shape::MulOp::inferReturnTypes(

1673 MLIRContext *context, std::optional location,

1675 if (llvm::isa(adaptor.getLhs().getType()) ||

1676 llvm::isa(adaptor.getRhs().getType()))

1677 inferredReturnTypes.assign({SizeType::get(context)});

1678 else

1679 inferredReturnTypes.assign({IndexType::get(context)});

1680 return success();

1681 }

1682

1683 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {

1684

1685 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);

1686 }

1687

1689

1690

1691

1692

1693

1694 namespace {

1695

1696 struct ShapeOfOpToConstShapeOp : public OpRewritePatternshape::ShapeOfOp {

1698

1699 LogicalResult matchAndRewrite(shape::ShapeOfOp op,

1701 auto type = llvm::dyn_cast(op.getArg().getType());

1702 if (!type || !type.hasStaticShape())

1703 return failure();

1705 Value constShape =

1706 rewriter

1707 .create(loc,

1709 .getResult();

1710 if (constShape.getType() != op.getResult().getType())

1711 constShape = rewriter.createtensor::CastOp(

1712 loc, op.getResult().getType(), constShape);

1713 rewriter.replaceOp(op, constShape);

1714 return success();

1715 }

1716 };

1717

1718

1719

1720

1721

1722

1723

1724

1725

1726

1727

1728 struct ShapeOfFromReshape : public OpRewritePatternshape::ShapeOfOp {

1730

1731 LogicalResult matchAndRewrite(shape::ShapeOfOp op,

1733 auto tensorReshapeOp = op.getArg().getDefiningOptensor::ReshapeOp();

1734 if (!tensorReshapeOp)

1735 return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape");

1736 if (!isa(op.getType()))

1738

1739

1740

1741

1742

1743

1744

1745

1746

1747 Value shape = tensorReshapeOp.getShape();

1748

1749 auto opTensorTy = cast(op.getType());

1750 auto shapeTensorTy = cast(shape.getType());

1751

1752 if (opTensorTy != shapeTensorTy) {

1753 if (opTensorTy.getElementType() == shapeTensorTy.getElementType())

1754 shape = rewriter.createtensor::CastOp(op.getLoc(), opTensorTy, shape);

1756 shape =

1757 rewriter.createarith::IndexCastOp(op.getLoc(), opTensorTy, shape);

1758 }

1759

1761 return success();

1762 }

1763 };

1764

1765

1766

1767

1768

1769

1770

1771

1772

1773

1774 struct ShapeOfCastExtentTensor : public OpRewritePatterntensor::CastOp {

1776

1777 LogicalResult matchAndRewrite(tensor::CastOp op,

1779 auto ty = llvm::dyn_cast(op.getType());

1780 if (!ty || ty.getRank() != 1)

1781 return failure();

1782

1783 auto shapeOfOp = op.getSource().getDefiningOp();

1784 if (!shapeOfOp)

1785 return failure();

1786

1787

1788 auto argTy = llvm::dyn_cast(shapeOfOp.getArg().getType());

1789 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))

1790 return failure();

1791

1793 return success();

1794 }

1795 };

1796 }

1797

1800 patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,

1801 ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(

1802 context);

1803 }

1804

1805 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(

1806 MLIRContext *context, std::optional location,

1808 if (llvm::isa(adaptor.getArg().getType()))

1809 inferredReturnTypes.assign({ShapeType::get(context)});

1810 else {

1811 auto shapedTy = llvm::cast(adaptor.getArg().getType());

1812 int64_t rank =

1813 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;

1816 inferredReturnTypes.assign({extentTensorTy});

1817 }

1818 return success();

1819 }

1820

1821 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {

1822 if (l.size() != 1 || r.size() != 1)

1823 return false;

1824 if (l == r)

1825 return true;

1826

1827 Type lhs = l.front();

1828 Type rhs = r.front();

1829

1830 if (!llvm::isa<ShapeType, ShapedType>(lhs) ||

1831 !llvm::isa<ShapeType, ShapedType>(rhs))

1832 return false;

1833

1834 if (llvm::isa(lhs) || llvm::isa(rhs))

1835

1836 return true;

1837

1839 return true;

1840 return false;

1841 }

1842

1845 }

1846

1847

1848

1849

1850

1851 OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {

1852

1853

1854 if (Attribute arg = adaptor.getArg())

1855 return arg;

1857 }

1858

1861 patterns.add(context);

1862 }

1863

1864 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

1865 if (inputs.size() != 1 || outputs.size() != 1)

1866 return false;

1867 return llvm::isa<IndexType, SizeType>(inputs[0]) &&

1868 llvm::isa(outputs[0]);

1869 }

1870

1871

1872

1873

1874

1876 auto *parentOp = (*this)->getParentOp();

1877 auto results = parentOp->getResults();

1878 auto operands = getOperands();

1879

1880 if (parentOp->getNumResults() != getNumOperands())

1881 return emitOpError() << "number of operands does not match number of "

1882 "results of its parent";

1883 for (auto e : llvm::zip(results, operands))

1884 if (std::get<0>(e).getType() != std::get<1>(e).getType())

1885 return emitOpError() << "types mismatch between yield op and its parent";

1886

1887 return success();

1888 }

1889

1890

1891

1892

1893

1894 LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,

1896 if (!adaptor.getOperand() || !adaptor.getIndex())

1897 return failure();

1898 auto shapeVec = llvm::to_vector<6>(

1899 llvm::cast(adaptor.getOperand()).getValues<int64_t>());

1901 auto splitPoint = llvm::cast(adaptor.getIndex()).getInt();

1902

1903

1904 int64_t rank = shape.size();

1905 if (-rank > splitPoint || splitPoint > rank)

1906 return failure();

1907 if (splitPoint < 0)

1908 splitPoint += shape.size();

1909 Builder builder(adaptor.getOperand().getContext());

1910 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));

1911 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));

1912 return success();

1913 }

1914

1915

1916

1917

1918

1919 OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {

1920 if (!adaptor.getInput())

1923 auto shape = llvm::to_vector<6>(

1924 llvm::cast(adaptor.getInput()).getValues<int64_t>());

1928 }

1929

1930 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

1931 if (inputs.size() != 1 || outputs.size() != 1)

1932 return false;

1933 if (auto inputTensor = llvm::dyn_cast(inputs[0])) {

1934 if (!llvm::isa(inputTensor.getElementType()) ||

1935 inputTensor.getRank() != 1)

1936 return false;

1937 } else if (!llvm::isa(inputs[0])) {

1938 return false;

1939 }

1940

1941 TensorType outputTensor = llvm::dyn_cast(outputs[0]);

1942 return outputTensor && llvm::isa(outputTensor.getElementType());

1943 }

1944

1945

1946

1947

1948

1954

1958

1959 Type elementType;

1960 if (auto tensorType = llvm::dyn_cast(shape.getType()))

1961 elementType = tensorType.getElementType();

1962 else

1965

1966 for (Value initVal : initVals) {

1967 bodyBlock->addArgument(initVal.getType(), initVal.getLoc());

1968 result.addTypes(initVal.getType());

1969 }

1970 }

1971

1973

1975

1976

1977 auto blockArgsCount = getInitVals().size() + 2;

1979 return emitOpError() << "ReduceOp body is expected to have "

1980 << blockArgsCount << " arguments";

1981

1982

1984 return emitOpError(

1985 "argument 0 of ReduceOp body is expected to be of IndexType");

1986

1987

1988

1989

1992 if (!llvm::isa(extentTy))

1993 return emitOpError("argument 1 of ReduceOp body is expected to be of "

1994 "SizeType if the ReduceOp operates on a ShapeType");

1995 } else {

1996 if (!llvm::isa(extentTy))

1997 return emitOpError(

1998 "argument 1 of ReduceOp body is expected to be of IndexType if the "

1999 "ReduceOp operates on an extent tensor");

2000 }

2001

2003 if (block.getArgument(type.index() + 2).getType() != type.value().getType())

2004 return emitOpError() << "type mismatch between argument "

2005 << type.index() + 2

2006 << " of ReduceOp body and initial value "

2007 << type.index();

2008 return success();

2009 }

2010

2012

2014 Type shapeOrExtentTensorType;

2015 if (parser.parseOperandList(operands, -1,

2019 return failure();

2020

2021

2022 auto initVals = llvm::ArrayRef(operands).drop_front();

2023 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,

2027 return failure();

2028

2029

2031 if (parser.parseRegion(*body, {}, {}))

2032 return failure();

2033

2034

2036 return failure();

2037

2038 return success();

2039 }

2040

2042 p << '(' << getShape() << ", " << getInitVals()

2043 << ") : " << getShape().getType();

2045 p << ' ';

2048 }

2049

2050 #define GET_OP_CLASSES

2051 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"

2052

2053 #define GET_TYPEDEF_CLASSES

2054 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"

static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)

Copies the given number of bytes from src to dst pointers.

static bool isErrorPropagationPossible(TypeRange operandTypes)

static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)

static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)

static bool eachHasOnlyOneOfTypes(TypeRange typeRange)

static LogicalResult verifySizeOrIndexOp(Operation *op)

static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)

A utility function used to materialize a constant for a given attribute and type.

static int64_t product(ArrayRef< int64_t > vals)

static MLIRContext * getContext(OpFoldResult val)

static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)

Utility to check that all of the operations within 'src' can be inlined.

static int64_t getNumElements(Type t)

Compute the total number of elements in the given type, also taking into account nested types.

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)

static ArrayRef< int64_t > getShape(Type type)

Returns the shape of the given type.

ParseResult parseSymbolName(StringAttr &result)

Parse an -identifier and store it (without the '@' symbol) in a string attribute.

@ Paren

Parens surrounding zero or more operands.

virtual Builder & getBuilder() const =0

Return a builder which provides useful access to MLIRContext, global objects like types and attribute...

virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0

Parse a named dictionary into 'result' if it is present.

virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0

Parse a named dictionary into 'result' if the attributes keyword is present.

virtual ParseResult parseColonType(Type &result)=0

Parse a colon followed by a type.

virtual SMLoc getNameLoc() const =0

Return the location of the original name token.

virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0

Parse an optional arrow followed by a type list.

ParseResult parseKeyword(StringRef keyword)

Parse a given keyword.

virtual ParseResult parseAttribute(Attribute &result, Type type={})=0

Parse an arbitrary attribute of a given type and return it in result.

virtual void printAttributeWithoutType(Attribute attr)

Print the given attribute without its type.

virtual void printType(Type type)

virtual void printSymbolName(StringRef symbolRef)

Print the given string as a symbol reference, i.e.

void printOptionalArrowTypeList(TypeRange &&types)

Print an optional arrow followed by a type list.

Attributes are known-constant values of operations.

MLIRContext * getContext() const

Return the context this attribute belongs to.

Block represents an ordered list of Operations.

BlockArgument getArgument(unsigned i)

unsigned getNumArguments()

Operation * getTerminator()

Get the terminator operation of this block.

BlockArgument addArgument(Type type, Location loc)

Add one value to the argument list.

static BoolAttr get(MLIRContext *context, bool value)

This class is a general helper class for creating context-global objects like types,...

IntegerAttr getIndexAttr(int64_t value)

FunctionType getFunctionType(TypeRange inputs, TypeRange results)

Ty getType(Args &&...args)

Get or construct an instance of the type Ty with provided arguments.

StringAttr getStringAttr(const Twine &bytes)

DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)

MLIRContext * getContext() const

NamedAttribute getNamedAttr(StringRef name, Attribute val)

An attribute that represents a reference to a dense integer vector or tensor object.

static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)

Get an instance of a DenseIntElementsAttr with the given arguments.

This is the interface that must be implemented by the dialects of operations to be inlined.

DialectInlinerInterface(Dialect *dialect)

This is a utility class for mapping one set of IR entities to another.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...

void push_back(NamedAttribute newAttribute)

Add an attribute with the specified name.

NamedAttribute represents a combination of a name and an Attribute value.

StringAttr getName() const

Return the name of the attribute.

Attribute getValue() const

Return the value of the attribute.

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0

Parses a region.

virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0

Resolve an operand to an SSA value, emitting an error on failure.

ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)

Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...

virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0

Parse a single SSA value operand name along with a result number if allowResultNumber is true.

virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0

Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0

If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...

virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0

If the specified operation has attributes, print out an attribute dictionary with their values.

virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0

Prints a region.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Block::iterator getInsertionPoint() const

Returns the current insertion point of the builder.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

Block * getInsertionBlock() const

Return the block the current insertion point belongs to.

This class represents a single result from folding an operation.

A trait used to provide symbol table functionalities to a region operation.

StringAttr getIdentifier() const

Return the name of this operation as a StringAttr.

Operation is the basic unit of execution within MLIR.

bool hasTrait()

Returns true if the operation was registered with a particular trait, e.g.

static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)

Create a new Operation with the specific fields.

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

OperationName getName()

The name of an operation is the key identifier for it.

operand_type_range getOperandTypes()

result_type_range getResultTypes()

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

unsigned getNumResults()

Return the number of results held by this operation.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class represents a point being branched from in the methods of the RegionBranchOpInterface.

bool isParent() const

Returns true if branching from the parent op.

This class represents a successor of a region.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

Block * splitBlock(Block *block, Block::iterator before)

Split the operations starting at "before" (inclusive) out of the given block into a new block,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)

Inline the operations of block 'source' into the end of block 'dest'.

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)

Move the blocks that belong to "region" before the given position in another region "parent".

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

static StringRef getSymbolAttrName()

Return the name of the attribute used for symbol names.

static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)

Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.

Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...

Type getElementType() const

Returns the element type of this tensor type.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

ValueTypeRange< ValueRange > type_range

Type front()

Return first type in the range.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Location getLoc() const

Return the location of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

A named class for passing around the variadic flag.

bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)

Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...

bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)

Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...

void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)

Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)

Printer implementation for function-like operations.

ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)

Parser implementation for function-like operations.

DynamicAPInt getIndex(const ConeV &cone)

Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...

QueryRef parse(llvm::StringRef line, const QuerySession &qs)

bool isExtentTensorType(Type)

LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)

RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)

Alias type for extent tensors.

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)

Returns success if the given two arrays have the same number of elements and each pair wise entries h...

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)

Overloads of the above emission functions that take an optionally null location.

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

detail::constant_int_predicate_matcher m_Zero()

Matches a constant scalar / vector splat / tensor splat integer zero.

const FrozenRewritePatternSet & patterns

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

This is the representation of an operand reference.

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

This represents an operation in an abstracted form, suitable for use with the builder APIs.

SmallVector< Value, 4 > operands

void addOperands(ValueRange newOperands)

void addAttribute(StringRef name, Attribute attr)

Add an attribute with the specified name.

void addTypes(ArrayRef< Type > newTypes)

SmallVector< std::unique_ptr< Region >, 1 > regions

Regions that the op will hold.

SmallVector< Type, 4 > types

Types of the results of this operation.

Region * addRegion()

Create a region that should be attached to the operation.