MLIR: lib/Dialect/Linalg/IR/LinalgOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

14

41

42 #include "llvm/ADT/DenseMap.h"

43 #include "llvm/ADT/STLExtras.h"

44 #include "llvm/ADT/SetOperations.h"

45 #include "llvm/ADT/SmallSet.h"

46 #include "llvm/ADT/SmallVector.h"

47 #include "llvm/ADT/StringSet.h"

48 #include "llvm/ADT/TypeSwitch.h"

49 #include "llvm/Support/FormatVariadic.h"

50 #include "llvm/Support/InterleavedRange.h"

51 #include "llvm/Support/LogicalResult.h"

52 #include "llvm/Support/MathExtras.h"

53 #include "llvm/Support/raw_ostream.h"

54 #include

55 #include

56

57 using namespace mlir;

59

60

62 int64_t dim) {

63 auto type = cast(v.getType());

64 if (!type.isDynamicDim(dim))

65 return builder.getIndexAttr(type.getDimSize(dim));

66

69 .Case([&](RankedTensorType t) -> Value {

70 return builder.createtensor::DimOp(loc, v, dim);

71 })

72 .Case([&](MemRefType t) -> Value {

73 return builder.creatememref::DimOp(loc, v, dim);

74 }));

75 }

76

77

78

84 .Case([&](RankedTensorType t) -> Operation * {

85 return b.createtensor::ExtractSliceOp(loc, source, offsets, sizes,

86 strides);

87 })

88 .Case([&](MemRefType type) -> Operation * {

89 return b.creatememref::SubViewOp(loc, source, offsets, sizes,

90 strides);

91 })

92 .Default([&](Type t) -> Operation * { return nullptr; });

93 }

94

95

96

97

98

100 int64_t dim) {

101 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))

102 return b.createOrFoldmemref::DimOp(loc, source, dim);

103 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))

104 return b.createOrFoldtensor::DimOp(loc, source, dim);

105 llvm_unreachable("Expected MemRefType or TensorType");

106 }

107

109 int64_t dim) {

110 auto shapedType = llvm::cast(source.getType());

111 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))

113 return b.getIndexAttr(shapedType.getDimSize(dim));

114 }

115

116

117

118

119

122

123

124

125

126

127

134 for (auto containers : {inputTypes, outputTypes}) {

135 for (auto t : containers) {

136 argTypes.push_back(

138

139

141 }

142 }

143

144

147 opBuilder.createBlock(&region, {}, argTypes, argLocs);

148

151 regionBuilder(b, *body, attrs);

152

153

154

155

156 }

157

158

159

160

161

163 std::optional resultTensorTypes,

167

169 resultTensorTypes.value_or(TypeRange());

170 if (!resultTensorTypes)

171 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),

172 llvm::IsaPred);

173

174 state.addOperands(inputs);

175 state.addOperands(outputs);

176 state.addTypes(derivedResultTypes);

177

178 state.addAttributes(attributes);

179 state.addAttribute(

180 "operandSegmentSizes",

182 static_cast<int32_t>(outputs.size())}));

183

184

185 Region &region = *state.addRegion();

187 state.attributes.getAttrs(), regionBuilder);

188 }

189

191 std::optional resultTensorTypes,

196

198 indexingMapsAttrVal = llvm::map_to_vector(

199 MatmulOp::getDefaultIndexingMaps(b.getContext()),

201 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));

202 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,

203 attributes, regionBuilder);

204 }

205

207 std::optional resultTensorTypes,

212

214 indexingMapsAttrVal =

217 });

218 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));

219 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,

220 attributes, regionBuilder);

221 }

222

224 std::optional resultTensorTypes,

229

231 indexingMapsAttrVal =

234 });

235 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));

236 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,

237 attributes, regionBuilder);

238 }

239

240

241

242 static ParseResult

246 bool addOperandSegmentSizes = true) {

247 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;

249 outputsOperands;

250

253 return failure();

254 }

257 return failure();

258

261 return failure();

262

266 return failure();

267 }

268

273 return failure();

274 }

275

276 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,

278 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,

280 return failure();

281

282 if (addOperandSegmentSizes) {

283

284

285

286

287

288

291 attrs.append("operandSegmentSizes",

293 {static_cast<int32_t>(inputsOperands.size()),

294 static_cast<int32_t>(outputsOperands.size())}));

296 } else {

299 {static_cast<int32_t>(inputsOperands.size()),

300 static_cast<int32_t>(outputsOperands.size())}));

301 }

302 }

304 std::optional info =

306 if (info) {

307 if (failed(info->verifyInherentAttrs(result.attributes, [&]() {

308 return parser.emitError(attrsLoc)

309 << "'" << result.name.getStringRef() << "' op ";

310 })))

311 return failure();

312 }

313 }

314 return success();

315 }

316

319 if (!inputs.empty())

320 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";

321 if (!outputs.empty())

322 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";

323 }

324

325

326

327

328

333 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {

336 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "

337 "region expects {0} args, got {1}",

338 numRegionArgs, inputTypes.size() + outputTypes.size()));

339 }

340

343 regionBuilder);

344 return success();

345 }

346

347 static ParseResult

351 return failure();

352 return success();

353 }

354

357 unsigned numRegionArgs,

359

362 return failure();

363

364

366 return failure();

367

368

369

372 return failure();

373 result.addTypes(outputTensorsTypes);

374

375 std::unique_ptr region = std::make_unique();

378 regionBuilder))

379 return failure();

380 result.addRegion(std::move(region));

381

382 return success();

383 }

384

387 if (resultTypes.empty())

388 return;

390 }

391

396

397

398

400

401

403

404

405 }

406

407

408

409

410

411

412

413

414

415

416

417

418

419

420

421

422

423

424

425

426

427

428

429

430 namespace {

431

432 class RegionBuilderHelper {

433 public:

434 RegionBuilderHelper(OpBuilder &builder, Block &block)

435 : builder(builder), block(block) {}

436

437

439 if (!isFloatingPoint(arg))

440 llvm_unreachable("unsupported non numeric type");

442 builder.setInsertionPointToEnd(&block);

444 case UnaryFn::exp:

445 return builder.createmath::ExpOp(arg.getLoc(), arg);

446 case UnaryFn:🪵

447 return builder.createmath::LogOp(arg.getLoc(), arg);

449 return builder.createmath::AbsFOp(arg.getLoc(), arg);

451 return builder.createmath::CeilOp(arg.getLoc(), arg);

453 return builder.createmath::FloorOp(arg.getLoc(), arg);

454 case UnaryFn::negf:

455 return builder.createarith::NegFOp(arg.getLoc(), arg);

456 case UnaryFn::reciprocal: {

458 auto one = builder.createarith::ConstantOp(arg.getLoc(),

459 ::cast(oneAttr));

460 return builder.createarith::DivFOp(arg.getLoc(), one, arg);

461 }

463 return builder.createmath::RoundOp(arg.getLoc(), arg);

464 case UnaryFn::sqrt:

465 return builder.createmath::SqrtOp(arg.getLoc(), arg);

466 case UnaryFn::rsqrt:

467 return builder.createmath::RsqrtOp(arg.getLoc(), arg);

468 case UnaryFn::square:

469 return builder.createarith::MulFOp(arg.getLoc(), arg, arg);

470 case UnaryFn::tanh:

471 return builder.createmath::TanhOp(arg.getLoc(), arg);

472 case UnaryFn::erf:

473 return builder.createmath::ErfOp(arg.getLoc(), arg);

474 }

475 llvm_unreachable("unsupported unary function");

476 }

477

478

480 bool allComplex = isComplex(arg0) && isComplex(arg1);

481 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);

482 bool allInteger = isInteger(arg0) && isInteger(arg1);

485 if (!allComplex && !allFloatingPoint && !allInteger)

486 llvm_unreachable("unsupported non numeric type");

488 builder.setInsertionPointToEnd(&block);

490 case BinaryFn::add:

491 if (allComplex)

492 return builder.createcomplex::AddOp(arg0.getLoc(), arg0, arg1);

493 if (allFloatingPoint)

494 return builder.createarith::AddFOp(arg0.getLoc(), arg0, arg1);

495 if (allBool)

496 return builder.createarith::OrIOp(arg0.getLoc(), arg0, arg1);

497 return builder.createarith::AddIOp(arg0.getLoc(), arg0, arg1);

498 case BinaryFn::sub:

499 if (allComplex)

500 return builder.createcomplex::SubOp(arg0.getLoc(), arg0, arg1);

501 if (allFloatingPoint)

502 return builder.createarith::SubFOp(arg0.getLoc(), arg0, arg1);

503 if (allBool)

504 llvm_unreachable("unsupported operation: sub with bools");

505 return builder.createarith::SubIOp(arg0.getLoc(), arg0, arg1);

506 case BinaryFn::mul:

507 if (allComplex)

508 return builder.createcomplex::MulOp(arg0.getLoc(), arg0, arg1);

509 if (allFloatingPoint)

510 return builder.createarith::MulFOp(arg0.getLoc(), arg0, arg1);

511 if (allBool)

512 return builder.createarith::AndIOp(arg0.getLoc(), arg0, arg1);

513 return builder.createarith::MulIOp(arg0.getLoc(), arg0, arg1);

514 case BinaryFn::div:

515 if (allComplex)

516 return builder.createcomplex::DivOp(arg0.getLoc(), arg0, arg1);

517 if (allFloatingPoint)

518 return builder.createarith::DivFOp(arg0.getLoc(), arg0, arg1);

519 if (allBool)

520 llvm_unreachable("unsupported operation: div with bools");

521 return builder.createarith::DivSIOp(arg0.getLoc(), arg0, arg1);

522 case BinaryFn::div_unsigned:

523 if (!allInteger || allBool)

524 llvm_unreachable("unsupported operation: unsigned div not on uint");

525 return builder.createarith::DivUIOp(arg0.getLoc(), arg0, arg1);

526 case BinaryFn::max_signed:

527 assert(!allComplex);

528 if (allFloatingPoint)

529 return builder.createarith::MaximumFOp(arg0.getLoc(), arg0, arg1);

530 return builder.createarith::MaxSIOp(arg0.getLoc(), arg0, arg1);

531 case BinaryFn::min_signed:

532 assert(!allComplex);

533 if (allFloatingPoint)

534 return builder.createarith::MinimumFOp(arg0.getLoc(), arg0, arg1);

535 return builder.createarith::MinSIOp(arg0.getLoc(), arg0, arg1);

536 case BinaryFn::max_unsigned:

537 assert(!allComplex);

538 if (allFloatingPoint)

539 return builder.createarith::MaximumFOp(arg0.getLoc(), arg0, arg1);

540 return builder.createarith::MaxUIOp(arg0.getLoc(), arg0, arg1);

541 case BinaryFn::min_unsigned:

542 assert(!allComplex);

543 if (allFloatingPoint)

544 return builder.createarith::MinimumFOp(arg0.getLoc(), arg0, arg1);

545 return builder.createarith::MinUIOp(arg0.getLoc(), arg0, arg1);

546 case BinaryFn::powf:

547 assert(allFloatingPoint);

548 return builder.createmath::PowFOp(arg0.getLoc(), arg0, arg1);

549 }

550 llvm_unreachable("unsupported binary function");

551 }

552

553

556 bool headBool =

558 bool tailFloatingPoint =

559 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);

560 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);

562 builder.setInsertionPointToEnd(&block);

564 case TernaryFn::select:

565 if (!headBool && !(tailFloatingPoint || tailInteger))

566 llvm_unreachable("unsupported non numeric type");

567 return builder.createarith::SelectOp(arg0.getLoc(), arg0, arg1, arg2);

568 }

569 llvm_unreachable("unsupported ternary function");

570 }

571

572

573 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {

574 switch (typeFn) {

575 case TypeFn::cast_signed:

576 return cast(toType, operand, false);

577 case TypeFn::cast_unsigned:

578 return cast(toType, operand, true);

579 }

580 llvm_unreachable("unsupported type conversion function");

581 }

582

583 void yieldOutputs(ValueRange values) {

585 builder.setInsertionPointToEnd(&block);

586 Location loc = builder.getUnknownLoc();

587 builder.create(loc, values);

588 }

589

590 Value constant(const std::string &value) {

592 builder.setInsertionPointToEnd(&block);

593 Location loc = builder.getUnknownLoc();

595 return builder.createarith::ConstantOp(loc, ::cast(valueAttr));

596 }

597

598 Value index(int64_t dim) {

600 builder.setInsertionPointToEnd(&block);

601 return builder.create(builder.getUnknownLoc(), dim);

602 }

603

604 Type getIntegerType(unsigned width) {

606 }

607

610

611 private:

612

613

614

615

616 Value cast(Type toType, Value operand, bool isUnsignedCast) {

618 builder.setInsertionPointToEnd(&block);

619 auto loc = operand.getLoc();

621 }

622

623 bool isComplex(Value value) {

624 return llvm::isa(value.getType());

625 }

626 bool isFloatingPoint(Value value) {

627 return llvm::isa(value.getType());

628 }

629 bool isInteger(Value value) {

630 return llvm::isa(value.getType());

631 }

632

635 };

636

637 }

638

639

640

641

642

643 namespace {

644

647 LogicalResult matchAndRewrite(CopyOp copyOp,

649 if (copyOp.getInputs() != copyOp.getOutputs())

651 if (copyOp.hasPureBufferSemantics())

652 rewriter.eraseOp(copyOp);

653 else

654 rewriter.replaceOp(copyOp, copyOp.getInputs());

655

656 return success();

657 }

658 };

659

660 }

661

662 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,

664 results.add(context);

665 }

666

667

668

669

670

671 namespace {

672

673

674

675

676

677 template

678 struct FoldFillWithTensorReshape : OpRewritePattern {

680 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,

682 auto oldFill = reshapeOp.getSrc().template getDefiningOp();

683 if (!oldFill)

684 return failure();

685

686 Location loc = oldFill.getLoc();

687 TensorReshapeOp newInit;

688 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {

689

690 newInit = rewriter.create(

691 loc, reshapeOp.getResultType(), oldFill.output(),

692 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),

693 reshapeOp.getStaticOutputShape());

694 } else {

695 newInit = rewriter.create(loc, reshapeOp.getResultType(),

696 oldFill.output(),

697 reshapeOp.getReassociation());

698 }

701 return success();

702 }

703 };

704

705

706

707 struct FoldFillWithPad final : public OpRewritePatterntensor::PadOp {

709

710 LogicalResult matchAndRewrite(tensor::PadOp padOp,

712 auto fillOp = padOp.getSource().getDefiningOplinalg::FillOp();

713 if (!fillOp)

714 return failure();

715

716

717

718 Value padValue = padOp.getConstantPaddingValue();

719 if (!padValue || fillOp.value() != padValue)

720 return failure();

721

725 padOp, "failed to reify tensor.pad op result shape");

726

727 auto emptyTensor = rewriter.createtensor::EmptyOp(

728 padOp.getLoc(), reifiedShape.front(),

729 padOp.getResultType().getElementType());

730 Value replacement =

731 rewriter

734 .getResult(0);

735 if (replacement.getType() != padOp.getResultType()) {

736 replacement = rewriter.createtensor::CastOp(

737 fillOp.getLoc(), padOp.getResultType(), replacement);

738 }

739 rewriter.replaceOp(padOp, replacement);

740 return success();

741 }

742 };

743

744

745

746

747 struct FoldInsertPadIntoFill : public OpRewritePatterntensor::InsertSliceOp {

749

750 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,

752 auto srcPadOp = insertOp.getSource().getDefiningOptensor::PadOp();

753 if (!srcPadOp)

754 return failure();

755

756 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())

757 return failure();

758

759

760

761 Value firstDest = insertOp.getDest();

762 while (auto prevOp = firstDest.getDefiningOptensor::InsertSliceOp()) {

763 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())

764 return failure();

765

766

767

768 bool disjoint = false;

769 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {

770

771

772 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||

773 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||

774 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))

775 continue;

776

777

778 int64_t prevStart = prevOp.getStaticOffset(i);

779 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *

780 prevOp.getStaticStride(i);

781 int64_t nextStart = insertOp.getStaticOffset(i);

782 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *

783 insertOp.getStaticStride(i);

