MLIR: lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

11

21

22 using namespace mlir;

24

25

26

27

28

29 #define CMPI(p, l, r) \

30 (builder.createarith::CmpIOp(loc, arith::CmpIPredicate::p, (l), (r)) \

31 .getResult())

32

33 #define C_IDX(v) (constantIndex(builder, loc, (v)))

34 #define YIELD(vs) (builder.createscf::YieldOp(loc, (vs)))

35 #define ADDI(lhs, rhs) (builder.createarith::AddIOp(loc, (lhs), (rhs)))

36 #define ANDI(lhs, rhs) (builder.createarith::AndIOp(loc, (lhs), (rhs)))

37 #define SUBI(lhs, rhs) (builder.createarith::SubIOp(loc, (lhs), (rhs)))

38 #define MULI(lhs, rhs) (builder.createarith::MulIOp(loc, (lhs), (rhs)))

39 #define REMUI(lhs, rhs) (builder.createarith::RemUIOp(loc, (lhs), (rhs)))

40 #define DIVUI(lhs, rhs) (builder.createarith::DivUIOp(loc, (lhs), (rhs)))

41 #define SELECT(c, l, r) (builder.createarith::SelectOp(loc, (c), (l), (r)))

42

43

44

45

46

47 #ifndef NDEBUG

50 memref = builder.creatememref::CastOp(

54 }

55 #endif

56

57

58

59

60

61

62

63

64

65

70 }

71

76 }

77

79 if (auto f = llvm::dyn_cast(attr); f && f.getValue().isZero())

80 return true;

81 if (auto i = llvm::dyn_cast(attr); i && i.getValue().isZero())

82 return true;

83 return false;

84 }

85

90 return cast(ofr);

91 }

92

94

95

98 if (padOp && stt.has_value() && stt->hasEncoding() &&

99 padOp.getSourceType().getEncoding() == stt->getEncoding() &&

100 stt->getEncoding().isIdentity()) {

101

103 if (matchPattern(padOp.getBody()->getTerminator(),

104 m_Optensor::YieldOp(m_Constant(&padCst))) &&

106 return padOp.getSource();

107 }

108 }

109 return t;

110 }

111

112

113

114

115

117 bool isSparseOut, unsigned numLoops,

120 initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter);

121 }

122

124 bool isSparseOut, unsigned numLoops,

127

128 this->loopTag = loopTag;

129 this->hasOutput = hasOutput;

130 this->isSparseOut = isSparseOut;

131 this->emitStrategy = emitStrategy;

132

133 const unsigned numManifestTensors = ts.size();

134 const unsigned synTensorId = numManifestTensors;

135 const unsigned numTensors = numManifestTensors + 1;

136

137 this->tensors.assign(ts.begin(), ts.end());

138

139 this->valBuffer.assign(numTensors, nullptr);

140 this->lvls.resize(numTensors);

141 this->iters.resize(numTensors);

142 this->spIterVals.resize(numTensors);

143

144

145

146 this->loopStack.reserve(numLoops);

147 this->loopSeqStack.reserve(numLoops);

148

149

150 this->dependentLvlMap.assign(

151 numTensors, std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());

152 this->sliceMeta.assign(

153 numTensors, std::vector<std::vector<std::pair<Value, unsigned>>>());

154 this->levelReducedDep.assign(numTensors, std::vector());

155

156

