MLIR: lib/Dialect/SPIRV/IR/ControlFlowOps.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

17

18 #include "llvm/Support/InterleavedRange.h"

19

22

24

26

27

28

29 template <typename EnumAttrClass, typename EnumClass>

30 static ParseResult

32 StringRef attrName = spirv::attributeName()) {

34 EnumClass control;

36 spirv::parseEnumKeywordAttr(control, parser, state) ||

38 return failure();

39 return success();

40 }

41

43 state.addAttribute(attrName,

44 builder.getAttr(static_cast<EnumClass>(0)));

45 return success();

46 }

47

48

49

50

51

53 assert(index == 0 && "invalid successor index");

55 }

56

57

58

59

60

62 assert(index < 2 && "invalid successor index");

63 return SuccessorOperands(index == kTrueIndex

64 ? getTrueTargetOperandsMutable()

65 : getFalseTargetOperandsMutable());

66 }

67

69 OperationState &result) {

70 auto &builder = parser.getBuilder();

71 OpAsmParser::UnresolvedOperand condInfo;

73

74

75 Type boolTy = builder.getI1Type();

76 if (parser.parseOperand(condInfo) ||

77 parser.resolveOperand(condInfo, boolTy, result.operands))

78 return failure();

79

80

81 if (succeeded(parser.parseOptionalLSquare())) {

82 IntegerAttr trueWeight, falseWeight;

83 NamedAttrList weights;

84

85 auto i32Type = builder.getIntegerType(32);

86 if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||

87 parser.parseComma() ||

88 parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||

89 parser.parseRSquare())

90 return failure();

91

92 StringAttr branchWeightsAttrName =

93 BranchConditionalOp::getBranchWeightsAttrName(result.name);

94 result.addAttribute(branchWeightsAttrName,

95 builder.getArrayAttr({trueWeight, falseWeight}));

96 }

97

98

99 SmallVector<Value, 4> trueOperands;

100 if (parser.parseComma() ||

101 parser.parseSuccessorAndUseList(dest, trueOperands))

102 return failure();

103 result.addSuccessors(dest);

104 result.addOperands(trueOperands);

105

106

107 SmallVector<Value, 4> falseOperands;

108 if (parser.parseComma() ||

109 parser.parseSuccessorAndUseList(dest, falseOperands))

110 return failure();

111 result.addSuccessors(dest);

112 result.addOperands(falseOperands);

113 result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),

114 builder.getDenseI32ArrayAttr(

115 {1, static_cast<int32_t>(trueOperands.size()),

116 static_cast<int32_t>(falseOperands.size())}));

117

118 return success();

119 }

120

122 printer << ' ' << getCondition();

123

124 if (std::optional weights = getBranchWeights()) {

125 printer << ' '

126 << llvm::interleaved_array(weights->getAsValueRange());

127 }

128

129 printer << ", ";

130 printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());

131 printer << ", ";

132 printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());

133 }

134

136 if (auto weights = getBranchWeights()) {

137 if (weights->getValue().size() != 2) {

138 return emitOpError("must have exactly two branch weights");

139 }

140 if (llvm::all_of(*weights, [](Attribute attr) {

141 return llvm::cast(attr).getValue().isZero();

142 }))

143 return emitOpError("branch weights cannot both be zero");

144 }

145

146 return success();

147 }

148

149

150

151

152

154 auto fnName = getCalleeAttr();

155

156 auto funcOp = dyn_cast_or_nullspirv::FuncOp(

158 if (!funcOp) {

159 return emitOpError("callee function '")

160 << fnName.getValue() << "' not found in nearest symbol table";

161 }

162

163 auto functionType = funcOp.getFunctionType();

164

165 if (getNumResults() > 1) {

166 return emitOpError(

167 "expected callee function to have 0 or 1 result, but provided ")

168 << getNumResults();

169 }

170

171 if (functionType.getNumInputs() != getNumOperands()) {

172 return emitOpError("has incorrect number of operands for callee: expected ")

173 << functionType.getNumInputs() << ", but provided "

174 << getNumOperands();

175 }

176

177 for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {

178 if (getOperand(i).getType() != functionType.getInput(i)) {

179 return emitOpError("operand type mismatch: expected operand type ")

180 << functionType.getInput(i) << ", but provided "

181 << getOperand(i).getType() << " for operand number " << i;

182 }

183 }

184

185 if (functionType.getNumResults() != getNumResults()) {

186 return emitOpError(

187 "has incorrect number of results has for callee: expected ")

188 << functionType.getNumResults() << ", but provided "

189 << getNumResults();

190 }

191

192 if (getNumResults() &&

193 (getResult(0).getType() != functionType.getResult(0))) {

194 return emitOpError("result type mismatch: expected ")

195 << functionType.getResult(0) << ", but provided "

196 << getResult(0).getType();

197 }

198

199 return success();

200 }

201

202 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {

203 return (*this)->getAttrOfType(getCalleeAttrName());

204 }

205

206 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {

207 (*this)->setAttr(getCalleeAttrName(), cast(callee));

208 }

209

211 return getArguments();

212 }

213

