MLIR: lib/Dialect/Affine/IR/AffineOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

23 #include "llvm/ADT/STLExtras.h"

24 #include "llvm/ADT/ScopeExit.h"

25 #include "llvm/ADT/SmallBitVector.h"

26 #include "llvm/ADT/SmallVectorExtras.h"

27 #include "llvm/ADT/TypeSwitch.h"

28 #include "llvm/Support/Debug.h"

29 #include "llvm/Support/MathExtras.h"

30 #include

31 #include

32

33 using namespace mlir;

35

36 using llvm::divideCeilSigned;

37 using llvm::divideFloorSigned;

38 using llvm::mod;

39

40 #define DEBUG_TYPE "affine-ops"

41

42 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"

43

44

45

46

47

49 if (auto arg = llvm::dyn_cast(value))

50 return arg.getParentRegion() == region;

52 }

53

54

55

56

57

58

59 static bool

63

64

65

66

68 return true;

69

70

71

72

73 if (llvm::isa(value))

74 return legalityCheck(mapping.lookup(value), dest);

75

76

77

78

79

81 bool isDimLikeOp = isa(value.getDefiningOp());

83 isDimLikeOp;

84 }

85

86

87

88 static bool

92 return llvm::all_of(values, [&](Value v) {

94 });

95 }

96

97

98

99 template

102 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,

103 AffineWriteOpInterface>::value,

104 "only ops with affine read/write interface are supported");

105

106 AffineMap map = op.getAffineMap();

109 op.getMapOperands().take_back(map.getNumSymbols());

111 dimOperands, src, dest, mapping,

113 return false;

115 symbolOperands, src, dest, mapping,

117 return false;

118 return true;

119 }

120

121

122

123

124

125 template <>

129

132 op.getMapOperands(), src, dest, mapping,

134

135

137 op.getMapOperands(), src, dest, mapping,

139 }

140

141

142

143

144

145 namespace {

146

147

150

151

152

153

154

155

156

157

158

160 IRMapping &valueMapping) const final {

161

162

164 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))

165 return false;

166

167

168

169 if (!llvm::hasSingleElement(*src))

170 return false;

171

172

173

176

177 if (auto iface = dyn_cast(op)) {

178 if (iface.hasNoEffect())

179 continue;

180 }

181

182

183

184 bool remainsValid =

186 .Case<AffineApplyOp, AffineReadOpInterface,

187 AffineWriteOpInterface>([&](auto op) {

189 })

191

192 return false;

193 });

194

195 if (!remainsValid)

196 return false;

197 }

198

199 return true;

200 }

201

202

203

205 IRMapping &valueMapping) const final {

206

207

208

209

212 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);

213 }

214

215

216 bool shouldAnalyzeRecursively(Operation *op) const final { return true; }

217 };

218 }

219

220

221

222

223

224 void AffineDialect::initialize() {

226 #define GET_OP_LIST

227 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"

228 >();

229 addInterfaces();

230 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,

231 AffineMinOp>();

232 }

233

234

235

239 if (auto poison = dyn_castub::PoisonAttr(value))

240 return builder.createub::PoisonOp(loc, type, poison);

241 return arith::ConstantOp::materialize(builder, value, type, loc);

242 }

243

244

245

246

247

249 if (auto arg = llvm::dyn_cast(value)) {

250

251

252

255 }

256

259 }

260

261

262

264 auto *curOp = op;

265 while (auto *parentOp = curOp->getParentOp()) {

268 curOp = parentOp;

269 }

270 return nullptr;

271 }

272

275 while (auto *parentOp = curOp->getParentOp()) {

276 if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))

278 curOp = parentOp;

279 }

280 return nullptr;

281 }

282

283

284

285

286

287

289

291 return false;

292

295

296

297

298

300 return true;

301 auto *parentOp = llvm::cast(value).getOwner()->getParentOp();

303 }

304

305

306

307

308

309

310

311

313

315 return false;

316

317

319 return true;

320

322 if (!op) {

323

324

326 }

327

328

329 if (auto applyOp = dyn_cast(op))

330 return applyOp.isValidDim(region);

331

332

333 if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))

334 return llvm::all_of(op->getOperands(),

335 [&](Value arg) { return ::isValidDim(arg, region); });

336

337

338 if (auto dimOp = dyn_cast(op))

340 return false;

341 }

342

343

344

345

346 template

349 MemRefType memRefType = memrefDefOp.getType();

350

351

352 if (index >= memRefType.getRank()) {

353 return false;

354 }

355

356

357 if (!memRefType.isDynamicDim(index))

358 return true;

359

360 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);

361 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),

362 region);

363 }

364

365

367

369 return true;

370

371

372

373 if (llvm::isa(dimOp.getShapedValue()))

374 return false;

375

376

377

378 std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());

379

380

381 if (!index.has_value())

382 return false;

383

384

385 Operation *op = dimOp.getShapedValue().getDefiningOp();

386 while (auto castOp = dyn_castmemref::CastOp(op)) {

387

388 if (isa(castOp.getSource().getType()))

389 return false;

390 op = castOp.getSource().getDefiningOp();

391 if (!op)

392 return false;

393 }

394

395 int64_t i = index.value();

397 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(

399 .Default([](Operation *) { return false; });

400 }

401

402

403

404

405

406

407

408

409

411 if (!value)

412 return false;

413

414

416 return false;

417

418

420 return true;

421

424

425 return false;

426 }

427

428

429

430

431

432

433

434

435

436

437

438

439

441

443 return false;

444

445

447 return true;

448

450 if (!defOp) {

451

452

457 return false;

458 }

459

460

463 return true;

464

465

466 if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {

467 return affine::isValidSymbol(operand, region);

468 })) {

469 return true;

470 }

471

472

473 if (auto dimOp = dyn_cast(defOp))

475

476

481

482 return false;

483 }

484

485

486

487

490 }

491

492

497 printer << '(' << operands.take_front(numDims) << ')';

498 if (operands.size() > numDims)

499 printer << '[' << operands.drop_front(numDims) << ']';

500 }

501

502

507 return failure();

508

509 numDims = opInfos.size();

510

511

516 }

517

518

519

520

521

522

523 template

524 static LogicalResult

526 unsigned numDims) {

527 unsigned opIt = 0;

528 for (auto operand : operands) {

529 if (opIt++ < numDims) {

531 return op.emitOpError("operand cannot be used as a dimension id");

533 return op.emitOpError("operand cannot be used as a symbol");

534 }

535 }

536 return success();

537 }

538

539

540

541

542

544 return AffineValueMap(getAffineMap(), getOperands(), getResult());

545 }

546

550

551 AffineMapAttr mapAttr;

552 unsigned numDims;

556 return failure();

557 auto map = mapAttr.getValue();

558

559 if (map.getNumDims() != numDims ||

560 numDims + map.getNumSymbols() != result.operands.size()) {

562 "dimension or symbol index mismatch");

563 }

564

565 result.types.append(map.getNumResults(), indexTy);

566 return success();

567 }

568

570 p << " " << getMapAttr();

572 getAffineMap().getNumDims(), p);

574 }

575

577

579

580

582 return emitOpError(

583 "operand count and affine map dimension and symbol count must match");

584

585

587 return emitOpError("mapping must produce one value");

588

589

590

591

593 for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) {

595 return emitError("dimensional operand cannot be used as a symbol");

596 }

597

598 return success();

599 }

600

601

602

604 return llvm::all_of(getOperands(),

606 }

607

608

609

610

612 return llvm::all_of(getOperands(),

614 }

615

616

617

619 return llvm::all_of(getOperands(),

621 }

622

623

624

626 return llvm::all_of(getOperands(), [&](Value operand) {

628 });

629 }

630

631 OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {

632 auto map = getAffineMap();

633

634

635 auto expr = map.getResult(0);

636 if (auto dim = dyn_cast(expr))

637 return getOperand(dim.getPosition());

638 if (auto sym = dyn_cast(expr))

639 return getOperand(map.getNumDims() + sym.getPosition());

640

641

643 bool hasPoison = false;

644 auto foldResult =

645 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);

646 if (hasPoison)

648 if (failed(foldResult))

649 return {};

650 return result[0];

651 }

652

653

654

656

658

659

660

661

662

663

664

665 auto dimExpr = dyn_cast(e);

666

667 if (!dimExpr)

668 return div;

669

670

671

672

673

674

675 Value operand = operands[dimExpr.getPosition()];

676 int64_t operandDivisor = 1;

677

678

680 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {

681 operandDivisor = forOp.getStepAsInt();

682 } else {

683 uint64_t lbLargestKnownDivisor =

684 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();

685 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());

686 }

687 }

688 return operandDivisor;

689 }

690

691

692

694 int64_t k) {

695 if (auto constExpr = dyn_cast(e)) {

696 int64_t constVal = constExpr.getValue();

697 return constVal >= 0 && constVal < k;

698 }

699 auto dimExpr = dyn_cast(e);

700 if (!dimExpr)

701 return false;

702 Value operand = operands[dimExpr.getPosition()];

703

704

706 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&

707 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {

708 return true;

709 }

710 }

711

712

713

714

715 return false;

716 }

717

718

719

720

723 auto bin = dyn_cast(e);

725 return false;

726

731 quotientTimesDiv = llhs;

732 rem = rlhs;

733 return true;

734 }

737 quotientTimesDiv = rlhs;

738 rem = llhs;

739 return true;

740 }

741 return false;

742 }

743

744

747 if (forOp && forOp.hasConstantLowerBound())

748 return forOp.getConstantLowerBound();

749 return std::nullopt;

750 }

751

752

755 if (!forOp || !forOp.hasConstantUpperBound())

756 return std::nullopt;

757

758

759

760 if (forOp.hasConstantLowerBound()) {

761 return forOp.getConstantUpperBound() - 1 -

762 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %

763 forOp.getStepAsInt();

764 }

765 return forOp.getConstantUpperBound() - 1;

766 }

767

768

769

770

772 unsigned numSymbols,

774

776 constLowerBounds.reserve(operands.size());

777 constUpperBounds.reserve(operands.size());

778 for (Value operand : operands) {

779 constLowerBounds.push_back(getLowerBound(operand));

780 constUpperBounds.push_back(getUpperBound(operand));

781 }

782

783 if (auto constExpr = dyn_cast(expr))

784 return constExpr.getValue();

785

787 constUpperBounds,

788 true);

789 }

790

791

792

793

795 unsigned numSymbols,

797

799 constLowerBounds.reserve(operands.size());

800 constUpperBounds.reserve(operands.size());

801 for (Value operand : operands) {

802 constLowerBounds.push_back(getLowerBound(operand));

803 constUpperBounds.push_back(getUpperBound(operand));

804 }

805

806 std::optional<int64_t> lowerBound;

807 if (auto constExpr = dyn_cast(expr)) {

808 lowerBound = constExpr.getValue();

809 } else {

811 constLowerBounds, constUpperBounds,

812 false);

813 }

814 return lowerBound;

815 }

816

817

819 unsigned numSymbols,

821

822 auto binExpr = dyn_cast(expr);

823 if (!binExpr)

824 return;

825

826

832

833 binExpr = dyn_cast(expr);

837 return;

838 }

839

840

841 lhs = binExpr.getLHS();

842 rhs = binExpr.getRHS();

843 auto rhsConst = dyn_cast(rhs);

844 if (!rhsConst)

845 return;

846

847 int64_t rhsConstVal = rhsConst.getValue();

848

849 if (rhsConstVal <= 0)

850 return;

851

852

854 std::optional<int64_t> lhsLbConst =

855 getLowerBound(lhs, numDims, numSymbols, operands);

856 std::optional<int64_t> lhsUbConst =

857 getUpperBound(lhs, numDims, numSymbols, operands);

858 if (lhsLbConst && lhsUbConst) {

859 int64_t lhsLbConstVal = *lhsLbConst;

860 int64_t lhsUbConstVal = *lhsUbConst;

861

862

864 divideFloorSigned(lhsLbConstVal, rhsConstVal) ==

865 divideFloorSigned(lhsUbConstVal, rhsConstVal)) {

867 divideFloorSigned(lhsLbConstVal, rhsConstVal), context);

868 return;

869 }

870

871

873 divideCeilSigned(lhsLbConstVal, rhsConstVal) ==

874 divideCeilSigned(lhsUbConstVal, rhsConstVal)) {

876 context);

877 return;

878 }

879

881 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {

882 expr = lhs;

883 return;

884 }

885 }

886

887

888

889

890

892 int64_t divisor;

893 if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {

894 if (rhsConstVal % divisor == 0 &&

896 expr = quotientTimesDiv.floorDiv(rhsConst);

897 } else if (divisor % rhsConstVal == 0 &&

899 expr = rem % rhsConst;

900 }

901 return;

902 }

903

904

905

906

907

913 }

914 }

915

916

917

918

919

920

923 bool isMax) {

924

925 if (operands.empty())

926 return;

927

928

929

931 constLowerBounds.reserve(operands.size());

932 constUpperBounds.reserve(operands.size());

933 for (Value operand : operands) {

934 constLowerBounds.push_back(getLowerBound(operand));

935 constUpperBounds.push_back(getUpperBound(operand));

936 }

937

938

939

940

941

942

947 if (auto constExpr = dyn_cast(e)) {

948 lowerBounds.push_back(constExpr.getValue());

949 upperBounds.push_back(constExpr.getValue());

950 } else {

951 lowerBounds.push_back(

953 constLowerBounds, constUpperBounds,

954 false));

955 upperBounds.push_back(

957 constLowerBounds, constUpperBounds,

958 true));

959 }

960 }

961

962

966 unsigned i = exprEn.index();

967

968 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])

970

971

972 if (isMax) {

973 if (!upperBounds[i]) {

974 irredundantExprs.push_back(e);

975 continue;

976 }

977

978

979 if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {

980 auto otherLowerBound = en.value();

981 unsigned pos = en.index();

982 if (pos == i || !otherLowerBound)

983 return false;

984 if (*otherLowerBound > *upperBounds[i])

985 return true;

986 if (*otherLowerBound < *upperBounds[i])

987 return false;

988

989

990

991 if (upperBounds[pos] && lowerBounds[i] &&

992 lowerBounds[i] == upperBounds[i] &&

993 otherLowerBound == *upperBounds[pos] && i < pos)

994 return false;

995 return true;

996 }))

997 irredundantExprs.push_back(e);

998 } else {

999 if (!lowerBounds[i]) {

1000 irredundantExprs.push_back(e);

1001 continue;

1002 }

1003

1004 if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {

1005 auto otherUpperBound = en.value();

1006 unsigned pos = en.index();

1007 if (pos == i || !otherUpperBound)

1008 return false;

1009 if (*otherUpperBound < *lowerBounds[i])

1010 return true;

1011 if (*otherUpperBound > *lowerBounds[i])

1012 return false;

1013 if (lowerBounds[pos] && upperBounds[i] &&

1014 lowerBounds[i] == upperBounds[i] &&

1015 otherUpperBound == lowerBounds[pos] && i < pos)

1016 return false;

1017 return true;

1018 }))

