MLIR: lib/Dialect/SparseTensor/Utils/Merger.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

14

16 #include "llvm/Support/Debug.h"

17 #include

18

19 namespace mlir {

20 namespace sparse_tensor {

21

26 };

27

29 switch (k) {

30

72

98 }

99 llvm_unreachable("unexpected kind");

100 }

101

102

103

104

105

108 : kind(k), val(v), op(o), attr(a) {

109 switch (kind) {

110

114 return;

117 return;

120 return;

124 return;

125

150 return;

165 return;

171 return;

173

174

178 return;

179

202 return;

208 return;

214 return;

219 return;

220 }

221 llvm_unreachable("unexpected kind");

222 }

223

224 Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,

225 unsigned maxLvlRank)

226 : outTensor(numInputOutputTensors - 1),

227 syntheticTensor(numInputOutputTensors),

228 numTensors(numInputOutputTensors + 1), numLoops(numLoops),

229 hasSparseOut(false),

230 lvlTypes(numTensors,

232 loopToLvl(numTensors,

233 std::vector<std::optional<Level>>(numLoops, std::nullopt)),

234 lvlToLoop(numTensors,

235 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),

236 loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>(

237 numTensors, std::nullopt)),

238 levelToDependentLoop(numTensors,

241 loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}

242

243

244

245

246

248 assert(isValidTensorId(t));

249 const ExprId eNew(tensorExps.size());

251 Value(), nullptr, nullptr);

252 return eNew;

253 }

254

256 assert(isValidLoopId(i));

257 const ExprId eNew(tensorExps.size());

259 Value(), nullptr, nullptr);

260 return eNew;

261 }

262

264 const ExprId eNew(tensorExps.size());

267 return eNew;

268 }

269

271 const ExprId eNew(tensorExps.size());

274 return eNew;

275 }

276

280 const ExprId eNew(tensorExps.size());

281 tensorExps.emplace_back(k, e0, e1, Value(), op, attr);

282 return eNew;

283 }

284

288 const ExprId eNew(tensorExps.size());

290 return eNew;

291 }

292

294 const LatPointId pNew(latPoints.size());

295 const unsigned size = numLoops * numTensors;

297 latPoints.emplace_back(size, e);

298 latPoints[pNew].bits.set(b);

299 return pNew;

300 }

301

303 assert(bits.size() == numLoops * numTensors);

304 const LatPointId pNew(latPoints.size());

305 latPoints.emplace_back(bits, e);

306 return pNew;

307 }

308

310 const LatSetId sNew(latSets.size());

311 latSets.emplace_back();

312 return sNew;

313 }

314

319 const LatPointId pNew(latPoints.size());

320 const auto &point0 = lat(p0);

321 const auto &point1 = lat(p1);

322 BitVector bits(point0.bits);

323 bits |= point1.bits;

324 const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr);

325 latPoints.emplace_back(bits, ne);

326 return pNew;

327 }

328

331 auto &setNew = latSets[sNew];

334 setNew.push_back(conjLat(e, p0, p1, op));

335 return sNew;

336 }

337

341

342 latSets[sNew].append(latSets[s0]);

343

344

351

352 latSets[sNew].append(latSets[s1]);

353 return sNew;

354 }

355

360

365

367

368

369

370 return sNew;

371 }

372

375 latSets[sNew].append(latSets[lhsSet]);

376 latSets[sNew].append(latSets[rhsSet]);

377 return sNew;

378 }

379

382 Operation *opleft, bool includeRight,

386

387 if (includeLeft) {

388 if (opleft)

389 s0 = mapSet(ltrans, s0, Value(), opleft, a);

390 latSets[sNew].append(latSets[s0]);

391 }

392

393 if (includeRight) {

394 if (opright)

395 s1 = mapSet(rtrans, s1, Value(), opright, a);

396 latSets[sNew].append(latSets[s1]);

397 }

398 return sNew;

399 }

400

406 auto &setNew = latSets[sNew];

408 const auto &point = latPoints[p];

409 setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op, a)));

410 }

411 return sNew;

412 }

413

418

420 auto &setNew = latSets[sNew];

423 const auto &point = latPoints[p];

424 ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a)

425 : addExp(kind, point.exp, zeroExp, nullptr, a);

426 setNew.push_back(addLat(point.bits, newExp));

427 }

428 return sNew;

429 }

430

433 auto &setNew = latSets[sNew];

434 const auto &set0 = set(s0);

435 assert(!set0.empty());

438 bool add = true;

439 if (p0 != p1) {

440

442 continue;

443

445 assert(latGT(p1, p2));

447 add = false;

448 break;

449 }

450 }

451 assert(!add || latGT(p0, p1));

