MLIR: lib/Dialect/SparseTensor/IR/Detail/Var.h Source File (original) (raw)

1

2

3

4

5

6

7

8

9 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H

10 #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H

11

13

15 #include "llvm/ADT/EnumeratedArray.h"

16 #include "llvm/ADT/STLForwardCompat.h"

17 #include "llvm/ADT/SmallBitVector.h"

18 #include "llvm/ADT/StringMap.h"

19

20 namespace mlir {

21 namespace sparse_tensor {

22 namespace ir_detail {

23

24

25

26

27

28

29

30

31

33

35 const auto vk_ = llvm::to_underlying(vk);

36 return 0 <= vk_ && vk_ <= 2;

37 }

38

39

41

42

43

44

45 const auto vk_ = static_cast<int_fast8_t>(llvm::to_underlying(vk));

46 return static_cast<char>(100 + vk_ * (26 - vk_ * 11));

47 }

51

52

53

54 template

55 using VarKindArray = llvm::EnumeratedArray<T, VarKind, VarKind::Level>;

56

57

58

59

60

62 public:

63

64 using Num = unsigned;

65

66 private:

67

68 using Storage = unsigned;

69

70

71

72

73

74

75 static constexpr Num kMaxNum =

77

78 public:

79

80

81

82

83 [[nodiscard]] static constexpr bool isWF_Num(Num n) { return n <= kMaxNum; }

84

85 protected:

86

87

88

89

91 Storage data;

92

93 public:

95 : data((static_cast(n) << 2) |

96 static_cast(llvm::to_underlying(vk))) {

97 assert(isWF(vk) && "unknown VarKind");

98 assert(isWF_Num(n) && "Var::Num is too large");

99 }

100 constexpr bool operator==(Impl other) const { return data == other.data; }

101 constexpr bool operator!=(Impl other) const { return !(*this == other); }

103 constexpr Num getNum() const { return static_cast<Num>(data >> 2); }

104 };

105 static_assert(IsZeroCostAbstraction);

106

107 private:

109

110 protected:

111

113

114 public:

119 }

120

122 constexpr bool operator!=(Var other) const { return !(*this == other); }

123

126

127 template

128 constexpr bool isa() const;

129 template

130 constexpr U cast() const;

131 template

132 constexpr std::optional dyn_cast() const;

133

134 std::string str() const;

135 void print(llvm::raw_ostream &os) const;

137 void dump() const;

138 };

139 static_assert(IsZeroCostAbstraction);

140

142 using Var::Var;

143 public:

147 }

150 };

151 static_assert(IsZeroCostAbstraction);

152

154 using Var::Var;

155 public:

159 }

162 };

163 static_assert(IsZeroCostAbstraction);

164

166 using Var::Var;

167 public:

171 }

174 };

175 static_assert(IsZeroCostAbstraction);

176

177 template

179 if constexpr (std::is_same_v<U, SymVar>)

181 if constexpr (std::is_same_v<U, DimVar>)

183 if constexpr (std::is_same_v<U, LvlVar>)

185 }

186

187 template

189 assert(isa());

190

191 return U(impl);

192 }

193

194 template

196

197 return isa() ? std::make_optional(U(impl)) : std::nullopt;

198 }

199

200

201

203

204

206

207 unsigned impl[3];

208

209 static constexpr unsigned to_index(VarKind vk) {

210 assert(isWF(vk) && "unknown VarKind");

211 return static_cast<unsigned>(llvm::to_underlying(vk));

212 }

213

214 public:

215 constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank)

220 }

224

226 bool operator!=(Ranks const &other) const { return !(*this == other); }

227

232

233 [[nodiscard]] constexpr bool isValid(Var var) const {

235 }

237 };

238 static_assert(IsZeroCostAbstraction);

239

240

241

244

245 public:

247

254 }

255

256

257

259

260

261

265 };

266

267

268

269

270

271

272

273

274

276 public:

277

278

279 enum class ID : unsigned {};

280

281 private:

282 StringRef name;

283 llvm::SMLoc loc;

284 ID id;

285 std::optionalVar::Num num;

286 VarKind kind;

287

288 public:

290 std::optionalVar::Num n = {})

291 : name(name), loc(loc), id(id), num(n), kind(vk) {

292 assert(!name.empty() && "null StringRef");

293 assert(loc.isValid() && "null SMLoc");

294 assert(isWF(vk) && "unknown VarKind");

295 assert((!n || Var::isWF_Num(*n)) && "Var::Num is too large");

296 }

297

298 constexpr StringRef getName() const { return name; }

299 constexpr llvm::SMLoc getLoc() const { return loc; }

302 }

303 constexpr ID getID() const { return id; }

305 constexpr std::optionalVar::Num getNum() const { return num; }

306 constexpr bool hasNum() const { return num.has_value(); }

311 }

312 };

313

314

316

317

319

321

323

324 llvm::StringMapVarInfo::ID ids;

325

327

328 public:

330

331

332

333

334

335

336

337

339

340 return vars[llvm::to_underlying(id)];

341 }

342 VarInfo const *access(std::optionalVarInfo::ID oid) const {

343 return oid ? &access(*oid) : nullptr;

344 }

345

346 private:

348 return const_cast<VarInfo &>(std::as_const(*this).access(id));

349 }

350 VarInfo *access(std::optionalVarInfo::ID oid) {

351 return const_cast<VarInfo *>(std::as_const(*this).access(oid));

352 }

