MLIR: lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

42

43 #include

44 #include

45

59 #include "llvm/ADT/DenseSet.h"

60 #include "llvm/ADT/SetVector.h"

61

63

64

65

66 #define DEBUG_TYPE "one-shot-analysis"

67

68 using namespace mlir;

70

71 static bool isaTensor(Type t) { return isa(t); }

72

73

74

75

76

77

78

79

80

81

83

85 "__opresult_alias_set_attr__";

86

88

89

95 cast(attr).getAsValueRange()));

96 } else {

99 if (isa(opOperand.get().getType()))

101 }

102 inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";

104 OpBuilder(op).getStrArrayAttr(inPlaceVector));

105 }

106

107

108

109

110

114

117 if (isa(v.getType()))

120 for (Block &b : r.getBlocks())

121 for (auto bbArg : b.getArguments())

122 if (isa(bbArg.getType()))

124 });

125

126

127 op->walk([&](BufferizableOpInterface bufferizableOp) {

128 if (!options.isOpAllowed(bufferizableOp))

130 for (OpOperand &opOperand : bufferizableOp->getOpOperands())

131 if (isa(opOperand.get().getType()))

132 if (bufferizableOp.mustBufferizeInPlace(opOperand, *this))

135 });

136 }

137

140 auto leaderIt = equivalentInfo.findLeader(v);

141 for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;

142 ++mit) {

143 fun(*mit);

144 }

145 }

146

149 auto leaderIt = aliasInfo.findLeader(v);

150 for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {

151 fun(*mit);

152 }

153 }

154

156 Value v2) const {

157 return equivalentInfo.isEquivalent(v1, v2);

158 }

159

161 Value v2) const {

162 return aliasInfo.isEquivalent(v1, v2);

163 }

164

166 if (inplaceBufferized.contains(&operand))

167 return;

168 inplaceBufferized.insert(&operand);

170 aliasInfo.unionSets(alias.value, operand.get());

171 ++statNumTensorInPlace;

172 }

173

175 assert(!inplaceBufferized.contains(&operand) &&

176 "OpOperand was already decided to bufferize inplace");

177 ++statNumTensorOutOfPlace;

178 }

179

181 aliasInfo.insert(v);

182 equivalentInfo.insert(v);

183 }

184

187

189 if (!bufferizableOp)

191

192

194 if (!isa(opResult.getType()))

195 continue;

196

197

198

199 if (opResult.getUses().empty())

200 continue;

201

202

203 OpOperand *opOperand = &(*opResult.getUses().begin());

205 for (OpOperand &use : opResult.getUses())

206 undefinedTensorUses.insert(&use);

207 }

208

210 });

211 }

212

214 return undefinedTensorUses.contains(opOperand);

215 }

216

218 return inplaceBufferized.contains(&opOperand);

219 }

220

222 bool isWritten = false;

226 isWritten = true;

227 });

228 return isWritten;

229 }

230

232

233

234 if (auto bufferizableOp =

236 return bufferizableOp.isWritable(value, *this);

237

238

239 return false;

240 }

241

243 aliasInfo.unionSets(v1, v2);

244 }

245

247 equivalentInfo.unionSets(v1, v2);

248 }

249

251

252

253

254

255

256

259

260 if (!state.bufferizesToMemoryWrite(opOperand))

261 return false;

262

263 return state.isInPlace(opOperand);

264 }

265

266

267

270 do {

271

272

274 return false;

276 return true;

278 return false;

279 }

280

281

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

304

305

306

307

308

309

310

311

312

313

314

315

316

317

318

319

320

321

322

323

324

325

326

327

328

329

330

331

332

333

334

335

336

337

338

339

340

341

342

343

344

345

346

347

348

349

350

351

352

353

354

355

360 for (Value def : definitions) {

362 state.getEnclosingRepetitiveRegion(uRead->getOwner(), options);

363 Region *rDef = state.getEnclosingRepetitiveRegion(def, options);

364

365

366

367 if (rRead == rDef)

368 continue;

369

370

371

372 while (true) {

374 if (nextRegion == rDef)

375 break;

376 assert(nextRegion && "expected to find another repetitive region");

377 rRead = nextRegion;

378 }

379

380

382 return false;

383 }

384