157 for (TensorId tid = 0; tid < numTensors; tid++) {

159 if (tid == synTensorId) {

160

161

162

163 lvlRank = numLoops;

164 } else {

165 const Value t = tensors[tid];

166

168 continue;

169

173 }

174

175 lvls[tid].resize(lvlRank);

176 iters[tid].resize(lvlRank);

177 spIterVals[tid].resize(lvlRank);

178 loopHighs.assign(numLoops, nullptr);

179

180

181 levelReducedDep[tid].assign(lvlRank, 0);

182 dependentLvlMap[tid].assign(

183 lvlRank, std::vector<std::pair<TensorLevel, unsigned>>());

184 sliceMeta[tid].assign(lvlRank, std::vector<std::pair<Value, unsigned>>());

185 if (dimGetter && !isSynTensor(tid)) {

186 for (Level l = 0; l < lvlRank; l++) {

187 std::vector<std::pair<LoopId, unsigned>> deps = dimGetter(tid, l);

188

189 llvm::sort(deps, llvm::less_first());

190

191 dependentLvlMap[tid][l] = std::move(deps);

192 unsigned depends = dependentLvlMap[tid][l].size();

193 if (depends == 0)

194 continue;

195 sliceMeta[tid][l].reserve(depends);

196 }

197 }

198 }

199 }

200

201 std::unique_ptr

204 Value tensor = tensors[t];

207

209 if (folded != tensor) {

210 auto padOp = tensor.getDefiningOptensor::PadOp();

211 assert(padOp);

212 if (padOp.getPaddedDims().test(l)) {

215 auto padIt = makePaddedIterator(std::move(it), low, high, emitStrategy);

216 return padIt;

217 }

218 }

219

220 if (stt.hasEncoding() && stt.getEncoding().isSlice()) {

224 std::move(it), offset, stride, lvls[t][l]->getSize(), emitStrategy);

225 return slicedIt;

226 }

227

228 return it;

229 }

230

234

235

237 t++) {

238

239

241 const auto rtp = dyn_cast(tensor.getType());

242

243

244 if (!rtp)

245 continue;

246

248 const auto shape = rtp.getShape();

249

250

251

252

253 bool isOutput = isOutputTensor(t);

254 Type elementType = stt.getElementType();

255 if (!stt.hasEncoding()) {

256

258

259

260

261

262 if (llvm::isa_and_nonnulltensor::ExtractSliceOp(tensor.getDefiningOp()))

264

266 builder.createbufferization::ToBufferOp(loc, denseTp, tensor);

267

268 if (isOutput && updater)

269 denseVal = updater(builder, loc, denseVal, tensor);

270

271 valBuffer[t] = denseVal;

272 } else {

273

274

275

276 valBuffer[t] = builder.create(loc, tensor);

277 }

278 }

279

280

281

283 return;

284

285

286 if (synSetter) {

288 for (unsigned i = 0, e = loopHighs.size(); i < e; i++) {

289 Value sz = loopHighs[i] = synSetter(builder, loc, i);

291 lvls[synId][i] = std::move(stl);

292 iters[synId][i].emplace_back(std::move(it));

293 }

294 }

295

296

297

298

299

300

302 t++) {

303

304

306 const auto rtp = dyn_cast(tensor.getType());

307 if (!rtp)

308

309

310 continue;

311

313 const Level lvlRank = stt.getLvlRank();

314

315

316 for (Level l = 0; l < lvlRank; l++) {

317

319 if (!dependentLvlMap[t][l].empty())

320 continue;

321

322 auto it = makeLevelIterator(builder, loc, t, l);

323 iters[t][l].emplace_back(std::move(it));

324 }

325

326

327

328 }

329

330 initSubSectIterator(builder, loc);

331 }

332