784 if (prevEnd < nextStart || nextEnd < prevStart) {

785 disjoint = true;

786 break;

787 }

788 }

789

790 if (!disjoint)

791 break;

792 firstDest = prevOp.getDest();

793 }

794

795

796

797 auto dstFillOp = firstDest.getDefiningOplinalg::FillOp();

798 if (!dstFillOp)

799 return failure();

800

801

802

803 Value padValue = srcPadOp.getConstantPaddingValue();

804 if (!padValue || dstFillOp.value() != padValue)

805 return failure();

806

809

810 Location loc = insertOp.getLoc();

812

815 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);

816

817

818

820 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {

822 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));

823 }

824

825 RankedTensorType srcPadType = srcPadOp.getSourceType();

827 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {

828 if (srcPadType.isDynamicDim(i)) {

829 newSizes.push_back(

830 rewriter.createtensor::DimOp(loc, srcPadOp.getSource(), i)

831 .getResult());

832 } else {

833 newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));

834 }

835 }

836

838 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,

839 newSizes, insertOp.getMixedStrides());

840 return success();

841 }

842 };

843

844

845 struct FoldFillWithTensorExtract : public OpRewritePatterntensor::ExtractOp {

846 public:

848

849 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,

851

852

853 auto fillOp = extractOp.getTensor().getDefiningOplinalg::FillOp();

854 if (!fillOp)

855 return failure();

856

857

858 Value extractedScalar = fillOp.getInputs()[0];

859

860

861 rewriter.replaceOp(extractOp, extractedScalar);

862 return success();

863 }

864 };

865

866

867

868

869 static FailureOr foldFillPackIntoFillOp(RewriterBase &rewriter,

870 linalg::PackOp packOp) {

871 auto fillOp = packOp.getSource().getDefiningOp();

872 if (!fillOp)

873 return failure();

874

875 if (auto paddingValue = packOp.getPaddingValue())

877 return failure();

878

879 Value packOpDest = packOp.getDest();

881 return failure();

882

883 return rewriter.createlinalg::FillOp(packOp.getLoc(), fillOp.getInputs(),

884 packOp.getDest());

885 }

886

887

888 struct FoldFillWithPack : public OpRewritePatternlinalg::PackOp {

889 public:

892

893 LogicalResult matchAndRewrite(linalg::PackOp packOp,

895 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);

896 if (failed(fillOp))

897 return failure();

898 rewriter.replaceOp(packOp, fillOp.value().result());

899 return success();

900 }

901 };

902

903

906

907 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,

909 if (auto fillOp = copyOp.getInputs().front().getDefiningOp()) {

911 fillOp.getInputs(),

912 copyOp.getOutputs());

913 return success();

914 }

915 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp()) {

916 rewriter.replaceOpWithNewOplinalg::CopyOp(copyOp, copyOp.getInputs(),

917 fillOp.getOutputs());

918 return success();

919 }

920 return failure();

921 }

922 };

923

924

925 struct FoldFillWithTranspose : OpRewritePatternlinalg::TransposeOp {

927

928 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,

930 if (auto fillOp = transposeOp.getInput().getDefiningOp()) {

932 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),

933 transposeOp.getDpsInitOperand(0)->get());

934 return success();

935 }

936 return failure();

937 }

938 };

939

940

941

942 struct FoldConcatsOfFill : public OpRewritePatterntensor::ConcatOp {

944

945 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,

947 auto concatOperands = concatOp.getInputs();

948 if (concatOperands.empty()) {

949 return failure();

950 }

951

952 auto firstFillOp = concatOperands.front().getDefiningOplinalg::FillOp();

953 if (!firstFillOp) {

954 return failure();

955 }

956

959

961 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());

962

963 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {

964 auto fillOp = v.getDefiningOplinalg::FillOp();

965 if (!fillOp) {

966 return false;

967 }

968

971 if (fillVal != firstFillVal)

972 return false;

973

974 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());

975 return true;

976 };

977 if (!llvm::all_of(concatOperands.drop_front(),

978 isDefinedByCompatibleFillOp)) {

980 concatOp, "not all operands are defined by a compatible fill op");

981 }

982

983 Value outsConcat = rewriter.createtensor::ConcatOp(

984 concatOp.getLoc(), concatOp.getDim(), allOuts);

986 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);

987 return success();

988 }

989 };

990

991 }

992

993 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,

995 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,

996 FoldFillWithPack, FoldFillWithPad,

997 FoldFillWithTensorReshapetensor::CollapseShapeOp,

998 FoldFillWithTensorReshapetensor::ExpandShapeOp,

999 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);

1000 }

1001

1002

1003

1004

1005

1012 for (ValueRange container : {inputs, outputs}) {

1013 for (Value v : container) {

1014 Type t = v.getType();

1015 blockArgTypes.push_back(

1017 blockArgLocs.push_back(v.getLoc());

1018 }

1019 }

1020

1022 Block *bodyBlock =

1023 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);

1024 bodyBuild(builder, loc, bodyBlock->getArguments());

1025 }

1026

1027 void GenericOp::getAsmBlockArgumentNames(Region &region,

1029 for (Value v : getRegionInputArgs())

1030 setNameFn(v, "in");

1031 for (Value v : getRegionOutputArgs())

1032 setNameFn(v, "out");

1033 }

1034

1035 void GenericOp::build(

1038 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,

1041 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,

1042 iteratorTypes, doc, libraryCall);

1044 if (bodyBuild)

1046 inputs, outputs, bodyBuild);

1047 }

1048

1049 void GenericOp::build(

1053 StringRef libraryCall,

1056 build(builder, result, resultTensorTypes, inputs, outputs,

1058 builder.getArrayAttr(llvm::to_vector(llvm::map_range(

1059 iteratorTypes,

1061 return IteratorTypeAttr::get(builder.getContext(), iter);

1062 }))),

1063 doc.empty() ? StringAttr() : builder.getStringAttr(doc),

1064 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),

1065 bodyBuild, attributes);

1066 }

1067

1068 void GenericOp::build(

1072 StringRef libraryCall,

1075 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,

1076 iteratorTypes, doc, libraryCall, bodyBuild, attributes);

1077 }

1078

1079 void GenericOp::build(

1085 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,

1086 "",

1087 "", bodyBuild, attributes);

1088 }

1089

1090 void GenericOp::build(

1096 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,

1097 iteratorTypes,

1098 "",

1099 "", bodyBuild, attributes);

1100 }

1101

1103 p << " ";

1104

1105

1106 auto genericAttrNames = linalgTraitAttrNames();

1107

1109 genericAttrNamesSet.insert_range(genericAttrNames);

1111 for (auto attr : (*this)->getAttrs()) {

1112 if (attr.getName() == getIteratorTypesAttrName()) {

1113 auto iteratorTypes =

1114 llvm::cast(attr.getValue())

1115 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();

1116

1117

1118

1119

1121 llvm::to_vector(llvm::map_range(

1122 iteratorTypes, [&](utils::IteratorType t) -> Attribute {

1124 }));

1125

1126 genericAttrs.emplace_back(

1127 getIteratorTypesAttrName(),

1129 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {

1130 genericAttrs.push_back(attr);

1131 }

1132 }

1133 if (!genericAttrs.empty()) {

1135 p << genericDictAttr;

1136 }

1137

1138

1140

1141 genericAttrNames.push_back("operandSegmentSizes");

1142 genericAttrNamesSet.insert(genericAttrNames.back());

1143

1144 bool hasExtraAttrs = false;

1146 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))

1147 break;

1148 }

1149 if (hasExtraAttrs) {

1150 p << " attrs = ";

1152 genericAttrNames);

1153 }

1154

1155

1156 if (!getRegion().empty()) {

1157 p << ' ';

1159 }

1160

1161

1163 }

1164

1166 DictionaryAttr dictAttr;

1167

1168

1169

1170

1173 return failure();

1175 dictAttr.getValue().end());

1176

1177

1178

1179

1180

1181 auto iteratorTypes = dyn_cast_or_null(

1183 if (!iteratorTypes) {

1184 return parser.emitError(attributeLocation)

1185 << "expected " << getIteratorTypesAttrName(result.name)

1186 << " array attribute";

1187 }

1188

1190

1191 for (StringRef s : iteratorTypes.getAsValueRange()) {

1192 auto maybeIteratorType = utils::symbolizeIteratorType(s);

1193 if (!maybeIteratorType.has_value())

1195 << "unexpected iterator_type (" << s << ")";

1196

1197 iteratorTypeAttrs.push_back(

1199 }

1202

1203

1206 return failure();

1207

1208

1212 return failure();

1213

1214 std::unique_ptr region = std::make_unique();

1216 return failure();

1217 result.addRegion(std::move(region));

1218

1219

1220

1221

1222

1225 return failure();

1226 result.addTypes(outputTensorsTypes);

1227

1228 return success();

1229 }

1230

1233 &effects,

1234 LinalgOp linalgOp) {

1235 for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {

1236 if (!llvm::isa(operand.getType()))

1237 continue;

1238 effects.emplace_back(

1241 }

1242

1243 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {

1244 if (!llvm::isa(operand.get().getType()))

1245 continue;

1246 if (linalgOp.payloadUsesValueFromOperand(&operand)) {

1248 true,

1250 }

1252 true,

1254 }

1255 }

1256

1257 void GenericOp::getEffects(

1259 &effects) {

1261 }

1262

1265

1266

1267 if (!linalgOp.hasPureTensorSemantics())

1269

1271 }

1272

1275 }

1276

1278

1279 namespace {

1280

1281

1282

1283

1284

1285

1286

1287 template

1288 struct EraseIdentityLinalgOp : public OpRewritePattern {

1290

1291 LogicalResult matchAndRewrite(OpTy linalgOp,

1293

1294 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))

1295 return failure();

1296

1297

1298

1299 Block &body = linalgOp->getRegion(0).front();

1300 if (!llvm::hasSingleElement(body))

1301 return failure();

1302 auto yieldOp = dyn_castlinalg::YieldOp(body.getTerminator());

1303 if (!yieldOp)

1304 return failure();

1305

1306

1307 if (linalgOp.hasPureBufferSemantics()) {

1308 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||

1309 linalgOp.getDpsInputOperand(0)->get() !=

1310 linalgOp.getDpsInitOperand(0)->get()) {

1312 linalgOp, "expected single input and output to be the same value");

1313 }

1314

1315 auto yieldArg = dyn_cast(yieldOp.getOperand(0));

1316 if (!yieldArg || yieldArg.getOwner() != &body) {

1318 "cannot fold fill-like op");

1319 }

1320

1321 rewriter.eraseOp(linalgOp);

1322 return success();

1323 }

1324

1325 if (!linalgOp.hasPureTensorSemantics()) {

1327 linalgOp, "mixed semantics is not supported yet");

1328 }

1329

1330

1331

1333 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {

1334 auto yieldArg = llvm::dyn_cast(yieldVal.value());

1335 if (!yieldArg || yieldArg.getOwner() != &body)

1336 return failure();

1337 unsigned argumentNumber = yieldArg.getArgNumber();

1338 Value returnedArg = linalgOp->getOperand(argumentNumber);

1339 Type resultType = linalgOp->getResult(yieldVal.index()).getType();

1340

1341

1342 Type returnType = returnedArg.getType();

1343 if (returnType != resultType) {

1344

1345

1348 returnedArg = rewriter.create<sparse_tensor::ConvertOp>(

1349 linalgOp.getLoc(), resultType, returnedArg);

1350 else {

1351 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),

1352 resultType))

1353 return failure();

1354 returnedArg = rewriter.createtensor::CastOp(

1355 linalgOp.getLoc(), resultType, returnedArg);

1356 }

1357 }

1358 returnedArgs.push_back(returnedArg);

1359 }

1360

1361 if (returnedArgs.size() != linalgOp->getNumResults())

1362 return failure();

1363 rewriter.replaceOp(linalgOp, returnedArgs);

1364 return success();

1365 }

1366 };

1367

1368 }

1369

1370 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,

1372 results.add<EraseIdentityLinalgOp>(context);

1373 }

1374

1377 }

1378

1379

1380

1381

1382

1386 nullptr) {

1387

1390 false))

1391 return failure();

1392

1393

1394 for (Type outputType : outputTypes) {

1395 if (llvm::isa(outputType))

1396 result.addTypes(outputType);

1397 }

1398

1399

1400 if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))

1401 return failure();

1402

1403

1405 return failure();

1406 return success();

1407 }

1408

1409 void MapOp::getAsmBlockArgumentNames(Region &region,

1411 for (Value v : getRegionInputArgs())

1412 setNameFn(v, "in");

1413 }

1414

1415 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {

1416 if (!getResults().empty())

1417 setNameFn(getResults().front(), "mapped");

1418 }

1419

1420 void MapOp::build(

1424 build(builder, result, TypeRange{}, inputs, init);

1426

1427

1429 if (llvm::isa(initType))

1431

1432 if (bodyBuild)

1434 inputs, {}, bodyBuild);

1435 }

1436

1441 bool initFirst = false) {

1446 for (auto &operand : operands) {

1448 llvm::cast(operand.getType()).getElementType(),

1450 }

1452

1453

1454 if (initFirst) {

1455 payloadOpOperands.push_back(block.getArguments().back());

1456 for (const auto &arg : block.getArguments().drop_back())

1457 payloadOpOperands.push_back(arg);

1458 } else {

1459 payloadOpOperands = {block.getArguments().begin(),

1461 }

1462

1465 payloadOpOperands,

1466 TypeRange{llvm::cast(result.operands.back().getType())

1467 .getElementType()},

1468 payloadOpAttrs);

1470 }

1471

1473 std::optional payloadOpName;

1477 if (failed(operationName))

1478 return failure();

1480 return failure();

1481 payloadOpName = operationName.value();

1483 return failure();

1484 }

1485

1487 return failure();

1488

1489 if (payloadOpName.has_value()) {

1490 if (!result.operands.empty())

1492 payloadOpAttrs,

1494 else

1496 } else {

1499 true, true)) {

1500 return failure();

1501 }

1503 if (parser.parseRegion(*body, regionArgs))

1504 return failure();

1505 }

1506 return success();

1507 }

1508

1509

1510

1511

1512

1515 return nullptr;

1517 assert(isa(body->getOperations().back()));

1518

1521 return nullptr;

1522 if (initFirst) {

1523

1525 return nullptr;

1526

1527 for (const auto &[operand, bbArg] :

1529 if (bbArg != operand)

1530 return nullptr;

1531 }

1532 } else {

1533 for (const auto &[operand, bbArg] :

1535 if (bbArg != operand)

1536 return nullptr;

1537 }

1538 }

1539 return &payload;

1540 }

1541

1544 std::string attrToElide;

1546 for (const auto &attr : payloadOp->getAttrs()) {

1547 auto fastAttr =

1548 llvm::dyn_castmlir::arith::FastMathFlagsAttr(attr.getValue());

1549 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {

1550 attrToElide = attr.getName().str();

1551 elidedAttrs.push_back(attrToElide);

1552 break;

1553 }

1554 }

1556 p << " }";

1557 }

1558

1560 Block *mapper = getBody();

1562 if (payloadOp) {

1564 }

1565

1568

1569 if (!payloadOp) {

1570

1573 p << "(";

1574 llvm::interleaveComma(mapper->getArguments(), p,

1575 [&](auto arg) { p.printRegionArgument(arg); });

1576 p << ") ";

1577

1578 p.printRegion(getMapper(), false);

1580 }

1581 }

1582

1584 auto *bodyBlock = getBody();

1585 auto blockArgs = bodyBlock->getArguments();

1586

1587

1588 if (getInputs().size() != blockArgs.size())

1589 return emitOpError() << "expects number of operands to match the arity of "

1590 "mapper, but got: "

1591 << getInputs().size() << " and " << blockArgs.size();

1592

1593

1594 for (const auto &[bbArgType, inputArg] :

1595 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {

1596 auto inputElemType =

1597 llvm::cast(inputArg.getType()).getElementType();

1598 if (bbArgType != inputElemType) {

1599 return emitOpError() << "expected element type of input " << inputElemType

1600 << " to match bbArg type " << bbArgType;

1601 }

1602 }

1603

1604

1605 auto outputShape = getInit().getType().getShape();

1606 for (Type inputArgType : TypeRange{getInputs()}) {

1607 auto inputElemShape = llvm::cast(inputArgType).getShape();

1608 if (inputElemShape != outputShape) {

1609 return emitOpError() << "expected shape of input (" << inputElemShape

1610 << ") to match shape of output (" << outputShape

1611 << ")";

1612 }

1613 }

1614

1615 return success();

1616 }

