MLIR: lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

29

30 using namespace mlir;

32

33

34

35

36

37 void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns(

40 }

41

42 void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns(

45 }

46

47 void transform::ApplySCFStructuralConversionPatternsOp::

48 populateConversionTargetRules(const TypeConverter &typeConverter,

51 conversionTarget);

52 }

53

54 void transform::ApplySCFToControlFlowPatternsOp::populatePatterns(

57 }

58

59

60

61

62

67 auto payload = state.getPayloadOps(getTarget());

68 if (!llvm::hasSingleElement(payload))

69 return emitSilenceableError() << "expected a single payload op";

70

71 auto target = dyn_castscf::ForallOp(*payload.begin());

72 if (!target) {

74 emitSilenceableError() << "expected the payload to be scf.forall";

75 diag.attachNote((*payload.begin())->getLoc()) << "payload op";

77 }

78

79 if (!target.getOutputs().empty()) {

80 return emitSilenceableError()

81 << "unsupported shared outputs (didn't bufferize?)";

82 }

83

85

86 if (getNumResults() != lbs.size()) {

88 emitSilenceableError()

89 << "op expects as many results (" << getNumResults()

90 << ") as payload has induction variables (" << lbs.size() << ")";

91 diag.attachNote(target.getLoc()) << "payload op";

93 }

94

98 << "failed to convert forall into for";

100 }

101

103 results.set(cast(getTransformed()[i]), {res});

104 }

106 }

107

108

109

110

111

116 auto payload = state.getPayloadOps(getTarget());

117 if (!llvm::hasSingleElement(payload))

118 return emitSilenceableError() << "expected a single payload op";

119

120 auto target = dyn_castscf::ForallOp(*payload.begin());

121 if (!target) {

123 emitSilenceableError() << "expected the payload to be scf.forall";

124 diag.attachNote((*payload.begin())->getLoc()) << "payload op";

126 }

127

128 if (!target.getOutputs().empty()) {

129 return emitSilenceableError()

130 << "unsupported shared outputs (didn't bufferize?)";

131 }

132

133 if (getNumResults() != 1) {

135 << "op expects one result, given "

136 << getNumResults();

137 diag.attachNote(target.getLoc()) << "payload op";

139 }

140

141 scf::ParallelOp opResult;

144 emitSilenceableError() << "failed to convert forall into parallel";

146 }

147

148 results.set(cast(getTransformed()[0]), {opResult});

150 }

151

152

153

154

155

156

157

158

162 return nullptr;

165 scf::ExecuteRegionOp executeRegionOp =

167 {

172 assert(clonedRegion.empty() && "expected empty region");

174 clonedRegion.end());

176 }

178 return executeRegionOp;

179 }

180

188 for (Operation *target : state.getPayloadOps(getTarget())) {

189 Location location = target->getLoc();

192 if (!exec) {

194 << "failed to outline";

195 diag.attachNote(target->getLoc()) << "target op";

197 }

198 func::CallOp call;

200 rewriter, location, exec.getRegion(), getFuncName(), &call);

201

202 if (failed(outlined))

203 return emitDefaultDefiniteFailure(target);

204

205 if (symbolTableOp) {

207 symbolTables.try_emplace(symbolTableOp, symbolTableOp)

208 .first->getSecond();

209 symbolTable.insert(*outlined);

211 }

212 functions.push_back(*outlined);

213 calls.push_back(call);

214 }

215 results.set(cast(getFunction()), functions);

216 results.set(cast(getCall()), calls);

218 }

219

220

221

222

223

226 scf::ForOp target,

229 scf::ForOp result;

230 if (getPeelFront()) {

231 LogicalResult status =

233 if (failed(status)) {

235 emitSilenceableError() << "failed to peel the first iteration";

237 }

238 } else {

239 LogicalResult status =

241 if (failed(status)) {

243 << "failed to peel the last iteration";

245 }

246 }

247

250

252 }

253

254

255

256

257

258

259

260

261 static void

263 std::vector<std::pair<Operation *, unsigned>> &schedule,