333 void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {

335 for (TensorId t = 0, e = tensors.size(); t < e; t++) {

336 auto rtp = dyn_cast(tensors[t].getType());

337 if (!rtp)

338 continue;

339

341

342

343 auto remDepStack = dependentLvlMap;

344 std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder;

345 for (Level lvl = 0; lvl < lvlRank; lvl++) {

346

347 std::reverse(remDepStack[t][lvl].begin(), remDepStack[t][lvl].end());

348 for (auto [loop, coeff] : dependentLvlMap[t][lvl])

349 depRedOrder.emplace_back(std::make_tuple(loop, t, lvl));

350 }

351

352 if (depRedOrder.empty())

353 continue;

354

355 llvm::sort(depRedOrder, llvm::less_first());

356

358 for (auto [loop, t, lvl] : depRedOrder) {

359 std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back();

360 assert(curDep.first == loop);

361 remDepStack[t][lvl].pop_back();

362

363 auto lvlIt = makeLevelIterator(builder, loc, t, lvl);

365 if (!parent && lvl > 0) {

366 if (dependentLvlMap[t][lvl - 1].empty()) {

367 parent = iters[t][lvl - 1].back().get();

368 }

369 }

370

371 std::unique_ptr it;

372 if (!remDepStack[t][lvl].empty()) {

373

375 for (auto [loop, stride] : remDepStack[t][lvl]) {

378 }

380 std::move(lvlIt), size, curDep.second,

381 emitStrategy);

382 } else {

383 const SparseIterator &subSectIter = *iters[t][lvl].back();

385 std::move(lvlIt), loopHighs[loop],

386 curDep.second, emitStrategy);

387 }

388 lastIter[t] = it.get();

389 iters[t][lvl].emplace_back(std::move(it));

390 }

391 }

392 }

393

394 void LoopEmitter::categorizeIterators(

397

398

402 raIters.push_back(it);

403 else

404 spIters.push_back(it);

405 }

406

407 llvm::stable_sort(spIters, [](auto lhs, auto rhs) {

408

409 return static_cast<uint8_t>(lhs->kind) > static_cast<uint8_t>(rhs->kind);

410 });

411 }

412

415

416 assert(loopSeqStack.size() == loopStack.size());

417

419

421 levelReducedDep[tid][lvl]++;

422 prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);

423 }

424 }

425

426

427 loopSeqStack.emplace_back(C_IDX(0), tidLvls.vec());

428 }

429

431 assert(loopSeqStack.size() == loopStack.size() + 1);

432

433

434

436 levelReducedDep[tid][lvl]--;

437

438 loopSeqStack.pop_back();

439 }

440

444

445

446

447

448 const auto loopId = cast(a).getPosition();

449 return loopStack[loopId].iv;

450 }

452 auto binOp = cast(a);

453 return ADDI(genAffine(builder, loc, binOp.getLHS()),

454 genAffine(builder, loc, binOp.getRHS()));

455 }

457 auto binOp = cast(a);

458 return MULI(genAffine(builder, loc, binOp.getLHS()),

459 genAffine(builder, loc, binOp.getRHS()));

460 }

462 int64_t c = cast(a).getValue();

464 }

465 default:

466 llvm_unreachable("unexpected affine subscript");

467 }

468 }

469

470 std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(

473

474

475

476

477

479 auto [lo, hi] = iter.genForCond(builder, loc);

482 if (isParallel) {

483 scf::ParallelOp parOp =

484 builder.createscf::ParallelOp(loc, lo, hi, step, reduc);

486 assert(parOp.getNumReductions() == reduc.size());

487 iv = parOp.getInductionVars()[0];

488

489

490

491

492

493

494

495

496 for (int i = 0, e = reduc.size(); i < e; i++)

497 reduc[i] = parOp.getInitVals()[i];

498 loop = parOp;

499 } else {

500 scf::ForOp forOp = builder.createscf::ForOp(loc, lo, hi, step, reduc);

502 iv = forOp.getInductionVar();

503

504

505 assert(forOp.getNumRegionIterArgs() == reduc.size());

506 for (int i = 0, e = reduc.size(); i < e; i++)

507 reduc[i] = forOp.getRegionIterArg(i);

508 loop = forOp;

509 }

510 assert(loop && iv);

511

515 crd = iter.deref(builder, loc);

516 } else {

517 iter.locate(builder, loc, iv);

518 }

519

520 return {loop, crd};

521 }

522

523 std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(

527 needsUniv ? loopSeqStack.back().first : nullptr);

528 }

529

531

532 if (spIters.size() > 1)

533 return false;

534

535 if (spIters.size() == 1)

536 return spIters.front()->iteratableByFor();

537

538 return true;

539 }

540

544 unsigned caseIdx,

546 auto coIterOp = cast(loopStack.back().loop);

549

