MLIR: lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

15

32

33 using namespace mlir;

37

38

39

40

41

42

45 }

46

47

50 return enc && !llvm::all_of(enc.getLvlTypes(),

51 [](auto lt) { return lt == LevelFormat::Dense; });

52 }

54

55

58

59 if (auto alloc = val.getDefiningOp()) {

61 if (isZero)

64 }

65

66 if (auto empty = val.getDefiningOptensor::EmptyOp())

67 return !isZero;

68

70 }

71

72

74 auto yieldOp = castlinalg::YieldOp(op.getRegion().front().getTerminator());

75 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {

76 if (isaarith::MulFOp(def) || isaarith::MulIOp(def)) {

77

78 Value s1 = op.getBlock()->getArgument(0);

79 Value s2 = op.getBlock()->getArgument(1);

80 return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||

81 (def->getOperand(1) == s1 && def->getOperand(0) == s2);

82 }

83 }

84 return false;

85 }

86

87

89 if (auto arg = dyn_cast(val))

90 return arg != x;

92 if (isaarith::MulFOp(def) || isaarith::MulIOp(def))

93 return isMulChain(def->getOperand(0), x) &&

95 }

96 return false;

97 }

98

99

101 auto yieldOp = castlinalg::YieldOp(op.getRegion().front().getTerminator());

102 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {

103 if (isaarith::AddFOp(def) || isaarith::AddIOp(def)) {

104 Value x = op.getBlock()->getArguments().back();

105 return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||

106 (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));

107 }

108 }

109 return false;

110 }

111

112

114 auto yieldOp = castlinalg::YieldOp(op.getRegion().front().getTerminator());

115 if (auto arg = dyn_cast(yieldOp.getOperand(0))) {

116 if (arg.getOwner()->getParentOp() == op) {

117 return isZeroValue(op->getOperand(arg.getArgNumber()));

118 }

119 }

120 return isZeroValue(yieldOp.getOperand(0));

121 }

122

123

124

127 for (const auto &d : enumerate(stp.getShape())) {

129 if (d.value() == ShapedType::kDynamic)

130 dim = builder.createtensor::DimOp(loc, tensor, d.index());

131 else

133 sizes.push_back(dim);

134 }

135 }

136

138 bool needTmpCOO) {

139 return needTmpCOO ? stt.getCOOType(false)

141 }

142

143

144

145

148 for (const auto &d : enumerate(tp.getShape())) {

149 if (d.value() == ShapedType::kDynamic)

150 dynSizes.push_back(sizes[d.index()]);

151 }

152 }

153

156 SparseElementsAttr attr) {

157 auto loc = op.getLoc();

159

160

162 rewriter, loc, attr, op.getOrder().value_or(AffineMap()),

165 args.append(cvs.begin(), cvs.end());

166 args.push_back(v);

167 args.append(reduc);

168

169 auto cloned = cast(rewriter.clone(*op.getOperation()));

170 assert(args.size() == cloned.getBody()->getNumArguments());

171 Operation *yield = cloned.getBody()->getTerminator();

173

174 rewriter.eraseOp(cloned);

175 reduc = yield->getOperands();

177 });

178

180 return success();

181 }

182

183

184

188 unsigned dim) {

189 auto dstShape = dstTp.getShape();

191

192

193 if (dstShape[dim] != ShapedType::kDynamic) {

194

195 sizes[dim] = constantIndex(builder, loc, dstShape[dim]);

196 } else {

197

198 for (const auto &src : srcs.drop_front()) {

200

201 sizes[dim] = builder.createarith::AddIOp(loc, sizes[dim], srcSz);

202 }

203 }

204 }

205

206

207

208

209

210 namespace {

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226 struct FuseExtractSliceWithConcat

229

230 LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,

232 auto concatOp = extractOp.getSource().getDefiningOptensor::ConcatOp();

233 if (!concatOp)

234 return failure();

235

236 Location loc = extractOp.getLoc();

237 int64_t dim = concatOp.getDim();

238 int64_t rank = extractOp.getResultType().getRank();

239

242

243

247 for (auto [idx, input] :

250 partialSums.push_back(sum);

251 offsetStrides.push_back(

252 rewriter.createOrFoldtensor::DimOp(loc, input, dim));

253 }

254 auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,

258 rewriter, loc, partialSumMap, offsetStrides);

259

261 for (auto [l, r] : llvm::zip(lhs, rhs)) {

264 return false;

265 }

266 return lhs.size() == rhs.size();

267 };

268

269 for (auto [i, input, offset] :

273 srcOffsets[dim] = offset;

274

278

279 if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&

280 allEqual(srcStrides, dstStrides)) {

281 Value operand = concatOp.getOperand(i);

282 if (operand.getType() == extractOp.getResultType())

283 rewriter.replaceOp(extractOp, operand);

284 break;

285 }

286 }

287

288 return success();

289 }

290 };

291

292

