MLIR: lib/Dialect/SparseTensor/Transforms/Sparsification.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

16

36 #include "llvm/ADT/SmallBitVector.h"

37

38 #include

39

40 using namespace mlir;

42

43

44

45

46

47

48

52 const LoopId i = cast(a).getPosition();

53 if (i + 1 == curr) {

54 isCurrentLoop = true;

55 return true;

56 }

57 return i < curr;

58 }

61 auto binOp = cast(a);

64 }

65 default: {

66 assert(isa(a));

67 return true;

68 }

69 }

70 }

71

72

73

74

76 LevelType lt, bool setLvlFormat = true) {

79 const LoopId idx = merger.makeLoopId(cast(a).getPosition());

81 return false;

82 if (setLvlFormat)

84 return true;

85 }

90 if (auto binOp = dyn_cast(a)) {

91

92

93

94 return findAffine(merger, tid, lvl, binOp.getLHS(), lt, false) &&

95 findAffine(merger, tid, lvl, binOp.getRHS(), lt, false);

96 }

97

98 return true;

99 }

100 default:

101 return false;

102 }

103 }

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

121 int64_t coefficient = 1) {

124

125 if (coefficient <= 0)

126 return false;

127

128 const LoopId idx = merger.makeLoopId(cast(a).getPosition());

130 return false;

131

132

133

134

135 if (!isSubExp) {

136 assert(coefficient == 1);

138 }

139

140 if (isSubExp) {

141

142

143

145

146

147

148

149

150

151

152

153

154 return false;

155 }

157 }

158 return true;

159 }

162

163

164 if (!isSubExp)

165 return false;

166

167

168 if (isa(a))

169 llvm_unreachable("Not yet implemented");

170

171 auto binOp = cast(a);

172 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();

173 if (isa(rhs))

174 std::swap(lhs, rhs);

175

176 assert(isa(lhs) && isa(rhs));

177 int64_t coefficient = cast(lhs).getValue();

178 return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient);

179 }

181 auto binOp = cast(a);

182 return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt, true) &&

183 findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt, true);

184 }

185 default:

186 return false;

187 }

188 }

189

190

191

192

193

194

195

196

199

200

201

202

203 const auto rtp = dyn_cast(tensor.getType());

204 if (!rtp)

205 return 0;

207

210 assert(static_cast<Dimension>(exprs.size()) == lvlRank &&

211 "AffineMap does not have dimension-rank many results");

212 unsigned num = 0;

213 for (Level l = 0; l < lvlRank; l++) {

215 num++;

216 }

217 return num;

218 }

219

220

221

223 unsigned num = 0;

224 for (OpOperand &t : op->getOpOperands())

226 t.get());

227 return num;

228 }

229

230

232 OpOperand *out = op.getDpsInitOperand(0);

234 return false;

236 out->get());

237 }

238

239

240

241

242

243

244

245

246

247

249 bool annotated = false;

250 for (OpOperand &t : env.op()->getOpOperands()) {

252 const auto map = env.op().getMatchingIndexingMap(&t);

254 if (enc)

255 annotated = true;

256 const Level lvlRank = map.getNumResults();

257 assert(!enc || lvlRank == enc.getLvlRank());

258 assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);

259

260

261

262

263 bool needIdxReduc =

265

266

267 for (Level l = 0; l < lvlRank; l++) {

268 const AffineExpr a = map.getResult(l);

269 const LevelType lt = enc.getLvlType(l);

270 if (idxReducBased && needIdxReduc) {

272 return false;

273 } else {

275 return false;

276 }

277 }

278 }

279 return annotated;

280 }

281

282

283

284

285

286

288 linalg::GenericOp op = env.op();

290 assert(op.getNumOperands() == op.getNumDpsInputs() + 1);

291

293 llvm::castlinalg::LinalgOp(op.getOperation())

294 .createLoopRanges(builder, loc);

295

297 builder, loc,

298

299

300

301

302

303

304

307

309

310 OpOperand *lhs = op.getDpsInitOperand(0);

311 assert(lhs->get() == tensor);

312

313

314

315

316

317

318

319

320 bool isInit = op.isInitTensor(lhs);

321 Value init = memref;

322 if (!isInit) {

327 }

328 return init;

329 },

331 assert(l < loopRange.size());

333 });

334 }

335

336

