MLIR: lib/Dialect/Linalg/Transforms/HoistPadding.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

30 #include "llvm/Support/Debug.h"

31

32 using llvm::dbgs;

33

34 #define DEBUG_TYPE "hoist-padding"

35

36 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")

37

38 using namespace mlir;

41

42 #ifndef NDEBUG

45 (void)state;

46 if (auto forOp = dyn_castscf::ForOp(op)) {

47 forOp.getInductionVar().printAsOperand(dbgs(), state);

48 dbgs() << " @ " << forOp.getOperation();

49 return true;

50 }

51 return false;

52 }

53 #endif

54

56 LLVM_DEBUG(llvm::interleaveComma(backwardSlice, DBGS() << "--backwardSlice:",

58 dbgs() << "\n";

59 DBGS() << "----";

61 dbgs() << "\n";

62 return;

63 }

64 dbgs() << *op << "\n";

65 });

66 DBGS() << "\n";);

67 }

68

69

70

71

72

73 static void

76 scf::ForOp outermostEnclosingForOp = nullptr;

78 while (nLevels-- > 0 &&

79 (outermostEnclosingForOp = dyn_castscf::ForOp(nextEnclosingOp))) {

80 LLVM_DEBUG(DBGS() << "loops: ";

82 dbgs() << "\n");

83 reverseEnclosingLoops.push_back(outermostEnclosingForOp);

84 nextEnclosingOp = outermostEnclosingForOp->getParentOp();

85 }

86 }

87

88

89

90

91

92 static void

95 scf::ForOp outermostEnclosingForOp = nullptr;

97 while (outermostEnclosingForOp != untilLoop &&

98 (outermostEnclosingForOp = dyn_castscf::ForOp(nextEnclosingOp))) {

99 LLVM_DEBUG(DBGS() << "loops: ";

101 dbgs() << "\n");

102 reverseEnclosingLoops.push_back(outermostEnclosingForOp);

103 nextEnclosingOp = outermostEnclosingForOp->getParentOp();

104 }

105 }

106

107

108

109

110

112 scf::ForOp outermostEnclosingForOp,

117 return domInfo.dominates(outermostEnclosingForOp, op) &&

118 !padOp->isProperAncestor(op);

119 };

121

122

125 valuesDefinedAbove);

126 for (Value v : valuesDefinedAbove) {

127 LogicalResult result = getBackwardSlice(v, &backwardSlice, sliceOptions);

128 assert(result.succeeded() && "expected a backward slice");

129 (void)result;

130 }

131

132 LogicalResult result =

133 getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);

134 assert(result.succeeded() && "expected a backward slice");

135 (void)result;

136 }

137

138

139

140

141

142 namespace {

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158 struct HoistPaddingAnalysis {

159 HoistPaddingAnalysis(tensor::PadOp padOp, int numLoops);

160 HoistPaddingAnalysis(tensor::PadOp padOp, scf::ForOp outermostEnclosingForOp);

161

162 bool isValid() { return valid.has_value() && valid.value(); }

163 bool isInvalid() { return valid.has_value() && !valid.value(); }

164

165

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183 void enableHoistPadding(RewriterBase &rewriter);

184

185

186

187

188 void finalizeHoistPaddingAnalysis();

189

190 private:

191

192 std::optional valid;

193

194

195 tensor::PadOp opToHoist;

196

197

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224 LogicalResult dropNonIndexDependencies();

225

226 public:

227

228

229 scf::ForOp outermostEnclosingForOp;

230

231

232

234

235

236

237

238

239

241

242

243 tensor::ExtractSliceOp sliceOp;

244

245

246 scf::ForOp padConsumingForOp;

247 };

248

249 }

250

251 HoistPaddingAnalysis::HoistPaddingAnalysis(tensor::PadOp padOp, int numLoops)

252 : valid(std::nullopt), opToHoist(padOp) {

253

255 if (reverseEnclosingLoops.empty()) {

256 LLVM_DEBUG(DBGS() << "--No immediately enclosing loop -> Skip\n");

257 valid = false;

258 return;

259 }

260 outermostEnclosingForOp = reverseEnclosingLoops.back();

261 sliceOp = opToHoist.getSource().getDefiningOptensor::ExtractSliceOp();

262 if (!sliceOp) {

263 LLVM_DEBUG(DBGS() << "--Cannot find the extract slice op -> Skip\n");

264 valid = false;

265 return;

266 }

267 }