293 struct FoldConvertIntoProducer : public OpRewritePattern {

294 public:

296

297 LogicalResult matchAndRewrite(ConvertOp op,

299 auto producer = op.getSource().getDefiningOp();

300 if (!producer || producer.getDpsInits().size() != 1 ||

301 isMaterializing(producer.getDpsInitOperand(0), false) ||

302 !producer.getResult(0).hasOneUse()) {

303 return failure();

304 }

305

307 Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();

310

312 producer.getDpsInitsMutable().assign(cloned->getResults());

313 producer.getResult(0).setType(op.getResult().getType());

314 });

315

317 op->erase();

318

319 return success();

320 }

321 };

322

323

324 struct FoldInvariantYield : public OpRewritePattern {

325 public:

327

328 LogicalResult matchAndRewrite(GenericOp op,

330 if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||

331 isMaterializing(op.getDpsInitOperand(0), false) ||

332 isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())

333 return failure();

335

336

338 rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());

339 return success();

340 }

341

342 if (!outputType.hasStaticShape())

343 return failure();

344 Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();

347 return success();

348 }

349 };

350

351

352

353

354

355

356

357

358

359

360

361

362

363

364

365

366 struct FuseSparseMultiplyOverAdd : public OpRewritePattern {

367 public:

369

370 LogicalResult matchAndRewrite(GenericOp op,

372

373 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||

374 op.getNumResults() != 1 ||

375 op.getNumParallelLoops() != op.getNumLoops() ||

376 !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||

377 !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||

378 !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())

379 return failure();

380

381

382

383

384 unsigned other = 0;

386 other = 1;

388 return failure();

389

390 auto prod = dyn_cast_or_null(

391 op.getDpsInputOperand(other)->get().getDefiningOp());

392 if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||

393 !prod.getResult(0).hasOneUse())

394 return failure();

395

396 if (isMaterializing(op.getDpsInitOperand(0), false) ||

397 isMaterializing(prod.getDpsInitOperand(0), true) ||

399 return failure();

400

401 Location loc = prod.getLoc();

405 inputOps.push_back(op.getDpsInputOperand(1 - other)->get());

406 fusedIndexMaps.push_back(fusedIndexMaps.back());

407

408 auto fusedOp = rewriter.create(

409 loc, op.getResult(0).getType(), inputOps, outputOps,

411 nullptr, nullptr);

412 Block &prodBlock = prod.getRegion().front();

413 Block &consBlock = op.getRegion().front();

415 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());

417 for (unsigned i = 0; i < num - 1; i++)

418 addArg(mapper, fusedBlock, prodBlock.getArgument(i));

419 addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));

420 addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));

421

422 auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();

425 for (auto &op : prodBlock.without_terminator())

426 if (&op != acc) {

427 last = op.getResult(0);

428 rewriter.clone(op, mapper);

429 }

431 mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));

433 rewriter.createlinalg::YieldOp(loc, last);

434

435

437 Value init = prod.getDpsInitOperand(0)

438 ->get()

439 .getDefiningOp()

440 .getCopy();

441 AllocTensorOp a =

442 op.getDpsInitOperand(0)->get().getDefiningOp();

443 rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });

444 }

445

446

447 rewriter.replaceOp(op, fusedOp->getResults());

448 return success();

449 }

450

451 private:

452

455 }

456 };

457

458

459

460

461

462

463 struct FuseTensorCast : public OpRewritePatterntensor::CastOp {

464 public:

466

467 LogicalResult matchAndRewrite(tensor::CastOp op,

469 Type srcType = op.getSource().getType();

470 Type dstType = op.getDest().getType();

471

472 if (srcType == dstType) {

473 rewriter.replaceOp(op, op->getResults());

474 return success();

475 }

476

478 if (Operation *def = op.getSource().getDefiningOp()) {

479 if (def->hasOneUse() && isatensor::ExtractSliceOp(def)) {

481 def->getResult(0).setType(op->getResultTypes()[0]);

482 });

484 return success();

485 }

486 }

487 }

488

489

492 return success();

493 }

494

495 return failure();

496 }

497 };

498

499

500

501

502

503

504

505

506

507

508

509

510

511

512

513

514

515 struct GenSemiRingSelect : public OpRewritePattern {

516 public:

518 LogicalResult matchAndRewrite(GenericOp op,

520

522 return failure();

523

526 for (Operation &inst : *op.getBody()) {

527

528 auto matched = isRewritablePattern(op, &inst);

529 if (!matched.has_value())

530 continue;

531

533 auto [c, t, f] = matched.value();

534 assert(t.getType() == f.getType());

535 auto selTp = t.getType();

536 auto c0 = constantZero(rewriter, loc, selTp);

537 auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f);

538

539 rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},

540 {t.getLoc(), f.getLoc()});

541 rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());

542 rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());

543

544 for (auto *r : binOp.getRegions()) {

545 Block *b = &r->front();

547

549

550

552 if (auto *def = c.getDefiningOp())

554

555 irMap.map(c, newC);

556 if (r == &binOp.getLeftRegion()) {

558 irMap.map(f, c0);

559 } else if (r == &binOp.getRightRegion()) {

560 irMap.map(t, c0);

562 } else {

565 }

567 rewriter.create<sparse_tensor::YieldOp>(loc, y);

568 }

569

570

571

572

573 semiRings.emplace_back(&inst, binOp);

574 }

575

576