452 }

453 if (add)

454 setNew.push_back(p1);

455 }

458 return sNew;

459 }

460

462

463

464 bool isSingleton = true;

466 if (p0 != p1 && latGT(p0, p1)) {

467 isSingleton = false;

468 break;

469 }

470 }

471

472 BitVector simple(latPoints[p0].bits);

473 bool reset = isSingleton && hasAnySparse(simple);

475 TensorLoopId offset = 0;

476 if (!reset)

477

478

479 for (unsigned b = 0; b < be; b++) {

481 offset = be - b - 1;

482 break;

483 }

484 }

485

486

487

488 for (unsigned b = be - 1 - offset, i = 0; i < be;

489 b = b == 0 ? be - 1 : b - 1, i++) {

490

493 if (!lt.hasSparseSemantic()) {

494 if (reset)

495 simple.reset(b);

496 reset = true;

497 }

498 }

499 }

500 return simple;

501 }

502

504 const BitVector &bitsi = lat(i).bits;

505 const BitVector &bitsj = lat(j).bits;

506 assert(bitsi.size() == bitsj.size());

507 if (bitsi.count() > bitsj.count()) {

508 for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)

509 if (bitsj[b] && !bitsi[b])

510 return false;

511 return true;

512 }

513 return false;

514 }

515

517 BitVector tmp(latPoints[j].bits);

518 tmp ^= latPoints[i].bits;

520 }

521

523 const auto &expr = exp(e);

524

526 return expr.tensor == t;

527

530 return false;

532 const ExprId e0 = expr.children.e0;

534 }

536 const ExprId e0 = expr.children.e0;

537 const ExprId e1 = expr.children.e1;

539 }

540 }

541 llvm_unreachable("unexpected arity");

542 }

543

545 const auto &expr = exp(e);

546 switch (expr.kind) {

560 return lhsNeg;

561 }

562 default: {

565 return false;

571 }

572 }

573 }

574 llvm_unreachable("unexpected kind");

575 }

576

578 assert(isValidTensorId(t));

579 const auto &expr = exp(e);

580 switch (expr.kind) {

581

583 return expr.tensor == t;

587 return false;

588

625 return false;

626

631 assert(!maybeZero(expr.children.e1));

636 assert(isInvariant(expr.children.e1));

645 isInvariant(expr.children.e1);

647 return isInvariant(expr.children.e0);

648 return false;

662 return false;

664

665

666 return true;

667 }

668 llvm_unreachable("unexpected kind");

669 }

670

674 if (lt.hasSparseSemantic())

675 return true;

676 }

678 }

679

683 return true;

684 return false;

685 }

686

687 #ifndef NDEBUG

688

689

690

691

692

694 switch (kind) {

695

697 return "tensor";

699 return "invariant";

701 return "index";

703 return "0";

704

708 return "abs";

710 return "ceil";

712 return "floor";

715 return "sqrt";

718 return "expm1";

721 return "log1p";

723 return "relu";

726 return "sin";

729 return "tanh";

733 return "-";

745 return "complex.im";

747 return "complex.re";

749 return "cast";

751 return "binary_branch";

753 return "unary";

755 return "select";

756

760 return "*";

765 return "/";

769 return "+";

773 return "-";

775 return "&";

777 return "|";

779 return "^";

781 return "a>>";

783 return ">>";

785 return "<<";

788 return "cmp";

790 return "binary";

792 return "reduce";

794 return "dense";

795 }

796 llvm_unreachable("unexpected kind for symbol");

797 }

798

800 const auto &expr = exp(e);

801 switch (expr.kind) {

802

804 if (expr.tensor == syntheticTensor)

805 llvm::dbgs() << "synthetic_";

806 else if (expr.tensor == outTensor)

807 llvm::dbgs() << "output_";

808 llvm::dbgs() << "tensor_" << expr.tensor;

809 break;

811 llvm::dbgs() << "invariant";

812 break;

814 llvm::dbgs() << "0";

815 break;

817 llvm::dbgs() << "loopvar_" << expr.loop;

818 break;

819

856 dumpExp(expr.children.e0);

857 break;

858

883 llvm::dbgs() << "(";

884 dumpExp(expr.children.e0);

886 if (expr.attr)

887 llvm::dbgs() << "{" << expr.attr << "}";

889 llvm::dbgs() << " ";

890 dumpExp(expr.children.e1);

891 llvm::dbgs() << ")";

892 } else {

894 }

895 break;

896 }

897 }

898

900 const auto &point = lat(p);

901 llvm::dbgs() << "lat(";

903 llvm::dbgs() << " :";

905 llvm::dbgs() << " : ";

907 llvm::dbgs() << " )\n";

