MLIR: lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

18#include

19

20namespace mlir {

21

24 for (Block &b : funcOp.getBody())

25 if (auto returnOp = dyn_castfunc::ReturnOp(b.getTerminator()))

26 result.push_back(returnOp);

28}

29

31namespace func_ext {

32

36 auto createdAliasingResults =

40 (void)createdEquiv;

41 (void)createdAliasingResults;

42 (void)createdRead;

43 (void)createdWritten;

44#ifndef NDEBUG

45 assert(createdEquiv.second && "equivalence info exists already");

46 assert(createdAliasingResults.second && "aliasing info exists already");

47 assert(createdRead.second && "bbarg access info exists already");

48 assert(createdWritten.second && "bbarg access info exists already");

49#endif

50}

51

52

53

56 TensorLikeType type) {

57 if (auto tensorType = dyn_cast(type)) {

58 return *options.defaultMemorySpaceFn(tensorType);

59 }

60 return nullptr;

61}

62

63

64

65

66static BufferLikeType

69 auto type =

70 dyn_cast(funcOp.getFunctionType().getInput(index));

71 assert(type && "expected TensorLikeType");

72

73

74 if (auto tensorType = dyn_cast(type)) {

75 BufferLikeType memrefType = options.functionArgTypeConverterFn(

76 type, *options.defaultMemorySpaceFn(tensorType), funcOp, options);

77

78 auto layoutAttr = funcOp.getArgAttrOfType(

79 index, BufferizationDialect::kBufferLayoutAttrName);

80 if (!layoutAttr)

81 return memrefType;

82

83 auto rankedMemrefType = dyn_cast(memrefType);

84 assert(rankedMemrefType &&

85 "buffer layout not supported on unranked tensors");

86 return cast(MemRefType::get(

87 rankedMemrefType.getShape(), rankedMemrefType.getElementType(),

88 layoutAttr, rankedMemrefType.getMemorySpace()));

89 }

90

91 return options.functionArgTypeConverterFn(type, nullptr, funcOp,

93}

94

95

98 return dyn_cast_or_null(callOp.resolveCallableInTable(&symbolTables));

99}

100

101

104 auto &oneShotAnalysisState = static_cast<const OneShotAnalysisState &>(state);

105

106 if (auto *funcAnalysisState =

108

109 return getCalledFunction(callOp, funcAnalysisState->symbolTables);

110 }

111

114}

115

116

117static const FuncAnalysisState &

119 assert(isa(state) && "expected OneShotAnalysisState");

121 .getExtension();

122 assert(result && "FuncAnalysisState does not exist");

124}

125

126

128 FuncOp funcOp) {

129 if (!isa(state))

132 .getExtension();

133 if (!funcState)

135 const auto &analyzedFuncOps = funcState->analyzedFuncOps;

136 auto it = analyzedFuncOps.find(funcOp);

137 if (it == analyzedFuncOps.end())

139 return it->second;

140}

141

142

143

144static std::optional<int64_t>

149

150 return std::nullopt;

151

152 auto retValIt = funcOpIt->getSecond().find(returnValIdx);

153 if (retValIt == funcOpIt->getSecond().end())

154

155 return std::nullopt;

156

157 return retValIt->getSecond();

158}

159

161 : public BufferizableOpInterface::ExternalModel<CallOpInterface,

162 func::CallOp> {

165 func::CallOp callOp = castfunc::CallOp(op);

167 assert(funcOp && "expected CallOp to a FuncOp");

168

170

171 return true;

172

174 return funcState.readBbArgs.lookup(funcOp).contains(

176 }

177

180 func::CallOp callOp = castfunc::CallOp(op);

182 assert(funcOp && "expected CallOp to a FuncOp");

183

185

186 return true;

187

189 return funcState.writtenBbArgs.lookup(funcOp).contains(

191 }

192

195 func::CallOp callOp = castfunc::CallOp(op);

197 assert(funcOp && "expected CallOp to a FuncOp");

199

200 return detail::unknownGetAliasingValues(opOperand);

201

202

204 auto aliasingReturnVals =

207

208

209 std::optional<int64_t> equivalent = {};

210 if (aliasingReturnVals.size() == 1) {

212 aliasingReturnVals.front());

213 assert((!equivalent.has_value() ||

215 "inconsistent analysis state");

216 }

217 AliasingValueList result;

218 for (int64_t resultIdx : aliasingReturnVals)

219 result.addAlias({callOp->getOpResult(resultIdx),

220 equivalent.has_value() ? BufferRelation::Equivalent

221 : BufferRelation::Unknown,

222 equivalent.has_value()});

224 }