385 return true;

386 }

387

388

389

390

391

392

393

394

395

396

397

398

399

400

401

405

406

407

410 return true;

411

414 for (Value def : definitions) {

415 Block *defBlock = def.getParentBlock();

416 if (readBlock->isReachable(writeBlock, {defBlock}) &&

417 writeBlock->isReachable(readBlock, {defBlock}))

418 return false;

419 }

420

421 return true;

422 }

423

429 }

430

431

433 Value definition) {

434 static uint64_t counter = 0;

437

439 std::string id = "C_" + std::to_string(counter++);

440

441 std::string conflictingWriteAttr =

442 id +

443 "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) +

444 "]";

445 conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());

446

447 std::string readAttr =

448 id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";

450

451 if (auto opResult = dyn_cast(definition)) {

452 std::string defAttr =

453 id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]";

454 opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr());

455 } else {

456 auto bbArg = cast(definition);

457 std::string defAttr =

458 id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";

459 bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr());

460 }

461 }

462

463

464

465

466

467

468

469

470

475 config.followEquivalentOnly = true;

476 config.alwaysIncludeLeaves = false;

477 config.followSameTypeOrCastsOnly = true;

478 return !state

479 .findValueInReverseUseDefChain(

480 start, [&](Value v) { return v == other; }, config)

481 .empty();

482 }

483

484

485

488 SubsetInsertionOpInterface subsetOp) {

489 auto matchingSubset = [&](Value val) {

490 if (auto opResult = dyn_cast(val))

491 if (subsetOp.isEquivalentSubset(opResult, [&](Value v1, Value v2) {

492 return state.areEquivalentBufferizedValues(v1, v2);

493 }))

494 return true;

495 return false;

496 };

497

498

500 state.findValueInReverseUseDefChain(opOperand, matchingSubset);

501 return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));

502 }

503

504

505

506

507

508

514

515

516

517 if (auto subsetOp = dyn_cast(readingOp)) {

518

519

520

521

522

523

524

525 if (uRead == &subsetOp.getDestinationOperand() &&

527

528

529

530

531

532

533

534

535

536

537

538

539 return true;

540

541 if (uRead == &subsetOp.getSourceOperand() &&

542 uConflictingWrite == &subsetOp.getDestinationOperand() &&

544

545

546

547

548

549

550

551

552 return true;

553 }

554

555

556 if (auto subsetOp =

557 dyn_cast(conflictingWritingOp))

558

559

560

561

562

563

564

565

566

567

568

569

570

571

572

573

574 if (uConflictingWrite == &subsetOp.getDestinationOperand() &&

575 state.areEquivalentBufferizedValues(

576 uRead->get(), subsetOp.getSourceOperand().get()) &&

578 return true;

579

580 return false;

581 }

582

583

584

585

586

587

588

589

590 static bool

596

597

598

599

600 if (options.checkParallelRegions && !usesRead.empty()) {

601 for (OpOperand *uConflictingWrite : usesWrite) {

602

603

604

605

606

607

609 state.findValueInReverseUseDefChain(uConflictingWrite, [&](Value v) {

610 return state.bufferizesToMemoryWrite(v);

611 });

612 assert(!definitionsOrLeaves.empty() &&

613 "expected at least one definition or leaf");

614

615

616

617 for (Value def : definitionsOrLeaves) {

619 getParallelRegion(uConflictingWrite->getOwner()->getParentRegion(),

621 LLVM_DEBUG(

622 llvm::dbgs()

623 << "\n- bufferizes out-of-place due to parallel region:\n");

624 LLVM_DEBUG(llvm::dbgs()

625 << " unConflictingWrite = operand "

626 << uConflictingWrite->getOperandNumber() << " of "

627 << *uConflictingWrite->getOwner() << "\n");

628 return true;

629 }

630 }

631 }

632 }

633

634 for (OpOperand *uRead : usesRead) {

635 Operation *readingOp = uRead->getOwner();

636 LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");

637 LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber()

638 << " of " << *readingOp << "\n");

639

640

641

642

643

644

645

646

647

648

649

650 const SetVector &definitions = state.findDefinitionsCached(uRead);

651 if (definitions.empty()) {

652

653 LLVM_DEBUG(llvm::dbgs()

654 << " no conflict: read value has no definitions\n");

655 continue;

656 }

657

658

659

660 for (OpOperand *uConflictingWrite : usesWrite) {

661 LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand "

662 << uConflictingWrite->getOperandNumber() << " of "

663 << *uConflictingWrite->getOwner() << "\n");

664

665

666

667 bool useDominance =

669 LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n");

670

671

672

673 Operation *conflictingWritingOp = uConflictingWrite->getOwner();

674

675

676

677 if (useDominance) {

678

679

680

681

682

683

684 if (happensBefore(readingOp, conflictingWritingOp, domInfo)) {

685 LLVM_DEBUG(llvm::dbgs()

686 << " no conflict: read happens before write\n");

687 continue;

688 }

689

690

691

692

693

694

695

696

697 if (uConflictingWrite == uRead) {

698 LLVM_DEBUG(llvm::dbgs()

699 << " no conflict: read and write are same use\n");

700 continue;

701 }

702

703

704

705

706

707

708 if (state.insideMutuallyExclusiveRegions(readingOp,

709 conflictingWritingOp)) {

710 LLVM_DEBUG(llvm::dbgs() << " no conflict: read and write are in "

711 "mutually exclusive regions\n");

712 continue;

713 }

714

715

716

717

718 if (conflictingWritingOp == readingOp) {

719 if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {

720 if (bufferizableOp.bufferizesToElementwiseAccess(

721 state, {uRead, uConflictingWrite})) {

723 state, uRead, uConflictingWrite->get()) ||

725 state, uConflictingWrite, uRead->get())) {

726 LLVM_DEBUG(

727 llvm::dbgs()

728 << " no conflict: op bufferizes to element-wise access\n");

729 continue;

730 }

731 }

732 }

733 }

