MLIR: lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

25 #include "llvm/ADT/TypeSwitch.h"

26 #include "llvm/Support/Debug.h"

27 #include "llvm/Support/FormatVariadic.h"

28

29 #define DEBUG_TYPE "spirv-to-llvm-pattern"

30

31 using namespace mlir;

32

33

34

35

36

37

40 return true;

41 if (auto vecType = dyn_cast(type))

42 return vecType.getElementType().isSignedInteger();

43 return false;

44 }

45

46

49 return true;

50 if (auto vecType = dyn_cast(type))

51 return vecType.getElementType().isUnsignedInteger();

52 return false;

53 }

54

55

56

58 if (auto intType = dyn_cast(type))

59 return intType.getWidth();

60 if (auto vecType = dyn_cast(type))

61 if (auto intType = dyn_cast(vecType.getElementType()))

62 return intType.getWidth();

63 return std::nullopt;

64 }

65

66

68 assert((type.isIntOrFloat() || isa(type)) &&

69 "bitwidth is not supported for this type");

72 auto vecType = dyn_cast(type);

73 auto elementType = vecType.getElementType();

74 assert(elementType.isIntOrFloat() &&

75 "only integers and floats have a bitwidth");

76 return elementType.getIntOrFloatBitWidth();

77 }

78

79

81 if (auto vecTy = dyn_cast(type))

82 type = vecTy.getElementType();

83 return cast(type).getWidth();

84 }

85

86

88 if (auto vecType = dyn_cast(type)) {

89 auto integerType = cast(vecType.getElementType());

91 }

92 auto integerType = cast(type);

94 }

95

96

99 if (isa(srcType)) {

100 return rewriter.createLLVM::ConstantOp(

101 loc, dstType,

104 }

105 return rewriter.createLLVM::ConstantOp(

107 }

108

109

112 if (auto vecType = dyn_cast(srcType)) {

113 auto floatType = cast(vecType.getElementType());

114 return rewriter.createLLVM::ConstantOp(

115 loc, dstType,

118 }

119 auto floatType = cast(srcType);

120 return rewriter.createLLVM::ConstantOp(

121 loc, dstType, rewriter.getFloatAttr(floatType, value));

122 }

123

124

125

126

127

128

129

131 Type llvmType,

133 auto srcType = value.getType();

138

139 if (valueBitWidth < targetBitWidth)

140 return rewriter.createLLVM::ZExtOp(loc, llvmType, value);

141

142

143

144

145 if (valueBitWidth > targetBitWidth)

146 return rewriter.createLLVM::TruncOp(loc, llvmType, value);

147 return value;

148 }

149

150

155 auto llvmVectorType = typeConverter.convertType(vectorType);

157 Value broadcasted = rewriter.createLLVM::PoisonOp(loc, llvmVectorType);

158 for (unsigned i = 0; i < numElements; ++i) {

159 auto index = rewriter.createLLVM::ConstantOp(

161 broadcasted = rewriter.createLLVM::InsertElementOp(

162 loc, llvmVectorType, broadcasted, toBroadcast, index);

163 }

164 return broadcasted;

165 }

166

167

171 if (auto vectorType = dyn_cast(srcType)) {

172 unsigned numElements = vectorType.getNumElements();

173 return broadcast(loc, value, numElements, typeConverter, rewriter);

174 }

175 return value;

176 }

177

178

179

180

181

182

183

184

185

186

187

191 Value broadcasted =

194 }

195

196

197

201 return nullptr;

202

205 return nullptr;

206 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,

207 false);

208 }

209

210

215 return nullptr;

216 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,

217 true);

218 }

219

220

222 unsigned value) {

223 return rewriter.createLLVM::ConstantOp(

226 }

227

228

232 unsigned alignment, bool isVolatile,

233 bool isNonTemporal) {

234 if (auto loadOp = dyn_castspirv::LoadOp(op)) {

235 auto dstType = typeConverter.convertType(loadOp.getType());

236 if (!dstType)

239 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,

240 isVolatile, isNonTemporal);

241 return success();

242 }

243 auto storeOp = castspirv::StoreOp(op);

244 spirv::StoreOpAdaptor adaptor(operands);

245 rewriter.replaceOpWithNewOpLLVM::StoreOp(storeOp, adaptor.getValue(),

246 adaptor.getPtr(), alignment,

247 isVolatile, isNonTemporal);

248 return success();

249 }

250

251

252

253

254

255

256

257

262 auto sizeInBytes = castspirv::SPIRVType(elementType).getSizeInBytes();

263 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))

264 return std::nullopt;

265

266 auto llvmElementType = converter.convertType(elementType);

269 }

270

271

272

275 spirv::ClientAPI clientAPI) {

276 unsigned addressSpace =

279 }

280

281

282

283

287 return std::nullopt;

290 }

291

292

293

298 if (!memberDecorations.empty())

299 return nullptr;

303 }

304

305

306

307

308

309 namespace {

310

312 public:

314

315 LogicalResult

316 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,

318 auto dstType =

319 getTypeConverter()->convertType(op.getComponentPtr().getType());

320 if (!dstType)

322

323 auto indices = llvm::to_vector<4>(adaptor.getIndices());

324 Type indexType = op.getIndices().front().getType();

325 auto llvmIndexType = getTypeConverter()->convertType(indexType);

326 if (!llvmIndexType)

328 Value zero = rewriter.createLLVM::ConstantOp(

329 op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));

330 indices.insert(indices.begin(), zero);

331

332 auto elementType = getTypeConverter()->convertType(

333 castspirv::PointerType(op.getBasePtr().getType()).getPointeeType());

334 if (!elementType)

337 adaptor.getBasePtr(), indices);

338 return success();

339 }

340 };

341

343 public:

345

346 LogicalResult

347 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,

349 auto dstType = getTypeConverter()->convertType(op.getPointer().getType());

350 if (!dstType)

353 op.getVariable());

354 return success();

355 }

356 };

357

358 class BitFieldInsertPattern

360 public:

362

363 LogicalResult

364 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,

366 auto srcType = op.getType();

367 auto dstType = getTypeConverter()->convertType(srcType);

368 if (!dstType)

371

372

374 *getTypeConverter(), rewriter);

376 *getTypeConverter(), rewriter);

377

378

380 Value maskShiftedByCount =

381 rewriter.createLLVM::ShlOp(loc, dstType, minusOne, count);

382 Value negated = rewriter.createLLVM::XOrOp(loc, dstType,

383 maskShiftedByCount, minusOne);

