LLVM: lib/Target/DirectX/DXILFlattenArrays.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
25#include
26#include
27#include
28#include
29
30#define DEBUG_TYPE "dxil-flatten-arrays"
31
32using namespace llvm;
33namespace {
34
35class DXILFlattenArraysLegacy : public ModulePass {
36
37public:
38 bool runOnModule(Module &M) override;
40
41 static char ID;
42};
43
44struct GEPInfo {
45 ArrayType *RootFlattenedArrayType;
46 Value *RootPointerOperand;
48 APInt ConstantOffset;
49};
50
51class DXILFlattenArraysVisitor
52 : public InstVisitor<DXILFlattenArraysVisitor, bool> {
53public:
54 DXILFlattenArraysVisitor(
56 : GlobalMap(GlobalMap) {}
58
59
62 bool visitInstruction(Instruction &I) { return false; }
63 bool visitSelectInst(SelectInst &SI) { return false; }
64 bool visitICmpInst(ICmpInst &ICI) { return false; }
65 bool visitFCmpInst(FCmpInst &FCI) { return false; }
66 bool visitUnaryOperator(UnaryOperator &UO) { return false; }
67 bool visitBinaryOperator(BinaryOperator &BO) { return false; }
68 bool visitCastInst(CastInst &CI) { return false; }
69 bool visitBitCastInst(BitCastInst &BCI) { return false; }
70 bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
71 bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
72 bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
73 bool visitPHINode(PHINode &PHI) { return false; }
74 bool visitLoadInst(LoadInst &LI);
76 bool visitCallInst(CallInst &ICI) { return false; }
77 bool visitFreezeInst(FreezeInst &FI) { return false; }
78 static bool isMultiDimensionalArray(Type *T);
79 static std::pair<unsigned, Type *> getElementCountAndType(Type *ArrayTy);
80
81private:
85 bool finish();
92};
93}
94
95bool DXILFlattenArraysVisitor::finish() {
96 GEPChainInfoMap.clear();
98 return true;
99}
100
101bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
104 return false;
105}
106
107std::pair<unsigned, Type *>
108DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) {
109 unsigned TotalElements = 1;
110 Type *CurrArrayTy = ArrayTy;
112 TotalElements *= InnerArrayTy->getNumElements();
113 CurrArrayTy = InnerArrayTy->getElementType();
114 }
115 return std::make_pair(TotalElements, CurrArrayTy);
116}
117
118ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(
121 "Indicies and dimmensions should be the same");
122 unsigned FlatIndex = 0;
123 unsigned Multiplier = 1;
124
125 for (int I = Indices.size() - 1; I >= 0; --I) {
126 unsigned DimSize = Dims[I];
128 assert(CIndex && "This function expects all indicies to be ConstantInt");
129 FlatIndex += CIndex->getZExtValue() * Multiplier;
130 Multiplier *= DimSize;
131 }
132 return Builder.getInt32(FlatIndex);
133}
134
135Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(
137 if (Indices.size() == 1)
138 return Indices[0];
139
141 unsigned Multiplier = 1;
142
143 for (int I = Indices.size() - 1; I >= 0; --I) {
144 unsigned DimSize = Dims[I];
145 Value *VMultiplier = Builder.getInt32(Multiplier);
146 Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier);
147 FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);
148 Multiplier *= DimSize;
149 }
150 return FlatIndex;
151}
152
153bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
155 for (unsigned I = 0; I < NumOperands; ++I) {
158 if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
159 GetElementPtrInst *OldGEP =
162
164 LoadInst *NewLoad =
169 visitGetElementPtrInst(*OldGEP);
170 return true;
171 }
172 }
173 return false;
174}
175
176bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
177 unsigned NumOperands = SI.getNumOperands();
178 for (unsigned I = 0; I < NumOperands; ++I) {
179 Value *CurrOpperand = SI.getOperand(I);
181 if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
182 GetElementPtrInst *OldGEP =
185
187 StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
189 SI.replaceAllUsesWith(NewStore);
190 SI.eraseFromParent();
191 visitGetElementPtrInst(*OldGEP);
192 return true;
193 }
194 }
195 return false;
196}
197
198bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
200 return false;
201
204 auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
205
206 ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
207 AllocaInst *FlatAlloca =
208 Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".1dim");
212 return true;
213}
214
215bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
216
218 return false;
219
220 Value *PtrOperand = GEP.getPointerOperand();
221
222
223
225 "Pointer operand of GEP should not be a PHI Node");
226
227
228
230 PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) {
231 GetElementPtrInst *OldGEPI =
234
238 Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices,
239 GEP.getName(), GEP.getNoWrapFlags());
241 "Expected newly-created GEP to be an instruction");
243
244 GEP.replaceAllUsesWith(NewGEPI);
245 GEP.eraseFromParent();
246 visitGetElementPtrInst(*OldGEPI);
247 visitGetElementPtrInst(*NewGEPI);
248 return true;
249 }
250
251
252 GEPInfo Info;
253
254
255 const DataLayout &DL = GEP.getDataLayout();
256 unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType());
258 [[maybe_unused]] bool Success = GEP.collectOffset(
260 assert(Success && "Failed to collect offsets for GEP");
261
262
263
264
266
267
268
269
270 if (!GEPChainInfoMap.contains(PtrOpGEP))
271 return false;
272
273 GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];
274 Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType;
275 Info.RootPointerOperand = PGEPInfo.RootPointerOperand;
276 for (auto &VariableOffset : PGEPInfo.VariableOffsets)
277 Info.VariableOffsets.insert(VariableOffset);
278 Info.ConstantOffset += PGEPInfo.ConstantOffset;
279 } else {
280 Info.RootPointerOperand = PtrOperand;
281
282
283
284
285 Type *RootTy = GEP.getSourceElementType();
287 if (GlobalMap.contains(GlobalVar))
290 RootTy = GlobalVar->getValueType();
292 RootTy = Alloca->getAllocatedType();
293 assert(!isMultiDimensionalArray(RootTy) &&
294 "Expected root array type to be flattened");
295
296
298 return false;
299
301 }
302
303
304
305
306 bool ReplaceThisGEP = GEP.users().empty();
307 for (Value *User : GEP.users())
309 ReplaceThisGEP = true;
310
311 if (ReplaceThisGEP) {
312 unsigned BytesPerElem =
313 DL.getTypeAllocSize(Info.RootFlattenedArrayType->getArrayElementType());
315 "Bytes per element should be a power of 2");
316
317
318
321 uint64_t ConstantOffset =
323 assert(ConstantOffset < UINT32_MAX &&
324 "Constant byte offset for flat GEP index must fit within 32 bits");
325 Value *FlattenedIndex = Builder.getInt32(ConstantOffset);
326 for (auto [VarIndex, Multiplier] : Info.VariableOffsets) {
327 assert(Multiplier.getActiveBits() <= 32 &&
328 "The multiplier for a flat GEP index must fit within 32 bits");
329 assert(VarIndex->getType()->isIntegerTy(32) &&
330 "Expected i32-typed GEP indices");
332 if (Multiplier.getZExtValue() % BytesPerElem != 0) {
333
334
335
337 Builder.getInt32(Multiplier.getZExtValue()));
339 } else
341 VarIndex,
342 Builder.getInt32(Multiplier.getZExtValue() / BytesPerElem));
343 FlattenedIndex = Builder.CreateAdd(FlattenedIndex, VI);
344 }
345
346
348 Info.RootFlattenedArrayType, Info.RootPointerOperand,
349 {ZeroIndex, FlattenedIndex}, GEP.getName(), GEP.getNoWrapFlags());
350
351
352
353
354
357 Info.RootFlattenedArrayType, Info.RootPointerOperand,
358 {ZeroIndex, FlattenedIndex}, GEP.getNoWrapFlags(), GEP.getName(),
360
361
362
364 GEP.replaceAllUsesWith(NewGEP);
365 GEP.eraseFromParent();
366 return true;
367 }
368
369
370
371
373 PotentiallyDeadInstrs.emplace_back(&GEP);
374 return false;
375}
376
377bool DXILFlattenArraysVisitor::visit(Function &F) {
378 bool MadeChange = false;
379 ReversePostOrderTraversal<Function *> RPOT(&F);
383 }
384 finish();
385 return MadeChange;
386}
387
390
392 if (!ArrayTy) {
393 Elements.push_back(Init);
394 return;
395 }
396 unsigned ArrSize = ArrayTy->getNumElements();
398 for (unsigned I = 0; I < ArrSize; ++I)
400 return;
401 }
402
403
405 for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) {
407 }
409 for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) {
410 collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
411 }
412 } else {
414 "Expected a ConstantArray or ConstantDataArray for array initializer!");
415 }
416}
417
421
424
425
428
431
434 assert(FlattenedType->getNumElements() == FlattenedElements.size() &&
435 "The number of collected elements should match the FlattenedType");
437}
438
443 Type *OrigType = G.getValueType();
444 if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
445 continue;
446
448 auto [TotalElements, BaseType] =
449 DXILFlattenArraysVisitor::getElementCountAndType(ArrType);
451
452
453
455 new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(),
456 nullptr, G.getName() + ".1dim", &G,
457 G.getThreadLocalMode(), G.getAddressSpace(),
458 G.isExternallyInitialized());
459
460
462 if (G.getAlignment() > 0) {
464 }
465
466 if (G.hasInitializer()) {
471 }
472 GlobalMap[&G] = NewGlobal;
473 }
474}
475
477 bool MadeChange = false;
480 DXILFlattenArraysVisitor Impl(GlobalMap);
482 if (F.isDeclaration())
483 continue;
484 MadeChange |= Impl.visit(F);
485 }
486 for (auto &[Old, New] : GlobalMap) {
487 Old->replaceAllUsesWith(New);
488 Old->eraseFromParent();
489 MadeChange = true;
490 }
491 return MadeChange;
492}
493
496 if (!MadeChanges)
499 return PA;
500}
501
502bool DXILFlattenArraysLegacy::runOnModule(Module &M) {
504}
505
506char DXILFlattenArraysLegacy::ID = 0;
507
509 "DXIL Array Flattener", false, false)
512
514 return new DXILFlattenArraysLegacy();
515}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Analysis containing CSE Info
static Constant * transformInitializer(Constant *Init, Type *OrigType, Type *NewType, LLVMContext &Ctx)
static void collectElements(Constant *Init, SmallVectorImpl< Constant * > &Elements)
Definition DXILFlattenArrays.cpp:388
static bool flattenArrays(Module &M)
Definition DXILFlattenArrays.cpp:476
static void flattenGlobalArrays(Module &M, SmallDenseMap< GlobalVariable *, GlobalVariable * > &GlobalMap)
Definition DXILFlattenArrays.cpp:439
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
BaseType
A given derived pointer can have multiple base pointers through phi/selects.
Class for arbitrary precision integers.
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Type * getAllocatedType() const
Return the type that is being allocated by the instruction.
void setAlignment(Align Align)
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
size - Get the array size.
static LLVM_ABI ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
This class represents a no-op cast from one type to another.
This class represents a function call, abstracting a target machine's calling convention.
This is the base class for all instructions that perform data casts.
static LLVM_ABI ConstantAggregateZero * get(Type *Ty)
static LLVM_ABI Constant * get(ArrayType *T, ArrayRef< Constant * > V)
This is the shared class of boolean and integer constants.
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
This is an important base class in LLVM.
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
PreservedAnalyses run(Module &M, ModuleAnalysisManager &)
Definition DXILFlattenArrays.cpp:494
This instruction compares its operands according to the predicate given to the constructor.
This class represents a freeze function that returns random concrete value if an operand is either a ...
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
static GetElementPtrInst * Create(Type *PointeeType, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
void setUnnamedAddr(UnnamedAddr Val)
LLVM_ABI void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
void setAlignment(Align Align)
Sets the alignment attribute of the GlobalVariable.
This instruction compares its operands according to the predicate given to the constructor.
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
BasicBlock::iterator GetInsertPoint() const
Value * CreateLShr(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
ConstantInt * getInt32(uint32_t C)
Get a constant 32-bit value.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
This instruction inserts a single (scalar) element into a VectorType value.
Base class for instruction visitors.
void visit(Iterator Start, Iterator End)
LLVM_ABI void insertBefore(InstListType::iterator InsertPos)
Insert an unlinked instruction into a basic block immediately before the specified position.
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
This is an important class for using LLVM in a threaded context.
An instruction for reading from memory.
void setAlignment(Align Align)
Align getAlign() const
Return the alignment of the access that is being performed.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
A Module instance is used to store all the information related to an LLVM module.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
This class represents the LLVM 'select' instruction.
This instruction constructs a fixed permutation of two input vectors.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
void setAlignment(Align Align)
The instances of the Type class are immutable: once they are created, they are never changed.
static LLVM_ABI UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Value * getOperand(unsigned i) const
unsigned getNumOperands() const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
self_iterator getIterator()
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ CE
Windows NT (Windows on ARM)
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
decltype(auto) dyn_cast(const From &Val)
dyn_cast - Return the argument parameter cast to the specified type.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
ModulePass * createDXILFlattenArraysLegacyPass()
Pass to flatten arrays into a one dimensional DXIL legal form.
Definition DXILFlattenArrays.cpp:513
unsigned Log2_32(uint32_t Value)
Return the floor log base 2 of the specified value, -1 if the value is zero.
constexpr bool isPowerOf2_32(uint32_t Value)
Return true if the argument is a power of two > 0.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa - Return true if the parameter to the template is an instance of one of the template type argu...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructionsPermissive(SmallVectorImpl< WeakTrackingVH > &DeadInsts, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
Same functionality as RecursivelyDeleteTriviallyDeadInstructions, but allow instructions that are not...
constexpr unsigned BitWidth
decltype(auto) cast(const From &Val)
cast - Return the argument parameter cast to the specified type.
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
A MapVector that performs no allocations if smaller than a certain size.