264 unsigned iterationInterval, unsigned readLatency) {

265 auto getLatency = [&](Operation *op) -> unsigned {

266 if (isavector::TransferReadOp(op))

267 return readLatency;

268 return 1;

269 };

270

271 std::optional<int64_t> ubConstant =

273 std::optional<int64_t> lbConstant =

276 std::map<unsigned, std::vector<Operation *>> wrappedSchedule;

277 for (Operation &op : forOp.getBody()->getOperations()) {

278 if (isascf::YieldOp(op))

279 continue;

280 unsigned earlyCycle = 0;

281 for (Value operand : op.getOperands()) {

282 Operation *def = operand.getDefiningOp();

283 if (!def)

284 continue;

285 if (ubConstant && lbConstant) {

286 unsigned ubInt = ubConstant.value();

287 unsigned lbInt = lbConstant.value();

288 auto minLatency = std::min(ubInt - lbInt - 1, getLatency(def));

289 earlyCycle = std::max(earlyCycle, opCycles[def] + minLatency);

290 } else {

291 earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));

292 }

293 }

294 opCycles[&op] = earlyCycle;

295 wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);

296 }

297 for (const auto &it : wrappedSchedule) {

298 for (Operation *op : it.second) {

299 unsigned cycle = opCycles[op];

300 schedule.emplace_back(op, cycle / iterationInterval);

301 }

302 }

303 }

304

307 scf::ForOp target,

312 [this](scf::ForOp forOp,

313 std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {

314 loopScheduling(forOp, schedule, getIterationInterval(),

315 getReadLatency());

316 };

319 FailureOrscf::ForOp patternResult =

321 if (succeeded(patternResult)) {

322 results.push_back(*patternResult);

324 }

325 return emitDefaultSilenceableFailure(target);

326 }

327

328

329

330

331

336 (void)target.promoteIfSingleIteration(rewriter);

338 }

339

340 void transform::LoopPromoteIfOneIterationOp::getEffects(

344 }

345

346

347

348

349

355 LogicalResult result(failure());

356 if (scf::ForOp scfFor = dyn_castscf::ForOp(op))

358 else if (AffineForOp affineFor = dyn_cast(op))

360 else

361 return emitSilenceableError()

362 << "failed to unroll, incorrect type of payload";

363

364 if (failed(result))

365 return emitSilenceableError() << "failed to unroll";

366

368 }

369

370

371

372

373

378 LogicalResult result(failure());

379 if (scf::ForOp scfFor = dyn_castscf::ForOp(op))

381 else if (AffineForOp affineFor = dyn_cast(op))

383 else

384 return emitSilenceableError()

385 << "failed to unroll and jam, incorrect type of payload";

386

387 if (failed(result))

388 return emitSilenceableError() << "failed to unroll and jam";

389

391 }

392

393

394

395

396

402 LogicalResult result(failure());

403 if (scf::ForOp scfForOp = dyn_castscf::ForOp(op))

405 else if (AffineForOp affineForOp = dyn_cast(op))

407

409 if (failed(result)) {

411 << "failed to coalesce";

413 }

415 }

416

417

418

419

420

421

424 assert(llvm::hasSingleElement(region) && "expected single-region block");

430 rewriter.eraseOp(terminator);

431 }

432

439 getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();

440 if (!llvm::hasSingleElement(region)) {

442 << "requires an scf.if op with a single-block "

443 << ((getTakeElseBranch()) ? "`else`" : "`then`") << " region";

444 }

447 }

448

449 void transform::TakeAssumedBranchOp::getEffects(

453 }

454

455

456

457

458

459

460

461

462

463

464

467

468 if (target == source)

470 << "target and source need to be different loops";

471

472

475 << "target and source are not in the same block";

476

477

480

481

483 if (!domInfo.properlyDominates(source, user, false)) {

485 << "user of results of target should be properly dominated by "

486 "source";

487 }

488 }

489 } else {

490

491

492

493

495 Operation *operandOp = operand.getDefiningOp();

496

497

498 if (!operandOp)

499 continue;

500

501

503 false))