1019 irredundantExprs.push_back(e);

1020 }

1021 }

1022

1023

1026 }

1027

1028

1029

1030

1031 static void LLVM_ATTRIBUTE_UNUSED

1033 assert(map.getNumInputs() == operands.size() && "invalid operands for map");

1038 operands);

1039 newResults.push_back(expr);

1040 }

1043 }

1044

1045

1046

1047

1048

1049

1050

1051

1052

1053

1054

1056 unsigned dimOrSymbolPosition,

1060 bool isDimReplacement = (dimOrSymbolPosition < dims.size());

1061 unsigned pos = isDimReplacement ? dimOrSymbolPosition

1062 : dimOrSymbolPosition - dims.size();

1063 Value &v = isDimReplacement ? dims[pos] : syms[pos];

1064 if (!v)

1065 return failure();

1066

1067 auto affineApply = v.getDefiningOp();

1068 if (!affineApply)

1069 return failure();

1070

1071

1072

1073 v = nullptr;

1074

1075

1076 AffineMap composeMap = affineApply.getAffineMap();

1077 assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");

1078 SmallVector composeOperands(affineApply.getMapOperands().begin(),

1079 affineApply.getMapOperands().end());

1080

1081

1091

1092

1093 dims.append(composeDims.begin(), composeDims.end());

1094 syms.append(composeSyms.begin(), composeSyms.end());

1095 *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());

1096

1097 return success();

1098 }

1099

1100

1101

1102

1108 return;

1109 }

1110

1113 operands->begin() + map->getNumDims());

1115 operands->end());

1116

1117

1118

1119

1120

1121

1122 while (true) {

1124 for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)

1126 break;

1128 break;

1129 }

1130

1131

1132 operands->clear();

1133

1134

1135

1136 unsigned nDims = 0, nSyms = 0;

1138 dimReplacements.reserve(dims.size());

1139 symReplacements.reserve(syms.size());

1140 for (auto *container : {&dims, &syms}) {

1141 bool isDim = (container == &dims);

1142 auto &repls = isDim ? dimReplacements : symReplacements;

1144 Value v = en.value();

1145 if (!v) {

1148 "map is function of unexpected expr@pos");

1150 continue;

1151 }

1154 operands->push_back(v);

1155 }

1156 }

1158 nSyms);

1159

1160

1163 }

1164

1167 while (llvm::any_of(*operands, [](Value v) {

1168 return isa_and_nonnull(v.getDefiningOp());

1169 })) {

1171 }

1172 }

1173

1174 AffineApplyOp

1180 assert(map);

1181 return b.create(loc, map, valueOperands);

1182 }

1183

1184 AffineApplyOp

1188 b, loc,

1190 .front(),

1191 operands);

1192 }

1193

1194

1195

1198

1199

1200

1203 for (unsigned i : llvm::seq(0, map.getNumResults())) {

1204 SmallVector submapOperands(operands.begin(), operands.end());

1208 unsigned numNewDims = submap.getNumDims();

1210 llvm::append_range(dims,

1211 ArrayRef(submapOperands).take_front(numNewDims));

1212 llvm::append_range(symbols,

1213 ArrayRef(submapOperands).drop_front(numNewDims));

1214 exprs.push_back(submap.getResult(0));

1215 }

1216

1217

1218

1219 operands = llvm::to_vector(llvm::concat(dims, symbols));

1222 }

1223

1228 assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");

1229

1230

1231

1232

1233

1236

1237

1238 AffineApplyOp applyOp =

1240

1241

1243 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)

1245

1246

1248 if (failed(applyOp->fold(constOperands, foldResults)) ||

1249 foldResults.empty()) {

1251 listener->notifyOperationInserted(applyOp, {});

1252 return applyOp.getResult();

1253 }

1254

1255 applyOp->erase();

1256 return llvm::getSingleElement(foldResults);

1257 }

1258

1264 b, loc,

1266 .front(),

1267 operands);

1268 }

1269

1274 return llvm::map_to_vector(llvm::seq(0, map.getNumResults()),

1275 [&](unsigned i) {

1276 return makeComposedFoldedAffineApply(

1277 b, loc, map.getSubMap({i}), operands);

1278 });

1279 }

1280

1281 template

1288 }

1289

1290 AffineMinOp

1293 return makeComposedMinMax(b, loc, map, operands);

1294 }

1295

1296 template

1300

1301

1302

1303

1306

1307

1308 auto minMaxOp = makeComposedMinMax(newBuilder, loc, map, operands);

1309

1310

1312 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)

1314

1315

1317 if (failed(minMaxOp->fold(constOperands, foldResults)) ||

1318 foldResults.empty()) {

1320 listener->notifyOperationInserted(minMaxOp, {});

1321 return minMaxOp.getResult();

1322 }

1323

1324 minMaxOp->erase();

1325 return llvm::getSingleElement(foldResults);

1326 }

1327

1332 return makeComposedFoldedMinMax(b, loc, map, operands);

1333 }

1334

1339 return makeComposedFoldedMinMax(b, loc, map, operands);

1340 }

1341

1342

1343

1344 template

1347 if (!mapOrSet || operands->empty())

1348 return;

1349

1350 assert(mapOrSet->getNumInputs() == operands->size() &&

1351 "map/set inputs must match number of operands");

1352

1353 auto *context = mapOrSet->getContext();

1355 resultOperands.reserve(operands->size());

1357 remappedSymbols.reserve(operands->size());

1358 unsigned nextDim = 0;

1359 unsigned nextSym = 0;

1360 unsigned oldNumSyms = mapOrSet->getNumSymbols();

1362 for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {

1363 if (i < mapOrSet->getNumDims()) {

1365

1367 remappedSymbols.push_back((*operands)[i]);

1368 } else {

1370 resultOperands.push_back((*operands)[i]);

1371 }

1372 } else {

1373 resultOperands.push_back((*operands)[i]);

1374 }

1375 }

1376

1377 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());

1378 *operands = resultOperands;

1379 *mapOrSet = mapOrSet->replaceDimsAndSymbols(

1380 dimRemapping, {}, nextDim, oldNumSyms + nextSym);

1381

1382 assert(mapOrSet->getNumInputs() == operands->size() &&

1383 "map/set inputs must match number of operands");

1384 }

1385

1386

1387

1388

1389

1390

1391

1392 template

1395 if (!mapOrSet || operands.empty())

1396 return;

1397

1398 unsigned numOperands = operands.size();

1399

1400 assert(mapOrSet.getNumInputs() == numOperands &&

1401 "map/set inputs must match number of operands");

1402

1403 auto *context = mapOrSet.getContext();

1405 resultOperands.reserve(numOperands);

1407 remappedDims.reserve(numOperands);

1409 symOperands.reserve(mapOrSet.getNumSymbols());

1410 unsigned nextSym = 0;

1411 unsigned nextDim = 0;

1412 unsigned oldNumDims = mapOrSet.getNumDims();

1414 resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);

1415 for (unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {

1417

1418 symRemapping[i - oldNumDims] =

1420 remappedDims.push_back(operands[i]);

1421 } else {

1423 symOperands.push_back(operands[i]);

1424 }

1425 }

1426

1427 append_range(resultOperands, remappedDims);

1428 append_range(resultOperands, symOperands);

1429 operands = resultOperands;

1430 mapOrSet = mapOrSet.replaceDimsAndSymbols(

1431 {}, symRemapping, oldNumDims + nextDim, nextSym);

1432

1433 assert(mapOrSet.getNumInputs() == operands.size() &&

1434 "map/set inputs must match number of operands");

1435 }

1436

1437

1438 template

1441 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,

1442 "Argument must be either of AffineMap or IntegerSet type");

1443

1444 if (!mapOrSet || operands->empty())

1445 return;

1446

1447 assert(mapOrSet->getNumInputs() == operands->size() &&

1448 "map/set inputs must match number of operands");

1449

1450 canonicalizePromotedSymbols(mapOrSet, operands);

1451 legalizeDemotedDims(*mapOrSet, *operands);

1452

1453

1454 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());

1455 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());

1456 mapOrSet->walkExprs([&](AffineExpr expr) {

1457 if (auto dimExpr = dyn_cast(expr))

1458 usedDims[dimExpr.getPosition()] = true;

1459 else if (auto symExpr = dyn_cast(expr))

1460 usedSyms[symExpr.getPosition()] = true;

1461 });

1462

1463 auto *context = mapOrSet->getContext();

1464

1466 resultOperands.reserve(operands->size());

1467

1468 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;

1470 unsigned nextDim = 0;

1471 for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {

1472 if (usedDims[i]) {

1473

1474 auto it = seenDims.find((*operands)[i]);

1475 if (it == seenDims.end()) {

1477 resultOperands.push_back((*operands)[i]);

1478 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));

1479 } else {

1480 dimRemapping[i] = it->second;

1481 }

1482 }

1483 }

1484 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;

1486 unsigned nextSym = 0;

1487 for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {

1488 if (!usedSyms[i])

1489 continue;

1490

1491

1492

1493 IntegerAttr operandCst;

1494 if (matchPattern((*operands)[i + mapOrSet->getNumDims()],

1496 symRemapping[i] =

1498 continue;

1499 }

1500

1501 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);

1502 if (it == seenSymbols.end()) {

1504 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);

1505 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],

1506 symRemapping[i]));

1507 } else {

1508 symRemapping[i] = it->second;

1509 }

1510 }

1511 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,

1512 nextDim, nextSym);

1513 *operands = resultOperands;

1514 }

1515

1518 canonicalizeMapOrSetAndOperands(map, operands);

1519 }

1520

1523 canonicalizeMapOrSetAndOperands(set, operands);

1524 }

1525

1526 namespace {

1527

1528

1529

1530 template

1531 struct SimplifyAffineOp : public OpRewritePattern {

1533

1534

1535

1536 void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,

1538

1539 LogicalResult matchAndRewrite(AffineOpTy affineOp,

1541 static_assert(

1542 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,

1543 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,

1544 AffineVectorStoreOp, AffineVectorLoadOp>::value,

1545 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "

1546 "expected");

1547 auto map = affineOp.getAffineMap();

1549 auto oldOperands = affineOp.getMapOperands();

1554 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),

1555 resultOperands.begin()))

1556 return failure();

1557

1558 replaceAffineOp(rewriter, affineOp, map, resultOperands);

1559 return success();

1560 }

1561 };

1562

1563

1564

1565 template <>

1566 void SimplifyAffineOp::replaceAffineOp(

1569 rewriter.replaceOpWithNewOp(load, load.getMemRef(), map,

1570 mapOperands);

1571 }

1572 template <>

1573 void SimplifyAffineOp::replaceAffineOp(

1577 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),

1578 prefetch.getLocalityHint(), prefetch.getIsDataCache());

1579 }

1580 template <>

1581 void SimplifyAffineOp::replaceAffineOp(

1585 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);

1586 }

1587 template <>

1588 void SimplifyAffineOp::replaceAffineOp(

1592 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,

1593 mapOperands);

1594 }

1595 template <>

1596 void SimplifyAffineOp::replaceAffineOp(

1600 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,

1601 mapOperands);

1602 }

1603

1604

1605 template

1606 void SimplifyAffineOp::replaceAffineOp(

1610 }

1611 }

1612

1613 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,

1615 results.add<SimplifyAffineOp>(context);

1616 }

1617

1618

1619

1620

1621

1622

1629 Value stride, Value elementsPerStride) {

1640 if (stride) {

1641 result.addOperands({stride, elementsPerStride});

1642 }

1643 }

1644

1646 p << " " << getSrcMemRef() << '[';

1648 p << "], " << getDstMemRef() << '[';

1650 p << "], " << getTagMemRef() << '[';

1653 if (isStrided()) {

1655 p << ", " << getNumElementsPerStride();

1656 }

1657 p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "

1658 << getTagMemRefType();

1659 }

1660

1661

1662

1663

1664

1665

1666

1670 AffineMapAttr srcMapAttr;

1673 AffineMapAttr dstMapAttr;

1676 AffineMapAttr tagMapAttr;

1680

1683

1684

1685

1686

1687

1688

1691 getSrcMapAttrStrName(),

1695 getDstMapAttrStrName(),

1699 getTagMapAttrStrName(),

1702 return failure();

1703

1704

1706 return failure();

1707

1708 if (!strideInfo.empty() && strideInfo.size() != 2) {

1710 "expected two stride related operands");

1711 }

1712 bool isStrided = strideInfo.size() == 2;

1713

1715 return failure();

1716

1717 if (types.size() != 3)

1719

1727 return failure();

1728

1729 if (isStrided) {

1731 return failure();

1732 }

1733

1734

1735 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||

1736 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||

1737 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())

1739 "memref operand count not equal to map.numInputs");

1740 return success();

1741 }

1742

1743 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {

1744 if (!llvm::isa(getOperand(getSrcMemRefOperandIndex()).getType()))

1745 return emitOpError("expected DMA source to be of memref type");

1746 if (!llvm::isa(getOperand(getDstMemRefOperandIndex()).getType()))

1747 return emitOpError("expected DMA destination to be of memref type");

1748 if (!llvm::isa(getOperand(getTagMemRefOperandIndex()).getType()))

1749 return emitOpError("expected DMA tag to be of memref type");

1750

1751 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +

1752 getDstMap().getNumInputs() +

1753 getTagMap().getNumInputs();

1754 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&

1755 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {

1756 return emitOpError("incorrect number of operands");

1757 }

1758

1760 for (auto idx : getSrcIndices()) {

1761 if (!idx.getType().isIndex())

1762 return emitOpError("src index to dma_start must have 'index' type");

1764 return emitOpError(

1765 "src index must be a valid dimension or symbol identifier");

1766 }

1767 for (auto idx : getDstIndices()) {

1768 if (!idx.getType().isIndex())

1769 return emitOpError("dst index to dma_start must have 'index' type");

1771 return emitOpError(

1772 "dst index must be a valid dimension or symbol identifier");

1773 }

1774 for (auto idx : getTagIndices()) {

1775 if (!idx.getType().isIndex())

1776 return emitOpError("tag index to dma_start must have 'index' type");

1778 return emitOpError(

1779 "tag index must be a valid dimension or symbol identifier");

1780 }

1781 return success();

1782 }

1783

1786

1788 }

1789

1790 void AffineDmaStartOp::getEffects(

1792 &effects) {

1799 }

1800

1801

1802

1803

1804

1805

1813 }

1814

1816 p << " " << getTagMemRef() << '[';

1819 p << "], ";

1821 p << " : " << getTagMemRef().getType();

1822 }

1823

1824

1825

1826

1827

1828

1832 AffineMapAttr tagMapAttr;

1837

1838

1841 getTagMapAttrStrName(),

1848 return failure();

1849

1850 if (!llvm::isa(type))

1852 "expected tag to be of memref type");

1853

1854 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())

1856 "tag memref operand count != to map.numInputs");

1857 return success();

1858 }

1859