577 for (auto [sel, semi] : semiRings)

578 rewriter.replaceOp(sel, semi->getResults());

579

580 return success(!semiRings.empty());

581 }

582

583 private:

584 static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>

585 isRewritablePattern(GenericOp op, Operation *v) {

586 auto sel = dyn_castarith::SelectOp(v);

587 if (!sel)

588 return std::nullopt;

589

590 auto tVal = dyn_cast(sel.getTrueValue());

591 auto fVal = dyn_cast(sel.getFalseValue());

592

593

594

595 if (!tVal || !fVal)

596 return std::nullopt;

597

598

599

600 auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {

601 if (auto bArg = dyn_cast(v);

602 bArg && isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))

603 return true;

604

605 return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();

606 };

607

608

609

610 auto cond = sel.getCondition();

611 if (isValFromDenseInputOrInvariant(cond))

612 return std::make_tuple(cond, tVal, fVal);

613

614 Value cmpL, cmpR;

619

620

621 if (isValFromDenseInputOrInvariant(cmpL) ||

622 isValFromDenseInputOrInvariant(cmpR))

623 return std::make_tuple(cond, tVal, fVal);

624 }

625

626 return std::nullopt;

627 };

628 };

629

630

631

632

633

634

635

636

637

638

639

640

641

642

643

644

645

646

647 struct GenSemiRingReduction : public OpRewritePattern {

648 public:

650

651 LogicalResult matchAndRewrite(GenericOp op,

653

654 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||

655 op.getNumReductionLoops() == 0 || op.getNumResults() != 1)

656 return failure();

657 auto *inp = op.getDpsInputOperand(0);

658 auto *init = op.getDpsInitOperand(0);

660 return failure();

661

662 auto *red = castlinalg::YieldOp(op.getRegion().front().getTerminator())

663 .getOperand(0)

664 .getDefiningOp();

665 if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,

666 arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,

667 arith::MaxUIOp>(red))

668 return failure();

669 Value s0 = op.getBlock()->getArgument(0);

670 Value s1 = op.getBlock()->getArgument(1);

671 if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&

672 (red->getOperand(0) != s1 || red->getOperand(1) != s0))

673 return failure();

674

677 rewriter.createtensor::ExtractOp(loc, init->get(), ValueRange());

678

679

680

681

684 auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0);

686 rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);

688 rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0));

689 rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});

691 auto zero =

692 rewriter.createarith::ConstantOp(loc, rewriter.getZeroAttr(rtp));

693 rewriter.create<sparse_tensor::YieldOp>(loc, zero);

695

696

697

698 auto custom = rewriter.create<sparse_tensor::ReduceOp>(

699 loc, rtp, semiring.getResult(), s1, identity);

701 rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});

704 irMap.map(red->getOperand(0), region->getArgument(0));

705 irMap.map(red->getOperand(1), region->getArgument(1));

706 auto *cloned = rewriter.clone(*red, irMap);

707 rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0));

709 rewriter.replaceOp(red, custom.getResult());

710 return success();

711 }

712 };

713

714

715

716

718 public:

720 LogicalResult matchAndRewrite(PrintOp op,

723 auto tensor = op.getTensor();

725

726 auto nse = rewriter.create(loc, tensor);

727 rewriter.createvector::PrintOp(

728 loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));

729 rewriter.createvector::PrintOp(loc, nse);

730

732 printSizes(rewriter, loc, tensor, stt.getDimRank(), true);

734 printSizes(rewriter, loc, tensor, stt.getLvlRank(), false);

735

736

741 switch (kind) {

742 case SparseTensorFieldKind::StorageSpec: {

743 break;

744 }

745 case SparseTensorFieldKind::PosMemRef: {

748 rewriter.createvector::PrintOp(

749 loc, lvl, vector::PrintPunctuation::NoPunctuation);

751 auto pos = rewriter.create(loc, tensor, l);

752 printContents(rewriter, loc, pos);

753 break;

754 }

755 case SparseTensorFieldKind::CrdMemRef: {

758 rewriter.createvector::PrintOp(

759 loc, lvl, vector::PrintPunctuation::NoPunctuation);

761 Value crd = nullptr;

762

763

764

765 if (stt.getAoSCOOStart() == l)

766 crd = rewriter.create(loc, tensor);

767 else

768 crd = rewriter.create(loc, tensor, l);

769 printContents(rewriter, loc, crd);

770 break;

771 }

772 case SparseTensorFieldKind::ValMemRef: {

773 rewriter.createvector::PrintOp(loc,

775 auto val = rewriter.create(loc, tensor);

776 printContents(rewriter, loc, val);

777 break;

778 }

779 }

780 return true;

781 });

784 return success();

785 }

786

787 private:

788

789

790

791

792

793

794

797 auto shape = cast(vec.getType()).getShape();

799 printContentsLevel(rewriter, loc, vec, 0, shape, idxs);

800 rewriter.createvector::PrintOp(loc, vector::PrintPunctuation::NewLine);

801 }

802

803

807

808 rewriter.createvector::PrintOp(loc, vector::PrintPunctuation::Open);

809

812 auto size = rewriter.creatememref::DimOp(loc, vec, index);