1617

1619 int64_t rank = getInit().getType().getRank();

1621 }

1622

1623 ArrayAttr MapOp::getIndexingMaps() {

1625 int64_t rank = getInit().getType().getRank();

1626 int64_t numIndexingMaps = getOperands().size();

1629 }

1630

1631 void MapOp::getEffects(

1633 &effects) {

1635 }

1636

1639 }

1640

1641

1642

1643

1644

1645 void ReduceOp::getAsmBlockArgumentNames(Region &region,

1647 for (Value v : getRegionInputArgs())

1648 setNameFn(v, "in");

1649 for (Value v : getRegionOutputArgs())

1650 setNameFn(v, "init");

1651 }

1652

1653 void ReduceOp::getAsmResultNames(

1655 if (!getResults().empty())

1656 setNameFn(getResults().front(), "reduced");

1657 }

1658

1659 void ReduceOp::build(

1664 build(builder, result, TypeRange{}, inputs, inits, dimensions);

1666

1667

1668 for (Value init : inits) {

1670 if (llvm::isa(initType))

1672 }

1673

1674 if (bodyBuild)

1676 inputs, inits, bodyBuild);

1677 }

1678

1680 int64_t inputRank =

1681 llvm::cast(getInputs()[0].getType()).getRank();

1683 utils::IteratorType::parallel);

1684 for (int64_t reductionDim : getDimensions())

1685 iteratorTypes[reductionDim] = utils::IteratorType::reduction;

1686 return iteratorTypes;

1687 }

1688

1689 ArrayAttr ReduceOp::getIndexingMaps() {

1690 int64_t inputRank =

1691 llvm::cast(getInputs()[0].getType()).getRank();

1693 getNumDpsInputs(),

1698 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)

1699 affineMaps.push_back(resultMap);

1701 }

1702

1703 void ReduceOp::getEffects(

1705 &effects) {

1707 }

1708

1711 }

1712

1715 StringRef attributeName) {

1717 return failure();

1718

1720 return success();

1721 }

1722

1724 std::optional payloadOpName;

1728 if (failed(operationName))

1729 return failure();

1731 return failure();

1732 payloadOpName = operationName.value();

1734 return failure();

1735 }

1736

1740 }))

1741 return failure();

1742

1743 if (payloadOpName.has_value()) {

1746 } else {

1749 true, true)) {

1750 return failure();

1751 }

1752

1754 if (parser.parseRegion(*body, regionArgs))

1755 return failure();

1756 }

1757

1758 return success();

1759 }

1760

1763 p << ' ' << attributeName << " = [" << attributeValue << "] ";

1764 }

1765

1767 Block *mapper = getBody();

1769 if (payloadOp) {

1771 }

1772

1776 if (!payloadOp) {

1777

1780 p << "(";

1781 llvm::interleaveComma(mapper->getArguments(), p,

1782 [&](auto arg) { p.printRegionArgument(arg); });

1783 p << ") ";

1784

1785 p.printRegion(getCombiner(), false);

1787 }

1788 }

1789

1792

1793 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {

1794 if (llvm::cast(getInputs()[i].getType()).getShape() !=

1795 llvm::cast(getInputs()[0].getType()).getShape()) {

1796 return emitOpError() << "expects all inputs to have the same shapes. "

1797 "Shape at input-index "

1798 << i

1799 << " is not equal to the shape at input-index 0.";

1800 }

1801 }

1802 for (int64_t i = 1; i < getNumDpsInits(); ++i) {

1803 if (llvm::cast(getInits()[i].getType()).getShape() !=

1804 llvm::cast(getInits()[0].getType()).getShape()) {

1805 return emitOpError() << "expects all outputs to have the same shapes. "

1806 "Shape at output-index "

1807 << i

1808 << " is not equal to the shape at output-index 0.";

1809 }

1810 }

1811 auto inputType = llvm::cast(getInputs()[0].getType());

1812 auto initType = llvm::cast(getInits()[0].getType());

1813

1815 for (int64_t dimension : dimensionsRef) {

1816 if (dimension < 0 || dimension >= inputType.getRank()) {

1817 return emitOpError()

1818 << "dimensions for reduction should be in the range [0, "

1819 << inputType.getRank() - 1 << "].";

1820 }

1821 dimensionsToReduce.insert(dimension);

1822 }

1823

1824 auto inputDims = inputType.getShape();

1825 auto initDims = initType.getShape();

1826

1827

1830 if (!dimensionsToReduce.count(en.index()))

1831 reducedInputDims.push_back(en.value());

1832 }

1833

1834 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {

1835 return emitOpError() << "number of dimensions after reduction "

1836 << reducedInputDims.size()

1837 << " doesn't match the init rank "

1838 << initType.getRank();

1839 }

1840

1841 if (reducedInputDims != initDims)

1842 return emitOpError() << "init dimensions [" << initDims

1843 << "] doesn't match input dimensions after reduction ["

1844 << reducedInputDims << "]";

1845

1846 Block *block = getBody();

1848 return emitOpError()

1849 << "mismatching number of operands and block arguments";

1850

1851

1852 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {

1853 Type inputElementType =

1854 llvm::cast(input.getType()).getElementType();

1855 if (inputElementType != bbArg.getType())

1856 return emitOpError()

1857 << "input element type " << inputElementType

1858 << " does not match corresponding block argument type "

1859 << bbArg.getType();

1860 }

1861

1862

1863 for (auto [output, bbArg] : llvm::zip(

1864 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {

1865 auto outputElementType =

1866 llvm::cast(output.getType()).getElementType();

1867 if (outputElementType != bbArg.getType())

1868 return emitOpError()

1869 << "output element type " << outputElementType

1870 << " does not match corresponding block argument type "

1871 << bbArg.getType();

1872 }

1873 return success();

1874 }

1875

1876

1877

1878

1879

1885 if (!args.empty())

1886 b.createlinalg::YieldOp(loc, args[0]);

1887 });

1888 }

1889

1896 result.addAttribute(getPermutationAttrName(result.name), permutation);

1898

1899

1901 if (llvm::isa(initType))

1903

1905 init);

1906 }

1907

1912 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),

1913 attributes);

1914 }

1915

1920 })))

1921 return failure();

1922

1926 {});

1927 return success();

1928 }

1929

1930 void TransposeOp::getAsmResultNames(

1932 if (!getResults().empty())

1933 setNameFn(getResults().front(), "transposed");

1934 }

1935

1940 }

1941

1944

1946 return emitOpError("permutation is not valid");

1947

1948 auto inputType = getInput().getType();

1949 auto initType = getInit().getType();

1950

1951 int64_t rank = inputType.getRank();

1952

1953 if (rank != initType.getRank())

1954 return emitOpError() << "input rank " << rank

1955 << " does not match init rank " << initType.getRank();

1956

1957 if (rank != static_cast<int64_t>(permutationRef.size()))

1958 return emitOpError() << "size of permutation " << permutationRef.size()

1959 << " does not match the argument rank " << rank;

1960

1961 auto inputDims = inputType.getShape();

1962 auto initDims = initType.getShape();

1963

1964 for (int64_t i = 0; i < rank; ++i) {

1965 int64_t inputDim = inputDims[permutationRef[i]];

1966 int64_t initDim = initDims[i];

1967

1968 if (inputDim != initDim) {

1969 return emitOpError() << "dim(result, " << i << ") = " << initDim

1970 << " doesn't match dim(input, permutation[" << i

1971 << "]) = " << inputDim;

1972 }

1973 }

1974

1975 return success();

1976 }

1977

1979 int64_t rank = getInit().getType().getRank();

1981 }

1982

1983 ArrayAttr TransposeOp::getIndexingMaps() {

1985 int64_t rank = getInit().getType().getRank();

1988 llvm::to_vector_of(getPermutation()), getContext())),

1990 }

1991

1992 void TransposeOp::getEffects(

1994 &effects) {

1996 }

1997

2000 }

2001

2002 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,

2004

2005 if (!isa(getInput().getType()))

2006 return failure();

2007

2008

2009 if (getPermutation().size() == 0) {

2010 result.push_back(getInput());

2011 return success();

2012 }

2013

2015 result.push_back(getInput());

2016 return success();

2017 }

2018

2019 return failure();

2020 }

2021

2022

2025

2028 auto defTransposeOp = transposeOp.getInput().getDefiningOp();

2029 if (!defTransposeOp)

2030 return failure();

2034 foldedPerms.reserve(perms.size());

2035 for (int64_t perm : perms)

2036 foldedPerms.push_back(defPerms[perm]);

2037

2039 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),

2040 foldedPerms);

2041 return success();

2042 }

2043 };

2044

2045

2046

2047

2050

2053 Value input = transposeOp.getInput();

2054 BroadcastOp broadcastOp = input.getDefiningOp();

2055 if (!input.hasOneUse() || !broadcastOp)

2056 return failure();

2057

2060

2061

2065 unsigned dimensionSize = dimensions.size();

2066 for (unsigned i = 0; i < dimensionSize; ++i)

2067 resultDimensions.push_back(invertPerm[dimensions[i]]);

2068

2069

2070 Value broadcastInput = broadcastOp.getInput();

2071 Location loc = transposeOp.getLoc();

2072 MLIRContext *ctx = transposeOp.getContext();

2074 auto broadcastInputTy =

2075 mlir::cast(broadcastInput.getType());

2076 unsigned inputRank = broadcastInputTy.getRank();

2077 for (unsigned i = 0; i < inputRank; ++i) {

2078 if (broadcastInputTy.isDynamicDim(i)) {

2079 dims.push_back(rewriter.createtensor::DimOp(loc, broadcastInput, i)

2081 } else {

2083 broadcastInputTy.getDimSize(i)));

2084 }

2085 }

2088 Value transposeInit = rewriter.createtensor::EmptyOp(

2089 transposeOp.getLoc(), transposeResultShapes,

2090 broadcastInputTy.getElementType());

2091

2092

2093 Value transposeResult =

2094 rewriter

2095 .create(loc, broadcastOp.getInput(), transposeInit,

2096 resultPerms)

2097 ->getResult(0);

2099 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);

2100 return success();

2101 }

2102 };

2103

2104 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,

2107 }

2108

2109

2110

2111

2112

2119 result.addAttribute(getDimensionsAttrName(result.name), dimensions);

2121

2122

2124 if (llvm::isa(initType))

2126

2128 init);

2129 }

2130

2135 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),

2136 attributes);

2137 }

2138

2143 })))

2144 return failure();

2145

2149 {});

2150 return success();

2151 }

2152

2153 void BroadcastOp::getAsmResultNames(

2155 if (!getResults().empty())

2156 setNameFn(getResults().front(), "broadcasted");

2157 }

2158

2163 }

2164

2167

2168 auto inputType = getInput().getType();

2169 auto initType = getInit().getType();

2170

2171 int64_t inputRank = inputType.getRank();

2172 int64_t initRank = initType.getRank();

2173

2174 auto inputShape = inputType.getShape();

2175 auto initShape = initType.getShape();

2176

2177 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)

2178 return emitOpError() << "input rank plus added dimensions does not "

2179 "match init rank. input rank: "

2180 << inputRank

2181 << ", dimensions size: " << dimensionsRef.size()

2182 << ", init rank: " << initRank;

2183

2184 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {

2185 if (dim < 0 || dim >= initRank)

2186 return emitOpError() << "dimension " << idx

2187 << " is out of range. expected range: [0, "

2188 << initRank - 1 << "], got: " << dim;

2189 }

2190

2191

2193 for (auto dim : llvm::seq<int64_t>(0, initRank)) {

2194 if (!llvm::is_contained(dimensionsRef, dim))

2195 dimMap.push_back(dim);

2196 }

2197

2198 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {

2199

2200

2201 if (inputShape[inputDimIdx] != initShape[initDimIdx])

2202 return emitOpError() << "input dim " << inputDimIdx

2203 << " should match init dim " << initDimIdx

2204 << ". input: " << inputShape[inputDimIdx]

2205 << ", init: " << initShape[initDimIdx];

2206 }

2207

2208 return success();

2209 }

2210

2212 int64_t rank = getInit().getType().getRank();

2214 }

2215

2216 ArrayAttr BroadcastOp::getIndexingMaps() {

2218 int64_t rank = getInit().getType().getRank();

2222 }

2223

2224 void BroadcastOp::getEffects(

2226 &effects) {

2228 }

2229

2232 }

2233

2234 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,

2236 results.add<EraseIdentityLinalgOp>(context);

2237 }

2238

2239

2240

2241

2242

2244 if (getNumOperands() > 0)

2245 p << ' ' << getOperands();

2247 if (getNumOperands() > 0)

2248 p << " : " << getOperandTypes();

2249 }

2250

2259 }

2260

2261

2262

2263 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {

2264 if (op.getNumOperands() != linalgOp.getNumDpsInits())

2265 return op.emitOpError("expected number of yield values (")

2266 << op.getNumOperands()

2267 << ") to match the number of inits / outs operands of the enclosing "

2268 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";

2269

2270 for (OpOperand &opOperand : op->getOpOperands()) {

2272 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());

2274 if (isa<MemRefType, RankedTensorType>(elementType))

2276 if (opOperand.get().getType() != elementType)

2277 return op.emitOpError("type of yield operand ")

2278 << (opOperand.getOperandNumber() + 1) << " ("

2279 << opOperand.get().getType() << ") doesn't match "

2280 << "the element type of the enclosing linalg.generic op ("

2281 << elementType << ")";

2282 }

2283 return success();

2284 }

2285

2287 auto *parentOp = (*this)->getParentOp();

2288 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())

2289 return emitOpError("expected single non-empty parent region");

2290

2291 if (auto linalgOp = dyn_cast(parentOp))

2293

2294 return emitOpError("expected parent op with LinalgOp interface");

2295 }

2296

2297

2298

2299

2300

2302 auto linalgOp = dyn_cast((*this)->getParentOp());

2303 if (!linalgOp)

2304 return emitOpError("expected parent op with LinalgOp interface");

2305 if (linalgOp.getNumLoops() <= getDim())

2306 return emitOpError("expected dim (")

2307 << getDim() << ") to be lower than the number of loops ("

2308 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";

2309 return success();

2310 }

2311

2312 OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {

2313 auto linalgOp = dyn_cast_or_null((*this)->getParentOp());

2314

2315

2316

2317 if (!linalgOp)

2319

2320

2322 uint64_t dim = getDim();

2323 assert(dim < loopBounds.size() && "Dim is out of bounds");

2324 if (loopBounds[dim] == 1)

2326

2328 }

2329

2330

2331

2332 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"

2333

2334 #define GET_OP_CLASSES

2335 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"

2336

2337 #define GET_OP_CLASSES

2338 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"

2339 #define GET_OP_CLASSES

2340 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"

2341

2343 unsigned rank,

2345 if (maybeMap)

2346 return *maybeMap;

2347 if (rank == 0)

2350 }

2351

2356 res.reserve(num);

2357 for (unsigned i = 0; i < num; ++i)

2359 return res;

2360 }

2361

2364 auto rangeA = llvm::make_range(a.begin(), a.end());

2365 auto rangeB = llvm::make_range(b.begin(), b.end());

2366 auto concatRanges = llvm::concat(rangeA, rangeB);

2367 return llvm::to_vector<4>(concatRanges);

2368 }

2369

2371 if (auto memref = llvm::dyn_cast(t)) {

2372 ss << "view";

2373 for (auto size : memref.getShape())

2374 if (size < 0)

2375 ss << "sx";

2376 else

2377 ss << size << "x";

2379 return failure();

2380 if (auto as = memref.getMemorySpace()) {

2381 if (auto attr = llvm::dyn_cast(as))

2382 ss << "as" << attr.getInt();

2383 else

2384 return failure();

2385 }

2386 return success();

2387 }

2388 if (auto vec = llvm::dyn_cast(t)) {

2389 ss << "vector";

2390 llvm::interleave(

2391 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });

2393 return failure();

2394 return success();

2395 }

2397 ss << t;