338 const auto map = env.op().getMatchingIndexingMap(t);

340 const Level lvlRank = stt.getLvlRank();

341 assert(static_cast<Level>(map.getNumResults()) == lvlRank);

342 const AffineExpr a = map.getResult(lvlRank - 1);

344 const LoopId idx = env.makeLoopId(cast(a).getPosition());

346 }

347

348

351 const Location loc = env.op().getLoc();

353 const auto map = env.op().getMatchingIndexingMap(t);

355 if (stt.hasEncoding()) {

356

358 assert(!pos.empty());

359 args.append(pos);

360

362 return t->get();

363 } else {

364

365 const Level lvlRank = stt.getLvlRank();

366 assert(static_cast<Level>(map.getNumResults()) == lvlRank);

367 for (Level l = 0; l < lvlRank; l++) {

368 const auto lvlExpr = map.getResult(l);

369 const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr);

370 args.push_back(lvlCrd);

371 }

372 }

374 }

375

376

379 linalg::GenericOp op = env.op();

381

385 }

386

389 }

390

391

394 linalg::GenericOp op = env.op();

397

399 return identity;

400

404 Value isFilled = builder.creatememref::LoadOp(loc, filled, index);

405 Value valAtIndex = builder.creatememref::LoadOp(loc, values, index);

406 return builder.createarith::SelectOp(loc, isFilled, valAtIndex, identity);

407 }

408

411 scf::IfOp condInsert =

412 builder.createscf::IfOp(loc, sparseOut.getType(), cond, true);

413

415 Value res = builder.createtensor::InsertOp(loc, v, sparseOut, ivs);

416 builder.createscf::YieldOp(loc, res);

417

419 builder.createscf::YieldOp(loc, sparseOut);

420

422 return condInsert.getResult(0);

423 }

424

425

428 linalg::GenericOp op = env.op();

430

432 const LoopId numLoops = op.getRank(t);

433

438

439

440

441

442

443

444

446 chain, ivs, rhs);

448 } else {

451

452

455 } else {

456 sparseOut = builder.createtensor::InsertOp(loc, rhs, chain, ivs);

457 }

458

460 }

461 return;

462 }

463

464

465

466

467

468

476

477 Value isFilled = builder.creatememref::LoadOp(loc, filled, index);

478 Value cond = builder.createarith::CmpIOp(loc, arith::CmpIPredicate::eq,

479 isFilled, fval);

480 scf::IfOp ifOp = builder.createscf::IfOp(loc, builder.getIndexType(), cond,

481 true);

482

484 builder.creatememref::StoreOp(loc, tval, filled, index);

485 builder.creatememref::StoreOp(loc, index, added, count);

487 Value add = builder.createarith::AddIOp(loc, count, one);

488 builder.createscf::YieldOp(loc, add);

489

491 builder.createscf::YieldOp(loc, count);

493

495 builder.creatememref::StoreOp(loc, rhs, values, index);

496 }

497

498

500

502 if (val)

503 return val;

504

505 linalg::GenericOp op = env.op();

508

510 if (auto explVal = stt.getExplicitVal())

512

517 }

518

519

522 if (llvm::isa(ptr.getType())) {

525 return builder.create(loc, ptr, llvm::getSingleElement(args));

526 }

527 return builder.creatememref::LoadOp(loc, ptr, args);

528 }

529

530

533

534

535

536 if (!rhs) {

540 return;

541 }

542

545 return;

546 }

547

548 linalg::GenericOp op = env.op();

550 OpOperand *t = op.getDpsInitOperand(0);

554 builder.creatememref::StoreOp(loc, rhs, ptr, args);

555 return;

556 }

557

560 return;

561 }

562

564 scf::IfOp ifOp =

565 builder.createscf::IfOp(loc, chain.getType(), rhs, true);

567

568 assert(env.exp(exp).val);

572

574 builder.createscf::YieldOp(op.getLoc(), mchain);

575

577 builder.createscf::YieldOp(loc, chain);

578

581 }

582

583

585 return env.exp(exp).val;

586 }

587

588

589

590

591

592