734 }

735

736

738 LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n");

739 continue;

740 }

741

742

743 if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {

744 if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {

745 LLVM_DEBUG(llvm::dbgs()

746 << " no conflict: op interace of reading op says 'no'\n");

747 continue;

748 }

749 }

750

751 if (conflictingWritingOp != readingOp) {

752 if (auto bufferizableOp =

753 options.dynCastBufferizableOp(conflictingWritingOp)) {

754 if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,

755 state)) {

756 LLVM_DEBUG(

757 llvm::dbgs()

758 << " no conflict: op interace of writing op says 'no'\n");

759 continue;

760 }

761 }

762 }

763

764

765 for (Value definition : definitions) {

766 LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n");

767

768

769 if (Operation *defOp = definition.getDefiningOp()) {

770 if (happensBefore(conflictingWritingOp, defOp, domInfo)) {

771

772 LLVM_DEBUG(llvm::dbgs()

773 << " no conflict: write happens before definition\n");

774 continue;

775 }

776

777 if (defOp->isProperAncestor(conflictingWritingOp)) {

778 LLVM_DEBUG(

779 llvm::dbgs()

780 << " no conflict: write is contained in definition\n");

781 continue;

782 }

783 } else {

784 auto bbArg = cast(definition);

785 Block *block = bbArg.getOwner();

787 LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg "

788 "and write happens outside of block\n");

789

790

791 continue;

792 }

793 }

794

795

796

797 AliasingValueList aliases = state.getAliasingValues(*uConflictingWrite);

799 aliases.getAliases()[0].value == definition) {

800 LLVM_DEBUG(llvm::dbgs()

801 << " no conflict: definition and write are same\n");

802 continue;

803 }

804

805

806

807 if (options.printConflicts)

809 LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n");

810 return true;

811 }

812 }

813 }

814

815 return false;

816 }

817

818

821 state.applyOnAliases(root, [&](Value alias) {

822 for (auto &use : alias.getUses())

823

825 res.insert(&use);

826 });

827 }

828

829

832 state.applyOnAliases(root, [&](Value alias) {

833 for (auto &use : alias.getUses()) {

834

835 if (state.bufferizesToMemoryRead(use)) {

836 res.insert(&use);

837 continue;

838 }

839

840

841

842

843

844

845

846

847

848

849

850

851

852

853

854

855 if (!state.bufferizesToMemoryWrite(use)) {

856 AliasingValueList aliases = state.getAliasingValues(use);

857 if (llvm::any_of(aliases, [&](AliasingValue a) {

858 return state.isValueRead(a.value);

859 }))

860 res.insert(&use);

861 }

862 }

863 });