550 coIterOp.setCasesAttr(builder.getArrayAttr(cases));

551 Region &caseRegion = coIterOp.getRegion(caseIdx);

552 assert(caseRegion.getBlocks().empty() &&

553 "re-initialize the same coiteration case region.");

554

555

556 TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes();

557

560

561 blockArgTps.append(iterArgsTps.begin(), iterArgsTps.end());

562

563 for (auto i : caseBit.bits()) {

564 blockArgTps.push_back(

565 cast(coIterOp.getIterSpaces()[i].getType())

566 .getIteratorType());

567 }

570

571

573

574 loopStack.back().iv = coIterOp.getCrds(caseIdx).front();

575

576 ValueRange iterArgs = coIterOp.getRegionIterArgs(caseIdx);

578

579 ValueRange iters = coIterOp.getRegionIterators(caseIdx);

582 if (caseBit[i]) {

583 spIterVals[tl.first][tl.second] = iters.front();

584 iters = iters.drop_front();

585 } else {

586 spIterVals[tl.first][tl.second] = nullptr;

587 }

588 }

589

590 assert(iters.empty());

591 return &caseRegion;

592 }

593

597 bool needsUniv) {

598

599

600

602 if (tidLvls.size() == 1) {

604 Value t = tensors[tid];

605

606

607 ExtractIterSpaceOp extractSpaceOp =

608 lvl == 0 ? builder.create(loc, t)

609 : builder.create(

610 loc, t, spIterVals[tid][lvl - 1], lvl);

611

612 IterateOp iterOp = builder.create(

613 loc, extractSpaceOp.getExtractedSpace(), reduc);

614 spIterVals[tid][lvl] = iterOp.getIterator();

615

616

617 llvm::copy(iterOp.getRegionIterArgs(), reduc.begin());

618

620 loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(),

621 iterOp.getCrds().front(), loopTag);

622 return iterOp;

623 }

624

625

628 Value t = tensors[tid];

629 ExtractIterSpaceOp extractSpaceOp =

630 lvl == 0 ? builder.create(loc, t)

631 : builder.create(

632 loc, t, spIterVals[tid][lvl - 1], lvl);

633 spaces.push_back(extractSpaceOp.getExtractedSpace());

634 }

635 auto coIterOp = builder.create(loc, spaces, reduc, numCases);

636

637

638 loopStack.emplace_back(tidLvls, coIterOp, nullptr,

639 nullptr, loopTag);

640 return coIterOp;

641 }

642

643

644 tryParallel = tryParallel && reduc.size() <= 1;

645

648 categorizeIterators(tidLvls, raIters, spIters);

649

650

651

652

653

654 needsUniv = !spIters.empty() && needsUniv;

655

656

657

658

660 Value iv = nullptr;

662

663

664

665 if (shouldIteratedByForLoop(spIters) && !needsUniv) {

666 assert(spIters.size() <= 1);

667 SparseIterator &it = spIters.empty() ? *raIters.front() : *spIters.front();

668 std::tie(l, iv) =

669 emitForLoopOverTensorAtLvl(builder, loc, it, reduc, tryParallel);

671 } else {

672 for (auto *it : spIters) {

674 }

675

676 if (needsUniv)

677 for (auto *it : raIters)

679

680 std::tie(l, iv) =

681 emitWhileLoopOverTensorsAtLvls(builder, loc, spIters, reduc, needsUniv);

682 }

683

684

686 it->locate(builder, loc, iv);

687

688

689

690 loopStack.emplace_back(tls, l, builder.getInsertionBlock(), iv, loopTag);

691 return l;

692 }

693

698

700 lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get();

701 auto &it = getCurIterator(tid, lvl);

702 it.genInit(builder, loc, parent);

703

706 it.locate(builder, loc, lvlCrd);

707 }

708

709 void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,

711

712

713

714

715 bool hasParent = lvl == 0 || !dependentLvlMap[tid][lvl].empty();

716

718 hasParent ? nullptr : iters[tid][lvl - 1].back().get();

