MLIR: lib/Dialect/SparseTensor/IR/Detail/Var.h Source File (original) (raw)
1
2
3
4
5
6
7
8
9 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
10 #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
11
13
15 #include "llvm/ADT/EnumeratedArray.h"
16 #include "llvm/ADT/STLForwardCompat.h"
17 #include "llvm/ADT/SmallBitVector.h"
18 #include "llvm/ADT/StringMap.h"
19
20 namespace mlir {
21 namespace sparse_tensor {
22 namespace ir_detail {
23
24
25
26
27
28
29
30
31
33
35 const auto vk_ = llvm::to_underlying(vk);
36 return 0 <= vk_ && vk_ <= 2;
37 }
38
39
41
42
43
44
45 const auto vk_ = static_cast<int_fast8_t>(llvm::to_underlying(vk));
46 return static_cast<char>(100 + vk_ * (26 - vk_ * 11));
47 }
51
52
53
54 template
55 using VarKindArray = llvm::EnumeratedArray<T, VarKind, VarKind::Level>;
56
57
58
59
60
62 public:
63
65
66 private:
67
68 using Storage = unsigned;
69
70
71
72
73
74
75 static constexpr Num kMaxNum =
77
78 public:
79
80
81
82
83 [[nodiscard]] static constexpr bool isWF_Num(Num n) { return n <= kMaxNum; }
84
85 protected:
86
87
88
89
91 Storage data;
92
93 public:
95 : data((static_cast(n) << 2) |
96 static_cast(llvm::to_underlying(vk))) {
97 assert(isWF(vk) && "unknown VarKind");
98 assert(isWF_Num(n) && "Var::Num is too large");
99 }
100 constexpr bool operator==(Impl other) const { return data == other.data; }
101 constexpr bool operator!=(Impl other) const { return !(*this == other); }
103 constexpr Num getNum() const { return static_cast<Num>(data >> 2); }
104 };
105 static_assert(IsZeroCostAbstraction);
106
107 private:
109
110 protected:
111
113
114 public:
119 }
120
122 constexpr bool operator!=(Var other) const { return !(*this == other); }
123
126
127 template
128 constexpr bool isa() const;
129 template
130 constexpr U cast() const;
131 template
132 constexpr std::optional dyn_cast() const;
133
134 std::string str() const;
135 void print(llvm::raw_ostream &os) const;
137 void dump() const;
138 };
139 static_assert(IsZeroCostAbstraction);
140
142 using Var::Var;
143 public:
147 }
150 };
151 static_assert(IsZeroCostAbstraction);
152
154 using Var::Var;
155 public:
159 }
162 };
163 static_assert(IsZeroCostAbstraction);
164
166 using Var::Var;
167 public:
171 }
174 };
175 static_assert(IsZeroCostAbstraction);
176
177 template
179 if constexpr (std::is_same_v<U, SymVar>)
181 if constexpr (std::is_same_v<U, DimVar>)
183 if constexpr (std::is_same_v<U, LvlVar>)
185 }
186
187 template
189 assert(isa());
190
191 return U(impl);
192 }
193
194 template
196
197 return isa() ? std::make_optional(U(impl)) : std::nullopt;
198 }
199
200
201
203
204
206
207 unsigned impl[3];
208
209 static constexpr unsigned to_index(VarKind vk) {
210 assert(isWF(vk) && "unknown VarKind");
211 return static_cast<unsigned>(llvm::to_underlying(vk));
212 }
213
214 public:
215 constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank)
220 }
224
226 bool operator!=(Ranks const &other) const { return !(*this == other); }
227
232
233 [[nodiscard]] constexpr bool isValid(Var var) const {
235 }
237 };
238 static_assert(IsZeroCostAbstraction);
239
240
241
244
245 public:
247
254 }
255
256
257
259
260
261
265 };
266
267
268
269
270
271
272
273
274
276 public:
277
278
279 enum class ID : unsigned {};
280
281 private:
282 StringRef name;
283 llvm::SMLoc loc;
284 ID id;
285 std::optionalVar::Num num;
286 VarKind kind;
287
288 public:
290 std::optionalVar::Num n = {})
291 : name(name), loc(loc), id(id), num(n), kind(vk) {
292 assert(!name.empty() && "null StringRef");
293 assert(loc.isValid() && "null SMLoc");
294 assert(isWF(vk) && "unknown VarKind");
295 assert((!n || Var::isWF_Num(*n)) && "Var::Num is too large");
296 }
297
298 constexpr StringRef getName() const { return name; }
299 constexpr llvm::SMLoc getLoc() const { return loc; }
302 }
303 constexpr ID getID() const { return id; }
305 constexpr std::optionalVar::Num getNum() const { return num; }
306 constexpr bool hasNum() const { return num.has_value(); }
311 }
312 };
313
314
316
317
319
321
323
324 llvm::StringMapVarInfo::ID ids;
325
327
328 public:
330
331
332
333
334
335
336
337
339
340 return vars[llvm::to_underlying(id)];
341 }
342 VarInfo const *access(std::optionalVarInfo::ID oid) const {
343 return oid ? &access(*oid) : nullptr;
344 }
345
346 private:
348 return const_cast<VarInfo &>(std::as_const(*this).access(id));
349 }
350 VarInfo *access(std::optionalVarInfo::ID oid) {
351 return const_cast<VarInfo *>(std::as_const(*this).access(oid));
352 }
353
354 public:
355
356 std::optionalVarInfo::ID lookup(StringRef name) const;
357
358
359
360
361
362
363
364 std::optional<std::pair<VarInfo::ID, bool>>
365 create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage = false);
366
367
368
369
370
371
372
373 std::optional<std::pair<VarInfo::ID, bool>>
376
377
379
380
381
382
384
386
387
388
389
391
392
393
395 };
396
397
398
399 }
400 }
401 }
402
403 #endif
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
A dimensional identifier appearing in an affine expression.
A symbolic identifier appearing in an affine expression.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
This base class exposes generic asm printer hooks, usable across the various derived printers.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
static constexpr bool classof(Var const *var)
static constexpr VarKind Kind
constexpr DimVar(Num dim)
DimVar(AffineDimExpr dimExpr)
constexpr LvlVar(Num lvl)
static constexpr VarKind Kind
static constexpr bool classof(Var const *var)
LvlVar(AffineDimExpr lvlExpr)
constexpr unsigned getRank(VarKind vk) const
constexpr unsigned getLvlRank() const
bool operator==(Ranks const &other) const
Ranks(VarKindArray< unsigned > const &ranks)
bool operator!=(Ranks const &other) const
constexpr unsigned getDimRank() const
constexpr unsigned getSymRank() const
constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank)
constexpr bool isValid(Var var) const
constexpr SymVar(Num sym)
SymVar(AffineSymbolExpr symExpr)
static constexpr VarKind Kind
static constexpr bool classof(Var const *var)
Var bindUnusedVar(VarKind vk)
Creates a new variable of the given kind and immediately binds it.
VarInfo const & access(VarInfo::ID id) const
Gets the underlying storage for the VarInfo identified by the VarInfo::ID.
Ranks getRanks() const
Returns the current ranks of bound variables.
InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const
std::optional< std::pair< VarInfo::ID, bool > > lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc, VarKind vk)
Looks up or creates a variable according to the given Policy.
Var getVar(VarInfo::ID id) const
Gets the Var identified by the VarInfo::ID, raising an assertion failure if the variable is not bound...
std::optional< VarInfo::ID > lookup(StringRef name) const
Looks up the variable with the given name.
VarInfo const * access(std::optional< VarInfo::ID > oid) const
std::optional< std::pair< VarInfo::ID, bool > > create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage=false)
Creates a new currently-unbound variable.
Var bindVar(VarInfo::ID id)
Binds the given variable to the next free Var::Num for its VarKind.
A record of metadata for/about a variable, used by VarEnv.
constexpr VarKind getKind() const
constexpr ID getID() const
constexpr llvm::SMLoc getLoc() const
constexpr Var getVar() const
constexpr StringRef getName() const
ID
Newtype for unique identifiers of VarInfo records, to ensure they aren't confused with Var::Num.
constexpr bool hasNum() const
constexpr std::optional< Var::Num > getNum() const
Location getLocation(AsmParser &parser) const
constexpr VarInfo(ID id, StringRef name, llvm::SMLoc loc, VarKind vk, std::optional< Var::Num > n={})
Efficient representation of a set of Var.
bool contains(Var var) const
For the contains method: if variables occurring in the method parameter are OOB for the VarSet,...
unsigned getRank(VarKind vk) const
unsigned getDimRank() const
unsigned getLvlRank() const
unsigned getSymRank() const
VarSet(Ranks const &ranks)
void add(Var var)
For the add methods: OOB parameters cause undefined behavior.
The underlying implementation of Var.
constexpr Num getNum() const
constexpr bool operator!=(Impl other) const
constexpr VarKind getKind() const
constexpr Impl(VarKind vk, Num n)
constexpr bool operator==(Impl other) const
A concrete variable, to be used in our variant of AffineExpr.
constexpr Num getNum() const
constexpr Var(Impl impl)
Protected ctor for the RTTI methods to use.
constexpr std::optional< U > dyn_cast() const
Var(VarKind vk, AffineDimExpr var)
constexpr VarKind getKind() const
void print(llvm::raw_ostream &os) const
constexpr bool operator!=(Var other) const
constexpr Var(VarKind vk, Num n)
static constexpr bool isWF_Num(Num n)
Checks whether the number would be accepted by Var(VarKind,Var::Num).
constexpr bool isa() const
Var(AffineSymbolExpr sym)
constexpr bool operator==(Var other) const
unsigned Num
Typedef for the type of variable numbers.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr bool isWF(VarKind vk)
VarKind
The three kinds of variables that Var can be.
llvm::EnumeratedArray< T, VarKind, VarKind::Level > VarKindArray
The type of arrays indexed by VarKind.
constexpr char toChar(VarKind vk)
Gets the ASCII character used as the prefix when printing Var.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
uint64_t Level
The type of level identifiers and level-ranks.
Include the generated interface declarations.