MLIR: lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

10

16 #include "llvm/ADT/SetOperations.h"

17 #include "llvm/ADT/SetVector.h"

18

19 using namespace mlir;

21

22

23

24

25

26

28

33 queue.push_back(value);

34 while (!queue.empty()) {

35 Value currentValue = queue.pop_back_val();

36 if (result.insert(currentValue).second) {

37 auto it = map.find(currentValue);

38 if (it != map.end()) {

39 for (Value aliasValue : it->second)

40 queue.push_back(aliasValue);

41 }

42 }

43 }

44 return result;

45 }

46

47

48

49

53 }

54

57 return resolveValues(reverseDependencies, rootValue);

58 }

59

60

62 for (auto &entry : dependencies)

63 llvm::set_subtract(entry.second, aliasValues);

64 }

65

67 dependencies[to] = dependencies[from];

68 dependencies.erase(from);

69

70 for (auto &[_, value] : dependencies) {

71 if (value.contains(from)) {

72 value.insert(to);

73 value.erase(from);

74 }

75 }

76 }

77

78

79

80

81

82

83 void BufferViewFlowAnalysis::build(Operation *op) {

84

86 for (auto [value, dep] : llvm::zip_equal(values, dependencies)) {

87 this->dependencies[value].insert(dep);

88 this->reverseDependencies[dep].insert(value);

89 }

90 };

91

92

93

94 auto populateTerminalValues = [&](Operation *op) {

96 if (isa(v.getType()))

97 this->terminals.insert(v);

100 if (isa(v.getType()))

101 this->terminals.insert(v);

102 };

103

105

106

107

108 if (auto bufferViewFlowOp = dyn_cast(op)) {

109 bufferViewFlowOp.populateDependencies(registerDependencies);

111 if (isa(v.getType()) &&

112 bufferViewFlowOp.mayBeTerminalBuffer(v))

113 this->terminals.insert(v);

116 if (isa(v.getType()) &&

117 bufferViewFlowOp.mayBeTerminalBuffer(v))

118 this->terminals.insert(v);

120 }

121

122

123 if (auto viewInterface = dyn_cast(op)) {

124 registerDependencies(viewInterface.getViewSource(),

125 viewInterface->getResult(0));

127 }

128

129 if (auto branchInterface = dyn_cast(op)) {

130

131 Block *parentBlock = branchInterface->getBlock();

132 for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();

133 it != e; ++it) {

134

135 auto successorOperands =

136 branchInterface.getSuccessorOperands(it.getIndex());

137

138 registerDependencies(successorOperands.getForwardedOperands(),

139 (*it)->getArguments().drop_front(

140 successorOperands.getProducedOperandCount()));

141 }

143 }

144

145 if (auto regionInterface = dyn_cast(op)) {

146

147

150 entrySuccessors);

151 for (RegionSuccessor &entrySuccessor : entrySuccessors) {

152

153

154 registerDependencies(

155 regionInterface.getEntrySuccessorOperands(entrySuccessor),

156 entrySuccessor.getSuccessorInputs());

157 }

158

159

160 for (Region &region : regionInterface->getRegions()) {

161

162

164 regionInterface.getSuccessorRegions(region, successorRegions);

165 for (RegionSuccessor &successorRegion : successorRegions) {

166

167

168 for (Block &block : region)

169 if (auto terminator = dyn_cast(

170 block.getTerminator()))

171 registerDependencies(

172 terminator.getSuccessorOperands(successorRegion),

173 successorRegion.getSuccessorInputs());

174 }

175 }

176

178 }

179

180

181 if (isa(op))

183

184 if (isa(op)) {

185

186

187

188

189 populateTerminalValues(op);

192 registerDependencies({operand}, {result});

194 }

195

196

197 populateTerminalValues(op);

198

200 });

201 }

202

204 assert(isa(value.getType()) && "expected memref");

205 return terminals.contains(value);

206 }

207

208

209

210

211

212

215 if (!op)

216 return false;

217 return hasEffectMemoryEffects::Allocate(op, v);

218 }

219

220

