MLIR: lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

11

21

22using namespace mlir;

24

25namespace {

26

27

28

29

30

31

32template <typename SubClass, typename SourceOp>

34 using OpRewritePattern::OpRewritePattern;

35 using OpAdaptor = typename SourceOp::Adaptor;

36

37 LogicalResult matchAndRewrite(SourceOp op,

38 PatternRewriter &rewriter) const override {

39 Location loc = op.getLoc();

40

41

43 SmallVector deMappedIns(op->getOperands());

44 for (Value &in : deMappedIns) {

46 in =

47 ReinterpretMapOp::create(rewriter, loc, stt->getDemappedType(), in);

49 }

50 }

51

52

53 OpAdaptor adaptor(deMappedIns, op);

54 LogicalResult status =

55 static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);

57 }

58};

59

60

61struct AffineDimCollector : public AffineExprVisitor {

62 explicit AffineDimCollector(unsigned dimNum) : dims(dimNum) {};

63 void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }

64 BitVector dims;

65};

66

67

68struct AffineExprAdmissibleVisitor

70 explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput) {};

71

72

73 void visitAddExpr(AffineBinaryOpExpr expr) {

74 if (isOutput)

75 admissible = false;

76 }

77 void visitMulExpr(AffineBinaryOpExpr expr) {

78 if (isOutput)

79 admissible = false;

80 }

81

82

83 void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }

84 void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }

85 void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }

86 operator bool() { return admissible; }

87

88private:

89 bool admissible = true;

90 bool isOutput;

91};

92

93

94

95

96using InadmissInfo = std::pair<BitVector, BitVector>;

97

98}

99

100

101

102

103

104

106 auto ret = std::make_pair(BitVector(map.getNumResults()),

108 AffineDimCollector collector(map.getNumDims());

109 for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {

110 AffineExprAdmissibleVisitor admissible(isOutput);

111 admissible.walkPostOrder(map.getResult(lvl));

112 if (!admissible) {

113

114 ret.first.set(lvl);

115

116 collector.walkPostOrder(map.getResult(lvl));

117 }

118 }

119 ret.second = collector.dims;

120 return ret;

121}

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

154 auto [inAdLvls, usedDims] = info;

155

156

157

158

159

161

162 assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());

163 if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {

164

165

166

167

168

170 AffineDimCollector usedInLvl(idxMap.getNumDims());

172 usedInLvl.walkPostOrder(e);

173

174 unsigned curUsedDimID = 0;

175 unsigned curUnusedDimID = lvl2Idx.getNumDims();

176

177 BitVector unused = usedInLvl.dims.flip();

178 for (unsigned i = 0; i < idxMap.getNumDims(); i++) {

179 if (unused.test(i))

181 else

182 results.push_back(lvl2Idx.getResult(curUsedDimID++));

183 }

184 lvl2Idx =

185 AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);

186 }

187 assert(lvl2Idx.getNumResults() == idxMap.getNumDims());

188

189

190

191

192

193

194 unsigned curRepID = 0;

195 unsigned curOriID = inAdLvls.count();

199

200 for (unsigned l : inAdLvls.set_bits()) {

201

202

203

204

206

207

208

210 AffineDimCollector collector(idxMap.getNumDims());

211 collector.walkPostOrder(lvlExp);

212

213 assert(collector.dims.count() == 1);

214 transItTps.push_back(itTps[collector.dims.find_first()]);

215 }

216

217 for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {

218 if (usedDims.test(d)) {

219

220

221

222 results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));

223 } else {

224

225

226

228 transItTps.push_back(itTps[d]);

229 }

230 }

231 unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();

232

233 itTps.assign(transItTps.begin(), transItTps.end());

235}

236

237

238

239

240

241static std::optional<std::pair<ArrayAttr, ArrayAttr>>

243

247 for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {

248 Value tensor = op->getOpOperand(i).get();

250 if (stt && !stt->isIdentity()) {

251 AffineMap dim2Lvl = stt->getDimToLvl();

252

253 idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);

254 }

255 }

256

257

258