505 << "operands of target should be properly dominated by source";

506 }

507

508

509 bool failed = false;

510 OpOperand *failedValue = nullptr;

512 Operation *operandOp = operand->get().getDefiningOp();

513 if (operandOp && !domInfo.properlyDominates(operandOp, source,

514 false)) {

515

516

517 failed = true;

518 failedValue = operand;

519 }

520 });

521

522 if (failed)

524 << "values used inside regions of target should be properly "

525 "dominated by source";

526 }

527

529 }

530

531

532

533

534

535

538 auto targetOp = dyn_castscf::ForallOp(target);

539 auto sourceOp = dyn_castscf::ForallOp(source);

540 if (!targetOp || !sourceOp)

541 return false;

542

543 return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&

544 targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&

545 targetOp.getMixedStep() == sourceOp.getMixedStep() &&

546 targetOp.getMapping() == sourceOp.getMapping();

547 }

548

549

550

551

552

553

556 auto targetOp = dyn_castscf::ForOp(target);

557 auto sourceOp = dyn_castscf::ForOp(source);

558 if (!targetOp || !sourceOp)

559 return false;

560

561 return targetOp.getLowerBound() == sourceOp.getLowerBound() &&

562 targetOp.getUpperBound() == sourceOp.getUpperBound() &&

563 targetOp.getStep() == sourceOp.getStep();

564 }

565

570 auto targetOps = state.getPayloadOps(getTarget());

571 auto sourceOps = state.getPayloadOps(getSource());

572

573 if (!llvm::hasSingleElement(targetOps) ||

574 !llvm::hasSingleElement(sourceOps)) {

576 << "requires exactly one target handle (got "

577 << llvm::range_size(targetOps) << ") and exactly one "

578 << "source handle (got " << llvm::range_size(sourceOps) << ")";

579 }

580

581 Operation *target = *targetOps.begin();

582 Operation *source = *sourceOps.begin();

583

584

586 if (diag.succeeded())

588

590

593 castscf::ForOp(target), castscf::ForOp(source), rewriter);

596 castscf::ForallOp(target), castscf::ForallOp(source), rewriter);

597 } else {

599 << "operations cannot be fused";

600 }

601

602 assert(fusedLoop && "failed to fuse operations");

603

604 results.set(cast(getFusedLoop()), {fusedLoop});

606 }

607

608

609

610

611

612 namespace {

613 class SCFTransformDialectExtension

615 SCFTransformDialectExtension> {

616 public:

618

619 using Base::Base;

620

621 void init() {

622 declareGeneratedDialectaffine::AffineDialect();

623 declareGeneratedDialectfunc::FuncDialect();

624

625 registerTransformOps<

626 #define GET_OP_LIST

627 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"

628 >();

629 }

630 };

631 }

632

633 #define GET_OP_CLASSES

634 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"

635

637 registry.addExtensions();

638 }

static std::string diag(const llvm::Value &value)

static llvm::ManagedStatic< PassManagerOptions > options

static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)

static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)

static bool isForWithIdenticalConfiguration(Operation *target, Operation *source)

Check if target scf.for can be fused into source scf.for.

static DiagnosedSilenceableFailure isOpSibling(Operation *target, Operation *source)

Check if target and source are siblings, in the context that target is being fused into source.

static void loopScheduling(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned >> &schedule, unsigned iterationInterval, unsigned readLatency)

Callback for PipeliningOption.

static bool isForallWithIdenticalConfiguration(Operation *target, Operation *source)

Check if target scf.forall can be fused into source scf.forall.

static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, Region &region)

Replaces the given op with the contents of the given single-block region, using the operands of the b...

static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, Operation *op)

Wraps the given operation op into an scf.execute_region operation.

#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)

Block represents an ordered list of Operations.

Operation * getTerminator()

Get the terminator operation of this block.

This class describes a specific conversion target.

The result of a transform IR operation application.

static DiagnosedSilenceableFailure success()

Constructs a DiagnosedSilenceableFailure in the success state.

The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.

void addExtensions()