384 Value maskShiftedByCountAndOffset =

385 rewriter.createLLVM::ShlOp(loc, dstType, negated, offset);

386 Value mask = rewriter.createLLVM::XOrOp(

387 loc, dstType, maskShiftedByCountAndOffset, minusOne);

388

389

390

391 Value baseAndMask =

392 rewriter.createLLVM::AndOp(loc, dstType, op.getBase(), mask);

393 Value insertShiftedByOffset =

394 rewriter.createLLVM::ShlOp(loc, dstType, op.getInsert(), offset);

396 insertShiftedByOffset);

397 return success();

398 }

399 };

400

401

402 class ConstantScalarAndVectorPattern

404 public:

406

407 LogicalResult

408 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,

410 auto srcType = constOp.getType();

411 if (!isa(srcType) && !srcType.isIntOrFloat())

412 return failure();

413

414 auto dstType = getTypeConverter()->convertType(srcType);

415 if (!dstType)

416 return rewriter.notifyMatchFailure(constOp, "type conversion failed");

417

418

419

420

421

422

426

427 if (isa(srcType)) {

428 auto dstElementsAttr = cast(constOp.getValue());

430 constOp, dstType,

431 dstElementsAttr.mapValues(

432 signlessType, [&](const APInt &value) { return value; }));

433 return success();

434 }

435 auto srcAttr = cast(constOp.getValue());

436 auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());

437 rewriter.replaceOpWithNewOpLLVM::ConstantOp(constOp, dstType, dstAttr);

438 return success();

439 }

441 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());

442 return success();

443 }

444 };

445

446 class BitFieldSExtractPattern

448 public:

450

451 LogicalResult

452 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,

454 auto srcType = op.getType();

455 auto dstType = getTypeConverter()->convertType(srcType);

456 if (!dstType)

459

460

462 *getTypeConverter(), rewriter);

464 *getTypeConverter(), rewriter);

465

466

467 IntegerType integerType;

468 if (auto vecType = dyn_cast(srcType))

469 integerType = cast(vecType.getElementType());

470 else

471 integerType = cast(srcType);

472

475 isa(srcType)

476 ? rewriter.createLLVM::ConstantOp(

477 loc, dstType,

479 : rewriter.createLLVM::ConstantOp(loc, dstType, baseSize);

480

481

482

483 Value countPlusOffset =

484 rewriter.createLLVM::AddOp(loc, dstType, count, offset);

485 Value amountToShiftLeft =

486 rewriter.createLLVM::SubOp(loc, dstType, size, countPlusOffset);

487 Value baseShiftedLeft = rewriter.createLLVM::ShlOp(

488 loc, dstType, op.getBase(), amountToShiftLeft);

489

490

491 Value amountToShiftRight =

492 rewriter.createLLVM::AddOp(loc, dstType, offset, amountToShiftLeft);

493 rewriter.replaceOpWithNewOpLLVM::AShrOp(op, dstType, baseShiftedLeft,

494 amountToShiftRight);

495 return success();

496 }

497 };

498

499 class BitFieldUExtractPattern

501 public:

503

504 LogicalResult

505 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,

507 auto srcType = op.getType();

508 auto dstType = getTypeConverter()->convertType(srcType);

509 if (!dstType)

512

513

515 *getTypeConverter(), rewriter);

517 *getTypeConverter(), rewriter);

518

519

521 Value maskShiftedByCount =

522 rewriter.createLLVM::ShlOp(loc, dstType, minusOne, count);

523 Value mask = rewriter.createLLVM::XOrOp(loc, dstType, maskShiftedByCount,

524 minusOne);

525

526

527 Value shiftedBase =

528 rewriter.createLLVM::LShrOp(loc, dstType, op.getBase(), offset);

529 rewriter.replaceOpWithNewOpLLVM::AndOp(op, dstType, shiftedBase, mask);

530 return success();

531 }

532 };

533

535 public:

537

538 LogicalResult

539 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,

541 rewriter.replaceOpWithNewOpLLVM::BrOp(branchOp, adaptor.getOperands(),

542 branchOp.getTarget());

543 return success();

544 }

545 };

546

547 class BranchConditionalConversionPattern

549 public:

552

553 LogicalResult

554 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,

556

558 if (auto weights = op.getBranchWeights()) {

560 for (auto weight : weights->getAsRange())

561 weightValues.push_back(weight.getInt());

563 }

564

566 op, op.getCondition(), op.getTrueBlockArguments(),

567 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),

568 op.getFalseBlock());

569 return success();

570 }

571 };

572

573

574

575

576 class CompositeExtractPattern

578 public:

580

581 LogicalResult

582 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,

584 auto dstType = this->getTypeConverter()->convertType(op.getType());

585 if (!dstType)

587

588 Type containerType = op.getComposite().getType();

589 if (isa(containerType)) {

591 IntegerAttr value = cast(op.getIndices()[0]);

594 op, dstType, adaptor.getComposite(), index);

595 return success();

596 }

597

599 op, adaptor.getComposite(),

601 return success();

602 }

603 };

604

605

606

607

608 class CompositeInsertPattern

610 public:

612

613 LogicalResult

614 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,

616 auto dstType = this->getTypeConverter()->convertType(op.getType());

617 if (!dstType)

619

620 Type containerType = op.getComposite().getType();

621 if (isa(containerType)) {

623 IntegerAttr value = cast(op.getIndices()[0]);

626 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);

627 return success();

628 }

629

631 op, adaptor.getComposite(), adaptor.getObject(),

633 return success();

634 }

635 };

636

637

638

639 template <typename SPIRVOp, typename LLVMOp>

641 public:

643

644 LogicalResult

645 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,

647 auto dstType = this->getTypeConverter()->convertType(op.getType());

648 if (!dstType)

650 rewriter.template replaceOpWithNewOp(

651 op, dstType, adaptor.getOperands(), op->getAttrs());

652 return success();

653 }

654 };

655

656

657

658 class ExecutionModePattern

660 public:

662

663 LogicalResult

664 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,

666

667

668

669 ModuleOp module = op->getParentOfType();

670 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();

671 std::string moduleName;

672 if (module.getName().has_value())

673 moduleName = "_" + module.getName()->str();

674 else

675 moduleName = "";

676 std::string executionModeInfoName = llvm::formatv(

677 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),

678 static_cast<uint32_t>(executionModeAttr.getValue()));

679

683