2398 return success();

2399 }

2400 return failure();

2401 }

2402

2404 assert(isa(op));

2406 std::string fun = "";

2408 if (UnaryFnAttr ufa = llvm::dyn_cast(kv.getValue())) {

2409 fun = stringifyEnum(ufa.getValue()).str() + "_";

2410 } else if (BinaryFnAttr bfa = llvm::dyn_cast(kv.getValue())) {

2411 fun = stringifyEnum(bfa.getValue()).str() + "_";

2412 }

2413 }

2414 name.reserve(128);

2415 llvm::replace(name, '.', '_');

2416 llvm::raw_string_ostream ss(name);

2417 ss << "_" << fun;

2420 return std::string();

2421 ss << "_";

2422 }

2423 name.pop_back();

2424 return name;

2425 }

2426

2427

2428

2429

2430

2431 namespace {

2434

2435 LogicalResult matchAndRewrite(LinalgOp op,

2437 for (OpOperand &opOperand : op->getOpOperands()) {

2438

2439

2440

2441 auto mt = llvm::dyn_cast(opOperand.get().getType());

2442 if (!mt)

2443 continue;

2444 if (llvm::is_contained(op.getShape(&opOperand), 0)) {

2446 return success();

2447 }

2448 }

2449 return failure();

2450 }

2451 };

2452

2453

2454

2455 struct FoldTensorCastConsumerOp : public OpRewritePatterntensor::CastOp {

2457

2458 LogicalResult matchAndRewrite(tensor::CastOp castOp,

2461 return failure();

2462

2463 auto linalgOp = castOp.getSource().getDefiningOp();

2464 if (!linalgOp)

2465 return failure();

2466

2467

2468

2469

2470 if (castOp->getBlock() != linalgOp->getBlock())

2471 return failure();

2472

2475

2476 Location loc = linalgOp.getLoc();

2477 OpResult resultValue = llvm::cast(castOp.getSource());

2479 auto resultType =

2480 llvm::cast(castOp->getResult(0).getType());

2481

2482

2483

2484

2485

2486 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);

2487 Value newOperand =

2488 rewriter.createtensor::CastOp(loc, resultType, outOperand->get());

2491 linalgOp.getDpsInits().end());

2492 outputOperands[resultNumber] = newOperand;

2493 newOperands.append(outputOperands.begin(), outputOperands.end());

2494

2496 linalgOp->result_type_end());

2497 resultTypes[resultNumber] = resultType;

2498 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);

2499

2500

2501 Value castBack = rewriter.createtensor::CastOp(

2502 loc, resultValue.getType(), newOp->getResult(resultNumber));

2503

2505 results[resultNumber] = castBack;

2506 rewriter.replaceOp(linalgOp, results);

2508 return success();

2509 }

2510 };

2511

2512

2513

2516 for (OpOperand &opOperand : operands) {

2517 if (linalgOp.isScalar(&opOperand))

2518 continue;

2519 Value src = opOperand.get();

2520 auto sourceType = llvm::cast(src.getType());

2521 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);

2522

2523

2524

2525

2528 if (parentOp) {

2529 if (auto castOp = dyn_casttensor::CastOp(parentOp)) {

2530 Value castSource = castOp.getSource();

2531 auto castSourceType =

2532 llvm::dyn_cast(castSource.getType());

2533 if (castSourceType && castSourceType.hasStaticShape())

2534 sourceShape = castSourceType.getShape();

2535 }

2536 }

2537

2538

2539

2540 for (unsigned i = 0; i < sourceShape.size(); i++) {

2541 if (sourceType.isDynamicDim(i))

2542 continue;

2543 if (auto affineDimExpr = dyn_cast(sourceMap.getResult(i)))

2544 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);

2545 }

2546 }

2547 }

2548

2549

2550

2551

2552

2553

2554 static void createNewOperandWithStaticSizes(

2558 bool &changeNeeded) {

2559 Value src = opOperand->get();

2560 newOperands.push_back(src);

2561 if (linalgOp.isScalar(opOperand))

2562 return;

2563 auto sourceType = llvm::cast(src.getType());

2564 Type resultType = sourceType;

2565 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {

2566 resultTypes.push_back(resultType);

2567 return;

2568 }

2570 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);

2572

2573

2574 bool newOperandNeeded = false;

2575 for (unsigned i = 0; i < sourceShape.size(); i++) {

2576 int64_t dimShape = sourceShape[i];

2578 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {

2579 newShape.push_back(dimShape);

2580 continue;

2581 }

2582

2583

2584

2585 newShape.push_back(affineExprToSize[dimExpr]);

2586 newOperandNeeded = true;

2587 }

2589 sourceType.getEncoding());

2590 if (newOperandNeeded) {

2591 changeNeeded = true;

2592

2593

2594 Value newOperand = rewriter.createtensor::CastOp(loc, resultType, src);

2596 newOperands[index] = newOperand;

2597 }

2598 if (linalgOp.isDpsInit(opOperand))

2599 resultTypes.push_back(resultType);

2600 }

2601

2602

2603

2604

2607

2608 LogicalResult matchAndRewrite(LinalgOp linalgOp,

2610 if (!linalgOp.hasPureTensorSemantics())

2611 return failure();

2612

2613

2614 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {

2615 return !map.isProjectedPermutation();

2616 }))

2617 return failure();

2618

2619

2621 Location loc = linalgOp.getLoc();

2622

2623

2624

2625 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);

2626

2629

2630

2631

2632 bool changeNeeded = false;

2633 newOperands.reserve(linalgOp->getNumOperands());

2634 resultTypes.reserve(linalgOp.getNumDpsInits());

2635

2636

2637 for (OpOperand &opOperand : linalgOp->getOpOperands()) {

2638 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,

2639 affineExprToSize, linalgOp, newOperands,

2640 resultTypes, changeNeeded);

2641 }

2642

2643

2644

2645 if (!changeNeeded)

2646 return failure();

2647

2648

2649 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);

2652 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {

2653 Value newResult = std::get<1>(it);

2654 Value oldResult = std::get<0>(it);

2657 replacements.push_back(

2658 (newType != oldType)

2659 ? rewriter.createtensor::CastOp(loc, oldType, newResult)

2660 : newResult);

2661 }

2662 rewriter.replaceOp(linalgOp, replacements);

2663 return success();

2664 }

2665 };

2666

2667 }

2668

2669

2670

2671

2672

2673

2674

2675

2677 ShapedType inputType = getInputOperandType();

2678 ShapedType outputType = getOutputOperandType();

2679

2683 return emitOpError("incompatible output shape");

2684

2685 int64_t inputRank = getInputOperandRank();

2686 int64_t dimension = getDimension();

2687 if ((dimension < 0) || (dimension >= inputRank))

2688 return emitOpError("incorrect dimension specified");

2689

2690 return success();

2691 }

2692

2694 int64_t operandRank = getInputOperandRank();

2697 Value zero = builder.createarith::ConstantIndexOp(loc, 0);

2698 Value one = builder.createarith::ConstantIndexOp(loc, 1);

2699 Value source = getInput();

2700 for (auto dim : llvm::seq<int64_t>(0, operandRank)) {

2701 loopBounds[dim].offset = zero;

2702 loopBounds[dim].size = getDimValue(builder, loc, source, dim);

2703 loopBounds[dim].stride = one;

2704 }

2705 return loopBounds;

2706 }

2707

2710 utils::IteratorType::parallel);

2711 iteratorTypes[getDimension()] = utils::IteratorType::reduction;

2712 return iteratorTypes;

2713 }

2714

2715 FailureOr

2719 int64_t rank = getInputOperandRank();

2724 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);

2725 if (!inputSlice) {

2726 return emitOpError("failed to compute input slice");

2727 }

2728 tiledOperands.emplace_back(inputSlice->getResult(0));

2730 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);

2731 if (!outputSlice) {

2732 return emitOpError("failed to compute output slice");

2733 }

2734 tiledOperands.emplace_back(outputSlice->getResult(0));

2735

2737 if (hasPureTensorSemantics())

2738 resultTypes.push_back(tiledOperands[1].getType());

2740 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);

2741

2743 {tiledOp},

2746 }

2747

2752 if (resultNumber == 0) {

2753 resultOffsets.assign(offsets.begin(), offsets.end());

2754 resultSizes.assign(sizes.begin(), sizes.end());

2755 return success();

2756 }

2757 return failure();

2758 }

2759

2760

2763 }

2764

2765 LogicalResult

2769 Location loc = getOperation()->getLoc();

2771 auto inputShapedType = llvm::cast(getInputOperandType());

2772 auto outputShapedType = llvm::cast(getOutputOperandType());

2773 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {

2774 if (!outputShapedType.isDynamicDim(dim)) {

2775

2776 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));

2777 } else {

2778

2781 }

2782 }

2783 reifiedReturnShapes.emplace_back(std::move(shapes));

2784 return success();

2785 }

2786

2787 void SoftmaxOp::getEffects(

2789 &effects) {

2790 for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {

2791 if (!llvm::isa(operand.getType()))

2792 continue;

2794 &getOperation()->getOpOperand(index), 0,

2795 true,

2797 }

2798

2799 for (OpOperand &operand : getDpsInitsMutable()) {

2800 if (!llvm::isa(operand.get().getType()))

2801 continue;

2803 true,

2806 true,

2808 }

2809 }

2810

2811

2812

2813

2814

2815

2816

2817

2818

2819

2820

2821

2822

2823

2824

2825

2826

2827

2828

2829

2830

2833 int64_t dim, bool allParallel = false) {

2835 utils::IteratorType::parallel);

2836 if (!allParallel)

2837 iteratorTypes[dim] = utils::IteratorType::reduction;

2841 for (int i = 0; i < inputRank; i++) {

2842 if (i != dim)

2844 }

2845 auto reductionMap =

2846 AffineMap::get(inputRank, 0, affineExprs, ctxt);

2848 return std::make_tuple(iteratorTypes, indexingMaps);

2849 }

2850

2851

2852

2853 template

2855 int64_t dim) {

2856 auto inputType = cast(input.getType());

2858 int64_t inputRank = inputShape.size();

2859 auto [iteratorTypes, indexingMaps] =

2861 assert(indexingMaps.size() == 2 &&

2862 "We should have two maps: 1 for the input, 1 for the output");

2863 assert(indexingMaps[0].isIdentity() && "input map should be identity");

2864

2865 auto genericOp = builder.createlinalg::GenericOp(

2866 loc, output.getType(), input, output, indexingMaps, iteratorTypes,

2868 Value result = b.create(loc, args[0], args[1]);

2869 b.createlinalg::YieldOp(loc, result);

2870 });

2872 }

2873

2874

2875

2876

2879 auto inputType = cast(input.getType());

2881 int64_t inputRank = inputShape.size();

2883 builder, inputRank, dim, true);

2884 assert(indexingMaps.size() == 2 && "We should have one map for each input");

2885 assert(indexingMaps[0].isIdentity() && "input map should be identity");

2886

2887 indexingMaps.push_back(indexingMaps[0]);

2888 auto genericOp = builder.createlinalg::GenericOp(

2891 Value diff = b.createarith::SubFOp(loc, args[0], args[1]);

2892 Value result = b.createmath::ExpOp(loc, diff);

2893 b.createlinalg::YieldOp(loc, result);

2894 });

2896 }

2897

2898

2899

2900

2901

2902

2904 Value denominator, Value output, int64_t dim) {

2905 auto inputType = cast(numerator.getType());

2907 int64_t inputRank = inputShape.size();

2909 builder, inputRank, dim, true);

2910 assert(indexingMaps.size() == 2 &&

2911 "We should have one map for each input (2)");

2912 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");

2913

2914 indexingMaps.push_back(indexingMaps[0]);

2915 auto genericOp = builder.createlinalg::GenericOp(

2916 loc, numerator.getType(), ValueRange{numerator, denominator}, output,

2917 indexingMaps, iteratorTypes,

2919 Value result = b.createarith::DivFOp(loc, args[0], args[1]);

2920 b.createlinalg::YieldOp(loc, result);

2921 });

2923 }

2924

2925

2926

2927

2928

2929

2930

2931

2932

2933

2934

2935

2936

2937

2938

2939

2940

2941

2942

2943

2944 FailureOr<SmallVector> SoftmaxOp::decomposeOperation(OpBuilder &b) {

2948 Value input = getInput();

2949 ShapedType inputType = getInputOperandType();

2950 Type elementType = inputType.getElementType();

2951 int64_t reductionDim = getDimension();

2953 Value output = getOutput();

2954 dims.erase(dims.begin() + reductionDim);

2955

2956 Value outputReduce = b.createtensor::EmptyOp(loc, dims, elementType);

2958 elementType, b, loc,

2959 true);

2960 Value neutralForMaxFInit =

2961 b.createlinalg::FillOp(loc, Value{neutralForMaxF}, outputReduce)

2962 .result();

2964 reducearith::MaxNumFOp(b, loc, input, neutralForMaxFInit, reductionDim);

2965

2966

2968

2969

2971 b, loc, true);

2972 Value zeroInit =

2973 b.createlinalg::FillOp(loc, Value{zero}, outputReduce).result();

2974 Value denominator =

2975 reducearith::AddFOp(b, loc, numerator, zeroInit, reductionDim);

2976

2977

2979 buildDivOp(b, loc, numerator, denominator, output, reductionDim);

2981 }

2982

2983

2984

2985

2986

2988 auto filterType = cast(getFilter().getType());

2990 int64_t filterH = filterShape[getFilterHDim()];

2991 int64_t filterW = filterShape[getFilterWDim()];

2992 int64_t r = getR();

2993 int64_t m = getM();

2994

2995 if (filterH != r && filterH != 1)

2996 return emitOpError("expect filter height either equals to r or 1");

2997 if (filterW != r && filterW != 1)

2998 return emitOpError("expect filter width either equals to r or 1");

2999 if (filterH == 1 && filterW == 1)

3000 return emitOpError("expect either filter height or width equals to r");

3001

3003 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);

3004 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);

3005 expectedOutputShape.push_back(filterShape[getFilterCDim()]);

3006 expectedOutputShape.push_back(filterShape[getFilterFDim()]);

3007

3008 auto outputType = cast(getOutput().getType());

3011 return emitOpError("the output shape is not expected");

3012 }

3013 return success();

3014 }

3015

3017 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {

3019 IntegerAttr zeroAttr = builder.getIndexAttr(0);

3020 IntegerAttr oneAttr = builder.getIndexAttr(1);

3021 Value filter = getFilter();

3022 int64_t filterRank = getFilterOperandRank();

3024 for (unsigned dim = 0; dim < filterRank; ++dim) {

3025 loopBounds[dim].offset = zeroAttr;

3026 loopBounds[dim].size = getDimValue(builder, loc, filter, dim);

3027 loopBounds[dim].stride = oneAttr;

3028 }

3029 return loopBounds;

3030 }

3031

3033 WinogradFilterTransformOp::getLoopIteratorTypes() {

3034 int64_t filterRank = getFilterOperandRank();

3036 utils::IteratorType::parallel);

3037 return iteratorTypes;

3038 }

3039

3045 ShapedType filterType = getFilterOperandType();

3047 int64_t filterH = filterShape[getFilterHDim()];

3048 int64_t filterW = filterShape[getFilterWDim()];

3049 int64_t m = getM();

3050 int64_t r = getR();

3051 int64_t alpha = m + r - 1;

3052 int64_t alphaH = filterH != 1 ? alpha : 1;

3053 int64_t alphaW = filterW != 1 ? alpha : 1;

3056

3057 resultOffsets.append(

3058 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});

3059 resultSizes.append(

3060 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});

3061

3062 return success();

3063 }

3064

3065

3066

3067

3068

3069

3070

3076 ShapedType filterType = getFilterOperandType();

3078 int64_t filterH = filterShape[getFilterHDim()];

3079 int64_t filterW = filterShape[getFilterWDim()];

3084

3085 sliceOffsets.append(

3086 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});

3087 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,

3088 sizes[getFilterCDim()]});

3089 int64_t filterRank = getFilterOperandRank();

3092 auto filterSlice = builder.createtensor::ExtractSliceOp(

3093 loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);

3094 tiledOperands.emplace_back(filterSlice);

3095

3098 resultSizes)))

3099 return failure();

