LLVM: lib/Target/DirectX/DXILFlattenArrays.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
24#include
25#include
26#include
27#include
28
29#define DEBUG_TYPE "dxil-flatten-arrays"
30
31using namespace llvm;
32namespace {
33
34class DXILFlattenArraysLegacy : public ModulePass {
35
36public:
39
40 static char ID;
41};
42
43struct GEPData {
45 Value *ParendOperand;
48 bool AllIndicesAreConstInt;
49};
50
51class DXILFlattenArraysVisitor
52 : public InstVisitor<DXILFlattenArraysVisitor, bool> {
53public:
54 DXILFlattenArraysVisitor() {}
56
57
76 static bool isMultiDimensionalArray(Type *T);
77 static std::pair<unsigned, Type *> getElementCountAndType(Type *ArrayTy);
78
79private:
82 bool finish();
89 void
92 unsigned &GEPChainUseCount,
95 bool AllIndicesAreConstInt = true);
97 bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
99};
100}
101
102bool DXILFlattenArraysVisitor::finish() {
104 return true;
105}
106
107bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
108 if (ArrayType *ArrType = dyn_cast(T))
109 return isa(ArrType->getElementType());
110 return false;
111}
112
113std::pair<unsigned, Type *>
114DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) {
115 unsigned TotalElements = 1;
116 Type *CurrArrayTy = ArrayTy;
117 while (auto *InnerArrayTy = dyn_cast(CurrArrayTy)) {
118 TotalElements *= InnerArrayTy->getNumElements();
119 CurrArrayTy = InnerArrayTy->getElementType();
120 }
121 return std::make_pair(TotalElements, CurrArrayTy);
122}
123
124ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(
127 "Indicies and dimmensions should be the same");
128 unsigned FlatIndex = 0;
129 unsigned Multiplier = 1;
130
131 for (int I = Indices.size() - 1; I >= 0; --I) {
132 unsigned DimSize = Dims[I];
133 ConstantInt *CIndex = dyn_cast(Indices[I]);
134 assert(CIndex && "This function expects all indicies to be ConstantInt");
135 FlatIndex += CIndex->getZExtValue() * Multiplier;
136 Multiplier *= DimSize;
137 }
138 return Builder.getInt32(FlatIndex);
139}
140
141Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(
143 if (Indices.size() == 1)
144 return Indices[0];
145
147 unsigned Multiplier = 1;
148
149 for (int I = Indices.size() - 1; I >= 0; --I) {
150 unsigned DimSize = Dims[I];
151 Value *VMultiplier = Builder.getInt32(Multiplier);
152 Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier);
153 FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);
154 Multiplier *= DimSize;
155 }
156 return FlatIndex;
157}
158
159bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
161 for (unsigned I = 0; I < NumOperands; ++I) {
163 ConstantExpr *CE = dyn_cast(CurrOpperand);
164 if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
166 cast(CE->getAsInstruction());
168
175 visitGetElementPtrInst(*OldGEP);
176 return true;
177 }
178 }
179 return false;
180}
181
182bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
183 unsigned NumOperands = SI.getNumOperands();
184 for (unsigned I = 0; I < NumOperands; ++I) {
185 Value *CurrOpperand = SI.getOperand(I);
186 ConstantExpr *CE = dyn_cast(CurrOpperand);
187 if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
189 cast(CE->getAsInstruction());
191
195 SI.replaceAllUsesWith(NewStore);
196 SI.eraseFromParent();
197 visitGetElementPtrInst(*OldGEP);
198 return true;
199 }
200 }
201 return false;
202}
203
204bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
206 return false;
207
210 auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
211
212 ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
214 Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".flat");
218 return true;
219}
220
221void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
226 AllIndicesAreConstInt &= isa(LastIndex);
231 bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
232 if (!IsMultiDimArr) {
234 GEPChainMap.insert(
235 {&CurrGEP,
236 {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
237 std::move(Dims), AllIndicesAreConstInt}});
238 return;
239 }
240 bool GepUses = false;
241 for (auto *User : CurrGEP.users()) {
243 recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
244 ++GEPChainUseCount, Indices, Dims,
245 AllIndicesAreConstInt);
246 GepUses = true;
247 }
248 }
249
250 if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
251 GEPChainMap.insert(
252 {&CurrGEP,
253 {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
254 std::move(Dims), AllIndicesAreConstInt}});
255 }
256}
257
258bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
260 GEPData GEPInfo = GEPChainMap.at(&GEP);
261 return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
262}
263bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
266 Value *FlatIndex;
267 if (GEPInfo.AllIndicesAreConstInt)
268 FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
269 else
270 FlatIndex =
271 genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
272
273 ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
275 Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParendOperand, FlatIndex,
276 GEP.getName() + ".flat", GEP.isInBounds());
277
278 GEP.replaceAllUsesWith(FlatGEP);
279 GEP.eraseFromParent();
280 return true;
281}
282
283bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
284 auto It = GEPChainMap.find(&GEP);
285 if (It != GEPChainMap.end())
286 return visitGetElementPtrInstInGEPChain(GEP);
287 if (!isMultiDimensionalArray(GEP.getSourceElementType()))
288 return false;
289
290 ArrayType *ArrType = cast(GEP.getSourceElementType());
292 auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
293 ArrayType *FlattenedArrayType = ArrayType::get(BaseType, TotalElements);
294
295 Value *PtrOperand = GEP.getPointerOperand();
296
297 unsigned GEPChainUseCount = 0;
298 recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
299
300
301
302
303
304 if (GEPChainUseCount == 0) {
307 bool AllIndicesAreConstInt = isa(Indices[0]);
308 GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
309 std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
310 return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
311 }
312
313 PotentiallyDeadInstrs.emplace_back(&GEP);
314 return false;
315}
316
317bool DXILFlattenArraysVisitor::visit(Function &F) {
318 bool MadeChange = false;
323 }
324 finish();
325 return MadeChange;
326}
327
330
331 auto *ArrayTy = dyn_cast(Init->getType());
332 if (!ArrayTy) {
333 Elements.push_back(Init);
334 return;
335 }
336 unsigned ArrSize = ArrayTy->getNumElements();
337 if (isa(Init)) {
338 for (unsigned I = 0; I < ArrSize; ++I)
340 return;
341 }
342
343
344 if (auto *ArrayConstant = dyn_cast(Init)) {
345 for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) {
347 }
348 } else if (auto *DataArrayConstant = dyn_cast(Init)) {
349 for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) {
350 collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
351 }
352 } else {
354 "Expected a ConstantArray or ConstantDataArray for array initializer!");
355 }
356}
357
361
362 if (isa(Init))
364
365
366 if (isa(Init))
368
369 if (!isa(OrigType))
371
374 assert(FlattenedType->getNumElements() == FlattenedElements.size() &&
375 "The number of collected elements should match the FlattenedType");
377}
378
379static void
384 Type *OrigType = G.getValueType();
385 if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
386 continue;
387
388 ArrayType *ArrType = cast(OrigType);
389 auto [TotalElements, BaseType] =
390 DXILFlattenArraysVisitor::getElementCountAndType(ArrType);
391 ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
392
393
394
396 new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(),
397 nullptr, G.getName() + ".1dim", &G,
398 G.getThreadLocalMode(), G.getAddressSpace(),
399 G.isExternallyInitialized());
400
401
403 if (G.getAlignment() > 0) {
405 }
406
407 if (G.hasInitializer()) {
412 }
413 GlobalMap[&G] = NewGlobal;
414 }
415}
416
418 bool MadeChange = false;
419 DXILFlattenArraysVisitor Impl;
423 if (F.isDeclaration())
424 continue;
425 MadeChange |= Impl.visit(F);
426 }
427 for (auto &[Old, New] : GlobalMap) {
428 Old->replaceAllUsesWith(New);
429 Old->eraseFromParent();
430 MadeChange = true;
431 }
432 return MadeChange;
433}
434
437 if (!MadeChanges)
440 return PA;
441}
442
443bool DXILFlattenArraysLegacy::runOnModule(Module &M) {
445}
446
447char DXILFlattenArraysLegacy::ID = 0;
448
450 "DXIL Array Flattener", false, false)
453
455 return new DXILFlattenArraysLegacy();
456}
static void collectElements(Constant *Init, SmallVectorImpl< Constant * > &Elements)
static bool flattenArrays(Module &M)
static void flattenGlobalArrays(Module &M, DenseMap< GlobalVariable *, GlobalVariable * > &GlobalMap)
static Constant * transformInitializer(Constant *Init, Type *OrigType, ArrayType *FlattenedType, LLVMContext &Ctx)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static unsigned getNumElements(Type *Ty)
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Type * getAllocatedType() const
Return the type that is being allocated by the instruction.
void setAlignment(Align Align)
A container for analyses that lazily runs them and caches their results.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
size - Get the array size.
LLVM Basic Block Representation.
This class represents a no-op cast from one type to another.
This class represents a function call, abstracting a target machine's calling convention.
This is the base class for all instructions that perform data casts.
static ConstantAggregateZero * get(Type *Ty)
static Constant * get(ArrayType *T, ArrayRef< Constant * > V)
A constant value that is initialized with an expression using other constant values.
This is the shared class of boolean and integer constants.
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
This is an important base class in LLVM.
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
PreservedAnalyses run(Module &M, ModuleAnalysisManager &)
This instruction compares its operands according to the predicate given to the constructor.
This class represents a freeze function that returns random concrete value if an operand is either a ...
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
Type * getSourceElementType() const
void setAlignment(Align Align)
Sets the alignment attribute of the GlobalObject.
void setUnnamedAddr(UnnamedAddr Val)
void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
This instruction compares its operands according to the predicate given to the constructor.
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
ConstantInt * getInt32(uint32_t C)
Get a constant 32-bit value.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
This instruction inserts a single (scalar) element into a VectorType value.
Base class for instruction visitors.
RetTy visitFreezeInst(FreezeInst &I)
RetTy visitFCmpInst(FCmpInst &I)
RetTy visitExtractElementInst(ExtractElementInst &I)
RetTy visitShuffleVectorInst(ShuffleVectorInst &I)
RetTy visitBitCastInst(BitCastInst &I)
void visit(Iterator Start, Iterator End)
RetTy visitPHINode(PHINode &I)
RetTy visitUnaryOperator(UnaryOperator &I)
RetTy visitStoreInst(StoreInst &I)
RetTy visitInsertElementInst(InsertElementInst &I)
RetTy visitAllocaInst(AllocaInst &I)
RetTy visitBinaryOperator(BinaryOperator &I)
RetTy visitICmpInst(ICmpInst &I)
RetTy visitCallInst(CallInst &I)
RetTy visitCastInst(CastInst &I)
RetTy visitSelectInst(SelectInst &I)
RetTy visitGetElementPtrInst(GetElementPtrInst &I)
void visitInstruction(Instruction &I)
RetTy visitLoadInst(LoadInst &I)
void insertBefore(Instruction *InsertPos)
Insert an unlinked instruction into a basic block immediately before the specified instruction.
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
This is an important class for using LLVM in a threaded context.
An instruction for reading from memory.
void setAlignment(Align Align)
Align getAlign() const
Return the alignment of the access that is being performed.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
This class represents the LLVM 'select' instruction.
This instruction constructs a fixed permutation of two input vectors.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
void setAlignment(Align Align)
The instances of the Type class are immutable: once they are created, they are never changed.
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Value * getOperand(unsigned i) const
unsigned getNumOperands() const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
StringRef getName() const
Return a constant reference to the value's name.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ CE
Windows NT (Windows on ARM)
This is an optimization pass for GlobalISel generic memory operations.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
ModulePass * createDXILFlattenArraysLegacyPass()
Pass to flatten arrays into a one dimensional DXIL legal form.
bool RecursivelyDeleteTriviallyDeadInstructionsPermissive(SmallVectorImpl< WeakTrackingVH > &DeadInsts, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
Same functionality as RecursivelyDeleteTriviallyDeadInstructions, but allow instructions that are not...