814 auto forOp = rewriter.createscf::ForOp(loc, zero, size, step);

815 idxs.push_back(forOp.getInductionVar());

817 if (i < shape.size() - 1) {

818

819 printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);

820 } else {

821

822 auto val = rewriter.creatememref::LoadOp(loc, vec, idxs);

823 if (llvm::isa(val.getType())) {

824

825

826 Value real = rewriter.createcomplex::ReOp(loc, val);

827 Value imag = rewriter.createcomplex::ImOp(loc, val);

828 rewriter.createvector::PrintOp(loc, vector::PrintPunctuation::Open);

829 rewriter.createvector::PrintOp(loc, real,

830 vector::PrintPunctuation::Comma);

831 rewriter.createvector::PrintOp(loc, imag,

832 vector::PrintPunctuation::Close);

833 } else {

834 rewriter.createvector::PrintOp(

835 loc, val, vector::PrintPunctuation::NoPunctuation);

836 }

837

838 auto bound = rewriter.createarith::AddIOp(loc, idxs.back(), step);

839 Value cond = rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::ne,

840 bound, size);

841 scf::IfOp ifOp = rewriter.createscf::IfOp(loc, cond, false);

843 rewriter.createvector::PrintOp(loc, vector::PrintPunctuation::Comma);

844 }

845 idxs.pop_back();

847

848 rewriter.createvector::PrintOp(loc, vector::PrintPunctuation::Close);

849 }

850

851

853 unsigned size, bool isDim) {

854

855 rewriter.createvector::PrintOp(loc, vector::PrintPunctuation::Open);

856

857 for (unsigned i = 0; i < size; i++) {

860 if (isDim)

861 val = rewriter.createtensor::DimOp(loc, tensor, idx);

862 else

863 val = rewriter.create(loc, tensor, idx);

864 rewriter.createvector::PrintOp(

865 loc, val,

866 i != size - 1 ? vector::PrintPunctuation::Comma

867 : vector::PrintPunctuation::NoPunctuation);

868 }

869

870 rewriter.createvector::PrintOp(loc, vector::PrintPunctuation::Close);

871 rewriter.createvector::PrintOp(loc, vector::PrintPunctuation::NewLine);

872 }

873 };

874

875

876 struct TensorReshapeRewriter : public OpRewritePatterntensor::ReshapeOp {

877 public:

879

880 LogicalResult matchAndRewrite(tensor::ReshapeOp op,

883 Value srcTensor = op.getSource();

886 if (!srcTp || !dstTp)

887 return failure();

888

889 if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||

890 !dstTp->hasStaticDimShape())

891 return failure();

892

894 sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor);

896 for (Dimension d : dstTp->getDimShape())

897 dstSizes.push_back(constantIndex(rewriter, loc, d));

898

899 Value nnz = rewriter.create(loc, srcTensor);

900

901

903 dstTp->withoutDimToLvl(),

904 !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());

906 Value buffer = rewriter

907 .create(loc, bufferTp, dynSizes, Value(),

909 .getResult();

910

911

912

913

914

915

916

917

918

919

920

921

922 const auto encSrc = srcTp->getEncoding();

923 ForeachOp foreachOp = rewriter.create(

924 loc, srcTensor, buffer,

927 const Dimension srcRank = srcTp->getDimRank();

929 srcDcvs.reserve(srcRank);

930 for (Dimension d = 0; d < srcRank; d++) {

932 srcDcvs.push_back(srcLcvs[lvl]);

933 }

934

936 for (Dimension d = 0; d < srcRank; d++)

937 collapseSize =

938 builder.createarith::MulIOp(loc, collapseSize, srcSizes[d]);

940

942 for (Dimension i = 0; i < srcRank; i++)

943 collapseIdx.push_back(i);

946 reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,

947 collapsedSizes, collapsedDcvs);

948

950 for (Dimension i = 0; i < dstTp->getDimRank(); i++)

951 expandIdx.push_back(i);

954 reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,

955 dstSizes, dstDcvs);

956

957 auto t =

958 builder.createtensor::InsertOp(loc, v, reduc.front(), dstDcvs);

959 builder.create<sparse_tensor::YieldOp>(loc, t);

960 });

961

962 Value t = rewriter.create(loc, foreachOp.getResult(0), true);

963 if (bufferTp != *dstTp) {

964 auto dstRTT = dstTp->getRankedTensorType();

965 Value converted = rewriter.create(loc, dstRTT, t).getResult();

966 rewriter.create(loc, t);

967 t = converted;

968 }

970 return success();

971 }

972 };

973

974

975 template

976 struct Sparse2SparseReshapeRewriter : public OpRewritePattern {

977 public:

979

980 LogicalResult matchAndRewrite(ReshapeOp op,

983 Value srcTensor = op.getSrc();

986 if (!srcTp.hasEncoding() || !dstTp.hasEncoding())

987 return failure();

988

989

990

992 sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);

995 if (dstTp.hasStaticDimShape()) {

996 for (Dimension d : dstTp.getDimShape())

997 dstSizes.push_back(constantIndex(rewriter, loc, d));

998 } else {

1001 op.getReassociationIndices());

1003 if (shape == ShapedType::kDynamic)

1004 dstDynSizes.push_back(dstSizes[idx]);

1005 }

