MLIR: lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
15
32
33 using namespace mlir;
37
38
39
40
41
42
45 }
46
47
50 return enc && !llvm::all_of(enc.getLvlTypes(),
51 [](auto lt) { return lt == LevelFormat::Dense; });
52 }
54
55
58
59 if (auto alloc = val.getDefiningOp()) {
61 if (isZero)
64 }
65
66 if (auto empty = val.getDefiningOptensor::EmptyOp())
67 return !isZero;
68
70 }
71
72
74 auto yieldOp = castlinalg::YieldOp(op.getRegion().front().getTerminator());
75 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
76 if (isaarith::MulFOp(def) || isaarith::MulIOp(def)) {
77
78 Value s1 = op.getBlock()->getArgument(0);
79 Value s2 = op.getBlock()->getArgument(1);
80 return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
81 (def->getOperand(1) == s1 && def->getOperand(0) == s2);
82 }
83 }
84 return false;
85 }
86
87
89 if (auto arg = dyn_cast(val))
90 return arg != x;
92 if (isaarith::MulFOp(def) || isaarith::MulIOp(def))
93 return isMulChain(def->getOperand(0), x) &&
95 }
96 return false;
97 }
98
99
101 auto yieldOp = castlinalg::YieldOp(op.getRegion().front().getTerminator());
102 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
103 if (isaarith::AddFOp(def) || isaarith::AddIOp(def)) {
104 Value x = op.getBlock()->getArguments().back();
105 return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
106 (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
107 }
108 }
109 return false;
110 }
111
112
114 auto yieldOp = castlinalg::YieldOp(op.getRegion().front().getTerminator());
115 if (auto arg = dyn_cast(yieldOp.getOperand(0))) {
116 if (arg.getOwner()->getParentOp() == op) {
117 return isZeroValue(op->getOperand(arg.getArgNumber()));
118 }
119 }
120 return isZeroValue(yieldOp.getOperand(0));
121 }
122
123
124
127 for (const auto &d : enumerate(stp.getShape())) {
129 if (d.value() == ShapedType::kDynamic)
130 dim = builder.createtensor::DimOp(loc, tensor, d.index());
131 else
133 sizes.push_back(dim);
134 }
135 }
136
138 bool needTmpCOO) {
139 return needTmpCOO ? stt.getCOOType(false)
141 }
142
143
144
145
148 for (const auto &d : enumerate(tp.getShape())) {
149 if (d.value() == ShapedType::kDynamic)
150 dynSizes.push_back(sizes[d.index()]);
151 }
152 }
153
156 SparseElementsAttr attr) {
157 auto loc = op.getLoc();
159
160
162 rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
165 args.append(cvs.begin(), cvs.end());
166 args.push_back(v);
167 args.append(reduc);
168
169 auto cloned = cast(rewriter.clone(*op.getOperation()));
170 assert(args.size() == cloned.getBody()->getNumArguments());
171 Operation *yield = cloned.getBody()->getTerminator();
173
174 rewriter.eraseOp(cloned);
175 reduc = yield->getOperands();
177 });
178
180 return success();
181 }
182
183
184
188 unsigned dim) {
189 auto dstShape = dstTp.getShape();
191
192
193 if (dstShape[dim] != ShapedType::kDynamic) {
194
195 sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
196 } else {
197
198 for (const auto &src : srcs.drop_front()) {
200
201 sizes[dim] = builder.createarith::AddIOp(loc, sizes[dim], srcSz);
202 }
203 }
204 }
205
206
207
208
209
210 namespace {
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226 struct FuseExtractSliceWithConcat
229
230 LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
232 auto concatOp = extractOp.getSource().getDefiningOptensor::ConcatOp();
233 if (!concatOp)
234 return failure();
235
236 Location loc = extractOp.getLoc();
237 int64_t dim = concatOp.getDim();
238 int64_t rank = extractOp.getResultType().getRank();
239
242
243
247 for (auto [idx, input] :
250 partialSums.push_back(sum);
251 offsetStrides.push_back(
252 rewriter.createOrFoldtensor::DimOp(loc, input, dim));
253 }
254 auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
258 rewriter, loc, partialSumMap, offsetStrides);
259
261 for (auto [l, r] : llvm::zip(lhs, rhs)) {
264 return false;
265 }
266 return lhs.size() == rhs.size();
267 };
268
269 for (auto [i, input, offset] :
273 srcOffsets[dim] = offset;
274
278
279 if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
280 allEqual(srcStrides, dstStrides)) {
281 Value operand = concatOp.getOperand(i);
282 if (operand.getType() == extractOp.getResultType())
283 rewriter.replaceOp(extractOp, operand);
284 break;
285 }
286 }
287
288 return success();
289 }
290 };
291
292
293 struct FoldConvertIntoProducer : public OpRewritePattern {
294 public:
296
297 LogicalResult matchAndRewrite(ConvertOp op,
299 auto producer = op.getSource().getDefiningOp();
300 if (!producer || producer.getDpsInits().size() != 1 ||
301 (producer.getDpsInitOperand(0), false) ||
302 !producer.getResult(0).hasOneUse()) {
303 return failure();
304 }
305
307 Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
310
312 producer.getDpsInitsMutable().assign(cloned->getResults());
313 producer.getResult(0).setType(op.getResult().getType());
314 });
315
317 op->erase();
318
319 return success();
320 }
321 };
322
323
324 struct FoldInvariantYield : public OpRewritePattern {
325 public:
327
328 LogicalResult matchAndRewrite(GenericOp op,
330 if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
331 (op.getDpsInitOperand(0), false) ||
332 (op) || !op.getDpsInitOperand(0)->get().hasOneUse())
333 return failure();
335
336
338 rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
339 return success();
340 }
341
342 if (!outputType.hasStaticShape())
343 return failure();
344 Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
347 return success();
348 }
349 };
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366 struct FuseSparseMultiplyOverAdd : public OpRewritePattern {
367 public:
369
370 LogicalResult matchAndRewrite(GenericOp op,
372
373 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
374 op.getNumResults() != 1 ||
375 op.getNumParallelLoops() != op.getNumLoops() ||
376 !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
377 !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
378 !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
379 return failure();
380
381
382
383
384 unsigned other = 0;
386 other = 1;
388 return failure();
389
390 auto prod = dyn_cast_or_null(
391 op.getDpsInputOperand(other)->get().getDefiningOp());
392 if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
393 !prod.getResult(0).hasOneUse())
394 return failure();
395
396 if ((op.getDpsInitOperand(0), false) ||
397 (prod.getDpsInitOperand(0), true) ||
399 return failure();
400
401 Location loc = prod.getLoc();
405 inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
406 fusedIndexMaps.push_back(fusedIndexMaps.back());
407
408 auto fusedOp = rewriter.create(
409 loc, op.getResult(0).getType(), inputOps, outputOps,
411 nullptr, nullptr);
412 Block &prodBlock = prod.getRegion().front();
413 Block &consBlock = op.getRegion().front();
415 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
417 for (unsigned i = 0; i < num - 1; i++)
418 addArg(mapper, fusedBlock, prodBlock.getArgument(i));
419 addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
420 addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
421
422 auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
425 for (auto &op : prodBlock.without_terminator())
426 if (&op != acc) {
427 last = op.getResult(0);
428 rewriter.clone(op, mapper);
429 }
431 mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
433 rewriter.createlinalg::YieldOp(loc, last);
434
435
437 Value init = prod.getDpsInitOperand(0)
438 ->get()
439 .getDefiningOp()
440 .getCopy();
441 AllocTensorOp a =
442 op.getDpsInitOperand(0)->get().getDefiningOp();
443 rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
444 }
445
446
447 rewriter.replaceOp(op, fusedOp->getResults());
448 return success();
449 }
450
451 private:
452
455 }
456 };
457
458
459
460
461
462
463 struct FuseTensorCast : public OpRewritePatterntensor::CastOp {
464 public:
466
467 LogicalResult matchAndRewrite(tensor::CastOp op,
469 Type srcType = op.getSource().getType();
470 Type dstType = op.getDest().getType();
471
472 if (srcType == dstType) {
473 rewriter.replaceOp(op, op->getResults());
474 return success();
475 }
476
478 if (Operation *def = op.getSource().getDefiningOp()) {
479 if (def->hasOneUse() && isatensor::ExtractSliceOp(def)) {
481 def->getResult(0).setType(op->getResultTypes()[0]);
482 });
484 return success();
485 }
486 }
487 }
488
489
492 return success();
493 }
494
495 return failure();
496 }
497 };
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515 struct GenSemiRingSelect : public OpRewritePattern {
516 public:
518 LogicalResult matchAndRewrite(GenericOp op,
520
522 return failure();
523
526 for (Operation &inst : *op.getBody()) {
527
528 auto matched = isRewritablePattern(op, &inst);
529 if (!matched.has_value())
530 continue;
531
533 auto [c, t, f] = matched.value();
534 assert(t.getType() == f.getType());
535 auto selTp = t.getType();
536 auto c0 = constantZero(rewriter, loc, selTp);
537 auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
538
539 rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
540 {t.getLoc(), f.getLoc()});
541 rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
542 rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
543
544 for (auto *r : binOp.getRegions()) {
545 Block *b = &r->front();
547
549
550
552 if (auto *def = c.getDefiningOp())
554
555 irMap.map(c, newC);
556 if (r == &binOp.getLeftRegion()) {
558 irMap.map(f, c0);
559 } else if (r == &binOp.getRightRegion()) {
560 irMap.map(t, c0);
562 } else {
565 }
567 rewriter.create<sparse_tensor::YieldOp>(loc, y);
568 }
569
570
571
572
573 semiRings.emplace_back(&inst, binOp);
574 }
575
576
577 for (auto [sel, semi] : semiRings)
578 rewriter.replaceOp(sel, semi->getResults());
579
580 return success(!semiRings.empty());
581 }
582
583 private:
584 static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
585 isRewritablePattern(GenericOp op, Operation *v) {
586 auto sel = dyn_castarith::SelectOp(v);
587 if (!sel)
588 return std::nullopt;
589
590 auto tVal = dyn_cast(sel.getTrueValue());
591 auto fVal = dyn_cast(sel.getFalseValue());
592
593
594
595 if (!tVal || !fVal)
596 return std::nullopt;
597
598
599
600 auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
601 if (auto bArg = dyn_cast(v);
602 bArg && (op.getDpsInputOperand(bArg.getArgNumber())))
603 return true;
604
605 return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
606 };
607
608
609
610 auto cond = sel.getCondition();
611 if (isValFromDenseInputOrInvariant(cond))
612 return std::make_tuple(cond, tVal, fVal);
613
614 Value cmpL, cmpR;
619
620
621 if (isValFromDenseInputOrInvariant(cmpL) ||
622 isValFromDenseInputOrInvariant(cmpR))
623 return std::make_tuple(cond, tVal, fVal);
624 }
625
626 return std::nullopt;
627 };
628 };
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647 struct GenSemiRingReduction : public OpRewritePattern {
648 public:
650
651 LogicalResult matchAndRewrite(GenericOp op,
653
654 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
655 op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
656 return failure();
657 auto *inp = op.getDpsInputOperand(0);
658 auto *init = op.getDpsInitOperand(0);
660 return failure();
661
662 auto *red = castlinalg::YieldOp(op.getRegion().front().getTerminator())
663 .getOperand(0)
664 .getDefiningOp();
665 if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
666 arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
667 arith::MaxUIOp>(red))
668 return failure();
669 Value s0 = op.getBlock()->getArgument(0);
670 Value s1 = op.getBlock()->getArgument(1);
671 if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
672 (red->getOperand(0) != s1 || red->getOperand(1) != s0))
673 return failure();
674
677 rewriter.createtensor::ExtractOp(loc, init->get(), ValueRange());
678
679
680
681
684 auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0);
686 rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
688 rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0));
689 rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});
691 auto zero =
692 rewriter.createarith::ConstantOp(loc, rewriter.getZeroAttr(rtp));
693 rewriter.create<sparse_tensor::YieldOp>(loc, zero);
695
696
697
698 auto custom = rewriter.create<sparse_tensor::ReduceOp>(
699 loc, rtp, semiring.getResult(), s1, identity);
701 rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
704 irMap.map(red->getOperand(0), region->getArgument(0));
705 irMap.map(red->getOperand(1), region->getArgument(1));
706 auto *cloned = rewriter.clone(*red, irMap);
707 rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0));
709 rewriter.replaceOp(red, custom.getResult());
710 return success();
711 }
712 };
713
714
715
716
718 public:
720 LogicalResult matchAndRewrite(PrintOp op,
723 auto tensor = op.getTensor();
725
726 auto nse = rewriter.create(loc, tensor);
727 rewriter.createvector::PrintOp(
728 loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
729 rewriter.createvector::PrintOp(loc, nse);
730
732 printSizes(rewriter, loc, tensor, stt.getDimRank(), true);
734 printSizes(rewriter, loc, tensor, stt.getLvlRank(), false);
735
736
741 switch (kind) {
742 case SparseTensorFieldKind::StorageSpec: {
743 break;
744 }
745 case SparseTensorFieldKind::PosMemRef: {
748 rewriter.createvector::PrintOp(
749 loc, lvl, vector::PrintPunctuation::NoPunctuation);
751 auto pos = rewriter.create(loc, tensor, l);
752 printContents(rewriter, loc, pos);
753 break;
754 }
755 case SparseTensorFieldKind::CrdMemRef: {
758 rewriter.createvector::PrintOp(
759 loc, lvl, vector::PrintPunctuation::NoPunctuation);
761 Value crd = nullptr;
762
763
764
765 if (stt.getAoSCOOStart() == l)
766 crd = rewriter.create(loc, tensor);
767 else
768 crd = rewriter.create(loc, tensor, l);
769 printContents(rewriter, loc, crd);
770 break;
771 }
772 case SparseTensorFieldKind::ValMemRef: {
773 rewriter.createvector::PrintOp(loc,
775 auto val = rewriter.create(loc, tensor);
776 printContents(rewriter, loc, val);
777 break;
778 }
779 }
780 return true;
781 });
784 return success();
785 }
786
787 private:
788
789
790
791
792
793
794
797 auto shape = cast(vec.getType()).getShape();
799 printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
800 rewriter.createvector::PrintOp(loc, vector::PrintPunctuation::NewLine);
801 }
802
803
807
808 rewriter.createvector::PrintOp(loc, vector::PrintPunctuation::Open);
809
812 auto size = rewriter.creatememref::DimOp(loc, vec, index);
814 auto forOp = rewriter.createscf::ForOp(loc, zero, size, step);
815 idxs.push_back(forOp.getInductionVar());
817 if (i < shape.size() - 1) {
818
819 printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
820 } else {
821
822 auto val = rewriter.creatememref::LoadOp(loc, vec, idxs);
823 if (llvm::isa(val.getType())) {
824
825
826 Value real = rewriter.createcomplex::ReOp(loc, val);
827 Value imag = rewriter.createcomplex::ImOp(loc, val);
828 rewriter.createvector::PrintOp(loc, vector::PrintPunctuation::Open);
829 rewriter.createvector::PrintOp(loc, real,
830 vector::PrintPunctuation::Comma);
831 rewriter.createvector::PrintOp(loc, imag,
832 vector::PrintPunctuation::Close);
833 } else {
834 rewriter.createvector::PrintOp(
835 loc, val, vector::PrintPunctuation::NoPunctuation);
836 }
837
838 auto bound = rewriter.createarith::AddIOp(loc, idxs.back(), step);
839 Value cond = rewriter.createarith::CmpIOp(loc, arith::CmpIPredicate::ne,
840 bound, size);
841 scf::IfOp ifOp = rewriter.createscf::IfOp(loc, cond, false);
843 rewriter.createvector::PrintOp(loc, vector::PrintPunctuation::Comma);
844 }
845 idxs.pop_back();
847
848 rewriter.createvector::PrintOp(loc, vector::PrintPunctuation::Close);
849 }
850
851
853 unsigned size, bool isDim) {
854
855 rewriter.createvector::PrintOp(loc, vector::PrintPunctuation::Open);
856
857 for (unsigned i = 0; i < size; i++) {
860 if (isDim)
861 val = rewriter.createtensor::DimOp(loc, tensor, idx);
862 else
863 val = rewriter.create(loc, tensor, idx);
864 rewriter.createvector::PrintOp(
865 loc, val,
866 i != size - 1 ? vector::PrintPunctuation::Comma
867 : vector::PrintPunctuation::NoPunctuation);
868 }
869
870 rewriter.createvector::PrintOp(loc, vector::PrintPunctuation::Close);
871 rewriter.createvector::PrintOp(loc, vector::PrintPunctuation::NewLine);
872 }
873 };
874
875
876 struct TensorReshapeRewriter : public OpRewritePatterntensor::ReshapeOp {
877 public:
879
880 LogicalResult matchAndRewrite(tensor::ReshapeOp op,
883 Value srcTensor = op.getSource();
886 if (!srcTp || !dstTp)
887 return failure();
888
889 if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
890 !dstTp->hasStaticDimShape())
891 return failure();
892
894 sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor);
896 for (Dimension d : dstTp->getDimShape())
897 dstSizes.push_back(constantIndex(rewriter, loc, d));
898
899 Value nnz = rewriter.create(loc, srcTensor);
900
901
903 dstTp->withoutDimToLvl(),
904 !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
906 Value buffer = rewriter
907 .create(loc, bufferTp, dynSizes, Value(),
909 .getResult();
910
911
912
913
914
915
916
917
918
919
920
921
922 const auto encSrc = srcTp->getEncoding();
923 ForeachOp foreachOp = rewriter.create(
924 loc, srcTensor, buffer,
927 const Dimension srcRank = srcTp->getDimRank();
929 srcDcvs.reserve(srcRank);
930 for (Dimension d = 0; d < srcRank; d++) {
932 srcDcvs.push_back(srcLcvs[lvl]);
933 }
934
936 for (Dimension d = 0; d < srcRank; d++)
937 collapseSize =
938 builder.createarith::MulIOp(loc, collapseSize, srcSizes[d]);
940
942 for (Dimension i = 0; i < srcRank; i++)
943 collapseIdx.push_back(i);
946 reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
947 collapsedSizes, collapsedDcvs);
948
950 for (Dimension i = 0; i < dstTp->getDimRank(); i++)
951 expandIdx.push_back(i);
954 reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
955 dstSizes, dstDcvs);
956
957 auto t =
958 builder.createtensor::InsertOp(loc, v, reduc.front(), dstDcvs);
959 builder.create<sparse_tensor::YieldOp>(loc, t);
960 });
961
962 Value t = rewriter.create(loc, foreachOp.getResult(0), true);
963 if (bufferTp != *dstTp) {
964 auto dstRTT = dstTp->getRankedTensorType();
965 Value converted = rewriter.create(loc, dstRTT, t).getResult();
966 rewriter.create(loc, t);
967 t = converted;
968 }
970 return success();
971 }
972 };
973
974
975 template
976 struct Sparse2SparseReshapeRewriter : public OpRewritePattern {
977 public:
979
980 LogicalResult matchAndRewrite(ReshapeOp op,
983 Value srcTensor = op.getSrc();
986 if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
987 return failure();
988
989
990
992 sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
995 if (dstTp.hasStaticDimShape()) {
996 for (Dimension d : dstTp.getDimShape())
997 dstSizes.push_back(constantIndex(rewriter, loc, d));
998 } else {
1001 op.getReassociationIndices());
1003 if (shape == ShapedType::kDynamic)
1004 dstDynSizes.push_back(dstSizes[idx]);
1005 }
1006 }
1007 Value nnz = rewriter.create(loc, srcTensor);
1008
1009
1011 dstTp.withoutDimToLvl(),
1012 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
1013
1015 rewriter
1016 .create(loc, bufferTp, dstDynSizes, Value(),
1018 .getResult();
1019
1020
1021
1022
1023
1024
1025
1026
1027 const auto encSrc = srcTp.getEncoding();
1028 ForeachOp foreachOp = rewriter.create(
1029 loc, srcTensor, buffer,
1032 const Dimension dimRank = srcTp.getDimRank();
1034 srcDcvs.reserve(dimRank);
1035 for (Dimension d = 0; d < dimRank; d++) {
1037 srcDcvs.push_back(srcLcvs[lvl]);
1038 }
1040 reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
1041 srcDcvs, dstSizes, dstDcvs);
1042 auto t =
1043 builder.createtensor::InsertOp(loc, v, reduc.front(), dstDcvs);
1044 builder.create<sparse_tensor::YieldOp>(loc, t);
1045 });
1046
1047 Value t = rewriter.create(loc, foreachOp.getResult(0), true);
1048 if (bufferTp != dstTp) {
1049 auto dstRTT = dstTp.getRankedTensorType();
1050 Value converted = rewriter.create(loc, dstRTT, t).getResult();
1051 rewriter.create(loc, t);
1052 t = converted;
1053 }
1055 return success();
1056 }
1057 };
1058
1059
1060
1061 template
1062 struct ReshapeRewriter : public OpRewritePattern {
1063 public:
1065
1066 LogicalResult matchAndRewrite(ReshapeOp op,
1068 Location loc = op->getLoc();
1071
1072
1073
1074
1075 if (encDst && encSrc) {
1076 return failure();
1077 }
1078 if (encSrc) {
1080 auto denseTp =
1082 auto convert = rewriter.create(loc, denseTp, op.getSrc());
1083 rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
1084 return success();
1085 }
1086 if (encDst) {
1088 auto denseTp =
1090 ReshapeOp reshape;
1091 if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
1092 reshape = rewriter.create(
1093 loc, denseTp, op.getSrc(), op.getReassociation(),
1094 op.getOutputShape(), op.getStaticOutputShape());
1095 } else {
1096 reshape = rewriter.create(loc, denseTp, op.getSrc(),
1097 op.getReassociation());
1098 }
1099 Value convert = rewriter.create(loc, rtp, reshape);
1100 rewriter.replaceOp(op, convert);
1101 return success();
1102 }
1103 return failure();
1104 }
1105 };
1106
1107
1108
1109 struct TensorLike {
1110 TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
1114
1115 val = builder.create(loc, rtt, dynSzs);
1116 if (!isSparse()) {
1118 val = builder.createlinalg::FillOp(loc, c0, val).getResult(0);
1119 }
1120 }
1121
1123 val = builder.createtensor::InsertOp(loc, v, val, crds);
1124 }
1125
1127 if (isSparse())
1128 return builder.create(loc, val, true);
1129 return val;
1130 }
1131
1132 bool isSparse() const {
1134 }
1135
1137 };
1138
1139 struct SparseTensorDimOpRewriter : public OpRewritePatterntensor::DimOp {
1141 LogicalResult matchAndRewrite(tensor::DimOp op,
1143 std::optional<int64_t> dim = op.getConstantIndex();
1145 if (!dim || !stt || !stt->hasEncoding())
1146 return failure();
1147
1148 if (stt->isPermutation()) {
1150 toLvl(stt->getEncoding(), *dim));
1151 return success();
1152 }
1153
1154
1155
1156
1157
1158
1159
1162 for (Level l = 0; l < stt->getLvlRank(); l++) {
1163 Value lvlSz = rewriter.create(loc, op.getSource(), l);
1164 Value maxLvlCrd = rewriter.createarith::SubIOp(
1166 maxLvlCrds.push_back(maxLvlCrd);
1167 }
1168
1169 AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
1170 Value maxDimCrd = rewriter.createaffine::AffineApplyOp(
1171 op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
1172 maxLvlCrds);
1173
1174 Value dimSz = rewriter.createarith::AddIOp(
1177 return success();
1178 }
1179 };
1180
1181 struct ConcatenateRewriter : public OpRewritePattern {
1183 LogicalResult matchAndRewrite(ConcatenateOp op,
1185 if (op.needsExtraSort())
1186 op.emitError("ConcatenateOp not staged");
1187
1188 const Location loc = op.getLoc();
1190 const Dimension conDim = op.getDimension();
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207 TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1209 Value iterArg = dstBuf.val;
1210
1211 ForeachOp foreachOp;
1212 for (Value input : op.getInputs()) {
1213
1214
1215 foreachOp = rewriter.create(
1216 loc, input, iterArg,
1220 offDimCrd[conDim] =
1221 builder.createarith::AddIOp(loc, offDimCrd[conDim], offset);
1222
1223
1224 dstBuf.val = reduc.front();
1225 if (!dstTp.isAllDense()) {
1227 auto ifOp = builder.createscf::IfOp(loc, reduc.getTypes(), cond,
1228 true);
1229 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1230 builder.createscf::YieldOp(loc, dstBuf.val);
1231
1232 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1233 dstBuf.insert(builder, loc, v, offDimCrd);
1234 builder.createscf::YieldOp(loc, dstBuf.val);
1235
1236
1237 builder.setInsertionPointAfter(ifOp);
1238 dstBuf.val = ifOp.getResult(0);
1239 } else {
1240 dstBuf.insert(builder, loc, v, offDimCrd);
1241 }
1242 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1243 });
1244
1245
1246
1248 assert(!ShapedType::isDynamic(sz));
1249 offset = rewriter.createarith::AddIOp(loc, offset,
1251 iterArg = foreachOp.getResult(0);
1252 dstBuf.val = iterArg;
1253 }
1254
1255 dstBuf.val = iterArg;
1256 Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
1258 return success();
1259 }
1260 };
1261
1262 struct DirectConvertRewriter : public OpRewritePattern {
1264 LogicalResult matchAndRewrite(ConvertOp op,
1266 if (op.needsExtraSort())
1267 return op.emitError("ConvertOp not staged.");
1268
1269
1272 if (encDst && encSrc && !encSrc.isSlice() &&
1273 encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
1274
1275
1276 return failure();
1277 }
1278
1280 Value src = op.getSource();
1281
1284
1285 bool fromSparseConst = false;
1286 if (auto constOp = op.getSource().getDefiningOparith::ConstantOp())
1287 if (isa(constOp.getValue()))
1288 fromSparseConst = true;
1289
1290 const AffineMapAttr foreachOrder =
1291 (!dstStt.isIdentity() && fromSparseConst)
1293 : nullptr;
1294
1295 bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;
1296
1301
1302 auto foreachOp = rewriter.create(
1303 loc, src, dstBuf.val, foreachOrder,
1306
1307 dstBuf.val = reduc.front();
1308 if (!skipZeroCheck) {
1310 auto ifOp = builder.createscf::IfOp(loc, reduc.getTypes(), cond,
1311 true);
1312 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1313 builder.createscf::YieldOp(loc, dstBuf.val);
1314
1315 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1316 dstBuf.insert(builder, loc, v, dcvs);
1317 builder.createscf::YieldOp(loc, dstBuf.val);
1318
1319
1320 builder.setInsertionPointAfter(ifOp);
1321 dstBuf.val = ifOp.getResult(0);
1322 } else {
1323 dstBuf.insert(builder, loc, v, dcvs);
1324 }
1325 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1326 });
1327
1329
1330
1331 dstBuf.val = foreachOp.getResult(0);
1332
1335 return success();
1336 }
1337 };
1338
1339 struct CrdTranslateRewriter : public OpRewritePattern {
1341 LogicalResult matchAndRewrite(CrdTranslateOp op,
1343 AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1344 ? op.getEncoder().getDimToLvl()
1345 : op.getEncoder().getLvlToDim();
1346
1349
1350
1351
1352 Value trans = rewriter.createaffine::AffineApplyOp(
1354 op.getInCrds());
1355 outCrds.push_back(trans);
1356 }
1357 rewriter.replaceOp(op, outCrds);
1358 return success();
1359 }
1360 };
1361
1362
1363 struct ForeachRewriter : public OpRewritePattern {
1364 public:
1366
1367 LogicalResult matchAndRewrite(ForeachOp op,
1369
1370 auto loc = op.getLoc();
1371 Value input = op.getTensor();
1374 const Level lvlRank = stt.getLvlRank();
1375
1376
1377
1378 if (auto constOp = input.getDefiningOparith::ConstantOp()) {
1379 if (auto attr = dyn_cast(constOp.getValue())) {
1381 }
1382 }
1383
1384
1385 const auto enc = stt.getEncoding();
1386
1387
1392 for (Level l = 0; l < lvlRank; l++) {
1393
1394
1396 loopEmitter.makeTensorLevel(0, l)};
1397 loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
1398
1399
1400 loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,
1401 reduc);
1402 }
1403
1405 if (op.getOrder()) {
1406
1407 llvm_unreachable(
1408 "Level order not yet implemented on non-constant input tensors.");
1409 }
1410
1411 Value vals = loopEmitter.getValBuffer()[0];
1413
1414
1415 Value val = enc ? rewriter.creatememref::LoadOp(loc, vals, pos)
1416 : rewriter.creatememref::LoadOp(loc, vals, lcvs);
1417
1418
1419 Block *srcBlock = op.getBody();
1420
1421
1423 enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1424
1425
1426 args.push_back(val);
1427
1428 args.append(reduc);
1429
1430
1433
1435 if (llvm::isascf::YieldOp(last)) {
1436
1437
1438
1440 }
1441
1445 for (Level l = 0; l < lvlRank; l++) {
1446
1447
1448 loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
1449 loopEmitter.exitCurrentLoopSeq(rewriter, loc);
1450 }
1451
1452
1453
1454 rewriter.replaceOp(op, reducValue);
1455 return success();
1456 }
1457 };
1458
1459
1462 LogicalResult matchAndRewrite(NewOp op,
1466 if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
1467 return failure();
1468
1469
1470
1471
1472
1473 RankedTensorType dstTp = stt.getRankedTensorType();
1474 RankedTensorType cooTp = stt.getCOOType(true);
1475 Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
1476 Value convert = cooTensor;
1477 auto enc = stt.getEncoding();
1478 if (!stt.isPermutation()) {
1480 convert = rewriter.create(loc, coo, convert);
1482 }
1483 convert = rewriter.create(loc, dstTp, convert);
1484 if (!stt.isPermutation())
1485 convert = rewriter.create(loc, enc, convert);
1486 rewriter.replaceOp(op, convert);
1487
1488
1490 rewriter.create(loc, cooTensor);
1491
1492 return success();
1493 }
1494 };
1495
1496
1499 LogicalResult matchAndRewrite(OutOp op,
1502
1503 Value src = op.getTensor();
1504 Value nnz = rewriter.create(loc, src);
1505
1506
1508 const Dimension dimRank = srcTp.getDimRank();
1510 Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);
1511
1512
1513
1516 for (Dimension d = 0; d < dimRank; d++) {
1517 rewriter.creatememref::StoreOp(loc, dims[d], dimSizes,
1519 }
1520
1521
1524 createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
1525 {op.getDest()}, EmitCInterface::Off)
1526 .getResult(0);
1528 createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},
1529 {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1530
1531 Value dimCoords = dimSizes;
1532 Type eltTp = srcTp.getElementType();
1533 SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
1536 ModuleOp module = op->getParentOfType();
1537
1538
1539 rewriter.create(
1540 loc, src, std::nullopt,
1543 for (Dimension d = 0; d < dimRank; d++) {
1544 rewriter.creatememref::StoreOp(loc, dcvs[d], dimCoords,
1546 }
1547 rewriter.creatememref::StoreOp(loc, v, value);
1550 EmitCInterface::On);
1551 builder.createfunc::CallOp(loc, TypeRange(), fn, operands);
1552 builder.create<sparse_tensor::YieldOp>(loc);
1553 });
1554
1555
1556 createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},
1557 EmitCInterface::Off);
1558
1560 return success();
1561 }
1562 };
1563
1564 }
1565
1566
1567
1568
1569
1571 patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
1572 FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1573 GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1575 }
1576
1578 bool enableRT,
1579 bool enableConvert) {
1580 patterns.add<ConcatenateRewriter, ReshapeRewritertensor::ExpandShapeOp,
1581 ReshapeRewritertensor::CollapseShapeOp,
1582 Sparse2SparseReshapeRewritertensor::ExpandShapeOp,
1583 Sparse2SparseReshapeRewritertensor::CollapseShapeOp,
1584 SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1586
1587 if (enableConvert)
1589 if (!enableRT)
1591 }
1592
1594
1595
1596 patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());
1597 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
static bool isMulChain(Value val, Value x)
static bool isSampling(GenericOp op)
static bool isSumOfMul(GenericOp op)
static bool isZeroValue(Value val)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static LogicalResult genForeachOnSparseConstant(ForeachOp op, RewriterBase &rewriter, SparseElementsAttr attr)
static bool isMaterializing(OpOperand *op, bool isZero)
static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType dstTp, ValueRange srcs, unsigned dim)
Populates the given sizes array for concatenation from types (for static sizes) and from the source t...
static bool isSparseTensor(Value v)
static bool isZeroYield(GenericOp op)
static void sizesForTensor(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType stp, Value tensor)
Populates given sizes array from type (for static sizes) and from the tensor (for dynamic sizes).
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
IntegerAttr getIndexAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Block * getBlock() const
Returns the current block of the builder.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Block * getBlock()
Returns the operation block that contains this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
RankedTensorType getRankedTensorType() const
Explicitly convert to RankedTensorType.
AffineMap getExpandedDimToLvl() const
Returns the dimToLvl mapping, where the identity map is expanded out into a full AffineMap.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Returns a function reference (first hit also inserts into module).
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
void foreachInSparseConstant(OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref< void(ArrayRef< Value >, Value)> callback)
Iterate over a sparse constant, generates constantOp for value and coordinates.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
uint64_t Level
The type of level identifiers and level-ranks.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl< Value > &dstShape, ArrayRef< Value > srcShape, ArrayRef< Size > staticDstShape, ArrayRef< ReassociationIndices > reassociation)
Computes the shape of destination tensor of a reshape operator.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef< ReassociationIndices > reassociation, ValueRange srcSizes, ValueRange srcCvs, ValueRange dstSizes, SmallVectorImpl< Value > &dstCvs)
Reshape coordinates during a reshaping operation.
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, Value src)
Populates given sizes array from dense tensor or sparse tensor constant.
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This enum defines all the sparse representations supportable by the SparseTensor dialect.