260 unsigned pos, int64_t lvlSz) {

261 if (ShapedType::isStatic(lvlSz)) {

265

266

267 auto divExp =

269 cstMapping.try_emplace(divExp, c0);

270

271

273 cstMapping.try_emplace(modExp, lvlExp);

274 }

275 };

276

277 unsigned boundedNum = 0;

278

282 for (OpOperand &operand : op->getOpOperands()) {

284

285 if (!stt || !stt->getEncoding())

286 continue;

287

288 unsigned tid = operand.getOperandNumber();

289 bool isOutput = &operand == op.getDpsInitOperand(0);

290 AffineMap idxMap = idxMapArray[tid];

292 auto [inAdLvls, dimExprs] = inAdInfo;

293 for (unsigned d : dimExprs.set_bits()) {

294

295

296

297 if (d < boundedNum)

298 return std::nullopt;

299 }

300

301 if (inAdLvls.count() != 0) {

302

303

306 unsigned position = 0;

307 for (unsigned lvl : inAdLvls.set_bits()) {

308 int64_t lvlSz = lvlShape[lvl];

309 populateCstMapping(cstMapping, position, lvlSz);

310 position++;

311 }

312

314

315

316 for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {

317 AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);

318 idxMapArray[tid] = transMap.replace(

319 cstMapping, transMap.getNumDims(),

320 0);

321 }

323 boundedNum += inAdLvls.count();

324 }

325 }

326 };

327

329 llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {

330 return linalg::IteratorTypeAttr::get(ctx, itTp);

331 });

332

335}

336

337

340 return ReinterpretMapOp::create(builder, val.getLoc(), enc.withoutDimToLvl(),

341 val);

342}

343

344

347 return ReinterpretMapOp::create(builder, val.getLoc(), enc, val);

348}

349

353 assert(outs.size() == types.size());

354 for (auto [r, t] : llvm::zip(ret, types))

355 if (r.getType() != t)

356 r = ReinterpretMapOp::create(rewriter, r.getLoc(), t, r);

357 return ret;

358}

359

360namespace {

361

362

363

364

365

366

367struct GenericOpReinterpretMap

368 : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {

369public:

370 using DemapInsRewriter::DemapInsRewriter;

371 LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,

372 PatternRewriter &rewriter) const {

373

374

375 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||

378 return failure();

379

380

381 auto transMap = translateMap(linalgOp, rewriter);

382 if (!transMap)

384 linalgOp, "the sparse kernel can not be sparsified.");

385

386

387 Value res = linalgOp.getResult(0);

389 auto [idxMap, itTp] = *transMap;

390

392 linalgOp.setIndexingMapsAttr(idxMap);

393 linalgOp.setIteratorTypesAttr(itTp);

394

395 linalgOp.getInputsMutable().assign(adaptor.getInputs());

396 linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());

397 res.setType(adaptor.getOutputs()[0].getType());

399

401 if (stt && stt->hasEncoding()) {

402 Value t = genRemap(rewriter, stt->getEncoding(), res);

404 }

406 }

407};

408

409struct GenericOpScheduler : public OpRewritePatternlinalg::GenericOp {

410 GenericOpScheduler(MLIRContext *context,

412 : OpRewritePatternlinalg::GenericOp(context), strategy(strategy) {}

413

414 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,

415 PatternRewriter &rewriter) const override {

416 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||

419 return failure();

420 }

421

422 const StringRef sorted = "sorted";

423 if (linalgOp->hasAttr(sorted))

424 return failure();

425

426

428 bool isAdmissible = false;

429 AffineMap order;

430

431

432

433

434 const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,

435 SortMask::kIncludeDenseInput,

436 SortMask::kIncludeDenseOutput,

437 SortMask::kSparseOnly};

438 for (const SortMask mask : allMasks) {

439 order = scheduler.sort(mask);

440 if (order) {

441 if (isAdmissibleOrder(linalgOp, order)) {

442 isAdmissible = true;

443 break;

444 }

445

446 }

447 }

448

449 if (!order) {

450

451 if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {

453 linalgOp, "the sparse kernel can not be scheduled: loop detected.");

454 }

456 }

457