353

354 public:

355

356 std::optionalVarInfo::ID lookup(StringRef name) const;

357

358

359

360

361

362

363

364 std::optional<std::pair<VarInfo::ID, bool>>

365 create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage = false);

366

367

368

369

370

371

372

373 std::optional<std::pair<VarInfo::ID, bool>>

376

377

379

380

381

382

384

386

387

388

389

391

392

393

395 };

396

397

398

399 }

400 }

401 }

402

403 #endif

union mlir::linalg::@1203::ArityGroupAndKind::Kind kind

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

A dimensional identifier appearing in an affine expression.

A symbolic identifier appearing in an affine expression.

This base class exposes generic asm parser hooks, usable across the various derived parsers.

virtual Location getEncodedSourceLoc(SMLoc loc)=0

Re-encode the given source location as an MLIR location and return it.

This base class exposes generic asm printer hooks, usable across the various derived printers.

This class represents a diagnostic that is inflight and set to be reported.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

static constexpr bool classof(Var const *var)

static constexpr VarKind Kind

constexpr DimVar(Num dim)

DimVar(AffineDimExpr dimExpr)

constexpr LvlVar(Num lvl)

static constexpr VarKind Kind

static constexpr bool classof(Var const *var)

LvlVar(AffineDimExpr lvlExpr)

constexpr unsigned getRank(VarKind vk) const

constexpr unsigned getLvlRank() const

bool operator==(Ranks const &other) const

Ranks(VarKindArray< unsigned > const &ranks)

bool operator!=(Ranks const &other) const

constexpr unsigned getDimRank() const

constexpr unsigned getSymRank() const

constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank)

constexpr bool isValid(Var var) const

constexpr SymVar(Num sym)

SymVar(AffineSymbolExpr symExpr)

static constexpr VarKind Kind

static constexpr bool classof(Var const *var)

Var bindUnusedVar(VarKind vk)

Creates a new variable of the given kind and immediately binds it.

VarInfo const & access(VarInfo::ID id) const

Gets the underlying storage for the VarInfo identified by the VarInfo::ID.

Ranks getRanks() const

Returns the current ranks of bound variables.

InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const

std::optional< std::pair< VarInfo::ID, bool > > lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc, VarKind vk)

Looks up or creates a variable according to the given Policy.

Var getVar(VarInfo::ID id) const

Gets the Var identified by the VarInfo::ID, raising an assertion failure if the variable is not bound...

std::optional< VarInfo::ID > lookup(StringRef name) const

Looks up the variable with the given name.

VarInfo const * access(std::optional< VarInfo::ID > oid) const

std::optional< std::pair< VarInfo::ID, bool > > create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage=false)

Creates a new currently-unbound variable.

Var bindVar(VarInfo::ID id)

Binds the given variable to the next free Var::Num for its VarKind.

A record of metadata for/about a variable, used by VarEnv.

constexpr VarKind getKind() const

constexpr ID getID() const

constexpr llvm::SMLoc getLoc() const

constexpr Var getVar() const

constexpr StringRef getName() const

ID

Newtype for unique identifiers of VarInfo records, to ensure they aren't confused with Var::Num.

constexpr bool hasNum() const

constexpr std::optional< Var::Num > getNum() const

Location getLocation(AsmParser &parser) const

constexpr VarInfo(ID id, StringRef name, llvm::SMLoc loc, VarKind vk, std::optional< Var::Num > n={})

Efficient representation of a set of Var.

bool contains(Var var) const

For the contains method: if variables occurring in the method parameter are OOB for the VarSet,...

unsigned getRank(VarKind vk) const

unsigned getDimRank() const

unsigned getLvlRank() const

unsigned getSymRank() const

VarSet(Ranks const &ranks)

void add(Var var)

For the add methods: OOB parameters cause undefined behavior.

The underlying implementation of Var.

constexpr Num getNum() const

constexpr bool operator!=(Impl other) const

constexpr VarKind getKind() const

constexpr Impl(VarKind vk, Num n)

constexpr bool operator==(Impl other) const

A concrete variable, to be used in our variant of AffineExpr.

constexpr Num getNum() const

constexpr Var(Impl impl)

Protected ctor for the RTTI methods to use.

constexpr std::optional< U > dyn_cast() const

Var(VarKind vk, AffineDimExpr var)

constexpr VarKind getKind() const

void print(llvm::raw_ostream &os) const

constexpr bool operator!=(Var other) const

constexpr Var(VarKind vk, Num n)

static constexpr bool isWF_Num(Num n)

Checks whether the number would be accepted by Var(VarKind,Var::Num).

constexpr bool isa() const

Var(AffineSymbolExpr sym)

constexpr bool operator==(Var other) const

unsigned Num

Typedef for the type of variable numbers.

The OpAsmOpInterface, see OpAsmInterface.td for more details.

constexpr bool isWF(VarKind vk)

VarKind

The three kinds of variables that Var can be.

llvm::EnumeratedArray< T, VarKind, VarKind::Level > VarKindArray

The type of arrays indexed by VarKind.

constexpr char toChar(VarKind vk)

Gets the ASCII character used as the prefix when printing Var.

uint64_t Dimension

The type of dimension identifiers and dimension-ranks.

uint64_t Level

The type of level identifiers and level-ranks.

Include the generated interface declarations.