1860 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {

1861 if (!llvm::isa(getOperand(0).getType()))

1862 return emitOpError("expected DMA tag to be of memref type");

1864 for (auto idx : getTagIndices()) {

1865 if (!idx.getType().isIndex())

1866 return emitOpError("index to dma_wait must have 'index' type");

1868 return emitOpError(

1869 "index must be a valid dimension or symbol identifier");

1870 }

1871 return success();

1872 }

1873

1876

1878 }

1879

1880 void AffineDmaWaitOp::getEffects(

1882 &effects) {

1885 }

1886

1887

1888

1889

1890

1891

1892

1896 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {

1897 assert(((!lbMap && lbOperands.empty()) ||

1898 lbOperands.size() == lbMap.getNumInputs()) &&

1899 "lower bound operand count does not match the affine map");

1900 assert(((!ubMap && ubOperands.empty()) ||

1901 ubOperands.size() == ubMap.getNumInputs()) &&

1902 "upper bound operand count does not match the affine map");

1903 assert(step > 0 && "step has to be a positive integer constant");

1904

1906

1907

1909 getOperandSegmentSizeAttr(),

1911 static_cast<int32_t>(ubOperands.size()),

1912 static_cast<int32_t>(iterArgs.size())}));

1913

1914 for (Value val : iterArgs)

1915 result.addTypes(val.getType());

1916

1917

1920

1921

1925

1926

1930

1932

1933

1936 Value inductionVar =

1938 for (Value val : iterArgs)

1939 bodyBlock->addArgument(val.getType(), val.getLoc());

1940

1941

1942

1943

1944 if (iterArgs.empty() && !bodyBuilder) {

1945 ensureTerminator(*bodyRegion, builder, result.location);

1946 } else if (bodyBuilder) {

1949 bodyBuilder(builder, result.location, inductionVar,

1951 }

1952 }

1953

1955 int64_t ub, int64_t step, ValueRange iterArgs,

1956 BodyBuilderFn bodyBuilder) {

1959 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,

1960 bodyBuilder);

1961 }

1962

1963 LogicalResult AffineForOp::verifyRegions() {

1964

1965

1966 auto *body = getBody();

1967 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())

1968 return emitOpError("expected body to have a single index argument for the "

1969 "induction variable");

1970

1971

1972

1973 if (getLowerBoundMap().getNumInputs() > 0)

1975 getLowerBoundMap().getNumDims())))

1976 return failure();

1977

1978 if (getUpperBoundMap().getNumInputs() > 0)

1980 getUpperBoundMap().getNumDims())))

1981 return failure();

1982 if (getLowerBoundMap().getNumResults() < 1)

1983 return emitOpError("expected lower bound map to have at least one result");

1984 if (getUpperBoundMap().getNumResults() < 1)

1985 return emitOpError("expected upper bound map to have at least one result");

1986

1987 unsigned opNumResults = getNumResults();

1988 if (opNumResults == 0)

1989 return success();

1990

1991

1992

1993

1994 if (getNumIterOperands() != opNumResults)

1995 return emitOpError(

1996 "mismatch between the number of loop-carried values and results");

1997 if (getNumRegionIterArgs() != opNumResults)

1998 return emitOpError(

1999 "mismatch between the number of basic block args and results");

2000

2001 return success();

2002 }

2003

2004

2007

2008

2009 bool failedToParsedMinMax =

2011

2013 auto boundAttrStrName =

2014 isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)

2015 : AffineForOp::getUpperBoundMapAttrName(result.name);

2016

2017

2020 return failure();

2021

2022 if (!boundOpInfos.empty()) {

2023

2024 if (boundOpInfos.size() > 1)

2026 "expected only one loop bound operand");

2027

2028

2029

2032 return failure();

2033

2034

2035

2036

2039 return success();

2040 }

2041

2042

2044

2048 return failure();

2049

2050

2051 if (auto affineMapAttr = llvm::dyn_cast(boundAttr)) {

2052 unsigned currentNumOperands = result.operands.size();

2053 unsigned numDims;

2055 return failure();

2056

2057 auto map = affineMapAttr.getValue();

2061 "dim operand count and affine map dim count must match");

2062

2063 unsigned numDimAndSymbolOperands =

2064 result.operands.size() - currentNumOperands;

2065 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)

2068 "symbol operand count and affine map symbol count must match");

2069

2070

2071

2072 if (map.getNumResults() > 1 && failedToParsedMinMax) {

2073 if (isLower) {

2074 return p.emitError(attrLoc, "lower loop bound affine map with "

2075 "multiple results requires 'max' prefix");

2076 }

2077 return p.emitError(attrLoc, "upper loop bound affine map with multiple "

2078 "results requires 'min' prefix");

2079 }

2080 return success();

2081 }

2082

2083

2084 if (auto integerAttr = llvm::dyn_cast(boundAttr)) {

2087 boundAttrStrName,

2089 return success();

2090 }

2091

2094 "expected valid affine map representation for loop bounds");

2095 }

2096

2098 auto &builder = parser.getBuilder();

2101

2103 return failure();

2104

2105

2106 int64_t numOperands = result.operands.size();

2107 if (parseBound(true, result, parser))

2108 return failure();

2109 int64_t numLbOperands = result.operands.size() - numOperands;

2110 if (parser.parseKeyword("to", " between bounds"))

2111 return failure();

2112 numOperands = result.operands.size();

2113 if (parseBound(false, result, parser))

2114 return failure();

2115 int64_t numUbOperands = result.operands.size() - numOperands;

2116

2117

2120 getStepAttrName(result.name),

2122 } else {

2124 IntegerAttr stepAttr;

2126 getStepAttrName(result.name).data(),

2128 return failure();

2129

2130 if (stepAttr.getValue().isNegative())

2132 stepLoc,

2133 "expected step to be representable as a positive signed integer");

2134 }

2135

2136

2139

2140

2141 regionArgs.push_back(inductionVariable);

2142

2144

2147 return failure();

2148

2149 for (auto argOperandType :

2150 llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {

2151 Type type = std::get<2>(argOperandType);

2152 std::get<0>(argOperandType).type = type;

2153 if (parser.resolveOperand(std::get<1>(argOperandType), type,

2155 return failure();

2156 }

2157 }

2158

2160 getOperandSegmentSizeAttr(),

2162 static_cast<int32_t>(numUbOperands),

2163 static_cast<int32_t>(operands.size())}));

2164

2165

2167 if (regionArgs.size() != result.types.size() + 1)

2170 "mismatch between the number of loop-carried values and results");

2171 if (parser.parseRegion(*body, regionArgs))

2172 return failure();

2173

2174 AffineForOp::ensureTerminator(*body, builder, result.location);

2175

2176

2178 }

2179

2183 AffineMap map = boundMap.getValue();

2184

2185

2186

2187

2188

2189

2190

2193

2194

2196 if (auto constExpr = dyn_cast(expr)) {

2197 p << constExpr.getValue();

2198 return;

2199 }

2200 }

2201

2202

2203

2205 if (isa(expr)) {

2207 return;

2208 }

2209 }

2210 } else {

2211

2212 p << prefix << ' ';

2213 }

2214

2215

2216 p << boundMap;

2219 }

2220

2221 unsigned AffineForOp::getNumIterOperands() {

2222 AffineMap lbMap = getLowerBoundMapAttr().getValue();

2223 AffineMap ubMap = getUpperBoundMapAttr().getValue();

2224

2226 }

2227

2228 std::optional<MutableArrayRef>

2229 AffineForOp::getYieldedValuesMutable() {

2230 return cast(getBody()->getTerminator()).getOperandsMutable();

2231 }

2232

2234 p << ' ';

2236 true);

2237 p << " = ";

2239 p << " to ";

2241

2242 if (getStepAsInt() != 1)

2243 p << " step " << getStepAsInt();

2244

2245 bool printBlockTerminators = false;

2246 if (getNumIterOperands() > 0) {

2247 p << " iter_args(";

2248 auto regionArgs = getRegionIterArgs();

2249 auto operands = getInits();

2250

2251 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {

2252 p << std::get<0>(it) << " = " << std::get<1>(it);

2253 });

2254 p << ") -> (" << getResultTypes() << ")";

2255 printBlockTerminators = true;

2256 }

2257

2258 p << ' ';

2259 p.printRegion(getRegion(), false,

2260 printBlockTerminators);

2262 (*this)->getAttrs(),

2263 {getLowerBoundMapAttrName(getOperation()->getName()),

2264 getUpperBoundMapAttrName(getOperation()->getName()),

2265 getStepAttrName(getOperation()->getName()),

2266 getOperandSegmentSizeAttr()});

2267 }

2268

2269

2271 auto foldLowerOrUpperBound = [&forOp](bool lower) {

2272

2273

2275 auto boundOperands =

2276 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();

2277 for (auto operand : boundOperands) {

2280 operandConstants.push_back(operandCst);

2281 }

2282

2284 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();

2286 "bound maps should have at least one result");

2288 if (failed(boundMap.constantFold(operandConstants, foldedResults)))

2289 return failure();

2290

2291

2292 assert(!foldedResults.empty() && "bounds should have at least one result");

2293 auto maxOrMin = llvm::cast(foldedResults[0]).getValue();

2294 for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {

2295 auto foldedResult = llvm::cast(foldedResults[i]).getValue();

2296 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)

2297 : llvm::APIntOps::smin(maxOrMin, foldedResult);

2298 }

2299 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())

2300 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());

2301 return success();

2302 };

2303

2304

2305 bool folded = false;

2306 if (!forOp.hasConstantLowerBound())

2307 folded |= succeeded(foldLowerOrUpperBound(true));

2308

2309

2310 if (!forOp.hasConstantUpperBound())

2311 folded |= succeeded(foldLowerOrUpperBound(false));

2312 return success(folded);

2313 }

2314

2315

2319

2320 auto lbMap = forOp.getLowerBoundMap();

2321 auto ubMap = forOp.getUpperBoundMap();

2322 auto prevLbMap = lbMap;

2323 auto prevUbMap = ubMap;

2324

2330

2334

2335

2336 if (lbMap == prevLbMap && ubMap == prevUbMap)

2337 return failure();

2338

2339 if (lbMap != prevLbMap)

2340 forOp.setLowerBound(lbOperands, lbMap);

2341 if (ubMap != prevUbMap)

2342 forOp.setUpperBound(ubOperands, ubMap);

2343 return success();

2344 }

2345

2346 namespace {

2347

2348 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {

2349 int64_t step = forOp.getStepAsInt();

2350 if (!forOp.hasConstantBounds() || step <= 0)

2351 return std::nullopt;

2352 int64_t lb = forOp.getConstantLowerBound();

2353 int64_t ub = forOp.getConstantUpperBound();

2354 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;

2355 }

2356

2357

2358

2359 struct AffineForEmptyLoopFolder : public OpRewritePattern {

2361

2362 LogicalResult matchAndRewrite(AffineForOp forOp,

2364

2365 if (!llvm::hasSingleElement(*forOp.getBody()))

2366 return failure();

2367 if (forOp.getNumResults() == 0)

2368 return success();

2369 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);

2370 if (tripCount == 0) {

2371

2372

2373 rewriter.replaceOp(forOp, forOp.getInits());

2374 return success();

2375 }

2377 auto yieldOp = cast(forOp.getBody()->getTerminator());

2378 auto iterArgs = forOp.getRegionIterArgs();

2379 bool hasValDefinedOutsideLoop = false;

2380 bool iterArgsNotInOrder = false;

2381 for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {

2382 Value val = yieldOp.getOperand(i);

2383 auto *iterArgIt = llvm::find(iterArgs, val);

2384

2385

2386 if (val == forOp.getInductionVar())

2387 return failure();

2388 if (iterArgIt == iterArgs.end()) {

2389

2390 assert(forOp.isDefinedOutsideOfLoop(val) &&

2391 "must be defined outside of the loop");

2392 hasValDefinedOutsideLoop = true;

2393 replacements.push_back(val);

2394 } else {

2395 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);

2396 if (pos != i)

2397 iterArgsNotInOrder = true;

2398 replacements.push_back(forOp.getInits()[pos]);

2399 }

2400 }

2401

2402

2403 if (!tripCount.has_value() &&

2404 (hasValDefinedOutsideLoop || iterArgsNotInOrder))

2405 return failure();

2406

2407

2408 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)

2409 return failure();

2410 rewriter.replaceOp(forOp, replacements);

2411 return success();

2412 }

2413 };

2414 }

2415

2416 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,

2418 results.add(context);

2419 }

2420

2422 assert((point.isParent() || point == getRegion()) && "invalid region point");

2423

2424

2425

2426 return getInits();

2427 }

2428

2429 void AffineForOp::getSuccessorRegions(

2431 assert((point.isParent() || point == getRegion()) && "expected loop region");

2432

2433

2434

2435

2436 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);

2437 if (point.isParent() && tripCount.has_value()) {

2438 if (tripCount.value() > 0) {

2439 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));

2440 return;

2441 }

2442 if (tripCount.value() == 0) {

2444 return;

2445 }

2446 }

2447

2448

2449

2450 if (!point.isParent() && tripCount == 1) {

2452 return;

2453 }

2454

2455

2456

2457 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));

2459 }

2460

2461

2463 return getTrivialConstantTripCount(op) == 0;

2464 }

2465

2466 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,

2471

2472

2473

2474

2475

2476 results.assign(getInits().begin(), getInits().end());

2477 folded = true;

2478 }

2479 return success(folded);

2480 }

2481

2484 }

2485

2488 }

2489

2491 assert(lbOperands.size() == map.getNumInputs());

2492 assert(map.getNumResults() >= 1 && "bound map has at least one result");

2493 getLowerBoundOperandsMutable().assign(lbOperands);

2494 setLowerBoundMap(map);

2495 }

2496

2498 assert(ubOperands.size() == map.getNumInputs());

2499 assert(map.getNumResults() >= 1 && "bound map has at least one result");

2500 getUpperBoundOperandsMutable().assign(ubOperands);

2501 setUpperBoundMap(map);

2502 }

2503

2504 bool AffineForOp::hasConstantLowerBound() {

2505 return getLowerBoundMap().isSingleConstant();

2506 }

2507

2508 bool AffineForOp::hasConstantUpperBound() {

2509 return getUpperBoundMap().isSingleConstant();

2510 }

2511

2512 int64_t AffineForOp::getConstantLowerBound() {

2513 return getLowerBoundMap().getSingleConstantResult();

2514 }

2515

2516 int64_t AffineForOp::getConstantUpperBound() {

2517 return getUpperBoundMap().getSingleConstantResult();

2518 }

2519

2520 void AffineForOp::setConstantLowerBound(int64_t value) {

2522 }

2523

2524 void AffineForOp::setConstantUpperBound(int64_t value) {

2526 }

2527

2528 AffineForOp::operand_range AffineForOp::getControlOperands() {

2531 }

2532