3100

3101 int64_t outputRank = getOutputOperandRank();

3103 auto outputSlice = builder.createtensor::ExtractSliceOp(

3104 loc, getOutput(), resultOffsets, resultSizes, outputStrides);

3105 tiledOperands.emplace_back(outputSlice);

3106

3108 resultTypes.push_back(tiledOperands[1].getType());

3110 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);

3111

3113 {tiledOp},

3116 }

3117

3118

3119

3120

3121

3123 auto inputType = cast(getInput().getType());

3125 int64_t inputH = inputShape[getInputHDim()];

3126 int64_t inputW = inputShape[getInputWDim()];

3127 int m = getM();

3128 int r = getR();

3129 int64_t tileSize = m + r - 1;

3130

3131 auto outputType = cast(getOutput().getType());

3133 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;

3134 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;

3135

3137 if (ShapedType::isDynamic(inputH)) {

3138 expectedOutputShape[getOutputAlphaHDim()] = tileSize;

3139 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;

3140 } else {

3141 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;

3142 expectedOutputShape[getOutputTileHDim()] =

3143 leftTransform ? (inputH - (r - 1)) / m : inputH;

3144 }

3145 if (ShapedType::isDynamic(inputW)) {

3146 expectedOutputShape[getOutputAlphaWDim()] = tileSize;

3147 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;

3148 } else {

3149 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;

3150 expectedOutputShape[getOutputTileWDim()] =

3151 rightTransform ? (inputW - (r - 1)) / m : inputW;

3152 }

3153 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];

3154 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];

3155

3157 return emitOpError("the output shape is not expected");

3158 }

3159 return success();

3160 }

3161

3163 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {

3165 IntegerAttr zeroAttr = builder.getIndexAttr(0);

3166 IntegerAttr oneAttr = builder.getIndexAttr(1);

3167 Value output = getOutput();

3168 int64_t outputRank = getOutputOperandRank();

3170 for (unsigned dim = 0; dim < outputRank; ++dim) {

3171 loopBounds[dim].offset = zeroAttr;

3172

3173 loopBounds[dim].size = getDimValue(builder, loc, output, dim);

3174 loopBounds[dim].stride = oneAttr;

3175 }

3176 return loopBounds;

3177 }

3178

3180 WinogradInputTransformOp::getLoopIteratorTypes() {

3181 int64_t outputRank = getOutputOperandRank();

3183 utils::IteratorType::parallel);

3184 return iteratorTypes;

3185 }

3186

3192 ShapedType outputType = getOutputOperandType();

3194 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];

3195 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];

3196

3197 int64_t m = getM();

3198 int64_t r = getR();

3199 int64_t alpha = m + r - 1;

3200 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;

3201 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;

3202

3205

3206 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],

3207 offsets[getOutputTileWDim()], offsets[getOutputNDim()],

3208 offsets[getOutputCDim()]});

3209 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],

3210 sizes[getOutputTileWDim()], sizes[getOutputNDim()],

3211 sizes[getOutputCDim()]});

3212

3213 return success();

3214 }

3215

3216

3217

3218

3219

3220

3221

3222 FailureOr

3227 int64_t m = getM();

3228 int64_t r = getR();

3229

3230 ShapedType outputType = getOutputOperandType();

3232 int64_t alphaH = outputShape[getOutputAlphaHDim()];

3233 int64_t alphaW = outputShape[getOutputAlphaWDim()];

3234

3237 auto identityAffineMap =

3239 auto offsetAffineMap =

3242 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),

3243 offsets[getOutputTileHDim()]);

3245 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),

3246 offsets[getOutputTileWDim()]);

3248 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);

3250 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);

3252 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);

3253

3256

3259 sliceOffsets.append(

3260 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});

3265 sliceSizes.append(

3266 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});

3267 int64_t inputRank = getInputOperandRank();

3269 auto inputSlice = builder.createtensor::ExtractSliceOp(

3270 loc, getInput(), sliceOffsets, sliceSizes, inputStrides);

3271 tiledOperands.emplace_back(inputSlice);

3272

3275 resultSizes)))

3276 return failure();

3277

3278 int64_t outputRank = getOutputOperandRank();

3280 auto outputSlice = builder.createtensor::ExtractSliceOp(

3281 loc, getOutput(), resultOffsets, resultSizes, outputStrides);

3282 tiledOperands.emplace_back(outputSlice);

3283

3285 resultTypes.push_back(tiledOperands[1].getType());

3287 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);

3288

3290 {tiledOp},

3293 }

3294

3295

3296

3297

3298

3300 auto valueType = cast(getValue().getType());

3302 int64_t valueH = valueShape[getValueAlphaHDim()];

3303 int64_t valueW = valueShape[getValueAlphaWDim()];

3304 int64_t valueTileH = valueShape[getValueTileHDim()];

3305 int64_t valueTileW = valueShape[getValueTileWDim()];

3306 int m = getM();

3307 int r = getR();

3308 bool leftTransform = valueH != 1;

3309 bool rightTransform = valueW != 1;

3310

3311 int64_t outputRank = getOutputOperandRank();

3313 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {

3314 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;

3315 } else {

3316 if (valueH != (leftTransform ? m + r - 1 : 1))

3317 return emitOpError("expect input height equals to input tile size");

3318 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;

3319 }

3320 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {

3321 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;

3322 } else {

3323 if (valueW != (rightTransform ? m + r - 1 : 1))

3324 return emitOpError("expect input width equals to input tile size");

3325 expectedOutputShape[getOutputWDim()] =

3326 (rightTransform ? m : 1) * valueTileW;

3327 }

3328 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];

3329 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];

3330

3331 auto outputType = cast(getOutput().getType());

3334 return emitOpError("the output shape is not expected");

3335 }

3336 return success();

3337 }

3338

3340 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {

3342 IntegerAttr zeroAttr = builder.getIndexAttr(0);

3343 IntegerAttr oneAttr = builder.getIndexAttr(1);

3344 Value value = getValue();

3345 int64_t valueRank = getValueOperandRank();

3347 for (unsigned dim = 0; dim < valueRank; ++dim) {

3348 loopBounds[dim].offset = zeroAttr;

3349

3350 loopBounds[dim].size = getDimValue(builder, loc, value, dim);

3351 loopBounds[dim].stride = oneAttr;

3352 }

3353 return loopBounds;

3354 }

3355

3357 WinogradOutputTransformOp::getLoopIteratorTypes() {

3358 int64_t valueRank = getValueOperandRank();

3360 utils::IteratorType::parallel);

3361 return iteratorTypes;

3362 }

3363

3368 int64_t m = getM();

3369

3372 auto identityAffineMap =

3374 auto affineMap =

3376

3377 ShapedType valueType = getValueOperandType();

3379 int64_t valueH = valueShape[0];

3380 int64_t valueW = valueShape[1];

3382 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),

3383 offsets[getValueTileHDim()]);

3385 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),

3386 offsets[getValueTileWDim()]);

3388 builder, loc, affineMap, sizes[getValueTileHDim()]);

3390 builder, loc, affineMap, sizes[getValueTileWDim()]);

3391

3399

3400 resultOffsets.append(

3401 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});

3402 resultSizes.append(

3403 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});

3404 return success();

3405 }

3406

3407

3408

3409

3410

3411

3412

3421

3422 ShapedType valueType = getValueOperandType();

3424 int64_t alphaH = valueShape[getValueAlphaHDim()];

3425 int64_t alphaW = valueShape[getValueAlphaWDim()];

3428

3429 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],

3430 offsets[getValueTileWDim()], offsets[getValueNDim()],

3431 offsets[getValueFDim()]});

3432 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],

3433 sizes[getValueTileWDim()], sizes[getValueNDim()],

3434 sizes[getValueFDim()]});

3435 int64_t valueRank = getValueOperandRank();

3437 auto valueSlice = builder.createtensor::ExtractSliceOp(

3438 loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);

3439 tiledOperands.emplace_back(valueSlice);

3440

3443 resultSizes)))

3444 return failure();

3445

3446 int64_t outputRank = getOutputOperandRank();

3448 auto outputSlice = builder.createtensor::ExtractSliceOp(

3449 loc, getOutput(), resultOffsets, resultSizes, strides);

3450 tiledOperands.emplace_back(outputSlice);

3451

3453 resultTypes.push_back(tiledOperands[1].getType());

3455 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);

3456

3458 {tiledOp},

3461 }

3462

3463

3464

3465

3466

3467

3468

3470 auto explicitRange = subMap.getResults();

3471 auto defaultRange = fullMap.getResults();

3472 DenseSet explicitSet(explicitRange.begin(), explicitRange.end());

3474 llvm::set_union(explicitSet, defaultSet);

3475 return explicitSet == defaultSet;

3476 }

3477

3478

3479

3480

3481

3482

3483

3486 }

3487

3488

3489

3490

3492 unsigned opIndex) {

3495 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());

3496

3497 auto opIndexingMap = opIndexingMaps[opIndex];

3498 auto defaultIndexingMap = defaultIndexingMaps[opIndex];

3499

3501 return matmulOp->emitOpError()

3502 << "Unexpected dim expression in map result.";

3503

3504 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {

3505 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {

3506 return matmulOp->emitOpError()

3507 << "Invalid broadcast requested, should be (d2).";

3508 }

3509 return success();

3510 }

3511 return success();

3512 }

3513

3514

3515

3516 template

3519 AffineMap defaultIndexingMap, bool isLHS) {

3520 assert((isa(batchVariantMatmulOp) ||

3521 isa(batchVariantMatmulOp)) &&

3522 "Expected BatchMatmulOp or BatchReduceMatmulOp");

3523

3525 return batchVariantMatmulOp->emitOpError()

3526 << "Unexpected result dim expression (outside the set of default "

3527 "result dims).";

3528

3529

3531 return batchVariantMatmulOp->emitOpError()

3532 << "no. of result dim expressions exceeds 3.";

3533

3534 auto hasValidBatchDim = [](AffineMap map) {

3537 };

3538

3539

3540 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {

3541 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))

3542 return batchVariantMatmulOp->emitOpError()

3543 << "Invalid broadcast requested.";

3544 } else if (!hasValidBatchDim(opIndexingMap)) {

3545 return batchVariantMatmulOp->emitOpError()

3546 << "Invalid batch dimension expression.";

3547 }

3548 return success();

3549 }

3550

3551

3552

3553

3554 template

3557 assert((isa(batchVariantMatmulOp) ||

3558 isa(batchVariantMatmulOp)) &&

3559 "Expected BatchMatmulOp or BatchReduceMatmulOp");

3560 if (isa(batchVariantMatmulOp) &&

3562

3563 return batchVariantMatmulOp->emitOpError()

3564 << "expects 3 dims, but got (" << opIndexingMap.getNumResults()

3565 << ").";

3566 }

3567 if (isa(batchVariantMatmulOp) &&

3569 return batchVariantMatmulOp->emitOpError()

3570 << "expects 2 dims, but got (" << opIndexingMap.getNumResults()

3571 << ").";

3572 }

3573

3574 auto areValidOutputResultDim = [&](AffineMap outputMap) {

3575 return isa(batchVariantMatmulOp)

3576 ? outputMap.getResult(0).isFunctionOfDim(0) &&

3577 outputMap.getResult(1).isFunctionOfDim(1) &&

3578 outputMap.getResult(2).isFunctionOfDim(2)

3579 : outputMap.getResult(0).isFunctionOfDim(1) &&

3580 outputMap.getResult(1).isFunctionOfDim(2);

3581 };

3582

3583 if (!areValidOutputResultDim(opIndexingMap)) {

3584 return batchVariantMatmulOp->emitOpError()

3585 << "Invalid output map result dimension.";

3586 }

3587

3588 return success();

3589 }

3590

3591

3592

3593

3594 template

3595 static LogicalResult

3597 unsigned opIndex) {

3599 batchVariantMatmulOp.getIndexingMapsArray();

3601 batchVariantMatmulOp.getDefaultIndexingMaps(

3602 batchVariantMatmulOp->getContext());

3603

3604 if (opIndexingMaps.size() != 3)

3605 return batchVariantMatmulOp->emitOpError()

3606 << "Indexing_map attribute must have 3 affine maps.";

3607

3608 auto opIndexingMap = opIndexingMaps[opIndex];

3609 auto defaultIndexingMap = defaultIndexingMaps[opIndex];

3610

3611 if (opIndex == 2 &&

3612 failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))

3613 return failure();

3614

3615 if (opIndex != 2 &&

3616 failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,

3617 defaultIndexingMap, opIndex == 0)))

3618 return failure();

3619

3620 return success();

3621 }

3622

3623 namespace mlir {

3624 namespace linalg {

3625

3626

3627

3628

3629

3630

3634 bindDims(context, d0, d1, d2);

3635 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));

3636 indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));

3637 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));

3638 return indexingMaps;

3639 }

3640

3643 utils::IteratorType::parallel,

3644 utils::IteratorType::reduction};

3645 }

3646

3647 unsigned MatmulOp::getNumRegionArgs() { return 3; }

3648

3649 std::string MatmulOp::getLibraryCallName() {

3651 }

3652

3653 bool MatmulOp::hasDynamicIndexingMaps() { return true; }

3654

3655

3656

3657 bool MatmulOp::hasUserDefinedMaps() {

3659 getDefaultIndexingMaps(this->getContext());

3661 return defaultMaps != explicitMaps;

3662 }

3663

3664

3665

3669 "MatmulOp regionBuilder expects 3 (>=0) args");

3670 RegionBuilderHelper helper(b, block);

3672

3673 TypeFn castVal = TypeFn::cast_signed;

3674 const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {

3675 return attr.getName() == "cast";

3676 });

3677 if (castIter != attrs.end()) {

3678 if (auto attr = llvm::dyn_cast(castIter->getValue()))

3680 }

3681

3686 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);

3688 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);

3689 yields.push_back(value4);

3690 helper.yieldOutputs(yields);

3691 }

3692

3693

3694

3695

3696

3697

3698

3699

3700 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {

3701 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");

3703

3705 }

3706

3709 return ArrayAttr{

3710 nullptr};

3711

3712 ArrayAttr arrayAttr;

3714 return failure();

3715

3716 if (llvm::any_of(arrayAttr,

3717 [](auto elt) { return !dyn_cast(elt); }))

3719 << "element of indexing_maps array is not an affine_map";

3720

3721 return arrayAttr;

3722 }

3723

3726 if (failed(indexingMapsAttr))

3727 return failure();

3728

3729 if (*indexingMapsAttr == nullptr) {

3730 auto indexingMapAttrs = llvm::map_to_vector(

3731 MatmulOp::getDefaultIndexingMaps(parser.getContext()),

3732 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });

3734 }

3735

3736 result.addAttribute("indexing_maps", *indexingMapsAttr);

3738 MatmulOp::getRegionBuilder());

3739 }

3740

3743 MatmulOp::getDefaultIndexingMaps(getContext()),

3745 if (!llvm::equal(getIndexingMaps(), indexingMaps))

3746 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());

3747

3748 std::array<StringRef, 3> elidedAttrs = {

3749 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};

3751 elidedAttrs);

3752 }

3753

3754

3756

3757 if (!hasUserDefinedMaps())

3758 return success();

3759

3760 for (unsigned opIndex = 0; opIndex < 2; opIndex++) {

3762 return failure();

3763 }

3764 return success();

3765 }

3766

3769 }

3770

3771 void MatmulOp::getEffects(

3773 &effects) {

3774 if (hasPureTensorSemantics())

3775 return;

3777 }

3778

3781 }

3782

3783

3784

3785

3786

3788 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();

3789

3790

3791

3792

3793

3794

3795

3796

3798 for (auto result : outAffineMap.getResults()) {

3799 auto dimExpr = dyn_cast(result);

3800 assert(dimExpr && "affine_map is a projected permutation");

3801 dimsInOutput[dimExpr.getPosition()] = true;

3802 }

3803

3805 for (auto dimOccursInOutput : dimsInOutput)

3806 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel

3807 : utils::IteratorType::reduction);

3808

3809 return iteratorTypes;

3810 }

3811

3812 unsigned ContractOp::getNumRegionArgs() { return 3; }

3813

3814

3818 "ContractOp regionBuilder expects 3 args");

3819 RegionBuilderHelper helper(b, block);

