LLVM: lib/Target/SPIRV/SPIRVPrepareFunctions.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

31#include "llvm/IR/IntrinsicsSPIRV.h"

34#include

35

36using namespace llvm;

37

38namespace llvm {

40}

41

42namespace {

43

44class SPIRVPrepareFunctions : public ModulePass {

46 bool substituteIntrinsicCalls(Function *F);

48

49public:

50 static char ID;

53 }

54

56

58

61 }

62};

63

64}

65

66char SPIRVPrepareFunctions::ID = 0;

67

69 "SPIRV prepare functions", false, false)

70

72 Function *IntrinsicFunc = II->getCalledFunction();

73 assert(IntrinsicFunc && "Missing function");

74 std::string FuncName = IntrinsicFunc->getName().str();

75 std::replace(FuncName.begin(), FuncName.end(), '.', '_');

76 FuncName = "spirv." + FuncName;

77 return FuncName;

78}

79

85 if (F && F->getFunctionType() == FT)

86 return F;

88 if (F)

91 return NewF;

92}

93

95

96

97

98

99 if (auto *MSI = dyn_cast(Intrinsic))

100 if (isa(MSI->getValue()) && isa(MSI->getLength()))

101 return false;

102

103 Module *M = Intrinsic->getModule();

104 std::string FuncName = lowerLLVMIntrinsicName(Intrinsic);

105 if (Intrinsic->isVolatile())

106 FuncName += ".volatile";

107

108 Function *F = M->getFunction(FuncName);

109 if (F) {

110 Intrinsic->setCalledFunction(F);

111 return true;

112 }

113

115 M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType());

116 auto IntrinsicID = Intrinsic->getIntrinsicID();

117 Intrinsic->setCalledFunction(FC);

118

119 F = dyn_cast(FC.getCallee());

120 assert(F && "Callee must be a function");

121

122 switch (IntrinsicID) {

123 case Intrinsic::memset: {

124 auto *MSI = static_cast<MemSetInst *>(Intrinsic);

128 Argument *IsVolatile = F->getArg(3);

131 Len->setName("len");

132 IsVolatile->setName("isvolatile");

135 auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),

136 MSI->isVolatile());

139 MemSet->eraseFromParent();

140 break;

141 }

142 case Intrinsic::bswap: {

145 auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),

146 F->getArg(0));

150 break;

151 }

152 default:

153 break;

154 }

155 return true;

156}

157

159 if (auto *Ref = dyn_cast_or_null(AnnoVal))

160 AnnoVal = Ref->getOperand(0);

161 if (auto *Ref = dyn_cast_or_null(OptAnnoVal))

162 OptAnnoVal = Ref->getOperand(0);

163

164 std::string Anno;

165 if (auto *C = dyn_cast_or_null(AnnoVal)) {

168 Anno = Str;

169 }

170

171

172 if (auto *C = dyn_cast_or_null(OptAnnoVal);

173 C && C->getNumOperands()) {

174 Value *MaybeStruct = C->getOperand(0);

175 if (auto *Struct = dyn_cast(MaybeStruct)) {

176 for (unsigned I = 0, E = Struct->getNumOperands(); I != E; ++I) {

177 if (auto *CInt = dyn_cast(Struct->getOperand(I)))

178 Anno += (I == 0 ? ": " : ", ") +

179 std::to_string(CInt->getType()->getIntegerBitWidth() == 1

180 ? CInt->getZExtValue()

181 : CInt->getSExtValue());

182 }

183 } else if (auto *Struct = dyn_cast(MaybeStruct)) {

184

185 for (unsigned I = 0, E = Struct->getType()->getStructNumElements();

186 I != E; ++I)

187 Anno += I == 0 ? ": 0" : ", 0";

188 }

189 }

190 return Anno;

191}

192

194 const std::string &Anno,