908 }

909

911 const auto &ss = set(s);

912 llvm::dbgs() << "{ #" << ss.size() << "\n";

914 llvm::dbgs() << " ";

916 }

917 llvm::dbgs() << "}\n";

918 }

919

921 for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {

922 if (bits[b]) {

925 const auto lt = lvlTypes[t][i];

927 llvm::dbgs() << " DEP_" << t << "_" << i;

928 else

929 llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt);

930 }

931 }

932 }

933

934 #endif

935

936

937

938

939

941

942

943

944

945 const auto &expr = exp(e);

947 switch (kind) {

948

953

954

955

956

957

959 TensorId t = syntheticTensor;

961 t = expr.tensor;

962 if (hasSparseOut && t == outTensor)

963 t = syntheticTensor;

964 }

965 latSets[s].push_back(addLat(t, i, e));

966 return s;

967 }

968

1001

1002

1003

1004

1005

1006

1007 {

1008 const ExprId e0 = expr.children.e0;

1009 const Value v = expr.val;

1012 }

1015

1016

1017 {

1018 const ExprId e0 = expr.children.e0;

1021 }

1023

1024

1025

1026

1027

1028 {

1029 const ExprId e0 = expr.children.e0;

1030 UnaryOp unop = cast(expr.op);

1032 Region &absentRegion = unop.getAbsentRegion();

1033 if (absentRegion.empty()) {

1034

1036 }

1037

1038

1039 Block &absentBlock = absentRegion.front();

1040 YieldOp absentYield = cast(absentBlock.getTerminator());

1041 const Value absentVal = absentYield.getSingleResult();

1044 }

1045

1050

1051

1052

1053

1054

1055

1056

1057

1058

1059 {

1060 const ExprId e0 = expr.children.e0;

1061 const ExprId e1 = expr.children.e1;

1063 }

1068

1069

1070

1071

1072

1073

1074

1075

1076

1077

1078

1079

1080

1081 {

1082 const ExprId e0 = expr.children.e0;

1083 const ExprId e1 = expr.children.e1;

1084 assert(!maybeZero(e1));

1086 }

1095

1096

1097

1098

1099

1100

1101

1102 {

1103 const ExprId e0 = expr.children.e0;

1104 const ExprId e1 = expr.children.e1;

1106 }

1109

1110

1111

1112

1113

1114

1115

1116 {

1117 const ExprId e0 = expr.children.e0;

1118 const ExprId e1 = expr.children.e1;

1120 }

1124

1125

1126

1127 {

1128 const ExprId e0 = expr.children.e0;

1129 const ExprId e1 = expr.children.e1;

1130 assert(isInvariant(e1));

1132 }

1134

1135

1136

1137

1138

1139

1140 {

1141 const ExprId e0 = expr.children.e0;

1142 const ExprId e1 = expr.children.e1;

1143 BinaryOp binop = cast(expr.op);

1146 Region &leftRegion = binop.getLeftRegion();

1147 Region &rightRegion = binop.getRightRegion();

1148

1150 if (!leftRegion.empty()) {

1151 Block &leftBlock = leftRegion.front();

1153 }

1154

1155 Operation *rightYield = nullptr;

1156 if (!rightRegion.empty()) {

1157 Block &rightBlock = rightRegion.front();

1159 }

1160 bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();

1161 bool includeRight = binop.getRightIdentity() || !rightRegion.empty();

1162 return combiSet(e, child0, child1, binop, includeLeft,

1165 }

1167

1168 {

1169 const ExprId e0 = expr.children.e0;

1170 const ExprId e1 = expr.children.e1;

1173 }

1175

1176

1177

1179 const ExprId e0 = expr.children.e0;

1182 }

1183

1184 const ExprId e0 = expr.children.e0;

1185 const ExprId e1 = expr.children.e1;

1188 }

1189 }

1190 llvm_unreachable("unexpected expression kind");

1191 }

1192

1194

1196 assert(isalinalg::YieldOp(yield));

1197 return buildTensorExp(op, yield->getOperand(0)).first;

1198 }

1199

1200

1202 if (auto c = val.getDefiningOpcomplex::ConstantOp()) {

1203 ArrayAttr arrayAttr = c.getValue();

1204 return cast(arrayAttr[0]).getValue().isZero() &&

1205 cast(arrayAttr[1]).getValue().isZero();

1206 }

1208 return c.value() == 0;

1210 return c.value().isZero();

1211 return false;

1212 }

1213

1214