3820

3821 TypeFn castSignedness = TypeFn::cast_signed;

3822 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {

3823 return attr.getName() == "cast";

3824 });

3825 if (castIter != attrs.end()) {

3826 if (auto attr = llvm::dyn_cast(castIter->getValue()))

3827 castSignedness = attr.getValue();

3828 }

3829

3830

3832 Value lhsAtOutType =

3833 helper.buildTypeFn(castSignedness, outType, block.getArgument(0));

3834 Value rhsAtOutType =

3835 helper.buildTypeFn(castSignedness, outType, block.getArgument(1));

3836 Value productAtOutType =

3837 helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);

3838 Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),

3839 productAtOutType);

3840 helper.yieldOutputs({result});

3841 }

3842

3845 if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)

3847 "expected 'indexing_maps' attribute");

3848 result.addAttribute("indexing_maps", *indexingMapsAttr);

3849

3851 regionBuilder);

3852 }

3853

3855 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());

3857 p, getOperation(), getInputs(), getOutputs(),

3858 {"indexing_maps", "operandSegmentSizes"});

3859 }

3860

3862 int iterationSpaceDims = -1;

3863

3864

3865

3866

3869

3870

3871 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,

3872 bool isInput) -> LogicalResult {

3873

3875 return emitError("provided affine_map is not a projected permutation");

3876

3877

3878 if (auto shapedType = dyn_cast(operandType)) {

3879 if (affineMap.getNumResults() != shapedType.getRank())

3880 return emitError("ranks of shaped operand and results of corresponding "

3881 "affine_map differ");

3883 return emitError("affine_map specifies shaped access while operand has "

3884 "non-shaped type");

3885 }

3886

3887

3888 if (iterationSpaceDims == -1) {

3889 iterationSpaceDims = affineMap.getNumDims();

3892 } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {

3893 return emitError("iteration spaces of provided affine_maps differ");

3894 }

3895

3896

3898 auto affineDimExpr = dyn_cast(affineExpr);

3899 if (!affineDimExpr)

3900 llvm_unreachable("affine_map is a projected permutation");

3901

3902 if (isInput)

3903 inOccurrences[affineDimExpr.getPosition()] += 1;

3904 else

3905 outOccurrences[affineDimExpr.getPosition()] += 1;

3906 }

3907

3908 return success();

3909 };

3910

3911 for (auto &&[affineMap, operandType, isInput] :

3912 llvm::zip(getIndexingMapsArray(), getOperandTypes(),

3914 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))

3915 return failure();

3916 }

3917

3918 bool hasContractingDim = false;

3919 for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {

3920 size_t inOccCount = inOccurrences[dimIndex];

3921 size_t outOccCount = outOccurrences[dimIndex];

3922

3923

3924 hasContractingDim |= inOccCount == 2 && outOccCount == 0;

3925

3926 if (inOccCount == 0 && outOccCount == 0)

3927 return emitError() << "iteration space dim at index " << dimIndex

3928 << " not used to access any operand";

3929

3930

3931

3932

3933

3934

3935

3936

3937

3938

3939 if (inOccCount == 1 && outOccCount != 1)

3941 << "iteration space dim at index " << dimIndex

3942 << " is neither a contracting dim nor of parallel iteration type";

3943 }

3944

3945 if (!hasContractingDim)

3946 return emitError("'indexing_maps' do not specify a contracting dimension");

3947

3948 return success();

3949 }

3950

3953 }

3954

3955 void ContractOp::getEffects(

3957 &effects) {

3958 if (hasPureTensorSemantics())

3959 return;

3961 }

3962

3965 }

3966

3967

3968

3969

3971 BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {

3974 bindDims(context, d0, d1, d2, d3);

3975 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));

3976 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));

3977 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));

3978 return indexingMaps;

3979 }

3980

3983 utils::IteratorType::parallel, utils::IteratorType::parallel,

3984 utils::IteratorType::parallel, utils::IteratorType::reduction};

3985 }

3986

3987 unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }

3988

3989 std::string BatchMatmulOp::getLibraryCallName() {

3991 }

3992

3993

3994

3995 bool BatchMatmulOp::hasUserDefinedMaps() {

3997 getDefaultIndexingMaps(this->getContext());

3999 return defaultMaps != explicitMaps;

4000 }

4001

4002

4003

4004

4005

4006

4007

4008

4009 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {

4011 "Expected less than 3 result dim expr.");

4012 bool isValid = false;

4013 enum Indices { batchPos, mPos, nPos, kPos };

4020 isValid =

4027 }

4028 return isValid;

4029 }

4030

4034 "BatchMatmulOp regionBuilder expects 3 (>=0) args");

4035 RegionBuilderHelper helper(b, block);

4037

4038 TypeFn castVal = TypeFn::cast_signed;

4039 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {

4040 return attr.getName() == "cast";

4041 });

4042 if (castIter != attrs.end()) {

4043 if (auto attr = llvm::dyn_cast(castIter->getValue()))

4045 }

4046

4048 Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));

4049 Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));

4050 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);

4052 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);

4053 yields.push_back(addVal);

4054 helper.yieldOutputs(yields);

4055 }

4056

4062 return failure();

4063

4065 return failure();

4066

4067 do {

4069 return failure();

4070 if (!isa(mapAttr)) {

4072 "expected affine map attribute");

4073 }

4074 indexingMapsAttr.push_back(mapAttr);

4075

4077 break;

4078 } while (true);

4079

4081 return failure();

4082 }

4083

4084 if (indexingMapsAttr.empty()) {

4085 indexingMapsAttr = llvm::map_to_vector(

4086 BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),

4087 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });

4088 }

4091

4093 BatchMatmulOp::getNumRegionArgs(),

4094 BatchMatmulOp::getRegionBuilder());

4095 }

4096

4099 BatchMatmulOp::getDefaultIndexingMaps(getContext()),

4101 if (!llvm::equal(getIndexingMaps(), indexingMaps))

4102 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());

4103

4104 std::array<StringRef, 3> elidedAttrs = {

4105 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};

4107 elidedAttrs);

4108 }

4109

4110

4112

4113

4114 if (!hasUserDefinedMaps())

4115 return success();

4116

4117 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {

4119 return failure();

4120 }

4121 return success();

4122 }

4123

4124 LogicalResult BatchMatmulOp::fold(FoldAdaptor,

4127 }

4128

4129 void BatchMatmulOp::getEffects(

4131 &effects) {

4132 if (hasPureTensorSemantics())

4133 return;

4135 }

4136

4139 }

4140

4141

4142

4143

4144

4145 namespace {

4146 struct ArityGroupAndKind {

4147

4149

4150

4151 union Kind {

4156 };

4157

4158 unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {

4159 return static_cast<unsigned>(arityGroup);

4160 }

4161 }

4162

4164 constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);

4165 constexpr int lastBinary =

4166 static_cast<int>(ElementwiseCaseLimits::LastBinary);

4167 constexpr int lastTernary =

4168 static_cast<int>(ElementwiseCaseLimits::LastTernary);

4169

4170 int val = static_cast<int>(kind);

4171 ArityGroupAndKind result;

4172

4173 if (val < lastUnary) {

4174 result.arityGroup = ElementwiseArityGroup::Unary;

4175 result.kind.unaryFn = static_cast<UnaryFn>(val);

4176 return result;

4177 }

4178 if (val < lastBinary) {

4179 result.arityGroup = ElementwiseArityGroup::Binary;

4180 result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);

4181 return result;

4182 }

4183 if (val >= lastTernary) {

4184 llvm_unreachable("unhandled ElementwiseFn");

4185 }

4186 result.arityGroup = ElementwiseArityGroup::Ternary;

4187 result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);

4188 return result;

4189 }

4190

4192 auto rank = getResultRank();

4194 }

4195

4197 ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,

4201 }

4202

4204

4206 mlir::linalg::ElementwiseKind elemwiseKindVal;

4208 return failure();

4209

4211 auto elemwiseKindAttr = dyn_cast(attr);

4212 if (!elemwiseKindAttr)

4214 "expected ElementwiseKind attribute");

4215 elemwiseKindVal = elemwiseKindAttr.getValue();

4216 } else {

4218 "expected operation 'kind' attribute");

4219 }

4222

4223

4228 return failure();

4230 return failure();

4231 do {

4233 return failure();

4234 if (!isa(mapAttr))

4236 "expected affine map attribute");

4237 indexingMapsAttr.push_back(mapAttr);

4239 break;

4240 } while (true);

4242 return failure();

4243 }

4244

4245

4247 int numRegionArgs =

4248 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;

4250 ElementwiseOp::getRegionBuilder())) {

4252 "unable to parse elemwise op");

4253 }

4254

4255

4256 if (indexingMapsAttr.empty()) {

4257

4258

4259 auto resultType = result.operands[result.operands.size() - 1].getType();

4260 auto shapedType = llvm::dyn_cast(resultType);

4261 if (!shapedType)

4263 "return type needs to be shaped type");

4264 auto numDims = shapedType.getRank();

4265 indexingMapsAttr = llvm::map_to_vector(

4266 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,

4268 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });

4269 }

4270

4273 return success();

4274 }

4275

4277 p << " kind=";

4280 "indexing_maps"};

4281 unsigned arity =

4283 unsigned numDims = getResultRank();

4284

4286 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,

4289

4290 if (!llvm::equal(getIndexingMaps(), indexingMaps))

4291 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());

4292

4294 elidedAttrs);

4295 }

4296

4298

4299

4300

4301 return success();

4302 }

4303

4304

4305

4308 ElementwiseKind elemwiseKind;

4309 for (auto attr : attrs) {

4310 if (attr.getName() == b.getStringAttr("kind")) {

4311 auto kindAttr = dyn_cast(attr.getValue());

4312 assert(kindAttr && "op kind attribute incorrectly set");

4313 elemwiseKind = kindAttr.getValue();

4314 break;

4315 }

4316 }

4317

4319 auto arityGroup = groupAndKind.arityGroup;

4320 auto kind = groupAndKind.kind;

4322 getArityGroupAsUInt(arityGroup) + 1

4323 && "Elementwise regionBuilder number of block args mismatch");

4324

4325 RegionBuilderHelper helper(b, block);

4328

4329 if (arityGroup == ElementwiseArityGroup::Unary) {

4330 result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));

4331

4332 } else if (arityGroup == ElementwiseArityGroup::Binary) {

4333 result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),

4335

4336 } else if (arityGroup == ElementwiseArityGroup::Ternary) {

4337 result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),

4339

4340 } else {

4341 assert(false && "found unhandled category in elemwise");

4342 }

4343

4344 yields.push_back(result);

4345 helper.yieldOutputs(yields);

4346 }

4347

4348 LogicalResult ElementwiseOp::fold(FoldAdaptor,

4351 }

4352

4353 void ElementwiseOp::getEffects(

4355 &effects) {

4356 if (hasPureTensorSemantics())

4357 return;

4359 }

4360

4363 }

4364

4365

4366

4367

4368

4369

4370

4371

4372

4373

4374

4379 for (auto it : llvm::zip(cast(newPackedTy)

4381 .take_back(mixedTiles.size()),

4382 mixedTiles)) {

4383 int64_t shape = std::get<0>(it);

4384 if (shape == ShapedType::kDynamic) {

4385 newMixedTileSizes.push_back(std::get<1>(it));

4386 continue;

4387 }

4388

4389

4390

4392 if (Attribute attr = llvm::dyn_cast_if_present(tile)) {

4393

4394 newMixedTileSizes.push_back(tile);

4395 } else {

4397 "tile size and dim size don't match!");

4398 newMixedTileSizes.push_back(

4400 }

4401 }

4402

4403 return newMixedTileSizes;

4404 }

4405

4406 template

4407 static LogicalResult

4410 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,

4411 "applies to only pack or unpack operations");

4412 int64_t destRank = op.getDestRank();

4414 reifiedReturnShapes[0] =

4416 return success();

4417 }

4418

4419 template

4421 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,

4422 "applies to only pack or unpack operations");

4426 assert(tiles.size() == dimsToTile.size() &&

4427 "tiles must match indices of dimension to block");

4428

4429 for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))

4430 dimAndTileMapping[dimsToTile[i]] = tiles[i];

4431 return dimAndTileMapping;

4432 }

4433

4434 template

4436 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,

4437 "applies to only pack or unpack operations");

4440 unsigned dynamicValIndex = 0;

4441 for (int64_t staticTile : op.getStaticInnerTiles()) {

4442 if (!ShapedType::isDynamic(staticTile))

4443 mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));

4444 else

4445 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);

4446 }

4447 return mixedInnerTiles;

4448 }

4449

4450 template

4452 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,

4453 "applies to only pack or unpack operations");

4457 return staticTiles;

4458 }

4459

4460

4461

4462

4463

4465 size_t rank) {

4466 size_t dimsPosSize = dimsPos.size();

4467 if (dimsPosSize > rank)

4468 return true;

4470 if (dimsPosSize != uniqued.size())

4471 return true;

4472 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {

4473 return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);

4474 });

4475 }

4476

4477

4478

4481 assert(

4482 sourceShape.size() == limitShape.size() &&

4483 "expected source shape rank, and limit of the shape to have same rank");

4484 return llvm::all_of(

4485 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {

4486 int64_t sourceExtent = std::get<0>(it);

4487 int64_t limit = std::get<1>(it);

4488 return ShapedType::isDynamic(sourceExtent) ||

4489 ShapedType::isDynamic(limit) || sourceExtent <= limit;

4490 });

4491 }

4492

4493 template

4495 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,

4496 "applies to only pack or unpack operations");

4497 Operation *op = packOrUnPack.getOperation();

4498

4499

4502 };

4503

4504

4506 if (hasZeros(mixedTiles))

4507 return op->emitError("invalid zero tile factor");

4508

4509

4510 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)

4511 ? packOrUnPack.getSourceType()

4512 : packOrUnPack.getDestType();

4513 size_t unpackedRank = unpackedType.getRank();

4515 ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();

4517 return op->emitError("invalid inner_dims_pos vector");

4519 return op->emitError("invalid outer_dims_perm vector");

4520 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)

4521 return op->emitError("outer_dims_perm must be a permutation or empty");

4522

4523

4524

4525 if (mixedTiles.size() > unpackedRank) {

4526 return op->emitError("tiling factors must be less than or equal to the "

4527 "input rank for pack or output rank for unpack");

4528 }

4529 if (mixedTiles.size() != innerDimsPos.size()) {

4531 "tiling factors must equal the number of dimensions to tile");

4532 }

4533

4534 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)

4535 ? packOrUnPack.getDestType()

4536 : packOrUnPack.getSourceType();

4537 size_t packedRank = packedType.getRank();

4538

4539 size_t expectedPackedRank = unpackedRank + mixedTiles.size();

4540 if (expectedPackedRank != packedRank) {

4542 "packed rank != (unpacked rank + num tiling factors), got ")

4543 << packedRank << " != " << expectedPackedRank;

4544 }

4545

4546

4547

4548

4549 RankedTensorType expectedPackedType = PackOp::inferPackedType(

4550 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);

4551 if (areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {

4552 return op->emitError("the shape of output is not large enough to hold the "

4553 "packed data. Expected at least ")

4554 << expectedPackedType << ", got " << packedType;

4555 }

4556 if (!llvm::all_of(

4557 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),

4558 mixedTiles),

4559 [](std::tuple<int64_t, OpFoldResult> it) {

4560 int64_t shape = std::get<0>(it);

4561 if (Attribute attr =

4562 llvm::dyn_cast_if_present(std::get<1>(it))) {

4563 IntegerAttr intAttr = dyn_cast_or_null(attr);

4564 int64_t staticTileSize = intAttr.getValue().getSExtValue();

4565 return shape == staticTileSize;

4566 }

4567 return ShapedType::isDynamic(shape);

4568 })) {

4569 return op->emitError("mismatch in inner tile sizes specified and shaped of "

4570 "tiled dimension in the packed type");

4571 }

4572 return success();

4573 }

4574

4575 namespace {

4576

4577

4578

4579

4580

4581

4582 struct PackOrUnPackTransposeResult {

4586 };

4587 }

4588

4589 template

4590 static PackOrUnPackTransposeResult

4594 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,

4595 "applies to only pack or unpack operations");

4596 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&