196 Type *Int32Ty) {

197

198

199

200

201 static const std::regex R(

202 "\\{(\\d+)(?:[:,](\\d+|\"[^\"]*\")(?:,(\\d+|\"[^\"]*\"))*)?\\}");

204 int Pos = 0;

205 for (std::sregex_iterator

206 It = std::sregex_iterator(Anno.begin(), Anno.end(), R),

207 ItEnd = std::sregex_iterator();

208 It != ItEnd; ++It) {

209 if (It->position() != Pos)

211 Pos = It->position() + It->length();

212 std::smatch Match = *It;

214 for (std::size_t i = 1; i < Match.size(); ++i) {

215 std::ssub_match SMatch = Match[i];

216 std::string Item = SMatch.str();

217 if (Item.length() == 0)

218 break;

219 if (Item[0] == '"') {

220 Item = Item.substr(1, Item.length() - 2);

221

222 static const std::regex RStr("^(\\d+)(?:,(\\d+))*$");

223 if (std::smatch MatchStr; std::regex_match(Item, MatchStr, RStr)) {

224 for (std::size_t SubIdx = 1; SubIdx < MatchStr.size(); ++SubIdx)

225 if (std::string SubStr = MatchStr[SubIdx].str(); SubStr.length())

227 ConstantInt::get(Int32Ty, std::stoi(SubStr))));

228 } else {

230 }

231 } else if (int32_t Num; llvm::to_integer(StringRef(Item), Num, 10)) {

234 } else {

236 }

237 }

238 if (MDsItem.size() == 0)

241 }

242 return Pos == static_cast<int>(Anno.length()) ? MDs

244}

245

249

250

251 Value *PtrArg = nullptr;

252 if (auto *BI = dyn_cast(II->getArgOperand(0)))

253 PtrArg = BI->getOperand(0);

254 else

255 PtrArg = II->getOperand(0);

256 std::string Anno =

258 4 < II->arg_size() ? II->getArgOperand(4) : nullptr);

259

260

262

263

264

265

266

267 if (MDs.size() == 0) {

269 Int32Ty, static_cast<uint32_t>(SPIRV::Decoration::UserSemantic)));

271 }

272

273

277 Intrinsic::spv_assign_decoration, {PtrArg->getType()},

279 II->replaceAllUsesWith(II->getOperand(0));

280}

281

283

284

285

286

289 Type *FSHRetTy = FSHFuncTy->getReturnType();

290 const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);

293

294 if (!FSHFunc->empty()) {

296 return;

297 }

301

302

303

304 FixedVectorType *VectorTy = dyn_cast(Ty);

308 Value *BitWidthForInsts =

309 VectorTy

311 : BitWidthConstant;

312 Value *RotateModVal =

313 IRB.CreateURem( FSHFunc->getArg(2), BitWidthForInsts);

314 Value *FirstShift = nullptr, *SecShift = nullptr;

315 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {

316

317

318 FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);

319 } else {

320

321

322 FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);

323 }

324

325

326

327 Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);

328 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {

329

330

331 SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);

332 } else {

333

334

336 }

337

339

341}

342

344

345

346

347

348

349

350

351

352

353 if (II->getIntrinsicID() == Intrinsic::assume) {

355 II->getModule(), Intrinsic::SPVIntrinsics::spv_assume);

356 II->setCalledFunction(F);

357 } else if (II->getIntrinsicID() == Intrinsic::expect) {

359 II->getModule(), Intrinsic::SPVIntrinsics::spv_expect,

360 {II->getOperand(0)->getType()});

361 II->setCalledFunction(F);

362 } else {

364 }

365

366 return;

367}

368

372 if (OpNos.empty()) {

374 } else {

376 for (unsigned OpNo : OpNos)

377 Tys.push_back(II->getOperand(OpNo)->getType());

379 }

380 II->setCalledFunction(F);

381 return true;

382}

383

384

385

386bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {

387 bool Changed = false;

390 auto Call = dyn_cast(&I);

391 if (!Call)

392 continue;

395 continue;

396 auto *II = cast(Call);

397 switch (II->getIntrinsicID()) {

398 case Intrinsic::memset:

399 case Intrinsic::bswap:

401 break;

402 case Intrinsic::fshl:

403 case Intrinsic::fshr:

405 Changed = true;

406 break;

407 case Intrinsic::assume:

408 case Intrinsic::expect: {

410 if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume))

412 Changed = true;

413 } break;

414 case Intrinsic::lifetime_start:

416 II, Intrinsic::SPVIntrinsics::spv_lifetime_start, {1});

417 break;

418 case Intrinsic::lifetime_end:

420 II, Intrinsic::SPVIntrinsics::spv_lifetime_end, {1});

