MLIR: lib/Interfaces/ControlFlowInterfaces.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9 #include

10

13 #include "llvm/ADT/SmallPtrSet.h"

14

15 using namespace mlir;

16

17

18

19

20

21 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"

22

24 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {

25 }

26

29 : producedOperandCount(producedOperandCount),

30 forwardedOperands(std::move(forwardedOperands)) {}

31

32

33

34

35

36

37

38

39 std::optional

41 unsigned operandIndex, Block *successor) {

43

44 if (forwardedOperands.empty())

45 return std::nullopt;

46

47

49 if (operandIndex < operandsStart ||

50 operandIndex >= (operandsStart + forwardedOperands.size()))

51 return std::nullopt;

52

53

54 unsigned argIndex =

57 }

58

59

60 LogicalResult

63

64 unsigned operandCount = operands.size();

67 return op->emitError() << "branch has " << operandCount

68 << " operands for successor #" << succNo

69 << ", but target block has "

71

72

74 ++i) {

75 if (!cast(op).areTypesCompatible(

77 return op->emitError() << "type mismatch for bb argument #" << i

78 << " of successor #" << succNo;

79 }

80 return success();

81 }

82

83

84

85

86

90 diag << "from ";

92 diag << "Region #" << region->getRegionNumber();

93 else

94 diag << "parent operands";

95

96 diag << " to ";

98 diag << "Region #" << region->getRegionNumber();

99 else

100 diag << "parent results";

102 }

103

104

105

106

107 static LogicalResult

110 getInputsTypesForRegion) {

111 auto regionInterface = cast(op);

112

114 regionInterface.getSuccessorRegions(sourcePoint, successors);

115

117 FailureOr sourceTypes = getInputsTypesForRegion(succ);

118 if (failed(sourceTypes))

119 return failure();

120

121 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();

122 if (sourceTypes->size() != succInputsTypes.size()) {

125 << ": source has " << sourceTypes->size()

126 << " operands, but target successor needs "

127 << succInputsTypes.size();

128 }

129

130 for (const auto &typesIdx :

131 llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {

132 Type sourceType = std::get<0>(typesIdx.value());

133 Type inputType = std::get<1>(typesIdx.value());

134 if (!regionInterface.areTypesCompatible(sourceType, inputType)) {

137 << ": source type #" << typesIdx.index() << " " << sourceType

138 << " should match input type #" << typesIdx.index() << " "

139 << inputType;

140 }

141 }

142 }

143 return success();

144 }

145

146

148 auto regionInterface = cast(op);

149

151 return regionInterface.getEntrySuccessorOperands(point).getTypes();

152 };

153

154

156 inputTypesFromParent)))

157 return failure();

158

160 if (lhs.size() != rhs.size())

161 return false;

162 for (auto types : llvm::zip(lhs, rhs)) {

163 if (!regionInterface.areTypesCompatible(std::get<0>(types),

164 std::get<1>(types))) {

165 return false;

166 }

167 }

168 return true;

169 };

170

171

173

174

175

176

177

179 for (Block &block : region)

180 if (!block.empty())

181 if (auto terminator =

182 dyn_cast(block.back()))

183 regionReturnOps.push_back(terminator);

184

185

186

187 if (regionReturnOps.empty())

188 continue;

189

190 auto inputTypesForRegion =

192 std::optional regionReturnOperands;

193 for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {

194 auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);

195

196 if (!regionReturnOperands) {

197 regionReturnOperands = terminatorOperands;

198 continue;

199 }

200

201

202

203 if (!areTypesCompatible(regionReturnOperands->getTypes(),

204 terminatorOperands.getTypes())) {

207 << " operands mismatch between return-like terminators";

208 }

209 }

210

211

212 return TypeRange(regionReturnOperands->getTypes());

213 };

214

216 return failure();

217 }

218

219 return success();

220 }

221

222

223

224

225

227

228

229

230

231

234 auto op = cast(begin->getParentOp());

237

238

240 auto enqueueAllSuccessors = [&](Region *region) {

242 op.getSuccessorRegions(region, successors);

244 if (!successor.isParent())

245 worklist.push_back(successor.getSuccessor());

246 };

247 enqueueAllSuccessors(begin);

248

249

250 while (!worklist.empty()) {

251 Region *nextRegion = worklist.pop_back_val();

252 if (stopConditionFn(nextRegion, visited))

253 return true;

255 continue;

257 enqueueAllSuccessors(nextRegion);

258 }

259

260 return false;

261 }

262

263

264

266 assert(begin->getParentOp() == r->getParentOp() &&

267 "expected that both regions belong to the same op");