268

269 HoistPaddingAnalysis::HoistPaddingAnalysis(tensor::PadOp padOp,

270 scf::ForOp outermostEnclosingForOp)

271 : valid(std::nullopt), opToHoist(padOp) {

272

274 reverseEnclosingLoops);

275 if (reverseEnclosingLoops.empty()) {

276 LLVM_DEBUG(DBGS() << "--No immediately enclosing loop -> Skip\n");

277 valid = false;

278 return;

279 }

280 this->outermostEnclosingForOp = reverseEnclosingLoops.back();

281 if (this->outermostEnclosingForOp != outermostEnclosingForOp) {

282 LLVM_DEBUG(DBGS() << "--Unexpected outermost enclosing loop -> Skip\n");

283 valid = false;

284 return;

285 }

286 sliceOp = opToHoist.getSource().getDefiningOptensor::ExtractSliceOp();

287 if (!sliceOp) {

288 LLVM_DEBUG(DBGS() << "--Cannot find the extract slice op -> Skip\n");

289 valid = false;

290 return;

291 }

292 }

293

294 void HoistPaddingAnalysis::enableHoistPadding(RewriterBase &rewriter) {

295 if (isInvalid())

296 return;

297

298

299

300 if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) {

301 outermostEnclosingForOp = castscf::ForOp(

303 }

304 }

305

306 void HoistPaddingAnalysis::finalizeHoistPaddingAnalysis() {

307 if (isInvalid())

308 return;

309

310 if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) {

311 LLVM_DEBUG(DBGS() << "--outermostEnclosingForOp:\n"

312 << outermostEnclosingForOp << "\n"

313 << "--sliceOp: " << sliceOp << "\n"

314 << "--sliceOp.getSource(): " << sliceOp.getSource()

315 << "\n");

316 LLVM_DEBUG(DBGS() << "----Source not defined outside of loops -> Skip\n");

317 valid = false;

318 return;

319 }

320 if (sliceOp->hasOneUse()) {

321 padConsumingForOp = dyn_castscf::ForOp(*(sliceOp->getUsers().begin()));

322 }

323

324

325

326

327 Value paddingValue = opToHoist.getConstantPaddingValue();

328 if (!paddingValue ||

329 !isa_and_nonnullarith::ConstantOp(paddingValue.getDefiningOp())) {

330 LLVM_DEBUG(DBGS() << "Cannot find constant padding value -> Skip\n");

331 valid = false;

332 return;

333 }

334

336 if (backwardSlice.size() <= 1) {

337 valid = false;

338 return;

339 }

340

342

343

344

345 if (failed(dropNonIndexDependencies())) {

346 LLVM_DEBUG(DBGS() << "--Cannot dropNonIndexDependencies -> Skip\n");

347 valid = false;

348 return;

349 }

351

352

353

354

355

356

357

358 for (scf::ForOp forOp : llvm::reverse(reverseEnclosingLoops))

359 if (backwardSlice.contains(forOp))

360 packingLoops.push_back(forOp);

361

362

363 if (packingLoops.size() > 1 && padConsumingForOp) {

364 LLVM_DEBUG(DBGS() << "--Cannot hoist multiple loops through iter_args -> "

365 "Downgrade to 1 loop\n");

366 packingLoops.resize(1);

367 }

368

369

370

371

372

373 valid = true;

374 }

375