421 break;

422 case Intrinsic::ptr_annotation:

424 Changed = true;

425 break;

426 }

427 }

428 }

429 return Changed;

430}

431

432

433

434

436SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {

437 bool IsRetAggr = F->getReturnType()->isAggregateType();

438

439 if (F->isIntrinsic() && IsRetAggr)

440 return F;

441

443

444 bool HasAggrArg =

445 std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {

446 return Arg.getType()->isAggregateType();

447 });

448 bool DoClone = IsRetAggr || HasAggrArg;

449 if (!DoClone)

450 return F;

452 Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();

453 if (IsRetAggr)

454 ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));

456 for (const auto &Arg : F->args()) {

457 if (Arg.getType()->isAggregateType()) {

460 std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));

461 } else

462 ArgTypes.push_back(Arg.getType());

463 }

465 FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());

467 Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());

468

470 auto NewFArgIt = NewF->arg_begin();

471 for (auto &Arg : F->args()) {

472 StringRef ArgName = Arg.getName();

473 NewFArgIt->setName(ArgName);

474 VMap[&Arg] = &(*NewFArgIt++);

475 }

477

478 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,

479 Returns);

481

483 F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");

486 for (auto &ChangedTyP : ChangedTypes)

488 B.getContext(),

489 {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),

490 ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));

493

495 if (auto *CI = dyn_cast(U))

497 U->replaceUsesOfWith(F, NewF);

498 }

499

500

501 if (RetType != F->getReturnType())

503 NewF, F->getReturnType());

504 return NewF;

505}

506

507bool SPIRVPrepareFunctions::runOnModule(Module &M) {

508 bool Changed = false;

510 Changed |= substituteIntrinsicCalls(&F);

512 }

513

514 std::vector<Function *> FuncsWorklist;

515 for (auto &F : M)

516 FuncsWorklist.push_back(&F);

517

518 for (auto *F : FuncsWorklist) {

519 Function *NewF = removeAggregateTypesFromSignature(F);

520

521 if (NewF != F) {

522 F->eraseFromParent();

523 Changed = true;

524 }

525 }

526 return Changed;

527}

528

531 return new SPIRVPrepareFunctions(TM);

532}

static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")

uint64_t IntrinsicInst * II

#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)

assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())

static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic)

static std::string getAnnotation(Value *AnnoVal, Value *OptAnnoVal)

static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID, ArrayRef< unsigned > OpNos)

static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic)

static void lowerPtrAnnotation(IntrinsicInst *II)

static SmallVector< Metadata * > parseAnnotation(Value *I, const std::string &Anno, LLVMContext &Ctx, Type *Int32Ty)

static void lowerExpectAssume(IntrinsicInst *II)

static Function * getOrCreateFunction(Module *M, Type *RetTy, ArrayRef< Type * > ArgTypes, StringRef Name)

Represent the analysis usage information of a pass.

This class represents an incoming formal argument to a Function.

ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...

bool empty() const

empty - Check if the array is empty.

LLVM Basic Block Representation.

static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)

Creates a new BasicBlock.

FunctionType * getFunctionType() const

void setCalledFunction(Function *Fn)

Sets the function called, including updating the function type.

This is the shared class of boolean and integer constants.

Class to represent fixed width SIMD vectors.

unsigned getNumElements() const

A handy container for a FunctionType+Callee-pointer pair, which can be passed around as a single enti...

static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)

FunctionType * getFunctionType() const

Returns the FunctionType for me.

bool isIntrinsic() const

isIntrinsic - Returns true if the function's name starts with "llvm.".

Type * getReturnType() const

Returns the type of the ret val.

void setCallingConv(CallingConv::ID CC)

Argument * getArg(unsigned i) const

void setDSOLocal(bool Local)

@ ExternalLinkage

Externally visible function.

Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")

Return a vector value that contains.

CallInst * CreateMemSet(Value *Ptr, Value *Val, uint64_t Size, MaybeAlign Align, bool isVolatile=false, MDNode *TBAATag=nullptr, MDNode *ScopeTag=nullptr, MDNode *NoAliasTag=nullptr)

Create and insert a memset to the specified pointer and the specified value.

Value * CreateLShr(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)

ReturnInst * CreateRet(Value *V)