684

685

686

687

688

691 fields.push_back(llvmI32Type);

692 ArrayAttr values = op.getValues();

693 if (!values.empty()) {

695 fields.push_back(arrayType);

696 }

697 auto structType = LLVM::LLVMStructType::getLiteral(context, fields);

698

699

700 auto global = rewriter.createLLVM::GlobalOp(

701 UnknownLoc::get(context), structType, true,

702 LLVM::Linkage::External, executionModeInfoName, Attribute(),

703 0);

704 Location loc = global.getLoc();

705 Region &region = global.getInitializerRegion();

707

708

710 Value structValue = rewriter.createLLVM::PoisonOp(loc, structType);

711 Value executionMode = rewriter.createLLVM::ConstantOp(

712 loc, llvmI32Type,

714 static_cast<uint32_t>(executionModeAttr.getValue())));

715 structValue = rewriter.createLLVM::InsertValueOp(loc, structValue,

716 executionMode, 0);

717

718

719 for (unsigned i = 0, e = values.size(); i < e; ++i) {

720 auto attr = values.getValue()[i];

721 Value entry = rewriter.createLLVM::ConstantOp(loc, llvmI32Type, attr);

722 structValue = rewriter.createLLVM::InsertValueOp(

724 }

727 return success();

728 }

729 };

730

731

732

733

734

735 class GlobalVariablePattern

737 public:

738 template <typename... Args>

739 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)

741 std::forward(args)...),

742 clientAPI(clientAPI) {}

743

744 LogicalResult

745 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,

747

748

749 if (op.getInitializer())

750 return failure();

751

752 auto srcType = castspirv::PointerType(op.getType());

753 auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());

754 if (!dstType)

756

757

758

759

760 auto storageClass = srcType.getStorageClass();

761 switch (storageClass) {

762 case spirv::StorageClass::Input:

763 case spirv::StorageClass::Private:

764 case spirv::StorageClass::Output:

765 case spirv::StorageClass::StorageBuffer:

766 case spirv::StorageClass::UniformConstant:

767 break;

768 default:

769 return failure();

770 }

771

772

773

774

775 bool isConstant = (storageClass == spirv::StorageClass::Input) ||

776 (storageClass == spirv::StorageClass::UniformConstant);

777

778

779

780

781

782 auto linkage = storageClass == spirv::StorageClass::Private

783 ? LLVM::Linkage::Private

784 : LLVM::Linkage::External;

786 op, dstType, isConstant, linkage, op.getSymName(), Attribute(),

788

789

790 if (op.getLocationAttr())

791 newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());

792

793 return success();

794 }

795

796 private:

797 spirv::ClientAPI clientAPI;

798 };

799

800

801

802 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>

804 public:

806

807 LogicalResult

808 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,

810

811 Type fromType = op.getOperand().getType();

812 Type toType = op.getType();

813

814 auto dstType = this->getTypeConverter()->convertType(toType);

815 if (!dstType)

817

819 rewriter.template replaceOpWithNewOp(op, dstType,

820 adaptor.getOperands());

821 return success();

822 }

824 rewriter.template replaceOpWithNewOp(op, dstType,

825 adaptor.getOperands());

826 return success();

827 }

828 return failure();

829 }

830 };

831

832 class FunctionCallPattern

834 public:

836

837 LogicalResult

838 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,

840 if (callOp.getNumResults() == 0) {

842 callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());

843 newOp.getProperties().operandSegmentSizes = {

844 static_cast<int32_t>(adaptor.getOperands().size()), 0};

846 return success();

847 }

848

849

850 auto dstType = getTypeConverter()->convertType(callOp.getType(0));

851 if (!dstType)

852 return rewriter.notifyMatchFailure(callOp, "type conversion failed");

854 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());

855 newOp.getProperties().operandSegmentSizes = {

856 static_cast<int32_t>(adaptor.getOperands().size()), 0};

858 return success();

859 }

860 };

861

862

863 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>

865 public:

867

868 LogicalResult

869 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,

871

872 auto dstType = this->getTypeConverter()->convertType(op.getType());

873 if (!dstType)

875

876 rewriter.template replaceOpWithNewOpLLVM::FCmpOp(

877 op, dstType, predicate, op.getOperand1(), op.getOperand2());

878 return success();

879 }

880 };

881

882

883 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>

885 public:

887

888 LogicalResult

889 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,

891

892 auto dstType = this->getTypeConverter()->convertType(op.getType());

893 if (!dstType)

895

896 rewriter.template replaceOpWithNewOpLLVM::ICmpOp(

897 op, dstType, predicate, op.getOperand1(), op.getOperand2());

898 return success();

899 }

900 };

901

902 class InverseSqrtPattern

904 public:

906

907 LogicalResult

908 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,

910 auto srcType = op.getType();

911 auto dstType = getTypeConverter()->convertType(srcType);

912 if (!dstType)

914

917 Value sqrt = rewriter.createLLVM::SqrtOp(loc, dstType, op.getOperand());

919 return success();

920 }

921 };

922

923

924 template

926 public:

928

929 LogicalResult

930 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,

932 if (!op.getMemoryAccess()) {

934 *this->getTypeConverter(), 0,

935 false,

936 false);

937 }

938 auto memoryAccess = *op.getMemoryAccess();

939 switch (memoryAccess) {

940 case spirv::MemoryAccess::Aligned:

942 case spirv::MemoryAccess::Nontemporal:

943 case spirv::MemoryAccess::Volatile: {

944 unsigned alignment =

945 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;

946 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;

947 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;

949 *this->getTypeConverter(), alignment,

950 isVolatile, isNonTemporal);

951 }

952 default:

953

954 return failure();

955 }

956 }

957 };

958

959

960 template

962 public:

964

965 LogicalResult

966 matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,

968 auto srcType = notOp.getType();

969 auto dstType = this->getTypeConverter()->convertType(srcType);

970 if (!dstType)

972

973 Location loc = notOp.getLoc();

975 auto mask =

976 isa(srcType)

977 ? rewriter.createLLVM::ConstantOp(

978 loc, dstType,

980 : rewriter.createLLVM::ConstantOp(loc, dstType, minusOne);

981 rewriter.template replaceOpWithNewOpLLVM::XOrOp(notOp, dstType,

982 notOp.getOperand(), mask);

983 return success();

984 }

985 };

986

987

988 template

990 public:

992

993 LogicalResult

994 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,

997 return success();

998 }

999 };

1000