2533 bool AffineForOp::matchingBoundOperandList() {

2534 auto lbMap = getLowerBoundMap();

2535 auto ubMap = getUpperBoundMap();

2538 return false;

2539

2540 unsigned numOperands = lbMap.getNumInputs();

2541 for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {

2542

2543 if (getOperand(i) != getOperand(numOperands + i))

2544 return false;

2545 }

2546 return true;

2547 }

2548

2550

2551 std::optional<SmallVector> AffineForOp::getLoopInductionVars() {

2553 }

2554

2555 std::optional<SmallVector> AffineForOp::getLoopLowerBounds() {

2556 if (!hasConstantLowerBound())

2557 return std::nullopt;

2560 OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};

2561 }

2562

2563 std::optional<SmallVector> AffineForOp::getLoopSteps() {

2566 OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};

2567 }

2568

2569 std::optional<SmallVector> AffineForOp::getLoopUpperBounds() {

2570 if (!hasConstantUpperBound())

2571 return {};

2574 OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};

2575 }

2576

2577 FailureOr AffineForOp::replaceWithAdditionalYields(

2579 bool replaceInitOperandUsesInLoop,

2581

2584 auto inits = llvm::to_vector(getInits());

2585 inits.append(newInitOperands.begin(), newInitOperands.end());

2586 AffineForOp newLoop = rewriter.create(

2589

2590

2591 auto yieldOp = cast(getBody()->getTerminator());

2593 newLoop.getBody()->getArguments().take_back(newInitOperands.size());

2594 {

2598 newYieldValuesFn(rewriter, getLoc(), newIterArgs);

2599 assert(newInitOperands.size() == newYieldedValues.size() &&

2600 "expected as many new yield values as new iter operands");

2602 yieldOp.getOperandsMutable().append(newYieldedValues);

2603 });

2604 }

2605

2606

2607 rewriter.mergeBlocks(getBody(), newLoop.getBody(),

2608 newLoop.getBody()->getArguments().take_front(

2609 getBody()->getNumArguments()));

2610

2611 if (replaceInitOperandUsesInLoop) {

2612

2613

2614 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {

2619 });

2620 }

2621 }

2622

2623

2624 rewriter.replaceOp(getOperation(),

2625 newLoop->getResults().take_front(getNumResults()));

2626 return cast(newLoop.getOperation());

2627 }

2628

2630

2631

2632

2633

2634

2637 }

2638

2639

2640

2643 }

2644

2647 }

2648

2651 }

2652

2654 auto ivArg = llvm::dyn_cast(val);

2655 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())

2656 return AffineForOp();

2657 if (auto forOp =

2658 ivArg.getOwner()->getParent()->getParentOfType())

2659

2660 return forOp.getInductionVar() == val ? forOp : AffineForOp();

2661 return AffineForOp();

2662 }

2663

2665 auto ivArg = llvm::dyn_cast(val);

2666 if (!ivArg || !ivArg.getOwner())

2667 return nullptr;

2669 auto parallelOp = dyn_cast_if_present(containingOp);

2670 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))

2671 return parallelOp;

2672 return nullptr;

2673 }

2674

2675

2676

2679 ivs->reserve(forInsts.size());

2680 for (auto forInst : forInsts)

2681 ivs->push_back(forInst.getInductionVar());

2682 }

2683

2686 ivs.reserve(affineOps.size());

2687 for (Operation *op : affineOps) {

2688

2689 if (auto forOp = dyn_cast(op))

2690 ivs.push_back(forOp.getInductionVar());

2691 else if (auto parallelOp = dyn_cast(op))

2692 for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)

2693 ivs.push_back(parallelOp.getBody()->getArgument(i));

2694 }

2695 }

2696

2697

2698

2699 template <typename BoundListTy, typename LoopCreatorTy>

2701 OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,

2704 LoopCreatorTy &&loopCreatorFn) {

2705 assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");

2706 assert(lbs.size() == steps.size() && "Mismatch in number of arguments");

2707

2708

2710 if (lbs.empty()) {

2711 if (bodyBuilderFn)

2712 bodyBuilderFn(builder, loc, ValueRange());

2713 return;

2714 }

2715

2716

2718 ivs.reserve(lbs.size());

2719 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {

2720

2723 ivs.push_back(iv);

2724

2725 if (i == e - 1 && bodyBuilderFn) {

2727 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);

2728 }

2729 nestedBuilder.create(nestedLoc);

2730 };

2731

2732

2733

2734 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);

2736 }

2737 }

2738

2739

2740 static AffineForOp

2742 int64_t ub, int64_t step,

2743 AffineForOp::BodyBuilderFn bodyBuilderFn) {

2744 return builder.create(loc, lb, ub, step,

2745 std::nullopt, bodyBuilderFn);

2746 }

2747

2748

2749 static AffineForOp

2751 int64_t step,

2752 AffineForOp::BodyBuilderFn bodyBuilderFn) {

2755 if (lbConst && ubConst)

2757 ubConst.value(), step, bodyBuilderFn);

2760 std::nullopt, bodyBuilderFn);

2761 }

2762

2769 }

2770

2777 }

2778

2779

2780

2781

2782

2783 namespace {

2784

2785 struct SimplifyDeadElse : public OpRewritePattern {

2787

2788 LogicalResult matchAndRewrite(AffineIfOp ifOp,

2790 if (ifOp.getElseRegion().empty() ||

2791 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())

2792 return failure();

2793

2795 rewriter.eraseBlock(ifOp.getElseBlock());

2797 return success();

2798 }

2799 };

2800

2801

2802

2803 struct AlwaysTrueOrFalseIf : public OpRewritePattern {

2805

2806 LogicalResult matchAndRewrite(AffineIfOp op,

2808

2809 auto isTriviallyFalse = [](IntegerSet iSet) {

2810 return iSet.isEmptyIntegerSet();

2811 };

2812

2813 auto isTriviallyTrue = [](IntegerSet iSet) {

2814 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&

2815 iSet.getConstraint(0) == 0);

2816 };

2817

2818 IntegerSet affineIfConditions = op.getIntegerSet();

2819 Block *blockToMove;

2820 if (isTriviallyFalse(affineIfConditions)) {

2821

2822

2823

2824 if (op.getNumResults() == 0 && !op.hasElse()) {

2825

2826

2828 return success();

2829 }

2830 blockToMove = op.getElseBlock();

2831 } else if (isTriviallyTrue(affineIfConditions)) {

2832 blockToMove = op.getThenBlock();

2833 } else {

2834 return failure();

2835 }

2837

2838

2840

2841

2842

2843

2844

2845

2846

2848

2849

2850 rewriter.eraseOp(blockToMoveTerminator);

2851 return success();

2852 }

2853 };

2854 }

2855

2856

2857

2858 void AffineIfOp::getSuccessorRegions(

2860

2861

2863 regions.reserve(2);

2864 regions.push_back(

2865 RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));

2866

2867 if (getElseRegion().empty()) {

2868 regions.push_back(getResults());

2869 } else {

2870 regions.push_back(

2871 RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));

2872 }

2873 return;

2874 }

2875

2876

2877

2879 }

2880

2882

2883

2884 auto conditionAttr =

2885 (*this)->getAttrOfType(getConditionAttrStrName());

2886 if (!conditionAttr)

2887 return emitOpError("requires an integer set attribute named 'condition'");

2888

2889

2890 IntegerSet condition = conditionAttr.getValue();

2891 if (getNumOperands() != condition.getNumInputs())

2892 return emitOpError("operand count and condition integer set dimension and "

2893 "symbol count must match");

2894

2895

2898 return failure();

2899

2900 return success();

2901 }

2902

2904

2905 IntegerSetAttr conditionAttr;

2906 unsigned numDims;

2908 AffineIfOp::getConditionAttrStrName(),

2911 return failure();

2912

2913

2914 auto set = conditionAttr.getValue();

2915 if (set.getNumDims() != numDims)

2918 "dim operand count and integer set dim count must match");

2919 if (numDims + set.getNumSymbols() != result.operands.size())

2922 "symbol operand count and integer set symbol count must match");

2923

2925 return failure();

2926

2927

2928

2929 result.regions.reserve(2);

2932

2933

2934 if (parser.parseRegion(*thenRegion, {}, {}))

2935 return failure();

2936 AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),

2938

2939

2941 if (parser.parseRegion(*elseRegion, {}, {}))

2942 return failure();

2943 AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),

2945 }

2946

2947

2949 return failure();

2950

2951 return success();

2952 }

2953

2955 auto conditionAttr =

2956 (*this)->getAttrOfType(getConditionAttrStrName());

2957 p << " " << conditionAttr;

2959 conditionAttr.getValue().getNumDims(), p);

2961 p << ' ';

2962 p.printRegion(getThenRegion(), false,

2963 getNumResults());

2964

2965

2966 auto &elseRegion = this->getElseRegion();

2967 if (!elseRegion.empty()) {

2968 p << " else ";

2970 false,

2971 getNumResults());

2972 }

2973

2974

2976 getConditionAttrStrName());

2977 }

2978

2979 IntegerSet AffineIfOp::getIntegerSet() {

2980 return (*this)

2981 ->getAttrOfType(getConditionAttrStrName())

2982 .getValue();

2983 }

2984

2985 void AffineIfOp::setIntegerSet(IntegerSet newSet) {

2986 (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));

2987 }

2988

2990 setIntegerSet(set);

2991 (*this)->setOperands(operands);

2992 }

2993

2996 bool withElseRegion) {

2997 assert(resultTypes.empty() || withElseRegion);

2999

3000 result.addTypes(resultTypes);

3003

3006 if (resultTypes.empty())

3007 AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);

3008

3010 if (withElseRegion) {

3012 if (resultTypes.empty())

3013 AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);

3014 }

3015 }

3016

3019 AffineIfOp::build(builder, result, {}, set, args,

3020 withElseRegion);

3021 }

3022

3023

3024

3025

3028

3029

3030

3033

3034 if (llvm::none_of(operands,

3036 return;

3037

3041 }

3042

3043

3045 auto set = getIntegerSet();

3049

3050

3051 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))

3052 return failure();

3053

3054 setConditional(set, operands);

3055 return success();

3056 }

3057

3058 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,

3060 results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);

3061 }

3062

3063

3064

3065

3066

3069 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");

3071 if (map)

3073 auto memrefType = llvm::cast(operands[0].getType());

3074 result.types.push_back(memrefType.getElementType());

3075 }

3076

3079 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");

3082 auto memrefType = llvm::cast(memref.getType());

3084 result.types.push_back(memrefType.getElementType());

3085 }

3086

3089 auto memrefType = llvm::cast(memref.getType());

3090 int64_t rank = memrefType.getRank();

3091

3092

3093 auto map =

3095 build(builder, result, memref, map, indices);

3096 }

3097

3099 auto &builder = parser.getBuilder();

3101

3102 MemRefType type;

3104 AffineMapAttr mapAttr;

3106 return failure(

3109 AffineLoadOp::getMapAttrStrName(),

3116 }

3117

3119 p << " " << getMemRef() << '[';

3120 if (AffineMapAttr mapAttr =

3121 (*this)->getAttrOfType(getMapAttrStrName()))

3123 p << ']';

3125 {getMapAttrStrName()});

3127 }

3128

3129

3130

3131 template

3132 static LogicalResult

3135 MemRefType memrefType, unsigned numIndexOperands) {

3136 AffineMap map = mapAttr.getValue();

3137 if (map.getNumResults() != memrefType.getRank())

3138 return op->emitOpError("affine map num results must equal memref rank");

3140 return op->emitOpError("expects as many subscripts as affine map inputs");

3141

3142 for (auto idx : mapOperands) {

3143 if (!idx.getType().isIndex())

3144 return op->emitOpError("index to load must have 'index' type");

3145 }

3147 return failure();

3148

3149 return success();

3150 }

3151

3154 if (getType() != memrefType.getElementType())

3155 return emitOpError("result type must match element type of memref");

3156

3158 *this, (*this)->getAttrOfType(getMapAttrStrName()),

3159 getMapOperands(), memrefType,

3160 getNumOperands() - 1)))

3161 return failure();

3162

3163 return success();

3164 }

3165

3166 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,

3168 results.add<SimplifyAffineOp>(context);

3169 }

3170

3171 OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {

3172

3174 return getResult();

3175

3176

3177 auto getGlobalOp = getMemref().getDefiningOpmemref::GetGlobalOp();

3178 if (!getGlobalOp)

3179 return {};

3180

3181 auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();

3182 if (!symbolTableOp)

3183 return {};

3184 auto global = dyn_cast_or_nullmemref::GlobalOp(

3186 if (!global)

3187 return {};

3188

3189

3190 auto cstAttr =

3191 llvm::dyn_cast_or_null(global.getConstantInitValue());

3192 if (!cstAttr)

3193 return {};

3194

3195 if (auto splatAttr = llvm::dyn_cast(cstAttr))

3196 return splatAttr.getSplatValue<Attribute>();

3197

3198 if (!getAffineMap().isConstant())

3199 return {};

3200 auto indices = llvm::to_vector<4>(

3201 llvm::map_range(getAffineMap().getConstantResults(),

3202 [](int64_t v) -> uint64_t { return v; }));

3203 return cstAttr.getValues<Attribute>()[indices];

3204 }

3205

3206

3207

3208

3209

3213 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");

3218 }

3219

3220

3224 auto memrefType = llvm::cast(memref.getType());

3225 int64_t rank = memrefType.getRank();

3226

3227

3228 auto map =

3230 build(builder, result, valueToStore, memref, map, indices);

3231 }

3232

3235

3236 MemRefType type;

3239 AffineMapAttr mapAttr;

3244 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),

3248 parser.resolveOperand(storeValueInfo, type.getElementType(),

3252 }

3253

3255 p << " " << getValueToStore();

3256 p << ", " << getMemRef() << '[';

3257 if (AffineMapAttr mapAttr =

3258 (*this)->getAttrOfType(getMapAttrStrName()))

3260 p << ']';

3262 {getMapAttrStrName()});

3264 }

3265

3267

3269 if (getValueToStore().getType() != memrefType.getElementType())

3270 return emitOpError(

3271 "value to store must have the same type as memref element type");

3272

3274 *this, (*this)->getAttrOfType(getMapAttrStrName()),

3275 getMapOperands(), memrefType,

3276 getNumOperands() - 2)))

3277 return failure();

3278

3279 return success();

3280 }

3281

3282 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,

3284 results.add<SimplifyAffineOp>(context);

3285 }

3286

3287 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,

3289

3291 }

3292

3293

3294

3295

3296

3297 template

3299

3300 if (op.getNumOperands() !=

3301 op.getMap().getNumDims() + op.getMap().getNumSymbols())

3302 return op.emitOpError(

3303 "operand count and affine map dimension and symbol count must match");

3304

3305 if (op.getMap().getNumResults() == 0)

3306 return op.emitOpError("affine map expect at least one result");

3307 return success();

3308 }

3309

3310 template

3312 p << ' ' << op->getAttr(T::getMapAttrStrName());

3313 auto operands = op.getOperands();

3314 unsigned numDims = op.getMap().getNumDims();

3315 p << '(' << operands.take_front(numDims) << ')';

3316