595 if (auto arg = dyn_cast(e)) {

596

597

598

599 linalg::GenericOp op = env.op();

600 if (arg.getOwner()->getParentOp() == op) {

602 OpOperand *t = &op->getOpOperand(tid);

606 return rewriter.creatememref::LoadOp(op.getLoc(), ptr, args);

607 }

609

610 if (auto indexOp = dyn_castlinalg::IndexOp(def))

612

613 if (def->getBlock() == block) {

615 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {

617 def->setOperand(

618 i, relinkBranch(env, rewriter, block, def->getOperand(i)));

619 });

620 }

621 }

622 }

623 return e;

624 }

625

626

630

631 linalg::GenericOp op = env.op();

641

644

645

646

656 } else {

659 }

660

663

664 } else {

665 ee = env.merger().buildExp(rewriter, loc, e, v0, v1);

666 if (ee &&

673 }

674 }

675

678

681

682 return ee;

683 }

684

685

687 LoopId curr, bool isStart) {

689 return;

691

692 linalg::GenericOp op = env.op();

694 const auto map = op.getMatchingIndexingMap(&t);

696 const Level lvlRank = stt.getLvlRank();

697 assert(static_cast<Level>(map.getNumResults()) == lvlRank);

698 bool isCurrentLoop = curr == 0;

699 for (Level l = 0; l < lvlRank; l++) {

700 const AffineExpr a = map.getResult(l);

702 return;

703 }

704

705 if (!isCurrentLoop)

706 return;

707

708

709

710

711 OpOperand *lhs = op.getDpsInitOperand(0);

712 if (lhs == &t) {

713

714 if (isStart) {

718 } else {

720 }

723 constantI1(builder, env.op().getLoc(), false));

724 } else {

729 }

730 } else {

731

732 if (isStart) {

734 } else {

736 }

737 }

741

742

743

752 }

753 }

754

755

757 bool isStart) {

758 linalg::GenericOp op = env.op();

759 OpOperand *lhs = op.getDpsInitOperand(0);

760 if (!env.atExpandLevel(lhs, op.getRank(lhs), curr))

761 return;

763

764

765

766

767

770 if (isStart) {

771 auto dynShape = {ShapedType::kDynamic};

772 Type etp = cast(tensor.getType()).getElementType();

777 auto r = builder.create(loc, TypeRange({t1, t2, t3, t4}), tensor);

778 assert(r.getNumResults() == 4);

779 env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2),

780 r.getResult(3));

781 } else {

783 for (LoopId i = 0; i < curr; i++)

790 Value compress = builder.create(loc, values, filled, added,

791 count, chain, indices);

794 }

795 }

796

797

798

799

801

803 return false;

804

806 return false;

807

810 return false;

812 return isOuter && !isSparse;

814 return isOuter;

816 return !isSparse;

818 return true;

819 }

820 llvm_unreachable("unexpected parallelization strategy");

821 }

822

823

824

827 linalg::GenericOp op = env.op();

828 auto iteratorTypes = op.getIteratorTypesArray();

829 bool isSparse = llvm::any_of(tidLvls, [curr, &env](TensorLevel tidLvl) {

830

831

832

835 });

836 return isParallelFor(env, curr == 0, isSparse);

837 }

838

839

840

841

844 unsigned numCases, bool tryParallel,

845 bool needsUniv) {

847

848 return env.emitter().enterCoIterationOverTensorsAtLvls(

849 builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel,

850 needsUniv);

851 });

852 assert(loop);

853 return loop;

854 }

855

856

857

859 unsigned numCases, bool needsUniv,

862 return genCoIteration(env, builder, tidLvls, numCases, tryParallel,

863 needsUniv);

864 }

865

866

868 bool needsUniv) {

870

872 while (auto ifOp = dyn_cast_or_nullscf::IfOp(

874

877 break;

878

879 unsigned y = 0;

882 yields.push_back(env.getReduc());

887 }

888 }

892 }

896 }

897 assert(y == yields.size());

898 builder.createscf::YieldOp(loc, yields);

900 }

901 }

902

903

904 }

905

906

908 unsigned caseIdx, LatPointId allCase,

911 assert(allCase == curCase || env.merger().latGT(allCase, curCase));

912 const BitVector &allCaseBits = env.merger().lat(allCase).simple;

913 const BitVector &curCaseBits = env.merger().lat(curCase).simple;

914

915

916

918 for (auto [idx, set] : llvm::enumerate(allCaseBits.set_bits()))

919 if (curCaseBits.test(set))

920 caseBit.set(idx);

921

923 caseIdx, reduc);

