MLIR: lib/Dialect/SparseTensor/Utils/Merger.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
14
16 #include "llvm/Support/Debug.h"
17 #include
18
19 namespace mlir {
20 namespace sparse_tensor {
21
26 };
27
29 switch (k) {
30
72
98 }
99 llvm_unreachable("unexpected kind");
100 }
101
102
103
104
105
108 : kind(k), val(v), op(o), attr(a) {
109 switch (kind) {
110
114 return;
117 return;
120 return;
124 return;
125
150 return;
165 return;
171 return;
173
174
178 return;
179
202 return;
208 return;
214 return;
219 return;
220 }
221 llvm_unreachable("unexpected kind");
222 }
223
224 Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
225 unsigned maxLvlRank)
226 : outTensor(numInputOutputTensors - 1),
227 syntheticTensor(numInputOutputTensors),
228 numTensors(numInputOutputTensors + 1), numLoops(numLoops),
229 hasSparseOut(false),
230 lvlTypes(numTensors,
232 loopToLvl(numTensors,
233 std::vector<std::optional<Level>>(numLoops, std::nullopt)),
234 lvlToLoop(numTensors,
235 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
236 loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>(
237 numTensors, std::nullopt)),
238 levelToDependentLoop(numTensors,
241 loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
242
243
244
245
246
248 assert(isValidTensorId(t));
249 const ExprId eNew(tensorExps.size());
251 Value(), nullptr, nullptr);
252 return eNew;
253 }
254
256 assert(isValidLoopId(i));
257 const ExprId eNew(tensorExps.size());
259 Value(), nullptr, nullptr);
260 return eNew;
261 }
262
264 const ExprId eNew(tensorExps.size());
267 return eNew;
268 }
269
271 const ExprId eNew(tensorExps.size());
274 return eNew;
275 }
276
280 const ExprId eNew(tensorExps.size());
281 tensorExps.emplace_back(k, e0, e1, Value(), op, attr);
282 return eNew;
283 }
284
288 const ExprId eNew(tensorExps.size());
290 return eNew;
291 }
292
294 const LatPointId pNew(latPoints.size());
295 const unsigned size = numLoops * numTensors;
297 latPoints.emplace_back(size, e);
298 latPoints[pNew].bits.set(b);
299 return pNew;
300 }
301
303 assert(bits.size() == numLoops * numTensors);
304 const LatPointId pNew(latPoints.size());
305 latPoints.emplace_back(bits, e);
306 return pNew;
307 }
308
310 const LatSetId sNew(latSets.size());
311 latSets.emplace_back();
312 return sNew;
313 }
314
319 const LatPointId pNew(latPoints.size());
320 const auto &point0 = lat(p0);
321 const auto &point1 = lat(p1);
322 BitVector bits(point0.bits);
323 bits |= point1.bits;
324 const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr);
325 latPoints.emplace_back(bits, ne);
326 return pNew;
327 }
328
331 auto &setNew = latSets[sNew];
334 setNew.push_back(conjLat(e, p0, p1, op));
335 return sNew;
336 }
337
341
342 latSets[sNew].append(latSets[s0]);
343
344
351
352 latSets[sNew].append(latSets[s1]);
353 return sNew;
354 }
355
360
365
367
368
369
370 return sNew;
371 }
372
375 latSets[sNew].append(latSets[lhsSet]);
376 latSets[sNew].append(latSets[rhsSet]);
377 return sNew;
378 }
379
382 Operation *opleft, bool includeRight,
386
387 if (includeLeft) {
388 if (opleft)
389 s0 = mapSet(ltrans, s0, Value(), opleft, a);
390 latSets[sNew].append(latSets[s0]);
391 }
392
393 if (includeRight) {
394 if (opright)
395 s1 = mapSet(rtrans, s1, Value(), opright, a);
396 latSets[sNew].append(latSets[s1]);
397 }
398 return sNew;
399 }
400
406 auto &setNew = latSets[sNew];
408 const auto &point = latPoints[p];
409 setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op, a)));
410 }
411 return sNew;
412 }
413
418
420 auto &setNew = latSets[sNew];
423 const auto &point = latPoints[p];
424 ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a)
425 : addExp(kind, point.exp, zeroExp, nullptr, a);
426 setNew.push_back(addLat(point.bits, newExp));
427 }
428 return sNew;
429 }
430
433 auto &setNew = latSets[sNew];
434 const auto &set0 = set(s0);
435 assert(!set0.empty());
438 bool add = true;
439 if (p0 != p1) {
440
442 continue;
443
445 assert((p1, p2));
447 add = false;
448 break;
449 }
450 }
451 assert(!add || latGT(p0, p1));
452 }
453 if (add)
454 setNew.push_back(p1);
455 }
458 return sNew;
459 }
460
462
463
464 bool isSingleton = true;
466 if (p0 != p1 && latGT(p0, p1)) {
467 isSingleton = false;
468 break;
469 }
470 }
471
472 BitVector simple(latPoints[p0].bits);
473 bool reset = isSingleton && hasAnySparse(simple);
475 TensorLoopId offset = 0;
476 if (!reset)
477
478
479 for (unsigned b = 0; b < be; b++) {
481 offset = be - b - 1;
482 break;
483 }
484 }
485
486
487
488 for (unsigned b = be - 1 - offset, i = 0; i < be;
489 b = b == 0 ? be - 1 : b - 1, i++) {
490
493 if (!lt.hasSparseSemantic()) {
494 if (reset)
495 simple.reset(b);
496 reset = true;
497 }
498 }
499 }
500 return simple;
501 }
502
504 const BitVector &bitsi = lat(i).bits;
505 const BitVector &bitsj = lat(j).bits;
506 assert(bitsi.size() == bitsj.size());
507 if (bitsi.count() > bitsj.count()) {
508 for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
509 if (bitsj[b] && !bitsi[b])
510 return false;
511 return true;
512 }
513 return false;
514 }
515
517 BitVector tmp(latPoints[j].bits);
518 tmp ^= latPoints[i].bits;
520 }
521
523 const auto &expr = exp(e);
524
526 return expr.tensor == t;
527
530 return false;
532 const ExprId e0 = expr.children.e0;
534 }
536 const ExprId e0 = expr.children.e0;
537 const ExprId e1 = expr.children.e1;
539 }
540 }
541 llvm_unreachable("unexpected arity");
542 }
543
545 const auto &expr = exp(e);
546 switch (expr.kind) {
560 return lhsNeg;
561 }
562 default: {
565 return false;
571 }
572 }
573 }
574 llvm_unreachable("unexpected kind");
575 }
576
578 assert(isValidTensorId(t));
579 const auto &expr = exp(e);
580 switch (expr.kind) {
581
583 return expr.tensor == t;
587 return false;
588
625 return false;
626
631 assert(!maybeZero(expr.children.e1));
636 assert(isInvariant(expr.children.e1));
645 isInvariant(expr.children.e1);
647 return isInvariant(expr.children.e0);
648 return false;
662 return false;
664
665
666 return true;
667 }
668 llvm_unreachable("unexpected kind");
669 }
670
674 if (lt.hasSparseSemantic())
675 return true;
676 }
678 }
679
683 return true;
684 return false;
685 }
686
687 #ifndef NDEBUG
688
689
690
691
692
694 switch (kind) {
695
697 return "tensor";
699 return "invariant";
701 return "index";
703 return "0";
704
708 return "abs";
710 return "ceil";
712 return "floor";
715 return "sqrt";
718 return "expm1";
721 return "log1p";
723 return "relu";
726 return "sin";
729 return "tanh";
733 return "-";
745 return "complex.im";
747 return "complex.re";
749 return "cast";
751 return "binary_branch";
753 return "unary";
755 return "select";
756
760 return "*";
765 return "/";
769 return "+";
773 return "-";
775 return "&";
777 return "|";
779 return "^";
781 return "a>>";
783 return ">>";
785 return "<<";
788 return "cmp";
790 return "binary";
792 return "reduce";
794 return "dense";
795 }
796 llvm_unreachable("unexpected kind for symbol");
797 }
798
800 const auto &expr = exp(e);
801 switch (expr.kind) {
802
804 if (expr.tensor == syntheticTensor)
805 llvm::dbgs() << "synthetic_";
806 else if (expr.tensor == outTensor)
807 llvm::dbgs() << "output_";
808 llvm::dbgs() << "tensor_" << expr.tensor;
809 break;
811 llvm::dbgs() << "invariant";
812 break;
814 llvm::dbgs() << "0";
815 break;
817 llvm::dbgs() << "loopvar_" << expr.loop;
818 break;
819
856 dumpExp(expr.children.e0);
857 break;
858
883 llvm::dbgs() << "(";
884 dumpExp(expr.children.e0);
886 if (expr.attr)
887 llvm::dbgs() << "{" << expr.attr << "}";
889 llvm::dbgs() << " ";
890 dumpExp(expr.children.e1);
891 llvm::dbgs() << ")";
892 } else {
894 }
895 break;
896 }
897 }
898
900 const auto &point = lat(p);
901 llvm::dbgs() << "lat(";
903 llvm::dbgs() << " :";
905 llvm::dbgs() << " : ";
907 llvm::dbgs() << " )\n";
908 }
909
911 const auto &ss = set(s);
912 llvm::dbgs() << "{ #" << ss.size() << "\n";
914 llvm::dbgs() << " ";
916 }
917 llvm::dbgs() << "}\n";
918 }
919
921 for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
922 if (bits[b]) {
925 const auto lt = lvlTypes[t][i];
927 llvm::dbgs() << " DEP_" << t << "_" << i;
928 else
929 llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt);
930 }
931 }
932 }
933
934 #endif
935
936
937
938
939
941
942
943
944
945 const auto &expr = exp(e);
947 switch (kind) {
948
953
954
955
956
957
959 TensorId t = syntheticTensor;
961 t = expr.tensor;
962 if (hasSparseOut && t == outTensor)
963 t = syntheticTensor;
964 }
965 latSets[s].push_back(addLat(t, i, e));
966 return s;
967 }
968
1001
1002
1003
1004
1005
1006
1007 {
1008 const ExprId e0 = expr.children.e0;
1009 const Value v = expr.val;
1012 }
1015
1016
1017 {
1018 const ExprId e0 = expr.children.e0;
1021 }
1023
1024
1025
1026
1027
1028 {
1029 const ExprId e0 = expr.children.e0;
1030 UnaryOp unop = cast(expr.op);
1032 Region &absentRegion = unop.getAbsentRegion();
1033 if (absentRegion.empty()) {
1034
1036 }
1037
1038
1039 Block &absentBlock = absentRegion.front();
1040 YieldOp absentYield = cast(absentBlock.getTerminator());
1041 const Value absentVal = absentYield.getSingleResult();
1044 }
1045
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059 {
1060 const ExprId e0 = expr.children.e0;
1061 const ExprId e1 = expr.children.e1;
1063 }
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081 {
1082 const ExprId e0 = expr.children.e0;
1083 const ExprId e1 = expr.children.e1;
1084 assert(!maybeZero(e1));
1086 }
1095
1096
1097
1098
1099
1100
1101
1102 {
1103 const ExprId e0 = expr.children.e0;
1104 const ExprId e1 = expr.children.e1;
1106 }
1109
1110
1111
1112
1113
1114
1115
1116 {
1117 const ExprId e0 = expr.children.e0;
1118 const ExprId e1 = expr.children.e1;
1120 }
1124
1125
1126
1127 {
1128 const ExprId e0 = expr.children.e0;
1129 const ExprId e1 = expr.children.e1;
1130 assert(isInvariant(e1));
1132 }
1134
1135
1136
1137
1138
1139
1140 {
1141 const ExprId e0 = expr.children.e0;
1142 const ExprId e1 = expr.children.e1;
1143 BinaryOp binop = cast(expr.op);
1146 Region &leftRegion = binop.getLeftRegion();
1147 Region &rightRegion = binop.getRightRegion();
1148
1150 if (!leftRegion.empty()) {
1151 Block &leftBlock = leftRegion.front();
1153 }
1154
1155 Operation *rightYield = nullptr;
1156 if (!rightRegion.empty()) {
1157 Block &rightBlock = rightRegion.front();
1159 }
1160 bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
1161 bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
1162 return combiSet(e, child0, child1, binop, includeLeft,
1165 }
1167
1168 {
1169 const ExprId e0 = expr.children.e0;
1170 const ExprId e1 = expr.children.e1;
1173 }
1175
1176
1177
1179 const ExprId e0 = expr.children.e0;
1182 }
1183
1184 const ExprId e0 = expr.children.e0;
1185 const ExprId e1 = expr.children.e1;
1188 }
1189 }
1190 llvm_unreachable("unexpected expression kind");
1191 }
1192
1194
1196 assert(isalinalg::YieldOp(yield));
1197 return buildTensorExp(op, yield->getOperand(0)).first;
1198 }
1199
1200
1202 if (auto c = val.getDefiningOpcomplex::ConstantOp()) {
1203 ArrayAttr arrayAttr = c.getValue();
1204 return cast(arrayAttr[0]).getValue().isZero() &&
1205 cast(arrayAttr[1]).getValue().isZero();
1206 }
1208 return c.value() == 0;
1210 return c.value().isZero();
1211 return false;
1212 }
1213
1214
1215 bool Merger::maybeZero(ExprId e) const {
1216 const auto &expr = exp(e);
1218
1219
1220 if (auto c = expr.val.getDefiningOpcomplex::ConstantOp()) {
1221 ArrayAttr arrayAttr = c.getValue();
1222 return cast(arrayAttr[0]).getValue().isZero() &&
1223 cast(arrayAttr[1]).getValue().isZero();
1224 }
1225 if (auto c = expr.val.getDefiningOparith::ConstantIntOp())
1226 return c.value() == 0;
1227 if (auto c = expr.val.getDefiningOparith::ConstantFloatOp())
1228 return c.value().isZero();
1229 }
1230 return true;
1231 }
1232
1233 Type Merger::inferType(ExprId e, Value src) const {
1234
1236
1237
1238 if (auto vtp = dyn_cast(src.getType()))
1239 return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
1240 return dtp;
1241 }
1242
1243
1245
1246 if (isa(v))
1247 return true;
1248
1250 if (isalinalg::IndexOp(def))
1251 return true;
1252
1253 if (def->getBlock() != block)
1255
1256
1257 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
1259 return false;
1260 return true;
1261 }
1262
1263
1265 if (region.empty())
1266 return true;
1267
1269 assert(isa(yield));
1271 }
1272
1273
1276 auto pred = llvm::castarith::CmpIPredicateAttr(attr).getValue();
1277 return pred == arith::CmpIPredicate::ugt ||
1278 pred == arith::CmpIPredicate::sgt;
1279 }
1281 auto pred = llvm::castarith::CmpFPredicateAttr(attr).getValue();
1282 return pred == arith::CmpFPredicate::UGT ||
1283 pred == arith::CmpFPredicate::OGT;
1284 }
1285 return false;
1286 }
1287
1288 std::pair<std::optional, bool>
1289 Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1290
1291 if (auto arg = dyn_cast(v)) {
1293
1294
1295
1296 if (arg.getOwner()->getParentOp() == op) {
1297 OpOperand &t = op->getOpOperand(tid);
1299 if (!op.isScalar(&t))
1301 v = t.get();
1302 }
1303
1304
1306 }
1307
1308
1310 if (def->getBlock() != &op.getRegion().front())
1312
1313 if (def->getNumOperands() == 0) {
1314 if (auto indexOp = dyn_castlinalg::IndexOp(def))
1316 }
1317
1318
1319 if (def->getNumOperands() == 1) {
1320 const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
1321 if (x.has_value()) {
1322 const ExprId e = *x;
1323 if (isamath::AbsFOp(def))
1325 if (isacomplex::AbsOp(def))
1327 if (isamath::AbsIOp(def))
1329 if (isamath::CeilOp(def))
1331 if (isamath::FloorOp(def))
1333 if (isamath::SqrtOp(def))
1335 if (isacomplex::SqrtOp(def))
1337 if (isamath::ExpM1Op(def))
1339 if (isacomplex::Expm1Op(def))
1341 if (isamath::Log1pOp(def))
1343 if (isacomplex::Log1pOp(def))
1345 if (isamath::SinOp(def))
1347 if (isacomplex::SinOp(def))
1349 if (isamath::TanhOp(def))
1351 if (isacomplex::TanhOp(def))
1353 if (isaarith::NegFOp(def))
1355 if (isacomplex::NegOp(def))
1357 if (isaarith::TruncFOp(def))
1359 if (isaarith::ExtFOp(def))
1361 if (isaarith::FPToSIOp(def))
1363 if (isaarith::FPToUIOp(def))
1365 if (isaarith::SIToFPOp(def))
1367 if (isaarith::UIToFPOp(def))
1369 if (isaarith::ExtSIOp(def))
1371 if (isaarith::ExtUIOp(def))
1373 if (isaarith::IndexCastOp(def))
1375 if (isaarith::TruncIOp(def))
1377 if (isacomplex::ImOp(def))
1379 if (isacomplex::ReOp(def))
1381 if (isaarith::BitcastOp(def))
1383 if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
1387 }
1388 if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
1391 }
1392 }
1393 }
1394
1395
1396
1397
1398 if (def->getNumOperands() == 2) {
1399 const auto [x, xSpVals] = buildTensorExp(op, def->getOperand(0));
1400 const auto [y, ySpVals] = buildTensorExp(op, def->getOperand(1));
1401
1402
1403
1404 bool conjSpVals = xSpVals || ySpVals;
1405 bool disjSpVals = xSpVals && ySpVals;
1406 if (x.has_value() && y.has_value()) {
1407 const ExprId e0 = *x;
1408 const ExprId e1 = *y;
1409 if (isaarith::MulFOp(def))
1411 if (isacomplex::MulOp(def))
1413 if (isaarith::MulIOp(def))
1415 if (isaarith::DivFOp(def) && !maybeZero(e1))
1417 if (isacomplex::DivOp(def) && !maybeZero(e1))
1419 if (isaarith::DivSIOp(def) && !maybeZero(e1))
1421 if (isaarith::DivUIOp(def) && !maybeZero(e1))
1423 if (isaarith::AddFOp(def))
1425 if (isacomplex::AddOp(def))
1427 if (isaarith::AddIOp(def))
1429 if (isaarith::SubFOp(def))
1431 if (isacomplex::SubOp(def))
1433 if (isaarith::SubIOp(def))
1435 if (isaarith::AndIOp(def))
1437 if (isaarith::OrIOp(def))
1439 if (isaarith::XOrIOp(def))
1441 if (isaarith::ShRSIOp(def) && isInvariant(e1))
1443 if (isaarith::ShRUIOp(def) && isInvariant(e1))
1445 if (isaarith::ShLIOp(def) && isInvariant(e1))
1447 if (auto ci = dyn_castarith::CmpIOp(def)) {
1448 if (ci.getPredicate() == arith::CmpIPredicate::eq &&
1449 ci.getPredicate() == arith::CmpIPredicate::sle &&
1450 ci.getPredicate() == arith::CmpIPredicate::sge &&
1451 ci.getPredicate() == arith::CmpIPredicate::ule &&
1452 ci.getPredicate() == arith::CmpIPredicate::uge) {
1453
1454
1455 return {std::nullopt, false};
1456 }
1457
1459 ci.getPredicateAttr());
1460 return {e, conjSpVals};
1461 }
1462 if (auto cf = dyn_castarith::CmpFOp(def)) {
1463 if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
1464 cf.getPredicate() == arith::CmpFPredicate::OGE &&
1465 cf.getPredicate() == arith::CmpFPredicate::OLE &&
1466 cf.getPredicate() == arith::CmpFPredicate::ONE &&
1467 cf.getPredicate() == arith::CmpFPredicate::UEQ &&
1468 cf.getPredicate() == arith::CmpFPredicate::UGE &&
1469 cf.getPredicate() == arith::CmpFPredicate::ULE &&
1470 cf.getPredicate() == arith::CmpFPredicate::ORD &&
1471 cf.getPredicate() == arith::CmpFPredicate::UNO) {
1472
1473
1474 return {std::nullopt, false};
1475 }
1477 cf.getPredicateAttr());
1478 return {e, conjSpVals};
1479 }
1480 if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1482 (binop.getLeftIdentity() ||
1484 (binop.getRightIdentity() ||
1487 }
1488 }
1489 }
1490
1491
1492 if (def->getNumOperands() == 3) {
1493 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1494 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1495 const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
1496 bool hasSpDep = xDepSp || yDepSp || zDepSp;
1497 if (x.has_value() && y.has_value() && z.has_value()) {
1498 const ExprId e0 = *x;
1499 const ExprId e1 = *y;
1500 if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1503 }
1504 if (auto selop = dyn_castarith::SelectOp(def)) {
1505
1506
1507
1508 const auto &cnd = exp(*x);
1509 if (isGreater(cnd.kind, cnd.attr) &&
1513 const auto &a = exp(cnd.children.e0);
1514 const auto &b = exp(cnd.children.e1);
1519 nullptr, cnd.attr),
1520 yDepSp};
1521 }
1522 }
1523 }
1524 }
1525 }
1526
1527
1528
1529
1530
1531 if (def->getNumResults() != 1)
1532 return {std::nullopt, false};
1533 SmallVector<std::pair<std::optional, bool>, 2> subExp;
1534
1535 for (Value operand : def->getOperands())
1536 subExp.push_back(buildTensorExp(op, operand));
1537
1538 if (llvm::all_of(subExp,
1539 [](auto e) { return e.first.has_value() && !e.second; })) {
1540
1541 if (subExp.size() == 2) {
1543 *subExp[1].first, def);
1544 return {e, false};
1545 }
1546 if (subExp.size() == 1) {
1549 return {e, false};
1550 }
1551 }
1552
1553
1554 return {std::nullopt, false};
1555 }
1556
1559
1562 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
1563 Block &clonedBlock = tmpRegion.front();
1564 YieldOp clonedYield = cast(clonedBlock.getTerminator());
1565
1568 Value val = clonedYield.getSingleResult();
1569 rewriter.eraseOp(clonedYield);
1570 rewriter.eraseOp(placeholder);
1571 return val;
1572 }
1573
1576 if (!v0)
1577
1579 UnaryOp unop = cast(op);
1580 Region &presentRegion = unop.getPresentRegion();
1581 if (presentRegion.empty())
1582
1583
1585 return insertYieldOp(rewriter, loc, presentRegion, {v0});
1586 }
1587
1590 if (!v0 || !v1)
1591
1593 BinaryOp binop = cast(op);
1594 Region &overlapRegion = binop.getOverlapRegion();
1595 if (overlapRegion.empty())
1596
1597
1599 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1600 }
1601
1605 auto zero =
1606 rewriter.createarith::ConstantOp(loc, tp, rewriter.getZeroAttr(tp));
1608 if (isa(tp)) {
1609 auto pred = llvm::castarith::CmpFPredicateAttr(attr);
1610 cmp = rewriter.createarith::CmpFOp(loc, pred, v0, zero);
1611 } else {
1612 auto pred = llvm::castarith::CmpIPredicateAttr(attr);
1613 cmp = rewriter.createarith::CmpIOp(loc, pred, v0, zero);
1614 }
1615 return rewriter.createarith::SelectOp(loc, cmp, v0, zero);
1616 }
1617
1619 Value v1) const {
1620 const auto &expr = exp(e);
1621 switch (expr.kind) {
1622
1627 llvm_unreachable("unexpected non-op");
1628
1630 return rewriter.createmath::AbsFOp(loc, v0);
1632 auto type = cast(v0.getType());
1633 auto eltType = cast(type.getElementType());
1634 return rewriter.createcomplex::AbsOp(loc, eltType, v0);
1635 }
1637 return rewriter.createmath::AbsIOp(loc, v0);
1639 return rewriter.createmath::CeilOp(loc, v0);
1641 return rewriter.createmath::FloorOp(loc, v0);
1643 return rewriter.createmath::SqrtOp(loc, v0);
1645 return rewriter.createcomplex::SqrtOp(loc, v0);
1647 return rewriter.createmath::ExpM1Op(loc, v0);
1649 return rewriter.createcomplex::Expm1Op(loc, v0);
1651 return rewriter.createmath::Log1pOp(loc, v0);
1653 return rewriter.createcomplex::Log1pOp(loc, v0);
1655 return buildRelu(rewriter, loc, v0, expr.attr);
1657 return rewriter.createmath::SinOp(loc, v0);
1659 return rewriter.createcomplex::SinOp(loc, v0);
1661 return rewriter.createmath::TanhOp(loc, v0);
1663 return rewriter.createcomplex::TanhOp(loc, v0);
1665 return rewriter.createarith::NegFOp(loc, v0);
1667 return rewriter.createcomplex::NegOp(loc, v0);
1669 return rewriter.createarith::SubIOp(
1670 loc,
1671 rewriter.createarith::ConstantOp(loc, v0.getType(),
1673 v0);
1675 return rewriter.createarith::TruncFOp(loc, inferType(e, v0), v0);
1677 return rewriter.createarith::ExtFOp(loc, inferType(e, v0), v0);
1679 return rewriter.createarith::FPToSIOp(loc, inferType(e, v0), v0);
1681 return rewriter.createarith::FPToUIOp(loc, inferType(e, v0), v0);
1683 return rewriter.createarith::SIToFPOp(loc, inferType(e, v0), v0);
1685 return rewriter.createarith::UIToFPOp(loc, inferType(e, v0), v0);
1687 return rewriter.createarith::ExtSIOp(loc, inferType(e, v0), v0);
1689 return rewriter.createarith::ExtUIOp(loc, inferType(e, v0), v0);
1691 return rewriter.createarith::IndexCastOp(loc, inferType(e, v0), v0);
1693 return rewriter.createarith::TruncIOp(loc, inferType(e, v0), v0);
1695 auto type = cast(v0.getType());
1696 auto eltType = cast(type.getElementType());
1697 return rewriter.createcomplex::ImOp(loc, eltType, v0);
1698 }
1700 auto type = cast(v0.getType());
1701 auto eltType = cast(type.getElementType());
1702 return rewriter.createcomplex::ReOp(loc, eltType, v0);
1703 }
1705 return rewriter.createarith::BitcastOp(loc, inferType(e, v0), v0);
1706
1708 return rewriter.createarith::MulFOp(loc, v0, v1);
1710 return rewriter.createcomplex::MulOp(loc, v0, v1);
1712 return rewriter.createarith::MulIOp(loc, v0, v1);
1714 return rewriter.createarith::DivFOp(loc, v0, v1);
1716 return rewriter.createcomplex::DivOp(loc, v0, v1);
1718 return rewriter.createarith::DivSIOp(loc, v0, v1);
1720 return rewriter.createarith::DivUIOp(loc, v0, v1);
1722 return rewriter.createarith::AddFOp(loc, v0, v1);
1724 return rewriter.createcomplex::AddOp(loc, v0, v1);
1726 return rewriter.createarith::AddIOp(loc, v0, v1);
1728 return rewriter.createarith::SubFOp(loc, v0, v1);
1730 return rewriter.createcomplex::SubOp(loc, v0, v1);
1732 return rewriter.createarith::SubIOp(loc, v0, v1);
1734 return rewriter.createarith::AndIOp(loc, v0, v1);
1736 return rewriter.createarith::OrIOp(loc, v0, v1);
1738 return rewriter.createarith::XOrIOp(loc, v0, v1);
1740 return rewriter.createarith::ShRSIOp(loc, v0, v1);
1742 return rewriter.createarith::ShRUIOp(loc, v0, v1);
1744 return rewriter.createarith::ShLIOp(loc, v0, v1);
1746 auto predicate = llvm::castarith::CmpIPredicateAttr(expr.attr);
1747 return rewriter.createarith::CmpIOp(loc, predicate, v0, v1);
1748 }
1750 auto predicate = llvm::castarith::CmpFPredicateAttr(expr.attr);
1751 return rewriter.createarith::CmpFOp(loc, predicate, v0, v1);
1752 }
1755 {v0});
1760 cast<sparse_tensor::SelectOp>(expr.op).getRegion(),
1761 {v0});
1765 ReduceOp redOp = cast(expr.op);
1766 return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1767 }
1774 return rewriter.clone(*actualOp, mapping)->getResult(0);
1775 }
1776 }
1777 llvm_unreachable("unexpected expression kind in build");
1778 }
1779
1780 }
1781 }
union mlir::linalg::@1204::ArityGroupAndKind::Kind kind
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
TypedAttr getZeroAttr(Type type)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns a floating point value.
Specialization of arith.constant op that returns an integer of index type.
Specialization of arith.constant op that returns an integer value.
LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1, Operation *op=nullptr)
Computes a single conjunction of two lattice points by taking the "union" of LoopId (effectively cons...
LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Disjunctive merge of two lattice sets: (s0 /\_op s1, s0, s1).
bool isSingleCondition(TensorId t, ExprId e) const
Returns true if given tensor iterates only in the given tensor expression.
bool hasSparseIdxReduction(const BitVector &bits) const
Returns true if bits contains a dependent index reduction condition on sparse levels.
bool expContainsTensor(ExprId e, TensorId t) const
Returns true if the expression contains the tensor as an operand.
LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero)
Maps the binary operator to the same operation but with one of its operand set to zero,...
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a sparse tensor level contains non-trivial index expressio...
void dumpBits(const BitVector &bits) const
LatSetId addSet()
Constructs a new (initially empty) set, and returns its identifier.
BitVector simplifyCond(LatSetId s, LatPointId p)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
bool hasNegateOnOut(ExprId e) const
Returns true if the expression contains a negation on output tensor.
bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a tensor level contains non-trivial index expression.
LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1)
Disjunctive merge of two lattice sets and also set one of the operand to zero: (s0 /\_op s1 (e0 op e1...
void dumpSet(LatSetId s) const
void dumpLat(LatPointId p) const
LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, bool includeRight, TensorExp::Kind rtrans, Operation *opright)
Disjunctive merge of two lattice sets with custom handling of the overlap, left, and right regions.
ExprId addTensorExp(TensorId t)
Constructs a new tensor expression, and returns its identifier.
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Conjunctive merge of two lattice sets: (s0 /\_op s1).
ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1=detail::kInvalidId, Operation *op=nullptr, Attribute attr=nullptr)
Constructs a new unary or binary expression, and returns its identifier.
ExprId addSynZeroExp()
Constructs a new synthetic zero expression.
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
std::optional< ExprId > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v=Value(), Operation *op=nullptr, Attribute attr=nullptr)
Maps the unary operator over the lattice set of the operand, i.e.
ArrayRef< LatPointId > set(LatSetId s) const
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
constexpr TensorId tensor(TensorLoopId b) const
Gets the tensor-identifier of the TensorLoopId.
void dumpExp(ExprId e) const
Print methods (for debugging).
Merger(unsigned numInputOutputTensors, unsigned numLoops, unsigned maxLvlRank)
Constructs a merger for the given number of tensors and loops.
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
LatPointId addLat(TensorId t, LoopId i, ExprId e)
Constructs a new iteration lattice point, and returns its identifier.
ExprId addLoopVarExp(LoopId i)
Constructs a new loop-variable expression, and returns its identifier.
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
const TensorExp & exp(ExprId e) const
Convenience getters to immediately access the stored nodes.
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
const LatPoint & lat(LatPointId p) const
bool onlyDenseDiff(LatPointId p0, LatPointId p1) const
Returns true if p0 and p1 only differ in dense.
ExprId addInvariantExp(Value v)
Constructs a new invariant expression, and returns its identifier.
constexpr TensorId makeTensorId(unsigned t) const
Safely converts the argument to a tensor identifier.
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const
Safely converts the arguments to a pair of (tensor,loop) identifiers.
bool expIsTensor(ExprId e, TensorId t) const
Returns true if the expression is (kTensor t).
@ Type
An inlay hint that for a type annotation.
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
LevelFormat
This enum defines all supported storage format without the level properties.
static bool isCertainZero(Value val)
Only returns true if we are certain this is a zero.
static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v)
Ensures that the sparsifier can generate code for expression.
unsigned LatSetId
LatSet identifiers.
static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, ValueRange vals)
std::string toMLIRString(LevelType lt)
std::pair< Level, LevelType > LvlLTPair
A pair of level and its corresponding LevelType of a tensor.
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, Operation *op, Value v0)
uint64_t Level
The type of level identifiers and level-ranks.
static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, Operation *op, Value v0, Value v1)
static bool isAdmissibleBranch(Operation *op, Region ®ion)
Ensures that the sparsifier can generate code for branch.
unsigned LoopId
Loop identifiers.
static const char * kindToOpSymbol(TensorExp::Kind kind)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
static bool isGreater(TensorExp::Kind kind, Attribute attr)
unsigned ExprId
TensorExp identifiers.
static ExpArity getExpArity(TensorExp::Kind k)
static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0, Attribute attr)
unsigned LatPointId
LatPoint identifiers.
std::pair< LoopId, unsigned > LoopCoeffPair
A pair of loop id and its coefficients.
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
BitVector bits
Conjunction of all TensorLoopIds involved in the tensor expression.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
LoopId loop
kLoopVar expressions simply have a loop identifier.
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Kind
Tensor expression kind.
Children children
All other expressions hold the ExprIds of their children.
Attribute attr
An optional attribute that is required to determine the semantics of the operations.
TensorId tensor
kTensor expressions simply have a tensor identifier.
Kind kind
Tensor expression kind.
TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a)
The x parameter has different types depending on the value of the k parameter.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.