3317 if (operands.size() != numDims)

3318 p << '[' << operands.drop_front(numDims) << ']';

3320 {T::getMapAttrStrName()});

3321 }

3322

3323 template

3326 auto &builder = parser.getBuilder();

3330 AffineMapAttr mapAttr;

3331 return failure(

3332 parser.parseAttribute(mapAttr, T::getMapAttrStrName(),

3341 }

3342

3343

3344

3345

3346 template

3348 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,

3349 "expected affine min or max op");

3350

3351

3352

3353

3355 auto foldedMap = op.getMap().partialConstantFold(operands, &results);

3356

3357 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())

3358 return op.getOperand(0);

3359

3360

3361 if (results.empty()) {

3362

3363 if (foldedMap == op.getMap())

3364 return {};

3366 return op.getResult();

3367 }

3368

3369

3370 auto resultIt = std::is_same<T, AffineMinOp>::value

3371 ? llvm::min_element(results)

3372 : llvm::max_element(results);

3373 if (resultIt == results.end())

3374 return {};

3376 }

3377

3378

3379 template

3382

3385 AffineMap oldMap = affineOp.getAffineMap();

3386

3389

3390

3391 if (!llvm::is_contained(newExprs, expr))

3392 newExprs.push_back(expr);

3393 }

3394

3396 return failure();

3397

3400 rewriter.replaceOpWithNewOp(affineOp, newMap, affineOp.getMapOperands());

3401

3402 return success();

3403 }

3404 };

3405

3406

3407

3408

3409

3410

3411

3412

3413

3414

3415

3416

3417

3418

3419

3420

3421

3422 template

3425

3428 AffineMap oldMap = affineOp.getAffineMap();

3430 affineOp.getMapOperands().take_front(oldMap.getNumDims());

3432 affineOp.getMapOperands().take_back(oldMap.getNumSymbols());

3433

3434 auto newDimOperands = llvm::to_vector<8>(dimOperands);

3435 auto newSymOperands = llvm::to_vector<8>(symOperands);

3438

3439

3440

3441

3443 if (auto symExpr = dyn_cast(expr)) {

3444 Value symValue = symOperands[symExpr.getPosition()];

3445 if (auto producerOp = symValue.getDefiningOp()) {

3446 producerOps.push_back(producerOp);

3447 continue;

3448 }

3449 } else if (auto dimExpr = dyn_cast(expr)) {

3450 Value dimValue = dimOperands[dimExpr.getPosition()];

3451 if (auto producerOp = dimValue.getDefiningOp()) {

3452 producerOps.push_back(producerOp);

3453 continue;

3454 }

3455 }

3456

3457

3458

3459 newExprs.push_back(expr);

3460 }

3461

3462 if (producerOps.empty())

3463 return failure();

3464

3465 unsigned numUsedDims = oldMap.getNumDims();

3467

3468

3469 for (T producerOp : producerOps) {

3470 AffineMap producerMap = producerOp.getAffineMap();

3471 unsigned numProducerDims = producerMap.getNumDims();

3472 unsigned numProducerSyms = producerMap.getNumSymbols();

3473

3474

3476 producerOp.getMapOperands().take_front(numProducerDims);

3478 producerOp.getMapOperands().take_back(numProducerSyms);

3479 newDimOperands.append(dimValues.begin(), dimValues.end());

3480 newSymOperands.append(symValues.begin(), symValues.end());

3481

3482

3484 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)

3485 .shiftSymbols(numProducerSyms, numUsedSyms));

3486 }

3487

3488 numUsedDims += numProducerDims;

3489 numUsedSyms += numProducerSyms;

3490 }

3491

3492 auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,

3494 auto newOperands =

3495 llvm::to_vector<8>(llvm::concat(newDimOperands, newSymOperands));

3497

3498 return success();

3499 }

3500 };

3501

3502

3503

3504

3505

3506

3507

3508

3509

3513

3514 if (!resultExpr.isPureAffine())

3515 return failure();

3516

3518 auto flattenResult = flattener.walkPostOrder(resultExpr);

3519 if (failed(flattenResult))

3520 return failure();

3521

3522

3525 return failure();

3526

3527 flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),

3529 }

3530

3531

3532 if (llvm::is_sorted(flattenedExprs))

3533 return failure();

3534

3535

3537 llvm::to_vector(llvm::seq(0, map.getNumResults()));

3538 llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {

3539 return flattenedExprs[lhs] < flattenedExprs[rhs];

3540 });

3542 for (unsigned idx : resultPermutation)

3543 newExprs.push_back(map.getResult(idx));

3544

3547 return success();

3548 }

3549

3550

3551

3552

3553

3554

3555

3556

3557

3558

3559

3560

3561

3562

3563 template

3566

3569 AffineMap map = affineOp.getAffineMap();

3571 return failure();

3572 rewriter.replaceOpWithNewOp(affineOp, map, affineOp.getMapOperands());

3573 return success();

3574 }

3575 };

3576

3577 template

3580

3583 if (affineOp.getMap().getNumResults() != 1)

3584 return failure();

3585 rewriter.replaceOpWithNewOp(affineOp, affineOp.getMap(),

3586 affineOp.getOperands());

3587 return success();

3588 }

3589 };

3590

3591

3592

3593

3594

3595

3596

3597

3598 OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {

3599 return foldMinMaxOp(*this, adaptor.getOperands());

3600 }

3601

3608 context);

3609 }

3610

3612

3614 return parseAffineMinMaxOp(parser, result);

3615 }

3616

3618

3619

3620

3621

3622

3623

3624

3625

3626 OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {

3627 return foldMinMaxOp(*this, adaptor.getOperands());

3628 }

3629

3636 context);

3637 }

3638

3640

3642 return parseAffineMinMaxOp(parser, result);

3643 }

3644

3646

3647

3648

3649

3650

3651

3652

3653

3656 auto &builder = parser.getBuilder();

3658

3659 MemRefType type;

3661 IntegerAttr hintInfo;

3663 StringRef readOrWrite, cacheType;

3664

3665 AffineMapAttr mapAttr;

3669 AffinePrefetchOp::getMapAttrStrName(),

3675 AffinePrefetchOp::getLocalityHintAttrStrName(),

3683 return failure();

3684

3685 if (readOrWrite != "read" && readOrWrite != "write")

3687 "rw specifier has to be 'read' or 'write'");

3688 result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),

3690

3691 if (cacheType != "data" && cacheType != "instr")

3693 "cache type has to be 'data' or 'instr'");

3694

3695 result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),

3697

3698 return success();

3699 }

3700

3702 p << " " << getMemref() << '[';

3703 AffineMapAttr mapAttr =

3704 (*this)->getAttrOfType(getMapAttrStrName());

3705 if (mapAttr)

3707 p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "

3708 << "locality<" << getLocalityHint() << ">, "

3709 << (getIsDataCache() ? "data" : "instr");

3711 (*this)->getAttrs(),

3712 {getMapAttrStrName(), getLocalityHintAttrStrName(),

3713 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});

3715 }

3716

3718 auto mapAttr = (*this)->getAttrOfType(getMapAttrStrName());

3719 if (mapAttr) {

3720 AffineMap map = mapAttr.getValue();

3722 return emitOpError("affine.prefetch affine map num results must equal"

3723 " memref rank");

3724 if (map.getNumInputs() + 1 != getNumOperands())

3725 return emitOpError("too few operands");

3726 } else {

3727 if (getNumOperands() != 1)

3728 return emitOpError("too few operands");

3729 }

3730

3732 for (auto idx : getMapOperands()) {

3734 return emitOpError(

3735 "index must be a valid dimension or symbol identifier");

3736 }

3737 return success();

3738 }

3739

3740 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,

3742

3743 results.add<SimplifyAffineOp>(context);

3744 }

3745

3746 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,

3748

3750 }

3751

3752

3753

3754

3755

3761 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {

3763 }));

3765 build(builder, result, resultTypes, reductions, lbs, {}, ubs,

3766 {}, steps);

3767 }

3768

3775 assert(llvm::all_of(lbMaps,

3777 return m.getNumDims() == lbMaps[0].getNumDims() &&

3778 m.getNumSymbols() == lbMaps[0].getNumSymbols();

3779 }) &&

3780 "expected all lower bounds maps to have the same number of dimensions "

3781 "and symbols");

3782 assert(llvm::all_of(ubMaps,

3784 return m.getNumDims() == ubMaps[0].getNumDims() &&

3785 m.getNumSymbols() == ubMaps[0].getNumSymbols();

3786 }) &&

3787 "expected all upper bounds maps to have the same number of dimensions "

3788 "and symbols");

3789 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&

3790 "expected lower bound maps to have as many inputs as lower bound "

3791 "operands");

3792 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&

3793 "expected upper bound maps to have as many inputs as upper bound "

3794 "operands");

3795

3797 result.addTypes(resultTypes);

3798

3799

3801 for (arith::AtomicRMWKind reduction : reductions)

3802 reductionAttrs.push_back(

3804 result.addAttribute(getReductionsAttrStrName(),

3806

3807

3808

3811 if (maps.empty())

3814 groups.reserve(groups.size() + maps.size());

3815 exprs.reserve(maps.size());

3817 llvm::append_range(exprs, m.getResults());

3818 groups.push_back(m.getNumResults());

3819 }

3820 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,

3822 };

3823

3824

3826 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);

3827 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);

3828 result.addAttribute(getLowerBoundsMapAttrStrName(),

3830 result.addAttribute(getLowerBoundsGroupsAttrStrName(),

3832 result.addAttribute(getUpperBoundsMapAttrStrName(),

3834 result.addAttribute(getUpperBoundsGroupsAttrStrName(),

3839

3840

3841 auto *bodyRegion = result.addRegion();

3843

3844

3845 for (unsigned i = 0, e = steps.size(); i < e; ++i)

3847 if (resultTypes.empty())

3848 ensureTerminator(*bodyRegion, builder, result.location);

3849 }

3850

3852 return {&getRegion()};

3853 }

3854

3855 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }

3856

3857 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {

3858 return getOperands().take_front(getLowerBoundsMap().getNumInputs());

3859 }

3860

3861 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {

3862 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());

3863 }

3864

3865 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {

3866 auto values = getLowerBoundsGroups().getValues<int32_t>();

3867 unsigned start = 0;

3868 for (unsigned i = 0; i < pos; ++i)

3869 start += values[i];

3870 return getLowerBoundsMap().getSliceMap(start, values[pos]);

3871 }

3872

3873 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {

3874 auto values = getUpperBoundsGroups().getValues<int32_t>();

3875 unsigned start = 0;

3876 for (unsigned i = 0; i < pos; ++i)

3877 start += values[i];

3878 return getUpperBoundsMap().getSliceMap(start, values[pos]);

3879 }

3880

3881 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {

3882 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());

3883 }

3884

3885 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {

3886 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());

3887 }

3888

3889 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {

3890 if (hasMinMaxBounds())

3891 return std::nullopt;

3892

3893

3896 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),

3897 &rangesValueMap);

3899 for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {

3900 auto expr = rangesValueMap.getResult(i);

3901 auto cst = dyn_cast(expr);

3902 if (!cst)

3903 return std::nullopt;

3904 out.push_back(cst.getValue());

3905 }

3906 return out;

3907 }

3908

3909 Block *AffineParallelOp::getBody() { return &getRegion().front(); }

3910

3911 OpBuilder AffineParallelOp::getBodyBuilder() {

3912 return OpBuilder(getBody(), std::prev(getBody()->end()));

3913 }

3914

3915 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {

3916 assert(lbOperands.size() == map.getNumInputs() &&

3917 "operands to map must match number of inputs");

3918

3919 auto ubOperands = getUpperBoundsOperands();

3920

3922 newOperands.append(ubOperands.begin(), ubOperands.end());

3923 (*this)->setOperands(newOperands);

3924

3926 }

3927

3928 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {

3929 assert(ubOperands.size() == map.getNumInputs() &&

3930 "operands to map must match number of inputs");

3931

3933 newOperands.append(ubOperands.begin(), ubOperands.end());

3934 (*this)->setOperands(newOperands);

3935

3937 }

3938

3940 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));

3941 }

3942

3943

3945 arith::AtomicRMWKind op) {

3946 switch (op) {

3947 case arith::AtomicRMWKind::addf:

3948 return isa(resultType);

3949 case arith::AtomicRMWKind::addi:

3950 return isa(resultType);

3951 case arith::AtomicRMWKind::assign:

3952 return true;

3953 case arith::AtomicRMWKind::mulf:

3954 return isa(resultType);

3955 case arith::AtomicRMWKind::muli:

3956 return isa(resultType);

3957 case arith::AtomicRMWKind::maximumf:

3958 return isa(resultType);

3959 case arith::AtomicRMWKind::minimumf:

3960 return isa(resultType);

3961 case arith::AtomicRMWKind::maxs: {

3962 auto intType = llvm::dyn_cast(resultType);

3963 return intType && intType.isSigned();

3964 }

3965 case arith::AtomicRMWKind::mins: {

3966 auto intType = llvm::dyn_cast(resultType);

3967 return intType && intType.isSigned();

3968 }

3969 case arith::AtomicRMWKind::maxu: {

3970 auto intType = llvm::dyn_cast(resultType);

3971 return intType && intType.isUnsigned();

3972 }

3973 case arith::AtomicRMWKind::minu: {

3974 auto intType = llvm::dyn_cast(resultType);

3975 return intType && intType.isUnsigned();

3976 }

3977 case arith::AtomicRMWKind::ori:

3978 return isa(resultType);

3979 case arith::AtomicRMWKind::andi:

3980 return isa(resultType);

3981 default:

3982 return false;

3983 }

3984 }

3985

3987 auto numDims = getNumDims();

3988 if (getLowerBoundsGroups().getNumElements() != numDims ||

3989 getUpperBoundsGroups().getNumElements() != numDims ||

3990 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {

3991 return emitOpError() << "the number of region arguments ("

3992 << getBody()->getNumArguments()

3993 << ") and the number of map groups for lower ("

3994 << getLowerBoundsGroups().getNumElements()

3995 << ") and upper bound ("

3996 << getUpperBoundsGroups().getNumElements()

3997 << "), and the number of steps (" << getSteps().size()

3998 << ") must all match";

3999 }

4000

4001 unsigned expectedNumLBResults = 0;

4002 for (APInt v : getLowerBoundsGroups()) {

4003 unsigned results = v.getZExtValue();

4004 if (results == 0)

4005 return emitOpError()

4006 << "expected lower bound map to have at least one result";

4007 expectedNumLBResults += results;

4008 }

4009 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())

4010 return emitOpError() << "expected lower bounds map to have "

4011 << expectedNumLBResults << " results";

4012 unsigned expectedNumUBResults = 0;

4013 for (APInt v : getUpperBoundsGroups()) {

4014 unsigned results = v.getZExtValue();

4015 if (results == 0)

4016 return emitOpError()

4017 << "expected upper bound map to have at least one result";

4018 expectedNumUBResults += results;

4019 }

4020 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())

