LLVM: lib/Target/SPIRV/SPIRVISelLowering.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

21#include "llvm/IR/IntrinsicsSPIRV.h"

22

23#define DEBUG_TYPE "spirv-lower"

24

25using namespace llvm;

26

30

31

32

36 return false;

37

39 return false;

40

41 if (Ty1->getOpcode() == SPIRV::OpTypeArray) {

42

44 return false;

45

48 return ElemType1 == ElemType2 ||

50 }

51

52 if (Ty1->getOpcode() == SPIRV::OpTypeStruct) {

58 if (ElemType1 != ElemType2 &&

60 return false;

61 }

62 return true;

63 }

64 return false;

65}

66

69

70

71

75 return 1;

77 return 1;

79}

80

83 EVT VT) const {

84

85

86

89 return MVT::v4i1;

91 return MVT::v4i8;

92 }

94}

95

100 unsigned AlignIdx = 3;

102 case Intrinsic::spv_load:

103 AlignIdx = 2;

104 [[fallthrough]];

105 case Intrinsic::spv_store: {

106 if (I.getNumOperands() >= AlignIdx + 1) {

108 Info.align = Align(AlignOp->getZExtValue());

109 }

112 Info.memVT = MVT::i64;

113

114

115 return true;

116 break;

117 }

118 default:

119 break;

120 }

121 return false;

122}

123

124std::pair<unsigned, const TargetRegisterClass *>

127 MVT VT) const {

130 return std::make_pair(0u, RC);

131

133 RC = VT.isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass;

135 RC = VT.isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass;

136 else

137 RC = &SPIRV::iIDRegClass;

138

139 return std::make_pair(0u, RC);

140}

141

143 SPIRVType *TypeInst = MRI->getVRegDef(OpReg);

144 return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter

146 : OpReg;

147}

148

155 bool Res = MIB.buildInstr(SPIRV::OpBitcast)

161 if (!Res)

162 report_fatal_error("insert validation bitcast: cannot constrain all uses");

163 I.getOperand(OpIdx).setReg(NewReg);

164}

165

167 SPIRVType *OpType, bool ReuseType,

169 SPIRV::StorageClass::StorageClass SC =

170 static_castSPIRV::StorageClass::StorageClass\(

171 OpType->getOperand(1).getImm());

174 ReuseType ? ResType

176 ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false);

178}

179

180

181

185 SPIRVType *ResType, const Type *ResTy = nullptr) {

186

191 if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)

192 return;

193

194 Register ElemTypeReg = OpType->getOperand(2).getReg();

196 if (!ElemType)

197 return;

198

200 bool IsEqualTypes = IsSameMF ? ElemType == ResType

202 if (IsEqualTypes)

203 return;

204

205

210 "insert validation bitcast: incompatible result and operand types");

212}

213

214

215

220 constexpr unsigned OpIdx = 2;

225 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)

226 return;

228 if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)

229 return;

230

236}

237

241 Register PtrReg = I.getOperand(0).getReg();

246 if (!PonteeElemType || PonteeElemType->getOpcode() == SPIRV::OpTypeVoid ||

247 (PonteeElemType->getOpcode() == SPIRV::OpTypeInt &&

249 return;

250

251 SPIRV::StorageClass::StorageClass SC =

252 static_castSPIRV::StorageClass::StorageClass\(

259}

260

269 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)

270 return;

272 if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||

274 return;

275

278 if (!MemberType)

279 return;

280 unsigned MemberTypeOp = MemberType->getOpcode();

281 if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&

282 MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)

283 return;

284

285

286 SPIRV::StorageClass::StorageClass SC =

287 static_castSPIRV::StorageClass::StorageClass\(

288 OpType->getOperand(1).getImm());

292}

293

294

295

296

297

298

299

300

301

307 if (FunDef->getOpcode() != SPIRV::OpFunction)

308 return;

309 unsigned OpIdx = 3;

311 FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&

316 DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer

319 : nullptr;

320 if (DefElemType) {

322

323

324

325

329 DefElemTy);

331 }

332 }

333}

334

335

336

337

346 if (!FunDef)

347 return F;

350 return nullptr;

351}

352

353

354

362 &FunCall->getParent()->getParent()->getRegInfo();

364 }

365}

366

367

371 if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {

375 }

376}

377

378

379

381

382

383 if (ProcessedMF.find(&MF) != ProcessedMF.end())

384 return;

385