4597 "some permutation must be non-empty");

4598 PackOrUnPackTransposeResult metadata;

4599 metadata.innerDimsPos =

4601 metadata.innerTiles =

4603 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value

4604 ? packOrUnPackOp.getSourceRank()

4605 : packOrUnPackOp.getDestRank();

4606 metadata.outerDimsPerm =

4607 packOrUnPackOp.getOuterDimsPerm().empty()

4608 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))

4610 if (!innerPermutation.empty()) {

4611 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&

4613 "invalid inner permutation");

4616 }

4617 if (!outerPermutation.empty()) {

4618 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&

4620 "invalid outer permutation");

4622 }

4623 return metadata;

4624 }

4625

4626

4627

4628

4629

4630 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {

4631 setNameFn(getResult(), "pack");

4632 }

4633

4637 std::optional paddingValue,

4640 "number of tile sizes specified must match the specified number of "

4641 "original dimensions to be tiled");

4645 build(builder, state, dest.getType(), source, dest,

4646 paddingValue ? *paddingValue : nullptr,

4651 }

4652

4653 LogicalResult

4657 }

4658

4661 }

4662

4665 }

4666

4669 }

4670

4672 ShapedType inputType = getSourceType();

4673 int64_t inputRank = inputType.getRank();

4674 return getDestType().getShape().take_front(inputRank);

4675 }

4676

4679 auto packedShape = getDestType().getShape();

4681

4683 res.push_back(packedShape[index]);

4684

4685 return res;

4686 }

4687

4694 outputShape.take_front(inputShape.size()));

4696 assert(outerDimsPerm.size() == outputTileSizes.size() &&

4697 "expected output and outer_dims_perm to have same size");

4700 }

4702 if (ShapedType::isDynamic(inputShape[pos]))

4703 continue;

4705

4706 if (!constantTile) {

4707 if (!ShapedType::isDynamic(outputTileSizes[pos]) &&

4708 (inputShape[pos] % outputTileSizes[pos] != 0))

4709 return true;

4710 } else if (inputShape[pos] % (*constantTile) != 0) {

4711 return true;

4712 }

4713 }

4714 return false;

4715 }

4716

4719 return failure();

4720

4721

4722

4723

4724 auto paddingValue = getPaddingValue();

4725 if (paddingValue &&

4726 paddingValue.getType() != getSourceType().getElementType()) {

4727 return emitOpError("expected padding_value has ")

4728 << getSourceType().getElementType()

4729 << " but got: " << paddingValue.getType();

4730 }

4731

4732 if (!paddingValue &&

4733 requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),

4734 getDestType().getShape(), getOuterDimsPerm(),

4735 getMixedTiles())) {

4736 return emitOpError(

4737 "invalid tile factor or output size provided. Only full tiles are "

4738 "supported when padding_value is not set");

4739 }

4740 return success();

4741 }

4742

4743

4744

4748 for (auto o : ofrs) {

4749

4750 if (llvm::dyn_cast_if_present(o))

4751 result.push_back(ShapedType::kDynamic);

4752 else

4753 result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));

4754 }

4755 return result;

4756 }

4757

4758

4759

4760

4766 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))

4767 continue;

4768 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {

4769 resultShape[tiledDim.value()] = ShapedType::kDynamic;

4770 continue;

4771 }

4772 resultShape[tiledDim.value()] = llvm::divideCeilSigned(

4773 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);

4774 }

4775

4776

4779

4780

4781 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());

4782 return resultShape;

4783 }

4784

4790

4796 builder, loc, ceilDivExpr,

4797 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});

4798 }

4801 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());

4802

4807

4808

4809

4810

4811

4812 for (unsigned i = 0; i < resultDims.size(); ++i) {

4813 if (!ShapedType::isDynamic(resultTypeShape[i]))

4814 continue;

4815 resultDims[i] =

4817 }

4818

4819 return resultDims;

4820 }

4821

4822

4823

4824 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,

4831 }

4832

4841 {v1, v2});

4842 };

4843

4846 llvm::cast(source.getType()).getShape())) {

4847 if (ShapedType::isDynamic(value))

4848 mixedSizes.push_back(

4849 b.createtensor::DimOp(loc, source, index).getResult());

4850 else

4851 mixedSizes.push_back(b.getIndexAttr(value));

4852 }

4853 for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {

4854 int64_t dimPos = std::get<0>(it);

4856 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);

4857 }

4859 applyPermutationToVector(mixedSizes, outerDimsPerm);

4860

4861 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());

4862 auto elemType = llvm::cast(source.getType()).getElementType();

4863 return b.createtensor::EmptyOp(loc, mixedSizes, elemType);

4864 }

4865

4866 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,

4870 *this, innerPermutation, outerPermutation);

4871 Value transposedDest =

4872 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,

4873 metadata.innerDimsPos, metadata.outerDimsPerm);

4874 return b.create(loc, getSource(), transposedDest,

4875 metadata.innerDimsPos, metadata.innerTiles,

4876 getPaddingValue(), metadata.outerDimsPerm);

4877 }

4878

4879

4880 template

4882 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,

4883 "applies to only pack or unpack operations");

4884 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)

4885 ? op.getDestType()

4886 : op.getSourceType();

4888 for (auto [dimDest, tile] : llvm::zip(

4889 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {

4891 if (!constTileSize || ShapedType::isDynamic(dimDest))

4892 return false;

4893 }

4894 return true;

4895 }

4896

4898 if (getPaddingValue())

4900

4901

4902

4903

4906

4908 }

4909

4910

4911

4913 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())

4914 return false;

4915 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())

4916 return true;

4917

4918

4919

4922 }

4923

4924

4925

4927 auto packTiles = packOp.getMixedTiles();

4928 auto unPackTiles = unPackOp.getMixedTiles();

4929 if (packTiles.size() != unPackTiles.size())

4930 return false;

4931 for (size_t i = 0, e = packTiles.size(); i < e; i++) {

4933 return false;

4934 }

4935 return true;

4936 }

4937

4938

4940 auto srcType = op.getSourceType();

4941 if (llvm::any_of(op.getInnerDimsPos(),

4942 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))

4943 return false;

4944 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))

4945 return false;

4946 return !PackOp::requirePaddingValue(

4947 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),

4948 op.getOuterDimsPerm(), op.getMixedTiles());

4949 }

4950

4951

4952

4955 bool changeNeeded = false;

4956 srcShape.assign(packOp.getSourceType().getShape().begin(),

4957 packOp.getSourceType().getShape().end());

4958 destShape.assign(packOp.getDestType().getShape().begin(),

4959 packOp.getDestType().getShape().end());

4960 llvm::SmallSetVector<int64_t, 4> innerDims;

4961 innerDims.insert_range(packOp.getInnerDimsPos());

4963 if (!packOp.getOuterDimsPerm().empty())

4965 int srcRank = packOp.getSourceRank();

4966 for (auto i : llvm::seq<int64_t>(0, srcRank)) {

4967 if (innerDims.contains(i))

4968 continue;

4969 int64_t srcPos = i;

4970 int64_t destPos = i;

4971 if (!inverseOuterDimsPerm.empty())

4972 destPos = inverseOuterDimsPerm[srcPos];

4973 if (ShapedType::isDynamic(srcShape[srcPos]) ==

4974 ShapedType::isDynamic(destShape[destPos])) {

4975 continue;

4976 }

4977 int64_t size = srcShape[srcPos];

4978 if (ShapedType::isDynamic(size))

4979 size = destShape[destPos];

4980 srcShape[srcPos] = size;

4981 destShape[destPos] = size;

4982 changeNeeded = true;

4983 }

4984 return changeNeeded;

4985 }

4986

4987 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {

4988

4989 if (auto unPackOp = packOp.getSource().getDefiningOp()) {

4990 if (unPackOp.getSourceType() != packOp.getDestType())

4991 return failure();

4992 if (packOp.getPaddingValue() ||

4995 return failure();

4996 rewriter.replaceOp(packOp, unPackOp.getSource());

4997 return success();

4998 }

4999

5000

5003 packOp.getPaddingValueMutable().clear();

5005 return success();

5006 }

5007

5008

5011 Location loc = packOp.getLoc();

5012 Value source = packOp.getSource();

5013 if (srcShape != packOp.getSourceType().getShape()) {

5014 auto newSrcType = packOp.getSourceType().clone(srcShape);

5015 source =

5016 rewriter.createtensor::CastOp(loc, newSrcType, packOp.getSource());

5017 }

5018 Value dest = packOp.getDest();

5019 RankedTensorType originalResultType = packOp.getDestType();

5020 bool needUpdateDestType = (destShape != originalResultType.getShape());

5021 if (needUpdateDestType) {

5022 auto newDestType = packOp.getDestType().clone(destShape);

5023 dest =

5024 rewriter.createtensor::CastOp(loc, newDestType, packOp.getDest());

5025 }

5027 packOp.getSourceMutable().assign(source);

5028 packOp.getDestMutable().assign(dest);

5029 packOp.getResult().setType(cast(dest.getType()));

5030 });

5031

5032 if (needUpdateDestType) {

5034 auto castOp =

5035 rewriter.createtensor::CastOp(loc, originalResultType, packOp);

5037 }

5038 return success();

5039 }

5040

5041 return failure();

5042 }

5043

5044 template

5046 RankedTensorType packedTensorType) {

5047 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||

5048 std::is_same<PackOrUnpackOp, UnPackOp>::value,

5049 "Function meant for pack/unpack");

5050

5051

5052

5054 int64_t numPackedDims = innerDimsPos.size();

5055 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));

5057

5058 return false;

5059 }

5060

5062 int64_t packedRank = packedTensorType.getRank();

5063

5064

5065

5066

5067

5068

5069

5070

5071

5072 return llvm::all_of(

5073 llvm::seq<int64_t>(0, packedRank - numPackedDims),

5074 [&packedShape](int64_t i) { return packedShape[i] == 1; });

5075 }

5076

5077 bool PackOp::isLikePad() {

5078 auto packedTensorType =

5079 llvm::cast((*this)->getResultTypes().front());

5081 }

5082

5083 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {

5084 std::optional paddingValue;

5085 if (auto pad = adaptor.getPaddingValue())

5086 paddingValue = pad;

5087 if (OpFoldResult reshapedSource = reshapeConstantSource(

5088 llvm::dyn_cast_if_present(adaptor.getSource()),

5089 getDestType(), paddingValue))

5090 return reshapedSource;

5091 return {};

5092 }

5093

5094

5095

5096

5097

5098

5099

5100

5101

5102

5103

5104

5105

5106

5107

5110

5114 return failure();

5115

5119

5120

5123

5124

5125

5126

5127

5128 PackOp newOp = rewriter.create(

5129 op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),

5130 newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());

5132

5133

5134 Value oldResult = op.getResult();

5135 Value newResult = newOp.getResult();

5137 ? rewriter.createtensor::CastOp(

5138 op->getLoc(), oldResult.getType(), newResult)

5139 : newResult;

5140

5141 rewriter.replaceOp(op, {replacement});

5142

5143 return success();

5144 }

5145 };

5146

5147

5148

5149

5150

5151 void UnPackOp::getAsmResultNames(

5153 setNameFn(getResult(), "unpack");

5154 }

5155

5156 LogicalResult

5160 }

5161

5164 }

5165

5168 }

5169

5172 }

5173

5175 ShapedType destType = getDestType();

5176 int64_t destRank = destType.getRank();

5177 return getSourceType().getShape().take_front(destRank);

5178 }

5179

5182 auto packedShape = getSourceType().getShape();

5184

5186 res.push_back(packedShape[index]);

5187

5188 return res;

5189 }

5190

5193 }

5194

5196

5199

5201 }

5202

5208 "number of tile sizes specified must match the specified number of "

5209 "original dimensions to be tiled");

5213 build(builder, state, dest.getType(), source, dest,

5218 }

5219

5229 };

5230

5232 auto srcType = llvm::cast(source.getType());

5233 for (auto i :

5234 llvm::seq(0, srcType.getRank() - innerTileSizes.size())) {

5235 if (srcType.isDynamicDim(i))

5236 mixedSizes.push_back(b.createtensor::DimOp(loc, source, i).getResult());

5237 else

5238 mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));

5239 }

5241 applyPermutationToVector(

5243 }

5244

5245 for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))

5246 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);

5247

5248 auto elemType = srcType.getElementType();

5249 return b.createtensor::EmptyOp(loc, mixedSizes, elemType);

5250 }

5251

5252 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,

5253 Value transposedSource,

5257 *this, innerPermutation, outerPermutation);

5258 return b.create(loc, transposedSource, getDest(),

5259 metadata.innerDimsPos, metadata.innerTiles,

5260 metadata.outerDimsPerm);

5261 }

5262

5263

5264

5267 bool changeNeeded = false;

5268 srcShape.assign(op.getSourceType().getShape().begin(),

5269 op.getSourceType().getShape().end());

5270 destShape.assign(op.getDestType().getShape().begin(),

5271 op.getDestType().getShape().end());

5272 llvm::SmallSetVector<int64_t, 4> innerDims;

5273 innerDims.insert_range(op.getInnerDimsPos());

5275 if (!op.getOuterDimsPerm().empty())

5277 int destRank = op.getDestRank();

5278 for (auto i : llvm::seq<int64_t>(0, destRank)) {

5279 if (innerDims.contains(i))

5280 continue;

5281 int64_t srcPos = i;

5282 int64_t destPos = i;

5283 if (!inverseOuterDimsPerm.empty())

5284 srcPos = inverseOuterDimsPerm[destPos];

5285 if (ShapedType::isDynamic(srcShape[srcPos]) ==

5286 ShapedType::isDynamic(destShape[destPos])) {

5287 continue;

5288 }

5289 int64_t size = srcShape[srcPos];

5290 if (ShapedType::isDynamic(size))

5291 size = destShape[destPos];

5292 srcShape[srcPos] = size;

5293 destShape[destPos] = size;

5294 changeNeeded = true;

5295 }

5296 return changeNeeded;

5297 }

5298

5299 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,

5301

5302 if (PackOp packOp = unPackOp.getSource().getDefiningOp()) {

5303 if (packOp.getSourceType() != unPackOp.getDestType())

5304 return failure();

5305 if (packOp.getPaddingValue() ||

5308 return failure();

5309 rewriter.replaceOp(unPackOp, packOp.getSource());

5310 return success();

5311 }

5312

5313 if (auto dstStyleOp =

5314 unPackOp.getDest().getDefiningOp()) {

5315 auto destValue = cast(unPackOp.getDest());

5316 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];

5318 [&]() { unPackOp.setDpsInitOperand(0, newDest); });

5319 return success();

5320 }

5321

5322 if (unPackOp->hasOneUse()) {

5323 auto extractSliceUser =

5324 dyn_casttensor::ExtractSliceOp(*unPackOp->getUsers().begin());

5325 if (extractSliceUser &&

5328 extractSliceUser.getSourceType().getRank() ==

5329 extractSliceUser.getResultType().getRank()) {

5332 auto newDest = rewriter.createtensor::ExtractSliceOp(

5333 unPackOp->getLoc(), unPackOp.getDest(),

5334 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),

5335 extractSliceUser.getMixedStrides());

5337 unPackOp.setDpsInitOperand(0, newDest);

5338 unPackOp.getResult().setType(newDest.getType());

5339 });

5340 rewriter.replaceOp(extractSliceUser, unPackOp);

5341 return success();

5342 }

5343 }

5344

5345

5348 Location loc = unPackOp.getLoc();

5349 Value source = unPackOp.getSource();

5350 if (srcShape != unPackOp.getSourceType().getShape()) {

5351 auto newSrcType = unPackOp.getSourceType().clone(srcShape);

5352 source = rewriter.createtensor::CastOp(loc, newSrcType,

5353 unPackOp.getSource());

5354 }

5355 Value dest = unPackOp.getDest();

5356 if (destShape != unPackOp.getDestType().getShape()) {

5357 auto newDestType = unPackOp.getDestType().clone(destShape);

5358 dest =

5359 rewriter.createtensor::CastOp(loc, newDestType, unPackOp.getDest());

5360 }

5361 Value newOp = rewriter.create(

5362 loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),

5363 unPackOp.getOuterDimsPerm());

5365 unPackOp, unPackOp.getResult().getType(), newOp);