864 }

865

866

867

868

869

870

871

872

873

874

875

876

877

878

879

880

881

882

883

884

885

886

887

888

889

890

891

892

893

894

898

902 for (AliasingValue alias : state.getAliasingValues(operand)) {

905 }

906 if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))

907 usesWrite.insert(&operand);

908

910 }

911

912

914 static int64_t counter = 0;

916 std::string id = "W_" + std::to_string(counter++);

917 if (auto opResult = dyn_cast(value)) {

918 std::string attr = id + "[NOT-WRITABLE: result " +

919 std::to_string(opResult.getResultNumber()) + "]";

920 opResult.getDefiningOp()->setAttr(attr, b.getUnitAttr());

921 } else {

922 auto bbArg = cast(value);

923 std::string attr = id + "[NOT-WRITABLE: bbArg " +

924 std::to_string(bbArg.getArgNumber()) + "]";

925 bbArg.getOwner()->getParentOp()->setAttr(attr, b.getUnitAttr());

926 }

927 }

928

929

930

931 static bool

934 bool checkConsistencyOnly = false) {

935 bool foundWrite =

936 !checkConsistencyOnly && state.bufferizesToMemoryWrite(operand);

937

938 if (!foundWrite) {

939

942 for (AliasingValue alias : state.getAliasingValues(operand))

944 foundWrite = !usesWrite.empty();

945 }

946

947 if (!foundWrite)

948 return false;

949

950

951 bool foundReadOnly = false;

952 auto checkReadOnly = [&](Value v) {

953 if (!state.isWritable(v)) {

954 foundReadOnly = true;

955 if (state.getOptions().printConflicts)

957 }

958 };

959 state.applyOnAliases(operand.get(), checkReadOnly);

960 for (AliasingValue alias : state.getAliasingValues(operand))

961 state.applyOnAliases(alias.value, checkReadOnly);

962 if (foundReadOnly) {

963 LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");

964 return true;

965 }

966

967 return false;

968 }

969

970

971

972

973

974

976 OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) {

977 Value value = opOperand->get();

978 if (!cachedDefinitions.count(value))

980 return cachedDefinitions[value];

981 }

982

985 cachedDefinitions.clear();

986 }

987

988

989 static LogicalResult

992 LLVM_DEBUG(

993 llvm::dbgs() << "//===-------------------------------------------===//\n"

995 << " of " << *operand.getOwner() << "\n");

996

997 bool foundInterference =

1000

1001 if (foundInterference)

1002 state.bufferizeOutOfPlace(operand);

1003 else

1004 state.bufferizeInPlace(operand);

1005

1006 LLVM_DEBUG(llvm::dbgs()

1007 << "//===-------------------------------------------===//\n");

1008 return success();

1009 }

1010

1011 LogicalResult

1015 if (isa(opOperand.get().getType()))

1017 return failure();

1018 return success();

1019 }

1020

1021

1025 if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {

1026 for (OpResult opResult : op->getOpResults()) {

1027 if (!isa(opResult.getType()))

1028 continue;

1031

1032 continue;

1033

1034 Value firstOperand = aliases.begin()->opOperand->get();

1035 bool allEquivalent = true;

1038 bool isInPlace = state.isInPlace(*alias.opOperand);

1039 Value operand = alias.opOperand->get();

1040 if (isEquiv && isInPlace && alias.isDefinite) {

1041

1042

1043 state.unionEquivalenceClasses(opResult, operand);

1044 allEquivalent = false;

1045 break;

1046 }

1047 if (!isEquiv || !isInPlace)

1048 allEquivalent = false;

1049 if (!state.areEquivalentBufferizedValues(operand, firstOperand))

1050 allEquivalent = false;

1051 }

1052

1053

1054

1055

1056

1057

1058

1059

1060

1061

1062

1063 if (allEquivalent && !bufferizableOp.bufferizesToAllocation(opResult))

1064 state.unionEquivalenceClasses(opResult, firstOperand);

1065 }

1066 }

1067 }

1068 }

1069

1070

1071

1073

1076

1078 return;