376 LogicalResult HoistPaddingAnalysis::dropNonIndexDependencies() {

377

379

380

381

382 auto addIndexOperandsToIndexEdges = [&](Operation *operation) {

383 for (Value operand : operation->getOperands())

384 if (operand.getType().isIndex())

385 indexEdges.insert(operand);

386 };

387

388

389 auto hasIndexResult = [&](Operation *operation) {

390 return llvm::any_of(operation->getResults(), [&](Value result) {

391 return indexEdges.contains(result);

392 });

393 };

394

395

396

397

398

399

400

401

402

403

404

405

406

407

408

409

410

411

412

413

414

416 for (Operation *op : llvm::reverse(backwardSlice)) {

417

418

419 if (op == opToHoist || op == sliceOp) {

420 addIndexOperandsToIndexEdges(op);

421 continue;

422 }

423

424

425 if (auto forOp = dyn_castscf::ForOp(op)) {

426 if (!hasIndexResult(op) && indexEdges.contains(forOp.getInductionVar())) {

427 addIndexOperandsToIndexEdges(op);

428 continue;

429 }

430 }

431

432

433 if (hasIndexResult(op)) {

434 addIndexOperandsToIndexEdges(op);

435

436 if (llvm::any_of(op->getOperandTypes(),

437 [](Type type) { return !type.isIndex(); })) {

438 LLVM_DEBUG(DBGS() << "Unsupported op with non index type operands: "

439 << op << " -> Skip\n");

440 return failure();

441 }

442

443 auto effectInterface = dyn_cast(op);

444 bool hasMemoryEffect = effectInterface && !effectInterface.hasNoEffect();

445 if (hasMemoryEffect || op->getNumRegions() != 0) {

446 LLVM_DEBUG(DBGS() << "Unsupported op with region or memory effect: "

447 << op << " -> Skip\n");

448 return failure();

449 }

450 continue;

451 }

452

453

454 if (!isaarith::ConstantOp(op))

455 operationsToRemove.insert(op);

456 }

457 backwardSlice.set_subtract(operationsToRemove);

458 return success();

459 }

460

462 HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,

465

466

467

468

469

470

471 for (auto forOp : packingLoops) {

472

474 rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),

475

477 if (v == forOp.getUpperBound())

478 return false;

479

480 Operation *op = v.getDefiningOp();

481 if (!op)

482 return true;

483 return !isa<affine::AffineMinOp, affine::AffineMaxOp,

484 affine::AffineApplyOp>(op);

485 },

486 true);

487 assert(succeeded(loopUb) && "could not get upper bound");

489

490

491

492

493

498 loc, (ub - lb).ceilDiv(step),

499 ValueRange{forOp.getLowerBound(), ubVal,

500 castscf::ForOp(forOp).getStep()});

501 dynamicTensorSizes.push_back(res);

502 }

503

504 return dynamicTensorSizes;

505 }

506

509 }

510

511

512

513

514

515

516

517

518

520 scf::ForOp forOp) {

528 Value ivVal = forOp.getInductionVar(), lbVal = forOp.getLowerBound(),

529 stepVal = forOp.getStep();

530 auto loc = forOp->getLoc();

531 return rewriter.createOrFoldaffine::AffineApplyOp(

532 loc, (iv - lb).ceilDiv(step), ValueRange{ivVal, lbVal, stepVal});

533 }

534

535

536

537

538

539

540

541

542

543

544

545

548 ArrayRef<int64_t> transposeVector, RankedTensorType transposedTensorType,

549 tensor::EmptyOp emptyOp, const HoistPaddingAnalysis &analysis) {

551 SmallVector clonedLoopIvs, leadingHoistedPackedTensorIndexings;

552

553 scf::ForOp outerLoop = analysis.outermostEnclosingForOp;

554

555 Location loc = opToHoist->getLoc();

556 RankedTensorType paddedTensorType = opToHoist.getResultType();

557 int paddedRank = paddedTensorType.getRank();

558

559

560 BlockArgument bbArg = dyn_cast(opToHoist.getSource());

561 while (bbArg) {

563 if (!forOp)

564 break;

565 if (forOp != outerLoop && !outerLoop->isAncestor(forOp))

566 break;

567 OpOperand &operand = *forOp.getTiedLoopInit(bbArg);

568 bvm.map(bbArg, operand.get());

569 bbArg = dyn_cast(operand.get());

570 }

571

572

573 Value hoistedPackedTensor = emptyOp.getResult();

575 for (Operation *op : analysis.backwardSlice) {

576

577

578 if (auto sliceOp = dyn_casttensor::ExtractSliceOp(op)) {

579 if (bvm.lookupOrDefault(sliceOp.getSource()) == hoistedPackedTensor) {

580 LLVM_DEBUG(DBGS() << "--Skip: " << sliceOp << "\n");

581 continue;

582 }

583 }

584

585

586 auto forOp = dyn_castscf::ForOp(op);

587 if (!forOp) {

588

589 rewriter.clone(*op, bvm);

590 continue;

591 }

592

593

594

595 auto clonedForOp = rewriter.createscf::ForOp(

598 bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor);

599

600

601 bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());

602 bvm.map(forOp.getRegionIterArgs(), clonedForOp.getRegionIterArgs());

603 bvm.map(forOp.getResults(), clonedForOp.getResults());

604 assert(clonedForOp->getNumRegions() == 1);

605 clonedLoopIvs.push_back(clonedForOp.getInductionVar());

606

607

609 Value loopIndependentIterationCount =

611

612

613 if (!loopIndependentIterationCount)

614 llvm_unreachable("loop independence prerequisite not met");

615 leadingHoistedPackedTensorIndexings.push_back(

616 loopIndependentIterationCount);

617 hoistedPackedTensor = clonedForOp.getRegionIterArgs().front();

618 }