222 auto bbArg = dyn_cast(v);

223 if (!bbArg)

224 return false;

225 Block *b = bbArg.getOwner();

226 auto funcOp = dyn_cast(b->getParentOp());

227 if (!funcOp)

228 return false;

229 return bbArg.getOwner() == &funcOp.getFunctionBody().front();

230 }

231

232

233

235 while (auto viewLikeOp = value.getDefiningOp())

236 value = viewLikeOp.getViewSource();

237 return value;

238 }

239

241

243 assert(isa(v1.getType()) && "expected buffer");

244 assert(isa(v2.getType()) && "expected buffer");

245

246

249

250

251

252 if (v1 == v2)

253 return true;

254

255

258

259

260

261

262

263

264

266

267

268

269 bool allAllocs1 = true, allAllocs2 = true;

270 bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;

271

272

275 bool &allAllocs,

276 bool &allAllocsOrFuncEntryArgs) {

277 for (Value v : origin) {

278 if (isa(v.getType()) && analysis.mayBeTerminalBuffer(v)) {

279 terminal.insert(v);

281 allAllocsOrFuncEntryArgs &=

283 }

284 }

285 assert(!terminal.empty() && "expected non-empty terminal set");

286 };

287

288

289 gatherTerminalBuffers(origin1, terminal1, allAllocs1,

290 allAllocsOrFuncEntryArgs1);

291 gatherTerminalBuffers(origin2, terminal2, allAllocs2,

292 allAllocsOrFuncEntryArgs2);

293

294

295

296 if (llvm::hasSingleElement(terminal1) && llvm::hasSingleElement(terminal2) &&

297 *terminal1.begin() == *terminal2.begin())

298 return true;

299

300

301

302

303 bool distinctTerminalSets = true;

304 for (Value v : terminal1)

305 distinctTerminalSets &= !terminal2.contains(v);

306

307

308 if (!distinctTerminalSets)

309 return std::nullopt;

310

311

312

313

314

315

316 bool isolatedAlloc1 = allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2);

317 bool isolatedAlloc2 = (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2;

318 if (isolatedAlloc1 || isolatedAlloc2)

319 return false;

320

321

322

323

324

325

326

327

328

329 return std::nullopt;

330 }

static bool isFunctionArgument(Value v)

Return "true" if the given value is a function block argument.

static Value getViewBase(Value value)

Given a memref value, return the "base" value by skipping over all ViewLikeOpInterface ops (if any) i...

static BufferViewFlowAnalysis::ValueSetT resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value)

static bool hasAllocateSideEffect(Value v)

Return "true" if the given value is the result of a memory allocation.

This class represents an argument of a Block.

Block represents an ordered list of Operations.

succ_iterator succ_begin()

Operation * getParentOp()

Returns the closest surrounding operation that contains this block.

BufferOriginAnalysis(Operation *op)

std::optional< bool > isSameAllocation(Value v1, Value v2)

Return "true" if v1 and v2 originate from the same buffer allocation.

BufferViewFlowAnalysis(Operation *op)

Constructs a new alias analysis using the op provided.

void remove(const SetVector< Value > &aliasValues)

Removes the given values from all alias sets.

ValueSetT resolve(Value value) const

Find all immediate and indirect views upon this value.

void rename(Value from, Value to)

Replaces all occurrences of 'from' in the internal datastructures with 'to'.

bool mayBeTerminalBuffer(Value value) const

Returns "true" if the given value may be a terminal.

ValueSetT resolveReverse(Value value) const

Operation is the basic unit of execution within MLIR.

std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)

Walk the operation by calling the callback for each nested operation (including this one),...

MutableArrayRef< Region > getRegions()

Returns the regions held by this operation.

operand_range getOperands()

Returns an iterator on the underlying Value's.

result_range getResults()

static constexpr RegionBranchPoint parent()

Returns an instance of RegionBranchPoint representing the parent operation.

This class represents a successor of a region.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Operation * getDefiningOp() const

If this value is the result of an operation, return the operation that defines it.

static WalkResult advance()

Include the generated interface declarations.