1002 public:

1004

1005 LogicalResult

1006 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,

1010 return success();

1011 }

1012 };

1013

1015 public:

1017

1018 LogicalResult

1019 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,

1022 adaptor.getOperands());

1023 return success();

1024 }

1025 };

1026

1028 StringRef name,

1030 Type resultType,

1031 bool convergent = true) {

1032 auto func = dyn_cast_or_nullLLVM::LLVMFuncOp(

1034 if (func)

1035 return func;

1036

1038 func = b.createLLVM::LLVMFuncOp(

1039 symbolTable->getLoc(), name,

1041 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);

1042 func.setConvergent(convergent);

1043 func.setNoUnwind(true);

1044 func.setWillReturn(true);

1045 return func;

1046 }

1047

1049 LLVM::LLVMFuncOp func,

1051 auto call = builder.createLLVM::CallOp(loc, func, args);

1052 call.setCConv(func.getCConv());

1053 call.setConvergentAttr(func.getConvergentAttr());

1054 call.setNoUnwindAttr(func.getNoUnwindAttr());

1055 call.setWillReturnAttr(func.getWillReturnAttr());

1056 return call;

1057 }

1058

1059 template

1061 public:

1063

1065

1066 static constexpr StringRef getFuncName();

1067

1068 LogicalResult

1069 matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,

1071 constexpr StringRef funcName = getFuncName();

1073 controlBarrierOp->template getParentWithTraitOpTrait::SymbolTable();

1074

1076

1077 Type voidTy = rewriter.getTypeLLVM::LLVMVoidType();

1078 LLVM::LLVMFuncOp func =

1080

1081 Location loc = controlBarrierOp->getLoc();

1082 Value execution = rewriter.createLLVM::ConstantOp(

1083 loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));

1084 Value memory = rewriter.createLLVM::ConstantOp(

1085 loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));

1086 Value semantics = rewriter.createLLVM::ConstantOp(

1087 loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));

1088

1090 {execution, memory, semantics});

1091

1092 rewriter.replaceOp(controlBarrierOp, call);

1093 return success();

1094 }

1095 };

1096

1097 namespace {

1098

1099 StringRef getTypeMangling(Type type, bool isSigned) {

1101 .Case([](auto) { return "Dh"; })

1102 .Case([](auto) { return "f"; })

1103 .Case([](auto) { return "d"; })

1104 .Case([isSigned](IntegerType intTy) {

1105 switch (intTy.getWidth()) {

1106 case 1:

1107 return "b";

1108 case 8:

1109 return (isSigned) ? "a" : "c";

1110 case 16:

1111 return (isSigned) ? "s" : "t";

1112 case 32:

1113 return (isSigned) ? "i" : "j";

1114 case 64:

1115 return (isSigned) ? "l" : "m";

1116 default:

1117 llvm_unreachable("Unsupported integer width");

1118 }

1119 })

1120 .Default([](auto) {

1121 llvm_unreachable("No mangling defined");

1122 return "";

1123 });

1124 }

1125

1126 template

1127 constexpr StringLiteral getGroupFuncName();

1128

1129 template <>

1130 constexpr StringLiteral getGroupFuncNamespirv::GroupIAddOp() {

1131 return "_Z17__spirv_GroupIAddii";

1132 }

1133 template <>

1134 constexpr StringLiteral getGroupFuncNamespirv::GroupFAddOp() {

1135 return "_Z17__spirv_GroupFAddii";

1136 }

1137 template <>

1138 constexpr StringLiteral getGroupFuncNamespirv::GroupSMinOp() {

1139 return "_Z17__spirv_GroupSMinii";

1140 }

1141 template <>

1142 constexpr StringLiteral getGroupFuncNamespirv::GroupUMinOp() {

1143 return "_Z17__spirv_GroupUMinii";

1144 }

1145 template <>

1146 constexpr StringLiteral getGroupFuncNamespirv::GroupFMinOp() {

1147 return "_Z17__spirv_GroupFMinii";

1148 }

1149 template <>

1150 constexpr StringLiteral getGroupFuncNamespirv::GroupSMaxOp() {

1151 return "_Z17__spirv_GroupSMaxii";

1152 }

1153 template <>

1154 constexpr StringLiteral getGroupFuncNamespirv::GroupUMaxOp() {

1155 return "_Z17__spirv_GroupUMaxii";

1156 }

1157 template <>

1158 constexpr StringLiteral getGroupFuncNamespirv::GroupFMaxOp() {

1159 return "_Z17__spirv_GroupFMaxii";

1160 }

1161 template <>

1162 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformIAddOp() {

1163 return "_Z27__spirv_GroupNonUniformIAddii";

1164 }

1165 template <>

1166 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformFAddOp() {

1167 return "_Z27__spirv_GroupNonUniformFAddii";

1168 }

1169 template <>

1170 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformIMulOp() {

1171 return "_Z27__spirv_GroupNonUniformIMulii";

1172 }

1173 template <>

1174 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformFMulOp() {

1175 return "_Z27__spirv_GroupNonUniformFMulii";

1176 }

1177 template <>

1178 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformSMinOp() {

1179 return "_Z27__spirv_GroupNonUniformSMinii";

1180 }

1181 template <>

1182 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformUMinOp() {

1183 return "_Z27__spirv_GroupNonUniformUMinii";

1184 }

1185 template <>

1186 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformFMinOp() {

1187 return "_Z27__spirv_GroupNonUniformFMinii";

1188 }

1189 template <>

1190 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformSMaxOp() {

1191 return "_Z27__spirv_GroupNonUniformSMaxii";

1192 }

1193 template <>

1194 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformUMaxOp() {

1195 return "_Z27__spirv_GroupNonUniformUMaxii";

1196 }

1197 template <>

1198 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformFMaxOp() {

1199 return "_Z27__spirv_GroupNonUniformFMaxii";

1200 }

1201 template <>

1202 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformBitwiseAndOp() {

1203 return "_Z33__spirv_GroupNonUniformBitwiseAndii";

1204 }

1205 template <>

1206 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformBitwiseOrOp() {

1207 return "_Z32__spirv_GroupNonUniformBitwiseOrii";

1208 }

1209 template <>

1210 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformBitwiseXorOp() {

1211 return "_Z33__spirv_GroupNonUniformBitwiseXorii";

1212 }

1213 template <>

1214 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformLogicalAndOp() {

1215 return "_Z33__spirv_GroupNonUniformLogicalAndii";

1216 }

1217 template <>

1218 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformLogicalOrOp() {

1219 return "_Z32__spirv_GroupNonUniformLogicalOrii";

1220 }

1221 template <>

1222 constexpr StringLiteral getGroupFuncNamespirv::GroupNonUniformLogicalXorOp() {

1223 return "_Z33__spirv_GroupNonUniformLogicalXorii";

1224 }

1225 }

