MLIR: lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
11
21
22 using namespace mlir;
24
25
26
27
28
29 #define CMPI(p, l, r) \
30 (builder.createarith::CmpIOp(loc, arith::CmpIPredicate::p, (l), (r)) \
31 .getResult())
32
33 #define C_IDX(v) (constantIndex(builder, loc, (v)))
34 #define YIELD(vs) (builder.createscf::YieldOp(loc, (vs)))
35 #define ADDI(lhs, rhs) (builder.createarith::AddIOp(loc, (lhs), (rhs)))
36 #define ANDI(lhs, rhs) (builder.createarith::AndIOp(loc, (lhs), (rhs)))
37 #define SUBI(lhs, rhs) (builder.createarith::SubIOp(loc, (lhs), (rhs)))
38 #define MULI(lhs, rhs) (builder.createarith::MulIOp(loc, (lhs), (rhs)))
39 #define REMUI(lhs, rhs) (builder.createarith::RemUIOp(loc, (lhs), (rhs)))
40 #define DIVUI(lhs, rhs) (builder.createarith::DivUIOp(loc, (lhs), (rhs)))
41 #define SELECT(c, l, r) (builder.createarith::SelectOp(loc, (c), (l), (r)))
42
43
44
45
46
47 #ifndef NDEBUG
50 memref = builder.creatememref::CastOp(
54 }
55 #endif
56
57
58
59
60
61
62
63
64
65
70 }
71
76 }
77
79 if (auto f = llvm::dyn_cast(attr); f && f.getValue().isZero())
80 return true;
81 if (auto i = llvm::dyn_cast(attr); i && i.getValue().isZero())
82 return true;
83 return false;
84 }
85
90 return cast(ofr);
91 }
92
94
95
98 if (padOp && stt.has_value() && stt->hasEncoding() &&
99 padOp.getSourceType().getEncoding() == stt->getEncoding() &&
100 stt->getEncoding().isIdentity()) {
101
103 if (matchPattern(padOp.getBody()->getTerminator(),
104 m_Optensor::YieldOp(m_Constant(&padCst))) &&
106 return padOp.getSource();
107 }
108 }
109 return t;
110 }
111
112
113
114
115
117 bool isSparseOut, unsigned numLoops,
120 initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter);
121 }
122
124 bool isSparseOut, unsigned numLoops,
127
128 this->loopTag = loopTag;
129 this->hasOutput = hasOutput;
130 this->isSparseOut = isSparseOut;
131 this->emitStrategy = emitStrategy;
132
133 const unsigned numManifestTensors = ts.size();
134 const unsigned synTensorId = numManifestTensors;
135 const unsigned numTensors = numManifestTensors + 1;
136
137 this->tensors.assign(ts.begin(), ts.end());
138
139 this->valBuffer.assign(numTensors, nullptr);
140 this->lvls.resize(numTensors);
141 this->iters.resize(numTensors);
142 this->spIterVals.resize(numTensors);
143
144
145
146 this->loopStack.reserve(numLoops);
147 this->loopSeqStack.reserve(numLoops);
148
149
150 this->dependentLvlMap.assign(
151 numTensors, std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());
152 this->sliceMeta.assign(
153 numTensors, std::vector<std::vector<std::pair<Value, unsigned>>>());
154 this->levelReducedDep.assign(numTensors, std::vector());
155
156
157 for (TensorId tid = 0; tid < numTensors; tid++) {
159 if (tid == synTensorId) {
160
161
162
163 lvlRank = numLoops;
164 } else {
165 const Value t = tensors[tid];
166
168 continue;
169
173 }
174
175 lvls[tid].resize(lvlRank);
176 iters[tid].resize(lvlRank);
177 spIterVals[tid].resize(lvlRank);
178 loopHighs.assign(numLoops, nullptr);
179
180
181 levelReducedDep[tid].assign(lvlRank, 0);
182 dependentLvlMap[tid].assign(
183 lvlRank, std::vector<std::pair<TensorLevel, unsigned>>());
184 sliceMeta[tid].assign(lvlRank, std::vector<std::pair<Value, unsigned>>());
185 if (dimGetter && !isSynTensor(tid)) {
186 for (Level l = 0; l < lvlRank; l++) {
187 std::vector<std::pair<LoopId, unsigned>> deps = dimGetter(tid, l);
188
189 llvm::sort(deps, llvm::less_first());
190
191 dependentLvlMap[tid][l] = std::move(deps);
192 unsigned depends = dependentLvlMap[tid][l].size();
193 if (depends == 0)
194 continue;
195 sliceMeta[tid][l].reserve(depends);
196 }
197 }
198 }
199 }
200
201 std::unique_ptr
204 Value tensor = tensors[t];
207
209 if (folded != tensor) {
210 auto padOp = tensor.getDefiningOptensor::PadOp();
211 assert(padOp);
212 if (padOp.getPaddedDims().test(l)) {
215 auto padIt = makePaddedIterator(std::move(it), low, high, emitStrategy);
216 return padIt;
217 }
218 }
219
220 if (stt.hasEncoding() && stt.getEncoding().isSlice()) {
224 std::move(it), offset, stride, lvls[t][l]->getSize(), emitStrategy);
225 return slicedIt;
226 }
227
228 return it;
229 }
230
234
235
237 t++) {
238
239
241 const auto rtp = dyn_cast(tensor.getType());
242
243
244 if (!rtp)
245 continue;
246
248 const auto shape = rtp.getShape();
249
250
251
252
253 bool isOutput = isOutputTensor(t);
254 Type elementType = stt.getElementType();
255 if (!stt.hasEncoding()) {
256
258
259
260
261
262 if (llvm::isa_and_nonnulltensor::ExtractSliceOp(tensor.getDefiningOp()))
264
266 builder.createbufferization::ToBufferOp(loc, denseTp, tensor);
267
268 if (isOutput && updater)
269 denseVal = updater(builder, loc, denseVal, tensor);
270
271 valBuffer[t] = denseVal;
272 } else {
273
274
275
276 valBuffer[t] = builder.create(loc, tensor);
277 }
278 }
279
280
281
283 return;
284
285
286 if (synSetter) {
288 for (unsigned i = 0, e = loopHighs.size(); i < e; i++) {
289 Value sz = loopHighs[i] = synSetter(builder, loc, i);
291 lvls[synId][i] = std::move(stl);
292 iters[synId][i].emplace_back(std::move(it));
293 }
294 }
295
296
297
298
299
300
302 t++) {
303
304
306 const auto rtp = dyn_cast(tensor.getType());
307 if (!rtp)
308
309
310 continue;
311
313 const Level lvlRank = stt.getLvlRank();
314
315
316 for (Level l = 0; l < lvlRank; l++) {
317
319 if (!dependentLvlMap[t][l].empty())
320 continue;
321
322 auto it = makeLevelIterator(builder, loc, t, l);
323 iters[t][l].emplace_back(std::move(it));
324 }
325
326
327
328 }
329
330 initSubSectIterator(builder, loc);
331 }
332
333 void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
335 for (TensorId t = 0, e = tensors.size(); t < e; t++) {
336 auto rtp = dyn_cast(tensors[t].getType());
337 if (!rtp)
338 continue;
339
341
342
343 auto remDepStack = dependentLvlMap;
344 std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder;
345 for (Level lvl = 0; lvl < lvlRank; lvl++) {
346
347 std::reverse(remDepStack[t][lvl].begin(), remDepStack[t][lvl].end());
348 for (auto [loop, coeff] : dependentLvlMap[t][lvl])
349 depRedOrder.emplace_back(std::make_tuple(loop, t, lvl));
350 }
351
352 if (depRedOrder.empty())
353 continue;
354
355 llvm::sort(depRedOrder, llvm::less_first());
356
358 for (auto [loop, t, lvl] : depRedOrder) {
359 std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back();
360 assert(curDep.first == loop);
361 remDepStack[t][lvl].pop_back();
362
363 auto lvlIt = makeLevelIterator(builder, loc, t, lvl);
365 if (!parent && lvl > 0) {
366 if (dependentLvlMap[t][lvl - 1].empty()) {
367 parent = iters[t][lvl - 1].back().get();
368 }
369 }
370
371 std::unique_ptr it;
372 if (!remDepStack[t][lvl].empty()) {
373
375 for (auto [loop, stride] : remDepStack[t][lvl]) {
378 }
380 std::move(lvlIt), size, curDep.second,
381 emitStrategy);
382 } else {
383 const SparseIterator &subSectIter = *iters[t][lvl].back();
385 std::move(lvlIt), loopHighs[loop],
386 curDep.second, emitStrategy);
387 }
388 lastIter[t] = it.get();
389 iters[t][lvl].emplace_back(std::move(it));
390 }
391 }
392 }
393
394 void LoopEmitter::categorizeIterators(
397
398
402 raIters.push_back(it);
403 else
404 spIters.push_back(it);
405 }
406
407 llvm::stable_sort(spIters, [](auto lhs, auto rhs) {
408
409 return static_cast<uint8_t>(lhs->kind) > static_cast<uint8_t>(rhs->kind);
410 });
411 }
412
415
416 assert(loopSeqStack.size() == loopStack.size());
417
419
421 levelReducedDep[tid][lvl]++;
422 prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
423 }
424 }
425
426
427 loopSeqStack.emplace_back(C_IDX(0), tidLvls.vec());
428 }
429
431 assert(loopSeqStack.size() == loopStack.size() + 1);
432
433
434
436 levelReducedDep[tid][lvl]--;
437
438 loopSeqStack.pop_back();
439 }
440
444
445
446
447
448 const auto loopId = cast(a).getPosition();
449 return loopStack[loopId].iv;
450 }
452 auto binOp = cast(a);
453 return ADDI(genAffine(builder, loc, binOp.getLHS()),
454 genAffine(builder, loc, binOp.getRHS()));
455 }
457 auto binOp = cast(a);
458 return MULI(genAffine(builder, loc, binOp.getLHS()),
459 genAffine(builder, loc, binOp.getRHS()));
460 }
462 int64_t c = cast(a).getValue();
464 }
465 default:
466 llvm_unreachable("unexpected affine subscript");
467 }
468 }
469
470 std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
473
474
475
476
477
479 auto [lo, hi] = iter.genForCond(builder, loc);
482 if (isParallel) {
483 scf::ParallelOp parOp =
484 builder.createscf::ParallelOp(loc, lo, hi, step, reduc);
486 assert(parOp.getNumReductions() == reduc.size());
487 iv = parOp.getInductionVars()[0];
488
489
490
491
492
493
494
495
496 for (int i = 0, e = reduc.size(); i < e; i++)
497 reduc[i] = parOp.getInitVals()[i];
498 loop = parOp;
499 } else {
500 scf::ForOp forOp = builder.createscf::ForOp(loc, lo, hi, step, reduc);
502 iv = forOp.getInductionVar();
503
504
505 assert(forOp.getNumRegionIterArgs() == reduc.size());
506 for (int i = 0, e = reduc.size(); i < e; i++)
507 reduc[i] = forOp.getRegionIterArg(i);
508 loop = forOp;
509 }
510 assert(loop && iv);
511
515 crd = iter.deref(builder, loc);
516 } else {
517 iter.locate(builder, loc, iv);
518 }
519
520 return {loop, crd};
521 }
522
523 std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
527 needsUniv ? loopSeqStack.back().first : nullptr);
528 }
529
531
532 if (spIters.size() > 1)
533 return false;
534
535 if (spIters.size() == 1)
536 return spIters.front()->iteratableByFor();
537
538 return true;
539 }
540
544 unsigned caseIdx,
546 auto coIterOp = cast(loopStack.back().loop);
549
550 coIterOp.setCasesAttr(builder.getArrayAttr(cases));
551 Region &caseRegion = coIterOp.getRegion(caseIdx);
552 assert(caseRegion.getBlocks().empty() &&
553 "re-initialize the same coiteration case region.");
554
555
556 TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes();
557
560
561 blockArgTps.append(iterArgsTps.begin(), iterArgsTps.end());
562
563 for (auto i : caseBit.bits()) {
564 blockArgTps.push_back(
565 cast(coIterOp.getIterSpaces()[i].getType())
566 .getIteratorType());
567 }
570
571
573
574 loopStack.back().iv = coIterOp.getCrds(caseIdx).front();
575
576 ValueRange iterArgs = coIterOp.getRegionIterArgs(caseIdx);
578
579 ValueRange iters = coIterOp.getRegionIterators(caseIdx);
582 if (caseBit[i]) {
583 spIterVals[tl.first][tl.second] = iters.front();
584 iters = iters.drop_front();
585 } else {
586 spIterVals[tl.first][tl.second] = nullptr;
587 }
588 }
589
590 assert(iters.empty());
591 return &caseRegion;
592 }
593
597 bool needsUniv) {
598
599
600
602 if (tidLvls.size() == 1) {
604 Value t = tensors[tid];
605
606
607 ExtractIterSpaceOp extractSpaceOp =
608 lvl == 0 ? builder.create(loc, t)
609 : builder.create(
610 loc, t, spIterVals[tid][lvl - 1], lvl);
611
612 IterateOp iterOp = builder.create(
613 loc, extractSpaceOp.getExtractedSpace(), reduc);
614 spIterVals[tid][lvl] = iterOp.getIterator();
615
616
617 llvm::copy(iterOp.getRegionIterArgs(), reduc.begin());
618
620 loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(),
621 iterOp.getCrds().front(), loopTag);
622 return iterOp;
623 }
624
625
628 Value t = tensors[tid];
629 ExtractIterSpaceOp extractSpaceOp =
630 lvl == 0 ? builder.create(loc, t)
631 : builder.create(
632 loc, t, spIterVals[tid][lvl - 1], lvl);
633 spaces.push_back(extractSpaceOp.getExtractedSpace());
634 }
635 auto coIterOp = builder.create(loc, spaces, reduc, numCases);
636
637
638 loopStack.emplace_back(tidLvls, coIterOp, nullptr,
639 nullptr, loopTag);
640 return coIterOp;
641 }
642
643
644 tryParallel = tryParallel && reduc.size() <= 1;
645
648 categorizeIterators(tidLvls, raIters, spIters);
649
650
651
652
653
654 needsUniv = !spIters.empty() && needsUniv;
655
656
657
658
660 Value iv = nullptr;
662
663
664
665 if (shouldIteratedByForLoop(spIters) && !needsUniv) {
666 assert(spIters.size() <= 1);
667 SparseIterator &it = spIters.empty() ? *raIters.front() : *spIters.front();
668 std::tie(l, iv) =
669 emitForLoopOverTensorAtLvl(builder, loc, it, reduc, tryParallel);
671 } else {
672 for (auto *it : spIters) {
674 }
675
676 if (needsUniv)
677 for (auto *it : raIters)
679
680 std::tie(l, iv) =
681 emitWhileLoopOverTensorsAtLvls(builder, loc, spIters, reduc, needsUniv);
682 }
683
684
686 it->locate(builder, loc, iv);
687
688
689
690 loopStack.emplace_back(tls, l, builder.getInsertionBlock(), iv, loopTag);
691 return l;
692 }
693
698
700 lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get();
701 auto &it = getCurIterator(tid, lvl);
702 it.genInit(builder, loc, parent);
703
706 it.locate(builder, loc, lvlCrd);
707 }
708
709 void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
711
712
713
714
715 bool hasParent = lvl == 0 || !dependentLvlMap[tid][lvl].empty();
716
718 hasParent ? nullptr : iters[tid][lvl - 1].back().get();
719 auto &it = getCurIterator(tid, lvl);
720 it.genInit(builder, loc, parent);
721
722
725 }
726
729 const LoopInfo &loopInfo = loopStack.back();
731 auto iterateOp = llvm::cast(loopInfo.loop);
732 assert(reduc.size() == iterateOp.getNumResults());
733 rewriter.create<sparse_tensor::YieldOp>(loc, reduc);
734
736
737 llvm::copy(iterateOp.getResults(), reduc.begin());
738 return;
739 }
740 if (auto forOp = llvm::dyn_castscf::ForOp(loopInfo.loop)) {
741 if (!reduc.empty()) {
742 assert(reduc.size() == forOp.getNumResults());
743 rewriter.createscf::YieldOp(loc, reduc);
744 }
745
747
748 llvm::copy(forOp.getResults(), reduc.begin());
749 } else {
750 auto parOp = llvm::castscf::ParallelOp(loopInfo.loop);
751 if (!reduc.empty()) {
752 assert(reduc.size() == parOp.getInitVals().size() && reduc.size() == 1);
753 Operation *redExp = reduc.front().getDefiningOp();
754
755 assert(redExp->getUses().empty());
756
757
758
760
761 Value redVal = parOp.getInitVals().front();
765 else if (redExp->getOperand(1) == redVal)
767
768
769 assert(curVal);
770 #ifndef NDEBUG
771
772
773 unsigned numUsers = 0;
775 if (op->getParentOp() == parOp)
776 numUsers++;
777 }
778 assert(numUsers == 1);
779 #endif
780
782 auto redOp = rewriter.createscf::ReduceOp(loc, curVal);
783
784 Block *redBlock = &redOp.getReductions().front().front();
787
788
790 newRed, [&]() { newRed->setOperands(redBlock->getArguments()); });
791
792 rewriter.eraseOp(redExp);
794 rewriter.createscf::ReduceReturnOp(loc, newRed->getResult(0));
795 }
797
798 for (unsigned i = 0, e = parOp.getResults().size(); i < e; i++)
799 reduc[i] = parOp.getResult(i);
800 }
801 }
802
803 void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
805 const LoopInfo &loopInfo = loopStack.back();
806 auto whileOp = llvm::castscf::WhileOp(loopInfo.loop);
807 Value iv = loopInfo.iv;
809
810
811
812
813
814
816 ValueRange whileRes = whileOp.getResults();
817
821
825
826
827
829 } else {
830
831
832 Value uniIdx = whileOp.getResults().back();
833 it.locate(builder, loc, uniIdx);
834 }
835 }
836
837
838 for (auto &i : reduc) {
839 operands.push_back(i);
840
841 i = whileRes.front();
842 whileRes = whileRes.drop_front();
843 }
844
845
846 if (operands.size() < whileOp.getNumResults()) {
847 assert(operands.size() + 1 == whileOp.getNumResults());
848
849 operands.push_back(ADDI(iv, one));
850
851 loopSeqStack.back().first = whileOp->getResults().back();
852 }
853
854 if (!operands.empty())
856
858 }
859
862
863
864 const LoopInfo &loopInfo = loopStack.back();
867 if (isa(p))
868 rewriter.create<sparse_tensor::YieldOp>(loc, reduc);
869
870
872
874 loopStack.pop_back();
875 return;
876 }
877
878
880 if (!loopInfo.userCodeBlock->empty() &&
881 llvm::isascf::YieldOp(&loopInfo.userCodeBlock->back())) {
882
883
884 assert(loopInfo.userCodeBlock->back().getNumResults() == 0);
886 }
887
888 if (llvm::isascf::WhileOp(loopInfo.loop)) {
889 exitWhileLoop(rewriter, loc, reduc);
890 } else {
891 exitForLoop(rewriter, loc, reduc);
892 }
893
894 assert(loopStack.size() == loopSeqStack.size());
895 loopStack.pop_back();
896 }
897
898
899
900
901
905
906
907
908
910
911
912
913
914
915 if (userReducFirst)
916 ivs.append(reduc.begin(), reduc.end());
917
918
921 ivs.append(itVals.begin(), itVals.end());
922 }
923
924 if (!userReducFirst)
925 ivs.append(reduc.begin(), reduc.end());
926
927
928 if (uniIdx)
929 ivs.push_back(uniIdx);
930
931
932 assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
934 auto whileOp = builder.createscf::WhileOp(loc, types, ivs);
935
937 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
938 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
939
940
943 Value whileCond = nullptr;
944
946 auto [cond, remArgs] = it->genWhileCond(builder, loc, bArgs);
947 whileCond = !whileCond ? cond : ANDI(whileCond, cond);
948 bArgs = remArgs;
949 }
950
951
952 assert(bArgs.size() == reduc.size() + (uniIdx ? 1 : 0));
953 builder.createscf::ConditionOp(loc, whileCond, before->getArguments());
954
955
958
961
962 it->deref(builder, loc);
963 }
964
965
966 for (unsigned i = 0, e = reduc.size(); i < e; i++)
967 reduc[i] = aArgs[i];
968
970
971 if (!uniIdx) {
973 if (min) {
976 } else {
978 }
979 }
980 } else {
981
982 min = whileOp.getAfterArguments().back();
983 }
984
985 return {whileOp, min};
986 }
987
988 #undef CMPI
989 #undef C_IDX
990 #undef YIELD
991 #undef ADDI
992 #undef ANDI
993 #undef SUBI
994 #undef MULI
995 #undef SELECT
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor, Level lvl)
static Value tryFoldTensors(Value t)
static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor, Level lvl)
static bool isIntOrFPZero(Attribute attr)
static LLVM_ATTRIBUTE_UNUSED void dumpIndexMemRef(OpBuilder &builder, Location loc, Value memref)
static Value unFoldOpIntResult(OpBuilder &builder, Location loc, OpFoldResult ofr)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Block represents an ordered list of Operations.
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
BlockArgListType getArguments()
IntegerAttr getI64IntegerAttr(int64_t value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
void exitCurrentLoop(RewriterBase &rewriter, Location loc, MutableArrayRef< Value > reduc={})
Generates code to exit the current loop (e.g., generates yields, forwards loop induction variables,...
void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)
Emits the address for a dense level based on the value evaluated by the provided affine expression.
void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)
Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...
Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)
Generates code to compute an affine expression whose variables are LoopIds (i.e., cast...
Region * enterCurrentCoIterationCase(OpBuilder &builder, Location loc, I64BitSet caseBit, unsigned caseIdx, MutableArrayRef< Value > reduc)
Operation * enterCoIterationOverTensorsAtLvls(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls, unsigned numCases, MutableArrayRef< Value > reduc={}, bool isParallel=false, bool needsUniv=false)
Emits a co-iteration loop over a set of tensors.
TensorLevel makeTensorLevel(TensorId t, Level l) const
Compresses a TensorId and Level into a TensorLevel.
unsigned getNumManifestTensors() const
Gets the total number of manifest tensors (excluding the synthetic tensor).
void initialize(ValueRange tensors, StringAttr loopTag=nullptr, bool hasOutput=false, bool isSparseOut=false, unsigned numLoops=0, DependentLvlGetter getter=nullptr, SparseEmitStrategy emitStrategy=SparseEmitStrategy::kFunctional)
Takes an array of input tensors, which the generated loops will iterate over.
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tidLvl) const
De-compresses a TensorLevel back to a pair of TensorId and Level.
auto unpackTensorLevelRange(ContainerTy &&c) const
Converts a range of TensorLevel to a range of std::pair<TensorId, Level>
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
void exitCurrentLoopSeq(OpBuilder &builder, Location loc)
Exits the current loop sequence, this will reset universal index to 0.
TensorId getSynTensorId() const
Gets the TensorId for synthetic tensor.
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual std::pair< Value, Value > genForCond(OpBuilder &b, Location l)
void genInit(OpBuilder &b, Location l, const SparseIterator *p)
void locate(OpBuilder &b, Location l, Value crd)
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)
ValueRange linkNewScope(ValueRange pos)
ValueRange getCursor() const
Value deref(OpBuilder &b, Location l)
virtual bool randomAccessible() const =0
std::pair< Value, ValueRange > genWhileCond(OpBuilder &b, Location l, ValueRange vs)
A wrapper around RankedTensorType, which has three goals:
Level getLvlRank() const
Returns the level-rank.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)
Helper function to create a TensorLevel object from given tensor.
std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a non-empty subsection created b...
uint64_t Level
The type of level identifiers and level-ranks.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)
Helper function to create a synthetic SparseIterator object that iterates over a dense space specifie...
Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim)
Generates code to retrieve the slice slice for the sparse tensor slice, return a constant if the offs...
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)
bool isZeroRankedTensorOrScalar(Type type)
std::unique_ptr< SparseIterator > makePaddedIterator(std::unique_ptr< SparseIterator > &&sit, Value padLow, Value padHigh, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a padded sparse level (the padde...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
std::unique_ptr< SparseIterator > makeSimpleIterator(OpBuilder &b, Location l, const SparseIterationSpace &iterSpace)
Helper function to create a simple SparseIterator object that iterate over the entire iteration space...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a sliced space,...
Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim)
Generates code to retrieve the slice offset for the sparse tensor slice, return a constant if the off...
std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over the non-empty subsections set.
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.