MLIR: include/mlir/IR/Visitors.h Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13 #ifndef MLIR_IR_VISITORS_H

14 #define MLIR_IR_VISITORS_H

15

17 #include "llvm/ADT/STLExtras.h"

18

19 namespace mlir {

20 class Diagnostic;

21 class InFlightDiagnostic;

22 class Operation;

24 class Region;

25

26

27

28

29

30

31

32

34 enum ResultEnum { Interrupt, Advance, Skip } result;

35

36 public:

37 WalkResult(ResultEnum result = Advance) : result(result) {}

38

39

41 : result(failed(result) ? Interrupt : Advance) {}

42

43

46

49

53

54

56

57

58 bool wasSkipped() const { return result == Skip; }

59 };

60

61

63

64

66

68

69

70 template

72 return range;

73 }

74 };

75

76

77

78

79

80

81

82

83

84

85

86

88 public:

90

91

93

94

95 bool isBeforeRegion(int region) const { return nextRegion == region; }

96

97

98 bool isAfterRegion(int region) const { return nextRegion == region + 1; }

99

101

103

105

106 private:

107 const int numRegions;

108 int nextRegion;

109 };

110

111 namespace detail {

112

113 template <typename Ret, typename Arg, typename... Rest>

115 template <typename Ret, typename F, typename Arg, typename... Rest>

117 template <typename Ret, typename F, typename Arg, typename... Rest>

119 template

121

122

123 template

125

126

127

128

129

130

131

132

133

134

135 template

138

139

140 for (auto &region : Iterator::makeIterable(*op)) {

142 callback(&region);

143 for (auto &block : Iterator::makeIterable(region)) {

144 for (auto &nestedOp : Iterator::makeIterable(block))

145 walk(&nestedOp, callback, order);

146 }

148 callback(&region);

149 }

150 }

151

152 template

155 for (auto &region : Iterator::makeIterable(*op)) {

156

157 for (auto &block :

158 llvm::make_early_inc_range(Iterator::makeIterable(region))) {

160 callback(&block);

161 for (auto &nestedOp : Iterator::makeIterable(block))

162 walk(&nestedOp, callback, order);

164 callback(&block);

165 }

166 }

167 }

168

169 template

173 callback(op);

174

175

176 for (auto &region : Iterator::makeIterable(*op)) {

177 for (auto &block : Iterator::makeIterable(region)) {

178

179 for (auto &nestedOp :

180 llvm::make_early_inc_range(Iterator::makeIterable(block)))

181 walk(&nestedOp, callback, order);

182 }

183 }

184

186 callback(op);

187 }

188

189

190

191

192

193

194

195

196

197

198

199 template

202

203

204 for (auto &region : Iterator::makeIterable(*op)) {

206 WalkResult result = callback(&region);

208 continue;

211 }

212 for (auto &block : Iterator::makeIterable(region)) {

213 for (auto &nestedOp : Iterator::makeIterable(block))

214 if (walk(&nestedOp, callback, order).wasInterrupted())

216 }

218 if (callback(&region).wasInterrupted())

220

221

222 }

223 }

225 }

226

227 template

230 for (auto &region : Iterator::makeIterable(*op)) {

231

232 for (auto &block :

233 llvm::make_early_inc_range(Iterator::makeIterable(region))) {

235 WalkResult result = callback(&block);

237 continue;

240 }

241 for (auto &nestedOp : Iterator::makeIterable(block))

242 if (walk(&nestedOp, callback, order).wasInterrupted())

245 if (callback(&block).wasInterrupted())

247

248

249 }

250 }

251 }

253 }

254

255 template

260

265 }

266

267

268 for (auto &region : Iterator::makeIterable(*op)) {

269 for (auto &block : Iterator::makeIterable(region)) {

270

271 for (auto &nestedOp :

272 llvm::make_early_inc_range(Iterator::makeIterable(block))) {

273 if (walk(&nestedOp, callback, order).wasInterrupted())

275 }

276 }

277 }

278

280 return callback(op);

282 }

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

304

305 template <

308 typename RetT = decltype(std::declval()(std::declval()))>

309 std::enable_if_t<llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value,

310 RetT>

312 return detail::walk(op, function_ref<RetT(ArgT)>(callback), Order);

313 }

314

315

316

317

318

319

320

321

322

323

324

325

326

327 template <