619

620

621

622 int64_t nPackedLoops = clonedLoopIvs.size();

623

624 offsets =

626 leadingHoistedPackedTensorIndexings.end()};

627 offsets.append(paddedRank, rewriter.getIndexAttr(0));

628

630 for (int64_t sz : transposedTensorType.getShape()) {

631

632 if (ShapedType::isDynamic(sz))

633 return failure();

635 }

636

639

640

641 TransposeOp maybeTransposeOp;

642 Value paddedTensor = bvm.lookup(opToHoist.getResult());

643 if (!transposeVector.empty()) {

644 Value outputTensor = rewriter.createtensor::ExtractSliceOp(

645 loc, transposedTensorType, hoistedPackedTensor, offsets, sizes,

646 strides);

647 maybeTransposeOp = rewriter.createlinalg::TransposeOp(

648 loc, paddedTensor, outputTensor, transposeVector);

649 paddedTensor = maybeTransposeOp.getResult()[0];

650 }

651

652

653 if (nPackedLoops > 0) {

654

655

656 Value inserted = rewriter.createtensor::InsertSliceOp(

657 loc, paddedTensor, hoistedPackedTensor, offsets, sizes, strides);

658

659

660 Value valueToYield = inserted;

661 for (Value iv : llvm::reverse(clonedLoopIvs)) {

664 rewriter.createscf::YieldOp(loc, valueToYield);

665 valueToYield = forOp.getResult(0);

666 }

667 }

668

670 offsets,

671 sizes,

672 strides,

673 clonedLoopIvs,

674 leadingHoistedPackedTensorIndexings,

675 maybeTransposeOp,

676 casttensor::PadOp(bvm.lookup(opToHoist.getResult()).getDefiningOp())};

677 }

678

679

680

681

684 ArrayRef<int64_t> transposeVector, const HoistPaddingAnalysis &analysis) {

685

686 int nPackedLoops = analysis.packingLoops.size();

687 LLVM_DEBUG(DBGS() << "\n";

688 DBGS() << "Func:\n"

689 << *opToHoist->getParentOfTypefunc::FuncOp() << "\n";

690 DBGS() << "Start hoisting above " << nPackedLoops << " loops\n");

691

692 Location loc = opToHoist->getLoc();

693 RankedTensorType paddedTensorType = opToHoist.getResultType();

694

695

696 FailureOr transposedTensorType =

698 if (failed(transposedTensorType)) {

699 LLVM_DEBUG(DBGS() << "--Could not compute transposed type -> Skip\n");

700 return failure();

701 }

702

703

705

706 llvm::append_range(packedShape, transposedTensorType->getShape());

708 packedShape, transposedTensorType->getElementType());

709

710

711 scf::ForOp outerLoop = analysis.outermostEnclosingForOp;

715 analysis.getHoistedPackedTensorSizes(rewriter, loc);

716 auto emptyOp = rewriter.createtensor::EmptyOp(

717 loc, hoistedPackedTensorType.getShape(),

718 hoistedPackedTensorType.getElementType(), dynamicTensorSizes);

719

721 *transposedTensorType, emptyOp, analysis);

722 }

723

724

725

726

728 RewriterBase &rewriter, tensor::PadOp opToHoist,

729 scf::ForOp outermostEnclosingForOp, ArrayRef<int64_t> transposeVector) {

730 HoistPaddingAnalysis analysis(opToHoist, outermostEnclosingForOp);

731 analysis.enableHoistPadding(rewriter);

732 analysis.finalizeHoistPaddingAnalysis();

733 if (!analysis.isValid()) {

734 LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n");

735 return failure();

736 }

739 analysis);

740 }

741

742

743

744

745

746

747

748

749

750

751

