MLIR: lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

11

16

17 namespace mlir {

18 namespace bufferization {

19 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMSPASS

20 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"

21 }

22 }

23

24 using namespace mlir;

27

28

30 int64_t offset;

32 if (failed(type.getStridesAndOffset(strides, offset)))

33 return false;

34 if (!llvm::all_of(strides, ShapedType::isDynamic))

35 return false;

36 if (!ShapedType::isDynamic(offset))

37 return false;

38 return true;

39 }

40

41

42

44 return type.getLayout().isIdentity();

45 }

46

47

48

49

50

51

52 static LogicalResult

55 bool addResultAttribute) {

56 auto functionType = func.getFunctionType();

57

58

60 BitVector erasedResultIndices(functionType.getNumResults());

61 for (const auto &resultType : llvm::enumerate(functionType.getResults())) {

62 if (auto memrefType = dyn_cast(resultType.value())) {

65

66

67

68 return func->emitError()

69 << "cannot create out param for result with unsupported layout";

70 }

71 erasedResultIndices.set(resultType.index());

72 erasedResultTypes.push_back(memrefType);

73 }

74 }

75

76

77 auto newArgTypes = llvm::to_vector<6>(

78 llvm::concat(functionType.getInputs(), erasedResultTypes));

79 auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,

80 functionType.getResults());

81 func.setType(newFunctionType);

82

83

84 auto erasedIndicesIt = erasedResultIndices.set_bits_begin();

85 for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {

86 func.setArgAttrs(functionType.getNumInputs() + i,

87 func.getResultAttrs(*erasedIndicesIt));

88 if (addResultAttribute)

89 func.setArgAttr(functionType.getNumInputs() + i,

92 }

93

94

95 if (failed(func.eraseResults(erasedResultIndices)))

96 return failure();

97

98

99 if (func.isExternal())

100 return success();

101 Location loc = func.getLoc();

102 for (Type type : erasedResultTypes)

103 appendedEntryArgs.push_back(func.front().addArgument(type, loc));

104

105 return success();

106 }

107

108

109

110

111 static LogicalResult

114 auto res = func.walk([&](func::ReturnOp op) {

117 for (Value operand : op.getOperands()) {

118 if (isa(operand.getType()))

119 copyIntoOutParams.push_back(operand);

120 else

121 keepAsReturnOperands.push_back(operand);

122 }

124 for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {

125 if (options.hoistStaticAllocs &&

126 isa_and_nonnullbufferization::AllocationOpInterface(

127 orig.getDefiningOp()) &&

128 mlir::cast(orig.getType()).hasStaticShape()) {

129 orig.replaceAllUsesWith(arg);

130 orig.getDefiningOp()->erase();

131 } else {

132 if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))

133 return WalkResult::interrupt();

134 }

135 }

136 builder.createfunc::ReturnOp(op.getLoc(), keepAsReturnOperands);

137 op.erase();

139 });

140 return failure(res.wasInterrupted());

141 }

142

143

144

145 static LogicalResult

148 bool didFail = false;

150 module.walk([&](func::CallOp op) {

151 auto callee = symtab.lookupfunc::FuncOp(op.getCallee());

152 if (!callee) {

153 op.emitError() << "cannot find callee '" << op.getCallee() << "' in "

154 << "symbol table";

155 didFail = true;

156 return;

157 }

158 if (options.filterFn(&callee))

159 return;

162 for (OpResult result : op.getResults()) {

163 if (isa(result.getType()))

164 replaceWithOutParams.push_back(result);

165 else

166 replaceWithNewCallResults.push_back(result);

167 }

170 for (Value memref : replaceWithOutParams) {

171 if (!cast(memref.getType()).hasStaticShape()) {

172 op.emitError()

173 << "cannot create out param for dynamically shaped result";

174 didFail = true;

175 return;

176 }

177 auto memrefType = cast(memref.getType());

178 auto allocType =

179 MemRefType::get(memrefType.getShape(), memrefType.getElementType(),

180 AffineMap(), memrefType.getMemorySpace());

181 auto maybeOutParam =

182 options.allocationFn(builder, op.getLoc(), allocType);

183 if (failed(maybeOutParam)) {

184 op.emitError() << "failed to create allocation op";

185 didFail = true;

186 return;

187 }

188 Value outParam = maybeOutParam.value();

190

192 "layout map not supported");

193 outParam =

194 builder.creatememref::CastOp(op.getLoc(), memrefType, outParam);

195 }

197 outParams.push_back(outParam);

198 }

199

200 auto newOperands = llvm::to_vector<6>(op.getOperands());

201 newOperands.append(outParams.begin(), outParams.end());

202 auto newResultTypes = llvm::to_vector<6>(llvm::map_range(

203 replaceWithNewCallResults, [](Value v) { return v.getType(); }));

204 auto newCall = builder.createfunc::CallOp(op.getLoc(), op.getCalleeAttr(),

205 newResultTypes, newOperands);

206 for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))

208 op.erase();

209 });

210

211 return failure(didFail);

212 }

213

215 ModuleOp module,

217 for (auto func : module.getOpsfunc::FuncOp()) {

218 if (options.filterFn(&func))

219 continue;

221 if (failed(

223 return failure();

224 if (func.isExternal())

225 continue;

227 return failure();

228 }

229 }

231 return failure();

232 return success();

233 }

234

235 namespace {

236 struct BufferResultsToOutParamsPass

237 : bufferization::impl::BufferResultsToOutParamsPassBase<

238 BufferResultsToOutParamsPass> {

239 using Base::Base;

240

241 void runOnOperation() override {

242

243 if (addResultAttribute)

244 options.addResultAttribute = true;

245 if (hoistStaticAllocs)

246 options.hoistStaticAllocs = true;

247

250 return signalPassFailure();

251 }

252

253 private:

255 };

256 }

static LogicalResult updateReturnOps(func::FuncOp func, ArrayRef< BlockArgument > appendedEntryArgs, const bufferization::BufferResultsToOutParamsOpts &options)

bufferization::BufferResultsToOutParamsOpts::AllocationFn AllocationFn

bufferization::BufferResultsToOutParamsOpts::MemCpyFn MemCpyFn

static LogicalResult updateFuncOp(func::FuncOp func, SmallVectorImpl< BlockArgument > &appendedEntryArgs, bool addResultAttribute)

static bool hasStaticIdentityLayout(MemRefType type)

Return true if the given MemRef type has a static identity layout (i.e., no layout).

static LogicalResult updateCalls(ModuleOp module, const bufferization::BufferResultsToOutParamsOpts &options)

static bool hasFullyDynamicLayoutMap(MemRefType type)

Return true if the given MemRef type has a fully dynamic layout.

static llvm::ManagedStatic< PassManagerOptions > options

A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

This class helps build Operations.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

This is a value defined by a result of an operation.

void replaceAllUsesWith(ValuesT &&values)

Replace all uses of results of this operation with the provided 'values'.

This class allows for representing and managing the symbol table used by operations with the 'SymbolT...

Operation * lookup(StringRef name) const

Look up a symbol with the specified name, returning null if no such name exists.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

static WalkResult advance()

LogicalResult promoteBufferResultsToOutParams(ModuleOp module, const BufferResultsToOutParamsOpts &options)

Replace buffers that are returned from a function with an out parameter.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

Include the generated interface declarations.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

std::function< LogicalResult(OpBuilder &, Location, Value, Value)> MemCpyFn

Memcpy function: Generate a memcpy between two memrefs.

std::function< FailureOr< Value >(OpBuilder &, Location, MemRefType)> AllocationFn

Allocator function: Generate a memref allocation with the given type.