1215 bool Merger::maybeZero(ExprId e) const {

1216 const auto &expr = exp(e);

1218

1219

1220 if (auto c = expr.val.getDefiningOpcomplex::ConstantOp()) {

1221 ArrayAttr arrayAttr = c.getValue();

1222 return cast(arrayAttr[0]).getValue().isZero() &&

1223 cast(arrayAttr[1]).getValue().isZero();

1224 }

1225 if (auto c = expr.val.getDefiningOparith::ConstantIntOp())

1226 return c.value() == 0;

1227 if (auto c = expr.val.getDefiningOparith::ConstantFloatOp())

1228 return c.value().isZero();

1229 }

1230 return true;

1231 }

1232

1233 Type Merger::inferType(ExprId e, Value src) const {

1234

1236

1237

1238 if (auto vtp = dyn_cast(src.getType()))

1239 return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());

1240 return dtp;

1241 }

1242

1243

1245

1246 if (isa(v))

1247 return true;

1248

1250 if (isalinalg::IndexOp(def))

1251 return true;

1252

1253 if (def->getBlock() != block)

1255

1256

1257 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)

1259 return false;

1260 return true;

1261 }

1262

1263

1265 if (region.empty())

1266 return true;

1267

1269 assert(isa(yield));

1271 }

1272

1273

1276 auto pred = llvm::castarith::CmpIPredicateAttr(attr).getValue();

1277 return pred == arith::CmpIPredicate::ugt ||

1278 pred == arith::CmpIPredicate::sgt;

1279 }

1281 auto pred = llvm::castarith::CmpFPredicateAttr(attr).getValue();

1282 return pred == arith::CmpFPredicate::UGT ||

1283 pred == arith::CmpFPredicate::OGT;

1284 }

1285 return false;

1286 }

1287

1288 std::pair<std::optional, bool>

1289 Merger::buildTensorExp(linalg::GenericOp op, Value v) {

1290

1291 if (auto arg = dyn_cast(v)) {

1293

1294

1295

1296 if (arg.getOwner()->getParentOp() == op) {

1297 OpOperand &t = op->getOpOperand(tid);

1299 if (!op.isScalar(&t))

1301 v = t.get();

1302 }

1303

1304

1306 }

1307

1308

1310 if (def->getBlock() != &op.getRegion().front())

1312

1313 if (def->getNumOperands() == 0) {

1314 if (auto indexOp = dyn_castlinalg::IndexOp(def))

1316 }

1317

1318

1319 if (def->getNumOperands() == 1) {

1320 const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));

1321 if (x.has_value()) {

1322 const ExprId e = *x;

1323 if (isamath::AbsFOp(def))

1325 if (isacomplex::AbsOp(def))

1327 if (isamath::AbsIOp(def))

1329 if (isamath::CeilOp(def))

1331 if (isamath::FloorOp(def))

1333 if (isamath::SqrtOp(def))

1335 if (isacomplex::SqrtOp(def))

1337 if (isamath::ExpM1Op(def))

1339 if (isacomplex::Expm1Op(def))

1341 if (isamath::Log1pOp(def))

1343 if (isacomplex::Log1pOp(def))

1345 if (isamath::SinOp(def))

1347 if (isacomplex::SinOp(def))

1349 if (isamath::TanhOp(def))

1351 if (isacomplex::TanhOp(def))

1353 if (isaarith::NegFOp(def))

1355 if (isacomplex::NegOp(def))

1357 if (isaarith::TruncFOp(def))

1359 if (isaarith::ExtFOp(def))

1361 if (isaarith::FPToSIOp(def))

1363 if (isaarith::FPToUIOp(def))

1365 if (isaarith::SIToFPOp(def))

1367 if (isaarith::UIToFPOp(def))

1369 if (isaarith::ExtSIOp(def))

1371 if (isaarith::ExtUIOp(def))

1373 if (isaarith::IndexCastOp(def))

1375 if (isaarith::TruncIOp(def))

1377 if (isacomplex::ImOp(def))

1379 if (isacomplex::ReOp(def))

1381 if (isaarith::BitcastOp(def))

1383 if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {

1387 }

1388 if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {

1391 }

1392 }

1393 }

1394

1395

1396

1397