1226

1227 template <typename ReduceOp, bool Signed = false, bool NonUniform = false>

1229 public:

1231

1232 LogicalResult

1233 matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,

1235

1236 Type retTy = op.getResult().getType();

1238 return failure();

1239 }

1240 SmallString<36> funcName = getGroupFuncName();

1241 funcName += getTypeMangling(retTy, false);

1242

1245 if constexpr (NonUniform) {

1246 if (adaptor.getClusterSize()) {

1247 funcName += "j";

1248 paramTypes.push_back(i32Ty);

1249 }

1250 }

1251

1253 op->template getParentWithTraitOpTrait::SymbolTable();

1254

1255 LLVM::LLVMFuncOp func =

1257

1259 Value scope = rewriter.createLLVM::ConstantOp(

1260 loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));

1261 Value groupOp = rewriter.createLLVM::ConstantOp(

1262 loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));

1264 operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());

1265

1268 return success();

1269 }

1270 };

1271

1272 template <>

1273 constexpr StringRef

1274 ControlBarrierPatternspirv::ControlBarrierOp::getFuncName() {

1275 return "_Z22__spirv_ControlBarrieriii";

1276 }

1277

1278 template <>

1279 constexpr StringRef

1280 ControlBarrierPatternspirv::INTELControlBarrierArriveOp::getFuncName() {

1281 return "_Z33__spirv_ControlBarrierArriveINTELiii";

1282 }

1283

1284 template <>

1285 constexpr StringRef

1286 ControlBarrierPatternspirv::INTELControlBarrierWaitOp::getFuncName() {

1287 return "_Z31__spirv_ControlBarrierWaitINTELiii";

1288 }

1289

1290

1291

1292

1293

1294

1295

1296

1297

1298

1299

1300

1301

1302

1303

1304

1305

1306

1307

1308

1309

1310

1311

1312

1313

1314

1315

1316

1317

1318

1319

1320

1321

1322

1323

1324

1325

1326

1327

1328

1329

1330

1331

1332

1333

1334

1335

1336

1337

1339 public:

1341

1342 LogicalResult

1343 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,

1345

1347 return failure();

1348

1349

1350 if (loopOp.getBody().empty()) {

1351 rewriter.eraseOp(loopOp);

1352 return success();

1353 }

1354

1355 Location loc = loopOp.getLoc();

1356

1357

1358

1361 Block *endBlock = rewriter.splitBlock(currentBlock, position);

1362

1363

1364

1365 Block *entryBlock = loopOp.getEntryBlock();

1366 assert(entryBlock->getOperations().size() == 1);

1367 auto brOp = dyn_castspirv::BranchOp(entryBlock->getOperations().front());

1368 if (!brOp)

1369 return failure();

1370 Block *headerBlock = loopOp.getHeaderBlock();

1372 rewriter.createLLVM::BrOp(loc, brOp.getBlockArguments(), headerBlock);

1374

1375

1376 Block *mergeBlock = loopOp.getMergeBlock();

1380 rewriter.createLLVM::BrOp(loc, terminatorOperands, endBlock);

1381

1384 return success();

1385 }

1386 };

1387

1388

1389

1390

1392 public:

1394

1395 LogicalResult

1396 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,

1398

1399

1400

1402 return failure();

1403

1404

1405

1406

1407

1408 if (op.getBody().getBlocks().size() <= 2) {

1410 return success();

1411 }

1412

1414

1415

1416

1420 auto *continueBlock = rewriter.splitBlock(currentBlock, position);

1421

1422

1423

1424

1425

1426 auto *headerBlock = op.getHeaderBlock();

1427 assert(headerBlock->getOperations().size() == 1);

1428 auto condBrOp = dyn_castspirv::BranchConditionalOp(

1430 if (!condBrOp)

1431 return failure();

1433

1434

1435 auto *mergeBlock = op.getMergeBlock();

1439 rewriter.createLLVM::BrOp(loc, terminatorOperands, continueBlock);

1440

1441

1442 Block *trueBlock = condBrOp.getTrueBlock();

1443 Block *falseBlock = condBrOp.getFalseBlock();

1445 rewriter.createLLVM::CondBrOp(loc, condBrOp.getCondition(), trueBlock,

1446 condBrOp.getTrueTargetOperands(),

1447 falseBlock,

1448 condBrOp.getFalseTargetOperands());

1449

1451 rewriter.replaceOp(op, continueBlock->getArguments());

1452 return success();

1453 }

1454 };

1455

1456

1457

1458

1459

1460 template <typename SPIRVOp, typename LLVMOp>

1462 public:

1464

1465 LogicalResult

1466 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,

1468

1469 auto dstType = this->getTypeConverter()->convertType(op.getType());

1470 if (!dstType)

1472

1473 Type op1Type = op.getOperand1().getType();

1474 Type op2Type = op.getOperand2().getType();

1475

1476 if (op1Type == op2Type) {

1477 rewriter.template replaceOpWithNewOp(op, dstType,

1478 adaptor.getOperands());

1479 return success();

1480 }

1481

1482 std::optional<uint64_t> dstTypeWidth =

1484 std::optional<uint64_t> op2TypeWidth =

1486

1487 if (!dstTypeWidth || !op2TypeWidth)

1488 return failure();

1489

1492 if (op2TypeWidth < dstTypeWidth) {

1494 extended = rewriter.template createLLVM::ZExtOp(

1495 loc, dstType, adaptor.getOperand2());

1496 } else {

1497 extended = rewriter.template createLLVM::SExtOp(

1498 loc, dstType, adaptor.getOperand2());

1499 }

1500 } else if (op2TypeWidth == dstTypeWidth) {

1501 extended = adaptor.getOperand2();

1502 } else {

1503 return failure();

1504 }

1505

1506 Value result = rewriter.template create(

1507 loc, dstType, adaptor.getOperand1(), extended);

1509 return success();

1510 }

1511 };

1512

1514 public:

1516

1517 LogicalResult

1518 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,

1520 auto dstType = getTypeConverter()->convertType(tanOp.getType());