393 MBBI != MBBE;) {

395 switch (MI.getOpcode()) {

396 case SPIRV::OpAtomicLoad:

397 case SPIRV::OpAtomicExchange:

398 case SPIRV::OpAtomicCompareExchange:

399 case SPIRV::OpAtomicCompareExchangeWeak:

400 case SPIRV::OpAtomicIIncrement:

401 case SPIRV::OpAtomicIDecrement:

402 case SPIRV::OpAtomicIAdd:

403 case SPIRV::OpAtomicISub:

404 case SPIRV::OpAtomicSMin:

405 case SPIRV::OpAtomicUMin:

406 case SPIRV::OpAtomicSMax:

407 case SPIRV::OpAtomicUMax:

408 case SPIRV::OpAtomicAnd:

409 case SPIRV::OpAtomicOr:

410 case SPIRV::OpAtomicXor:

411

412

413

414 case SPIRV::OpLoad:

415

417 break;

418

421 break;

422 case SPIRV::OpAtomicStore:

423

424

427 break;

428 case SPIRV::OpStore:

429

432 break;

433 case SPIRV::OpPtrCastToGeneric:

434 case SPIRV::OpGenericCastToPtr:

435 case SPIRV::OpGenericCastToPtrExplicit:

437 break;

438 case SPIRV::OpPtrAccessChain:

439 case SPIRV::OpInBoundsPtrAccessChain:

440 if (MI.getNumOperands() == 4)

442 break;

443

444 case SPIRV::OpFunctionCall:

445

446

447 if (MI.getNumOperands() > 3)

450 break;

451 case SPIRV::OpFunction:

452

453

455 break;

456

457

458

459 case SPIRV::OpIAddS:

460 case SPIRV::OpIAddV:

461 case SPIRV::OpISubS:

462 case SPIRV::OpISubV:

464 SPIRV::OpTypeBool))

465 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));

466 break;

467

468

469

470 case SPIRV::OpBitwiseOrS:

471 case SPIRV::OpBitwiseOrV:

473 SPIRV::OpTypeBool))

474 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr));

475 break;

476 case SPIRV::OpBitwiseAndS:

477 case SPIRV::OpBitwiseAndV:

479 SPIRV::OpTypeBool))

480 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd));

481 break;

482 case SPIRV::OpBitwiseXorS:

483 case SPIRV::OpBitwiseXorV:

485 SPIRV::OpTypeBool))

486 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));

487 break;

488 case SPIRV::OpLifetimeStart:

489 case SPIRV::OpLifetimeStop:

490 if (MI.getOperand(1).getImm() > 0)

492 break;

493 case SPIRV::OpGroupAsyncCopy:

496 break;

497 case SPIRV::OpGroupWaitEvents:

498

500 break;

501 case SPIRV::OpConstantI: {

503 if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&

504 MI.getOperand(2).getImm() == 0) {

505

506 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));

507 for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)

508 MI.removeOperand(i);

509 }

510 } break;

511 case SPIRV::OpPhi: {

512

513

514

515

518 if (Type->getParent() == Curr && !Curr->pred_empty())

520 } break;

521 case SPIRV::OpExtInst: {

522

523 if (MI.getOperand(2).isImm() || MI.getOperand(3).isImm() ||

524 MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std)

525 continue;

526 switch (MI.getOperand(3).getImm()) {

527 case SPIRV::OpenCLExtInst::frexp:

528 case SPIRV::OpenCLExtInst::lgamma_r:

529 case SPIRV::OpenCLExtInst::remquo: {

530

531

534 SPIRVType *RetType = MRI->getVRegDef(MI.getOperand(1).getReg());

535 assert(RetType && "Expected return type");

537 RetType->getOpcode() != SPIRV::OpTypeVector

538 ? Int32Type

541 MIB, false));

542 } break;

543 case SPIRV::OpenCLExtInst::fract:

544 case SPIRV::OpenCLExtInst::modf:

545 case SPIRV::OpenCLExtInst::sincos:

546

547

548 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&

549 "Expected v-reg");

551 STI, MRI, GR, MI, MI.getNumOperands() - 1,

553 MI.getOperand(MI.getNumOperands() - 2).getReg()));

554 break;

555 case SPIRV::OpenCLExtInst::prefetch:

556

557

558 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&

559 "Expected v-reg");

561 MI.getNumOperands() - 2);

562 break;

563 }

564 } break;

565 }

566 }

570 Pred->insert(Pred->getFirstTerminator(), Curr->remove_instr(MI));

571 }

572 }

573 ProcessedMF.insert(&MF);

575}

576

577

578

579

586

587 if (PointeeType == OpType)

588 return true;

589

591

592 if (I.getOperand(OpIdx).isDef() &&

594 return true;

595 }

596

598 return false;

599 }

600

601 return false;

602}

603

608

612

618