1398 if (def->getNumOperands() == 2) {

1399 const auto [x, xSpVals] = buildTensorExp(op, def->getOperand(0));

1400 const auto [y, ySpVals] = buildTensorExp(op, def->getOperand(1));

1401

1402

1403

1404 bool conjSpVals = xSpVals || ySpVals;

1405 bool disjSpVals = xSpVals && ySpVals;

1406 if (x.has_value() && y.has_value()) {

1407 const ExprId e0 = *x;

1408 const ExprId e1 = *y;

1409 if (isaarith::MulFOp(def))

1411 if (isacomplex::MulOp(def))

1413 if (isaarith::MulIOp(def))

1415 if (isaarith::DivFOp(def) && !maybeZero(e1))

1417 if (isacomplex::DivOp(def) && !maybeZero(e1))

1419 if (isaarith::DivSIOp(def) && !maybeZero(e1))

1421 if (isaarith::DivUIOp(def) && !maybeZero(e1))

1423 if (isaarith::AddFOp(def))

1425 if (isacomplex::AddOp(def))

1427 if (isaarith::AddIOp(def))

1429 if (isaarith::SubFOp(def))

1431 if (isacomplex::SubOp(def))

1433 if (isaarith::SubIOp(def))

1435 if (isaarith::AndIOp(def))

1437 if (isaarith::OrIOp(def))

1439 if (isaarith::XOrIOp(def))

1441 if (isaarith::ShRSIOp(def) && isInvariant(e1))

1443 if (isaarith::ShRUIOp(def) && isInvariant(e1))

1445 if (isaarith::ShLIOp(def) && isInvariant(e1))

1447 if (auto ci = dyn_castarith::CmpIOp(def)) {

1448 if (ci.getPredicate() == arith::CmpIPredicate::eq &&

1449 ci.getPredicate() == arith::CmpIPredicate::sle &&

1450 ci.getPredicate() == arith::CmpIPredicate::sge &&

1451 ci.getPredicate() == arith::CmpIPredicate::ule &&

1452 ci.getPredicate() == arith::CmpIPredicate::uge) {

1453

1454

1455 return {std::nullopt, false};

1456 }

1457

1459 ci.getPredicateAttr());

1460 return {e, conjSpVals};

1461 }

1462 if (auto cf = dyn_castarith::CmpFOp(def)) {

1463 if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&

1464 cf.getPredicate() == arith::CmpFPredicate::OGE &&

1465 cf.getPredicate() == arith::CmpFPredicate::OLE &&

1466 cf.getPredicate() == arith::CmpFPredicate::ONE &&

1467 cf.getPredicate() == arith::CmpFPredicate::UEQ &&

1468 cf.getPredicate() == arith::CmpFPredicate::UGE &&

1469 cf.getPredicate() == arith::CmpFPredicate::ULE &&

1470 cf.getPredicate() == arith::CmpFPredicate::ORD &&

1471 cf.getPredicate() == arith::CmpFPredicate::UNO) {

1472

1473

1474 return {std::nullopt, false};

1475 }

1477 cf.getPredicateAttr());

1478 return {e, conjSpVals};

1479 }

1480 if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {

1482 (binop.getLeftIdentity() ||

1484 (binop.getRightIdentity() ||

1487 }

1488 }

1489 }

1490

1491

1492 if (def->getNumOperands() == 3) {

1493 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));

1494 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));

1495 const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));

1496 bool hasSpDep = xDepSp || yDepSp || zDepSp;

1497 if (x.has_value() && y.has_value() && z.has_value()) {

1498 const ExprId e0 = *x;

1499 const ExprId e1 = *y;

1500 if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {

1503 }

1504 if (auto selop = dyn_castarith::SelectOp(def)) {

1505

1506

1507

1508 const auto &cnd = exp(*x);

1509 if (isGreater(cnd.kind, cnd.attr) &&

1513 const auto &a = exp(cnd.children.e0);

1514 const auto &b = exp(cnd.children.e1);

1519 nullptr, cnd.attr),

1520 yDepSp};

1521 }

1522 }

1523 }

1524 }

1525 }

1526

1527

1528

1529

1530

1531 if (def->getNumResults() != 1)

1532 return {std::nullopt, false};

1533 SmallVector<std::pair<std::optional, bool>, 2> subExp;

1534

1535 for (Value operand : def->getOperands())

1536 subExp.push_back(buildTensorExp(op, operand));

1537

1538 if (llvm::all_of(subExp,

1539 [](auto e) { return e.first.has_value() && !e.second; })) {

1540

1541 if (subExp.size() == 2) {

1543 *subExp[1].first, def);

1544 return {e, false};

1545 }

1546 if (subExp.size() == 1) {

1549 return {e, false};

1550 }

1551 }

1552

1553

1554 return {std::nullopt, false};

1555 }

1556

1559

1562 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);

1563 Block &clonedBlock = tmpRegion.front();

1564 YieldOp clonedYield = cast(clonedBlock.getTerminator());

1565

1568 Value val = clonedYield.getSingleResult();

1569 rewriter.eraseOp(clonedYield);

1570 rewriter.eraseOp(placeholder);

1571 return val;

1572 }

1573

1576 if (!v0)

1577

1579 UnaryOp unop = cast(op);

1580 Region &presentRegion = unop.getPresentRegion();