4021 return emitOpError() << "expected upper bounds map to have "

4022 << expectedNumUBResults << " results";

4023

4024 if (getReductions().size() != getNumResults())

4025 return emitOpError("a reduction must be specified for each output");

4026

4027

4028

4031 auto intAttr = llvm::dyn_cast(attr);

4032 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))

4033 return emitOpError("invalid reduction attribute");

4034 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();

4036 return emitOpError("result type cannot match reduction attribute");

4037 }

4038

4039

4040

4042 getLowerBoundsMap().getNumDims())))

4043 return failure();

4044

4046 getUpperBoundsMap().getNumDims())))

4047 return failure();

4048 return success();

4049 }

4050

4051 LogicalResult AffineValueMap::canonicalize() {

4053 auto newMap = getAffineMap();

4055 if (newMap == getAffineMap() && newOperands == operands)

4056 return failure();

4057 reset(newMap, newOperands);

4058 return success();

4059 }

4060

4061

4064 bool lbCanonicalized = succeeded(lb.canonicalize());

4065

4067 bool ubCanonicalized = succeeded(ub.canonicalize());

4068

4069

4070 if (!lbCanonicalized && !ubCanonicalized)

4071 return failure();

4072

4073 if (lbCanonicalized)

4075 if (ubCanonicalized)

4077

4078 return success();

4079 }

4080

4081 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,

4084 }

4085

4086

4087

4088

4089

4090

4093 StringRef keyword) {

4094 AffineMap map = mapAttr.getValue();

4095 unsigned numDims = map.getNumDims();

4096 ValueRange dimOperands = operands.take_front(numDims);

4097 ValueRange symOperands = operands.drop_front(numDims);

4098 unsigned start = 0;

4099 for (llvm::APInt groupSize : group) {

4100 if (start != 0)

4101 p << ", ";

4102

4103 unsigned size = groupSize.getZExtValue();

4104 if (size == 1) {

4106 ++start;

4107 } else {

4108 p << keyword << '(';

4111 p << ')';

4112 start += size;

4113 }

4114 }

4115 }

4116

4118 p << " (" << getBody()->getArguments() << ") = (";

4119 printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),

4120 getLowerBoundsOperands(), "max");

4121 p << ") to (";

4122 printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),

4123 getUpperBoundsOperands(), "min");

4124 p << ')';

4126 bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });

4127 if (!elideSteps) {

4128 p << " step (";

4129 llvm::interleaveComma(steps, p);

4130 p << ')';

4131 }

4132 if (getNumResults()) {

4133 p << " reduce (";

4134 llvm::interleaveComma(getReductions(), p, [&](auto &attr) {

4135 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(

4136 llvm::cast(attr).getInt());

4137 p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";

4138 });

4139 p << ") -> (" << getResultTypes() << ")";

4140 }

4141

4142 p << ' ';

4143 p.printRegion(getRegion(), false,

4144 getNumResults());

4146 (*this)->getAttrs(),

4147 {AffineParallelOp::getReductionsAttrStrName(),

4148 AffineParallelOp::getLowerBoundsMapAttrStrName(),

4149 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),

4150 AffineParallelOp::getUpperBoundsMapAttrStrName(),

4151 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),

4152 AffineParallelOp::getStepsAttrStrName()});

4153 }

4154

4155

4156

4157

4158

4165 "expected operands to be dim or symbol expression");

4166

4168 for (const auto &list : operands) {

4170 if (parser.resolveOperands(list, indexType, valueOperands))

4171 return failure();

4172 for (Value operand : valueOperands) {

4173 unsigned pos = std::distance(uniqueOperands.begin(),

4174 llvm::find(uniqueOperands, operand));

4175 if (pos == uniqueOperands.size())

4176 uniqueOperands.push_back(operand);

4177 replacements.push_back(

4181 }

4182 }

4183 return success();

4184 }

4185

4186 namespace {

4187 enum class MinMaxKind { Min, Max };

4188 }

4189

4190

4191

4192

4193

4194

4195

4196

4197

4198

4199

4200

4201

4202

4203

4204

4205

4208 MinMaxKind kind) {

4209

4210

4211 const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";

4212

4213 StringRef mapName = kind == MinMaxKind::Min

4214 ? AffineParallelOp::getUpperBoundsMapAttrStrName()

4215 : AffineParallelOp::getLowerBoundsMapAttrStrName();

4216 StringRef groupsName =

4217 kind == MinMaxKind::Min

4218 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()

4219 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();

4220

4222 return failure();

4223

4228 return success();

4229 }

4230

4236 auto parseOperands = [&]() {

4238 kind == MinMaxKind::Min ? "min" : "max"))) {

4239 mapOperands.clear();

4240 AffineMapAttr map;

4244 return failure();

4246 llvm::append_range(flatExprs, map.getValue().getResults());

4248 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());

4250 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());

4252 flatDimOperands.append(map.getValue().getNumResults(), dims);

4253 flatSymOperands.append(map.getValue().getNumResults(), syms);

4254 numMapsPerGroup.push_back(map.getValue().getNumResults());

4255 } else {

4257 flatSymOperands.emplace_back(),

4258 flatExprs.emplace_back())))

4259 return failure();

4260 numMapsPerGroup.push_back(1);

4261 }

4262 return success();

4263 };

4265 return failure();

4266

4267 unsigned totalNumDims = 0;

4268 unsigned totalNumSyms = 0;

4269 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {

4270 unsigned numDims = flatDimOperands[i].size();

4271 unsigned numSyms = flatSymOperands[i].size();

4272 flatExprs[i] = flatExprs[i]

4273 .shiftDims(numDims, totalNumDims)

4274 .shiftSymbols(numSyms, totalNumSyms);

4275 totalNumDims += numDims;

4276 totalNumSyms += numSyms;

4277 }

4278

4279

4286 return failure();

4287

4288 result.operands.append(dimOperands.begin(), dimOperands.end());

4289 result.operands.append(symOperands.begin(), symOperands.end());

4290

4292 auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,

4294 flatMap = flatMap.replaceDimsAndSymbols(

4295 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());

4296

4299 return success();

4300 }

4301

4302

4303

4304

4305

4306

4309 auto &builder = parser.getBuilder();

4317 return failure();

4318

4319 AffineMapAttr stepsMapAttr;

4324 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),

4326 } else {

4328 AffineParallelOp::getStepsAttrStrName(),

4329 stepsAttrs,

4331 return failure();

4332

4333

4335 auto stepsMap = stepsMapAttr.getValue();

4336 for (const auto &result : stepsMap.getResults()) {

4337 auto constExpr = dyn_cast(result);

4338 if (!constExpr)

4340 "steps must be constant integers");

4341 steps.push_back(constExpr.getValue());

4342 }

4343 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),

4345 }

4346

4347

4348

4352 return failure();

4353 auto parseAttributes = [&]() -> ParseResult {

4354

4355

4356

4357 StringAttr attrVal;

4361 attrStorage))

4362 return failure();

4363 std::optionalarith::AtomicRMWKind reduction =

4364 arith::symbolizeAtomicRMWKind(attrVal.getValue());

4365 if (!reduction)

4366 return parser.emitError(loc, "invalid reduction value: ") << attrVal;

4367 reductions.push_back(

4368 builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));

4369

4370 return success();

4371 };

4373 return failure();

4374 }

4375 result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),

4377

4378

4380 return failure();

4381

4382

4384 for (auto &iv : ivs)

4385 iv.type = indexType;

4388 return failure();

4389

4390

4391 AffineParallelOp::ensureTerminator(*body, builder, result.location);

4392 return success();

4393 }

4394

4395

4396

4397

4398

4400 auto *parentOp = (*this)->getParentOp();

4401 auto results = parentOp->getResults();

4402 auto operands = getOperands();

4403

4404 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))

4405 return emitOpError() << "only terminates affine.if/for/parallel regions";

4406 if (parentOp->getNumResults() != getNumOperands())

4407 return emitOpError() << "parent of yield must have same number of "

4408 "results as the yield operands";

4409 for (auto it : llvm::zip(results, operands)) {

4410 if (std::get<0>(it).getType() != std::get<1>(it).getType())

4411 return emitOpError() << "types mismatch between yield op and its parent";

4412 }

4413

4414 return success();

4415 }

4416

4417

4418

4419

4420

4422 VectorType resultType, AffineMap map,

4424 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");

4426 if (map)

4428 result.types.push_back(resultType);

4429 }

4430

4432 VectorType resultType, Value memref,

4434 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");

4438 result.types.push_back(resultType);

4439 }

4440

4442 VectorType resultType, Value memref,

4444 auto memrefType = llvm::cast(memref.getType());

4445 int64_t rank = memrefType.getRank();

4446

4447

4448 auto map =

4450 build(builder, result, resultType, memref, map, indices);

4451 }

4452

4453 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,

4455 results.add<SimplifyAffineOp>(context);

4456 }

4457

4460 auto &builder = parser.getBuilder();

4462

4463 MemRefType memrefType;

4464 VectorType resultType;

4466 AffineMapAttr mapAttr;

4468 return failure(

4471 AffineVectorLoadOp::getMapAttrStrName(),

4479 }

4480

4482 p << " " << getMemRef() << '[';

4483 if (AffineMapAttr mapAttr =

4484 (*this)->getAttrOfType(getMapAttrStrName()))

4486 p << ']';

4488 {getMapAttrStrName()});

4490 }

4491

4492

4494 VectorType vectorType) {

4495

4496 if (memrefType.getElementType() != vectorType.getElementType())

4498 "requires memref and vector types of the same elemental type");

4499 return success();

4500 }

4501

4505 *this, (*this)->getAttrOfType(getMapAttrStrName()),

4506 getMapOperands(), memrefType,

4507 getNumOperands() - 1)))

4508 return failure();

4509

4511 return failure();

4512

4513 return success();

4514 }

4515

4516

4517

4518

4519

4523 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");

4528 }

4529

4530

4534 auto memrefType = llvm::cast(memref.getType());

4535 int64_t rank = memrefType.getRank();

4536

4537

4538 auto map =

4540 build(builder, result, valueToStore, memref, map, indices);

4541 }

4542 void AffineVectorStoreOp::getCanonicalizationPatterns(

4544 results.add<SimplifyAffineOp>(context);

4545 }

4546

4550

4551 MemRefType memrefType;

4552 VectorType resultType;

4555 AffineMapAttr mapAttr;

4557 return failure(

4561 AffineVectorStoreOp::getMapAttrStrName(),

4569 }

4570

4572 p << " " << getValueToStore();

4573 p << ", " << getMemRef() << '[';

4574 if (AffineMapAttr mapAttr =

4575 (*this)->getAttrOfType(getMapAttrStrName()))

4577 p << ']';

4579 {getMapAttrStrName()});

4580 p << " : " << getMemRefType() << ", " << getValueToStore().getType();

4581 }

4582

4586 *this, (*this)->getAttrOfType(getMapAttrStrName()),

4587 getMapOperands(), memrefType,

4588 getNumOperands() - 2)))

4589 return failure();

4590

4592 return failure();

4593

4594 return success();

4595 }

4596

4597

4598

4599

4600

4601 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,

4605 bool hasOuterBound) {

4606 SmallVector returnTypes(hasOuterBound ? staticBasis.size()

4607 : staticBasis.size() + 1,

4609 build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,

4610 staticBasis);

4611 }

4612

4613 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,

4616 bool hasOuterBound) {

4617 if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {

4618 hasOuterBound = false;

4619 basis = basis.drop_front();

4620 }

4624 staticBasis);

4625 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,

4626 hasOuterBound);

4627 }

4628

4629 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,

4631 Value linearIndex,

4633 bool hasOuterBound) {

4634 if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {

4635 hasOuterBound = false;

4636 basis = basis.drop_front();

4637 }

4641 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,

4642 hasOuterBound);

4643 }

4644

4645 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,

4648 bool hasOuterBound) {

4649 build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);

4650 }

4651

4654 if (getNumResults() != staticBasis.size() &&

4655 getNumResults() != staticBasis.size() + 1)

4656 return emitOpError("should return an index for each basis element and up "

4657 "to one extra index");

4658

4659 auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);

4660 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())

4661 return emitOpError(

4662 "mismatch between dynamic and static basis (kDynamic marker but no "

4663 "corresponding dynamic basis entry) -- this can only happen due to an "

4664 "incorrect fold/rewrite");

4665

4666 if (!llvm::all_of(staticBasis, [](int64_t v) {

4667 return v > 0 || ShapedType::isDynamic(v);

4668 }))

4669 return emitOpError("no basis element may be statically non-positive");

4670

4671 return success();

4672 }

4673

4674

4675

4676

4677

4678 static std::optional<SmallVector<int64_t>>

4682 uint64_t dynamicBasisIndex = 0;

4684 if (basis) {

4685 mutableDynamicBasis.erase(dynamicBasisIndex);

4686 } else {

4687 ++dynamicBasisIndex;

4688 }

4689 }

4690

4691

4692 if (dynamicBasisIndex == dynamicBasis.size())

4693 return std::nullopt;

4694

4698 if (!basisVal)

4699 staticBasis.push_back(ShapedType::kDynamic);

4700 else

4701 staticBasis.push_back(*basisVal);

4702 }

4703

4704 return staticBasis;

4705 }

4706

4707 LogicalResult

4708 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,

4710 std::optional<SmallVector<int64_t>> maybeStaticBasis =

4712 adaptor.getDynamicBasis());

4713 if (maybeStaticBasis) {

4714 setStaticBasis(*maybeStaticBasis);

4715 return success();

4716 }

4717

4718

4719 if (getNumResults() == 1) {

4720 result.push_back(getLinearIndex());

4721 return success();

4722 }

4723

4724 if (adaptor.getLinearIndex() == nullptr)

4725 return failure();

4726

4727 if (!adaptor.getDynamicBasis().empty())

4728 return failure();

4729

4730 int64_t highPart = cast(adaptor.getLinearIndex()).getInt();

4731 Type attrType = getLinearIndex().getType();

4732

4734 if (hasOuterBound())

4735 staticBasis = staticBasis.drop_front();

4736 for (int64_t modulus : llvm::reverse(staticBasis)) {

4737 result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));

4738 highPart = llvm::divideFloorSigned(highPart, modulus);

4739 }

4741 std::reverse(result.begin(), result.end());

4742 return success();

4743 }

4744

4747 if (hasOuterBound()) {

4748 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)

4749 return getMixedValues(getStaticBasis().drop_front(),

4750 getDynamicBasis().drop_front(), builder);

4751

4752 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),

4753 builder);

4754 }

4755

4756 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);

4757 }

4758

4761 if (!hasOuterBound())

4763 return ret;

4764 }

4765