619 OldResult.setReg(NewResultReg);

620 OldType.setReg(NewTypeReg);

621

623 return MIB.buildInstr(SPIRV::OpCopyLogical)

624 .addDef(OldResultReg)

626 .addUse(NewResultReg)

627 .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),

628 *STI.getRegBankInfo());

629}

unsigned const MachineRegisterInfo * MRI

assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")

MachineBasicBlock MachineBasicBlock::iterator MBBI

Register const TargetRegisterInfo * TRI

MachineInstr unsigned OpIdx

static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, Register OpReg, unsigned OpIdx, SPIRVType *NewPtrType)

Definition SPIRVISelLowering.cpp:149

static SPIRVType * createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I, SPIRVType *OpType, bool ReuseType, SPIRVType *ResType, const Type *ResTy)

Definition SPIRVISelLowering.cpp:166

static void validateLifetimeStart(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)

Definition SPIRVISelLowering.cpp:238

static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)

Definition SPIRVISelLowering.cpp:216

static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, unsigned OpIdx)

Definition SPIRVISelLowering.cpp:261

Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg)

Definition SPIRVISelLowering.cpp:142

void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)

Definition SPIRVISelLowering.cpp:368

void validateFunCallMachineDef(const SPIRVSubtarget &STI, MachineRegisterInfo *DefMRI, MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall, MachineInstr *FunDef)

Definition SPIRVISelLowering.cpp:302

void validateForwardCalls(const SPIRVSubtarget &STI, MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunDef)

Definition SPIRVISelLowering.cpp:355

const Function * validateFunCall(const SPIRVSubtarget &STI, MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall)

Definition SPIRVISelLowering.cpp:338

static bool typesLogicallyMatch(const SPIRVType *Ty1, const SPIRVType *Ty2, SPIRVGlobalRegistry &GR)

Definition SPIRVISelLowering.cpp:33

static void validatePtrTypes(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, unsigned OpIdx, SPIRVType *ResType, const Type *ResTy=nullptr)

Definition SPIRVISelLowering.cpp:182

Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...

LLVMContext & getContext() const

getContext - Return a reference to the LLVMContext associated with this function.

This is an important class for using LLVM in a threaded context.

bool isVector() const

Return true if this is a vector value type.

bool isInteger() const

Return true if this is an integer or a vector integer type.

bool isFloatingPoint() const

Return true if this is a FP or a vector FP type.

LLVM_ABI MachineInstr * remove_instr(MachineInstr *I)

Remove the possibly bundled instruction from the instruction list without deleting it.

pred_iterator pred_begin()

const MachineFunction * getParent() const

Return the MachineFunction containing this basic block.

MachineInstrBundleIterator< MachineInstr > iterator

MachineRegisterInfo & getRegInfo()

getRegInfo - Return information about the registers currently in use.

Function & getFunction()

Return the LLVM function that this machine code represents.

BasicBlockListType::iterator iterator

void insert(iterator MBBI, MachineBasicBlock *MBB)

Helper class to build MachineInstr.

MachineInstrBuilder buildInstr(unsigned Opcode)

Build and insert = Opcode .

MachineFunction & getMF()

Getter for the function we currently build.

bool constrainAllUses(const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const RegisterBankInfo &RBI) const

const MachineInstrBuilder & addUse(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const

Add a virtual register use operand.

const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const

Add a virtual register definition operand.

Representation of each machine instruction.

unsigned getOpcode() const

Returns the opcode of this MachineInstr.

const MachineBasicBlock * getParent() const

unsigned getNumOperands() const

Retuns the total number of operands.

const MachineOperand & getOperand(unsigned i) const

Flags

Flags values. These may be or'd together.

MachineOperand class - Representation of each machine instruction operand.

const GlobalValue * getGlobal() const

LLVM_ABI void setReg(Register Reg)

Change the register this operand corresponds to.

Register getReg() const

getReg - Returns the register number.

MachineRegisterInfo - Keep track of information for virtual and physical registers,...

LLVM_ABI MachineInstr * getVRegDef(Register Reg) const

getVRegDef - Return the machine instr that defines the specified virtual register or null if none is ...

Wrapper class representing virtual and physical registers.

SPIRVType * getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const

void addForwardCall(const Function *F, MachineInstr *MI)

SPIRVType * getResultType(Register VReg, MachineFunction *MF=nullptr)

const Type * getTypeForSPIRVType(const SPIRVType *Ty) const

bool isBitcastCompatible(const SPIRVType *Type1, const SPIRVType *Type2) const

SPIRVType * getOrCreateSPIRVType(const Type *Type, MachineInstr &I, SPIRV::AccessQualifier::AccessQualifier AQ, bool EmitIR)

SPIRVType * getOrCreateSPIRVPointerType(const Type *BaseType, MachineIRBuilder &MIRBuilder, SPIRV::StorageClass::StorageClass SC)

const MachineInstr * getFunctionDefinition(const Function *F)

SPIRVType * getPointeeType(SPIRVType *PtrType)

Register getSPIRVTypeID(const SPIRVType *SpirvType) const

SmallPtrSet< MachineInstr *, 8 > * getForwardCalls(const Function *F)

SPIRVType * getOrCreateSPIRVVectorType(SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder, bool EmitIR)

bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const

MachineFunction * setCurrentFunc(MachineFunction &MF)

SPIRVType * getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineIRBuilder &MIRBuilder)