1079 ops.push_back(op);

1080 });

1081

1083 }

1084

1085

1090

1091

1093 if (!traversedOps.insert(term))

1094 return;

1095

1096

1098 for (Value v : term->getOperands()) {

1099 if (!isa(v.getType()))

1100 continue;

1101 auto opResult = dyn_cast(v);

1102 if (!opResult)

1103 continue;

1104 worklist.push_back(opResult);

1105 }

1106 while (!worklist.empty()) {

1107 OpResult opResult = worklist.pop_back_val();

1109 if (!traversedOps.insert(defOp))

1110 continue;

1111 if (!term->getParentRegion()->findAncestorOpInRegion(*defOp))

1112 continue;

1114 for (auto alias : aliases) {

1115 Value v = alias.opOperand->get();

1116 if (!isa(v.getType()))

1117 continue;

1118 auto opResult = dyn_cast(v);

1119 if (!opResult)

1120 continue;

1121 worklist.push_back(opResult);

1122 }

1123 }

1124 });

1125

1126

1130 result.push_back(op);

1131 });

1132 return result;

1133 }

1134

1139

1141 if (heuristic ==

1144 } else {

1146

1148 return;

1149 orderedOps.push_back(op);

1150 });

1151 switch (heuristic) {

1153

1154 std::reverse(orderedOps.begin(), orderedOps.end());

1155 break;

1156 }

1158

1159 break;

1160 }

1162 assert(getOptions().analysisFuzzerSeed &&

1163 "expected that fuzzer seed it set");

1164

1165

1166

1167

1168 std::mt19937 g(getOptions().analysisFuzzerSeed);

1169 llvm::shuffle(orderedOps.begin(), orderedOps.end(), g);

1170 break;

1171 }

1172 default: {

1173 llvm_unreachable("unsupported heuristic");

1174 }

1175 }

1176 }

1177

1178

1179 for (Operation *op : orderedOps)

1181 return failure();

1182

1184 return success();

1185 }

1186

1187

1188

1189 static LogicalResult

1193

1194

1195

1196

1197 WalkResult walkResult = op->walk([&](BufferizableOpInterface op) {

1198

1199 if (options.isOpAllowed(op.getOperation()))

1201

1202

1203 if (!op.supportsUnstructuredControlFlow()) {

1204 for (Region &r : op->getRegions()) {

1205 if (r.getBlocks().size() > 1) {

1206 op->emitOpError("op or BufferizableOpInterface implementation does "

1207 "not support unstructured control flow, but at least "

1208 "one region has multiple blocks");

1209 return WalkResult::interrupt();

1210 }

1211 }

1212 }

1213

1215 });

1216 if (walkResult.wasInterrupted())

1217 return failure();

1218

1219 walkResult = op->walk([&](BufferizableOpInterface op) {

1220

1221 if (options.isOpAllowed(op.getOperation()))

1223

1224

1225

1226

1227 if (auto toTensorOp = dyn_cast(op.getOperation())) {

1228 if (!toTensorOp.getRestrict() && !toTensorOp->getUses().empty()) {

1229 op->emitOpError("to_tensor ops without `restrict` are not supported by "

1230 "One-Shot Analysis");

1232 }

1233 }

1234

1235 for (OpOperand &opOperand : op->getOpOperands()) {

1236 if (isa(opOperand.get().getType())) {

1238 opOperand, domInfo, state,

1239 true)) {

1240

1241

1242

1243

1244

1245 op->emitOpError("not bufferizable under the given constraints: "

1246 "cannot avoid RaW conflict");

1248 }

1249

1250 if (state.isInPlace(opOperand) &&

1252 opOperand, state, true)) {

1253 op->emitOpError("not bufferizable under the given constraints: would "

1254 "write to read-only buffer");

1256 }

1257 }

1258 }

1259

1261 });

1262

1263 return success(!walkResult.wasInterrupted());

1264 }

1265

1266

1267 static void

1270

1273 if (isa(opOperand.get().getType()))

1275 });

1276 }

1277

1282

1283 auto buildAliasesArray = [&](Value v) {

1285 state.applyOnAliases(v, [&](Value alias) {

1286 std::string buffer;

1287 llvm::raw_string_ostream stream(buffer);

1290 });

1292 };

