MLIR: include/mlir/Dialect/SparseTensor/Utils/Merger.h Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13 #ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_

14 #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_

15

20 #include "llvm/ADT/BitVector.h"

21

22 #include

23

24 namespace mlir {

25 namespace sparse_tensor {

26

27 namespace detail {

28

29

31 }

32

33

34

36

37

39

40

41

42

43

45

46

47

49

50

51

53

54

55

56

58

59

60 using LvlLTPair = std::pair<Level, LevelType>;

61

62

63

65

66

68 enum class Kind;

69

70

74 };

75

76

77

78

79

80

81

82

83

84

85

87

88

90

91 union {

92

94

95

97

98

100 };

101

102

103

104

106

107

108

109

110

111

112

113

115

116

117

119 };

120

121

122

123

124

125

126

127

128

130

135

165 kCIm,

166 kCRe,

168 kBinaryBranch,

169 kUnary,

170 kSelect,

171

176 kDivC,

178 kDivU,

191 kShrU,

193 kBinary,

194 kReduce,

195 kDenseOp,

196 };

197

198

199

200

201

203

205

206

208

209

211

212

213

214

216

217

219 };

220

221

222

223

224

226 public:

227

228

229

230

231

232

233

234

235 Merger(unsigned numInputOutputTensors, unsigned numLoops,

236 unsigned maxLvlRank);

237

238

239

240

241

242

244 assert(isValidTensorId(t));

245 return t;

246 }

247

248

250 assert(isValidLoopId(i));

251 return i;

252 }

253

254

256 assert(isValidTensorId(t) && isValidLoopId(i));

257 return numTensors * i + t;

258 }

259

260

261

262

263

264

266

268

270

272

275

276

277

280

281

284

285

287

288

289

290

291

294

295

296

298

299

300

302

303

304

305

307

308

309

310

315

316

317

318

321

322

323

324

325

327

328

329

330

332

333

334

335

336

338

339

341

342

344

345

347

349

350

351

352 constexpr unsigned getNumTensors() const { return numTensors; }

353

354

355 constexpr unsigned getNumLoops() const { return numLoops; }

356

357

360 }

361

362

364

365

366

368

369

371 const auto &expr = exp(e);

373 }

374

375

377

378

379

380

381

383

384

385

386

387

389

390

391

392 bool hasAnySparse(const BitVector &bits) const;

393

394

395

397

398

400 assert(isValidTensorId(t) && isValidLoopId(i));

401 return lvlTypes[t][i];

402 }

403

404

407 }

408

409

411 assert(isValidLevel(t, lvl));

412 return lvlToLoop[t][lvl];

413 }

414

415

417 assert(isValidTensorId(t) && isValidLoopId(i));

418 return loopToLvl[t][i];

419 }

422 }

423

424

425

427 assert(isValidLevel(t, lvl) && isValidLoopId(i) && isValidLT(lt));

428 lvlTypes[t][i] = lt;

429 loopToLvl[t][i] = lvl;

430 lvlToLoop[t][lvl] = i;

431

432 loopBounds[i] = std::make_pair(t, lvl);

433 }

434

437

438

439

440

441

444

445

447 }

450 const auto &point = lat(p);

451 const auto &bits = simple ? point.simple : point.bits;

452 for (const TensorLoopId b : bits.set_bits()) {

454 const auto optLvl = getLvl(b);

457

458 assert(!optLvl.has_value());

459

461 true);

462 } else {

463 callback(b, t, optLvl, lvlTp, false);

464 }

465 }

466 }

467

468

470

471

473 LevelType lt, unsigned coefficient) {

474 assert(isValidLoopId(i) && isValidLevel(t, lvl));

475 assert(!loopToUnresolvedLvls[i][t].has_value());

476 loopToUnresolvedLvls[i][t] = std::make_pair(lvl, lt);

477 levelToDependentLoop[t][lvl].emplace_back(i, coefficient);

478 }