753 Value expectedSource) {

754 LLVM_DEBUG(DBGS() << "Start tracesBackToExpectedValue on: " << extractSliceOp

755 << "\n");

756 LLVM_DEBUG(DBGS() << "--with extractSlice: " << extractSliceOp << "\n");

757 Value source = extractSliceOp.getSource();

758 LLVM_DEBUG(DBGS() << "--with starting source: " << source << "\n");

759 while (source && source != expectedSource) {

760 auto destOp =

761 dyn_cast_or_null(source.getDefiningOp());

762 if (!destOp)

763 break;

764 LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n");

765 source = destOp.getDpsInitOperand(cast(source).getResultNumber())

766 ->get();

767 }

768 LLVM_DEBUG(DBGS() << "--final source: " << source << "\n");

769 LLVM_DEBUG(DBGS() << "--expected source: " << expectedSource << "\n");

770 return source == expectedSource;

771 }

772

773

774

775

776

777

778

779

780

781

782

783

784

785

786

787

788

789

790

791

792

793

794

795

796

797

798

799

800

801 static tensor::ExtractSliceOp

803 Value hoistedPackedTensor,

804 tensor::ExtractSliceOp outerSliceOp, scf::ForOp forOp) {

805 LLVM_DEBUG(DBGS() << "Start padThroughLoopIterArg on: " << forOp << "\n");

806 LLVM_DEBUG(DBGS() << "--paddedValueBeforeHoisting: "

807 << paddedValueBeforeHoisting << "\n");

809 for (OpOperand &use : outerSliceOp->getUses()) {

810 if (use.getOwner() == forOp) {

811 assert(!pUse && "Multiple slice uses in the for loop");

812 pUse = &use;

813 }

814 }

815 assert(pUse && "No slice use in the for loop");

818

819 unsigned iterArgNumber = forOp.getTiedLoopResult(pUse).getResultNumber();

820 auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]

821 .getDefiningOptensor::ExtractSliceOp();

822 if (!yieldingExtractSliceOp)

823 return tensor::ExtractSliceOp();

824

825

826

827

829 paddedValueBeforeHoisting))

830 return tensor::ExtractSliceOp();

831

833 initArgs[iterArgNumber] = hoistedPackedTensor;

834 SmallVector yieldOperands = llvm::to_vector(forOp.getYieldedValues());

835 yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();

836

837 int64_t numOriginalForOpResults = initArgs.size();

838 LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults

839 << "\n");

840 tensor::ExtractSliceOp extracted;

841 {

844 extracted = rewriter.createtensor::ExtractSliceOp(

845 hoistedPackedTensor.getLoc(), hoistedPackedTensor,

846 outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(),

847 outerSliceOp.getMixedStrides());

848 rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted);

849 }

850 scf::ForOp newForOp = castscf::ForOp(*forOp.replaceWithAdditionalYields(

851 rewriter, initArgs, true,

853 return yieldOperands;

854 }));

855

856 LLVM_DEBUG(DBGS() << "newForOp results: " << newForOp.getNumResults()

857 << "\n");

858 LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n");

859 LLVM_DEBUG(DBGS() << "with result #"

860 << numOriginalForOpResults + iterArgNumber

861 << " of forOp, giving us: " << extracted << "\n");

863 extracted.getSourceMutable().assign(

864 newForOp.getResult(numOriginalForOpResults + iterArgNumber));

866

867 LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting

868 << "\n");

869 LLVM_DEBUG(DBGS() << "with region iter arg #"

870 << numOriginalForOpResults + iterArgNumber << "\n");

872 paddedValueBeforeHoisting,

873 newForOp.getRegionIterArg(numOriginalForOpResults + iterArgNumber));

874

875 return extracted;

876 }

877

878

879

882 tensor::PadOp opToHoist,

883 RankedTensorType transposedTensorType,

884 const HoistPaddingAnalysis &analysis,

886

887

890

891 Location loc = opToHoist->getLoc();

892 RankedTensorType paddedTensorType = opToHoist.getResultType();

893 int paddedRank = paddedTensorType.getRank();

894

895 int64_t nPackedLoops = packingResult.clonedLoopIvs.size();

896 LLVM_DEBUG(DBGS() << "nPackedLoops: " << nPackedLoops << " loops\n");

897

898 scf::ForOp outerLoop = analysis.outermostEnclosingForOp;

900

901 Value hoistedPackedTensor;