Add the given extensions to the registry.

A class for computing basic dominance information.

bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const

Return true if operation A properly dominates operation B, i.e.

static FlatSymbolRefAttr get(StringAttr value)

Construct a symbol reference for the given value name.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

RAII guard to reset the insertion point of the builder when destroyed.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)

Creates a deep copy of this operation but keep the operation regions empty.

This class represents an operand of an operation.

Operation is the basic unit of execution within MLIR.

bool isBeforeInBlock(Operation *other)

Given an operation 'other' that is within the same parent block, return whether the current operation...

unsigned getNumRegions()

Returns the number of regions held by this operation.

Location getLoc()

The source location the operation was defined or derived from.

Block * getBlock()

Returns the operation block that contains this operation.

MutableArrayRef< Region > getRegions()

Returns the regions held by this operation.

result_type_range getResultTypes()

operand_range getOperands()

Returns an iterator on the underlying Value's.

user_range getUsers()

Returns a range of all users.

result_range getResults()

This class contains a list of basic blocks and a link to the parent operation it is attached to.

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

virtual void eraseOp(Operation *op)

This method erases an operation that is known to have no uses.

void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)

Move the blocks that belong to "region" before the given position in another region "parent".

virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)

Inline the operations of block 'source' into block 'dest' before the given position.

This class allows for representing and managing the symbol table used by operations with the 'SymbolT...

StringAttr insert(Operation *symbol, Block::iterator insertPt={})

Insert a new symbol into the table, and rename it as necessary to avoid collisions.

static Operation * getNearestSymbolTable(Operation *from)

Returns the nearest symbol table from a given operation from.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...

void push_back(Operation *op)

Appends an element to the list.

Base class for extensions of the Transform dialect that supports injecting operations into the Transf...

Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...

void set(OpResult value, Range &&ops)

Indicates that the result of the transform IR op at the given position corresponds to the given list ...

This is a special rewriter to be used in transform op implementations, providing additional helper fu...

The state maintained across applications of various ops implementing the TransformOpInterface.

LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr, bool cleanUpUnroll=false)

Unrolls this for operation by the specified unroll factor.

LogicalResult loopUnrollJamByFactor(AffineForOp forOp, uint64_t unrollJamFactor)

Unrolls and jams this loop by the specified factor.

LogicalResult coalescePerfectlyNestedAffineLoops(AffineForOp op)

Walk an affine.for to find a band to coalesce.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

void registerTransformDialectExtension(DialectRegistry &registry)

LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)

Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...

LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp, SmallVectorImpl< Operation * > *results=nullptr)

Try converting scf.forall into a set of nested scf.for loops.

LogicalResult peelForLoopFirstIteration(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)

Peel the first iteration out of the scf.for loop.

LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp, ParallelOp *result=nullptr)

Try converting scf.forall into an scf.parallel loop.

void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns)

Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...

void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns)

Populate patterns for canonicalizing operations inside SCF loop bodies.

FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)

Generate a pipelined version of the scf.for loop based on the schedule given as option.

void populateSCFStructuralTypeConversionTarget(const TypeConverter &typeConverter, ConversionTarget &target)

Updates the ConversionTarget with dynamic legality of SCF operations based on the provided type conve...

void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

Populates effects with the memory effects indicating the operation on the given handle value:

void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)

Populates effects with the memory effects indicating the access to payload IR resource.

Include the generated interface declarations.

std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)

If ofr is a constant integer or an IntegerAttr, return the integer.

LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)

Walk an affine.for to find a band to coalesce.

DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})

Emits a silenceable failure with the given message.

DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})

Emits a definite failure with the given message.

const FrozenRewritePatternSet & patterns

void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)

Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...

FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region &region, StringRef funcName, func::CallOp *callOp=nullptr)

Outline a region with a single block into a new FuncOp.

scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)

Given two scf.forall loops, target and source, fuses target into source.

scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)

Given two scf.for loops, target and source, fuses target into source.

void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)

Calls callback for each use of a value within region or its descendants that was defined at the ances...

Options to dictate how loops should be pipelined.