479

480

482 assert(isValidTensorId(t) && isValidLoopId(i));

483 return loopToUnresolvedLvls[i][t].has_value();

484 }

485

486

487

489 assert(isValidLevel(t, lvl));

490 return levelToDependentLoop[t][lvl];

491 }

492

493

495 assert(isValidLoopId(i));

496 return loopBounds[i];

497 }

498

499

500

504 assert(isValidTensorId(t) && isValidLoopId(i));

505 return loopToUnresolvedLvls[i][t].has_value();

506 }

507

508

509

513 return lt.hasSparseSemantic();

514 }

515 return false;

516 }

517

520 return loopToUnresolvedLvls[loop(b)][tensor(b)]->first;

521 }

522

525 return loopToUnresolvedLvls[loop(b)][tensor(b)]->second;

526 }

527

528

529

530

531

532

533

534

535

536

537

538

539

540

542 assert(isValidExprId(e));

543 return tensorExps[e];

544 }

546 assert(isValidLatPointId(p));

547 return latPoints[p];

548 }

550 assert(isValidLatSetId(s));

551 return latSets[s];

552 }

553

554

556

557

558

560 assert(exp(e).val && "Expression already has an associated value");

561 assert(v && "Trying to assign an undefined value");

562 tensorExps[e].val = v;

563 }

564

565

566

568 assert(exp(e).val && "Expression does not have an associated value");

569 tensorExps[e].val = Value();

570 }

571

572 #ifndef NDEBUG

573

577 void dumpBits(const BitVector &bits) const;

578 #endif

579

580

581

582

584

585

586

588

589

592

593 private:

594

595 constexpr bool isValidTensorId(TensorId t) const { return t < numTensors; }

596 constexpr bool isValidLoopId(LoopId i) const {

598 }

600 assert(levelToDependentLoop[t].size() == lvlToLoop[t].size());

601 return isValidTensorId(t) && lvl < lvlToLoop[t].size();

602 }

603 bool isValidExprId(ExprId e) const {

605 }

606 bool isValidLatPointId(LatPointId p) const {

608 }

609 bool isValidLatSetId(LatSetId s) const {

611 }

612 bool maybeZero(ExprId e) const;

613 bool isInvariant(ExprId e) const {

615 }

616 Type inferType(ExprId e, Value src) const;

617

618

619

620

621

622 std::pair<std::optional, bool> buildTensorExp(linalg::GenericOp op,

623 Value v);

624

625

627 const TensorId syntheticTensor;

628 const unsigned numTensors;

629 const unsigned numLoops;

630 bool hasSparseOut;

631

632

633

634

635

636

637

638

639 std::vector<std::vector> lvlTypes;

640

641

642 std::vector<std::vector<std::optional>> loopToLvl;

643

644

645 std::vector<std::vector<std::optional>> lvlToLoop;

646

647

648

649

650

651

652 std::vector<std::vector<std::optional>> loopToUnresolvedLvls;

653

654

655

656

657

658 std::vector<std::vector<std::vector>> levelToDependentLoop;

659

660

661 std::vector<std::pair<TensorId, Level>> loopBounds;

662

666 };

667

668 }

669 }

670

671 #endif

union mlir::linalg::@1203::ArityGroupAndKind::Kind kind

Attributes are known-constant values of operations.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

Operation is the basic unit of execution within MLIR.

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

A class to handle all iteration lattice operations.

void setHasSparseOut(bool s)

Sets whether the output tensor is sparse or not.

constexpr unsigned getNumLoops() const

Gets the total number of loops (native loops + filter loops).

LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1, Operation *op=nullptr)

Computes a single conjunction of two lattice points by taking the "union" of LoopId (effectively cons...

LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)

Disjunctive merge of two lattice sets: (s0 /\_op s1, s0, s1).

Level getLoopDependentLevel(TensorLoopId b) const