1293

1295

1298 if (llvm::isa(opResult.getType())) {

1299 opResultAliasSets.push_back(buildAliasesArray(opResult));

1300 }

1301 }

1302 if (!opResultAliasSets.empty())

1304

1305

1307 bool hasTensorBbArg = false;

1310 for (Block &block : r.getBlocks()) {

1312 for (BlockArgument bbArg : block.getArguments()) {

1313 if (llvm::isa(bbArg.getType())) {

1314 bbArgAliasSets.push_back(buildAliasesArray(bbArg));

1315 hasTensorBbArg = true;

1316 }

1317 }

1318 blockAliasSets.push_back(b.getArrayAttr(bbArgAliasSets));

1319 }

1320 regionAliasSets.push_back(b.getArrayAttr(blockAliasSets));

1321 }

1322 if (hasTensorBbArg)

1324 });

1325 }

1326

1332

1334 return failure();

1335

1336

1337 if (failed(state.analyzeOp(op, domInfo)))

1338 return failure();

1339

1340 if (statistics) {

1341 statistics->numTensorInPlace = state.getStatNumTensorInPlace();

1343 }

1344

1345 bool failedAnalysis = false;

1346

1347

1348 state.gatherUndefinedTensorUses(op);

1349

1350

1351

1352

1354 if (BufferizableOpInterface bufferizableOp =

1355 options.dynCastBufferizableOp(op))

1356 failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state));

1357 });

1358

1359

1360 if (options.testAnalysisOnly)

1362 if (options.dumpAliasSets)

1364

1365 return success(!failedAnalysis);

1366 }

1367

1371

1372

1373 assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&

1374 "invalid combination of bufferization flags");

1375

1376 if (options.copyBeforeWrite) {

1377

1378 } else {

1379

1380

1381

1383 return failure();

1384

1385

1386

1387 if (options.testAnalysisOnly)

1388 return success();

1389 }

1390

1391

1392

1394 }

static bool hasReadAfterWriteInterference(const DenseSet< OpOperand * > &usesRead, const DenseSet< OpOperand * > &usesWrite, const DominanceInfo &domInfo, OneShotAnalysisState &state)

Given sets of uses and writes, return true if there is a RaW conflict under the assumption that all g...

static void getAliasingReads(DenseSet< OpOperand * > &res, Value root, const OneShotAnalysisState &state)

static void equivalenceAnalysis(SmallVector< Operation * > &ops, OneShotAnalysisState &state)

Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.

static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace)

Mark whether OpOperand will be bufferized inplace.

static SmallVector< Operation * > bottomUpFromTerminatorsHeuristic(Operation *op, const OneShotAnalysisState &state)

"Bottom-up from terminators" heuristic.

constexpr StringLiteral kInPlaceOperandsAttrName

Attribute marker to specify op operands that bufferize in-place.

static bool isaTensor(Type t)

static void annotateNonWritableTensor(Value value)

Annotate IR with details about the detected non-writability conflict.

static bool canUseOpDominanceDueToRegions(OpOperand *uRead, OpOperand *uWrite, const SetVector< Value > &definitions, AnalysisState &state)

Return true if op dominance can be used to rule out a read-after-write conflicts based on the orderin...

static LogicalResult bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state, const DominanceInfo &domInfo)

Determine if operand can be bufferized in-place.

constexpr StringLiteral kOpResultAliasSetAttrName

static bool happensBefore(Operation *a, Operation *b, const DominanceInfo &domInfo)

Return true if a happens before b, i.e., a or one of its ancestors properly dominates b and b is not ...

static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite, const SetVector< Value > &definitions, AnalysisState &state)

static bool matchesInsertDestination(const AnalysisState &state, OpOperand *opOperand, SubsetInsertionOpInterface subsetOp)

Return "true" if the given operand's value is originating from a subset that is equivalent to the sub...

static bool wouldCreateWriteToNonWritableBuffer(OpOperand &operand, OneShotAnalysisState &state, bool checkConsistencyOnly=false)

Return true if bufferizing operand inplace would create a write to a non-writable buffer.

static void annotateOpsWithAliasSets(Operation *op, const OneShotAnalysisState &state)

static LogicalResult checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo, OneShotAnalysisState &state)