905 if (nPackedLoops > 0) {

906 loopIterationCounts =

907 llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) {

909 castscf::ForOp(loop));

910 }));

911

912 if (llvm ::any_of(loopIterationCounts, [](Value v) { return !v; }))

913 llvm_unreachable("loop independence prerequisite not met");

914

915

916 std::copy(loopIterationCounts.begin(), loopIterationCounts.end(),

917 offsets.begin());

918 hoistedPackedTensor =

920 ->getResult(0);

921 } else {

922

923 hoistedPackedTensor = bvm.lookup(opToHoist.getResult());

924 }

925

926 LLVM_DEBUG(DBGS() << "hoistedPackedTensor: " << hoistedPackedTensor << "\n");

927

928

929 scf::ForOp forOp = analysis.padConsumingForOp;

930 if (forOp) {

932 analysis.sliceOp, forOp);

933 }

934

935

936

937

938 return rewriter.createtensor::ExtractSliceOp(

939 loc, transposedTensorType, hoistedPackedTensor, offsets,

941 }

942

944 RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops,

947 LLVM_DEBUG(DBGS() << "\n"; DBGS() << " Try to hoist " << *(opToHoist) << "\n";

948 DBGS() << " by " << numLoops << " loops\n");

949

950 HoistPaddingAnalysis analysis(opToHoist, numLoops);

951 analysis.enableHoistPadding(rewriter);

952 analysis.finalizeHoistPaddingAnalysis();

953 if (!analysis.isValid()) {

954 LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n");

955 return failure();

956 }

957

958

961 rewriter, bvm, opToHoist, transposeVector, analysis);

962 if (failed(packingResult)) {

963 LLVM_DEBUG(DBGS() << "--buildPackingLoopNestImpl failed -> Skip\n");

964 return failure();

965 }

966

967 if (!transposeVector.empty())

968 transposeOps.push_back(packingResult->maybeTransposeOp);

969

970 FailureOr transposedTensorType =

972 assert(succeeded(transposedTensorType) && "unexpected failure in type");

973

974

975

976 Value newResult =

978 analysis, *packingResult);

979

980 Location loc = opToHoist->getLoc();

981 RankedTensorType paddedTensorType = opToHoist.getResultType();

982 if (!transposeVector.empty()) {

985

986 Value emptyTensor = rewriter.createtensor::EmptyOp(

987 loc, paddedTensorType.getShape(), paddedTensorType.getElementType());

988 TransposeOp unTransposeOp = rewriter.createlinalg::TransposeOp(

989 loc, newResult, emptyTensor, transposeVector);

990 newResult = unTransposeOp.getResult()[0];

991 transposeOps.push_back(unTransposeOp);

992 }

993

994 LLVM_DEBUG(DBGS() << "newResult: " << newResult << "\n");

995 LLVM_DEBUG(

996 DBGS() << "After hoisting: "

998 << "\n");

999

1000

1001 hoistedOp = packingResult->hoistedPadOp;

1002

1003 LLVM_DEBUG(DBGS() << "--SUCCESS\n");

1004 return newResult;

1005 }

1006

1008 tensor::PadOp opToHoist, int64_t numLoops,

1011 IRRewriter rewriter(opToHoist.getContext());

1013 hoistedOp, transposeOps);

1014 }

static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)

Copies the given number of bytes from src to dst pointers.

static tensor::ExtractSliceOp padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting, Value hoistedPackedTensor, tensor::ExtractSliceOp outerSliceOp, scf::ForOp forOp)

If the original consumer of outerSliceOp was a forOp (i.e.

static Value buildLoopIterationCount(RewriterBase &rewriter, scf::ForOp outer, scf::ForOp forOp)

Return the current iteration number in the loop (iv - lb).ceilDiv(step).

static void getEnclosingLoopsUntil(tensor::PadOp padOp, scf::ForOp untilLoop, SmallVector< scf::ForOp > &reverseEnclosingLoops)

Return at most nLevels of immediately enclosing scf::ForOp loops.

static bool debugPrintLoopInShortForm(Operation *op)

static bool tracesBackToExpectedValue(tensor::ExtractSliceOp extractSliceOp, Value expectedSource)

Return true if we can walk back the use-def chain from extractSliceOp to expectedSource going through...

static bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v)