std::optional< Level > getLvl(TensorId t, LoopId i) const

Gets the level number of the the tth tensor on ith loop.

constexpr bool isOutTensor(TensorLoopId b, LoopId i) const

Returns true if b is the ith loop of the output tensor.

bool isSingleCondition(TensorId t, ExprId e) const

Returns true if given tensor iterates only in the given tensor expression.

bool hasSparseIdxReduction(const BitVector &bits) const

Returns true if bits contains a dependent index reduction condition on sparse levels.

bool expContainsTensor(ExprId e, TensorId t) const

Returns true if the expression contains the tensor as an operand.

LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero)

Maps the binary operator to the same operation but with one of its operand set to zero,...

bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const

Checks whether the TensorLoopId represents a sparse tensor level contains non-trivial index expressio...

void dumpBits(const BitVector &bits) const

bool hasExprValue(ExprId e) const

Checks whether the given expression has an associated value.

void foreachTensorLoopId(LatPointId p, bool simple, ForeachTensorLoopIdCallback callback) const

LatSetId addSet()

Constructs a new (initially empty) set, and returns its identifier.

std::optional< LoopId > getLoopId(TensorId t, Level lvl) const

Gets the loop identifier for the lvlth level of the tth tensor.

std::pair< TensorId, Level > getLoopDefiningLvl(LoopId i) const

Returns the defining [tid, lvl] for the loop.

BitVector simplifyCond(LatSetId s, LatPointId p)

Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...

bool hasNegateOnOut(ExprId e) const

Returns true if the expression contains a negation on output tensor.

constexpr unsigned getNumTensors() const

Gets the total number of tensors (including the output-tensor and synthetic-tensor).

bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const

Checks whether the TensorLoopId represents a tensor level contains non-trivial index expression.

LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1)

Disjunctive merge of two lattice sets and also set one of the operand to zero: (s0 /\_op s1 (e0 op e1...

void dumpSet(LatSetId s) const

void dumpLat(LatPointId p) const

LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, bool includeRight, TensorExp::Kind rtrans, Operation *opright)

Disjunctive merge of two lattice sets with custom handling of the overlap, left, and right regions.

ExprId addTensorExp(TensorId t)

Constructs a new tensor expression, and returns its identifier.

LatSetId buildLattices(ExprId e, LoopId i)

Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...

LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)

Conjunctive merge of two lattice sets: (s0 /\_op s1).

ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1=detail::kInvalidId, Operation *op=nullptr, Attribute attr=nullptr)

Constructs a new unary or binary expression, and returns its identifier.

ExprId addSynZeroExp()

Constructs a new synthetic zero expression.

constexpr LoopId makeLoopId(unsigned i) const

Safely converts the argument to a loop identifier.

std::optional< ExprId > buildTensorExpFromLinalg(linalg::GenericOp op)

Builds a tensor expression from the given Linalg operation.

void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)

Sets the level number and level-type of the tth tensor on ith loop.

LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v=Value(), Operation *op=nullptr, Attribute attr=nullptr)

Maps the unary operator over the lattice set of the operand, i.e.

void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const

Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...

std::optional< Level > getLvl(TensorLoopId b) const

ArrayRef< LatPointId > set(LatSetId s) const

LatSetId optimizeSet(LatSetId s)

Optimizes the iteration lattice points in the given set.

constexpr TensorId tensor(TensorLoopId b) const

Gets the tensor-identifier of the TensorLoopId.

void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)

Establishes the two-way map that i <-> <t, lvl, lt>.

void dumpExp(ExprId e) const

Print methods (for debugging).

LevelType getLvlType(TensorLoopId b) const

Gets the level-type of the TensorLoopId.

Merger(unsigned numInputOutputTensors, unsigned numLoops, unsigned maxLvlRank)

Constructs a merger for the given number of tensors and loops.

bool hasAnySparse(const BitVector &bits) const

Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.