719 auto &it = getCurIterator(tid, lvl);

720 it.genInit(builder, loc, parent);

721

722

725 }

726

729 const LoopInfo &loopInfo = loopStack.back();

731 auto iterateOp = llvm::cast(loopInfo.loop);

732 assert(reduc.size() == iterateOp.getNumResults());

733 rewriter.create<sparse_tensor::YieldOp>(loc, reduc);

734

736

737 llvm::copy(iterateOp.getResults(), reduc.begin());

738 return;

739 }

740 if (auto forOp = llvm::dyn_castscf::ForOp(loopInfo.loop)) {

741 if (!reduc.empty()) {

742 assert(reduc.size() == forOp.getNumResults());

743 rewriter.createscf::YieldOp(loc, reduc);

744 }

745

747

748 llvm::copy(forOp.getResults(), reduc.begin());

749 } else {

750 auto parOp = llvm::castscf::ParallelOp(loopInfo.loop);

751 if (!reduc.empty()) {

752 assert(reduc.size() == parOp.getInitVals().size() && reduc.size() == 1);

753 Operation *redExp = reduc.front().getDefiningOp();

754

755 assert(redExp->getUses().empty());

756

757

758

760

761 Value redVal = parOp.getInitVals().front();

765 else if (redExp->getOperand(1) == redVal)

767

768

769 assert(curVal);

770 #ifndef NDEBUG

771

772

773 unsigned numUsers = 0;

775 if (op->getParentOp() == parOp)

776 numUsers++;

777 }

778 assert(numUsers == 1);

779 #endif

780

782 auto redOp = rewriter.createscf::ReduceOp(loc, curVal);

783

784 Block *redBlock = &redOp.getReductions().front().front();

787

788

790 newRed, [&]() { newRed->setOperands(redBlock->getArguments()); });

791

792 rewriter.eraseOp(redExp);

794 rewriter.createscf::ReduceReturnOp(loc, newRed->getResult(0));

795 }

797

798 for (unsigned i = 0, e = parOp.getResults().size(); i < e; i++)

799 reduc[i] = parOp.getResult(i);

800 }

801 }

802

803 void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,

805 const LoopInfo &loopInfo = loopStack.back();

806 auto whileOp = llvm::castscf::WhileOp(loopInfo.loop);

807 Value iv = loopInfo.iv;

809

810

811

812

813

814

816 ValueRange whileRes = whileOp.getResults();

817

821

825

826

827

829 } else {

830

831

832 Value uniIdx = whileOp.getResults().back();

833 it.locate(builder, loc, uniIdx);

834 }

835 }

836

837

838 for (auto &i : reduc) {

839 operands.push_back(i);

840

841 i = whileRes.front();

842 whileRes = whileRes.drop_front();

843 }

844

845

846 if (operands.size() < whileOp.getNumResults()) {

847 assert(operands.size() + 1 == whileOp.getNumResults());

848

849 operands.push_back(ADDI(iv, one));

850

851 loopSeqStack.back().first = whileOp->getResults().back();

852 }

853

854 if (!operands.empty())

856

858 }

859

862

863

864 const LoopInfo &loopInfo = loopStack.back();

867 if (isa(p))

868 rewriter.create<sparse_tensor::YieldOp>(loc, reduc);

869

870

872

874 loopStack.pop_back();

875 return;

876 }

877

878

880 if (!loopInfo.userCodeBlock->empty() &&

881 llvm::isascf::YieldOp(&loopInfo.userCodeBlock->back())) {

882

883

884 assert(loopInfo.userCodeBlock->back().getNumResults() == 0);

886 }

887

888 if (llvm::isascf::WhileOp(loopInfo.loop)) {

889 exitWhileLoop(rewriter, loc, reduc);

890 } else {

891 exitForLoop(rewriter, loc, reduc);

892 }

893

894 assert(loopStack.size() == loopSeqStack.size());

895 loopStack.pop_back();

896 }

897

898

899

900

901

905

906

907