1006 }

1007 Value nnz = rewriter.create(loc, srcTensor);

1008

1009

1011 dstTp.withoutDimToLvl(),

1012 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());

1013

1015 rewriter

1016 .create(loc, bufferTp, dstDynSizes, Value(),

1018 .getResult();

1019

1020

1021

1022

1023

1024

1025

1026

1027 const auto encSrc = srcTp.getEncoding();

1028 ForeachOp foreachOp = rewriter.create(

1029 loc, srcTensor, buffer,

1032 const Dimension dimRank = srcTp.getDimRank();

1034 srcDcvs.reserve(dimRank);

1035 for (Dimension d = 0; d < dimRank; d++) {

1037 srcDcvs.push_back(srcLcvs[lvl]);

1038 }

1040 reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,

1041 srcDcvs, dstSizes, dstDcvs);

1042 auto t =

1043 builder.createtensor::InsertOp(loc, v, reduc.front(), dstDcvs);

1044 builder.create<sparse_tensor::YieldOp>(loc, t);

1045 });

1046

1047 Value t = rewriter.create(loc, foreachOp.getResult(0), true);

1048 if (bufferTp != dstTp) {

1049 auto dstRTT = dstTp.getRankedTensorType();

1050 Value converted = rewriter.create(loc, dstRTT, t).getResult();

1051 rewriter.create(loc, t);

1052 t = converted;

1053 }

1055 return success();

1056 }

1057 };

1058

1059

1060

1061 template

1062 struct ReshapeRewriter : public OpRewritePattern {

1063 public:

1065

1066 LogicalResult matchAndRewrite(ReshapeOp op,

1068 Location loc = op->getLoc();

1071

1072

1073

1074

1075 if (encDst && encSrc) {

1076 return failure();

1077 }

1078 if (encSrc) {

1080 auto denseTp =

1082 auto convert = rewriter.create(loc, denseTp, op.getSrc());

1083 rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });

1084 return success();

1085 }

1086 if (encDst) {

1088 auto denseTp =

1090 ReshapeOp reshape;

1091 if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {

1092 reshape = rewriter.create(

1093 loc, denseTp, op.getSrc(), op.getReassociation(),

1094 op.getOutputShape(), op.getStaticOutputShape());

1095 } else {

1096 reshape = rewriter.create(loc, denseTp, op.getSrc(),

1097 op.getReassociation());

1098 }

1099 Value convert = rewriter.create(loc, rtp, reshape);

1100 rewriter.replaceOp(op, convert);

1101 return success();

1102 }

1103 return failure();

1104 }

1105 };

1106

1107

1108

1109 struct TensorLike {

1110 TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,

1114

1115 val = builder.create(loc, rtt, dynSzs);

1116 if (!isSparse()) {

1118 val = builder.createlinalg::FillOp(loc, c0, val).getResult(0);

1119 }

1120 }

1121

1123 val = builder.createtensor::InsertOp(loc, v, val, crds);

1124 }

1125

1127 if (isSparse())

1128 return builder.create(loc, val, true);

1129 return val;

1130 }

1131

1132 bool isSparse() const {

1134 }

1135

1137 };

1138

1139 struct SparseTensorDimOpRewriter : public OpRewritePatterntensor::DimOp {

1141 LogicalResult matchAndRewrite(tensor::DimOp op,

1143 std::optional<int64_t> dim = op.getConstantIndex();

1145 if (!dim || !stt || !stt->hasEncoding())

1146 return failure();

1147

1148 if (stt->isPermutation()) {

1150 toLvl(stt->getEncoding(), *dim));

1151 return success();

1152 }

1153

1154

1155

1156

1157

1158

1159

1162 for (Level l = 0; l < stt->getLvlRank(); l++) {

1163 Value lvlSz = rewriter.create(loc, op.getSource(), l);

1164 Value maxLvlCrd = rewriter.createarith::SubIOp(

1166 maxLvlCrds.push_back(maxLvlCrd);

1167 }

1168

1169 AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);

1170 Value maxDimCrd = rewriter.createaffine::AffineApplyOp(

1171 op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),

1172 maxLvlCrds);

1173

1174 Value dimSz = rewriter.createarith::AddIOp(

1177 return success();

1178 }

1179 };

1180

1181 struct ConcatenateRewriter : public OpRewritePattern {

1183 LogicalResult matchAndRewrite(ConcatenateOp op,

1185 if (op.needsExtraSort())

1186 op.emitError("ConcatenateOp not staged");

1187

1188 const Location loc = op.getLoc();

1190 const Dimension conDim = op.getDimension();

1193

1194

1195

1196

1197

1198

1199

1200

1201

1202

1203

1204

1205

1206

1207 TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);

1209 Value iterArg = dstBuf.val;

1210

1211 ForeachOp foreachOp;

1212 for (Value input : op.getInputs()) {

1213

1214

1215 foreachOp = rewriter.create(

1216 loc, input, iterArg,

1220 offDimCrd[conDim] =

1221 builder.createarith::AddIOp(loc, offDimCrd[conDim], offset);

1222

1223

1224 dstBuf.val = reduc.front();

1225 if (!dstTp.isAllDense()) {

1227 auto ifOp = builder.createscf::IfOp(loc, reduc.getTypes(), cond,

1228 true);

1229 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());

1230 builder.createscf::YieldOp(loc, dstBuf.val);

1231

1232 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());

1233 dstBuf.insert(builder, loc, v, offDimCrd);

1234 builder.createscf::YieldOp(loc, dstBuf.val);

1235

1236

1237 builder.setInsertionPointAfter(ifOp);

1238 dstBuf.val = ifOp.getResult(0);

1239 } else {

1240 dstBuf.insert(builder, loc, v, offDimCrd);

1241 }

1242 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);

