LLVM: lib/Target/DirectX/DXILFlattenArrays.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

24#include

25#include

26#include

27#include

28

29#define DEBUG_TYPE "dxil-flatten-arrays"

30

31using namespace llvm;

32namespace {

33

34class DXILFlattenArraysLegacy : public ModulePass {

35

36public:

39

40 static char ID;

41};

42

43struct GEPData {

45 Value *ParendOperand;

48 bool AllIndicesAreConstInt;

49};

50

51class DXILFlattenArraysVisitor

52 : public InstVisitor<DXILFlattenArraysVisitor, bool> {

53public:

54 DXILFlattenArraysVisitor() {}

56

57

76 static bool isMultiDimensionalArray(Type *T);

77 static std::pair<unsigned, Type *> getElementCountAndType(Type *ArrayTy);

78

79private:

82 bool finish();

89 void

92 unsigned &GEPChainUseCount,

95 bool AllIndicesAreConstInt = true);

97 bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,

99};

100}

101

102bool DXILFlattenArraysVisitor::finish() {

104 return true;

105}

106

107bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {

108 if (ArrayType *ArrType = dyn_cast(T))

109 return isa(ArrType->getElementType());

110 return false;

111}

112

113std::pair<unsigned, Type *>

114DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) {

115 unsigned TotalElements = 1;

116 Type *CurrArrayTy = ArrayTy;

117 while (auto *InnerArrayTy = dyn_cast(CurrArrayTy)) {

118 TotalElements *= InnerArrayTy->getNumElements();

119 CurrArrayTy = InnerArrayTy->getElementType();

120 }

121 return std::make_pair(TotalElements, CurrArrayTy);

122}

123

124ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(

127 "Indicies and dimmensions should be the same");

128 unsigned FlatIndex = 0;

129 unsigned Multiplier = 1;

130

131 for (int I = Indices.size() - 1; I >= 0; --I) {

132 unsigned DimSize = Dims[I];

133 ConstantInt *CIndex = dyn_cast(Indices[I]);

134 assert(CIndex && "This function expects all indicies to be ConstantInt");

135 FlatIndex += CIndex->getZExtValue() * Multiplier;

136 Multiplier *= DimSize;

137 }

138 return Builder.getInt32(FlatIndex);

139}

140

141Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(

143 if (Indices.size() == 1)

144 return Indices[0];

145

147 unsigned Multiplier = 1;

148

149 for (int I = Indices.size() - 1; I >= 0; --I) {

150 unsigned DimSize = Dims[I];

151 Value *VMultiplier = Builder.getInt32(Multiplier);

152 Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier);

153 FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);

154 Multiplier *= DimSize;

155 }

156 return FlatIndex;

157}

158

159bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {

161 for (unsigned I = 0; I < NumOperands; ++I) {

163 ConstantExpr *CE = dyn_cast(CurrOpperand);

164 if (CE && CE->getOpcode() == Instruction::GetElementPtr) {

166 cast(CE->getAsInstruction());

168

175 visitGetElementPtrInst(*OldGEP);

176 return true;

177 }

178 }

179 return false;

180}

181

182bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {

183 unsigned NumOperands = SI.getNumOperands();

184 for (unsigned I = 0; I < NumOperands; ++I) {

185 Value *CurrOpperand = SI.getOperand(I);

186 ConstantExpr *CE = dyn_cast(CurrOpperand);

187 if (CE && CE->getOpcode() == Instruction::GetElementPtr) {

189 cast(CE->getAsInstruction());

191

195 SI.replaceAllUsesWith(NewStore);

196 SI.eraseFromParent();

197 visitGetElementPtrInst(*OldGEP);

198 return true;

199 }

200 }

201 return false;

202}

203

204bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {

206 return false;

207

210 auto [TotalElements, BaseType] = getElementCountAndType(ArrType);

211

212 ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);

214 Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat");

218 return true;

219}

220

221void DXILFlattenArraysVisitor::recursivelyCollectGEPs(

226 AllIndicesAreConstInt &= isa(LastIndex);

231 bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());

232 if (!IsMultiDimArr) {

234 GEPChainMap.insert(

235 {&CurrGEP,

236 {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),

237 std::move(Dims), AllIndicesAreConstInt}});

238 return;

239 }

240 bool GepUses = false;

241 for (auto *User : CurrGEP.users()) {

243 recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,

244 ++GEPChainUseCount, Indices, Dims,

245 AllIndicesAreConstInt);

246 GepUses = true;

247 }

248 }

249

250 if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {

251 GEPChainMap.insert(

252 {&CurrGEP,

253 {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),

254 std::move(Dims), AllIndicesAreConstInt}});

255 }

256}

257

258bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(

260 GEPData GEPInfo = GEPChainMap.at(&GEP);

261 return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);

262}

263bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(

266 Value *FlatIndex;

267 if (GEPInfo.AllIndicesAreConstInt)

268 FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);

269 else

270 FlatIndex =

271 genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);

272

273 ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;

275 Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParendOperand, FlatIndex,

276 GEP.getName() + ".flat", GEP.isInBounds());

277

278 GEP.replaceAllUsesWith(FlatGEP);

279 GEP.eraseFromParent();

280 return true;

281}

282

283bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {

284 auto It = GEPChainMap.find(&GEP);

285 if (It != GEPChainMap.end())

286 return visitGetElementPtrInstInGEPChain(GEP);

287 if (!isMultiDimensionalArray(GEP.getSourceElementType()))

288 return false;

289

290 ArrayType *ArrType = cast(GEP.getSourceElementType());

292 auto [TotalElements, BaseType] = getElementCountAndType(ArrType);

293 ArrayType *FlattenedArrayType = ArrayType::get(BaseType, TotalElements);

294

295 Value *PtrOperand = GEP.getPointerOperand();

296

297 unsigned GEPChainUseCount = 0;

298 recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);

299

300

301

302

303

304 if (GEPChainUseCount == 0) {

307 bool AllIndicesAreConstInt = isa(Indices[0]);

308 GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,

309 std::move(Indices), std::move(Dims), AllIndicesAreConstInt};

310 return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);

311 }

312

313 PotentiallyDeadInstrs.emplace_back(&GEP);

314 return false;

315}

316

317bool DXILFlattenArraysVisitor::visit(Function &F) {

318 bool MadeChange = false;

323 }

324 finish();

325 return MadeChange;

326}

327

330

331 auto *ArrayTy = dyn_cast(Init->getType());

332 if (!ArrayTy) {

333 Elements.push_back(Init);

334 return;

335 }

336 unsigned ArrSize = ArrayTy->getNumElements();

337 if (isa(Init)) {

338 for (unsigned I = 0; I < ArrSize; ++I)

340 return;

341 }

342

343

344 if (auto *ArrayConstant = dyn_cast(Init)) {

345 for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) {

347 }

348 } else if (auto *DataArrayConstant = dyn_cast(Init)) {

349 for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) {

350 collectElements(DataArrayConstant->getElementAsConstant(I), Elements);

351 }

352 } else {

354 "Expected a ConstantArray or ConstantDataArray for array initializer!");

355 }

356}

357

361

362 if (isa(Init))

364

365

366 if (isa(Init))

368

369 if (!isa(OrigType))

371

374 assert(FlattenedType->getNumElements() == FlattenedElements.size() &&

375 "The number of collected elements should match the FlattenedType");

377}

378

379static void

384 Type *OrigType = G.getValueType();

385 if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))

386 continue;

387

388 ArrayType *ArrType = cast(OrigType);

389 auto [TotalElements, BaseType] =

390 DXILFlattenArraysVisitor::getElementCountAndType(ArrType);

391 ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);

392

393

394

396 new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(),

397 nullptr, G.getName() + ".1dim", &G,

398 G.getThreadLocalMode(), G.getAddressSpace(),

399 G.isExternallyInitialized());

400

401

403 if (G.getAlignment() > 0) {

405 }

406

407 if (G.hasInitializer()) {

412 }

413 GlobalMap[&G] = NewGlobal;

414 }

415}

416

418 bool MadeChange = false;

419 DXILFlattenArraysVisitor Impl;

423 if (F.isDeclaration())

424 continue;

425 MadeChange |= Impl.visit(F);

426 }

427 for (auto &[Old, New] : GlobalMap) {

428 Old->replaceAllUsesWith(New);

429 Old->eraseFromParent();

430 MadeChange = true;

431 }

432 return MadeChange;

433}

434

437 if (!MadeChanges)

440 return PA;

441}

442

443bool DXILFlattenArraysLegacy::runOnModule(Module &M) {

445}

446

447char DXILFlattenArraysLegacy::ID = 0;

448

450 "DXIL Array Flattener", false, false)

453

455 return new DXILFlattenArraysLegacy();

456}

static void collectElements(Constant *Init, SmallVectorImpl< Constant * > &Elements)

static bool flattenArrays(Module &M)

static void flattenGlobalArrays(Module &M, DenseMap< GlobalVariable *, GlobalVariable * > &GlobalMap)

static Constant * transformInitializer(Constant *Init, Type *OrigType, ArrayType *FlattenedType, LLVMContext &Ctx)

#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)

#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)

This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.

assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())

static unsigned getNumElements(Type *Ty)

an instruction to allocate memory on the stack

Align getAlign() const

Return the alignment of the memory that is being allocated by the instruction.

Type * getAllocatedType() const

Return the type that is being allocated by the instruction.

void setAlignment(Align Align)

A container for analyses that lazily runs them and caches their results.

ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...