1581 if (presentRegion.empty())

1582

1583

1585 return insertYieldOp(rewriter, loc, presentRegion, {v0});

1586 }

1587

1590 if (!v0 || !v1)

1591

1593 BinaryOp binop = cast(op);

1594 Region &overlapRegion = binop.getOverlapRegion();

1595 if (overlapRegion.empty())

1596

1597

1599 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});

1600 }

1601

1605 auto zero =

1606 rewriter.createarith::ConstantOp(loc, tp, rewriter.getZeroAttr(tp));

1608 if (isa(tp)) {

1609 auto pred = llvm::castarith::CmpFPredicateAttr(attr);

1610 cmp = rewriter.createarith::CmpFOp(loc, pred, v0, zero);

1611 } else {

1612 auto pred = llvm::castarith::CmpIPredicateAttr(attr);

1613 cmp = rewriter.createarith::CmpIOp(loc, pred, v0, zero);

1614 }

1615 return rewriter.createarith::SelectOp(loc, cmp, v0, zero);

1616 }

1617

1619 Value v1) const {

1620 const auto &expr = exp(e);

1621 switch (expr.kind) {

1622

1627 llvm_unreachable("unexpected non-op");

1628

1630 return rewriter.createmath::AbsFOp(loc, v0);

1632 auto type = cast(v0.getType());

1633 auto eltType = cast(type.getElementType());

1634 return rewriter.createcomplex::AbsOp(loc, eltType, v0);

1635 }

1637 return rewriter.createmath::AbsIOp(loc, v0);

1639 return rewriter.createmath::CeilOp(loc, v0);

1641 return rewriter.createmath::FloorOp(loc, v0);

1643 return rewriter.createmath::SqrtOp(loc, v0);

1645 return rewriter.createcomplex::SqrtOp(loc, v0);

1647 return rewriter.createmath::ExpM1Op(loc, v0);

1649 return rewriter.createcomplex::Expm1Op(loc, v0);

1651 return rewriter.createmath::Log1pOp(loc, v0);

1653 return rewriter.createcomplex::Log1pOp(loc, v0);

1655 return buildRelu(rewriter, loc, v0, expr.attr);

1657 return rewriter.createmath::SinOp(loc, v0);

1659 return rewriter.createcomplex::SinOp(loc, v0);

1661 return rewriter.createmath::TanhOp(loc, v0);

1663 return rewriter.createcomplex::TanhOp(loc, v0);

1665 return rewriter.createarith::NegFOp(loc, v0);

1667 return rewriter.createcomplex::NegOp(loc, v0);

1669 return rewriter.createarith::SubIOp(

1670 loc,

1671 rewriter.createarith::ConstantOp(loc, v0.getType(),

1673 v0);

1675 return rewriter.createarith::TruncFOp(loc, inferType(e, v0), v0);

1677 return rewriter.createarith::ExtFOp(loc, inferType(e, v0), v0);

1679 return rewriter.createarith::FPToSIOp(loc, inferType(e, v0), v0);

1681 return rewriter.createarith::FPToUIOp(loc, inferType(e, v0), v0);

1683 return rewriter.createarith::SIToFPOp(loc, inferType(e, v0), v0);

1685 return rewriter.createarith::UIToFPOp(loc, inferType(e, v0), v0);

1687 return rewriter.createarith::ExtSIOp(loc, inferType(e, v0), v0);

1689 return rewriter.createarith::ExtUIOp(loc, inferType(e, v0), v0);

1691 return rewriter.createarith::IndexCastOp(loc, inferType(e, v0), v0);

1693 return rewriter.createarith::TruncIOp(loc, inferType(e, v0), v0);

1695 auto type = cast(v0.getType());

1696 auto eltType = cast(type.getElementType());

1697 return rewriter.createcomplex::ImOp(loc, eltType, v0);

1698 }

1700 auto type = cast(v0.getType());

1701 auto eltType = cast(type.getElementType());

1702 return rewriter.createcomplex::ReOp(loc, eltType, v0);

1703 }

1705 return rewriter.createarith::BitcastOp(loc, inferType(e, v0), v0);

1706

1708 return rewriter.createarith::MulFOp(loc, v0, v1);

1710 return rewriter.createcomplex::MulOp(loc, v0, v1);

1712 return rewriter.createarith::MulIOp(loc, v0, v1);

1714 return rewriter.createarith::DivFOp(loc, v0, v1);

1716 return rewriter.createcomplex::DivOp(loc, v0, v1);

1718 return rewriter.createarith::DivSIOp(loc, v0, v1);

1720 return rewriter.createarith::DivUIOp(loc, v0, v1);

1722 return rewriter.createarith::AddFOp(loc, v0, v1);