924 }

925

926

933 p, true,

935 bool isIdxRed) {

936 if (isIdxRed) {

937

938

939

940

941 assert(lvl.has_value() && isUndefLT(lt));

943 lt = stt.getLvlType(*lvl);

944 }

945 assert(curr == env.merger().loop(b));

948 assert(lvl.has_value());

951 clause = builder.createarith::CmpIOp(loc, arith::CmpIPredicate::eq,

952 crd, lvar);

953 } else {

955 clause = constantI1(builder, loc, true);

956 }

957 cond = cond ? builder.createarith::AndIOp(loc, cond, clause) : clause;

958 });

963 }

968 scf::IfOp ifOp = builder.createscf::IfOp(loc, types, cond, true);

970 return ifOp;

971 }

972

973

976 Value validIns) {

979 operands.push_back(env.getReduc());

982

983 operands.push_back(constantI1(builder, env.op().getLoc(), true));

985 }

986 }

990 }

994 }

995 if (!operands.empty())

996 builder.createscf::YieldOp(env.op().getLoc(), operands);

998 }

999

1000

1001

1002

1003

1007 const BitVector &simple = env.lat(li).simple;

1009 const std::optional outLvl = env.merger().getLvl(outTid, curr);

1010

1011 unsigned numloopCond = 0;

1012 bool hasNonUnique = false;

1015 LevelType lt, bool isIdxReduc) {

1016 if (simple[b]) {

1017 if (isIdxReduc) {

1019 numloopCond++;

1020 return;

1021 }

1023

1024

1025

1027

1028

1029

1030

1031

1032

1033

1035 lvl = curr;

1036 } else if (!lvl) {

1037

1038 return;

1039 }

1040 }

1041 hasNonUnique = isUniqueLT(lt) || hasNonUnique;

1043 numloopCond++;

1046 } else {

1048 linalg::GenericOp op = env.op();

1049 if (tid >= op.getNumDpsInputs())

1050

1051 return;

1052 OpOperand *operand = &op->getOpOperand(tid);

1054

1055 if (!stt.hasEncoding())

1056 return;

1057

1059 op.getMatchingIndexingMap(operand).getResults();

1060 const Level lvlRank = stt.getLvlRank();

1061 assert(affines.size() == static_cast<size_t>(lvlRank));

1062 for (Level l = 0; l < lvlRank; l++) {

1064

1065

1066 LevelType lt = stt.getLvlType(l);

1068 continue;

1069

1070

1071 if (!isa(exp)) {

1072 bool isCurrentLoop = false;

1075 isCurrentLoop) {

1076

1077

1078

1079

1081 }

1082 }

1083 }

1084 }

1085 });

1086

1089

1090

1091

1092

1093

1094 if (stt.hasEncoding() && stt.isAllDense())

1095 callback(env.makeTensorLevel(outTid, *outLvl), nullptr);

1096 }

1097

1098 if (numloopCond == 0) {

1099

1100

1101

1103 numloopCond++;

1104 }

1105

1106

1107

1108

1109 return numloopCond == 1 &&

1112 }

1113

1114

1115

1119

1120 genInvariants(env, builder, exp, curr, true);

1121

1122 genExpand(env, builder, curr, true);

1123

1125

1128

1129

1130

1131 if (llvm::is_contained(tidLvls, tl))

1132 return;

1133 tidLvls.emplace_back(tl);

1134 });

1135

1137

1138

1139

1140 for (const LatPointId li : env.set(lts).drop_front())

1142 return true;

1143

1144 return false;

1145 }

1146

1147

1150 Level startLvl) {

1151

1152 linalg::GenericOp op = env.op();

1153 assert(tid < op.getNumDpsInputs());

1154 OpOperand *input = op.getDpsInputOperands()[tid];

1155 const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();

1157 if (enc) {

1158 const Location loc = op.getLoc();

1160 const Level lvlRank = enc.getLvlRank();

1161 assert(lvlExprs.size() == static_cast<size_t>(lvlRank));

1162 for (Level l = startLvl; l < lvlRank; l++) {

1164 if (enc.getLvlType(l).hasDenseSemantic() &&

1165 isa(lvlExpr))

1168 else

1169 return;

1170 }

1171 }

1172 }

1173

1174

1175

1176

1177

1180 for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)

1182 }

1183

1184

