MLIR: lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
11
21
22using namespace mlir;
24
25namespace {
26
27
28
29
30
31
32template <typename SubClass, typename SourceOp>
34 using OpRewritePattern::OpRewritePattern;
35 using OpAdaptor = typename SourceOp::Adaptor;
36
37 LogicalResult matchAndRewrite(SourceOp op,
38 PatternRewriter &rewriter) const override {
39 Location loc = op.getLoc();
40
41
43 SmallVector deMappedIns(op->getOperands());
44 for (Value &in : deMappedIns) {
46 in =
47 ReinterpretMapOp::create(rewriter, loc, stt->getDemappedType(), in);
49 }
50 }
51
52
53 OpAdaptor adaptor(deMappedIns, op);
54 LogicalResult status =
55 static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
57 }
58};
59
60
61struct AffineDimCollector : public AffineExprVisitor {
62 explicit AffineDimCollector(unsigned dimNum) : dims(dimNum) {};
63 void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
64 BitVector dims;
65};
66
67
68struct AffineExprAdmissibleVisitor
70 explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput) {};
71
72
73 void visitAddExpr(AffineBinaryOpExpr expr) {
74 if (isOutput)
75 admissible = false;
76 }
77 void visitMulExpr(AffineBinaryOpExpr expr) {
78 if (isOutput)
79 admissible = false;
80 }
81
82
83 void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
84 void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
85 void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
86 operator bool() { return admissible; }
87
88private:
89 bool admissible = true;
90 bool isOutput;
91};
92
93
94
95
96using InadmissInfo = std::pair<BitVector, BitVector>;
97
98}
99
100
101
102
103
104
106 auto ret = std::make_pair(BitVector(map.getNumResults()),
108 AffineDimCollector collector(map.getNumDims());
109 for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
110 AffineExprAdmissibleVisitor admissible(isOutput);
111 admissible.walkPostOrder(map.getResult(lvl));
112 if (!admissible) {
113
114 ret.first.set(lvl);
115
116 collector.walkPostOrder(map.getResult(lvl));
117 }
118 }
119 ret.second = collector.dims;
120 return ret;
121}
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
154 auto [inAdLvls, usedDims] = info;
155
156
157
158
159
161
162 assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
163 if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
164
165
166
167
168
170 AffineDimCollector usedInLvl(idxMap.getNumDims());
172 usedInLvl.walkPostOrder(e);
173
174 unsigned curUsedDimID = 0;
175 unsigned curUnusedDimID = lvl2Idx.getNumDims();
176
177 BitVector unused = usedInLvl.dims.flip();
178 for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
179 if (unused.test(i))
181 else
182 results.push_back(lvl2Idx.getResult(curUsedDimID++));
183 }
184 lvl2Idx =
185 AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
186 }
187 assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
188
189
190
191
192
193
194 unsigned curRepID = 0;
195 unsigned curOriID = inAdLvls.count();
199
200 for (unsigned l : inAdLvls.set_bits()) {
201
202
203
204
206
207
208
210 AffineDimCollector collector(idxMap.getNumDims());
211 collector.walkPostOrder(lvlExp);
212
213 assert(collector.dims.count() == 1);
214 transItTps.push_back(itTps[collector.dims.find_first()]);
215 }
216
217 for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
218 if (usedDims.test(d)) {
219
220
221
222 results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
223 } else {
224
225
226
228 transItTps.push_back(itTps[d]);
229 }
230 }
231 unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
232
233 itTps.assign(transItTps.begin(), transItTps.end());
235}
236
237
238
239
240
241static std::optional<std::pair<ArrayAttr, ArrayAttr>>
243
247 for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
248 Value tensor = op->getOpOperand(i).get();
250 if (stt && !stt->isIdentity()) {
251 AffineMap dim2Lvl = stt->getDimToLvl();
252
253 idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
254 }
255 }
256
257
258
260 unsigned pos, int64_t lvlSz) {
261 if (ShapedType::isStatic(lvlSz)) {
265
266
267 auto divExp =
269 cstMapping.try_emplace(divExp, c0);
270
271
273 cstMapping.try_emplace(modExp, lvlExp);
274 }
275 };
276
277 unsigned boundedNum = 0;
278
282 for (OpOperand &operand : op->getOpOperands()) {
284
285 if (!stt || !stt->getEncoding())
286 continue;
287
288 unsigned tid = operand.getOperandNumber();
289 bool isOutput = &operand == op.getDpsInitOperand(0);
290 AffineMap idxMap = idxMapArray[tid];
292 auto [inAdLvls, dimExprs] = inAdInfo;
293 for (unsigned d : dimExprs.set_bits()) {
294
295
296
297 if (d < boundedNum)
298 return std::nullopt;
299 }
300
301 if (inAdLvls.count() != 0) {
302
303
306 unsigned position = 0;
307 for (unsigned lvl : inAdLvls.set_bits()) {
308 int64_t lvlSz = lvlShape[lvl];
309 populateCstMapping(cstMapping, position, lvlSz);
310 position++;
311 }
312
314
315
316 for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
317 AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
318 idxMapArray[tid] = transMap.replace(
319 cstMapping, transMap.getNumDims(),
320 0);
321 }
323 boundedNum += inAdLvls.count();
324 }
325 }
326 };
327
329 llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
330 return linalg::IteratorTypeAttr::get(ctx, itTp);
331 });
332
335}
336
337
340 return ReinterpretMapOp::create(builder, val.getLoc(), enc.withoutDimToLvl(),
341 val);
342}
343
344
347 return ReinterpretMapOp::create(builder, val.getLoc(), enc, val);
348}
349
353 assert(outs.size() == types.size());
354 for (auto [r, t] : llvm::zip(ret, types))
355 if (r.getType() != t)
356 r = ReinterpretMapOp::create(rewriter, r.getLoc(), t, r);
357 return ret;
358}
359
360namespace {
361
362
363
364
365
366
367struct GenericOpReinterpretMap
368 : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
369public:
370 using DemapInsRewriter::DemapInsRewriter;
371 LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
372 PatternRewriter &rewriter) const {
373
374
375 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
378 return failure();
379
380
381 auto transMap = translateMap(linalgOp, rewriter);
382 if (!transMap)
384 linalgOp, "the sparse kernel can not be sparsified.");
385
386
387 Value res = linalgOp.getResult(0);
389 auto [idxMap, itTp] = *transMap;
390
392 linalgOp.setIndexingMapsAttr(idxMap);
393 linalgOp.setIteratorTypesAttr(itTp);
394
395 linalgOp.getInputsMutable().assign(adaptor.getInputs());
396 linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
397 res.setType(adaptor.getOutputs()[0].getType());
399
401 if (stt && stt->hasEncoding()) {
402 Value t = genRemap(rewriter, stt->getEncoding(), res);
404 }
406 }
407};
408
409struct GenericOpScheduler : public OpRewritePatternlinalg::GenericOp {
410 GenericOpScheduler(MLIRContext *context,
412 : OpRewritePatternlinalg::GenericOp(context), strategy(strategy) {}
413
414 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
415 PatternRewriter &rewriter) const override {
416 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
419 return failure();
420 }
421
422 const StringRef sorted = "sorted";
423 if (linalgOp->hasAttr(sorted))
424 return failure();
425
426
428 bool isAdmissible = false;
429 AffineMap order;
430
431
432
433
434 const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
435 SortMask::kIncludeDenseInput,
436 SortMask::kIncludeDenseOutput,
437 SortMask::kSparseOnly};
438 for (const SortMask mask : allMasks) {
439 order = scheduler.sort(mask);
440 if (order) {
441 if (isAdmissibleOrder(linalgOp, order)) {
442 isAdmissible = true;
443 break;
444 }
445
446 }
447 }
448
449 if (!order) {
450
451 if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
453 linalgOp, "the sparse kernel can not be scheduled: loop detected.");
454 }
456 }
457
458 if (!isAdmissible) {
460 linalgOp, "the sparse kernel can not be scheduled.");
461 }
462
463
465 linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
466 });
467
468
471
473
474 ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
475 SmallVector curItTypes;
476 curItTypes.reserve(preItTypes.size());
477 for (AffineExpr expr : order.getResults()) {
478 unsigned loopID = llvm::cast(expr).getPosition();
479 curItTypes.push_back(preItTypes[loopID]);
480 }
481
482
484 SmallVector idxMaps = linalgOp.getIndexingMapsArray();
485 for (AffineMap &idxMap : idxMaps)
486 idxMap = idxMap.compose(order);
487
490 linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
492
494 }
495
496private:
497
498 static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
500 return true;
501
502 OpOperand *lhs = linalgOp.getDpsInitOperand(0);
503 unsigned nest = 0;
504 const auto iteratorTypes = linalgOp.getIteratorTypesArray();
505 for (const AffineExpr l : order.getResults()) {
506 unsigned loopId = llvm::cast(l).getPosition();
507 auto itTp =
508 castlinalg::IteratorTypeAttr(linalgOp.getIteratorTypes()[loopId]);
510 break;
511 nest++;
512 }
513
514
515
516 return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
517 };
518
519
520 static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
521 linalg::LinalgOp linalgOp,
522 PatternRewriter &rewriter) {
523
524
525 for (OpOperand *t : linalgOp.getDpsInputOperands()) {
526 Value tval = t->get();
528
529
530 AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
531 bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
532 return !llvm::isa(exp);
533 });
534 if (!srcEnc || hasCompExpr)
535 continue;
536
537
538 AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);
539 if (!order)
540 continue;
541
542
543
544
546 assert(stt.isIdentity());
548
549 idxMap = idxMap.compose(order);
550
551
552
553
554
555
556
557
558 SmallVector<std::pair<unsigned, unsigned>> lvlSeq;
559 for (AffineExpr expr : idxMap.getResults()) {
560 unsigned lvl = llvm::cast(expr).getPosition();
561 lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
562 }
563 llvm::sort(lvlSeq, llvm::less_first());
564 SmallVector perm =
565 llvm::to_vector(llvm::make_second_range(lvlSeq));
567
568 assert(!dimToLvl.isIdentity());
569
570
572 RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
573 Value dst = ConvertOp::create(rewriter, tval.getLoc(), dstTp, tval);
575 linalgOp->setOperand(t->getOperandNumber(), dst);
576 });
577
578
579
581 bufferization::DeallocTensorOp::create(rewriter, dst.getLoc(), dst);
582
584 }
585
586
587 return failure();
588 }
589
590private:
592};
593
594
595
596
597
598template
599struct TensorAllocDemapper : public OpRewritePattern {
600 using OpRewritePattern::OpRewritePattern;
601 LogicalResult matchAndRewrite(AllocOp op,
602 PatternRewriter &rewriter) const override {
604 return failure();
605
606 Location loc = op.getLoc();
608
609 SmallVector maxDimCrds;
610 maxDimCrds.reserve(stt.getDimRank());
611 ValueRange dynSz = op.getDynamicSizes();
612 for (int64_t dimSz : stt.getDimShape()) {
613 if (ShapedType::isDynamic(dimSz)) {
614 Value maxCrd = arith::SubIOp::create(rewriter, loc, dynSz.front(),
616 maxDimCrds.push_back(maxCrd);
617 dynSz = dynSz.drop_front();
618 } else {
619 maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
620 }
621 }
622
623 ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
624 CrdTransDirectionKind::dim2lvl);
625 auto lvlShape = stt.getLvlShape();
626 SmallVector dynLvlSzs;
627 for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
628 if (ShapedType::isDynamic(lvlShape[i])) {
629 Value sz = arith::AddIOp::create(rewriter, loc, maxLvlCrds[i],
631 dynLvlSzs.push_back(sz);
632 }
633 }
634
635 assert(dynSz.empty());
637 op->setOperands(dynLvlSzs);
638 op.getResult().setType(stt.getDemappedType());
641
642 Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
645 }
646};
647
648struct TensorInsertDemapper
649 : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
650 using DemapInsRewriter::DemapInsRewriter;
651 LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
652 PatternRewriter &rewriter) const {
654 return failure();
655
656 Location loc = op.getLoc();
658 ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
659 CrdTransDirectionKind::dim2lvl);
660 auto insertOp = tensor::InsertOp::create(rewriter, loc, op.getScalar(),
661 adaptor.getDest(), lvlCrd);
662
663 Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
666 }
667};
668
669struct SparseAssembleDemapper : public OpRewritePattern {
671 LogicalResult matchAndRewrite(AssembleOp op,
672 PatternRewriter &rewriter) const override {
674 return failure();
675
679 op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
681 Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());
684 }
685};
686
687struct SparseDisassembleDemapper
688 : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
689 using DemapInsRewriter::DemapInsRewriter;
690 LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
691 PatternRewriter &rewriter) const {
693 return failure();
694
697 op.getTensorMutable().assign(adaptor.getTensor());
698 });
700 }
701};
702
703struct ForeachOpDemapper
704 : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
705 using DemapInsRewriter::DemapInsRewriter;
706 LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
707 PatternRewriter &rewriter) const {
708
709
711 return failure();
712
713
714 if (auto constOp = op.getTensor().getDefiningOparith::ConstantOp())
715 if (auto attr = dyn_cast(constOp.getValue()))
716 return failure();
717
718 Location loc = op.getLoc();
719
721 SmallVector prevRetTps(op.getResultTypes());
722
724 op.getTensorMutable().assign(adaptor.getTensor());
725 op.getInitArgsMutable().assign(adaptor.getInitArgs());
726
727 for (auto r : op.getResults())
729 r.setType(stt->getDemappedType());
730
732
733 SmallVector blockArgTps(lvlRank, rewriter.getIndexType());
734 blockArgTps.push_back(srcStt.getElementType());
735 blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
736 adaptor.getInitArgs().getTypes().end());
737 Block *body = op.getBody();
738
740 for (Type t : blockArgTps)
742
743
746
747 ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
748 CrdTransDirectionKind::lvl2dim);
750 body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
752
753 unsigned numInitArgs = op.getInitArgs().size();
755 body->getArgument(lvlRank + numInitArgs + 1));
757
760
761 SmallVector reMappedArgs =
765
766
767
768 if (numInitArgs != 0) {
770 auto yield = llvm::cast(body->getTerminator());
772 stt && !stt->isIdentity()) {
773 Value y =
774 genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
775 YieldOp::create(rewriter, loc, y);
777 }
778 }
780
782 SmallVector outs =
784
785
786
787 for (auto [from, to] : llvm::zip(op.getResults(), outs))
789
791 }
792};
793
794}
795
801 patterns.add(patterns.getContext());
802 patterns.add(patterns.getContext(), strategy);
803 }
806 patterns.add<TensorAllocDemapperbufferization::AllocTensorOp,
807 TensorAllocDemappertensor::EmptyOp, SparseAssembleDemapper,
808 SparseDisassembleDemapper, TensorInsertDemapper,
809 ForeachOpDemapper>(patterns.getContext());
810 }
811}
static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
Definition SparseReinterpretMap.cpp:338
static SmallVector< Value > remapValueRange(OpBuilder &rewriter, TypeRange types, ValueRange outs)
Definition SparseReinterpretMap.cpp:350
static AffineMap genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, SmallVector< utils::IteratorType > &itTps)
Definition SparseReinterpretMap.cpp:151
static std::optional< std::pair< ArrayAttr, ArrayAttr > > translateMap(linalg::GenericOp op, PatternRewriter &rewriter)
Definition SparseReinterpretMap.cpp:242
static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
Definition SparseReinterpretMap.cpp:345
static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput)
Definition SparseReinterpretMap.cpp:105
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
BoolAttr getBoolAttr(bool value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp, sparse_tensor::LoopOrderingStrategy strategy)
Factory method that constructs an iteration graph sorter for the given linalg.generic operation with ...
AffineMap sort(SortMask mask, Value ignored=nullptr)
Returns a permutation that represents the scheduled loop order.
Level getLvlRank() const
Returns the level-rank.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool hasAnySparseOperandOrResult(Operation *op)
Returns true iff MLIR operand has any sparse operand or result.
uint64_t Level
The type of level identifiers and level-ranks.
LoopOrderingStrategy
Defines a strategy for loop ordering during sparse code generation.
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SortMask
Iteration graph sorting mask,.
bool hasAnySparseResult(Operation *op)
Returns true iff MLIR operand has any sparse result.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
ReinterpretMapScope
Defines a scope for reinterpret map pass.
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy=sparse_tensor::LoopOrderingStrategy::kDefault)
Definition SparseReinterpretMap.cpp:796
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...