size_t size() const

size - Get the array size.

LLVM Basic Block Representation.

This class represents a no-op cast from one type to another.

This class represents a function call, abstracting a target machine's calling convention.

This is the base class for all instructions that perform data casts.

static ConstantAggregateZero * get(Type *Ty)

static Constant * get(ArrayType *T, ArrayRef< Constant * > V)

A constant value that is initialized with an expression using other constant values.

This is the shared class of boolean and integer constants.

uint64_t getZExtValue() const

Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...

This is an important base class in LLVM.

static Constant * getNullValue(Type *Ty)

Constructor to create a '0' constant of arbitrary type.

PreservedAnalyses run(Module &M, ModuleAnalysisManager &)

This instruction compares its operands according to the predicate given to the constructor.

This class represents a freeze function that returns random concrete value if an operand is either a ...

an instruction for type-safe pointer arithmetic to access elements of arrays and structs

Type * getSourceElementType() const

void setAlignment(Align Align)

Sets the alignment attribute of the GlobalObject.

void setUnnamedAddr(UnnamedAddr Val)

void setInitializer(Constant *InitVal)

setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...

This instruction compares its operands according to the predicate given to the constructor.

AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")

Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())

ConstantInt * getInt32(uint32_t C)

Get a constant 32-bit value.

LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)

Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...

StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)

Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)

Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)

This provides a uniform API for creating instructions and inserting them into a basic block: either a...

This instruction inserts a single (scalar) element into a VectorType value.

Base class for instruction visitors.

RetTy visitFreezeInst(FreezeInst &I)

RetTy visitFCmpInst(FCmpInst &I)

RetTy visitExtractElementInst(ExtractElementInst &I)

RetTy visitShuffleVectorInst(ShuffleVectorInst &I)

RetTy visitBitCastInst(BitCastInst &I)

void visit(Iterator Start, Iterator End)

RetTy visitPHINode(PHINode &I)

RetTy visitUnaryOperator(UnaryOperator &I)

RetTy visitStoreInst(StoreInst &I)

RetTy visitInsertElementInst(InsertElementInst &I)

RetTy visitAllocaInst(AllocaInst &I)

RetTy visitBinaryOperator(BinaryOperator &I)

RetTy visitICmpInst(ICmpInst &I)

RetTy visitCallInst(CallInst &I)

RetTy visitCastInst(CastInst &I)

RetTy visitSelectInst(SelectInst &I)

RetTy visitGetElementPtrInst(GetElementPtrInst &I)

void visitInstruction(Instruction &I)

RetTy visitLoadInst(LoadInst &I)

void insertBefore(Instruction *InsertPos)

Insert an unlinked instruction into a basic block immediately before the specified instruction.

InstListType::iterator eraseFromParent()

This method unlinks 'this' from the containing basic block and deletes it.

This is an important class for using LLVM in a threaded context.

An instruction for reading from memory.

void setAlignment(Align Align)

Align getAlign() const

Return the alignment of the access that is being performed.

ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...

virtual bool runOnModule(Module &M)=0

runOnModule - Virtual method overriden by subclasses to process the module being operated on.

A Module instance is used to store all the information related to an LLVM module.

A set of analyses that are preserved following a run of a transformation pass.

static PreservedAnalyses all()

Construct a special preserved set that preserves all passes.

This class represents the LLVM 'select' instruction.

This instruction constructs a fixed permutation of two input vectors.

This class consists of common code factored out of the SmallVector class to reduce code duplication b...

void push_back(const T &Elt)

This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.

An instruction for storing to memory.

void setAlignment(Align Align)

The instances of the Type class are immutable: once they are created, they are never changed.

static UndefValue * get(Type *T)

Static factory methods - Return an 'undef' object of the specified type.

Value * getOperand(unsigned i) const

unsigned getNumOperands() const

LLVM Value Representation.

Type * getType() const

All values are typed, get the type of this value.

void replaceAllUsesWith(Value *V)

Change all uses of this to point to a new Value.

iterator_range< user_iterator > users()

StringRef getName() const

Return a constant reference to the value's name.

#define llvm_unreachable(msg)

Marks that the current location is not supposed to be reachable.

unsigned ID

LLVM IR allows to use arbitrary numbers as calling convention identifiers.

@ CE

Windows NT (Windows on ARM)

This is an optimization pass for GlobalISel generic memory operations.

iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)

Make a range that does early increment to allow mutation of the underlying range without disrupting i...

ModulePass * createDXILFlattenArraysLegacyPass()

Pass to flatten arrays into a one dimensional DXIL legal form.

bool RecursivelyDeleteTriviallyDeadInstructionsPermissive(SmallVectorImpl< WeakTrackingVH > &DeadInsts, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())

Same functionality as RecursivelyDeleteTriviallyDeadInstructions, but allow instructions that are not...