4766 namespace {

4767

4768

4769 struct DropUnitExtentBasis

4770 : public OpRewritePatternaffine::AffineDelinearizeIndexOp {

4772

4773 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,

4775 SmallVector replacements(delinearizeOp->getNumResults(), nullptr);

4776 std::optional zero = std::nullopt;

4777 Location loc = delinearizeOp->getLoc();

4779 if (!zero)

4780 zero = rewriter.createarith::ConstantIndexOp(loc, 0);

4781 return zero.value();

4782 };

4783

4784

4785

4787 for (auto [index, basis] :

4789 std::optional<int64_t> basisVal =

4791 if (basisVal == 1)

4792 replacements[index] = getZero();

4793 else

4794 newBasis.push_back(basis);

4795 }

4796

4797 if (newBasis.size() == delinearizeOp.getNumResults())

4799 "no unit basis elements");

4800

4801 if (!newBasis.empty()) {

4802

4803 auto newDelinearizeOp = rewriter.createaffine::AffineDelinearizeIndexOp(

4804 loc, delinearizeOp.getLinearIndex(), newBasis);

4805 int newIndex = 0;

4806

4807 for (auto &replacement : replacements) {

4808 if (replacement)

4809 continue;

4810 replacement = newDelinearizeOp->getResult(newIndex++);

4811 }

4812 }

4813

4814 rewriter.replaceOp(delinearizeOp, replacements);

4815 return success();

4816 }

4817 };

4818

4819

4820

4821

4822

4823

4824

4825

4826

4827

4828

4829 struct CancelDelinearizeOfLinearizeDisjointExactTail

4830 : public OpRewritePatternaffine::AffineDelinearizeIndexOp {

4832

4833 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,

4835 auto linearizeOp = delinearizeOp.getLinearIndex()

4836 .getDefiningOpaffine::AffineLinearizeIndexOp();

4837 if (!linearizeOp)

4839 "index doesn't come from linearize");

4840

4841 if (!linearizeOp.getDisjoint())

4843

4844 ValueRange linearizeIns = linearizeOp.getMultiIndex();

4845

4848 size_t numMatches = 0;

4849 for (auto [linSize, delinSize] : llvm::zip(

4850 llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {

4851 if (linSize != delinSize)

4852 break;

4853 ++numMatches;

4854 }

4855

4856 if (numMatches == 0)

4858 delinearizeOp, "final basis element doesn't match linearize");

4859

4860

4861 if (numMatches == linearizeBasis.size() &&

4862 numMatches == delinearizeBasis.size() &&

4863 linearizeIns.size() == delinearizeOp.getNumResults()) {

4864 rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());

4865 return success();

4866 }

4867

4868 Value newLinearize = rewriter.createaffine::AffineLinearizeIndexOp(

4869 linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),

4871 linearizeOp.getDisjoint());

4872 auto newDelinearize = rewriter.createaffine::AffineDelinearizeIndexOp(

4873 delinearizeOp.getLoc(), newLinearize,

4875 delinearizeOp.hasOuterBound());

4877 mergedResults.append(linearizeIns.take_back(numMatches).begin(),

4878 linearizeIns.take_back(numMatches).end());

4879 rewriter.replaceOp(delinearizeOp, mergedResults);

4880 return success();

4881 }

4882 };

4883

4884

4885

4886

4887

4888

4889

4890

4891

4892

4893

4894

4895

4896

4897 struct SplitDelinearizeSpanningLastLinearizeArg final

4900

4901 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,

4903 auto linearizeOp = delinearizeOp.getLinearIndex()

4904 .getDefiningOpaffine::AffineLinearizeIndexOp();

4905 if (!linearizeOp)

4907 "index doesn't come from linearize");

4908

4909 if (!linearizeOp.getDisjoint())

4911 "linearize isn't disjoint");

4912

4913 int64_t target = linearizeOp.getStaticBasis().back();

4914 if (ShapedType::isDynamic(target))

4916 linearizeOp, "linearize ends with dynamic basis value");

4917

4918 int64_t sizeToSplit = 1;

4919 size_t elemsToSplit = 0;

4921 for (int64_t basisElem : llvm::reverse(basis)) {

4922 if (ShapedType::isDynamic(basisElem))

4924 delinearizeOp, "dynamic basis element while scanning for split");

4925 sizeToSplit *= basisElem;

4926 elemsToSplit += 1;

4927

4928 if (sizeToSplit > target)

4930 "overshot last argument size");

4931 if (sizeToSplit == target)

4932 break;

4933 }

4934

4935 if (sizeToSplit < target)

4937 delinearizeOp, "product of known basis elements doesn't exceed last "

4938 "linearize argument");

4939

4940 if (elemsToSplit < 2)

4942 delinearizeOp,

4943 "need at least two elements to form the basis product");

4944

4945 Value linearizeWithoutBack =

4946 rewriter.createaffine::AffineLinearizeIndexOp(

4947 linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),

4948 linearizeOp.getDynamicBasis(),

4949 linearizeOp.getStaticBasis().drop_back(),

4950 linearizeOp.getDisjoint());

4951 auto delinearizeWithoutSplitPart =

4952 rewriter.createaffine::AffineDelinearizeIndexOp(

4953 delinearizeOp.getLoc(), linearizeWithoutBack,

4954 delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),

4955 delinearizeOp.hasOuterBound());

4956 auto delinearizeBack = rewriter.createaffine::AffineDelinearizeIndexOp(

4957 delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),

4958 basis.take_back(elemsToSplit), true);

4960 llvm::concat(delinearizeWithoutSplitPart.getResults(),

4961 delinearizeBack.getResults()));

4962 rewriter.replaceOp(delinearizeOp, results);

4963

4964 return success();

4965 }

4966 };

4967 }

4968

4969 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(

4972 .insert<CancelDelinearizeOfLinearizeDisjointExactTail,

4973 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(

4974 context);

4975 }

4976

4977

4978

4979

4980

4981 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,

4984 bool disjoint) {

4985 if (!basis.empty() && basis.front() == Value())

4986 basis = basis.drop_front();

4990 staticBasis);

4991 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);

4992 }

4993

4994 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,

4998 bool disjoint) {

4999 if (!basis.empty() && basis.front() == OpFoldResult())

5000 basis = basis.drop_front();

5004 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);

5005 }

5006

5007 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,

5011 build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);

5012 }

5013

5015 size_t numIndexes = getMultiIndex().size();

5016 size_t numBasisElems = getStaticBasis().size();

5017 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)

5018 return emitOpError("should be passed a basis element for each index except "

5019 "possibly the first");

5020

5021 auto dynamicMarkersCount =

5022 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);

5023 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())

5024 return emitOpError(

5025 "mismatch between dynamic and static basis (kDynamic marker but no "

5026 "corresponding dynamic basis entry) -- this can only happen due to an "

5027 "incorrect fold/rewrite");

5028

5029 return success();

5030 }

5031

5032 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {

5033 std::optional<SmallVector<int64_t>> maybeStaticBasis =

5035 adaptor.getDynamicBasis());

5036 if (maybeStaticBasis) {

5037 setStaticBasis(*maybeStaticBasis);

5038 return getResult();

5039 }

5040

5041 if (getMultiIndex().empty())

5043

5044

5045 if (getMultiIndex().size() == 1)

5046 return getMultiIndex().front();

5047

5048 if (llvm::is_contained(adaptor.getMultiIndex(), nullptr))

5049 return nullptr;

5050

5051 if (!adaptor.getDynamicBasis().empty())

5052 return nullptr;

5053

5054 int64_t result = 0;

5055 int64_t stride = 1;

5056 for (auto [length, indexAttr] :

5057 llvm::zip_first(llvm::reverse(getStaticBasis()),

5058 llvm::reverse(adaptor.getMultiIndex()))) {

5059 result = result + cast(indexAttr).getInt() * stride;

5060 stride = stride * length;

5061 }

5062

5063 if (!hasOuterBound())

5064 result =

5065 result +

5066 cast(adaptor.getMultiIndex().front()).getInt() * stride;

5067

5069 }

5070

5073 if (hasOuterBound()) {

5074 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)

5075 return getMixedValues(getStaticBasis().drop_front(),

5076 getDynamicBasis().drop_front(), builder);

5077

5078 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),

5079 builder);

5080 }

5081

5082 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);

5083 }

5084

5087 if (!hasOuterBound())

5089 return ret;

5090 }

5091

5092 namespace {

5093

5094

5095

5096

5097

5098

5099

5100

5101

5102

5103 struct DropLinearizeUnitComponentsIfDisjointOrZero final

5106

5107 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,

5109 ValueRange multiIndex = op.getMultiIndex();

5110 size_t numIndices = multiIndex.size();

5112 newIndices.reserve(numIndices);

5114 newBasis.reserve(numIndices);

5115

5116 if (!op.hasOuterBound()) {

5117 newIndices.push_back(multiIndex.front());

5118 multiIndex = multiIndex.drop_front();

5119 }

5120

5122 for (auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {

5124 if (!basisEntry || *basisEntry != 1) {

5125 newIndices.push_back(index);

5126 newBasis.push_back(basisElem);

5127 continue;

5128 }

5129

5131 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {

5132 newIndices.push_back(index);

5133 newBasis.push_back(basisElem);

5134 continue;

5135 }

5136 }

5137 if (newIndices.size() == numIndices)

5139 "no unit basis entries to replace");

5140

5141 if (newIndices.size() == 0) {

5143 return success();

5144 }

5146 op, newIndices, newBasis, op.getDisjoint());

5147 return success();

5148 }

5149 };

5150

5153 int64_t nDynamic = 0;

5157 if (!term)

5158 return term;

5160 if (maybeConst) {

5162 } else {

5163 dynamicPart.push_back(cast(term));

5165 }

5166 }

5167 if (auto constant = dyn_cast(result))

5169 return builder.create(loc, result, dynamicPart).getResult();

5170 }

5171

5172

5173

5174

5175

5176

5177

5178

5179

5180

5181

5182

5183

5184

5185

5186

5187

5188

5189

5190

5191

5192

5193

5194

5195

5196

5197

5198

5199 struct CancelLinearizeOfDelinearizePortion final

5202

5203 private:

5204

5205

5206

5207

5210 unsigned linStart = 0;

5211 unsigned delinStart = 0;

5212 unsigned length = 0;

5213 };

5214

5215 public:

5216 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,

5219

5222

5223 ValueRange multiIndex = linearizeOp.getMultiIndex();

5224 unsigned numLinArgs = multiIndex.size();

5225 unsigned linArgIdx = 0;

5226

5227

5229 while (linArgIdx < numLinArgs) {

5230 auto asResult = dyn_cast(multiIndex[linArgIdx]);

5231 if (!asResult) {

5232 linArgIdx++;

5233 continue;

5234 }

5235

5236 auto delinearizeOp =

5237 dyn_cast(asResult.getOwner());

5238 if (!delinearizeOp) {

5239 linArgIdx++;

5240 continue;

5241 }

5242

5243

5244

5245

5246

5247

5248

5249

5250

5251

5252

5253

5254

5255 unsigned delinArgIdx = asResult.getResultNumber();

5257 OpFoldResult firstDelinBound = delinBasis[delinArgIdx];

5258 OpFoldResult firstLinBound = linBasis[linArgIdx];

5259 bool boundsMatch = firstDelinBound == firstLinBound;

5260 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;

5261 bool knownByDisjoint =

5262 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;

5263 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {

5264 linArgIdx++;

5265 continue;

5266 }

5267

5268 unsigned j = 1;

5269 unsigned numDelinOuts = delinearizeOp.getNumResults();

5270 for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;

5271 ++j) {

5272 if (multiIndex[linArgIdx + j] !=

5273 delinearizeOp.getResult(delinArgIdx + j))

5274 break;

5275 if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])

5276 break;

5277 }

5278

5279

5280

5281 if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {

5282 linArgIdx++;

5283 continue;

5284 }

5285 matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j});

5286 linArgIdx += j;

5287 }

5288

5289 if (matches.empty())

5291 linearizeOp, "no run of delinearize outputs to deal with");

5292

5293

5294

5295

5297

5299 newIndex.reserve(numLinArgs);

5301 newBasis.reserve(numLinArgs);

5302 unsigned prevMatchEnd = 0;

5303 for (Match m : matches) {

5304 unsigned gap = m.linStart - prevMatchEnd;

5305 llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));

5306 llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));

5307

5308 prevMatchEnd = m.linStart + m.length;

5309

5310 PatternRewriter::InsertionGuard g(rewriter);

5312

5314 linBasisRef.slice(m.linStart, m.length);

5315

5316

5318 computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);

5319

5320

5321 if (m.length == m.delinearize.getNumResults()) {

5322 newIndex.push_back(m.delinearize.getLinearIndex());

5323 newBasis.push_back(newSize);

5324

5326 continue;

5327 }

5328

5331 newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,

5332 newDelinBasis.begin() + m.delinStart + m.length);

5333 newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);

5334 auto newDelinearize = rewriter.create(

5335 m.delinearize.getLoc(), m.delinearize.getLinearIndex(),

5336 newDelinBasis);

5337

5338

5339

5340

5341 Value combinedElem = newDelinearize.getResult(m.delinStart);

5342 auto residualDelinearize = rewriter.create(

5343 m.delinearize.getLoc(), combinedElem, basisToMerge);

5344

5345

5346

5347

5348 llvm::append_range(newDelinResults,

5349 newDelinearize.getResults().take_front(m.delinStart));

5350 llvm::append_range(newDelinResults, residualDelinearize.getResults());

5351 llvm::append_range(

5352 newDelinResults,

5353 newDelinearize.getResults().drop_front(m.delinStart + 1));

5354

5355 delinearizeReplacements.push_back(newDelinResults);

5356 newIndex.push_back(combinedElem);

5357 newBasis.push_back(newSize);

5358 }

5359 llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));

5360 llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));

5362 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());

5363

5364 for (auto [m, newResults] :

5365 llvm::zip_equal(matches, delinearizeReplacements)) {

5366 if (newResults.empty())

5367 continue;

5368 rewriter.replaceOp(m.delinearize, newResults);

5369 }

5370

5371 return success();

5372 }

5373 };

5374

5375

5376

5377

5378

5379 struct DropLinearizeLeadingZero final

5382

5383 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,

5385 Value leadingIdx = op.getMultiIndex().front();

5387 return failure();

5388

5389 if (op.getMultiIndex().size() == 1) {

5390 rewriter.replaceOp(op, leadingIdx);

5391 return success();

5392 }

5393

5396 if (op.hasOuterBound())

5397 newMixedBasis = newMixedBasis.drop_front();

5398

5400 op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());

5401 return success();

5402 }

5403 };

5404 }

5405

5406 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(

5408 patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,

5409 DropLinearizeUnitComponentsIfDisjointOrZero>(context);

5410 }

5411

5412

5413

5414

5415

5416 #define GET_OP_CLASSES

5417 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"

static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)

Maps the 2-dim memref shape to the 64-bit stride.

static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)

Creates an affine loop from the bounds known to be constants.

static bool hasTrivialZeroTripCount(AffineForOp op)

Returns true if the affine.for has zero iterations in trivial cases.

static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)

Composes the given affine map with the given list of operands, pulling in the maps from any affine....

static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)

Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....

static void printAffineMinMaxOp(OpAsmPrinter &p, T op)