1188 SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {

1191 if (exp)

1192 affineTidLvls.emplace_back(tl, exp);

1193 else

1194 tidLvls.emplace_back(tl);

1195 });

1196 }

1197

1198

1202 bool needsUniv) {

1203

1204

1205

1207

1208

1209

1211 bool isSingleCond =

1213

1214

1215 Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls);

1217 for (auto [tidLvl, exp] : affineTidLvls) {

1219 }

1220

1221

1222

1223

1224 auto allTidLvls =

1225 llvm::concat(tidLvls, llvm::make_first_range(affineTidLvls));

1230 }

1231

1232 return std::make_pair(loop, isSingleCond);

1233 }

1234

1235

1237 LatPointId li, bool needsUniv, bool isSingleCond) {

1238

1239 if (isSingleCond) {

1240

1243 } else if (auto whileOp = dyn_castscf::WhileOp(loop)) {

1244

1246 } else {

1247 needsUniv = false;

1248 }

1250 env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);

1251 return std::nullopt;

1252 });

1253 return needsUniv;

1254 }

1255

1256

1258 unsigned at) {

1261

1262 genInvariants(env, builder, exp, at, false);

1263

1264 genExpand(env, builder, at, false);

1265 }

1266

1267

1268

1269

1273

1274

1278 return;

1279 }

1280

1281

1284

1285

1286 bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);

1287

1288

1289

1290 const unsigned lsize = env.set(lts).size();

1292

1294 auto [loop, isSingleCond] =

1295 startLoop(env, rewriter, curr, li, lsize, needsUniv);

1296 assert(isSingleCond == llvm::isa(loop));

1297

1298

1299

1300 for (unsigned j = 0; j < lsize; j++) {

1303

1304 if (!isSingleCond) {

1307 genStmt(env, rewriter, ej, curr + 1);

1308

1309 assert(reduc.empty() && "Not Implemented");

1310 rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc());

1311 return std::nullopt;

1312 });

1313

1314 } else {

1315 genStmt(env, rewriter, ej, curr + 1);

1316 }

1317 }

1318

1319 needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);

1320 } else {

1321

1322 for (unsigned i = 0; i < lsize; i++) {

1324

1325 auto [loop, isSingleCond] =

1326 startLoop(env, rewriter, curr, li, lsize, needsUniv);

1327

1328

1329

1334

1335

1336

1337 for (unsigned j = 0; j < lsize; j++) {

1340 if (li == lj || env.merger().latGT(li, lj)) {

1341

1342 if (!isSingleCond) {

1343 scf::IfOp ifOp = genIf(env, rewriter, curr, lj);

1344 genStmt(env, rewriter, ej, curr + 1);

1345 endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);

1346 } else {

1347 genStmt(env, rewriter, ej, curr + 1);

1348 }

1349 }

1350 }

1351

1352

1353 needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);

1354 }

1355 }

1356

1357

1358 endLoopSeq(env, rewriter, exp, curr);

1360 }

1361

1362

1364 linalg::GenericOp op = env.op();

1365 OpOperand *lhs = op.getDpsInitOperand(0);

1369

1370

1371

1372 bool hasInserts = false;

1374 hasInserts = true;

1375 tensor = chain;

1376 }

1377 rewriter.replaceOpWithNewOp(op, resType, tensor, hasInserts);

1378 } else {

1379

1380

1382 rewriter.replaceOpWithNewOpbufferization::ToTensorOp(op, resType, val);

1383 }

1384 }

1385

1386

1387

1388

1389