1724 return rewriter.createcomplex::AddOp(loc, v0, v1);

1726 return rewriter.createarith::AddIOp(loc, v0, v1);

1728 return rewriter.createarith::SubFOp(loc, v0, v1);

1730 return rewriter.createcomplex::SubOp(loc, v0, v1);

1732 return rewriter.createarith::SubIOp(loc, v0, v1);

1734 return rewriter.createarith::AndIOp(loc, v0, v1);

1736 return rewriter.createarith::OrIOp(loc, v0, v1);

1738 return rewriter.createarith::XOrIOp(loc, v0, v1);

1740 return rewriter.createarith::ShRSIOp(loc, v0, v1);

1742 return rewriter.createarith::ShRUIOp(loc, v0, v1);

1744 return rewriter.createarith::ShLIOp(loc, v0, v1);

1746 auto predicate = llvm::castarith::CmpIPredicateAttr(expr.attr);

1747 return rewriter.createarith::CmpIOp(loc, predicate, v0, v1);

1748 }

1750 auto predicate = llvm::castarith::CmpFPredicateAttr(expr.attr);

1751 return rewriter.createarith::CmpFOp(loc, predicate, v0, v1);

1752 }

1755 {v0});

1760 cast<sparse_tensor::SelectOp>(expr.op).getRegion(),

1761 {v0});

1765 ReduceOp redOp = cast(expr.op);

1766 return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});

1767 }

1774 return rewriter.clone(*actualOp, mapping)->getResult(0);

1775 }

1776 }

1777 llvm_unreachable("unexpected expression kind in build");

1778 }

1779

1780 }

1781 }

union mlir::linalg::@1204::ArityGroupAndKind::Kind kind

Attributes are known-constant values of operations.

Block represents an ordered list of Operations.

Region * getParent() const

Provide a 'getParent' method for ilist_node_with_parent methods.

Operation * getTerminator()

Get the terminator operation of this block.

TypedAttr getZeroAttr(Type type)

This is a utility class for mapping one set of IR entities to another.

void map(Value from, Value to)

Inserts a new mapping for 'from' to 'to'.

IRValueT get() const

Return the current value being used by this operand.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

Block * getBlock() const

Returns the current block of the builder.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This class represents an operand of an operation.

Operation is the basic unit of execution within MLIR.

Value getOperand(unsigned idx)

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

unsigned getNumOperands()

Block * getBlock()

Returns the operation block that contains this operation.

Region & getRegion(unsigned index)

Returns the region held by this operation at position 'index'.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

void cloneInto(Region *dest, IRMapping &mapper)

Clone the internal blocks from this region into dest.

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)

Inline the operations of block 'source' into block 'dest' before the given position.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

Specialization of arith.constant op that returns a floating point value.

Specialization of arith.constant op that returns an integer of index type.

Specialization of arith.constant op that returns an integer value.

LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1, Operation *op=nullptr)

Computes a single conjunction of two lattice points by taking the "union" of LoopId (effectively cons...

LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)

Disjunctive merge of two lattice sets: (s0 /\_op s1, s0, s1).

bool isSingleCondition(TensorId t, ExprId e) const

Returns true if given tensor iterates only in the given tensor expression.

bool hasSparseIdxReduction(const BitVector &bits) const

Returns true if bits contains a dependent index reduction condition on sparse levels.

bool expContainsTensor(ExprId e, TensorId t) const

Returns true if the expression contains the tensor as an operand.

LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero)

Maps the binary operator to the same operation but with one of its operand set to zero,...

bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const

Checks whether the TensorLoopId represents a sparse tensor level contains non-trivial index expressio...

void dumpBits(const BitVector &bits) const

LatSetId addSet()

Constructs a new (initially empty) set, and returns its identifier.

BitVector simplifyCond(LatSetId s, LatPointId p)

Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...

bool hasNegateOnOut(ExprId e) const

Returns true if the expression contains a negation on output tensor.

bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const

Checks whether the TensorLoopId represents a tensor level contains non-trivial index expression.

LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1)

Disjunctive merge of two lattice sets and also set one of the operand to zero: (s0 /\_op s1 (e0 op e1...

void dumpSet(LatSetId s) const

void dumpLat(LatPointId p) const

LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, bool includeRight, TensorExp::Kind rtrans, Operation *opright)

Disjunctive merge of two lattice sets with custom handling of the overlap, left, and right regions.

ExprId addTensorExp(TensorId t)

Constructs a new tensor expression, and returns its identifier.

LatSetId buildLattices(ExprId e, LoopId i)

Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...

LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)

Conjunctive merge of two lattice sets: (s0 /\_op s1).

ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1=detail::kInvalidId, Operation *op=nullptr, Attribute attr=nullptr)