458 if (!isAdmissible) {

460 linalgOp, "the sparse kernel can not be scheduled.");

461 }

462

463

465 linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));

466 });

467

468

471

473

474 ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();

475 SmallVector curItTypes;

476 curItTypes.reserve(preItTypes.size());

477 for (AffineExpr expr : order.getResults()) {

478 unsigned loopID = llvm::cast(expr).getPosition();

479 curItTypes.push_back(preItTypes[loopID]);

480 }

481

482

484 SmallVector idxMaps = linalgOp.getIndexingMapsArray();

485 for (AffineMap &idxMap : idxMaps)

486 idxMap = idxMap.compose(order);

487

490 linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));

492

494 }

495

496private:

497

498 static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {

500 return true;

501

502 OpOperand *lhs = linalgOp.getDpsInitOperand(0);

503 unsigned nest = 0;

504 const auto iteratorTypes = linalgOp.getIteratorTypesArray();

505 for (const AffineExpr l : order.getResults()) {

506 unsigned loopId = llvm::cast(l).getPosition();

507 auto itTp =

508 castlinalg::IteratorTypeAttr(linalgOp.getIteratorTypes()[loopId]);

510 break;

511 nest++;

512 }

513

514

515

516 return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;

517 };

518

519

520 static LogicalResult resolveCycle(IterationGraphSorter &scheduler,

521 linalg::LinalgOp linalgOp,

522 PatternRewriter &rewriter) {

523

524

525 for (OpOperand *t : linalgOp.getDpsInputOperands()) {

526 Value tval = t->get();

528

529

530 AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);

531 bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {

532 return !llvm::isa(exp);

533 });

534 if (!srcEnc || hasCompExpr)

535 continue;

536

537

538 AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);

539 if (!order)

540 continue;

541

542

543

544

546 assert(stt.isIdentity());

548

549 idxMap = idxMap.compose(order);

550

551

552

553

554

555

556

557

558 SmallVector<std::pair<unsigned, unsigned>> lvlSeq;

559 for (AffineExpr expr : idxMap.getResults()) {

560 unsigned lvl = llvm::cast(expr).getPosition();

561 lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));

562 }

563 llvm::sort(lvlSeq, llvm::less_first());

564 SmallVector perm =

565 llvm::to_vector(llvm::make_second_range(lvlSeq));

567

568 assert(!dimToLvl.isIdentity());

569

570

572 RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();

573 Value dst = ConvertOp::create(rewriter, tval.getLoc(), dstTp, tval);

575 linalgOp->setOperand(t->getOperandNumber(), dst);

576 });

577

578

579

581 bufferization::DeallocTensorOp::create(rewriter, dst.getLoc(), dst);

582

584 }

585

586

587 return failure();

588 }

589

590private:

592};

593

594

595

596

597

598template

599struct TensorAllocDemapper : public OpRewritePattern {

600 using OpRewritePattern::OpRewritePattern;

601 LogicalResult matchAndRewrite(AllocOp op,

602 PatternRewriter &rewriter) const override {

604 return failure();

605

606 Location loc = op.getLoc();

608

609 SmallVector maxDimCrds;

610 maxDimCrds.reserve(stt.getDimRank());

611 ValueRange dynSz = op.getDynamicSizes();

612 for (int64_t dimSz : stt.getDimShape()) {

613 if (ShapedType::isDynamic(dimSz)) {

614 Value maxCrd = arith::SubIOp::create(rewriter, loc, dynSz.front(),

616 maxDimCrds.push_back(maxCrd);

617 dynSz = dynSz.drop_front();

618 } else {

619 maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));

620 }

621 }

622

623 ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,

624 CrdTransDirectionKind::dim2lvl);

625 auto lvlShape = stt.getLvlShape();

626 SmallVector dynLvlSzs;

627 for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {

628 if (ShapedType::isDynamic(lvlShape[i])) {

629 Value sz = arith::AddIOp::create(rewriter, loc, maxLvlCrds[i],

631 dynLvlSzs.push_back(sz);

632 }

633 }

634

635 assert(dynSz.empty());

637 op->setOperands(dynLvlSzs);

638 op.getResult().setType(stt.getDemappedType());

641

642 Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());