330 typename RetT = decltype(std::declval()(std::declval()))>

331 std::enable_if_t<

332 !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&

333 std::is_same<RetT, void>::value,

334 RetT>

336 auto wrapperFn = [&](Operation *op) {

337 if (auto derivedOp = dyn_cast(op))

338 callback(derivedOp);

339 };

341 Order);

342 }

343

344

345

346

347

348

349

350

351

352

353

354

355

356

357

358

359

360

361

362

363

364 template <

367 typename RetT = decltype(std::declval()(std::declval()))>

368 std::enable_if_t<

369 !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&

370 std::is_same<RetT, WalkResult>::value,

371 RetT>

373 auto wrapperFn = [&](Operation *op) {

374 if (auto derivedOp = dyn_cast(op))

375 return callback(derivedOp);

377 };

379 Order);

380 }

381

382

383

384

385

386

387

388

391

392

393

394

395

396

400

401

402

403

404

405

406

407 template <typename FuncTy, typename ArgT = detail::first_argument,

408 typename RetT = decltype(std::declval()(

409 std::declval(), std::declval<const WalkStage &>()))>

410 std::enable_if_t<std::is_same<ArgT, Operation *>::value, RetT>

414 }

415

416

417

418

419

420

421

422 template <typename FuncTy, typename ArgT = detail::first_argument,

423 typename RetT = decltype(std::declval()(

424 std::declval(), std::declval<const WalkStage &>()))>

425 std::enable_if_t<!std::is_same<ArgT, Operation *>::value &&

426 std::is_same<RetT, void>::value,

427 RetT>

430 if (auto derivedOp = dyn_cast(op))

431 callback(derivedOp, stage);

432 };

435 }

436

437

438

439

440

441

442

443

444

445

446

447 template <typename FuncTy, typename ArgT = detail::first_argument,

448 typename RetT = decltype(std::declval()(

449 std::declval(), std::declval<const WalkStage &>()))>

450 std::enable_if_t<!std::is_same<ArgT, Operation *>::value &&

451 std::is_same<RetT, WalkResult>::value,

452 RetT>

455 if (auto derivedOp = dyn_cast(op))

456 return callback(derivedOp, stage);

458 };

461 }

462

463

464 template

466 }

467

468 }

469

470 #endif

Block represents an ordered list of Operations.

This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.

This class represents a diagnostic that is inflight and set to be reported.

Operation is the basic unit of execution within MLIR.

This class contains a list of basic blocks and a link to the parent operation it is attached to.

A utility result that is used to signal how to proceed with an ongoing walk:

WalkResult(InFlightDiagnostic &&)

bool operator==(const WalkResult &rhs) const

WalkResult(LogicalResult result)

Allow LogicalResult to interrupt the walk on failure.

WalkResult(ResultEnum result=Advance)

bool wasSkipped() const

Returns true if the walk was skipped.

static WalkResult advance()

bool wasInterrupted() const

Returns true if the walk was interrupted.

WalkResult(Diagnostic &&)

Allow diagnostics to interrupt the walk.

static WalkResult interrupt()

bool operator!=(const WalkResult &rhs) const

A utility class to encode the current walk stage for "generic" walkers.

void advance()

Advance the walk stage.

int getNextRegion() const

Returns the next region that will be visited.

bool isBeforeRegion(int region) const

Returns true if parent operation is being visited just before visiting region number region.

bool isAfterAllRegions() const

Return true if parent operation is being visited after all regions.

bool isAfterRegion(int region) const

Returns true if parent operation is being visited just after visiting region number region.

bool isBeforeAllRegions() const

Return true if parent operation is being visited before all regions.

void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)

Walk all of the regions, blocks, or operations nested under (and including) the given operation.

decltype(walk(nullptr, std::declval< FnT >())) walkResultType

Utility to provide the return type of a templated walk method.

decltype(first_argument_type(&F::operator())) first_argument_type(F)

decltype(first_argument_type(std::declval< T >())) first_argument

Type definition of the first argument to the given callable 'T'.

Include the generated interface declarations.

WalkOrder

Traversal order for region, block and operation walk utilities.

This iterator enumerates the elements in "forward" order.

static MutableArrayRef< Region > makeIterable(Operation &range)

Make operations iterable: return the list of regions.

static constexpr T & makeIterable(T &range)

Regions and block are already iterable.