void clearExprValue(ExprId e)

Clears the value associated with the expression.

std::vector< LoopCoeffPair > & getDependentLoops(TensorId t, Level lvl)

Returns the list of loop indices which appear in the non-trivial index expression on t_l,...

LatPointId addLat(TensorId t, LoopId i, ExprId e)

Constructs a new iteration lattice point, and returns its identifier.

ExprId addLoopVarExp(LoopId i)

Constructs a new loop-variable expression, and returns its identifier.

constexpr TensorId getSynTensorID() const

Gets the synthetic tensor's identifier (used for all invariant tensor expressions).

bool latGT(LatPointId p0, LatPointId p1) const

Returns true if p0 > p1.

const TensorExp & exp(ExprId e) const

Convenience getters to immediately access the stored nodes.

constexpr LoopId loop(TensorLoopId b) const

Gets the loop-identifier of the TensorLoopId.

const LatPoint & lat(LatPointId p) const

constexpr TensorId getOutTensorID() const

Gets the output tensor's identifier.

bool onlyDenseDiff(LatPointId p0, LatPointId p1) const

Returns true if p0 and p1 only differ in dense.

ExprId addInvariantExp(Value v)

Constructs a new invariant expression, and returns its identifier.

constexpr TensorId makeTensorId(unsigned t) const

Safely converts the argument to a tensor identifier.

LevelType getLoopDependentLevelType(TensorLoopId b) const

LevelType getLvlType(TensorId t, LoopId i) const

Gets the level-type of the tth tensor on ith loop.

Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const

Rebuilds SSA format from a tensor expression.

constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const

Safely converts the arguments to a pair of (tensor,loop) identifiers.

bool expIsTensor(ExprId e, TensorId t) const

Returns true if the expression is (kTensor t).

void setExprValue(ExprId e, Value v)

Sets the expression to have the associated value.

bool hasDependentLvl(LoopId i, TensorId t)

Whether the loop has dependent slice.

@ Type

An inlay hint that for a type annotation.

static constexpr unsigned kInvalidId

A constant serving as the canonically invalid identifier, regardless of the identifier type.

unsigned LatSetId

LatSet identifiers.

std::pair< Level, LevelType > LvlLTPair

A pair of level and its corresponding LevelType of a tensor.

unsigned TensorLoopId

A compressed representation of std::pair<TensorId, LoopId>.

uint64_t Level

The type of level identifiers and level-ranks.

unsigned LoopId

Loop identifiers.

bool isValidLT(LevelType lt)

unsigned ExprId

TensorExp identifiers.

unsigned LatPointId

LatPoint identifiers.

std::pair< LoopId, unsigned > LoopCoeffPair

A pair of loop id and its coefficients.

unsigned TensorId

Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...

Include the generated interface declarations.

LatPoint(const BitVector &bits, ExprId e)

Construct a lattice point from the given set of TensorLoopIds.

ExprId exp

Identifier of the tensor expression.

BitVector bits

Conjunction of all TensorLoopIds involved in the tensor expression.

BitVector simple

Simplified conjunction of TensorLoopId as bitvector.

LatPoint(unsigned size, ExprId e)

Construct a lattice point with the empty set of TensorLoopIds.

This enum defines all the sparse representations supportable by the SparseTensor dialect.

Child subexpressions for non-leaf expressions.

Tensor expression. Represents an MLIR expression in tensor index notation.

LoopId loop

kLoopVar expressions simply have a loop identifier.

Value val

Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...

Kind

Tensor expression kind.

Children children

All other expressions hold the ExprIds of their children.

Attribute attr

An optional attribute that is required to determine the semantics of the operations.

TensorId tensor

kTensor expressions simply have a tensor identifier.

Kind kind

Tensor expression kind.

Operation * op

Code blocks used by semirings.

TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a)

The x parameter has different types depending on the value of the k parameter.