908

910

911

912

913

914

915 if (userReducFirst)

916 ivs.append(reduc.begin(), reduc.end());

917

918

921 ivs.append(itVals.begin(), itVals.end());

922 }

923

924 if (!userReducFirst)

925 ivs.append(reduc.begin(), reduc.end());

926

927

928 if (uniIdx)

929 ivs.push_back(uniIdx);

930

931

932 assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));

934 auto whileOp = builder.createscf::WhileOp(loc, types, ivs);

935

937 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);

938 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);

939

940

943 Value whileCond = nullptr;

944

946 auto [cond, remArgs] = it->genWhileCond(builder, loc, bArgs);

947 whileCond = !whileCond ? cond : ANDI(whileCond, cond);

948 bArgs = remArgs;

949 }

950

951

952 assert(bArgs.size() == reduc.size() + (uniIdx ? 1 : 0));

953 builder.createscf::ConditionOp(loc, whileCond, before->getArguments());

954

955

958

961

962 it->deref(builder, loc);

963 }

964

965

966 for (unsigned i = 0, e = reduc.size(); i < e; i++)

967 reduc[i] = aArgs[i];

968

970

971 if (!uniIdx) {

973 if (min) {

976 } else {

978 }

979 }

980 } else {

981

982 min = whileOp.getAfterArguments().back();

983 }

984

985 return {whileOp, min};

986 }

987

988 #undef CMPI

989 #undef C_IDX

990 #undef YIELD

991 #undef ADDI

992 #undef ANDI

993 #undef SUBI

994 #undef MULI

995 #undef SELECT

static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)

Copies the given number of bytes from src to dst pointers.

static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor, Level lvl)

static Value tryFoldTensors(Value t)

static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor, Level lvl)

static bool isIntOrFPZero(Attribute attr)

static LLVM_ATTRIBUTE_UNUSED void dumpIndexMemRef(OpBuilder &builder, Location loc, Value memref)

static Value unFoldOpIntResult(OpBuilder &builder, Location loc, OpFoldResult ofr)

static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)

Base type for affine expression.

AffineExprKind getKind() const

Return the classification for this type.

Attributes are known-constant values of operations.

This class provides a shared interface for ranked and unranked memref types.

Block represents an ordered list of Operations.

iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)

Add one argument to the argument list for each type specified in the list.

BlockArgListType getArguments()

IntegerAttr getI64IntegerAttr(int64_t value)

ArrayAttr getArrayAttr(ArrayRef< Attribute > value)

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

This class helps build Operations.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

Block * getInsertionBlock() const

Return the block the current insertion point belongs to.

This class represents a single result from folding an operation.

Operation is the basic unit of execution within MLIR.

Value getOperand(unsigned idx)

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

unsigned getNumOperands()

void setOperands(ValueRange operands)

Replace the current operands of this operation with the ones provided in 'operands'.

result_range getResults()

use_range getUses()

Returns a range of all uses, which is useful for iterating over all uses.

unsigned getNumResults()

Return the number of results held by this operation.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

BlockListType & getBlocks()

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

type_range getTypes() const

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

user_range getUsers() const

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....

iterator_range< const_set_bits_iterator > bits() const

void exitCurrentLoop(RewriterBase &rewriter, Location loc, MutableArrayRef< Value > reduc={})

Generates code to exit the current loop (e.g., generates yields, forwards loop induction variables,...

void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)

Emits the address for a dense level based on the value evaluated by the provided affine expression.

void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)

Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...

Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)