1390 namespace {

1391

1392

1393 struct GenericOpSparsifier : public OpRewritePatternlinalg::GenericOp {

1394 public:

1397

1398 LogicalResult matchAndRewrite(linalg::GenericOp op,

1400

1401 if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics())

1402 return failure();

1403

1404

1406 return failure();

1407

1408

1409 if (!op->hasAttr("sorted")) {

1411 op, "Loops not yet scheduled, try run --sparse-reinterpret-map "

1412 "before sparsification.");

1413 }

1414

1415

1417

1418

1419 const unsigned numTensors = op->getNumOperands();

1420 const unsigned numLoops = op.getNumLoops();

1422

1423

1424

1425

1426

1427

1428 Level maxLvlRank = 0;

1429 for (auto operand : op.getOperands()) {

1430 if (auto rtp = dyn_cast(operand.getType())) {

1432 }

1433 }

1434

1435

1436

1439 return failure();

1440

1441

1442

1443

1444

1445

1446 if (op.getNumReductionLoops() > 0) {

1448 assert(isalinalg::YieldOp(yield));

1450 if (!isaarith::AddFOp(redop) && !isacomplex::AddOp(redop) &&

1451 !isaarith::AddIOp(redop) && !isaarith::SubFOp(redop) &&

1452 !isacomplex::SubOp(redop) && !isaarith::SubIOp(redop) &&

1453 !isaarith::OrIOp(redop) && !isaarith::XOrIOp(redop) &&

1454 !isa(redop)) {

1455 return failure();

1456 }

1457 }

1458

1459

1460

1462 return failure();

1463

1464

1467

1468

1469

1473 return success();

1474 }

1475

1476 private:

1477

1479 };

1480

1481 }

1482

1483

1484

1488 }

union mlir::linalg::@1203::ArityGroupAndKind::Kind kind

static llvm::ManagedStatic< PassManagerOptions > options

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, Value tensor)

Gets the total number of compound affine expressions in the getMatchingIndexingMap for the given tens...

static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)

Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...

static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, OpOperand *t)

Generates insertion code to implement dynamic tensor load for reduction.

static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop)

Returns true iff affine expression is invariant.

static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, AffineExpr a, LevelType lt, bool isSubExp=false, int64_t coefficient=1)

Helper method to inspect affine expressions for index variable reduction based codegen.

static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId p)

Generates a single if-statement within a while-loop.

static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, SmallVectorImpl< Value > &args)

Generates subscript for load/store on a dense or sparse tensor.

static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond, Value sparseOut, ValueRange ivs, Value v)

static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, bool isStart)

Generates an expanded access pattern in innermost dimension.

static void genConstantDenseAddressFromLevel(CodegenEnv &env, OpBuilder &builder, TensorId tid, Level startLvl)

static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, LatSetId lts)

Starts a loop sequence at given level.

static std::pair< Operation *, bool > startLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId li, unsigned numCases, bool needsUniv)

Starts a single loop in current sequence.

static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, bool isStart)

Hoists loop invariant tensor loads for which indices have been exhausted.

static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, unsigned at)

Ends a loop sequence at given level.

static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse)

Returns parallelization strategy.

static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, LevelType lt, bool setLvlFormat=true)

Helper method to inspect affine expressions.

static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased)

Helper method to inspect sparse encodings in the tensor types.

static bool getAllTidLvlsInLatPoints(CodegenEnv &env, LatPointId li, LoopId curr, llvm::function_ref< void(TensorLevel, AffineExpr)> callback)

static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, OpOperand *t)

Generates insertion code to implement dynamic tensor load.

static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op)

static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, Value redInput, Value cntInput, Value insInput, Value validIns)

Generates end of true branch of if-statement within a while-loop.

static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, LoopId curr)

Recursively generates code while computing iteration lattices in order to manage the complexity of im...

static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, Value rhs)

Generates insertion code to implement dynamic tensor store.

static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, Value rhs)

Generates a store on a dense or sparse tensor.

static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, Value e)

Semi-ring branches are simply inlined by the sparsifier.

static void genBuffers(CodegenEnv &env, OpBuilder &builder)

Local bufferization of all dense and sparse data structures.

static void genResult(CodegenEnv &env, RewriterBase &rewriter)

Converts the result computed by the sparse kernel into the required form.

static bool shouldTryParallize(CodegenEnv &env, LoopId curr, ArrayRef< TensorLevel > tidLvls)

Whether or not the current loop being generated should be parallized (if possible) according to the c...

static bool translateBitsToTidLvlPairs(CodegenEnv &env, LatPointId li, LoopId curr, SmallVectorImpl< TensorLevel > &tidLvls, SmallVectorImpl< std::pair< TensorLevel, AffineExpr >> &affineTidLvls)

Returns true if the lattice bit can be iterated by a for loop.

static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e)

Recursively generates tensor expression.

static void genInitConstantDenseAddress(CodegenEnv &env, RewriterBase &rewriter)

static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp)

Generates a load on a dense or sparse tensor.

static Value genInvariantValue(CodegenEnv &env, ExprId exp)

Generates an invariant value.