1521 if (!dstType)

1522 return rewriter.notifyMatchFailure(tanOp, "type conversion failed");

1523

1524 Location loc = tanOp.getLoc();

1525 Value sin = rewriter.createLLVM::SinOp(loc, dstType, tanOp.getOperand());

1526 Value cos = rewriter.createLLVM::CosOp(loc, dstType, tanOp.getOperand());

1527 rewriter.replaceOpWithNewOpLLVM::FDivOp(tanOp, dstType, sin, cos);

1528 return success();

1529 }

1530 };

1531

1532

1533

1534

1535

1536

1537

1539 public:

1541

1542 LogicalResult

1543 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,

1545 auto srcType = tanhOp.getType();

1546 auto dstType = getTypeConverter()->convertType(srcType);

1547 if (!dstType)

1548 return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");

1549

1550 Location loc = tanhOp.getLoc();

1552 Value multiplied =

1553 rewriter.createLLVM::FMulOp(loc, dstType, two, tanhOp.getOperand());

1554 Value exponential = rewriter.createLLVM::ExpOp(loc, dstType, multiplied);

1556 Value numerator =

1557 rewriter.createLLVM::FSubOp(loc, dstType, exponential, one);

1558 Value denominator =

1559 rewriter.createLLVM::FAddOp(loc, dstType, exponential, one);

1560 rewriter.replaceOpWithNewOpLLVM::FDivOp(tanhOp, dstType, numerator,

1561 denominator);

1562 return success();

1563 }

1564 };

1565

1567 public:

1569

1570 LogicalResult

1571 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,

1573 auto srcType = varOp.getType();

1574

1575 auto pointerTo = castspirv::PointerType(srcType).getPointeeType();

1576 auto init = varOp.getInitializer();

1577 if (init && !pointerTo.isIntOrFloat() && !isa(pointerTo))

1578 return failure();

1579

1580 auto dstType = getTypeConverter()->convertType(srcType);

1581 if (!dstType)

1582 return rewriter.notifyMatchFailure(varOp, "type conversion failed");

1583

1584 Location loc = varOp.getLoc();

1586 if (!init) {

1587 auto elementType = getTypeConverter()->convertType(pointerTo);

1588 if (!elementType)

1589 return rewriter.notifyMatchFailure(varOp, "type conversion failed");

1590 rewriter.replaceOpWithNewOpLLVM::AllocaOp(varOp, dstType, elementType,

1591 size);

1592 return success();

1593 }

1594 auto elementType = getTypeConverter()->convertType(pointerTo);

1595 if (!elementType)

1596 return rewriter.notifyMatchFailure(varOp, "type conversion failed");

1597 Value allocated =

1598 rewriter.createLLVM::AllocaOp(loc, dstType, elementType, size);

1599 rewriter.createLLVM::StoreOp(loc, adaptor.getInitializer(), allocated);

1600 rewriter.replaceOp(varOp, allocated);

1601 return success();

1602 }

1603 };

1604

1605

1606

1607

1608

1609 class BitcastConversionPattern

1611 public:

1613

1614 LogicalResult

1615 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,

1617 auto dstType = getTypeConverter()->convertType(bitcastOp.getType());

1618 if (!dstType)

1619 return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");

1620

1621

1622 if (isaLLVM::LLVMPointerType(dstType)) {

1623 rewriter.replaceOp(bitcastOp, adaptor.getOperand());

1624 return success();

1625 }

1626

1628 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());

1629 return success();

1630 }

1631 };

1632

1633

1634

1635

1636

1638 public:

1640

1641 LogicalResult

1642 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,

1644

1645

1646

1647 auto funcType = funcOp.getFunctionType();

1649 funcType.getNumInputs());

1650 auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())

1651 ->convertFunctionSignature(

1652 funcType, false,

1653 false, signatureConverter);

1654 if (!llvmType)

1655 return failure();

1656

1657

1658 Location loc = funcOp.getLoc();

1659 StringRef name = funcOp.getName();

1660 auto newFuncOp = rewriter.createLLVM::LLVMFuncOp(loc, name, llvmType);

1661

1662

1663 MLIRContext *context = funcOp.getContext();

1664 switch (funcOp.getFunctionControl()) {

1665 case spirv::FunctionControl::Inline:

1666 newFuncOp.setAlwaysInline(true);

1667 break;

1668 case spirv::FunctionControl::DontInline:

1669 newFuncOp.setNoInline(true);

1670 break;

1671

1672 #define DISPATCH(functionControl, llvmAttr) \

1673 case functionControl: \

1674 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \

1675 break;

1676

1677 DISPATCH(spirv::FunctionControl::Pure,

1679 DISPATCH(spirv::FunctionControl::Const,

1681

1682 #undef DISPATCH

1683

1684

1685

1686 default:

1687 break;

1688 }

1689

1691 newFuncOp.end());

1693 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {

1694 return failure();

1695 }

1696 rewriter.eraseOp(funcOp);

1697 return success();

1698 }

1699 };

1700

1701

1702

1703

1704

1706 public:

1708

1709 LogicalResult

1710 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,

1712

1713 auto newModuleOp =

1714 rewriter.create(spvModuleOp.getLoc(), spvModuleOp.getName());

1715 rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());

1716

1717

1718 rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());

1719 rewriter.eraseOp(spvModuleOp);

1720 return success();

1721 }

1722 };

1723

1724

1725

1726

1727

1728 class VectorShufflePattern

1730 public:

1732 LogicalResult

1733 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,

1736 auto components = adaptor.getComponents();

1737 auto vector1 = adaptor.getVector1();

1738 auto vector2 = adaptor.getVector2();

1739 int vector1Size = cast(vector1.getType()).getNumElements();

1740 int vector2Size = cast(vector2.getType()).getNumElements();

1741 if (vector1Size == vector2Size) {

1743 op, vector1, vector2,

1744 LLVM::convertArrayToIndices<int32_t>(components));

1745 return success();

1746 }

1747

1748 auto dstType = getTypeConverter()->convertType(op.getType());

1749 if (!dstType)

1751 auto scalarType = cast(dstType).getElementType();

1752 auto componentsArray = components.getValue();

1753 auto *context = rewriter.getContext();

1755 Value targetOp = rewriter.createLLVM::PoisonOp(loc, dstType);