Constructs a new unary or binary expression, and returns its identifier.

ExprId addSynZeroExp()

Constructs a new synthetic zero expression.

constexpr LoopId makeLoopId(unsigned i) const

Safely converts the argument to a loop identifier.

std::optional< ExprId > buildTensorExpFromLinalg(linalg::GenericOp op)

Builds a tensor expression from the given Linalg operation.

LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v=Value(), Operation *op=nullptr, Attribute attr=nullptr)

Maps the unary operator over the lattice set of the operand, i.e.

ArrayRef< LatPointId > set(LatSetId s) const

LatSetId optimizeSet(LatSetId s)

Optimizes the iteration lattice points in the given set.

constexpr TensorId tensor(TensorLoopId b) const

Gets the tensor-identifier of the TensorLoopId.

void dumpExp(ExprId e) const

Print methods (for debugging).

Merger(unsigned numInputOutputTensors, unsigned numLoops, unsigned maxLvlRank)

Constructs a merger for the given number of tensors and loops.

bool hasAnySparse(const BitVector &bits) const

Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.

LatPointId addLat(TensorId t, LoopId i, ExprId e)

Constructs a new iteration lattice point, and returns its identifier.

ExprId addLoopVarExp(LoopId i)

Constructs a new loop-variable expression, and returns its identifier.

bool latGT(LatPointId p0, LatPointId p1) const

Returns true if p0 > p1.

const TensorExp & exp(ExprId e) const

Convenience getters to immediately access the stored nodes.

constexpr LoopId loop(TensorLoopId b) const

Gets the loop-identifier of the TensorLoopId.

const LatPoint & lat(LatPointId p) const

bool onlyDenseDiff(LatPointId p0, LatPointId p1) const

Returns true if p0 and p1 only differ in dense.

ExprId addInvariantExp(Value v)

Constructs a new invariant expression, and returns its identifier.

constexpr TensorId makeTensorId(unsigned t) const

Safely converts the argument to a tensor identifier.

LevelType getLvlType(TensorId t, LoopId i) const

Gets the level-type of the tth tensor on ith loop.

Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const

Rebuilds SSA format from a tensor expression.

constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const

Safely converts the arguments to a pair of (tensor,loop) identifiers.

bool expIsTensor(ExprId e, TensorId t) const

Returns true if the expression is (kTensor t).

@ Type

An inlay hint that for a type annotation.

static constexpr unsigned kInvalidId

A constant serving as the canonically invalid identifier, regardless of the identifier type.

LevelFormat

This enum defines all supported storage format without the level properties.

static bool isCertainZero(Value val)

Only returns true if we are certain this is a zero.

static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v)

Ensures that the sparsifier can generate code for expression.

unsigned LatSetId

LatSet identifiers.

static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region, ValueRange vals)

std::string toMLIRString(LevelType lt)

std::pair< Level, LevelType > LvlLTPair

A pair of level and its corresponding LevelType of a tensor.

unsigned TensorLoopId

A compressed representation of std::pair<TensorId, LoopId>.

static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, Operation *op, Value v0)

uint64_t Level

The type of level identifiers and level-ranks.

static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, Operation *op, Value v0, Value v1)

static bool isAdmissibleBranch(Operation *op, Region &region)

Ensures that the sparsifier can generate code for branch.

unsigned LoopId

Loop identifiers.

static const char * kindToOpSymbol(TensorExp::Kind kind)

SparseTensorEncodingAttr getSparseTensorEncoding(Type type)

Convenience method to get a sparse encoding attribute from a type.

static bool isGreater(TensorExp::Kind kind, Attribute attr)

unsigned ExprId

TensorExp identifiers.

static ExpArity getExpArity(TensorExp::Kind k)

static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0, Attribute attr)

unsigned LatPointId

LatPoint identifiers.

std::pair< LoopId, unsigned > LoopCoeffPair

A pair of loop id and its coefficients.

unsigned TensorId

Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...

Include the generated interface declarations.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

BitVector bits

Conjunction of all TensorLoopIds involved in the tensor expression.

This enum defines all the sparse representations supportable by the SparseTensor dialect.

LoopId loop

kLoopVar expressions simply have a loop identifier.

Value val

Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...

Kind

Tensor expression kind.

Children children

All other expressions hold the ExprIds of their children.

Attribute attr

An optional attribute that is required to determine the semantics of the operations.

TensorId tensor

kTensor expressions simply have a tensor identifier.

Kind kind

Tensor expression kind.

TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a)

The x parameter has different types depending on the value of the k parameter.

Eliminates variable at the specified position using Fourier-Motzkin variable elimination.