Generates code to compute an affine expression whose variables are LoopIds (i.e., cast...

Region * enterCurrentCoIterationCase(OpBuilder &builder, Location loc, I64BitSet caseBit, unsigned caseIdx, MutableArrayRef< Value > reduc)

Operation * enterCoIterationOverTensorsAtLvls(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls, unsigned numCases, MutableArrayRef< Value > reduc={}, bool isParallel=false, bool needsUniv=false)

Emits a co-iteration loop over a set of tensors.

TensorLevel makeTensorLevel(TensorId t, Level l) const

Compresses a TensorId and Level into a TensorLevel.

unsigned getNumManifestTensors() const

Gets the total number of manifest tensors (excluding the synthetic tensor).

void initialize(ValueRange tensors, StringAttr loopTag=nullptr, bool hasOutput=false, bool isSparseOut=false, unsigned numLoops=0, DependentLvlGetter getter=nullptr, SparseEmitStrategy emitStrategy=SparseEmitStrategy::kFunctional)

Takes an array of input tensors, which the generated loops will iterate over.

std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tidLvl) const

De-compresses a TensorLevel back to a pair of TensorId and Level.

auto unpackTensorLevelRange(ContainerTy &&c) const

Converts a range of TensorLevel to a range of std::pair<TensorId, Level>

void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)

Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.

void exitCurrentLoopSeq(OpBuilder &builder, Location loc)

Exits the current loop sequence, this will reset universal index to 0.

TensorId getSynTensorId() const

Gets the TensorId for synthetic tensor.

Helper class that generates loop conditions, etc, to traverse a sparse tensor level.

virtual std::pair< Value, Value > genForCond(OpBuilder &b, Location l)

void genInit(OpBuilder &b, Location l, const SparseIterator *p)

void locate(OpBuilder &b, Location l, Value crd)

virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)

ValueRange linkNewScope(ValueRange pos)

ValueRange getCursor() const

Value deref(OpBuilder &b, Location l)

virtual bool randomAccessible() const =0

std::pair< Value, ValueRange > genWhileCond(OpBuilder &b, Location l, ValueRange vs)

A wrapper around RankedTensorType, which has three goals:

Level getLvlRank() const

Returns the level-rank.

BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)

Return a MemRef type with fully dynamic layout.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

Value constantIndex(OpBuilder &builder, Location loc, int64_t i)

Generates a constant of index type.

Dimension toDim(SparseTensorEncodingAttr enc, Level l)

Convenience method to translate the given level to the corresponding dimension.

std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)

Helper function to create a TensorLevel object from given tensor.

std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)

Helper function to create a SparseIterator object that iterates over a non-empty subsection created b...

uint64_t Level

The type of level identifiers and level-ranks.

std::optional< SparseTensorType > tryGetSparseTensorType(Value val)

RankedTensorType getRankedTensorType(T &&t)

Convenience method to abbreviate casting getType().

std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)

Helper function to create a synthetic SparseIterator object that iterates over a dense space specifie...

Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim)

Generates code to retrieve the slice slice for the sparse tensor slice, return a constant if the offs...

SparseTensorEncodingAttr getSparseTensorEncoding(Type type)

Convenience method to get a sparse encoding attribute from a type.

std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)

bool isZeroRankedTensorOrScalar(Type type)

std::unique_ptr< SparseIterator > makePaddedIterator(std::unique_ptr< SparseIterator > &&sit, Value padLow, Value padHigh, SparseEmitStrategy strategy)

Helper function to create a SparseIterator object that iterates over a padded sparse level (the padde...

SparseTensorType getSparseTensorType(Value val)

Convenience methods to obtain a SparseTensorType from a Value.

std::unique_ptr< SparseIterator > makeSimpleIterator(OpBuilder &b, Location l, const SparseIterationSpace &iterSpace)

Helper function to create a simple SparseIterator object that iterate over the entire iteration space...

func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)

Creates a CallOp to the function reference returned by getFunc() in the builder's module.

std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)

Helper function to create a SparseIterator object that iterates over a sliced space,...

Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim)

Generates code to retrieve the slice offset for the sparse tensor slice, return a constant if the off...

std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)

Helper function to create a SparseIterator object that iterate over the non-empty subsections set.

unsigned TensorId

Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

@ Mul

RHS of mul is always a constant or a symbolic expression.

@ DimId

Dimensional identifier.

@ Constant

Constant integer.

SparseEmitStrategy

Defines a scope for reinterpret map pass.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.