5366 return success();

5367 }

5368

5369 return failure();

5370 }

5371

5372 bool UnPackOp::isLikeUnPad() {

5373 RankedTensorType packedTensorType = getSourceType();

5375 }

5376

5377 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {

5378 if (OpFoldResult reshapedSource = reshapeConstantSource(

5379 llvm::dyn_cast_if_present(adaptor.getSource()),

5381 return reshapedSource;

5382 return {};

5383 }

5384

5385

5386

5387

5388

5389

5390

5391

5392

5393

5394

5395

5396

5397

5398

5401

5405 return failure();

5406

5410 Value sourceTensor = newOperands[0];

5411

5412

5414 rewriter, sourceTensor.getType(), op.getMixedTiles());

5415

5416

5417

5418

5419

5420 UnPackOp newOp = rewriter.create(

5421 op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),

5422 newMixedTileSizes, op.getOuterDimsPerm());

5424

5425

5426 Value oldResult = op.getResult();

5427 Value newResult = newOp.getResult();

5429 ? rewriter.createtensor::CastOp(

5430 op->getLoc(), oldResult.getType(), newResult)

5431 : newResult;

5432

5433 rewriter.replaceOp(op, {replacement});

5434

5435 return success();

5436 }

5437 };

5438

5439

5440

5441

5444 utils::IteratorType::reduction, utils::IteratorType::parallel,

5445 utils::IteratorType::parallel, utils::IteratorType::reduction};

5446 }

5447

5449 BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {

5452 bindDims(context, d0, d1, d2, d3);

5453 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));

5454 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));

5455 indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));

5456 return indexingMaps;

5457 }

5458

5459 unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }

5460

5461 std::string BatchReduceMatmulOp::getLibraryCallName() {

5463 }

5464

5465

5466

5467 bool BatchReduceMatmulOp::hasUserDefinedMaps() {

5469 getDefaultIndexingMaps(this->getContext());

5471 return defaultMaps != explicitMaps;

5472 }

5473

5474

5475

5476

5477

5478

5479

5480

5481 bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,

5482 bool isLHS) {

5484 "Expected less than 3 result dim expr.");

5485 bool isValid = false;

5486 enum Indices { batchPos, mPos, nPos, kPos };

5493 isValid =

5500 }

5501 return isValid;

5502 }

5503

5507 "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");

5508 RegionBuilderHelper helper(b, block);

5510

5512 Value castValA =

5513 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));

5514 Value castValB =

5515 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));

5516 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);

5518 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);

5519 yields.push_back(addVal);

5520 helper.yieldOutputs(yields);

5521 }

5522

5529 return failure();

5531 return failure();

5532

5533 do {

5535 return failure();

5536 if (!isa(mapAttr)) {

5538 "expected affine map attribute");

5539 }

5540 indexingMapsAttr.push_back(mapAttr);

5541

5543 break;

5544 } while (true);

5545

5547 return failure();

5548 }

5549

5550 if (indexingMapsAttr.empty()) {

5551 indexingMapsAttr = llvm::map_to_vector(

5552 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),

5553 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });

5554 }

5558 BatchReduceMatmulOp::getNumRegionArgs(),

5559 BatchReduceMatmulOp::getRegionBuilder());

5560 }

5561

5564 BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),

5566

5567 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {

5568 p << " indexing_maps = [";

5569 llvm::interleaveComma(getIndexingMaps(), p,

5571 p << "]";

5572 }

5573

5575 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};

5577 elidedAttrs);

5578 }

5579

5580

5582

5583

5584 if (!hasUserDefinedMaps())

5585 return success();

5586

5587 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {

5589 return failure();

5590 }

5591 return success();

5592 }

5593 LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,

5596 }

5597 void BatchReduceMatmulOp::getEffects(

5599 &effects) {

5600 if (hasPureTensorSemantics())

5601 return;

5603 }

5604

5607 }

5608

5609 }

5610 }

5611

5612

5613

5614

5615

5616 void LinalgDialect::getCanonicalizationPatterns(

5620 }

5621

5625 return arith::ConstantOp::materialize(builder, value, type, loc);

5626 }

static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)

A utility function used to materialize a constant for a given attribute and type.

static MLIRContext * getContext(OpFoldResult val)

static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)

Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...

static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)

static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)

static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)

static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)

Produce a linalg generic that computes the final step of the softmax decomposition.

static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)

SmallVector< int64_t > outerDimsPerm

static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)

Fills the region of a structured operation using the provided regionBuilder.

static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)

SmallVector< OpFoldResult > innerTiles

static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)

Check if the user defined map is valid broadcast map.

static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)

static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)

static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)

static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)

Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...

static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)

static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)

This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...

static Operation * findPayloadOp(Block *body, bool initFirst=false)

static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)

Creates a structured operation given inputs, outputs, and attributes.

static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)

static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)

static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)

static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)

ElementwiseArityGroup arityGroup

static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)

static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)

union mlir::linalg::@1203::ArityGroupAndKind::Kind kind

static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)

SmallVector< int64_t > innerDimsPos

static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)

static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)

static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)

static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)

Return a memref.dim or tensor.dim for the shape of v at dim.

static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)

static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)

void printShortForm(OpAsmPrinter &p, Operation *payloadOp)

static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)

static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)

Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...

static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})

static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)

Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)

static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)

Walks the given type hierarchy with the given indices, potentially down to component granularity,...

static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)

static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)

static ArrayRef< int64_t > getShape(Type type)

Returns the shape of the given type.

Base type for affine expression.

bool isFunctionOfDim(unsigned position) const

Return true if the affine expression involves AffineDimExpr position.

AffineExpr ceilDiv(uint64_t v) const

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

AffineMap dropResults(ArrayRef< int64_t > positions) const

static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)

Returns an AffineMap with 'numDims' identity result dim exprs.

static AffineMap get(MLIRContext *context)

Returns a zero result affine map with no dimensions or symbols: () -> ().

bool isProjectedPermutation(bool allowZeroInResults=false) const

Returns true if the AffineMap represents a subset (i.e.

unsigned getNumDims() const

ArrayRef< AffineExpr > getResults() const

unsigned getNumResults() const

AffineExpr getResult(unsigned idx) const

static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)

Returns an AffineMap representing a permutation.

@ Paren

Parens surrounding zero or more operands.

virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0

Parse a colon followed by a type list, which must have at least one type.

virtual Builder & getBuilder() const =0

Return a builder which provides useful access to MLIRContext, global objects like types and attribute...

virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0

Parse a named dictionary into 'result' if it is present.

virtual ParseResult parseOptionalKeyword(StringRef keyword)=0

Parse the given keyword if present.

MLIRContext * getContext() const

virtual ParseResult parseRParen()=0

Parse a ) token.

virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0

Emit a diagnostic at the specified location and return failure.

virtual ParseResult parseLSquare()=0

Parse a [ token.

virtual ParseResult parseRSquare()=0

Parse a ] token.

virtual ParseResult parseRBrace()=0

Parse a } token.

virtual ParseResult parseEqual()=0

Parse a = token.

virtual SMLoc getCurrentLocation()=0

Get the location of the next token and store it into the argument.

virtual ParseResult parseOptionalComma()=0

Parse a , token if present.

virtual ParseResult parseOptionalLess()=0

Parse a '<' token if present.

virtual ParseResult parseGreater()=0

Parse a '>' token.

virtual ParseResult parseLParen()=0

Parse a ( token.

virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0

Parse an optional arrow followed by a type list.

ParseResult parseKeyword(StringRef keyword)

Parse a given keyword.

virtual ParseResult parseAttribute(Attribute &result, Type type={})=0

Parse an arbitrary attribute of a given type and return it in result.

virtual ParseResult parseOptionalLBrace()=0

Parse a { token if present.

void printOptionalArrowTypeList(TypeRange &&types)

Print an optional arrow followed by a type list.

virtual void printAttribute(Attribute attr)

Attributes are known-constant values of operations.

Block represents an ordered list of Operations.

BlockArgument getArgument(unsigned i)

unsigned getNumArguments()

Operation * getTerminator()

Get the terminator operation of this block.

BlockArgument addArgument(Type type, Location loc)

Add one value to the argument list.

OpListType & getOperations()

BlockArgListType getArguments()

This class is a general helper class for creating context-global objects like types,...

IntegerAttr getIndexAttr(int64_t value)

DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)

IntegerAttr getIntegerAttr(Type type, int64_t value)

DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)

AffineMap getMultiDimIdentityMap(unsigned rank)

IntegerAttr getI64IntegerAttr(int64_t value)

StringAttr getStringAttr(const Twine &bytes)

AffineExpr getAffineDimExpr(unsigned position)

MLIRContext * getContext() const

ArrayAttr getArrayAttr(ArrayRef< Attribute > value)

ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)

IRValueT get() const

Return the current value being used by this operand.

This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...

ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...

void assign(const_iterator inStart, const_iterator inEnd)

Replaces the attributes with new list of attributes.

ArrayRef< NamedAttribute > getAttrs() const

Return all of the attributes on this operation.

DictionaryAttr getDictionary(MLIRContext *context) const

Return a dictionary attribute for the underlying dictionary.

Attribute get(StringAttr name) const

Return the specified attribute if present, null otherwise.

void append(StringRef name, Attribute attr)

Add an attribute with the specified name.

Attribute set(StringAttr name, Attribute value)

If the an attribute exists with the specified name, change it to the new value.

NamedAttribute represents a combination of a name and an Attribute value.

StringAttr getName() const

Return the name of the attribute.

Attribute getValue() const

Return the value of the attribute.

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0

Parses a region.

virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0

Parse zero or more arguments with a specified surrounding delimiter.

virtual FailureOr< OperationName > parseCustomOperationName()=0

Parse the name of an operation, in the custom form.

ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)

Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...

virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0

Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

virtual void printNewline()=0

Print a newline and indent the printer to the start of the current operation.

virtual void increaseIndent()=0

Increase indentation.

virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0

If the specified operation has attributes, print out an attribute dictionary with their values.

virtual void decreaseIndent()=0

Decrease indentation.

virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0

Prints a region.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

This class represents a single result from folding an operation.

This class represents an operand of an operation.

unsigned getOperandNumber()

Return which operand this is in the OpOperand list of the Operation.

This is a value defined by a result of an operation.

unsigned getResultNumber() const

Returns the number of this result.

StringRef getStringRef() const

Return the name of this operation. This always succeeds.

std::optional< RegisteredOperationName > getRegisteredInfo() const

If this operation is registered, returns the registered information, std::nullopt otherwise.

Operation is the basic unit of execution within MLIR.

result_iterator result_begin()

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

unsigned getNumOperands()

ArrayRef< NamedAttribute > getAttrs()

Return all of the attributes on this operation.

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

OperationName getName()

The name of an operation is the key identifier for it.

operand_type_range getOperandTypes()

result_iterator result_end()

result_type_range getResultTypes()

operand_range getOperands()

Returns an iterator on the underlying Value's.

result_range getResults()

void setDiscardableAttrs(DictionaryAttr newAttrs)

Set the discardable attribute dictionary on this operation.

unsigned getNumResults()

Return the number of results held by this operation.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class contains a list of basic blocks and a link to the parent operation it is attached to.

RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)

Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

virtual void finalizeOpModification(Operation *op)

This method is used to signal the end of an in-place modification of the given operation.

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)

Find uses of from and replace them with to except if the user is exceptedUser.

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

virtual void startOpModification(Operation *op)

This method is used to notify the rewriter that an in-place operation modification is about to happen...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class represents a specific instance of an effect.

static DerivedEffect * get()

Returns a unique instance for the derived effect class.

static DefaultResource * get()

Returns a unique instance for the given effect class.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

bool isSignlessIntOrIndexOrFloat() const

Return true if this is a signless integer, index, or float type.

This class provides an abstraction over the different types of ranges over Values.

type_range getTypes() const

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

bool hasOneUse() const

Returns true if this value has exactly one use.

Location getLoc() const

Return the location of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...

static Attribute parse(AsmParser &parser, Type type)

Parse the short form [42, 100, -1] without any type prefix.

constexpr auto RecursivelySpeculatable

Speculatability

This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...

constexpr auto Speculatable

constexpr auto NotSpeculatable

AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...

OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...

Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)

Returns the identity value associated with an AtomicRMWKind op.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)

static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)

Returns true if the srcShape or destShape is different from the one in packOp and populates each with...

static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)

static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)

Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...

static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)

Returns true if dimsPos is invalid.

static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)

Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.

static SmallVector< int64_t > getStaticTilesImpl(OpTy op)

static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)

Helper for PackOp::{getResultShape,inferPackedType}.

SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)

Return the vector that is the concatenation of a and b.

static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)

static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)

OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)

Create one memref::DimOp or tensor::DimOp depending on the type of val.

std::string generateLibraryCallName(Operation *op)

Returns the name mangled library call name to disambiguate between different overloads at the C level...

static bool paddingIsNotNeeded(PackOp op)

Returns true if the pack op does not need a padding value.

AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)

Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.

SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)

Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...

static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)

static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)

static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)

Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)

Create one memref::DimOp or tensor::DimOp depending on the type of val.

static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)

bool areTilesAndTiledDimsAllConstant(OpTy op)

Returns true if the tiles and the tiled dims are constant.

static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)

static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)

FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)

LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)

This is a common utility used for patterns of the form "someop(memref.cast) -> someop".

Kind

An enumeration of the kinds of predicates.

DynamicAPInt floor(const Fraction &f)

DynamicAPInt ceil(const Fraction &f)

DynamicAPInt round(const Fraction &f)

Fraction abs(const Fraction &f)

QueryRef parse(llvm::StringRef line, const QuerySession &qs)

SparseTensorEncodingAttr getSparseTensorEncoding(Type type)

Convenience method to get a sparse encoding attribute from a type.

uint64_t getM(LevelType lt)

bool hasFoldableTensorCastOperand(Operation *op)

Return true if any of the operands of op is a CastOp that can be folded into its consumer,...

bool canFoldIntoProducerOp(CastOp castOp)

Determines whether the tensor::CastOp casts to a more static version of the source tensor.

SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)

Assuming that op contains at least one operand that is a foldable CastOp (i.e.

SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)

Return the dimensions of the given tensor value.

Include the generated interface declarations.

Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)

Converts a scalar value operand to type toType.

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)

Reify the shape of the result of an operation (typically in terms of the shape of its operands).

bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)

Return true if all of ofrs are constant integers equal to value.

bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)

Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to DimExpr at positions: [0 .

SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

AffineMap inversePermutation(AffineMap map)

Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...

Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)

This parses a single MLIR attribute to an MLIR context if it was valid.

bool isIdentityPermutation(ArrayRef< int64_t > permutation)

Returns true if permutation is an identity permutation.

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

bool isZeroInteger(OpFoldResult v)

Return true if v is an IntegerAttr with value 0.

void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)

Bind a list of AffineExpr references to SymbolExpr at positions: [0 .

void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)

Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)

SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)

Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

OpFoldResult getAsOpFoldResult(Value val)

Given a value, try to extract a constant Attribute.

LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)

Returns success if the given two shapes are compatible.

SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})

Iteratively computes backward slices and forward slices until a fixed point is reached.

void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)

Apply the permutation defined by permutation to inVec.

AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)

These free functions allow clients of the API to not use classes in detail.

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)

Returns a permutation vector that drop the input dims in dropPositions from inputPerm.

bool isPermutationVector(ArrayRef< int64_t > interchange)

Method to check if an interchange vector is a permutation.

SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)

Helper method to apply to inverse a permutation.

Fold transpose with transpose.

LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override

This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...

LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override

OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...

This represents an operation in an abstracted form, suitable for use with the builder APIs.

SmallVector< Value, 4 > operands

void addOperands(ValueRange newOperands)

void addAttributes(ArrayRef< NamedAttribute > newAttributes)

Add an array of named attributes.

void addAttribute(StringRef name, Attribute attr)

Add an attribute with the specified name.

void addTypes(ArrayRef< Type > newTypes)

SmallVector< std::unique_ptr< Region >, 1 > regions

Regions that the op will hold.

Attribute propertiesAttr

This Attribute is used to opaquely construct the properties of the operation.

Region * addRegion()

Create a region that should be attached to the operation.

Container for result values of tiling.

Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...

LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override

Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...

LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override