214 MutableOperandRange FunctionCallOp::getArgOperandsMutable() {

215 return getArgumentsMutable();

216 }

217

218

219

220

221

222 void LoopOp::build(OpBuilder &builder, OperationState &state) {

223 state.addAttribute("loop_control", builder.getAttrspirv::LoopControlAttr(

225 state.addRegion();

226 }

227

228 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {

229 if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,

230 result))

231 return failure();

232

233 if (succeeded(parser.parseOptionalArrow()))

234 if (parser.parseTypeList(result.types))

235 return failure();

236

237 return parser.parseRegion(*result.addRegion(), {});

238 }

239

241 auto control = getLoopControl();

243 printer << " control(" << spirv::stringifyLoopControl(control) << ")";

244 if (getNumResults() > 0) {

245 printer << " -> ";

246 printer << getResultTypes();

247 }

248 printer << ' ';

249 printer.printRegion(getRegion(), false,

250 true);

251 }

252

253

254

256

257 if (!llvm::hasSingleElement(srcBlock))

258 return false;

259

260 auto branchOp = dyn_castspirv::BranchOp(srcBlock.back());

261 return branchOp && branchOp.getSuccessor() == &dstBlock;

262 }

263

264

266 return llvm::hasSingleElement(block) && isaspirv::MergeOp(block.front());

267 }

268

269

272 return isaspirv::MergeOp(op) && op.getBlock() != &region.back();

273 });

274 }

275

276 LogicalResult LoopOp::verifyRegions() {

277 auto *op = getOperation();

278

279

280

281

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

304 auto &region = op->getRegion(0);

305

306

307 if (region.empty())

308 return success();

309

310

313 return emitOpError("last block must be the merge block with only one "

314 "'spirv.mlir.merge' op");

316 return emitOpError(

317 "should not have 'spirv.mlir.merge' op outside the merge block");

318

319 if (region.hasOneBlock())

320 return emitOpError(

321 "must have an entry block branching to the loop header block");

322

324

325 if (std::next(region.begin(), 2) == region.end())

326 return emitOpError(

327 "must have a loop header block branched from the entry block");

328

329 Block &header = *std::next(region.begin(), 1);

330

332 return emitOpError(

333 "entry block must only have one 'spirv.Branch' op to the second block");

334

335 if (std::next(region.begin(), 3) == region.end())

336 return emitOpError(

337 "requires a loop continue block branching to the loop header block");

338

339 Block &cont = *std::prev(region.end(), 2);

340

341

342

343 if (llvm::none_of(

345 [&](unsigned index) { return cont.getSuccessor(index) == &header; }))

346 return emitOpError("second to last block must be the loop continue "

347 "block that branches to the loop header block");

348

349

350

351 for (auto &block : llvm::make_range(std::next(region.begin(), 2),

352 std::prev(region.end(), 2))) {

353 for (auto i : llvm::seq(0, block.getNumSuccessors())) {

354 if (block.getSuccessor(i) == &header) {

355 return emitOpError("can only have the entry and loop continue "

356 "block branching to the loop header block");

357 }

358 }

359 }

360

361 return success();

362 }

363

364 Block *LoopOp::getEntryBlock() {

365 assert(!getBody().empty() && "op region should not be empty!");

366 return &getBody().front();

367 }

368

369 Block *LoopOp::getHeaderBlock() {

370 assert(!getBody().empty() && "op region should not be empty!");

371

372 return &*std::next(getBody().begin());

373 }

374

375 Block *LoopOp::getContinueBlock() {

376 assert(!getBody().empty() && "op region should not be empty!");

377

378 return &*std::prev(getBody().end(), 2);

379 }

380

381 Block *LoopOp::getMergeBlock() {

382 assert(!getBody().empty() && "op region should not be empty!");

383

384 return &getBody().back();

385 }

386

387 void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {

388 assert(getBody().empty() && "entry and merge block already exist");

389 OpBuilder::InsertionGuard g(builder);

390 builder.createBlock(&getBody());

391 builder.createBlock(&getBody());

392

393

394 builder.createspirv::MergeOp(getLoc());

395 }

396

397

398

399

400

402

403 return success();

404 }

405

406

407

408

409

411

412 return success();

413 }

414

415

416

417

418

420 if (auto conditionTy = llvm::dyn_cast(getCondition().getType())) {

421 auto resultVectorTy = llvm::dyn_cast(getResult().getType());

422 if (!resultVectorTy) {

423 return emitOpError("result expected to be of vector type when "

424 "condition is of vector type");

425 }

426 if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {

427 return emitOpError("result should have the same number of elements as "

428 "the condition when condition is of vector type");

429 }

430 }

431 return success();

432 }

433

434

435

436 SmallVector<ArrayRefspirv::Extension, 1> SelectOp::getExtensions() {

437 return {};

438 }

439 SmallVector<ArrayRefspirv::Capability, 1> SelectOp::getCapabilities() {

440 return {};

441 }

442 std::optionalspirv::Version SelectOp::getMinVersion() {

443

444

445 if (isaspirv::ScalarType(getCondition().getType()) &&

446 isaspirv::CompositeType(getType()))

447 return Version::V_1_4;

448

449 return Version::V_1_0;

450 }