645 }

646};

647

648struct TensorInsertDemapper

649 : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {

650 using DemapInsRewriter::DemapInsRewriter;

651 LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,

652 PatternRewriter &rewriter) const {

654 return failure();

655

656 Location loc = op.getLoc();

658 ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),

659 CrdTransDirectionKind::dim2lvl);

660 auto insertOp = tensor::InsertOp::create(rewriter, loc, op.getScalar(),

661 adaptor.getDest(), lvlCrd);

662

663 Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());

666 }

667};

668

669struct SparseAssembleDemapper : public OpRewritePattern {

671 LogicalResult matchAndRewrite(AssembleOp op,

672 PatternRewriter &rewriter) const override {

674 return failure();

675

679 op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });

681 Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());

684 }

685};

686

687struct SparseDisassembleDemapper

688 : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {

689 using DemapInsRewriter::DemapInsRewriter;

690 LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,

691 PatternRewriter &rewriter) const {

693 return failure();

694

697 op.getTensorMutable().assign(adaptor.getTensor());

698 });

700 }

701};

702

703struct ForeachOpDemapper

704 : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {

705 using DemapInsRewriter::DemapInsRewriter;

706 LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,

707 PatternRewriter &rewriter) const {

708

709

711 return failure();

712

713

714 if (auto constOp = op.getTensor().getDefiningOparith::ConstantOp())

715 if (auto attr = dyn_cast(constOp.getValue()))

716 return failure();

717

718 Location loc = op.getLoc();

719

721 SmallVector prevRetTps(op.getResultTypes());

722

724 op.getTensorMutable().assign(adaptor.getTensor());

725 op.getInitArgsMutable().assign(adaptor.getInitArgs());

726

727 for (auto r : op.getResults())

729 r.setType(stt->getDemappedType());

730

732

733 SmallVector blockArgTps(lvlRank, rewriter.getIndexType());

734 blockArgTps.push_back(srcStt.getElementType());

735 blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),

736 adaptor.getInitArgs().getTypes().end());

737 Block *body = op.getBody();

738

740 for (Type t : blockArgTps)

742

743

746

747 ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,

748 CrdTransDirectionKind::lvl2dim);

750 body->getArguments().take_front(srcStt.getDimRank()), dimCrds);

752

753 unsigned numInitArgs = op.getInitArgs().size();

755 body->getArgument(lvlRank + numInitArgs + 1));

757

760

761 SmallVector reMappedArgs =

765

766

767

768 if (numInitArgs != 0) {

770 auto yield = llvm::cast(body->getTerminator());

772 stt && !stt->isIdentity()) {

773 Value y =

774 genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());

775 YieldOp::create(rewriter, loc, y);

777 }

778 }

780

782 SmallVector outs =

784

785

786

787 for (auto [from, to] : llvm::zip(op.getResults(), outs))

789

791 }

792};

793

794}

795

801 patterns.add(patterns.getContext());

802 patterns.add(patterns.getContext(), strategy);

803 }

806 patterns.add<TensorAllocDemapperbufferization::AllocTensorOp,

807 TensorAllocDemappertensor::EmptyOp, SparseAssembleDemapper,

808 SparseDisassembleDemapper, TensorInsertDemapper,

809 ForeachOpDemapper>(patterns.getContext());

810 }

811}

static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)

Definition SparseReinterpretMap.cpp:338

static SmallVector< Value > remapValueRange(OpBuilder &rewriter, TypeRange types, ValueRange outs)

Definition SparseReinterpretMap.cpp:350

static AffineMap genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, SmallVector< utils::IteratorType > &itTps)

Definition SparseReinterpretMap.cpp:151

static std::optional< std::pair< ArrayAttr, ArrayAttr > > translateMap(linalg::GenericOp op, PatternRewriter &rewriter)

Definition SparseReinterpretMap.cpp:242

static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)

Definition SparseReinterpretMap.cpp:345

static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput)

Definition SparseReinterpretMap.cpp:105

unsigned getPosition() const

See documentation for AffineExprVisitorBase.