static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder, unsigned caseIdx, LatPointId allCase, LatPointId curCase, MutableArrayRef< Value > reduc)

Generates a case region in the coiterate operation.

static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, LatPointId li, bool needsUniv, bool isSingleCond)

Ends a single loop in current sequence. Returns new values for needsUniv.

static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, bool needsUniv)

Generates the induction structure for a while-loop.

static Value genIndex(CodegenEnv &env, OpOperand *t)

Generates index for load/store on sparse tensor.

Base type for affine expression.

AffineExprKind getKind() const

Return the classification for this type.

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

ArrayRef< AffineExpr > getResults() const

Block represents an ordered list of Operations.

Operation * getTerminator()

Get the terminator operation of this block.

Operation * getParentOp()

Returns the closest surrounding operation that contains this block.

IRValueT get() const

Return the current value being used by this operand.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

Block * getInsertionBlock() const

Return the block the current insertion point belongs to.

This class represents an operand of an operation.

unsigned getOperandNumber()

Return which operand this is in the OpOperand list of the Operation.

Operation is the basic unit of execution within MLIR.

Value getOperand(unsigned idx)

Region & getRegion(unsigned index)

Returns the region held by this operation at position 'index'.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Block * getParentBlock()

Return the Block in which this Value is defined.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

The code generation environment class aggregates a number of data structures that are needed during t...

void startReduc(ExprId exp, Value val)

void updateValidLexInsert(Value val)

const SparsificationOptions & options() const

Value getInsertionChain() const

std::optional< Operation * > genLoopBoundary(function_ref< std::optional< Operation * >(MutableArrayRef< Value > parameters)> callback)

Generates loop boundary statements (entering/exiting loops).

ArrayRef< LatPointId > set(LatSetId s) const

bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const

bool isCustomReduc() const

unsigned getCurrentDepth() const

std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tl) const

Value getExpandValues() const

TensorLevel makeTensorLevel(TensorId t, Level l) const

const LatPoint & lat(LatPointId l) const

constexpr TensorId makeTensorId(unsigned t) const

void startExpand(Value values, Value filled, Value added, Value count)

bool hasSparseOutput() const

unsigned getLoopNum() const

void updateInsertionChain(Value chain)

bool generatingSparseIterator() const

Value getExpandCount() const

void startCustomReduc(ExprId exp)

linalg::GenericOp op() const

Value getLoopVar(LoopId i) const

Returns the induction-variable for the given loop.

Value getExpandFilled() const

LogicalResult initTensorExp()

void startEmit(SparseEmitStrategy emitStrategy)

auto unpackTensorLevelRange(ContainerTy &&c) const

Value getExpandAdded() const

const TensorExp & exp(ExprId e) const

void updateExpandCount(Value count)

void updateReduc(Value val)

Value getValidLexInsert() const

bool isSparseOutput(OpOperand *o) const

void startValidLexInsert(Value val)

constexpr LoopId makeLoopId(unsigned i) const

Value getCustomRedId() const

LevelType lt(TensorId t, LoopId i) const

bool isValidLexInsert() const

A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....

I64BitSet & set(unsigned i)

constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()

void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)

Emits the address for a dense level based on the value evaluated by the provided affine expression.

const std::vector< Value > & getValBuffer() const

void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)

Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...

Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)

Generates code to compute an affine expression whose variables are LoopIds (i.e., cast...

Region * enterCurrentCoIterationCase(OpBuilder &builder, Location loc, I64BitSet caseBit, unsigned caseIdx, MutableArrayRef< Value > reduc)

Value getLoopIV(LoopId n) const

Gets loop induction variable for the given loop.

SmallVector< Value > getValPosits(TensorId tid) const

Getters.

auto getLoopIVsRange() const

Get the range of values for all induction variables.

void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)

Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.

void exitCurrentLoopSeq(OpBuilder &builder, Location loc)

Exits the current loop sequence, this will reset universal index to 0.

Value getCoord(TensorId tid, Level lvl) const

A class to handle all iteration lattice operations.

std::optional< Level > getLvl(TensorId t, LoopId i) const

Gets the level number of the the tth tensor on ith loop.

LatSetId buildLattices(ExprId e, LoopId i)

Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...

constexpr LoopId makeLoopId(unsigned i) const

Safely converts the argument to a loop identifier.

void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)