1243 });

1244

1245

1246

1248 assert(!ShapedType::isDynamic(sz));

1249 offset = rewriter.createarith::AddIOp(loc, offset,

1251 iterArg = foreachOp.getResult(0);

1252 dstBuf.val = iterArg;

1253 }

1254

1255 dstBuf.val = iterArg;

1256 Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());

1258 return success();

1259 }

1260 };

1261

1262 struct DirectConvertRewriter : public OpRewritePattern {

1264 LogicalResult matchAndRewrite(ConvertOp op,

1266 if (op.needsExtraSort())

1267 return op.emitError("ConvertOp not staged.");

1268

1269

1272 if (encDst && encSrc && !encSrc.isSlice() &&

1273 encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {

1274

1275

1276 return failure();

1277 }

1278

1280 Value src = op.getSource();

1281

1284

1285 bool fromSparseConst = false;

1286 if (auto constOp = op.getSource().getDefiningOparith::ConstantOp())

1287 if (isa(constOp.getValue()))

1288 fromSparseConst = true;

1289

1290 const AffineMapAttr foreachOrder =

1291 (!dstStt.isIdentity() && fromSparseConst)

1293 : nullptr;

1294

1295 bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;

1296

1301

1302 auto foreachOp = rewriter.create(

1303 loc, src, dstBuf.val, foreachOrder,

1306

1307 dstBuf.val = reduc.front();

1308 if (!skipZeroCheck) {

1310 auto ifOp = builder.createscf::IfOp(loc, reduc.getTypes(), cond,

1311 true);

1312 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());

1313 builder.createscf::YieldOp(loc, dstBuf.val);

1314

1315 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());

1316 dstBuf.insert(builder, loc, v, dcvs);

1317 builder.createscf::YieldOp(loc, dstBuf.val);

1318

1319

1320 builder.setInsertionPointAfter(ifOp);

1321 dstBuf.val = ifOp.getResult(0);

1322 } else {

1323 dstBuf.insert(builder, loc, v, dcvs);

1324 }

1325 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);

1326 });

1327

1329

1330

1331 dstBuf.val = foreachOp.getResult(0);

1332

1335 return success();

1336 }

1337 };

1338

1339 struct CrdTranslateRewriter : public OpRewritePattern {

1341 LogicalResult matchAndRewrite(CrdTranslateOp op,

1343 AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl

1344 ? op.getEncoder().getDimToLvl()

1345 : op.getEncoder().getLvlToDim();

1346

1349

1350

1351

1352 Value trans = rewriter.createaffine::AffineApplyOp(

1354 op.getInCrds());

1355 outCrds.push_back(trans);

1356 }

1357 rewriter.replaceOp(op, outCrds);

1358 return success();

1359 }

1360 };

1361

1362

1363 struct ForeachRewriter : public OpRewritePattern {

1364 public:

1366

1367 LogicalResult matchAndRewrite(ForeachOp op,

1369

1370 auto loc = op.getLoc();

1371 Value input = op.getTensor();

1374 const Level lvlRank = stt.getLvlRank();

1375

1376

1377

1378 if (auto constOp = input.getDefiningOparith::ConstantOp()) {

1379 if (auto attr = dyn_cast(constOp.getValue())) {

1381 }

1382 }

1383

1384

1385 const auto enc = stt.getEncoding();

1386

1387

1392 for (Level l = 0; l < lvlRank; l++) {

1393

1394

1396 loopEmitter.makeTensorLevel(0, l)};

1397 loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);

1398

1399

1400 loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,

1401 reduc);

1402 }

1403

1405 if (op.getOrder()) {

1406

1407 llvm_unreachable(

1408 "Level order not yet implemented on non-constant input tensors.");

1409 }

1410

1411 Value vals = loopEmitter.getValBuffer()[0];

1413

1414

1415 Value val = enc ? rewriter.creatememref::LoadOp(loc, vals, pos)

1416 : rewriter.creatememref::LoadOp(loc, vals, lcvs);

1417

1418

1419 Block *srcBlock = op.getBody();

1420

1421

1423 enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);

1424

1425

1426 args.push_back(val);

1427

1428 args.append(reduc);

1429

1430

1433

1435 if (llvm::isascf::YieldOp(last)) {

1436

1437

1438

1440 }

1441

1445 for (Level l = 0; l < lvlRank; l++) {

1446

1447

1448 loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);