225

226 FailureOr

228 const BufferizationState &state,

230 auto callOp = castfunc::CallOp(op);

231

232

234

236 assert(funcOp && "expected CallOp to a FuncOp");

237

238

239

240 FunctionType funcType = funcOp.getFunctionType();

241 Type resultType =

242 funcType.getResult(cast(value).getResultNumber());

243 if (auto bufferizedType = dyn_cast(resultType))

244 return bufferizedType;

245

246

247 auto tensorType = cast(resultType);

248 return cast(options.functionArgTypeConverterFn(

251 }

252

253

254

257 BufferizationState &state) const {

258 func::CallOp callOp = castfunc::CallOp(op);

259

260

262 for (Value result : callOp.getResults()) {

264 if (!isa(returnType)) {

265

266 resultTypes.push_back(returnType);

267 continue;

268 }

269

270

271 FailureOr resultType =

272 bufferization::getBufferType(result, options, state);

273 if (failed(resultType))

274 return failure();

275 resultTypes.push_back(*resultType);

276 }

277

278

279

281

282 FuncOp funcOp = getCalledFunction(callOp, state.getSymbolTables());

283 assert(funcOp && "expected CallOp to a FuncOp");

284 FunctionType funcType = funcOp.getFunctionType();

285

286 for (OpOperand &opOperand : callOp->getOpOperands()) {

287

288 if (!isa(opOperand.get().getType())) {

289 newOperands.push_back(opOperand.get());

290 continue;

291 }

292

293

294 FailureOr maybeBuffer =

295 getBuffer(rewriter, opOperand.get(), options, state);

296 if (failed(maybeBuffer))

297 return failure();

298 Value buffer = *maybeBuffer;

299

300

301 auto bufferType = funcType.getInput(opOperand.getOperandNumber());

302 if (!isa(bufferType)) {

303

304

305

306 FailureOr maybeBufferType =

307 bufferization::getBufferType(

308 funcOp.getArgument(opOperand.getOperandNumber()), options,

309 state);

310 if (failed(maybeBufferType))

311 return failure();

312 bufferType = *maybeBufferType;

313 }

314

315

316

317

318

319

320

321 if (buffer.getType() != bufferType) {

322 auto memrefDstType = dyn_cast(bufferType);

323 assert(memrefDstType &&

324 "buffer layout not supported on unranked tensors");

326 rewriter, buffer, memrefDstType, options);

328 return failure();

330 }

331 newOperands.push_back(buffer);

332 }

333

334

336 func::CallOp::create(rewriter, callOp.getLoc(), funcOp.getSymName(),

337 resultTypes, newOperands);

338 newCallOp->setAttrs(callOp->getAttrs());

339

340

341 replaceOpWithBufferizedValues(rewriter, callOp, newCallOp->getResults());

342

344 }

345};

346

348 : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,

349 func::ReturnOp> {

354

359

364

367 BufferizationState &state) const {

368#ifndef NDEBUG

369 auto returnOp = castfunc::ReturnOp(op);

370 assert(isa(returnOp->getParentOp()) &&

371 "only support FuncOp parent for ReturnOp");

372#endif

373

374

376 }

377};

378

381 FuncOpInterface, FuncOp> {

382

384

386 auto isaTensor = llvm::IsaPred;

387

388

389 auto funcOp = cast(op);

390 bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);

391 bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);