Perform various checks on the input IR to see if it contains IR constructs that are unsupported by On...

static void annotateOpsWithBufferizationMarkers(Operation *op, const OneShotAnalysisState &state)

Annotate the IR with the result of the analysis. For testing/debugging only.

static bool wouldCreateReadAfterWriteInterference(OpOperand &operand, const DominanceInfo &domInfo, OneShotAnalysisState &state, bool checkConsistencyOnly=false)

Return true if bufferizing operand inplace would create a conflict.

constexpr StringLiteral kBbArgAliasSetAttrName

static bool canUseOpDominanceDueToBlocks(OpOperand *uRead, OpOperand *uWrite, const SetVector< Value > &definitions, AnalysisState &state)

Return true if op dominance can be used to rule out a read-after-write conflicts based on the orderin...

static void getAliasingInplaceWrites(DenseSet< OpOperand * > &res, Value root, const OneShotAnalysisState &state)

static bool areNonConflictingSubsets(OpOperand *uRead, OpOperand *uConflictingWrite, const AnalysisState &state)

Return "true" if the given "read" and potentially conflicting "write" are not conflicting due to thei...

static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, Value definition)

Annotate IR with details about the detected RaW conflict.

static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state, OpOperand *start, Value other)

Return 'true' if a tensor that is equivalent to other can be found in the reverse use-def chain of st...

static bool isInplaceMemoryWrite(OpOperand &opOperand, const OneShotAnalysisState &state)

Return true if opOperand has been decided to bufferize in-place.

static llvm::ManagedStatic< PassManagerOptions > options

#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)

Base class for generic analysis states.

This class provides management for the lifetime of the state used when printing the IR.

This class represents an argument of a Block.

Block represents an ordered list of Operations.

Operation * findAncestorOpInBlock(Operation &op)

Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...

bool isReachable(Block *other, SmallPtrSet< Block *, 16 > &&except={})

Return "true" if there is a path from this block to the given block (according to the successors rela...

This class is a general helper class for creating context-global objects like types,...

StringAttr getStringAttr(const Twine &bytes)

ArrayAttr getArrayAttr(ArrayRef< Attribute > value)

A class for computing basic dominance information.

bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const

Return true if operation A properly dominates operation B, i.e.

IRValueT get() const

Return the current value being used by this operand.

This class helps build Operations.

This class represents an operand of an operation.

unsigned getOperandNumber()

Return which operand this is in the OpOperand list of the Operation.

This is a value defined by a result of an operation.

Operation is the basic unit of execution within MLIR.

Attribute getAttr(StringAttr name)

Return the specified attribute if present, null otherwise.

std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)

Walk the operation by calling the callback for each nested operation (including this one),...

MLIRContext * getContext()

Return the context this operation is associated with.

unsigned getNumOperands()

Operation * getParentOp()

Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...

Block * getBlock()

Returns the operation block that contains this operation.

void setAttr(StringAttr name, Attribute value)

If the an attribute exists with the specified name, change it to the new value.

MutableArrayRef< Region > getRegions()

Returns the regions held by this operation.

MutableArrayRef< OpOperand > getOpOperands()

result_type_range getResultTypes()

bool isAncestor(Operation *other)

Return true if this operation is an ancestor of the other operation.

result_range getOpResults()

Region * getParentRegion()

Returns the region to which the instruction belongs.

result_range getResults()

bool isProperAncestor(Operation *other)

Return true if this operation is a proper ancestor of the other operation.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

Operation * getParentOp()

Return the parent operation this region is attached to.

This class provides an efficient unique identifier for a specific C++ type.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

MLIRContext * getContext() const

Utility to get the associated MLIRContext that this value is defined in.

Type getType() const

Return the type of this value.

use_range getUses() const

Returns a range of all uses, which is useful for iterating over all uses.

void printAsOperand(raw_ostream &os, AsmState &state) const

Print this value as if it were an operand.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

A utility result that is used to signal how to proceed with an ongoing walk:

static WalkResult advance()

static WalkResult interrupt()

size_t getNumAliases() const

ArrayRef< T > getAliases() const

AnalysisState provides a variety of helper functions for dealing with tensor values.

AliasingValueList getAliasingValues(OpOperand &opOperand) const