const Function * getFunctionByDefinition(const MachineInstr *MI)

const SPIRVInstrInfo * getInstrInfo() const override

const SPIRVRegisterInfo * getRegisterInfo() const override

const RegisterBankInfo * getRegBankInfo() const override

bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallBase &I, MachineFunction &MF, unsigned Intrinsic) const override

Given an intrinsic, checks if on the target the intrinsic will need to map to a MemIntrinsicNode (tou...

Definition SPIRVISelLowering.cpp:96

bool enforcePtrTypeCompatibility(MachineInstr &I, unsigned PtrOpIdx, unsigned OpIdx) const

Definition SPIRVISelLowering.cpp:580

unsigned getNumRegisters(LLVMContext &Context, EVT VT, std::optional< MVT > RegisterVT=std::nullopt) const override

Return the number of registers that this ValueType will eventually require.

unsigned getNumRegistersForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override

Certain targets require unusual breakdowns of certain types.

Definition SPIRVISelLowering.cpp:67

MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override

Certain combinations of ABIs, Targets and features require that types are legal for some operations a...

Definition SPIRVISelLowering.cpp:81

void finalizeLowering(MachineFunction &MF) const override

Execute target specific actions to finalize target lowering.

Definition SPIRVISelLowering.cpp:380

bool insertLogicalCopyOnResult(MachineInstr &I, SPIRVType *NewResultType) const

Definition SPIRVISelLowering.cpp:604

SPIRVTargetLowering(const TargetMachine &TM, const SPIRVSubtarget &ST)

Definition SPIRVISelLowering.cpp:27

std::pair< unsigned, const TargetRegisterClass * > getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const override

Given a physical register constraint (e.g.

Definition SPIRVISelLowering.cpp:125

std::pair< iterator, bool > insert(PtrType Ptr)

Inserts Ptr if and only if there is no element in the container equal to Ptr.

SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.

StringRef - Represent a constant reference to a string, i.e.

bool starts_with(StringRef Prefix) const

Check if this string starts with the given Prefix.

static LLVM_ABI TargetExtType * get(LLVMContext &Context, StringRef Name, ArrayRef< Type * > Types={}, ArrayRef< unsigned > Ints={})

Return a target extension type having the specified name and optional type and integer parameters.

virtual void finalizeLowering(MachineFunction &MF) const

Execute target specific actions to finalize target lowering.

MVT getRegisterType(MVT VT) const

Return the type of registers that this ValueType will eventually require.

TargetLowering(const TargetLowering &)=delete

Primary interface to the complete machine description for the target machine.

TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...

The instances of the Type class are immutable: once they are created, they are never changed.

static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)

NodeTy * getNextNode()

Get the next node, or nullptr for the list tail.

#define llvm_unreachable(msg)

Marks that the current location is not supposed to be reachable.

unsigned ID

LLVM IR allows to use arbitrary numbers as calling convention identifiers.

This namespace contains an enum with a value for every intrinsic/builtin function known by LLVM.

This is an optimization pass for GlobalISel generic memory operations.

auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)

Get the size of a range.

decltype(auto) dyn_cast(const From &Val)

dyn_cast - Return the argument parameter cast to the specified type.

Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR, MachineRegisterInfo *MRI, const MachineFunction &MF)

MachineInstr * getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI)

LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)

const MachineInstr SPIRVType

decltype(auto) cast(const From &Val)

cast - Return the argument parameter cast to the specified type.

This struct is a compact representation of a valid (non-zero power of two) alignment.

TypeSize getSizeInBits() const

Return the size of the specified value type in bits.

bool isVector() const

Return true if this is a vector value type.

EVT getVectorElementType() const

Given a vector type, return the type of each element.

unsigned getVectorNumElements() const

Given a vector type, return the number of elements it contains.

bool isInteger() const

Return true if this is an integer or a vector integer type.