392 if (hasTensorArg || hasTensorResult)

393 return true;

394

395

396

397

398

399

400 for (Block &block : funcOp.getBody())

401 if (any_of(block.getArgumentTypes(), isaTensor))

402 return true;

403

404 return false;

405 }

406

407 AliasingOpOperandList

412

413 FailureOr

415 const BufferizationState &state,

417 auto funcOp = cast(op);

418 auto bbArg = cast(value);

419

420

421 if (bbArg.getOwner() == &funcOp.getBody().front())

424

427 }

428

429

430

431

432

433

434

435

438 BufferizationState &state) const {

439 auto funcOp = cast(op);

440 FunctionType funcType = funcOp.getFunctionType();

441

442

444 for (const auto &it : llvm::enumerate(funcType.getInputs())) {

445 Type argType = it.value();

446 if (isa(argType)) {

447 argTypes.push_back(

449 continue;

450 }

451 argTypes.push_back(argType);

452 }

453

454

456 for (Type resultType : funcType.getResults()) {

457 if (auto tensorType = dyn_cast(resultType)) {

458 BufferLikeType resultType = options.functionArgTypeConverterFn(

461 retTypes.push_back(resultType);

462 continue;

463 }

464 retTypes.push_back(resultType);

465 }

466

467

468 auto newFuncType = FunctionType::get(op->getContext(), argTypes, retTypes);

469

470

471 if (funcOp.isExternal()) {

472 funcOp.setType(newFuncType);

474 }

475

476

477 for (Block &block : funcOp.getBody())

480 return failure();

481

482

483 for (func::ReturnOp returnOp : getReturnOps(funcOp)) {

484 assert(returnOp->getNumOperands() == retTypes.size() &&

485 "incorrect number of return values");

487 for (auto [returnVal, bufferizedType] :

488 llvm::zip_equal(returnOp->getOperands(), retTypes)) {

489 auto tensorType = dyn_cast(returnVal.getType());

491

492

493 if (!tensorType) {

494 returnValues.push_back(returnVal);

495 continue;

496 }

497

498

499

500 Value toBufferOp = bufferization::ToBufferOp::create(

501 rewriter, returnOp.getLoc(), bufferizedType, returnVal);

502 returnValues.push_back(toBufferOp);

503 }

504

505 returnOp.getOperandsMutable().assign(returnValues);

506 }

507

508

509 funcOp.setType(newFuncType);

511 }

512

513

516 auto funcOp = cast(op);

517 BlockArgument bbArg = dyn_cast(value);

518 assert(bbArg && "expected BlockArgument");

519

520

521

522 if (bbArg.getOwner() != &funcOp.getBody().front())

523 return true;

524

525

526

528 bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName))

529 return writable.getValue();

530

531

532 return true;

533 }

534};

535

536}

537}

538}

539

543 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);

544 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);

545 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);

546 });

547}

b

Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...