270

271 return nextRegion == r;

272 });

273 }

274

275

276

277

278

279

280

281

282

283

285 assert(a && "expected non-empty operation");

286 assert(b && "expected non-empty operation");

287

288 auto branchOp = a->getParentOfType();

289 while (branchOp) {

290

291 if (!branchOp->isProperAncestor(b)) {

292

293 branchOp = branchOp->getParentOfType();

294 continue;

295 }

296

297

298

299 Region *regionA = nullptr, *regionB = nullptr;

300 for (Region &r : branchOp->getRegions()) {

301 if (r.findAncestorOpInRegion(*a)) {

302 assert(!regionA && "already found a region for a");

303 regionA = &r;

304 }

305 if (r.findAncestorOpInRegion(*b)) {

306 assert(!regionB && "already found a region for b");

307 regionB = &r;

308 }

309 }

310 assert(regionA && regionB && "could not find region of op");

311

312

313

314 return regionA != regionB && isRegionReachable(regionA, regionB) &&

316 }

317

318

319

320 return false;

321 }

322

324 Region *region = &getOperation()->getRegion(index);

326 }

327

328 bool RegionBranchOpInterface::hasLoop() {

332 if (!successor.isParent() &&

335

336

337 return visited[nextRegion->getRegionNumber()];

338 }))

339 return true;

340 return false;

341 }

342

346 if (auto branchOp = dyn_cast(op))

347 if (branchOp.isRepetitiveRegion(region->getRegionNumber()))

348 return region;

349 }

350 return nullptr;

351 }

352

355 while (region) {

357 if (auto branchOp = dyn_cast(op))

358 if (branchOp.isRepetitiveRegion(region->getRegionNumber()))

359 return region;

361 }

362 return nullptr;

363 }

static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)

static bool traverseRegionGraph(Region *begin, StopConditionFn stopConditionFn)

Traverse the region graph starting at begin.

static InFlightDiagnostic & printRegionEdgeName(InFlightDiagnostic &diag, RegionBranchPoint sourceNo, RegionBranchPoint succRegionNo)

static LogicalResult verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, function_ref< FailureOr< TypeRange >(RegionBranchPoint)> getInputsTypesForRegion)

Verify that types match along all region control flow edges originating from sourcePoint.

static bool isRegionReachable(Region *begin, Region *r)

Return true if region r is reachable from region begin according to the RegionBranchOpInterface (by t...

static std::string diag(const llvm::Value &value)

Block represents an ordered list of Operations.

BlockArgument getArgument(unsigned i)

unsigned getNumArguments()

This class represents a diagnostic that is inflight and set to be reported.

This class provides a mutable adaptor for a range of operands.

This class implements the operand iterators for the Operation class.

unsigned getBeginOperandIndex() const

Return the operand index of the first element of this range.

Operation is the basic unit of execution within MLIR.

Block * getSuccessor(unsigned index)

unsigned getNumRegions()

Returns the number of regions held by this operation.

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

OpTy getParentOfType()

Return the closest surrounding parent operation that is of type 'OpTy'.

MutableArrayRef< Region > getRegions()

Returns the regions held by this operation.

Region * getParentRegion()

Returns the region to which the instruction belongs.

InFlightDiagnostic emitOpError(const Twine &message={})

Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.

This class represents a point being branched from in the methods of the RegionBranchOpInterface.

static constexpr RegionBranchPoint parent()

Returns an instance of RegionBranchPoint representing the parent operation.

Region * getRegionOrNull() const

Returns the region if branching from a region.

This class represents a successor of a region.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

unsigned getRegionNumber()

Return the number of this region in the parent operation.

Operation * getParentOp()

Return the parent operation this region is attached to.

This class models how operands are forwarded to block arguments in control flow.

SuccessorOperands(MutableOperandRange forwardedOperands)

Constructs a SuccessorOperands with no produced operands that simply forwards operands to the success...

unsigned getProducedOperandCount() const

Returns the amount of operands that are produced internally by the operation.

unsigned size() const

Returns the amount of operands passed to the successor.

OperandRange getForwardedOperands() const

Get the range of operands that are simply forwarded to the successor.

This class provides an abstraction over the various different ranges of value types.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

Region * getParentRegion()

Return the Region in which this Value is defined.

std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)

Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...

LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands)

Verify that the given operands match those of the given successor block.

constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)

LogicalResult verifyTypesAlongControlFlowEdges(Operation *op)

Verify that types match along control flow edges described the given op.

Include the generated interface declarations.

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)

Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.

Region * getEnclosingRepetitiveRegion(Operation *op)

Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...