Determine which Value will alias with opOperand if the op is bufferized in place.

bool bufferizesToMemoryWrite(OpOperand &opOperand) const

Return true if opOperand bufferizes to a memory write.

SetVector< Value > findDefinitions(OpOperand *opOperand) const

Find the values that may define the contents of the given value at runtime.

virtual void resetCache()

BufferizationState provides information about the state of the IR during the bufferization process.

virtual ~Extension()

Base virtual destructor.

State for analysis-enabled bufferization.

void bufferizeOutOfPlace(OpOperand &operand)

Mark the given OpOperand as out-of-place.

bool isWritable(Value value) const

Return true if the buffer of the given tensor value is writable.

const SetVector< Value > & findDefinitionsCached(OpOperand *opOperand)

Find the definitions of the given operand's value or retrieve them from the cache.

bool isInPlace(OpOperand &opOperand) const override

Return true if the given OpResult has been decided to bufferize inplace.

LogicalResult analyzeOp(Operation *op, const DominanceInfo &domInfo)

Analyze the given op and its nested ops.

bool isValueWritten(Value value) const

Return true if the buffer of the given tensor value is written to.

const OneShotBufferizationOptions & getOptions() const

Return a reference to the BufferizationOptions.

void unionEquivalenceClasses(Value v1, Value v2)

Union the equivalence classes of v1 and v2.

void gatherUndefinedTensorUses(Operation *op)

Find all tensor values in the given operation that have undefined contents and store them in undefine...

void resetCache() override

Reset cached data structures.

LogicalResult analyzeSingleOp(Operation *op, const DominanceInfo &domInfo)

Analyze a single op (without nested ops).

void applyOnEquivalenceClass(Value v, function_ref< void(Value)> fun) const

Apply fun to all the members of the equivalence class of v.

bool hasUndefinedContents(OpOperand *opOperand) const override

Return true if the given tensor has undefined contents.

void bufferizeInPlace(OpOperand &operand)

Mark the given OpOperand as in-place and merge the results' and operand's aliasing sets.

void applyOnAliases(Value v, function_ref< void(Value)> fun) const

Apply fun to all aliases of v.

bool areEquivalentBufferizedValues(Value v1, Value v2) const override

Return true if v1 and v2 bufferize to equivalent buffers.

OneShotAnalysisState(Operation *op, const OneShotBufferizationOptions &options)

bool areAliasingBufferizedValues(Value v1, Value v2) const override

Return true if v1 and v2 may bufferize to aliasing buffers.

void unionAliasSets(Value v1, Value v2)

Union the alias sets of v1 and v2.

void createAliasInfoEntry(Value v)

Add a new entry for v in the aliasInfo and equivalentInfo.

Operation * getOwner() const

Return the owner of this operand.

LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)

Bufferize op and its nested ops that implement BufferizableOpInterface.

LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)

Analyze op and its nested ops.

Operation * getOwnerOfValue(Value value)

Return the owner of the given value.

LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)

Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.

LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)

Run One-Shot Bufferize on the given op: Analysis + Bufferization.

Region * getParallelRegion(Region *region, const BufferizationOptions &options)

If region is a parallel region, return region.

Region * getNextEnclosingRepetitiveRegion(Region *region, const BufferizationOptions &options)

Assuming that the given region is repetitive, find the next enclosing repetitive region.

bool hasTensorSemantics(Operation *op)

Return "true" if the given op has tensor semantics and should be bufferized.

Include the generated interface declarations.

const FrozenRewritePatternSet GreedyRewriteConfig config

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

This iterator enumerates elements in "reverse" order.

A maybe aliasing OpOperand.

Options for BufferizableOpInterface-based bufferization.

BufferizableOpInterface dynCastBufferizableOp(Operation *op) const

Try to cast the given op to BufferizableOpInterface if the op is allow listed.

bool isOpAllowed(Operation *op) const

Return true if the given op should be bufferized.

Bufferization statistics for debugging.

int64_t numTensorOutOfPlace

Options for analysis-enabled bufferization.

AnalysisHeuristic analysisHeuristic

The heuristic controls the order in which ops are traversed during the analysis.

@ BottomUpFromTerminators

Traversal parameters for findValueInReverseUseDefChain.