Base type for affine expression.

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

MLIRContext * getContext() const

static AffineMap get(MLIRContext *context)

Returns a zero result affine map with no dimensions or symbols: () -> ().

unsigned getNumDims() const

ArrayRef< AffineExpr > getResults() const

unsigned getNumResults() const

AffineExpr getResult(unsigned idx) const

AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const

Sparse replace method.

static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)

Returns an AffineMap representing a permutation.

AffineMap compose(AffineMap map) const

Returns the AffineMap resulting from composing this with map.

bool isIdentity() const

Returns true if this affine map is an identity affine map.

bool isPermutation() const

Returns true if the AffineMap represents a symbol-less permutation map.

Attributes are known-constant values of operations.

BlockArgument getArgument(unsigned i)

unsigned getNumArguments()

Operation * getTerminator()

Get the terminator operation of this block.

BlockArgument addArgument(Type type, Location loc)

Add one value to the argument list.

void eraseArguments(unsigned start, unsigned num)

Erases 'num' arguments from the index 'start'.

BlockArgListType getArguments()

void eraseArgument(unsigned index)

Erase the argument at 'index' and remove it from the argument list.

BoolAttr getBoolAttr(bool value)

ArrayAttr getArrayAttr(ArrayRef< Attribute > value)

ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)

MLIRContext is the top-level object for a collection of MLIR operations.

This class helps build Operations.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

This class represents an operand of an operation.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

virtual void finalizeOpModification(Operation *op)

This method is used to signal the end of an in-place modification of the given operation.

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)

Find uses of from and replace them with to except if the user is exceptedUser.

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

virtual void replaceAllUsesWith(Value from, Value to)

Find uses of from and replace them with to.

virtual void startOpModification(Operation *op)

This method is used to notify the rewriter that an in-place operation modification is about to happen...

This class provides an abstraction over the various different ranges of value types.

This class provides an abstraction over the different types of ranges over Values.

type_range getTypes() const

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

void setType(Type newType)

Mutate the type of this Value to be of the specified type.

Type getType() const

Return the type of this value.

Location getLoc() const

Return the location of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp, sparse_tensor::LoopOrderingStrategy strategy)

Factory method that constructs an iteration graph sorter for the given linalg.generic operation with ...

AffineMap sort(SortMask mask, Value ignored=nullptr)

Returns a permutation that represents the scheduled loop order.

Level getLvlRank() const

Returns the level-rank.

bool isReductionIterator(utils::IteratorType iteratorType)

Check if iterator type has "reduction" semantics.

Value constantIndex(OpBuilder &builder, Location loc, int64_t i)

Generates a constant of index type.

bool hasAnySparseOperandOrResult(Operation *op)

Returns true iff MLIR operand has any sparse operand or result.

uint64_t Level

The type of level identifiers and level-ranks.

LoopOrderingStrategy

Defines a strategy for loop ordering during sparse code generation.

AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)

Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.

SparseTensorEncodingAttr getSparseTensorEncoding(Type type)

Convenience method to get a sparse encoding attribute from a type.

std::optional< SparseTensorType > tryGetSparseTensorType(Value val)

bool hasAnyNonIdentityOperandsOrResults(Operation *op)

Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.

SparseTensorType getSparseTensorType(Value val)

Convenience methods to obtain a SparseTensorType from a Value.

SortMask

Iteration graph sorting mask,.

bool hasAnySparseResult(Operation *op)

Returns true iff MLIR operand has any sparse result.

Include the generated interface declarations.

const FrozenRewritePatternSet GreedyRewriteConfig bool * changed

AffineMap inversePermutation(AffineMap map)

Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...

@ Mod

RHS of mod is always a constant or a symbolic expression with a positive value.

@ FloorDiv

RHS of floordiv is always a constant or a symbolic expression.

AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)

ReinterpretMapScope

Defines a scope for reinterpret map pass.

const FrozenRewritePatternSet & patterns

AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)

llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap

AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)

These free functions allow clients of the API to not use classes in detail.

void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy=sparse_tensor::LoopOrderingStrategy::kDefault)

Definition SparseReinterpretMap.cpp:796

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...