Sets the level number and level-type of the tth tensor on ith loop.

void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const

Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...

LatSetId optimizeSet(LatSetId s)

Optimizes the iteration lattice points in the given set.

void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)

Establishes the two-way map that i <-> <t, lvl, lt>.

bool hasAnySparse(const BitVector &bits) const

Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.

void clearExprValue(ExprId e)

Clears the value associated with the expression.

constexpr TensorId getSynTensorID() const

Gets the synthetic tensor's identifier (used for all invariant tensor expressions).

bool latGT(LatPointId p0, LatPointId p1) const

Returns true if p0 > p1.

constexpr LoopId loop(TensorLoopId b) const

Gets the loop-identifier of the TensorLoopId.

const LatPoint & lat(LatPointId p) const

constexpr TensorId getOutTensorID() const

Gets the output tensor's identifier.

LevelType getLvlType(TensorId t, LoopId i) const

Gets the level-type of the tth tensor on ith loop.

Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const

Rebuilds SSA format from a tensor expression.

void setExprValue(ExprId e, Value v)

Sets the expression to have the associated value.

bool hasDependentLvl(LoopId i, TensorId t)

Whether the loop has dependent slice.

A wrapper around RankedTensorType, which has three goals:

bool hasEncoding() const

Returns true for tensors which have an encoding, and false for those which do not.

bool isAllDense() const

Returns true for tensors where every level is dense.

Level getLvlRank() const

Returns the level-rank.

LevelType getLvlType(Level l) const

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

static constexpr unsigned kInvalidId

A constant serving as the canonically invalid identifier, regardless of the identifier type.

bool isUniqueLT(LevelType lt)

Value constantIndex(OpBuilder &builder, Location loc, int64_t i)

Generates a constant of index type.

Value constantZero(OpBuilder &builder, Location loc, Type tp)

Generates a 0-valued constant of the given type.

unsigned LatSetId

LatSet identifiers.

uint64_t Dimension

The type of dimension identifiers and dimension-ranks.

unsigned TensorLoopId

A compressed representation of std::pair<TensorId, LoopId>.

uint64_t Level

The type of level identifiers and level-ranks.

unsigned LoopId

Loop identifiers.

Value constantI1(OpBuilder &builder, Location loc, bool b)

Generates a constant of i1 type.

Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr)

SparseTensorEncodingAttr getSparseTensorEncoding(Type type)

Convenience method to get a sparse encoding attribute from a type.

bool hasAnySparseType(TypeRange types)

Returns true iff the type range has any sparse tensor type.

Value genIsNonzero(OpBuilder &builder, Location loc, Value v)

Generates the comparison v != 0 where v is of numeric type.

bool isUndefLT(LevelType lt)

std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)

bool isDenseLT(LevelType lt)

bool hasAnyNonIdentityOperandsOrResults(Operation *op)

Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.

SparseTensorType getSparseTensorType(Value val)

Convenience methods to obtain a SparseTensorType from a Value.

unsigned ExprId

TensorExp identifiers.

unsigned LatPointId

LatPoint identifiers.

unsigned TensorId

Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...

Include the generated interface declarations.

@ Mul

RHS of mul is always a constant or a symbolic expression.

@ DimId

Dimensional identifier.

@ Constant

Constant integer.

Type getElementTypeOrSelf(Type type)

Return the element type or return the type itself.

void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())

Sets up sparsification rewriting rules with the given options.

const FrozenRewritePatternSet & patterns

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

Options for the Sparsification pass.

SparseEmitStrategy sparseEmitStrategy

SparseParallelizationStrategy parallelizationStrategy

ExprId exp

Identifier of the tensor expression.

BitVector simple

Simplified conjunction of TensorLoopId as bitvector.

This enum defines all the sparse representations supportable by the SparseTensor dialect.

constexpr bool hasSparseSemantic() const

Check if the LevelType is considered to be sparse.

constexpr bool hasDenseSemantic() const

Check if the LevelType is considered to be dense-like.

Tensor expression. Represents an MLIR expression in tensor index notation.

LoopId loop

kLoopVar expressions simply have a loop identifier.

Value val

Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...

Children children

All other expressions hold the ExprIds of their children.

TensorId tensor

kTensor expressions simply have a tensor identifier.

Kind kind

Tensor expression kind.

Eliminates variable at the specified position using Fourier-Motzkin variable elimination.