1449 loopEmitter.exitCurrentLoopSeq(rewriter, loc);

1450 }

1451

1452

1453

1454 rewriter.replaceOp(op, reducValue);

1455 return success();

1456 }

1457 };

1458

1459

1462 LogicalResult matchAndRewrite(NewOp op,

1466 if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)

1467 return failure();

1468

1469

1470

1471

1472

1473 RankedTensorType dstTp = stt.getRankedTensorType();

1474 RankedTensorType cooTp = stt.getCOOType(true);

1475 Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());

1476 Value convert = cooTensor;

1477 auto enc = stt.getEncoding();

1478 if (!stt.isPermutation()) {

1480 convert = rewriter.create(loc, coo, convert);

1482 }

1483 convert = rewriter.create(loc, dstTp, convert);

1484 if (!stt.isPermutation())

1485 convert = rewriter.create(loc, enc, convert);

1486 rewriter.replaceOp(op, convert);

1487

1488

1490 rewriter.create(loc, cooTensor);

1491

1492 return success();

1493 }

1494 };

1495

1496

1499 LogicalResult matchAndRewrite(OutOp op,

1502

1503 Value src = op.getTensor();

1504 Value nnz = rewriter.create(loc, src);

1505

1506

1508 const Dimension dimRank = srcTp.getDimRank();

1510 Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);

1511

1512

1513

1516 for (Dimension d = 0; d < dimRank; d++) {

1517 rewriter.creatememref::StoreOp(loc, dims[d], dimSizes,

1519 }

1520

1521

1524 createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},

1525 {op.getDest()}, EmitCInterface::Off)

1526 .getResult(0);

1528 createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},

1529 {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);

1530

1531 Value dimCoords = dimSizes;

1532 Type eltTp = srcTp.getElementType();

1533 SmallString<29> outNextFuncName{"outSparseTensorWriterNext",

1536 ModuleOp module = op->getParentOfType();

1537

1538

1539 rewriter.create(

1540 loc, src, std::nullopt,

1543 for (Dimension d = 0; d < dimRank; d++) {

1544 rewriter.creatememref::StoreOp(loc, dcvs[d], dimCoords,

1546 }

1547 rewriter.creatememref::StoreOp(loc, v, value);

1550 EmitCInterface::On);

1551 builder.createfunc::CallOp(loc, TypeRange(), fn, operands);

1552 builder.create<sparse_tensor::YieldOp>(loc);

1553 });

1554

1555

1556 createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},

1557 EmitCInterface::Off);

1558

1560 return success();

1561 }

1562 };

1563

1564 }

1565

1566

1567

1568

1569

1571 patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,

1572 FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,

1573 GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(

1575 }

1576

1578 bool enableRT,

1579 bool enableConvert) {

1580 patterns.add<ConcatenateRewriter, ReshapeRewritertensor::ExpandShapeOp,

1581 ReshapeRewritertensor::CollapseShapeOp,

1582 Sparse2SparseReshapeRewritertensor::ExpandShapeOp,

1583 Sparse2SparseReshapeRewritertensor::CollapseShapeOp,

1584 SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(

1586

1587 if (enableConvert)

1589 if (!enableRT)

1591 }

1592

1594

1595

1596 patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());

1597 }

static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)

Copies the given number of bytes from src to dst pointers.

static MLIRContext * getContext(OpFoldResult val)

union mlir::linalg::@1203::ArityGroupAndKind::Kind kind

static bool isMulChain(Value val, Value x)

static bool isSampling(GenericOp op)

static bool isSumOfMul(GenericOp op)

static bool isZeroValue(Value val)

static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)

Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...

static LogicalResult genForeachOnSparseConstant(ForeachOp op, RewriterBase &rewriter, SparseElementsAttr attr)

static bool isMaterializing(OpOperand *op, bool isZero)

static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType dstTp, ValueRange srcs, unsigned dim)

Populates the given sizes array for concatenation from types (for static sizes) and from the source t...

static bool isSparseTensor(Value v)

static bool isZeroYield(GenericOp op)

static void sizesForTensor(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType stp, Value tensor)

Populates given sizes array from type (for static sizes) and from the tensor (for dynamic sizes).

@ NewOp

Op vectorized into a new Op whose results will replace original Op's results.

Base type for affine expression.

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

static AffineMap get(MLIRContext *context)

Returns a zero result affine map with no dimensions or symbols: () -> ().

unsigned getNumDims() const

ArrayRef< AffineExpr > getResults() const

Attributes are known-constant values of operations.

This class represents an argument of a Block.

Location getLoc() const

Return the location for this argument.

Block represents an ordered list of Operations.

BlockArgument getArgument(unsigned i)

unsigned getNumArguments()

Operation * getTerminator()

Get the terminator operation of this block.

BlockArgument addArgument(Type type, Location loc)

Add one value to the argument list.

IntegerAttr getIndexAttr(int64_t value)

StringAttr getStringAttr(const Twine &bytes)

TypedAttr getZeroAttr(Type type)

AffineExpr getAffineDimExpr(unsigned position)

MLIRContext * getContext() const

ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)

A symbol reference with a reference path containing a single element.

