MLIR: include/mlir/Dialect/SparseTensor/Utils/Merger.h Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13 #ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
14 #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
15
20 #include "llvm/ADT/BitVector.h"
21
22 #include
23
24 namespace mlir {
25 namespace sparse_tensor {
26
27 namespace detail {
28
29
31 }
32
33
34
36
37
39
40
41
42
43
45
46
47
49
50
51
53
54
55
56
58
59
60 using LvlLTPair = std::pair<Level, LevelType>;
61
62
63
65
66
68 enum class Kind;
69
70
74 };
75
76
77
78
79
80
81
82
83
84
85
87
88
90
91 union {
92
94
95
97
98
100 };
101
102
103
104
106
107
108
109
110
111
112
113
115
116
117
119 };
120
121
122
123
124
125
126
127
128
130
135
165 kCIm,
166 kCRe,
168 kBinaryBranch,
169 kUnary,
170 kSelect,
171
176 kDivC,
178 kDivU,
191 kShrU,
193 kBinary,
194 kReduce,
195 kDenseOp,
196 };
197
198
199
200
201
203
205
206
208
209
211
212
213
214
216
217
219 };
220
221
222
223
224
226 public:
227
228
229
230
231
232
233
234
235 Merger(unsigned numInputOutputTensors, unsigned numLoops,
236 unsigned maxLvlRank);
237
238
239
240
241
242
244 assert(isValidTensorId(t));
245 return t;
246 }
247
248
250 assert(isValidLoopId(i));
251 return i;
252 }
253
254
256 assert(isValidTensorId(t) && isValidLoopId(i));
257 return numTensors * i + t;
258 }
259
260
261
262
263
264
266
268
270
272
275
276
277
280
281
284
285
287
288
289
290
291
294
295
296
298
299
300
302
303
304
305
307
308
309
310
315
316
317
318
321
322
323
324
325
327
328
329
330
332
333
334
335
336
338
339
341
342
344
345
347
349
350
351
352 constexpr unsigned getNumTensors() const { return numTensors; }
353
354
355 constexpr unsigned getNumLoops() const { return numLoops; }
356
357
360 }
361
362
364
365
366
368
369
371 const auto &expr = exp(e);
373 }
374
375
377
378
379
380
381
383
384
385
386
387
389
390
391
392 bool hasAnySparse(const BitVector &bits) const;
393
394
395
397
398
400 assert(isValidTensorId(t) && isValidLoopId(i));
401 return lvlTypes[t][i];
402 }
403
404
407 }
408
409
411 assert(isValidLevel(t, lvl));
412 return lvlToLoop[t][lvl];
413 }
414
415
417 assert(isValidTensorId(t) && isValidLoopId(i));
418 return loopToLvl[t][i];
419 }
422 }
423
424
425
427 assert(isValidLevel(t, lvl) && isValidLoopId(i) && isValidLT(lt));
428 lvlTypes[t][i] = lt;
429 loopToLvl[t][i] = lvl;
430 lvlToLoop[t][lvl] = i;
431
432 loopBounds[i] = std::make_pair(t, lvl);
433 }
434
437
438
439
440
441
444
445
447 }
450 const auto &point = lat(p);
451 const auto &bits = simple ? point.simple : point.bits;
452 for (const TensorLoopId b : bits.set_bits()) {
454 const auto optLvl = getLvl(b);
457
458 assert(!optLvl.has_value());
459
461 true);
462 } else {
463 callback(b, t, optLvl, lvlTp, false);
464 }
465 }
466 }
467
468
470
471
473 LevelType lt, unsigned coefficient) {
474 assert(isValidLoopId(i) && isValidLevel(t, lvl));
475 assert(!loopToUnresolvedLvls[i][t].has_value());
476 loopToUnresolvedLvls[i][t] = std::make_pair(lvl, lt);
477 levelToDependentLoop[t][lvl].emplace_back(i, coefficient);
478 }
479
480
482 assert(isValidTensorId(t) && isValidLoopId(i));
483 return loopToUnresolvedLvls[i][t].has_value();
484 }
485
486
487
489 assert(isValidLevel(t, lvl));
490 return levelToDependentLoop[t][lvl];
491 }
492
493
495 assert(isValidLoopId(i));
496 return loopBounds[i];
497 }
498
499
500
504 assert(isValidTensorId(t) && isValidLoopId(i));
505 return loopToUnresolvedLvls[i][t].has_value();
506 }
507
508
509
513 return lt.hasSparseSemantic();
514 }
515 return false;
516 }
517
520 return loopToUnresolvedLvls[loop(b)][tensor(b)]->first;
521 }
522
525 return loopToUnresolvedLvls[loop(b)][tensor(b)]->second;
526 }
527
528
529
530
531
532
533
534
535
536
537
538
539
540
542 assert(isValidExprId(e));
543 return tensorExps[e];
544 }
546 assert(isValidLatPointId(p));
547 return latPoints[p];
548 }
550 assert(isValidLatSetId(s));
551 return latSets[s];
552 }
553
554
556
557
558
560 assert((e).val && "Expression already has an associated value");
561 assert(v && "Trying to assign an undefined value");
562 tensorExps[e].val = v;
563 }
564
565
566
568 assert(exp(e).val && "Expression does not have an associated value");
569 tensorExps[e].val = Value();
570 }
571
572 #ifndef NDEBUG
573
577 void dumpBits(const BitVector &bits) const;
578 #endif
579
580
581
582
584
585
586
588
589
592
593 private:
594
595 constexpr bool isValidTensorId(TensorId t) const { return t < numTensors; }
596 constexpr bool isValidLoopId(LoopId i) const {
598 }
600 assert(levelToDependentLoop[t].size() == lvlToLoop[t].size());
601 return isValidTensorId(t) && lvl < lvlToLoop[t].size();
602 }
603 bool isValidExprId(ExprId e) const {
605 }
606 bool isValidLatPointId(LatPointId p) const {
608 }
609 bool isValidLatSetId(LatSetId s) const {
611 }
612 bool maybeZero(ExprId e) const;
613 bool isInvariant(ExprId e) const {
615 }
616 Type inferType(ExprId e, Value src) const;
617
618
619
620
621
622 std::pair<std::optional, bool> buildTensorExp(linalg::GenericOp op,
623 Value v);
624
625
627 const TensorId syntheticTensor;
628 const unsigned numTensors;
629 const unsigned numLoops;
630 bool hasSparseOut;
631
632
633
634
635
636
637
638
639 std::vector<std::vector> lvlTypes;
640
641
642 std::vector<std::vector<std::optional>> loopToLvl;
643
644
645 std::vector<std::vector<std::optional>> lvlToLoop;
646
647
648
649
650
651
652 std::vector<std::vector<std::optional>> loopToUnresolvedLvls;
653
654
655
656
657
658 std::vector<std::vector<std::vector>> levelToDependentLoop;
659
660
661 std::vector<std::pair<TensorId, Level>> loopBounds;
662
666 };
667
668 }
669 }
670
671 #endif
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
Attributes are known-constant values of operations.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation is the basic unit of execution within MLIR.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A class to handle all iteration lattice operations.
void setHasSparseOut(bool s)
Sets whether the output tensor is sparse or not.
constexpr unsigned getNumLoops() const
Gets the total number of loops (native loops + filter loops).
LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1, Operation *op=nullptr)
Computes a single conjunction of two lattice points by taking the "union" of LoopId (effectively cons...
LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Disjunctive merge of two lattice sets: (s0 /\_op s1, s0, s1).
Level getLoopDependentLevel(TensorLoopId b) const
std::optional< Level > getLvl(TensorId t, LoopId i) const
Gets the level number of the the tth tensor on ith loop.
constexpr bool isOutTensor(TensorLoopId b, LoopId i) const
Returns true if b is the ith loop of the output tensor.
bool isSingleCondition(TensorId t, ExprId e) const
Returns true if given tensor iterates only in the given tensor expression.
bool hasSparseIdxReduction(const BitVector &bits) const
Returns true if bits contains a dependent index reduction condition on sparse levels.
bool expContainsTensor(ExprId e, TensorId t) const
Returns true if the expression contains the tensor as an operand.
LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero)
Maps the binary operator to the same operation but with one of its operand set to zero,...
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a sparse tensor level contains non-trivial index expressio...
void dumpBits(const BitVector &bits) const
bool hasExprValue(ExprId e) const
Checks whether the given expression has an associated value.
void foreachTensorLoopId(LatPointId p, bool simple, ForeachTensorLoopIdCallback callback) const
LatSetId addSet()
Constructs a new (initially empty) set, and returns its identifier.
std::optional< LoopId > getLoopId(TensorId t, Level lvl) const
Gets the loop identifier for the lvlth level of the tth tensor.
std::pair< TensorId, Level > getLoopDefiningLvl(LoopId i) const
Returns the defining [tid, lvl] for the loop.
BitVector simplifyCond(LatSetId s, LatPointId p)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
bool hasNegateOnOut(ExprId e) const
Returns true if the expression contains a negation on output tensor.
constexpr unsigned getNumTensors() const
Gets the total number of tensors (including the output-tensor and synthetic-tensor).
bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a tensor level contains non-trivial index expression.
LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1)
Disjunctive merge of two lattice sets and also set one of the operand to zero: (s0 /\_op s1 (e0 op e1...
void dumpSet(LatSetId s) const
void dumpLat(LatPointId p) const
LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, bool includeRight, TensorExp::Kind rtrans, Operation *opright)
Disjunctive merge of two lattice sets with custom handling of the overlap, left, and right regions.
ExprId addTensorExp(TensorId t)
Constructs a new tensor expression, and returns its identifier.
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Conjunctive merge of two lattice sets: (s0 /\_op s1).
ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1=detail::kInvalidId, Operation *op=nullptr, Attribute attr=nullptr)
Constructs a new unary or binary expression, and returns its identifier.
ExprId addSynZeroExp()
Constructs a new synthetic zero expression.
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
std::optional< ExprId > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)
Sets the level number and level-type of the tth tensor on ith loop.
LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v=Value(), Operation *op=nullptr, Attribute attr=nullptr)
Maps the unary operator over the lattice set of the operand, i.e.
void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const
Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...
std::optional< Level > getLvl(TensorLoopId b) const
ArrayRef< LatPointId > set(LatSetId s) const
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
constexpr TensorId tensor(TensorLoopId b) const
Gets the tensor-identifier of the TensorLoopId.
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)
Establishes the two-way map that i <-> <t, lvl, lt>.
void dumpExp(ExprId e) const
Print methods (for debugging).
LevelType getLvlType(TensorLoopId b) const
Gets the level-type of the TensorLoopId.
Merger(unsigned numInputOutputTensors, unsigned numLoops, unsigned maxLvlRank)
Constructs a merger for the given number of tensors and loops.
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
void clearExprValue(ExprId e)
Clears the value associated with the expression.
std::vector< LoopCoeffPair > & getDependentLoops(TensorId t, Level lvl)
Returns the list of loop indices which appear in the non-trivial index expression on t_l,...
LatPointId addLat(TensorId t, LoopId i, ExprId e)
Constructs a new iteration lattice point, and returns its identifier.
ExprId addLoopVarExp(LoopId i)
Constructs a new loop-variable expression, and returns its identifier.
constexpr TensorId getSynTensorID() const
Gets the synthetic tensor's identifier (used for all invariant tensor expressions).
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
const TensorExp & exp(ExprId e) const
Convenience getters to immediately access the stored nodes.
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
const LatPoint & lat(LatPointId p) const
constexpr TensorId getOutTensorID() const
Gets the output tensor's identifier.
bool onlyDenseDiff(LatPointId p0, LatPointId p1) const
Returns true if p0 and p1 only differ in dense.
ExprId addInvariantExp(Value v)
Constructs a new invariant expression, and returns its identifier.
constexpr TensorId makeTensorId(unsigned t) const
Safely converts the argument to a tensor identifier.
LevelType getLoopDependentLevelType(TensorLoopId b) const
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const
Safely converts the arguments to a pair of (tensor,loop) identifiers.
bool expIsTensor(ExprId e, TensorId t) const
Returns true if the expression is (kTensor t).
void setExprValue(ExprId e, Value v)
Sets the expression to have the associated value.
bool hasDependentLvl(LoopId i, TensorId t)
Whether the loop has dependent slice.
@ Type
An inlay hint that for a type annotation.
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
unsigned LatSetId
LatSet identifiers.
std::pair< Level, LevelType > LvlLTPair
A pair of level and its corresponding LevelType of a tensor.
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
uint64_t Level
The type of level identifiers and level-ranks.
unsigned LoopId
Loop identifiers.
bool isValidLT(LevelType lt)
unsigned ExprId
TensorExp identifiers.
unsigned LatPointId
LatPoint identifiers.
std::pair< LoopId, unsigned > LoopCoeffPair
A pair of loop id and its coefficients.
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Include the generated interface declarations.
LatPoint(const BitVector &bits, ExprId e)
Construct a lattice point from the given set of TensorLoopIds.
ExprId exp
Identifier of the tensor expression.
BitVector bits
Conjunction of all TensorLoopIds involved in the tensor expression.
BitVector simple
Simplified conjunction of TensorLoopId as bitvector.
LatPoint(unsigned size, ExprId e)
Construct a lattice point with the empty set of TensorLoopIds.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Child subexpressions for non-leaf expressions.
Tensor expression. Represents an MLIR expression in tensor index notation.
LoopId loop
kLoopVar expressions simply have a loop identifier.
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Kind
Tensor expression kind.
Children children
All other expressions hold the ExprIds of their children.
Attribute attr
An optional attribute that is required to determine the semantics of the operations.
TensorId tensor
kTensor expressions simply have a tensor identifier.
Kind kind
Tensor expression kind.
Operation * op
Code blocks used by semirings.
TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a)
The x parameter has different types depending on the value of the k parameter.