*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`

static bool isaTensor(Type t)

static llvm::ManagedStatic< PassManagerOptions > options

Base class for generic analysis states.

Attributes are known-constant values of operations.

This class represents an argument of a Block.

unsigned getArgNumber() const

Returns the number of this argument.

Block * getOwner() const

Returns the block that owns this argument.

Block represents an ordered list of Operations.

Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.

The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.

bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)

Add the given extension to the registry.

MLIRContext is the top-level object for a collection of MLIR operations.

void setInsertionPoint(Block *block, Block::iterator insertPoint)

Set the insertion point to the specified location.

This class represents an operand of an operation.

unsigned getOperandNumber()

Return which operand this is in the OpOperand list of the Operation.

Operation is the basic unit of execution within MLIR.

void setAttrs(DictionaryAttr newAttrs)

Set the attributes from a dictionary on this operation.

result_range getResults()

MLIRContext * getContext()

Return the context this operation is associated with.

This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...

This class represents a collection of SymbolTables.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

State for analysis-enabled bufferization.

static BufferLikeType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options)

Return the index-th bufferized function argument type.

Definition FuncBufferizableOpInterfaceImpl.cpp:67

FuncOpAnalysisState

The state of analysis of a FuncOp.

void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)

Definition FuncBufferizableOpInterfaceImpl.cpp:541

static FuncOp getCalledFunction(CallOpInterface callOp, SymbolTableCollection &symbolTables)

Return the FuncOp called by callOp.

Definition FuncBufferizableOpInterfaceImpl.cpp:96

static std::optional< int64_t > getEquivalentFuncArgIdx(FuncOp funcOp, const FuncAnalysisState &state, int64_t returnValIdx)

Return the index of the bbArg in the given FuncOp that is equivalent to the specified return value (i...

Definition FuncBufferizableOpInterfaceImpl.cpp:145

static mlir::Attribute getDefaultMemorySpace(const BufferizationOptions &options, TensorLikeType type)

Definition FuncBufferizableOpInterfaceImpl.cpp:55

static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, FuncOp funcOp)

Return the state (phase) of analysis of the FuncOp.

Definition FuncBufferizableOpInterfaceImpl.cpp:127

static const FuncAnalysisState & getFuncAnalysisState(const AnalysisState &state)

Get FuncAnalysisState.

Definition FuncBufferizableOpInterfaceImpl.cpp:118

FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)

Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.

SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)

Helper function that returns all func.return ops in the given function.

Definition FuncBufferizableOpInterfaceImpl.cpp:22

LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)

Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...

Include the generated interface declarations.

A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...

AliasingOpOperandList getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg, const AnalysisState &state) const

FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const

FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const

Definition FuncBufferizableOpInterfaceImpl.cpp:227

AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const

Definition FuncBufferizableOpInterfaceImpl.cpp:193

LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) const

All function arguments are writable.

Definition FuncBufferizableOpInterfaceImpl.cpp:255

bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const

Definition FuncBufferizableOpInterfaceImpl.cpp:178

bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const

Definition FuncBufferizableOpInterfaceImpl.cpp:163

Extra analysis state that is required for bufferization of function boundaries.

DenseMap< FuncOp, IndexMapping > equivalentFuncArgs

A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.

DenseMap< int64_t, SmallVector< int64_t > > IndexToIndexListMapping

A mapping of indices to a list of indices.

DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals

A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.

DenseMap< FuncOp, BbArgIndexSet > readBbArgs

A set of all read BlockArguments of FuncOps.

DenseSet< int64_t > BbArgIndexSet

A set of block argument indices.

DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs

A set of all written-to BlockArguments of FuncOps.

DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps

Keep track of which FuncOps are fully analyzed or currently being analyzed.

void startFunctionAnalysis(FuncOp funcOp)

This function is called right before analyzing the given FuncOp.

Definition FuncBufferizableOpInterfaceImpl.cpp:33

DenseMap< int64_t, int64_t > IndexMapping

A mapping of indices to indices.

FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const

Definition FuncBufferizableOpInterfaceImpl.cpp:414

AliasingOpOperandList getAliasingOpOperands(Operation *op, Value value, const AnalysisState &state) const

Definition FuncBufferizableOpInterfaceImpl.cpp:408

bool isWritable(Operation *op, Value value, const AnalysisState &state) const

Return true if the given function argument is writable.

Definition FuncBufferizableOpInterfaceImpl.cpp:514

static bool supportsUnstructuredControlFlow()

Definition FuncBufferizableOpInterfaceImpl.cpp:383

bool hasTensorSemantics(Operation *op) const

Definition FuncBufferizableOpInterfaceImpl.cpp:385

LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) const

Rewrite function bbArgs and return values into buffer form.

Definition FuncBufferizableOpInterfaceImpl.cpp:436

AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const

Definition FuncBufferizableOpInterfaceImpl.cpp:360

bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const

Definition FuncBufferizableOpInterfaceImpl.cpp:355

bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const

Definition FuncBufferizableOpInterfaceImpl.cpp:350

LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) const

Definition FuncBufferizableOpInterfaceImpl.cpp:365