MLIR: lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9 #include
10
12
17
27 #include "llvm/ADT/Bitset.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/FormatVariadic.h"
30
31 #define GET_ATTRDEF_CLASSES
32 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
34
35
36
42
43 #define GET_TYPEDEF_CLASSES
44 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
45
46 using namespace mlir;
48
49
50
54 }
55 }
56
57
58
59
60
62 switch (bitWidth) {
63 case 0:
64 case 8:
65 case 16:
66 case 32:
67 case 64:
68 return true;
69 default:
70 return false;
71 }
72 }
73
77 assert(enc);
78
79
80 SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
81 if (dimShape.has_value()) {
82
83
85 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
86 memrefShape.assign(lvlShape.begin(),
87 lvlShape.begin() + enc.getBatchLvlRank());
88 }
89
90 memrefShape.push_back(ShapedType::kDynamic);
91 return memrefShape;
92 }
93
94
95
96
97
101
105 callback) const {
106 const auto lvlTypes = enc.getLvlTypes();
107 const Level lvlRank = enc.getLvlRank();
110
111 ArrayRef cooSegsRef = cooSegs;
112
113 for (Level l = 0; l < lvlRank; ) {
114 const auto lt = lvlTypes[l];
117 return;
118 }
121 return;
122 }
123 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
124 if (!cooSegsRef.front().isSoA) {
125
126
127 l = cooSegsRef.front().lvlRange.second;
128 } else {
129
130 l++;
131 }
132
133 cooSegsRef = cooSegsRef.drop_front();
134 } else {
135
136 l++;
137 }
138 }
139
142 return;
143
146 return;
147 }
148
153 callback) {
155
158
160
162
164
166
171 switch (fieldKind) {
173 return callback(specType, fieldIdx, fieldKind, lvl, lt);
175 return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
177 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
179 return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
180 };
181 llvm_unreachable("unrecognized field kind");
182 });
183 }
184
186 unsigned numFields = 0;
189 numFields++;
190 return true;
191 });
192 return numFields;
193 }
194
196 unsigned numFields = 0;
200 numFields++;
201 return true;
202 });
203 numFields -= 1;
205 return numFields;
206 }
207
208 std::pair<FieldIndex, unsigned>
210 std::optional lvl) const {
212 unsigned stride = 1;
214 assert(lvl.has_value());
215 const Level cooStart = enc.getAoSCOOStart();
216 const Level lvlRank = enc.getLvlRank();
217 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
218 lvl = cooStart;
219 stride = lvlRank - cooStart;
220 }
221 }
225 if ((lvl && fLvl == lvl.value() && kind == fKind) ||
227 fieldIdx = fIdx;
228
229 return false;
230 }
231 return true;
232 });
234 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
235 }
236
237
238
239
240
241 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
242 return isDynamic(v) ? std::nullopt
243 : std::make_optional(static_cast<uint64_t>(v));
244 }
245
246 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
247 return getStatic(getOffset());
248 }
249
250 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
252 }
253
254 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
255 return getStatic(getSize());
256 }
257
258 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
259 return isDynamic(getOffset()) && isDynamic(getStride()) &&
260 isDynamic(getSize());
261 }
262
263 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
264 return isDynamic(v) ? "?" : std::to_string(v);
265 }
266
268 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
269 os << '(';
270 os << getStaticString(getOffset());
271 os << ", ";
272 os << getStaticString(getSize());
273 os << ", ";
274 os << getStaticString(getStride());
275 os << ')';
276 }
277
280 }
281
285 if (parseResult.has_value()) {
286 if (parseResult.value().succeeded() && result < 0) {
289 "expect positive value or ? for slice offset/size/stride");
290 return failure();
291 }
292 return parseResult.value();
293 }
294
295
296 result = SparseTensorDimSliceAttr::kDynamic;
298 }
299
301 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
302
310 return {};
311
313 offset, size, stride);
314 }
315
316 LogicalResult
318 int64_t offset, int64_t size, int64_t stride) {
319 if (!isDynamic(offset) && offset < 0)
320 return emitError() << "expect non-negative value or ? for slice offset";
321 if (!isDynamic(size) && size <= 0)
322 return emitError() << "expect positive value or ? for slice size";
323 if (!isDynamic(stride) && stride <= 0)
324 return emitError() << "expect positive value or ? for slice stride";
325 return success();
326 }
327
328 SparseTensorEncodingAttr
329 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
330 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
333 getCrdWidth(), getExplicitVal(), getImplicitVal());
334 }
335
336 SparseTensorEncodingAttr
337 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
338 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
339 }
340
341 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
342 return withDimToLvl(AffineMap());
343 }
344
345 SparseTensorEncodingAttr
346 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
347 unsigned crdWidth) const {
348 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
350 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
351 crdWidth, getExplicitVal(), getImplicitVal());
352 }
353
354 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
355 return withBitWidths(0, 0);
356 }
357
358 SparseTensorEncodingAttr
359 SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
360 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
362 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
363 getCrdWidth(), explicitVal, getImplicitVal());
364 }
365
366 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
367 return withExplicitVal(Attribute());
368 }
369
370 SparseTensorEncodingAttr
371 SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
372 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
374 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
375 getCrdWidth(), getExplicitVal(), implicitVal);
376 }
377
378 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
379 return withImplicitVal(Attribute());
380 }
381
382 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
385 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
386 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
387 }
388
389 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
391 }
392
393 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
395 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
396 return std::distance(lastBatch, lvlTypes.rend());
397 }
398
400 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
401 }
402
403 bool SparseTensorEncodingAttr::isAllOrdered() const {
404 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
405 }
406
407 Type SparseTensorEncodingAttr::getCrdElemType() const {
408 if (!getImpl())
409 return nullptr;
410 if (getCrdWidth())
413 }
414
415 Type SparseTensorEncodingAttr::getPosElemType() const {
416 if (!getImpl())
417 return nullptr;
418 if (getPosWidth())
421 }
422
423 MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
427 }
428
429 MemRefType SparseTensorEncodingAttr::getPosMemRefType(
433 }
434
435 bool SparseTensorEncodingAttr::isIdentity() const {
436 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
437 }
438
440 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
441 }
442
443 Dimension SparseTensorEncodingAttr::getDimRank() const {
444 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
445 const auto dimToLvl = getDimToLvl();
446 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
447 }
448
449 Level SparseTensorEncodingAttr::getLvlRank() const {
450 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
451 return getLvlTypes().size();
452 }
453
454 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
455 if (!getImpl())
457 assert(l < getLvlRank() && "Level is out of bounds");
458 return getLvlTypes()[l];
459 }
460
461 bool SparseTensorEncodingAttr::isSlice() const {
462 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
463 return !getDimSlices().empty();
464 }
465
466 SparseTensorDimSliceAttr
467 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
468 assert(isSlice() && "Is not a slice");
469 const auto dimSlices = getDimSlices();
470 assert(dim < dimSlices.size() && "Dimension is out of bounds");
471 return dimSlices[dim];
472 }
473
474 std::optional<uint64_t>
475 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
476 return getDimSlice(dim).getStaticOffset();
477 }
478
479 std::optional<uint64_t>
480 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
481 return getDimSlice(dim).getStaticStride();
482 }
483
484 std::optional<uint64_t>
485 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
486 return getStaticDimSliceOffset(toDim(*this, lvl));
487 }
488
489 std::optional<uint64_t>
490 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
491 return getStaticDimSliceStride(toDim(*this, lvl));
492 }
493
495 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
496 CrdTransDirectionKind dir) const {
497 if (isIdentity())
499
501 unsigned rank =
502 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
503 ret.reserve(rank);
504
506 for (unsigned r = 0; r < rank; r++) {
507 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
508 : toLvl(*this, r);
509 ret.push_back(srcShape[trans]);
510 }
511 return ret;
512 }
513
514
516 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
517
519 dimRep.reserve(srcShape.size());
520 for (int64_t sz : srcShape) {
521 if (!ShapedType::isDynamic(sz)) {
522
524 } else {
525
527 }
528 };
529
531
534
535 if (auto c = llvm::dyn_cast(evalExp)) {
536 ret.push_back(c.getValue() + 1);
537 } else {
538 if (auto mod = llvm::dyn_cast(evalExp);
540
541
542 if (auto bound = llvm::dyn_cast(mod.getRHS())) {
543 ret.push_back(bound.getValue());
544 continue;
545 }
546 }
547 ret.push_back(ShapedType::kDynamic);
548 }
549 }
550 assert(ret.size() == rank);
551 return ret;
552 }
553
555 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
557 CrdTransDirectionKind dir) const {
558 if (!getImpl())
559 return crds;
560
562 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
564 auto transOp = builder.create(loc, retType, crds, dir, *this);
565 return transOp.getOutCrds();
566 }
567
569
571 return {};
573 return {};
574
575
580 unsigned posWidth = 0;
581 unsigned crdWidth = 0;
584 StringRef attrName;
586 "explicitVal", "implicitVal"};
588
589 auto *it = find(keys, attrName);
590 if (it == keys.end()) {
592 return {};
593 }
594 unsigned keyWordIndex = it - keys.begin();
595
597 return {};
598
599 switch (keyWordIndex) {
600 case 0: {
602 auto res = cParser.parseDimLvlMap();
603 if (failed(res))
604 return {};
605 const auto &dlm = *res;
606
607 const Level lvlRank = dlm.getLvlRank();
608 for (Level lvl = 0; lvl < lvlRank; lvl++)
609 lvlTypes.push_back(dlm.getLvlType(lvl));
610
611 const Dimension dimRank = dlm.getDimRank();
612 for (Dimension dim = 0; dim < dimRank; dim++)
613 dimSlices.push_back(dlm.getDimSlice(dim));
614
615
616
617 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
618 return static_cast<bool>(slice.getImpl());
619 };
620 if (llvm::any_of(dimSlices, isDefined)) {
621 const auto defaultSlice =
623 for (Dimension dim = 0; dim < dimRank; dim++)
624 if (!isDefined(dimSlices[dim]))
625 dimSlices[dim] = defaultSlice;
626 } else {
627 dimSlices.clear();
628 }
629
630 dimToLvl = dlm.getDimToLvlMap(parser.getContext());
631 lvlToDim = dlm.getLvlToDimMap(parser.getContext());
632 break;
633 }
634 case 1: {
637 return {};
638 auto intAttr = llvm::dyn_cast(attr);
639 if (!intAttr) {
641 "expected an integral position bitwidth");
642 return {};
643 }
644 posWidth = intAttr.getInt();
645 break;
646 }
647 case 2: {
650 return {};
651 auto intAttr = llvm::dyn_cast(attr);
652 if (!intAttr) {
654 "expected an integral index bitwidth");
655 return {};
656 }
657 crdWidth = intAttr.getInt();
658 break;
659 }
660 case 3: {
663 return {};
664 if (auto result = llvm::dyn_cast(attr)) {
665 explicitVal = result;
666 } else if (auto result = llvm::dyn_cast(attr)) {
667 explicitVal = result;
668 } else if (auto result = llvm::dyn_castcomplex::NumberAttr(attr)) {
669 explicitVal = result;
670 } else {
672 "expected a numeric value for explicitVal");
673 return {};
674 }
675 break;
676 }
677 case 4: {
680 return {};
681 if (auto result = llvm::dyn_cast(attr)) {
682 implicitVal = result;
683 } else if (auto result = llvm::dyn_cast(attr)) {
684 implicitVal = result;
685 } else if (auto result = llvm::dyn_castcomplex::NumberAttr(attr)) {
686 implicitVal = result;
687 } else {
689 "expected a numeric value for implicitVal");
690 return {};
691 }
692 break;
693 }
694 }
695
697 break;
698 }
699
700
702 return {};
704 return {};
705
706
707 if (!lvlToDim || lvlToDim.isEmpty()) {
709 }
710 return parser.getChecked(
711 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
712 explicitVal, implicitVal, dimSlices);
713 }
714
716 auto map = static_cast<AffineMap>(getDimToLvl());
717
718 if (!map)
720 printer << "<{ map = ";
721 printSymbols(map, printer);
722 printer << '(';
723 printDimensions(map, printer, getDimSlices());
724 printer << ") -> (";
725 printLevels(map, printer, getLvlTypes());
726 printer << ')';
727
728 if (getPosWidth())
729 printer << ", posWidth = " << getPosWidth();
730 if (getCrdWidth())
731 printer << ", crdWidth = " << getCrdWidth();
732 if (getExplicitVal()) {
733 printer << ", explicitVal = " << getExplicitVal();
734 }
735 if (getImplicitVal())
736 printer << ", implicitVal = " << getImplicitVal();
737 printer << " }>";
738 }
739
740 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
743 return;
744 printer << '[';
745 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
746 printer << 's' << i << ", ";
749 printer << ']';
750 }
751
752 void SparseTensorEncodingAttr::printDimensions(
755 if (!dimSlices.empty()) {
756 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
757 printer << 'd' << i << " : " << dimSlices[i] << ", ";
759 printer << 'd' << map.getNumDims() - 1 << " : "
761 }
762 } else {
763 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
764 printer << 'd' << i << ", ";
766 printer << 'd' << map.getNumDims() - 1;
767 }
768 }
769
770 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
772 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
774 printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
775 }
779 printer << " : " << toMLIRString(lvlTypes[lastIndex]);
780 }
781 }
782
789 return emitError() << "unexpected position bitwidth: " << posWidth;
791 return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
792
793
794 auto *it = llvm::find_if(lvlTypes, isSingletonLT);
795 while (it != lvlTypes.end()) {
796 if (it == lvlTypes.begin() ||
798 return emitError() << "expected compressed or loose_compressed level "
799 "before singleton level";
800
801 auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
803 return emitError() << "expected all singleton lvlTypes "
804 "following a singleton level";
805
806 if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
809 })) {
810 return emitError() << "expected all singleton lvlTypes stored in the "
811 "same memory layout (SoA vs AoS).";
812 }
813 it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
814 }
815
816 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
817 if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
818 return emitError() << "Batch lvlType can only be leading levels.";
819
820
821 auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
823 });
824 if (llvm::any_of(soaLvls, [](LevelType lt) {
826 })) {
827 return emitError() << "SoA is only applicable to singleton lvlTypes.";
828 }
829
830
831 if (auto it = llvm::find_if(lvlTypes, isNOutOfMLT);
832 it != std::end(lvlTypes)) {
833 if (it != lvlTypes.end() - 1)
834 return emitError() << "expected n_out_of_m to be the last level type";
835 if (!std::all_of(lvlTypes.begin(), it, isDenseLT))
836 return emitError() << "expected all dense lvlTypes "
837 "before a n_out_of_m level";
841 << "expected 1xm block structure for n_out_of_m level";
842 }
844 unsigned coefficient = 0;
845 for (const auto &elem : sizes) {
846 if (elem != 0) {
847 if (elem != coefficient && coefficient != 0) {
848 return emitError() << "expected only one blocked level "
849 "with the same coefficients";
850 }
851 coefficient = elem;
852 }
853 }
854 if (coefficient != getM(*it)) {
855 return emitError() << "expected coeffiencts of Affine expressions "
856 "to be equal to m of n_out_of_m level";
857 }
858 }
859 }
860
861
862
863
864
865 const Level lvlRank = lvlTypes.size();
866 if (lvlRank == 0)
867 return emitError() << "expected a non-empty array for lvlTypes";
868
870 if (dimToLvl) {
873 << "level-rank mismatch between dimToLvl and lvlTypes: "
874 << dimToLvl.getNumResults() << " != " << lvlRank;
876
878 return emitError() << "failed to infer lvlToDim from dimToLvl";
879 if (lvlToDim && (inferRes != lvlToDim))
880 return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
881 if (dimRank > lvlRank)
882 return emitError() << "unexpected dimToLvl mapping from " << dimRank
883 << " to " << lvlRank;
884 }
885 if (!dimSlices.empty()) {
886 if (dimSlices.size() != dimRank)
888 << "dimension-rank mismatch between dimSlices and dimToLvl: "
889 << dimSlices.size() << " != " << dimRank;
890
891
892 if (dimRank != lvlRank)
894 << "dimSlices expected dimension-rank to match level-rank: "
895 << dimRank << " != " << lvlRank;
896 }
897 return success();
898 }
899
900 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
903
904
905 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
906 getPosWidth(), getCrdWidth(), getExplicitVal(),
907 getImplicitVal(), getDimSlices())))
908 return failure();
909
910
911
912 const Dimension dimRank = dimShape.size();
913 if (dimRank == 0)
914 return emitError() << "expected non-scalar sparse tensor";
915 if (getDimRank() != dimRank)
917 << "dimension-rank mismatch between encoding and tensor shape: "
918 << getDimRank() << " != " << dimRank;
919 if (auto expVal = getExplicitVal()) {
920 Type attrType = llvm::dyn_cast(expVal).getType();
921 if (attrType != elementType) {
922 return emitError() << "explicit value type mismatch between encoding and "
923 << "tensor element type: " << attrType
924 << " != " << elementType;
925 }
926 }
927 if (auto impVal = getImplicitVal()) {
928 Type attrType = llvm::dyn_cast(impVal).getType();
929 if (attrType != elementType) {
930 return emitError() << "implicit value type mismatch between encoding and "
931 << "tensor element type: " << attrType
932 << " != " << elementType;
933 }
934
935 auto impFVal = llvm::dyn_cast(impVal);
936 auto impIntVal = llvm::dyn_cast(impVal);
937 auto impComplexVal = llvm::dyn_castcomplex::NumberAttr(impVal);
938 if ((impFVal && impFVal.getValue().isNonZero()) ||
939 (impIntVal && !impIntVal.getValue().isZero()) ||
940 (impComplexVal && (impComplexVal.getImag().isNonZero() ||
941 impComplexVal.getReal().isNonZero()))) {
942 return emitError() << "implicit value must be zero";
943 }
944 }
945 return success();
946 }
947
948 Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
950 assert(coo.size() == 1 || coo.empty());
951 if (!coo.empty() && coo.front().isAoS()) {
952 return coo.front().lvlRange.first;
953 }
954 return getLvlRank();
955 }
956
958 mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
960 if (getLvlRank() <= 1)
961 return ret;
962
965 while (l < getLvlRank()) {
966 auto lt = lts[l];
968 auto cur = lts.begin() + l;
969 auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
970 return !lt.isaLevelFormat::Singleton();
971 });
972 unsigned cooLen = std::distance(cur, end);
973 if (cooLen > 1) {
974
975
976
977
978 ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
980 }
981 l += cooLen;
982 } else {
983 l++;
984 }
985 }
986 return ret;
987 }
988
989
990
991
992
995 if (!hasEncoding())
996 return false;
997 if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
998 return false;
999 for (Level l = startLvl + 1; l < lvlRank; ++l)
1000 if (!isSingletonLvl(l))
1001 return false;
1002
1003
1004
1005 return || isUniqueLvl(lvlRank - 1);
1006 }
1007
1008 RankedTensorType
1011 lvlTypes.reserve(lvlRank);
1012
1013
1014 lvlTypes.push_back(
1016 if (lvlRank > 1) {
1017
1018 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1020
1022 }
1024 getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1025 getCrdWidth(), getExplicitVal(), getImplicitVal());
1027 }
1028
1029
1030
1031
1032
1033 SparseTensorEncodingAttr
1035 if (auto ttp = llvm::dyn_cast(type))
1036 return llvm::dyn_cast_or_null(ttp.getEncoding());
1037 if (auto mdtp = llvm::dyn_cast(type))
1038 return mdtp.getEncoding();
1039 return nullptr;
1040 }
1041
1044 auto map = static_cast<AffineMap>(dimToLvl);
1046
1053 }
1054 return lvlToDim;
1055 }
1056
1061 lvlExprs.reserve(numLvls);
1062
1063
1064 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1065 for (unsigned i = 0, n = numLvls; i < n; i++) {
1066 auto result = dimToLvl.getResult(i);
1067 if (auto binOp = dyn_cast(result)) {
1069
1070 auto pos = dyn_cast(binOp.getLHS()).getPosition();
1071 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1072 "expected only one floordiv for each dimension");
1074
1076
1077 components.push_back(binOp.getRHS());
1078
1079 lvlExprComponents[pos] = components;
1081 auto pos = dyn_cast(binOp.getLHS()).getPosition();
1082 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1083 "expected floordiv before mod");
1084
1085
1086 lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1087 } else {
1088 assert(false && "expected floordiv or mod");
1089 }
1090 } else {
1092 }
1093 }
1094
1095
1096
1097
1098 for (auto &components : lvlExprComponents) {
1099 assert(components.second.size() == 3 &&
1100 "expected 3 components to build lvlExprs");
1103 auto addOp =
1105 lvlExprs.push_back(addOp);
1106 }
1107 return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1108 }
1109
1112 "expected dimToLvl to be block sparsity for calling getBlockSize");
1114 for (auto result : dimToLvl.getResults()) {
1115 if (auto binOp = dyn_cast(result)) {
1117 blockSize.push_back(
1118 dyn_cast(binOp.getRHS()).getValue());
1119 }
1120 } else {
1121 blockSize.push_back(0);
1122 }
1123 }
1124 return blockSize;
1125 }
1126
1128 if (!dimToLvl)
1129 return false;
1130 std::map<unsigned, int64_t> coeffientMap;
1131 bool hasBlock = false;
1132 for (auto result : dimToLvl.getResults()) {
1133 if (auto binOp = dyn_cast(result)) {
1134
1135 auto dimOp = dyn_cast(binOp.getLHS());
1136 auto conOp = dyn_cast(binOp.getRHS());
1137 if (!dimOp || !conOp || conOp.getValue() <= 0)
1138 return false;
1139
1140 auto pos = dimOp.getPosition();
1142
1143 auto [it, inserted] = coeffientMap.try_emplace(pos);
1144 if (!inserted)
1145 return false;
1146
1147 it->second = conOp.getValue();
1149
1150 auto it = coeffientMap.find(pos);
1151 if (it == coeffientMap.end())
1152 return false;
1153
1154 if (conOp.getValue() != it->second)
1155 return false;
1156 hasBlock = true;
1157 } else {
1158 return false;
1159 }
1160 } else if (auto dimOp = dyn_cast(result)) {
1161 auto pos = dimOp.getPosition();
1162
1163 if (!coeffientMap.try_emplace(pos, 0).second)
1164 return false;
1165 } else {
1166 return false;
1167 }
1168 }
1169 return hasBlock;
1170 }
1171
1173 auto hasNonIdentityMap = [](Value v) {
1176 };
1177
1178 return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1179 llvm::any_of(op->getResults(), hasNonIdentityMap);
1180 }
1181
1183 if (enc) {
1184 assert(enc.isPermutation() && "Non permutation map not supported");
1185 if (const auto dimToLvl = enc.getDimToLvl())
1187 }
1188 return l;
1189 }
1190
1192 if (enc) {
1193 assert(enc.isPermutation() && "Non permutation map not supported");
1194 if (const auto lvlToDim = enc.getLvlToDim())
1196 }
1197 return d;
1198 }
1199
1200
1201
1202
1203
1204 static SparseTensorEncodingAttr
1207 for (auto lt : enc.getLvlTypes())
1209
1211 enc.getContext(), lts,
1212 AffineMap(),
1213 AffineMap(),
1214
1215
1216
1217
1218 0, 0,
1219 Attribute(),
1220 Attribute(),
1221 enc.getDimSlices());
1222 }
1223
1224 StorageSpecifierType
1227 }
1228
1229 StorageSpecifierType
1232 SparseTensorEncodingAttr encoding) {
1233 return Base::getChecked(emitError, ctx,
1235 }
1236
1237
1238
1239
1240
1243 }
1244
1247 return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1248 }
1249
1251 StorageSpecifierKind mdKind, std::optional lvl,
1253 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1255 "redundant level argument for querying value memory size");
1256 }
1257
1258 const auto enc = md.getType().getEncoding();
1259 const Level lvlRank = enc.getLvlRank();
1260
1261 if (mdKind == StorageSpecifierKind::DimOffset ||
1262 mdKind == StorageSpecifierKind::DimStride)
1263 if (!enc.isSlice())
1264 return op->emitError("requested slice data on non-slice tensor");
1265
1266 if (mdKind != StorageSpecifierKind::ValMemSize) {
1267 if (!lvl)
1268 return op->emitError("missing level argument");
1269
1270 const Level l = lvl.value();
1271 if (l >= lvlRank)
1272 return op->emitError("requested level is out of bounds");
1273
1274 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1276 "requested position memory size on a singleton level");
1277 }
1278 return success();
1279 }
1280
1282 switch (kind) {
1290 return nullptr;
1291 }
1292 llvm_unreachable("Unrecognizable FieldKind");
1293 }
1294
1297 RankedTensorType valTp,
1300 return op->emitError("the sparse-tensor must have static shape");
1302 return op->emitError("the sparse-tensor must have an encoding attribute");
1303
1304
1306 if (cooStartLvl < stt.getLvlRank()) {
1307
1308 auto cooTp = llvm::cast(lvlTps.back());
1309
1310 unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1311 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1312 return op->emitError("input/output trailing COO level-ranks don't match");
1313 }
1314 }
1315
1316
1318 if (layout.getNumDataFields() != lvlTps.size() + 1)
1319 return op->emitError("inconsistent number of fields between input/output");
1320
1321 unsigned idx = 0;
1322 bool misMatch = false;
1323 layout.foreachField([&idx, &misMatch, stt, valTp,
1327 return true;
1328
1329 Type inputTp = nullptr;
1331 inputTp = valTp;
1332 } else {
1333 assert(fid == idx && stt.getLvlType(lvl) == lt);
1334 inputTp = lvlTps[idx++];
1335 }
1336
1337 Type inpElemTp = llvm::cast(inputTp).getElementType();
1339 if (inpElemTp != expElemTp) {
1340 misMatch = true;
1341 return false;
1342 }
1343 return true;
1344 });
1345
1346 if (misMatch)
1347 return op->emitError("input/output element-types don't match");
1348 return success();
1349 }
1350
1352 RankedTensorType valuesTp = getValues().getType();
1353 const auto lvlsTp = getLevels().getTypes();
1355 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1356 }
1357
1359 if (getOutValues().getType() != getRetValues().getType())
1360 return emitError("output values and return value type mismatch");
1361
1362 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1363 if (ot.getType() != rt.getType())
1364 return emitError("output levels and return levels type mismatch");
1365
1366 RankedTensorType valuesTp = getRetValues().getType();
1367 const auto lvlsTp = getRetLevels().getTypes();
1369 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1370 }
1371
1373 RankedTensorType tp1 = getSource().getType();
1374 RankedTensorType tp2 = getDest().getType();
1375 if (tp1.getRank() != tp2.getRank())
1376 return emitError("unexpected conversion mismatch in rank");
1377 auto dstEnc =
1378 llvm::dyn_cast_or_null(tp2.getEncoding());
1379 if (dstEnc && dstEnc.isSlice())
1380 return emitError("cannot convert to a sparse tensor slice");
1381
1382 auto shape1 = tp1.getShape();
1383 auto shape2 = tp2.getShape();
1384
1385
1386
1387 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1388 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1389 return emitError("unexpected conversion mismatch in dimension ") << d;
1390 return success();
1391 }
1392
1393 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1395 return getSource();
1396 return {};
1397 }
1398
1399 bool ConvertOp::needsExtraSort() {
1402
1403
1404
1406 return false;
1407
1410 return false;
1411 }
1412
1413
1414
1415
1416
1417
1418 if (auto constOp = getSource().getDefiningOparith::ConstantOp())
1419 if (isa(constOp.getValue()))
1420 return false;
1421
1422 return true;
1423 }
1424
1426 uint64_t inRank = getEncoder().getLvlRank();
1427 uint64_t outRank = getEncoder().getDimRank();
1428
1429 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1430 std::swap(inRank, outRank);
1431
1432 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1433 return emitError("Coordinate rank mismatch with encoding");
1434
1435 return success();
1436 }
1437
1438 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1440 if (getEncoder().isIdentity()) {
1441 results.assign(getInCrds().begin(), getInCrds().end());
1442 return success();
1443 }
1445 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1446 ? getEncoder().getDimToLvl()
1447 : getEncoder().getLvlToDim();
1449 results.push_back(getInCrds()[cast(exp).getPosition()]);
1450 return success();
1451 }
1452
1453
1454 auto def = getInCrds()[0].getDefiningOp();
1455 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1457 });
1458 if (!sameDef)
1459 return failure();
1460
1461 bool oppositeDir = def.getDirection() != getDirection();
1462 bool sameOracle =
1463 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1464 bool sameCount = def.getNumResults() == getInCrds().size();
1465 if (!oppositeDir || !sameOracle || !sameCount)
1466 return failure();
1467
1468
1469
1470 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1471 [](auto valuePair) {
1472 auto [lhs, rhs] = valuePair;
1473 return lhs == rhs;
1474 });
1475
1476 if (!sameOrder)
1477 return failure();
1478
1479
1480 results.append(def.getInCrds().begin(), def.getInCrds().end());
1481 return success();
1482 }
1483
1485 int64_t index) {
1486 Value val = builder.createarith::ConstantIndexOp(state.location, index);
1487 return build(builder, state, source, val);
1488 }
1489
1491 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1493 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1495 "Level index exceeds the rank of the input sparse tensor");
1496 }
1497 return success();
1498 }
1499
1500 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1502 }
1503
1508
1510 cast(getSource().getType()).getRank());
1512 }
1513
1514 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1515 auto lvlIndex = llvm::dyn_cast_if_present(adaptor.getIndex());
1516 if (!lvlIndex)
1517 return {};
1518
1519 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1522
1523
1524
1525 return {};
1526 }
1527
1528
1529 auto getIndexAttr = [this](int64_t lvlSz) {
1531 };
1532
1534 if (!ShapedType::isDynamic(lvlShape[lvl]))
1535 return getIndexAttr(lvlShape[lvl]);
1536
1537 return {};
1538 }
1539
1541 SparseTensorEncodingAttr dstEnc, Value source) {
1545 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1546 auto dstTp =
1548 return build(odsBuilder, odsState, dstTp, source);
1549 }
1550
1556
1557 if (srcLvlTps.size() != dstLvlTps.size())
1558 return emitError("Level rank mismatch between source/dest tensors");
1559
1560 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1561 if (srcLvlTp != dstLvlTp)
1562 return emitError("Level type mismatch between source/dest tensors");
1563
1566 return emitError("Crd/Pos width mismatch between source/dest tensors");
1567 }
1568
1570 return emitError("Element type mismatch between source/dest tensors");
1571
1574 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1575 if (srcLvlSz != dstLvlSz) {
1576
1577
1578
1579 return emitError("Level size mismatch between source/dest tensors");
1580 }
1581 }
1582
1583 return success();
1584 }
1585
1586 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1588 return getSource();
1589
1590 if (auto def = getSource().getDefiningOp()) {
1591
1592 if (def.getSource().getType() == getDest().getType())
1593 return def.getSource();
1594 }
1595 return {};
1596 }
1597
1598 template
1603 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1605 Type elemTp = nullptr;
1606 bool withStride = false;
1607 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1609 } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1610 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1612 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1613 withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1614 } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1616 }
1617
1618 assert(elemTp && "unhandled operation.");
1620 bufShape.push_back(ShapedType::kDynamic);
1621
1623 stt.getContext(), ShapedType::kDynamic,
1624 {ShapedType::kDynamic})
1625 : StridedLayoutAttr();
1626 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1627 return success();
1628 }
1629
1632 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1633 return emitError("requested level is out of bounds");
1635 return emitError("unexpected type for positions");
1636 return success();
1637 }
1638
1639 LogicalResult
1640 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional loc,
1641 ValueRange ops, DictionaryAttr attr,
1644 return inferSparseBufferType(ops, attr, prop, region, ret);
1645 }
1646
1649 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1650 return emitError("requested level is out of bounds");
1652 return emitError("unexpected type for coordinates");
1653 return success();
1654 }
1655
1656 LogicalResult
1657 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional loc,
1658 ValueRange ops, DictionaryAttr attr,
1661 return inferSparseBufferType(ops, attr, prop, region, ret);
1662 }
1663
1667 return emitError("expected sparse tensor with a COO region");
1668 return success();
1669 }
1670
1671 LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1675 return inferSparseBufferType(ops, attr, prop, region,
1676 ret);
1677 }
1678
1683 return emitError("unexpected mismatch in element types");
1684 return success();
1685 }
1686
1687 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1688 std::optional loc,
1689 ValueRange ops, DictionaryAttr attr,
1693 return inferSparseBufferType(ops, attr, prop, region, ret);
1694 }
1695
1697 auto rank = getSlice().getType().getRank();
1698 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1699 return emitError("requested dimension out of bound");
1700 return success();
1701 }
1702
1704 auto rank = getSlice().getType().getRank();
1705 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1706 return emitError("requested dimension out of bound");
1707 return success();
1708 }
1709
1712 getSpecifier(), getOperation());
1713 }
1714
1715 template
1717 return op.getSpecifier().template getDefiningOp();
1718 }
1719
1720 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1721 const StorageSpecifierKind kind = getSpecifierKind();
1722 const auto lvl = getLevel();
1724 if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1725 return op.getValue();
1726 return {};
1727 }
1728
1731 getSpecifier(), getOperation());
1732 }
1733
1734 template
1736 const char *regionName,
1739 unsigned expectedNum = inputTypes.size();
1740 if (numArgs != expectedNum)
1741 return op->emitError() << regionName << " region must have exactly "
1742 << expectedNum << " arguments";
1743
1744 for (unsigned i = 0; i < numArgs; i++) {
1746 if (typ != inputTypes[i])
1747 return op->emitError() << regionName << " region argument " << (i + 1)
1748 << " type mismatch";
1749 }
1751 YieldOp yield = dyn_cast(term);
1752 if (!yield)
1753 return op->emitError() << regionName
1754 << " region must end with sparse_tensor.yield";
1755 if (!yield.hasSingleResult() ||
1756 yield.getSingleResult().getType() != outputType)
1757 return op->emitError() << regionName << " region yield type mismatch";
1758
1759 return success();
1760 }
1761
1764 Type leftType = getX().getType();
1765 Type rightType = getY().getType();
1766 Type outputType = getOutput().getType();
1767 Region &overlap = getOverlapRegion();
1768 Region &left = getLeftRegion();
1769 Region &right = getRightRegion();
1770
1771
1772
1773 if (!overlap.empty()) {
1775 TypeRange{leftType, rightType}, outputType)))
1776 return failure();
1777 }
1778 if (!left.empty()) {
1780 outputType)))
1781 return failure();
1782 } else if (getLeftIdentity()) {
1783 if (leftType != outputType)
1784 return emitError("left=identity requires first argument to have the same "
1785 "type as the output");
1786 }
1787 if (!right.empty()) {
1789 outputType)))
1790 return failure();
1791 } else if (getRightIdentity()) {
1792 if (rightType != outputType)
1793 return emitError("right=identity requires second argument to have the "
1794 "same type as the output");
1795 }
1796 return success();
1797 }
1798
1800 Type inputType = getX().getType();
1801 Type outputType = getOutput().getType();
1802
1803
1804
1805 Region &present = getPresentRegion();
1806 if (!present.empty()) {
1808 TypeRange{inputType}, outputType)))
1809 return failure();
1810 }
1811 Region &absent = getAbsentRegion();
1812 if (!absent.empty()) {
1814 outputType)))
1815 return failure();
1816
1817 Block *absentBlock = &absent.front();
1818 Block *parent = getOperation()->getBlock();
1819 Value absentVal =
1820 cast(absentBlock->getTerminator()).getSingleResult();
1821 if (auto arg = dyn_cast(absentVal)) {
1822 if (arg.getOwner() == parent)
1823 return emitError("absent region cannot yield linalg argument");
1825 if (!isaarith::ConstantOp(def) &&
1826 (def->getBlock() == absentBlock || def->getBlock() == parent))
1827 return emitError("absent region cannot yield locally computed value");
1828 }
1829 }
1830 return success();
1831 }
1832
1833 bool ConcatenateOp::needsExtraSort() {
1836 return false;
1837
1838 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1840 });
1841
1842
1843
1844
1845 bool directLowerable =
1846 allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1847 return !directLowerable;
1848 }
1849
1852 const Dimension concatDim = getDimension();
1853 const Dimension dimRank = dstTp.getDimRank();
1854
1855 if (getInputs().size() <= 1)
1856 return emitError("Need at least two tensors to concatenate.");
1857
1858 if (concatDim >= dimRank)
1860 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1861 concatDim, dimRank));
1862
1864 const auto i = it.index();
1866 if (srcTp.hasDynamicDimShape())
1867 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1868 const Dimension srcDimRank = srcTp.getDimRank();
1869 if (srcDimRank != dimRank)
1871 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1872 "from the output tensor (rank={2}).",
1873 i, srcDimRank, dimRank));
1874 }
1875
1876 for (Dimension d = 0; d < dimRank; d++) {
1877 const Size dstSh = dstTp.getDimShape()[d];
1878 if (d == concatDim) {
1879 if (!ShapedType::isDynamic(dstSh)) {
1880
1881
1882
1883 Size sumSz = 0;
1884 for (const auto src : getInputs())
1886
1887
1888 if (sumSz != dstSh)
1890 "The concatenation dimension of the output tensor should be the "
1891 "sum of all the concatenation dimensions of the input tensors.");
1892 }
1893 } else {
1894 Size prev = dstSh;
1895 for (const auto src : getInputs()) {
1897 if (!ShapedType::isDynamic(prev) && sh != prev)
1898 return emitError("All dimensions (expect for the concatenating one) "
1899 "should be equal.");
1900 prev = sh;
1901 }
1902 }
1903 }
1904
1905 return success();
1906 }
1907
1910 build(builder, result, curSize, inBuffer, value, Value());
1911 }
1912
1916 if (nValue && nValue.value() < 1)
1917 return emitOpError("n must be not less than 1");
1918 }
1919 return success();
1920 }
1921
1924 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1925 return emitOpError("incorrect number of coordinates");
1926 return success();
1927 }
1928
1929 void ForeachOp::build(
1931 ValueRange initArgs, AffineMapAttr order,
1933 bodyBuilder) {
1934 build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1935
1936 if (!bodyBuilder)
1937 return;
1940
1941
1943
1945
1946 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1947
1949
1951 auto ®ion = *result.regions.front();
1952 Block *bodyBlock =
1953 builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
1954 bodyBuilder(builder, result.location,
1955 bodyBlock->getArguments().slice(0, dimRank),
1957 bodyBlock->getArguments().drop_front(dimRank + 1));
1958 }
1959
1962 const Dimension dimRank = t.getDimRank();
1963 const auto args = getBody()->getArguments();
1964
1965 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1966 return emitError("Level traverse order does not match tensor's level rank");
1967
1968 if (dimRank + 1 + getInitArgs().size() != args.size())
1969 return emitError("Unmatched number of arguments in the block");
1970
1971 if (getNumResults() != getInitArgs().size())
1972 return emitError("Mismatch in number of init arguments and results");
1973
1974 if (getResultTypes() != getInitArgs().getTypes())
1975 return emitError("Mismatch in types of init arguments and results");
1976
1977
1978 auto yield = cast(getBody()->getTerminator());
1979 if (yield.getNumOperands() != getNumResults() ||
1980 yield.getOperands().getTypes() != getResultTypes())
1981 return emitError("Mismatch in types of yield values and results");
1982
1984 for (Dimension d = 0; d < dimRank; d++)
1985 if (args[d].getType() != iTp)
1987 llvm::formatv("Expecting Index type for argument at index {0}", d));
1988
1989 const auto elemTp = t.getElementType();
1990 const auto valueTp = args[dimRank].getType();
1991 if (elemTp != valueTp)
1993 llvm::formatv("Unmatched element type between input tensor and "
1994 "block argument, expected:{0}, got: {1}",
1995 elemTp, valueTp));
1996 return success();
1997 }
1998
1999 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
2002 return getInputCoo();
2003
2004 return {};
2005 }
2006
2010
2012 return emitError("Expected COO sparse tensors only");
2013
2015 return emitError("Unmatched dim2lvl map between input and result COO");
2016
2020 return emitError("Unmatched storage format between input and result COO");
2021
2022 return success();
2023 }
2024
2026 Type inputType = getX().getType();
2027 Region &formula = getRegion();
2029 TypeRange{inputType, inputType}, inputType);
2030 }
2031
2034 Type inputType = getX().getType();
2035 Type boolType = b.getI1Type();
2036 Region &formula = getRegion();
2038 boolType);
2039 }
2040
2044 if (nx < 1)
2045 return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2046
2049 llvm::formatv("Expected a permutation map, got {0}", xPerm));
2050
2051
2052
2054 if (!cn)
2055 return success();
2056
2057
2058 const auto checkDim = [&](Value v, Size minSize,
2059 const char *message) -> LogicalResult {
2061 if (!ShapedType::isDynamic(sh) && sh < minSize)
2063 llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2064 return success();
2065 };
2066 uint64_t n = cn.value();
2067 uint64_t ny = 0;
2068 if (auto nyAttr = getNyAttr())
2069 ny = nyAttr.getInt();
2070 if (failed(checkDim(getXy(), n * (nx + ny),
2071 "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2072 return failure();
2073 for (Value opnd : getYs())
2074 if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
2075 return failure();
2076
2077 return success();
2078 }
2079
2080
2081
2082
2083
2084 IterSpaceType IteratorType::getIterSpaceType() const {
2086 getHiLvl());
2087 }
2088
2089 IteratorType IterSpaceType::getIteratorType() const {
2091 }
2092
2093
2094
2098 return failure();
2099
2102 return failure();
2103 } else {
2104 lvlHi = lvlLo + 1;
2105 }
2106
2107 if (lvlHi <= lvlLo)
2109 "expect larger level upper bound than lower bound");
2110
2111 return success();
2112 }
2113
2114
2115
2117 IntegerAttr &lvlHiAttr) {
2118 Level lvlLo, lvlHi;
2120 return failure();
2121
2124 return success();
2125 }
2126
2127
2128
2130
2131 if (lo + 1 == hi)
2132 p << lo;
2133 else
2134 p << lo << " to " << hi;
2135 }
2136
2137
2138
2140 IntegerAttr lvlHi) {
2141 unsigned lo = lvlLo.getValue().getZExtValue();
2142 unsigned hi = lvlHi.getValue().getZExtValue();
2144 }
2145
2146
2147
2148
2149
2155 unsigned cnt = 0;
2156 ParseResult crdList =
2159 if (parser.parseArgument(definedArgs.emplace_back()))
2160 return failure();
2161 definedSet.set(cnt);
2162 }
2163 cnt += 1;
2164 return success();
2165 });
2166
2167 if (cnt > maxCnt)
2169 "parsed more value than expected.");
2170
2171 if (failed(crdList)) {
2174 "expecting SSA value or \"_\" for level coordinates");
2175 }
2176 assert(definedArgs.size() == definedSet.count());
2177 return success();
2178 }
2179
2183 if (definedSet.empty())
2184 return;
2185
2186 for (unsigned i = 0; i < size; i++) {
2187 if (definedSet[i]) {
2188 p << blocksArgs.front();
2189 blocksArgs = blocksArgs.drop_front();
2190 } else {
2191 p << "_";
2192 }
2193 if (i != size - 1)
2194 p << ", ";
2195 }
2196 assert(blocksArgs.empty());
2197 }
2198
2199 static ParseResult
2202
2206 return failure();
2207
2208
2209 for (auto &coord : coords)
2211
2212
2213 state.addAttribute("crdUsedLvls",
2215 return success();
2216 }
2217
2218 static ParseResult
2224
2225
2228 return failure();
2229
2230 if (iterators.size() != spaces.size())
2233 "mismatch in number of sparse iterators and sparse spaces");
2234
2237 return failure();
2238 size_t numCrds = coords.size();
2239
2240
2242 if (hasIterArgs)
2244 return failure();
2245
2246 blockArgs.append(coords);
2247
2249
2251 return failure();
2252 if (iterSpaceTps.size() != spaces.size())
2254 "mismatch in number of iteration space operands "
2255 "and iteration space types");
2256
2257 for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2258 IterSpaceType spaceTp = llvm::dyn_cast(tp);
2259 if (!spaceTp)
2261 "expected sparse_tensor.iter_space type for "
2262 "iteration space operands");
2263 it.type = spaceTp.getIteratorType();
2264 }
2265
2266 if (hasIterArgs)
2268 return failure();
2269
2270
2272 state.operands))
2273 return failure();
2274
2275 if (hasIterArgs) {
2276
2278 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2281 "mismatch in number of iteration arguments and return values");
2282 }
2283
2284 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2285 it.type = tp;
2286 if (parser.resolveOperand(init, tp, state.operands))
2287 return failure();
2288 }
2289 }
2290 return success();
2291 }
2292
2293 static ParseResult
2297
2298
2301 return failure();
2302
2305 return failure();
2306 size_t numCrds = coords.size();
2307
2308
2311 if (hasIterArgs)
2313 return failure();
2314 blockArgs.append(coords);
2315
2317
2320 return failure();
2321
2322 if (iterSpaceTps.size() != spaces.size())
2324 "mismatch in number of iteration space operands "
2325 "and iteration space types");
2326
2327 if (hasIterArgs)
2329 return failure();
2330
2331
2333 spacesVals))
2334 return failure();
2335 state.operands.append(spacesVals);
2336
2337 if (hasIterArgs) {
2338
2340 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2343 "mismatch in number of iteration arguments and return values");
2344 }
2345
2346 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2347 it.type = tp;
2348 if (parser.resolveOperand(init, tp, state.operands))
2349 return failure();
2350 }
2351 }
2352 return success();
2353 }
2354
2355 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2359
2360 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2363 adaptor.getHiLvl()));
2364 return success();
2365 }
2366
2368 if (getLoLvl() >= getHiLvl())
2369 return emitOpError("expected smaller level low than level high");
2370
2372 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2373 return emitOpError(
2374 "parent iterator should be specified iff level lower bound equals 0");
2375 }
2376
2377 if (pIter) {
2378 IterSpaceType spaceTp = getExtractedSpace().getType();
2379 if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2380 return emitOpError(
2381 "mismatch in parent iterator encoding and iteration space encoding.");
2382
2383 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2384 return emitOpError("parent iterator should be used to extract an "
2385 "iteration space from a consecutive level.");
2386 }
2387
2388 return success();
2389 }
2390
2393 auto itTp = getIterator().getType();
2394
2395 if (stt.getEncoding() != itTp.getEncoding())
2396 return emitOpError("mismatch in tensor encoding and iterator encoding.");
2397
2398 if (stt.getLvlRank() != itTp.getHiLvl())
2399 return emitOpError("must use last-level iterator to extract values. ");
2400
2401 return success();
2402 }
2403
2406
2410 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2411 for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2412 if (auto crd = iterateOp.getLvlCrd(i)) {
2413 if (crd->getUsers().empty())
2414 toRemove.set(crd->getArgNumber());
2415 else
2416 newUsedLvls.set(i);
2417 }
2418 }
2419
2420
2421 if (toRemove.none())
2422 return failure();
2423
2425 iterateOp.setCrdUsedLvls(newUsedLvls);
2426 iterateOp.getBody()->eraseArguments(toRemove);
2428 return success();
2429 }
2430 };
2431
2435 }
2436
2439 unsigned rank = llvm::cast(iterSpace.getType()).getSpaceDim();
2440
2442 return build(builder, odsState, iterSpace, initArgs, set);
2443 }
2444
2449
2457
2458
2459 for (Value v : initArgs)
2461
2462
2463 for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2465
2466
2468 llvm::cast(iterSpace.getType()).getIteratorType(),
2470 }
2471
2475
2478 return failure();
2479 if (iters.size() != 1)
2481 "expected only one iterator/iteration space");
2482
2483 iterArgs.append(iters);
2485 if (parser.parseRegion(*body, iterArgs))
2486 return failure();
2487
2488 IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2489
2490
2492 return failure();
2493
2494 return success();
2495 }
2496
2497
2498
2499
2500
2504 StringRef prefix = "") {
2505 assert(blocksArgs.size() == initializers.size() &&
2506 "expected same length of arguments and initializers");
2507 if (initializers.empty())
2508 return;
2509
2510 p << prefix << '(';
2511 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
2512 p << std::get<0>(it) << " = " << std::get<1>(it);
2513 });
2514 p << ")";
2515 }
2516
2517 template
2519 if (op.getInitArgs().size() != op.getNumResults()) {
2520 return op.emitOpError(
2521 "mismatch in number of loop-carried values and defined values");
2522 }
2523 if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2524 return op.emitOpError("required out-of-bound coordinates");
2525
2526 return success();
2527 }
2528
2531
2533 p << " " << getIterator() << " in " << getIterSpace();
2534 if (!getCrdUsedLvls().empty()) {
2535 p << " at(";
2537 p << ")";
2538 }
2540
2541 p << " : " << getIterSpace().getType() << " ";
2542 if (!getInitArgs().empty())
2544
2545 p << " ";
2546 p.printRegion(getRegion(), false,
2547 !getInitArgs().empty());
2548 }
2549
2550 LogicalResult IterateOp::verifyRegions() {
2551 if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2552 return emitOpError("mismatch in iterator and iteration space type");
2553 if (getNumRegionIterArgs() != getNumResults())
2554 return emitOpError(
2555 "mismatch in number of basic block args and defined values");
2556
2557 auto initArgs = getInitArgs();
2558 auto iterArgs = getRegionIterArgs();
2559 auto yieldVals = getYieldedValues();
2560 auto opResults = getResults();
2561 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2562 opResults.size()})) {
2563 return emitOpError() << "number mismatch between iter args and results.";
2564 }
2565
2566 for (auto [i, init, iter, yield, ret] :
2567 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2568 if (init.getType() != ret.getType())
2569 return emitOpError() << "types mismatch between " << i
2570 << "th iter operand and defined value";
2571 if (iter.getType() != ret.getType())
2572 return emitOpError() << "types mismatch between " << i
2573 << "th iter region arg and defined value";
2574 if (yield.getType() != ret.getType())
2575 return emitOpError() << "types mismatch between " << i
2576 << "th yield value and defined value";
2577 }
2578
2579 return success();
2580 }
2581
2582
2584
2586 return getInitArgsMutable();
2587 }
2588
2590 return getRegion().getArguments().take_front(getNumRegionIterArgs());
2591 }
2592
2593 std::optional<MutableArrayRef> IterateOp::getYieldedValuesMutable() {
2594 return cast<sparse_tensor::YieldOp>(
2595 getRegion().getBlocks().front().getTerminator())
2596 .getResultsMutable();
2597 }
2598
2599 std::optional IterateOp::getLoopResults() { return getResults(); }
2600
2602 return getInitArgs();
2603 }
2604
2607
2608
2609 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2610
2612 }
2613
2616 unsigned numCases) {
2617 unsigned rank =
2618 cast(iterSpaces.front().getType()).getSpaceDim();
2619
2621
2622
2623
2624
2627 return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
2628 initArgs, set, cases,
2629 numCases);
2630 }
2631
2633
2635
2636
2639 return failure();
2640
2643 {static_cast<int32_t>(spaces.size()),
2644 static_cast<int32_t>(result.types.size())}));
2645
2648
2653 return failure();
2654
2656
2658
2659 auto spaceTp = llvm::cast(spaces[definedIdx].getType());
2660 definedIts[i].type = spaceTp.getIteratorType();
2661 }
2662 definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2664 if (parser.parseRegion(*body, definedIts))
2665 return failure();
2666
2667 CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2668 }
2669
2671
2672
2674 return failure();
2675
2676 return success();
2677 }
2678
2680 p << " (";
2681 llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
2682 p << ")";
2683
2684 if (!getCrdUsedLvls().empty()) {
2685 p << " at(";
2687 p << ")";
2688 }
2689
2691
2692 p << " : (" << getIterSpaces().getTypes() << ")";
2693 if (!getInitArgs().empty())
2694 p.printArrowTypeList(getInitArgs().getTypes());
2695
2696 for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2697 p.printNewline();
2698 p << "case ";
2700 getRegionDefinedSpace(idx));
2701 p << " ";
2702 p.printRegion(getRegion(idx), false,
2703 !getInitArgs().empty());
2704 }
2705 }
2706
2707 ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
2708 return cast<sparse_tensor::YieldOp>(
2709 getRegion(regionIdx).getBlocks().front().getTerminator())
2710 .getResults();
2711 }
2712
2713 LogicalResult CoIterateOp::verifyRegions() {
2714 for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2715 if (getNumRegionIterArgs() != getNumResults())
2716 return emitOpError(
2717 "mismatch in number of basic block args and defined values");
2718
2719 auto initArgs = getInitArgs();
2720 auto iterArgs = getRegionIterArgs(r);
2721 auto yieldVals = getYieldedValues(r);
2722 auto opResults = getResults();
2723 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2724 opResults.size()})) {
2725 return emitOpError()
2726 << "number mismatch between iter args and results on " << r
2727 << "th region";
2728 }
2729
2730 for (auto [i, init, iter, yield, ret] :
2731 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2732 if (init.getType() != ret.getType())
2733 return emitOpError()
2734 << "types mismatch between " << i
2735 << "th iter operand and defined value on " << r << "th region";
2736 if (iter.getType() != ret.getType())
2737 return emitOpError() << "types mismatch between " << i
2738 << "th iter region arg and defined value on " << r
2739 << "th region";
2740 if (yield.getType() != ret.getType())
2741 return emitOpError()
2742 << "types mismatch between " << i
2743 << "th yield value and defined value on " << r << "th region";
2744 }
2745 }
2746
2747 auto cases = getRegionDefinedSpaces();
2748 llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2749 if (set.size() != getNumRegions())
2750 return emitOpError("contains duplicated cases.");
2751
2752 return success();
2753 }
2754
2757 I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2758 for (Region &r : getCaseRegions())
2759 if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2760 ret.push_back(&r);
2761
2762 return ret;
2763 }
2764
2765
2766
2767
2768
2769
2770
2774 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2775 return op;
2776 return nullptr;
2777 }
2778
2779 void SparseTensorDialect::initialize() {
2780 addAttributes<
2781 #define GET_ATTRDEF_LIST
2782 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2783 >();
2784 addTypes<
2785 #define GET_TYPEDEF_LIST
2786 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2787 >();
2788 addOperations<
2789 #define GET_OP_LIST
2790 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2791 >();
2792 declarePromisedInterfaces<
2793 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2794 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2795 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2796 }
2797
2798 #define GET_OP_CLASSES
2799 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2800
2801 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isPermutation(std::vector< PermutationTy > permutation)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
bool isUnique(It begin, It end)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static ParseResult parseUsedCoordList(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` hi"orsimply"hi" or simply "hi"orsimply"lo" if hi−hi - hi−lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static LogicalResult verifySparseLoopOp(SparseLoopOp op)
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` hi"orsimply"hi" or simply "hi"orsimply"lo" if hi−hi - hi−lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t >> dimShape)
static ParseResult parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static ParseResult parseOptionalDefinedList(OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to an...
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of (inner = outer, inner2 = outer2,...
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static ParseResult parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
static bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes)
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printArrowTypeList(TypeRange &&types)
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
I64BitSet & set(unsigned i)
A wrapper around RankedTensorType, which has three goals:
MLIRContext * getContext() const
Type getElementType() const
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
ArrayRef< LevelType > getLvlTypes() const
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
LevelType getLvlType(Level l) const
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
bool isWithPosLT(LevelType lt)
bool isOrderedLT(LevelType lt)
std::string toMLIRString(LevelType lt)
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
uint64_t Level
The type of level identifiers and level-ranks.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
uint64_t getN(LevelType lt)
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
llvm::hash_code hash_value(LevelType lt)
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
uint64_t getM(LevelType lt)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
bool isNOutOfMLT(LevelType lt)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
LevelType stripStorageIrrelevantProperties() const