This is a utility class for mapping one set of IR entities to another.

void map(Value from, Value to)

Inserts a new mapping for 'from' to 'to'.

IRValueT get() const

Return the current value being used by this operand.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

This class helps build Operations.

Block::iterator getInsertionPoint() const

Returns the current insertion point of the builder.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Block * getBlock() const

Returns the current block of the builder.

void setInsertionPointAfterValue(Value val)

Sets the insertion point to the node after the specified value.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

This class represents an operand of an operation.

Operation is the basic unit of execution within MLIR.

Value getOperand(unsigned idx)

bool hasOneUse()

Returns true if this operation has exactly one use.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

Block * getBlock()

Returns the operation block that contains this operation.

operand_range getOperands()

Returns an iterator on the underlying Value's.

result_range getResults()

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)

Inline the operations of block 'source' into block 'dest' before the given position.

void replaceAllOpUsesWith(Operation *from, ValueRange to)

Find uses of from and replace them with to.

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

void setType(Type newType)

Mutate the type of this Value to be of the specified type.

Type getType() const

Return the type of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)

Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.

A wrapper around RankedTensorType, which has three goals:

Size getDynamicDimSize(Dimension d) const

Safely looks up the requested dimension-DynSize.

bool hasEncoding() const

Returns true for tensors which have an encoding, and false for those which do not.

SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const

bool isIdentity() const

Returns true if the dimToLvl mapping is the identity.

RankedTensorType getRankedTensorType() const

Explicitly convert to RankedTensorType.

AffineMap getExpandedDimToLvl() const

Returns the dimToLvl mapping, where the identity map is expanded out into a full AffineMap.

RankedTensorType getCOOType(bool ordered) const

Returns [un]ordered COO type for this sparse tensor type.

SparseTensorEncodingAttr getEncoding() const

SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Variant of makeComposedFoldedAffineApply suitable for multi-result maps.

FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)

Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)

Create one memref::DimOp or tensor::DimOp depending on the type of val.

FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)

Returns a function reference (first hit also inserts into module).

Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)

Generates an uninitialized temporary buffer with room for one value of the given type,...

Value constantIndex(OpBuilder &builder, Location loc, int64_t i)

Generates a constant of index type.

void foreachInSparseConstant(OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref< void(ArrayRef< Value >, Value)> callback)

Iterate over a sparse constant, generates constantOp for value and coordinates.

Value constantZero(OpBuilder &builder, Location loc, Type tp)

Generates a 0-valued constant of the given type.

Value constantOne(OpBuilder &builder, Location loc, Type tp)

Generates a 1-valued constant of the given type.

void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)

unsigned FieldIndex

The type of field indices.

uint64_t Dimension

The type of dimension identifiers and dimension-ranks.

uint64_t Level

The type of level identifiers and level-ranks.

std::optional< SparseTensorType > tryGetSparseTensorType(Value val)

int64_t Size

The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...

RankedTensorType getRankedTensorType(T &&t)

Convenience method to abbreviate casting getType().

Type getOpaquePointerType(MLIRContext *ctx)

Returns the equivalent of void* for opaque arguments to the execution engine.

SparseTensorEncodingAttr getSparseTensorEncoding(Type type)

Convenience method to get a sparse encoding attribute from a type.

Value genIsNonzero(OpBuilder &builder, Location loc, Value v)

Generates the comparison v != 0 where v is of numeric type.

Level toLvl(SparseTensorEncodingAttr enc, Dimension d)

Convenience method to translate the given dimension to the corresponding level.

Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)

Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...

void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl< Value > &dstShape, ArrayRef< Value > srcShape, ArrayRef< Size > staticDstShape, ArrayRef< ReassociationIndices > reassociation)

Computes the shape of destination tensor of a reshape operator.

SparseTensorType getSparseTensorType(Value val)

Convenience methods to obtain a SparseTensorType from a Value.

void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef< ReassociationIndices > reassociation, ValueRange srcSizes, ValueRange srcCvs, ValueRange dstSizes, SmallVectorImpl< Value > &dstCvs)

Reshape coordinates during a reshaping operation.

bool hasAnySparseOperand(Operation *op)

Returns true iff MLIR operand has any sparse operand.

SparseTensorFieldKind

===-------------------------------------------------------------------—===// The sparse tensor storag...

func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)

Creates a CallOp to the function reference returned by getFunc() in the builder's module.

StringRef primaryTypeFunctionSuffix(PrimaryType pt)

Convert PrimaryType to its function-name suffix.

void sizesFromSrc(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, Value src)

Populates given sizes array from dense tensor or sparse tensor constant.

bool isSameTypeWithoutEncoding(Type tp1, Type tp2)

Tests if types are the same when ignoring encoding on ranked tensors.

SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)

Return the dimensions of the given tensor value.

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

void populatePreSparsificationRewriting(RewritePatternSet &patterns)

detail::constant_int_predicate_matcher m_Zero()

Matches a constant scalar / vector splat / tensor splat integer zero.

const FrozenRewritePatternSet & patterns

detail::constant_float_predicate_matcher m_AnyZeroFloat()

Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.

void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...

This enum defines all the sparse representations supportable by the SparseTensor dialect.