1756 for (unsigned i = 0; i < componentsArray.size(); i++) {

1757 if (!isa(componentsArray[i]))

1758 return op.emitError("unable to support non-constant component");

1759

1760 int indexVal = cast(componentsArray[i]).getInt();

1761 if (indexVal == -1)

1762 continue;

1763

1764 int offsetVal = 0;

1765 Value baseVector = vector1;

1766 if (indexVal >= vector1Size) {

1767 offsetVal = vector1Size;

1768 baseVector = vector2;

1769 }

1770

1771 Value dstIndex = rewriter.createLLVM::ConstantOp(

1773 Value index = rewriter.createLLVM::ConstantOp(

1774 loc, llvmI32Type,

1776

1777 auto extractOp = rewriter.createLLVM::ExtractElementOp(

1778 loc, scalarType, baseVector, index);

1779 targetOp = rewriter.createLLVM::InsertElementOp(loc, dstType, targetOp,

1780 extractOp, dstIndex);

1781 }

1782 rewriter.replaceOp(op, targetOp);

1783 return success();

1784 }

1785 };

1786 }

1787

1788

1789

1790

1791

1793 spirv::ClientAPI clientAPI) {

1796 });

1799 });

1802 });

1805 });

1806 }

1807

1810 spirv::ClientAPI clientAPI) {

1812

1813 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,

1814 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,

1815 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,

1816 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,

1817 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,

1818 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,

1819 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,

1820 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,

1821 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,

1822 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,

1823 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,

1824 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,

1825 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,

1826

1827

1828 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,

1829 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,

1830 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,

1831 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,

1832 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,

1833 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,

1834 NotPatternspirv::NotOp,

1835

1836

1837 BitcastConversionPattern,

1838 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,

1839 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,

1840 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,

1841 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,

1842 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,

1843 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,

1844 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,

1845

1846

1847 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,

1848 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,

1849 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,

1850 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,

1851 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,

1852 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,

1853 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,

1854 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,

1855 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,

1856 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,

1857 FComparePattern<spirv::FUnordGreaterThanEqualOp,

1858 LLVM::FCmpPredicate::uge>,

1859 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,

1860 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,

1861 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,

1862 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,

1863 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,

1864 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,

1865 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,

1866 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,

1867 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,

1868 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,

1869 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,

1870

1871

1872 ConstantScalarAndVectorPattern,

1873

1874

1875 BranchConversionPattern, BranchConditionalConversionPattern,

1876 FunctionCallPattern, LoopPattern, SelectionPattern,

1877 ErasePatternspirv::MergeOp,

1878

1879

1880 ErasePatternspirv::EntryPointOp, ExecutionModePattern,

1881

1882

1883 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,

1884 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,

1885 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,

1886 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,

1887 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,

1888 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,

1889 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,

1890 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,

1891 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,

1892 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,

1893 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,

1894 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,

1895 InverseSqrtPattern, TanPattern, TanhPattern,

1896

1897

1898 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,

1899 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,

1900 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,

1901 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,

1902 NotPatternspirv::LogicalNotOp,

1903

1904

1905 AccessChainPattern, AddressOfPattern, LoadStorePatternspirv::LoadOp,

1906 LoadStorePatternspirv::StoreOp, VariablePattern,

1907

1908

1909 CompositeExtractPattern, CompositeInsertPattern,

1910 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,

1911 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,

1912 VectorShufflePattern,

1913

1914

1915 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,

1916 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,

1917 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,

1918

1919

1920 ReturnPattern, ReturnValuePattern,

1921

1922

1923 ControlBarrierPatternspirv::ControlBarrierOp,

1924 ControlBarrierPatternspirv::INTELControlBarrierArriveOp,

1925 ControlBarrierPatternspirv::INTELControlBarrierWaitOp,

1926

1927

1928 GroupReducePatternspirv::GroupIAddOp,

1929 GroupReducePatternspirv::GroupFAddOp,

1930 GroupReducePatternspirv::GroupFMinOp,

1931 GroupReducePatternspirv::GroupUMinOp,

1932 GroupReducePattern<spirv::GroupSMinOp, true>,

1933 GroupReducePatternspirv::GroupFMaxOp,

1934 GroupReducePatternspirv::GroupUMaxOp,

1935 GroupReducePattern<spirv::GroupSMaxOp, true>,

1936 GroupReducePattern<spirv::GroupNonUniformIAddOp, false,

1937 true>,

1938 GroupReducePattern<spirv::GroupNonUniformFAddOp, false,

1939 true>,

1940 GroupReducePattern<spirv::GroupNonUniformIMulOp, false,

1941 true>,

1942 GroupReducePattern<spirv::GroupNonUniformFMulOp, false,

1943 true>,

1944 GroupReducePattern<spirv::GroupNonUniformSMinOp, true,

1945 true>,

1946 GroupReducePattern<spirv::GroupNonUniformUMinOp, false,

1947 true>,

1948 GroupReducePattern<spirv::GroupNonUniformFMinOp, false,

1949 true>,

1950 GroupReducePattern<spirv::GroupNonUniformSMaxOp, true,

1951 true>,

1952 GroupReducePattern<spirv::GroupNonUniformUMaxOp, false,

1953 true>,

1954 GroupReducePattern<spirv::GroupNonUniformFMaxOp, false,

1955 true>,

1956 GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, false,

1957 true>,

1958 GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, false,

1959 true>,

1960 GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, false,

1961 true>,

1962 GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, false,

1963 true>,

1964 GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, false,

1965 true>,

1966 GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, false,

1967 true>>(patterns.getContext(),

1968 typeConverter);

1969

1970 patterns.add(clientAPI, patterns.getContext(),

1971 typeConverter);

1972 }

1973

1976 patterns.add(patterns.getContext(), typeConverter);

1977 }

1978

1981 patterns.add(patterns.getContext(), typeConverter);

1982 }

1983

1984

1985

1986

1987

1988

1989 static constexpr StringRef kBinding = "binding";

1992 auto spvModules = module.getOpsspirv::ModuleOp();

1993 for (auto spvModule : spvModules) {

1994 spvModule.walk([&](spirv::GlobalVariableOp op) {

1995 IntegerAttr descriptorSet =

1997 IntegerAttr binding = op->getAttrOfType(kBinding);

1998

1999

2000 if (descriptorSet && binding) {

2001

2002

2003 auto moduleAndName =

2004 spvModule.getName().has_value()

2005 ? spvModule.getName()->str() + "_" + op.getSymName().str()

2006 : op.getSymName().str();

2007 std::string name =

2008 llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,

2009 std::to_string(descriptorSet.getInt()),

2010 std::to_string(binding.getInt()));

2011 auto nameAttr = StringAttr::get(op->getContext(), name);

2012

2013

2014

2016 op.emitError("unable to replace all symbol uses for ") << name;

2020 }

2021 });