static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)

static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)

Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...

static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)

Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.

static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)

Simplify the map while exploiting information on the values in operands.

static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)

Fold an affine min or max operation with the given operands.

static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)

Canonicalize the bounds of the given loop.

static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)

Simplify expr while exploiting information from the values in operands.

static bool isValidAffineIndexOperand(Value value, Region *region)

static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)

static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)

Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.

static std::optional< int64_t > getUpperBound(Value iv)

Gets the constant upper bound on an affine.for iv.

static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)

Parse a for operation loop bounds.

static std::optional< int64_t > getLowerBound(Value iv)

Gets the constant lower bound on an iv.

static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)

Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...

static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)

Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...

static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)

static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)

Verify common invariants of affine.vector_load and affine.vector_store.

static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)

Simplify the expressions in map while making use of lower or upper bounds of its operands.

static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)

static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)

Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...

static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)

Check if e is known to be: 0 <= e < k.

static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)

Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...

static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)

Creates an affine loop from the bounds that may or may not be constants.

static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)

Prints dimension and symbol list.

static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)

Returns the largest known divisor of e.

static void legalizeDemotedDims(MapOrSet &mapOrSet, SmallVectorImpl< Value > &operands)

A valid affine dimension may appear as a symbol in affine.apply operations.

static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)

Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.

static LogicalResult foldLoopBounds(AffineForOp forOp)

Fold the constant bounds of a loop.

static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)

Utility function to verify that a set of operands are valid dimension and symbol identifiers.

static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)

Returns true if the result of the dim op is a valid symbol for region.

static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr &quotientTimesDiv, AffineExpr &rem)

Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.

static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)

Given a list of lists of parsed operands, populates uniqueOperands with unique operands.

static LogicalResult verifyAffineMinMaxOp(T op)

static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)

static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)

Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...

static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)

Canonicalize the result expression order of an affine map and return success if the order changed.

static Value getZero(OpBuilder &b, Location loc, Type elementType)

Get zero value for an element type.

static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)

A utility function used to materialize a constant for a given attribute and type.

static MLIRContext * getContext(OpFoldResult val)

static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)

Utility to check that all of the operations within 'src' can be inlined.

static int64_t getNumElements(Type t)

Compute the total number of elements in the given type, also taking into account nested types.

union mlir::linalg::@1203::ArityGroupAndKind::Kind kind

static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)

static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)

static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)

static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)

Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...

RetTy walkPostOrder(AffineExpr expr)

Base type for affine expression.

AffineExpr floorDiv(uint64_t v) const

AffineExprKind getKind() const

Return the classification for this type.

int64_t getLargestKnownDivisor() const

Returns the greatest known integral divisor of this affine expression.

MLIRContext * getContext() const

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

AffineMap getSliceMap(unsigned start, unsigned length) const

Returns the map consisting of length expressions starting from start.

MLIRContext * getContext() const

bool isFunctionOfDim(unsigned position) const

Return true if any affine expression involves AffineDimExpr position.

static AffineMap get(MLIRContext *context)

Returns a zero result affine map with no dimensions or symbols: () -> ().

AffineMap shiftDims(unsigned shift, unsigned offset=0) const

Replace dims[offset ...

unsigned getNumSymbols() const

unsigned getNumDims() const

ArrayRef< AffineExpr > getResults() const

bool isFunctionOfSymbol(unsigned position) const

Return true if any affine expression involves AffineSymbolExpr position.

unsigned getNumResults() const

AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const

This method substitutes any uses of dimensions and symbols (e.g.

unsigned getNumInputs() const

AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const

Replace symbols[offset ...

AffineExpr getResult(unsigned idx) const

AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const

Sparse replace method.

static AffineMap getConstantMap(int64_t val, MLIRContext *context)

Returns a single constant result affine map.

AffineMap getSubMap(ArrayRef< unsigned > resultPos) const

Returns the map consisting of the resultPos subset.

LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const

Folds the results of the application of an affine map on the provided operands to a constant if possi...

static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)

Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...

@ Paren

Parens surrounding zero or more operands.

@ OptionalSquare

Square brackets supporting zero or more ops, or nothing.

virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0

Parse a colon followed by a type list, which must have at least one type.

virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0

Parse a list of comma-separated items with an optional delimiter.

virtual Builder & getBuilder() const =0

Return a builder which provides useful access to MLIRContext, global objects like types and attribute...

virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0

Parse a named dictionary into 'result' if it is present.

virtual ParseResult parseOptionalKeyword(StringRef keyword)=0

Parse the given keyword if present.

MLIRContext * getContext() const

virtual ParseResult parseRParen()=0

Parse a ) token.

virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0

Emit a diagnostic at the specified location and return failure.

ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)

Add the specified type to the end of the specified type list and return success.

virtual ParseResult parseOptionalRParen()=0

Parse a ) token if present.

virtual ParseResult parseLess()=0

Parse a '<' token.

virtual ParseResult parseEqual()=0

Parse a = token.

virtual ParseResult parseColonType(Type &result)=0

Parse a colon followed by a type.

virtual SMLoc getCurrentLocation()=0

Get the location of the next token and store it into the argument.

virtual SMLoc getNameLoc() const =0

Return the location of the original name token.

virtual ParseResult parseGreater()=0

Parse a '>' token.

virtual ParseResult parseLParen()=0

Parse a ( token.

virtual ParseResult parseType(Type &result)=0

Parse a type.

virtual ParseResult parseComma()=0

Parse a , token.

virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0

Parse an optional arrow followed by a type list.

virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0

Parse an arrow followed by a type list.

ParseResult parseKeyword(StringRef keyword)

Parse a given keyword.

virtual ParseResult parseAttribute(Attribute &result, Type type={})=0

Parse an arbitrary attribute of a given type and return it in result.

void printOptionalArrowTypeList(TypeRange &&types)

Print an optional arrow followed by a type list.

Attributes are known-constant values of operations.

Block represents an ordered list of Operations.

Operation * getTerminator()

Get the terminator operation of this block.

BlockArgument addArgument(Type type, Location loc)

Add one value to the argument list.

BlockArgListType getArguments()

This class is a general helper class for creating context-global objects like types,...

DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)

IntegerAttr getIntegerAttr(Type type, int64_t value)

AffineMap getDimIdentityMap()

AffineMap getMultiDimIdentityMap(unsigned rank)

AffineExpr getAffineSymbolExpr(unsigned position)

AffineExpr getAffineConstantExpr(int64_t constant)

DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)

Tensor-typed DenseIntElementsAttr getters.

IntegerAttr getI64IntegerAttr(int64_t value)

IntegerType getIntegerType(unsigned width)

BoolAttr getBoolAttr(bool value)

AffineMap getEmptyAffineMap()

Returns a zero result affine map with no dimensions or symbols: () -> ().

AffineMap getConstantAffineMap(int64_t val)

Returns a single constant result affine map with 0 dimensions and 0 symbols.

MLIRContext * getContext() const

AffineMap getSymbolIdentityMap()

ArrayAttr getArrayAttr(ArrayRef< Attribute > value)

ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)

An attribute that represents a reference to a dense integer vector or tensor object.

This is the interface that must be implemented by the dialects of operations to be inlined.

DialectInlinerInterface(Dialect *dialect)

This is a utility class for mapping one set of IR entities to another.

auto lookup(T from) const

Lookup a mapped value within the map.

An integer set representing a conjunction of one or more affine equalities and inequalities.

unsigned getNumDims() const

static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)

MLIRContext * getContext() const

unsigned getNumInputs() const

ArrayRef< AffineExpr > getConstraints() const

ArrayRef< bool > getEqFlags() const

Returns the equality bits, which specify whether each of the constraints is an equality or inequality...

unsigned getNumSymbols() const

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

This class provides a mutable adaptor for a range of operands.

void erase(unsigned subStart, unsigned subLen=1)

Erase the operands within the given sub-range.

NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...

void pop_back()

Pop last element from list.

Attribute erase(StringAttr name)

Erase the attribute with the given name from the list.

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0

Parses a region.

virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0

Parse a single argument with the following syntax:

ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)

Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...

virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0

Parse zero or more arguments with a specified surrounding delimiter.

virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0

Parses an affine map attribute where dims and symbols are SSA operands.

ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)

Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)

virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0

Resolve an operand to an SSA value, emitting an error on failure.

ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)

Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...

virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0

Parse a single SSA value operand name along with a result number if allowResultNumber is true.

virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0

Parses an affine expression where dims and symbols are SSA operands.

virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0

Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...

This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...

virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0

If the specified operation has attributes, print out an attribute dictionary with their values.

virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0

Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.

virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0

Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.

virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0

Prints a region.

virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0

Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...

virtual void printOperand(Value value)=0

Print implementations for various things an operation contains.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Block::iterator getInsertionPoint() const

Returns the current insertion point of the builder.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

Listener * getListener() const

Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

Block * getInsertionBlock() const

Return the block the current insertion point belongs to.

This class represents a single result from folding an operation.

This class represents an operand of an operation.

A trait of region holding operations that defines a new scope for polyhedral optimization purposes.

This class provides the API for ops that are known to be isolated from above.

A trait used to provide symbol table functionalities to a region operation.

This class implements the operand iterators for the Operation class.

Operation is the basic unit of execution within MLIR.

bool hasTrait()

Returns true if the operation was registered with a particular trait, e.g.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

Operation * getParentOp()

Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...

operand_range getOperands()

Returns an iterator on the underlying Value's.

Region * getParentRegion()

Returns the region to which the instruction belongs.

bool isProperAncestor(Operation *other)

Return true if this operation is a proper ancestor of the other operation.

operand_range::iterator operand_iterator

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class represents a point being branched from in the methods of the RegionBranchOpInterface.

bool isParent() const

Returns true if branching from the parent op.

This class represents a successor of a region.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

Operation * getParentOp()

Return the parent operation this region is attached to.

RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)

Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void eraseBlock(Block *block)

This method erases all operations in a block.

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)

Inline the operations of block 'source' into the end of block 'dest'.

virtual void finalizeOpModification(Operation *op)

This method is used to signal the end of an in-place modification of the given operation.

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)

Find uses of from and replace them with to if the functor returns true.

void modifyOpInPlace(Operation *root, CallableT &&callable)

This method is a utility wrapper around an in-place modification of an operation.

virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)

Inline the operations of block 'source' into block 'dest' before the given position.

virtual void startOpModification(Operation *op)

This method is used to notify the rewriter that an in-place operation modification is about to happen...

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

This class represents a specific instance of an effect.

static DerivedEffect * get()

Returns a unique instance for the derived effect class.

static DefaultResource * get()

Returns a unique instance for the given effect class.

std::vector< SmallVector< int64_t, 8 > > operandExprStack

static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)

Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

AffineBound represents a lower or upper bound in the for operation.

AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...

AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...

An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.

LogicalResult canonicalize()

Attempts to canonicalize the map and operands.

ArrayRef< Value > getOperands() const

AffineExpr getResult(unsigned i)

AffineMap getAffineMap() const

unsigned getNumResults() const

Operation * getOwner() const

Return the owner of this operand.

constexpr auto RecursivelySpeculatable

Speculatability

This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...

constexpr auto NotSpeculatable

void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)

Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...

void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)

Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...

void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)

Extracts the induction variables from a list of AffineForOps and places them in the output argument i...

bool isValidDim(Value value)

Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...

SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Variant of makeComposedFoldedAffineApply suitable for multi-result maps.

bool isAffineInductionVar(Value val)

Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.

AffineForOp getForInductionVarOwner(Value val)

Returns the loop parent of an induction variable.

AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...

void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)

Modifies both map and operands in-place so as to:

OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...

bool isAffineForInductionVar(Value val)

Returns true if the provided value is the induction variable of an AffineForOp.

OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...

bool isTopLevelValue(Value value)

A utility function to check if a value is defined at the top level of an op with trait AffineScope or...

Region * getAffineAnalysisScope(Operation *op)

Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...

void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)

Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.

void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)

Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...

bool isValidSymbol(Value value)

Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...

OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...

AffineParallelOp getAffineParallelInductionVarOwner(Value val)

Returns true if the provided value is among the induction variables of an AffineParallelOp.

Region * getAffineScope(Operation *op)

Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...

ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)

Parses dimension and symbol list.

bool isAffineParallelInductionVar(Value val)

Returns true if val is the induction variable of an AffineParallelOp.

AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)

Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...

BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)

Return a MemRefType to which the type of the given value can be bufferized.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)

This is a common utility used for patterns of the form "someop(memref.cast) -> someop".

QueryRef parse(llvm::StringRef line, const QuerySession &qs)

Include the generated interface declarations.

AffineMap simplifyAffineMap(AffineMap map)

Simplifies an affine map by simplifying its underlying AffineExpr results.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)

Convert int64_t to integer attributes of index type and return them as OpFoldResult.

const FrozenRewritePatternSet GreedyRewriteConfig bool * changed

AffineMap removeDuplicateExprs(AffineMap map)

Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)

Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)

Given the strides together with a linear index in the dimension space, return the vector-space offset...

InFlightDiagnostic emitError(Location loc)

Utility method to emit an error message using this location.

bool isPure(Operation *op)

Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.

int64_t computeProduct(ArrayRef< int64_t > basis)

Self-explicit.

@ CeilDiv

RHS of ceildiv is always a constant or a symbolic expression.

@ Mod

RHS of mod is always a constant or a symbolic expression with a positive value.

@ DimId

Dimensional identifier.

@ FloorDiv

RHS of floordiv is always a constant or a symbolic expression.

@ SymbolId

Symbolic identifier.

AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)

std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn

A function that returns the additional yielded values during replaceWithAdditionalYields.

detail::constant_int_predicate_matcher m_Zero()

Matches a constant scalar / vector splat / tensor splat integer zero.

const FrozenRewritePatternSet & patterns

void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)

Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...

AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

OpFoldResult getAsOpFoldResult(Value val)

Given a value, try to extract a constant Attribute.

SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)

Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)

These free functions allow clients of the API to not use classes in detail.

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)

Fold all attributes among the given operands into the affine map.

AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)

Canonicalize the affine map result expression order of an affine min/max operation.

LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override

LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override

Remove duplicated expressions in affine min/max ops.

LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override

Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.

LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override

This is the representation of an operand reference.

This class represents a listener that may be used to hook into various actions within an OpBuilder.

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})

Patterns must specify the root operation name they match against, and can also specify the benefit of...

This represents an operation in an abstracted form, suitable for use with the builder APIs.

T & getOrAddProperties()

Get (or create) a properties of the provided type to be set on the operation on creation.

SmallVector< Value, 4 > operands

void addOperands(ValueRange newOperands)

void addAttribute(StringRef name, Attribute attr)

Add an attribute with the specified name.

void addTypes(ArrayRef< Type > newTypes)

SmallVector< std::unique_ptr< Region >, 1 > regions

Regions that the op will hold.

SmallVector< Type, 4 > types

Types of the results of this operation.

Region * addRegion()

Create a region that should be attached to the operation.

Eliminates variable at the specified position using Fourier-Motzkin variable elimination.