Create a 'ret ' instruction.

CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")

Create a call to intrinsic ID with Args, mangled using Types.

Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)

Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)

ReturnInst * CreateRetVoid()

Create a 'ret void' instruction.

Value * CreateOr(Value *LHS, Value *RHS, const Twine &Name="")

void SetInsertPoint(BasicBlock *TheBB)

This specifies that created instructions should be appended to the end of the specified block.

ConstantInt * getInt(const APInt &AI)

Get a constant integer value.

Value * CreateURem(Value *LHS, Value *RHS, const Twine &Name="")

This provides a uniform API for creating instructions and inserting them into a basic block: either a...

const Module * getModule() const

Return the module owning the function this instruction belongs to or nullptr it the function does not...

A wrapper class for inspecting calls to intrinsic functions.

Intrinsic::ID getIntrinsicID() const

Return the intrinsic ID of this intrinsic.

void LowerIntrinsicCall(CallInst *CI)

Replace a call to the specified intrinsic function.

This is an important class for using LLVM in a threaded context.

static MDTuple * get(LLVMContext &Context, ArrayRef< Metadata * > MDs)

static MDString * get(LLVMContext &Context, StringRef Str)

This class wraps the llvm.memset and llvm.memset.inline intrinsics.

ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...

virtual bool runOnModule(Module &M)=0

runOnModule - Virtual method overriden by subclasses to process the module being operated on.

A Module instance is used to store all the information related to an LLVM module.

void addOperand(MDNode *M)

PassRegistry - This class manages the registration and intitialization of the pass subsystem as appli...

static PassRegistry * getPassRegistry()

getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...

virtual void getAnalysisUsage(AnalysisUsage &) const

getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...

virtual StringRef getPassName() const

getPassName - Return a nice clean name for a pass.

void addMutated(Value *Val, Type *Ty)

SPIRVGlobalRegistry * getSPIRVGlobalRegistry() const

bool canUseExtension(SPIRV::Extension::Extension E) const

void push_back(const T &Elt)

This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.

StringRef - Represent a constant reference to a string, i.e.

std::string str() const

str - Get the contents as an std::string.

The instances of the Type class are immutable: once they are created, they are never changed.

unsigned getIntegerBitWidth() const

static IntegerType * getInt32Ty(LLVMContext &C)

LLVM Value Representation.

Type * getType() const

All values are typed, get the type of this value.

void setName(const Twine &Name)

Change the name of the value.

StringRef getName() const

Return a constant reference to the value's name.

void takeName(Value *V)

Transfer the name from V to this value.

Type * getElementType() const

#define llvm_unreachable(msg)

Marks that the current location is not supposed to be reachable.

@ SPIR_FUNC

Used for SPIR non-kernel device functions.

@ C

The default llvm calling convention, compatible with C.

unsigned ID

LLVM IR allows to use arbitrary numbers as calling convention identifiers.

Function * getOrInsertDeclaration(Module *M, ID id, ArrayRef< Type * > Tys={})

Look up the Function declaration of the intrinsic id in the Module M.

This is an optimization pass for GlobalISel generic memory operations.

void initializeSPIRVPrepareFunctionsPass(PassRegistry &)

bool getConstantStringInfo(const Value *V, StringRef &Str, bool TrimAtNul=true)

This function computes the length of a null-terminated C string pointed to by V.

iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)

Make a range that does early increment to allow mutation of the underlying range without disrupting i...

bool sortBlocks(Function &F)

@ Ref

The access may reference the value stored in memory.

constexpr unsigned BitWidth

void CloneFunctionInto(Function *NewFunc, const Function *OldFunc, ValueToValueMapTy &VMap, CloneFunctionChangeType Changes, SmallVectorImpl< ReturnInst * > &Returns, const char *NameSuffix="", ClonedCodeInfo *CodeInfo=nullptr, ValueMapTypeRemapper *TypeMapper=nullptr, ValueMaterializer *Materializer=nullptr)

Clone OldFunc into NewFunc, transforming the old arguments into references to VMap values.

ModulePass * createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM)

void expandMemSetAsLoop(MemSetInst *MemSet)

Expand MemSet as a loop. MemSet is not deleted.

Implement std::hash so that hash_code can be used in STL containers.