static FailureOr< PackingResult > buildPackingLoopNestImpl(RewriterBase &rewriter, IRMapping &bvm, tensor::PadOp opToHoist, ArrayRef< int64_t > transposeVector, RankedTensorType transposedTensorType, tensor::EmptyOp emptyOp, const HoistPaddingAnalysis &analysis)

static void computeBackwardSlice(tensor::PadOp padOp, scf::ForOp outermostEnclosingForOp, SetVector< Operation * > &backwardSlice)

static Value replaceByPackingResult(RewriterBase &rewriter, const IRMapping &bvm, tensor::PadOp opToHoist, RankedTensorType transposedTensorType, const HoistPaddingAnalysis &analysis, const PackingResult &packingResult)

Produce a tensor extracted from the packingResult.

static void debugPrintBackwardSlice(SetVector< Operation * > &backwardSlice)

static void getAtMostNEnclosingLoops(tensor::PadOp padOp, int nLevels, SmallVector< scf::ForOp > &reverseEnclosingLoops)

Return at most nLevels of immediately enclosing scf::ForOp loops.

Base type for affine expression.

This class provides management for the lifetime of the state used when printing the IR.

This class represents an argument of a Block.

Block * getOwner() const

Returns the block that owns this argument.

Operation * getParentOp()

Returns the closest surrounding operation that contains this block.

IntegerAttr getIndexAttr(int64_t value)

MLIRContext * getContext() const

A class for computing basic dominance information.

bool dominates(Operation *a, Operation *b) const

Return true if operation A dominates operation B, i.e.

This is a utility class for mapping one set of IR entities to another.

auto lookupOrDefault(T from) const

Lookup a mapped value within the map.

auto lookup(T from) const

Lookup a mapped value within the map.

void map(Value from, Value to)

Inserts a new mapping for 'from' to 'to'.

IRValueT get() const

Return the current value being used by this operand.

This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Operation * clone(Operation &op, IRMapping &mapper)

Creates a deep copy of the specified operation, remapping any operands that use values outside of the...

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)

Create an operation of specific op type at the current insertion point, and immediately try to fold i...

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

This class represents an operand of an operation.

Operation is the basic unit of execution within MLIR.

OpResult getResult(unsigned idx)

Get the 'idx'th result of this operation.

Operation * getParentOp()

Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...

OpTy getParentOfType()

Return the closest surrounding parent operation that is of type 'OpTy'.

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

void replaceAllUsesWith(Value from, Value to)

Find uses of from and replace them with to.

virtual void finalizeOpModification(Operation *op)

This method is used to signal the end of an in-place modification of the given operation.

virtual void startOpModification(Operation *op)

This method is used to notify the rewriter that an in-place operation modification is about to happen...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

A helper class to be used with ValueBoundsOpInterface.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Location getLoc() const

Return the location of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

AffineForOp getForInductionVarOwner(Value val)

Returns the loop parent of an induction variable.

FailureOr< OpFoldResult > reifyIndexValueBound(OpBuilder &b, Location loc, presburger::BoundType type, Value value, ValueBoundsConstraintSet::StopConditionFn stopCondition=nullptr, bool closedUB=false)

Reify a bound for the given index-typed value in terms of SSA values for which stopCondition is met.

void bindDims(MLIRContext *ctx)

void bindSymbols(MLIRContext *ctx)

FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)

Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.

FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)

Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.

FailureOr< RankedTensorType > computeTransposedType(RankedTensorType rankedTensorType, ArrayRef< int64_t > transposeVector)

Returns the transposed rankedTensorType if transposeVector is non-empty.

Include the generated interface declarations.

bool matchPattern(Value value, const Pattern &pattern)

Entry point for matching a pattern over a Value.

LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})

Fills backwardSlice with the computed backward slice (i.e.

LoopLikeOpInterface hoistLoopInvariantSubsets(RewriterBase &rewriter, LoopLikeOpInterface loopLike)

Hoist loop-invariant tensor subsets (subset extraction and subset insertion ops) from loop-like ops.

void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)

Fill values with a list of values defined at the ancestors of the limit region and used within region...

Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)

Converts an OpFoldResult to a Value.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

detail::constant_op_matcher m_Constant()

Matches a constant foldable operation.

bool inclusive

Include the top level op in the slice.

Helper struct to hold the results of building a packing loop nest.

SmallVector< OpFoldResult > strides

SmallVector< Value > clonedLoopIvs

SmallVector< OpFoldResult > sizes