2022 }

2023 }

static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)

static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)

static MLIRContext * getContext(OpFoldResult val)

static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)

Utility function for bitfield ops:

static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)

Creates llvm.mlir.constant with a floating-point scalar or vector value.

static constexpr StringRef kDescriptorSet

static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)

Creates LLVM dialect constant with the given value.

static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)

Converts SPIR-V pointer type to LLVM pointer.

static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)

Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.

static unsigned getBitWidth(Type type)

Returns the bit width of integer, float or vector of float or integer values.

static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)

Utility for spirv.Load and spirv.Store conversion.

static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)

Converts SPIR-V struct with no offset to packed LLVM struct.

static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)

Converts SPIR-V runtime array to LLVM array.

static bool isSignedIntegerOrVector(Type type)

Returns true if the given type is a signed integer or vector type.

static bool isUnsignedIntegerOrVector(Type type)

Returns true if the given type is an unsigned integer or vector type.

static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)

Converts SPIR-V array type to LLVM array.

static constexpr StringRef kBinding

Hook for descriptor set and binding number encoding.

static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)

Creates IntegerAttribute with all bits set for given type.

static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)

Broadcasts the value. If srcType is a scalar, the value remains unchanged.

static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)

Creates llvm.mlir.constant with all bits set for the given type.

static unsigned getLLVMTypeBitWidth(Type type)

Returns the bit width of LLVMType integer or vector.

#define DISPATCH(functionControl, llvmAttr)

static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)

Returns the width of an integer or of the element type of an integer vector, if applicable.

static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)

Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.

static Type convertStructType(spirv::StructType type, const TypeConverter &converter)

Converts SPIR-V struct to LLVM struct.

static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)

Broadcasts the value to vector with numElements number of elements.

Attributes are known-constant values of operations.

Block represents an ordered list of Operations.

OpListType::iterator iterator

Operation * getTerminator()

Get the terminator operation of this block.

OpListType & getOperations()

BlockArgListType getArguments()

This class is a general helper class for creating context-global objects like types,...

IntegerAttr getI32IntegerAttr(int32_t value)

DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)

IntegerAttr getIntegerAttr(Type type, int64_t value)

FloatAttr getFloatAttr(Type type, double value)

IntegerType getIntegerType(unsigned width)

Ty getType(Args &&...args)

Get or construct an instance of the type Ty with provided arguments.

MLIRContext * getContext() const

This class implements a pattern rewriter for use with ConversionPatterns.

void replaceOp(Operation *op, ValueRange newValues) override

Replace the given operation with the new values.

FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)

Apply a signature conversion to each block in the given region.

void eraseOp(Operation *op) override

PatternRewriter hook for erasing a dead operation.

void eraseBlock(Block *block) override

PatternRewriter hook for erase all operations in a block.

static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)

Constructs a dense elements attribute from an array of element values.

Conversion from types to the LLVM IR dialect.

This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...

MLIRContext is the top-level object for a collection of MLIR operations.

RAII guard to reset the insertion point of the builder when destroyed.

This class helps build Operations.

Block::iterator getInsertionPoint() const

Returns the current insertion point of the builder.

void setInsertionPointToStart(Block *block)

Sets the insertion point to the start of the specified block.

void setInsertionPointToEnd(Block *block)

Sets the insertion point to the end of the specified block.

Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)

Add new block with 'argTypes' arguments and set the insertion point to the end of it.

Block * getBlock() const

Returns the current block of the builder.

Operation * create(const OperationState &state)

Creates an operation given the fields represented as an OperationState.

void setInsertionPointAfter(Operation *op)

Sets the insertion point to the node after the specified operation, which will cause subsequent inser...

Block * getInsertionBlock() const

Return the block the current insertion point belongs to.

Operation is the basic unit of execution within MLIR.

Location getLoc()

The source location the operation was defined or derived from.

InFlightDiagnostic emitError(const Twine &message={})

Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...

Region & getRegion(unsigned index)

Returns the region held by this operation at position 'index'.

operand_range getOperands()

Returns an iterator on the underlying Value's.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

This class contains a list of basic blocks and a link to the parent operation it is attached to.

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

Block * splitBlock(Block *block, Block::iterator before)

Split the operations starting at "before" (inclusive) out of the given block into a new block,...

void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)

Move the blocks that belong to "region" before the given position in another region "parent".

OpTy replaceOpWithNewOp(Operation *op, Args &&...args)

Replace the results of the given (original) op with a new op that is created without verification (re...

static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)

Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...

static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)

Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.

static void setSymbolName(Operation *symbol, StringAttr name)

Sets the name of the given symbol operation.

This class provides all of the information necessary to convert a type signature.

void addConversion(FnT &&callback)

Register a conversion function.

LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const

Convert the given type.

LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const

Convert the given set of types, filling 'results' as necessary.

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

bool isSignedInteger() const

Return true if this is a signed integer type (with the specified width).

bool isUnsignedInteger() const

Return true if this is an unsigned integer type (with the specified width).

bool isIntOrFloat() const

Return true if this is an integer (of any signedness) or a float type.

unsigned getIntOrFloatBitWidth() const

Return the bit width of an integer or a float type, assert failure on other types.

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

static spirv::StructType decorateType(spirv::StructType structType)

Returns a new StructType with layout decoration.

static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)

Builder from ArrayRef.

Type getElementType() const

unsigned getArrayStride() const

Returns the array stride in bytes.

unsigned getNumElements() const

StorageClass getStorageClass() const

Type getElementType() const

unsigned getArrayStride() const

Returns the array stride in bytes.

void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const

TypeRange getElementTypes() const

bool isCompatibleType(Type type)

Returns true if the given type is compatible with the LLVM dialect.

SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)

Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...

Include the generated interface declarations.

unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)

void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)

Populates type conversions with additional SPIR-V types.

void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)

Populates the given list with patterns for function conversion from SPIR-V to LLVM.

const FrozenRewritePatternSet & patterns

void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)

Populates the given list with patterns that convert from SPIR-V to LLVM.

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

void encodeBindAttribute(ModuleOp module)

Encodes global variable's descriptor set and binding into its name if they both exist.

void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)

Populates the given patterns for module conversion from SPIR-V to LLVM.