MLIR: lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h Source File (original) (raw)
1
2
3
4
5
6
7
8
9 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
10 #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
11
13
15 #include "llvm/ADT/STLForwardCompat.h"
16
17 namespace mlir {
18 namespace sparse_tensor {
19 namespace ir_detail {
20
21
23
25 using VK = std::underlying_type_t;
26 return VarKind{2 * static_cast<VK>(!llvm::to_underlying(ek))};
27 }
30
31
33 private:
36
37 public:
39
40
41
42
44 return kind == other.kind && expr == other.expr;
45 }
47 return !(*this == other);
48 }
49 explicit operator bool() const { return static_cast<bool>(expr); }
50
51
52
53
54 template
55 constexpr bool isa() const;
56 template
57 constexpr U cast() const;
58 template
60
61
62
63
67 }
70 assert(expr);
72 }
74 return expr ? expr.getContext() : nullptr;
75 }
76
77
78
79
84 std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr> unpackBinop() const;
85
86
87
89
90 protected:
91
93 };
94 static_assert(IsZeroCostAbstraction);
95
99
100 public:
104 }
106
110 return var ? std::make_optional(var->cast<LvlVar>()) : std::nullopt;
111 }
112 };
113 static_assert(IsZeroCostAbstraction);
114
118
119 public:
123 }
125
129 return var ? std::make_optional(var->cast<DimVar>()) : std::nullopt;
130 }
131 };
132 static_assert(IsZeroCostAbstraction);
133
134 template
136 if constexpr (std::is_same_v<U, DimExpr>)
138 if constexpr (std::is_same_v<U, LvlExpr>)
140 }
141
142 template
144 assert(isa());
145 return U(*this);
146 }
147
148 template
150 return isa() ? U(*this) : U();
151 }
152
153
154
156
158
159
160
162
163
164
165 bool elideExpr = false;
166
167 SparseTensorDimSliceAttr slice;
168
169 public:
171
173
175 bool hasExpr() const { return static_cast<bool>(expr); }
179 expr = newExpr;
180 }
181 constexpr bool canElideExpr() const { return elideExpr; }
183 constexpr SparseTensorDimSliceAttr getSlice() const { return slice; }
184
185
186
187
188
189 [[nodiscard]] bool isValid(Ranks const &ranks) const;
190 };
191
192 static_assert(IsZeroCostAbstraction);
193
194
195
197
199
200
201 bool elideVar = false;
202
204
206
207 public:
209
212 assert(ctx);
213 return ctx;
214 }
215
217 constexpr bool canElideVar() const { return elideVar; }
221
222
223
224 [[nodiscard]] bool isValid(Ranks const &ranks) const;
225 };
226
227 static_assert(IsZeroCostAbstraction);
228
229
231 public:
234
236 unsigned getDimRank() const { return dimSpecs.size(); }
237 unsigned getLvlRank() const { return lvlSpecs.size(); }
240
245 }
246
250
253
254 private:
255
256
257 [[nodiscard]] bool isWF() const;
258
259
260
261
263 assert(expr && getRanks().isValid(expr));
264 dimSpecs[dim].setExpr(expr);
265 }
266
267
268 unsigned symRank;
271 bool mustPrintLvlVars;
272 };
273
274
275
276 }
277 }
278 }
279
280 #endif
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext is the top-level object for a collection of MLIR operations.
LvlVar castLvlVar() const
static constexpr ExprKind Kind
static constexpr bool classof(DimLvlExpr const *expr)
std::optional< LvlVar > dyn_castLvlVar() const
constexpr DimExpr(AffineExpr expr)
SymVar castSymVar() const
constexpr U dyn_cast() const
bool isValid(Ranks const &ranks) const
Checks whether the variables bound/used by this spec are valid with respect to the given ranks.
constexpr bool operator!=(DimLvlExpr other) const
std::optional< SymVar > dyn_castSymVar() const
AffineExprKind getAffineKind() const
Var castDimLvlVar() const
constexpr AffineExpr getAffineExpr() const
constexpr ExprKind getExprKind() const
constexpr bool isa() const
constexpr bool operator==(DimLvlExpr other) const
std::tuple< DimLvlExpr, AffineExprKind, DimLvlExpr > unpackBinop() const
constexpr DimLvlExpr(ExprKind ek, AffineExpr expr)
constexpr VarKind getAllowedVarKind() const
std::optional< Var > dyn_castDimLvlVar() const
MLIRContext * tryGetContext() const
DimLvlMap(unsigned symRank, ArrayRef< DimSpec > dimSpecs, ArrayRef< LvlSpec > lvlSpecs)
SparseTensorDimSliceAttr getDimSlice(Dimension dim) const
AffineMap getDimToLvlMap(MLIRContext *context) const
unsigned getSymRank() const
unsigned getDimRank() const
ArrayRef< LvlSpec > getLvls() const
AffineMap getLvlToDimMap(MLIRContext *context) const
unsigned getRank(VarKind vk) const
const DimSpec & getDim(Dimension dim) const
ArrayRef< DimSpec > getDims() const
unsigned getLvlRank() const
LevelType getLvlType(Level lvl) const
const LvlSpec & getLvl(Level lvl) const
The full dimVar = dimExpr : dimSlice specification for a given dimension.
void setElideExpr(bool b)
constexpr DimExpr getExpr() const
bool isValid(Ranks const &ranks) const
Checks whether the variables bound/used by this spec are valid with respect to the given ranks.
void setExpr(DimExpr newExpr)
constexpr bool canElideExpr() const
MLIRContext * tryGetContext() const
constexpr SparseTensorDimSliceAttr getSlice() const
DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice)
constexpr DimVar getBoundVar() const
constexpr LvlExpr(AffineExpr expr)
static constexpr bool classof(DimLvlExpr const *expr)
static constexpr ExprKind Kind
std::optional< DimVar > dyn_castDimVar() const
DimVar castDimVar() const
The full lvlVar = lvlExpr : lvlType specification for a given level.
bool isValid(Ranks const &ranks) const
Checks whether the variables bound/used by this spec are valid with respect to the given ranks.
LvlSpec(LvlVar var, LvlExpr expr, LevelType type)
constexpr LvlExpr getExpr() const
constexpr bool canElideVar() const
MLIRContext * getContext() const
constexpr LvlVar getBoundVar() const
constexpr LevelType getType() const
constexpr unsigned getRank(VarKind vk) const
A concrete variable, to be used in our variant of AffineExpr.
VarKind
The three kinds of variables that Var can be.
constexpr VarKind getVarKindAllowedInExpr(ExprKind ek)
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
uint64_t Level
The type of level identifiers and level-ranks.
Include the generated interface declarations.
This enum defines all the sparse representations supportable by the SparseTensor dialect.