451 std::optionalspirv::Version SelectOp::getMaxVersion() {

452 return Version::V_1_6;

453 }

454

455

456

457

458

459 ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {

461 spirv::SelectionControl>(parser, result))

462 return failure();

463

464 if (succeeded(parser.parseOptionalArrow()))

465 if (parser.parseTypeList(result.types))

466 return failure();

467

468 return parser.parseRegion(*result.addRegion(), {});

469 }

470

472 auto control = getSelectionControl();

474 printer << " control(" << spirv::stringifySelectionControl(control) << ")";

475 if (getNumResults() > 0) {

476 printer << " -> ";

477 printer << getResultTypes();

478 }

479 printer << ' ';

480 printer.printRegion(getRegion(), false,

481 true);

482 }

483

484 LogicalResult SelectionOp::verifyRegions() {

485 auto *op = getOperation();

486

487

488

489

490

491

492

493

494

495

496

497

498

499

500

501

502

503

504

505

506

507

508 auto &region = op->getRegion(0);

509

510

511 if (region.empty())

512 return success();

513

514

516 return emitOpError("last block must be the merge block with only one "

517 "'spirv.mlir.merge' op");

519 return emitOpError(

520 "should not have 'spirv.mlir.merge' op outside the merge block");

521

522 if (region.hasOneBlock())

523 return emitOpError("must have a selection header block");

524

525 return success();

526 }

527

528 Block *SelectionOp::getHeaderBlock() {

529 assert(!getBody().empty() && "op region should not be empty!");

530

531 return &getBody().front();

532 }

533

534 Block *SelectionOp::getMergeBlock() {

535 assert(!getBody().empty() && "op region should not be empty!");

536

537 return &getBody().back();

538 }

539

540 void SelectionOp::addMergeBlock(OpBuilder &builder) {

541 assert(getBody().empty() && "entry and merge block already exist");

542 OpBuilder::InsertionGuard guard(builder);

543 builder.createBlock(&getBody());

544

545

546 builder.createspirv::MergeOp(getLoc());

547 }

548

549 SelectionOp

550 SelectionOp::createIfThen(Location loc, Value condition,

551 function_ref<void(OpBuilder &builder)> thenBody,

552 OpBuilder &builder) {

553 auto selectionOp =

555

556 selectionOp.addMergeBlock(builder);

557 Block *mergeBlock = selectionOp.getMergeBlock();

558 Block *thenBlock = nullptr;

559

560

561 {

562 OpBuilder::InsertionGuard guard(builder);

563 thenBlock = builder.createBlock(mergeBlock);

564 thenBody(builder);

565 builder.createspirv::BranchOp(loc, mergeBlock);

566 }

567

568

569 {

570 OpBuilder::InsertionGuard guard(builder);

571 builder.createBlock(thenBlock);

572 builder.createspirv::BranchConditionalOp(

573 loc, condition, thenBlock,

574 ArrayRef(), mergeBlock,

575 ArrayRef());

576 }

577

578 return selectionOp;

579 }

580

581

582

583

584

586 auto *block = (*this)->getBlock();

587

588

589 if (block->isEntryBlock())

590 return emitOpError("cannot be used in reachable block");

591 if (block->hasNoPredecessors())

592 return success();

593

594

595

596

597 return success();

598 }

599

600 }

static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)

Return the operand range used to transfer operands from block to its successor with the given index.

static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)

virtual Builder & getBuilder() const =0

Return a builder which provides useful access to MLIRContext, global objects like types and attribute...

virtual ParseResult parseOptionalKeyword(StringRef keyword)=0

Parse the given keyword if present.

virtual ParseResult parseRParen()=0

Parse a ) token.

virtual ParseResult parseLParen()=0

Parse a ( token.

Block represents an ordered list of Operations.

unsigned getNumSuccessors()

This class is a general helper class for creating context-global objects like types,...

Attr getAttr(Args &&...args)

Get or construct an instance of the attribute Attr with provided arguments.

The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...

Operation is the basic unit of execution within MLIR.

OperandRange operand_range

This class contains a list of basic blocks and a link to the parent operation it is attached to.

iterator_range< OpIterator > getOps()

This class models how operands are forwarded to block arguments in control flow.

static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)

Returns the operation registered with the given symbol name within the closest parent operation of,...

@ Type

An inlay hint that for a type annotation.

QueryRef parse(llvm::StringRef line, const QuerySession &qs)

constexpr char kControl[]

static bool hasOtherMerge(Region &region)

Returns true if a spirv.mlir.merge op outside the merge block.

static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)

Returns true if the given srcBlock contains only one spirv.Branch to the given dstBlock.

static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())

Parses Function, Selection and Loop control attributes.

static bool isMergeBlock(Block &block)

Returns true if the given block only contains one spirv.mlir.merge op.

llvm::function_ref< Fn > function_ref

Type getType(OpFoldResult ofr)

Returns the int type of the integer in ofr.

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

This represents an operation in an abstracted form, suitable for use with the builder APIs.