LLVM: lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
29
30using namespace llvm;
31
32namespace {
33
34class SPIRVMergeRegionExitTargets : public FunctionPass {
35public:
36 static char ID;
37
38 SPIRVMergeRegionExitTargets() : FunctionPass(ID) {}
39
40
41
42 std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
43 std::unordered_set<BasicBlock *> output;
45
47 output.insert(BI->getSuccessor(0));
48 if (BI->isConditional())
49 output.insert(BI->getSuccessor(1));
50 return output;
51 }
52
54 output.insert(SI->getDefaultDest());
55 for (auto &Case : SI->cases())
56 output.insert(Case.getCaseSuccessor());
57 return output;
58 }
59
61 return output;
62 }
63
64
65
67 BasicBlock *BB,
68 const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
71 return nullptr;
72
74 Builder.SetInsertPoint(T);
75
77
78 BasicBlock *LHSTarget = BI->getSuccessor(0);
80 BI->isConditional() ? BI->getSuccessor(1) : nullptr;
81
84
85 if (LHS == nullptr || RHS == nullptr)
86 return LHS == nullptr ? RHS : LHS;
87 return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
88 }
89
90
92 }
93
94
96 const SmallPtrSet<BasicBlock *, 4> &ToReplace,
97 BasicBlock *NewTarget) {
100 return;
101
103 for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
104 if (ToReplace.count(BI->getSuccessor(i)) != 0)
105 BI->setSuccessor(i, NewTarget);
106 }
107 return;
108 }
109
111 for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
112 if (ToReplace.count(SI->getSuccessor(i)) != 0)
113 SI->setSuccessor(i, NewTarget);
114 }
115 return;
116 }
117
118 assert(false && "Unhandled terminator type.");
119 }
120
121 AllocaInst *CreateVariable(Function &F, Type *Type,
123 const DataLayout &DL = F.getDataLayout();
124 return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
125 Position);
126 }
127
128
129
130 bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
131 SPIRV::ConvergenceRegion *CR) {
132
133 SmallPtrSet<BasicBlock *, 4> ExitTargets;
134 for (BasicBlock *Exit : CR->Exits) {
135 for (BasicBlock *Target : gatherSuccessors(Exit)) {
137 ExitTargets.insert(Target);
138 }
139 }
140
141
142 if (ExitTargets.size() <= 1)
143 return false;
144
145
149
150 AllocaInst *Variable = CreateVariable(*F, Builder.getInt32Ty(),
151 F->begin()->getFirstInsertionPt());
152
153
154
155
156 std::vector<BasicBlock *> SortedExitTargets;
157 std::vector<BasicBlock *> SortedExits;
158 for (BasicBlock &BB : *F) {
159 if (ExitTargets.count(&BB) != 0)
160 SortedExitTargets.push_back(&BB);
161 if (CR->Exits.count(&BB) != 0)
162 SortedExits.push_back(&BB);
163 }
164
165
166
167 DenseMap<BasicBlock *, ConstantInt *> TargetToValue;
168 for (BasicBlock *Target : SortedExitTargets)
169 TargetToValue.insert(
170 std::make_pair(Target, Builder.getInt32(TargetToValue.size())));
171
172
173
174 std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
175 for (auto Exit : SortedExits) {
176 llvm::Value *Value = createExitVariable(Exit, TargetToValue);
178 B2.SetInsertPoint(Exit->getFirstInsertionPt());
179 B2.CreateStore(Value, Variable);
180 ExitToVariable.emplace_back(std::make_pair(Exit, Value));
181 }
182
183 llvm::Value *Load = Builder.CreateLoad(Builder.getInt32Ty(), Variable);
184
185
186 llvm::SwitchInst *Sw = Builder.CreateSwitch(Load, SortedExitTargets[0],
187 SortedExitTargets.size() - 1);
188 for (size_t i = 1; i < SortedExitTargets.size(); i++) {
189 BasicBlock *BB = SortedExitTargets[i];
190 Sw->addCase(TargetToValue[BB], BB);
191 }
192
193
194 for (auto Exit : CR->Exits)
196
198 while (CR) {
201 }
202
203 return true;
204 }
205
206
207
208
209 bool runOnConvergenceRegion(LoopInfo &LI, SPIRV::ConvergenceRegion *CR) {
210 for (auto *Child : CR->Children)
211 if (runOnConvergenceRegion(LI, Child))
212 return true;
213
214 return runOnConvergenceRegionNoRecurse(LI, CR);
215 }
216
217#if !NDEBUG
218
219
220 void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
221 for (auto *Child : CR->Children)
222 validateRegionExits(Child);
223
224 std::unordered_set<BasicBlock *> ExitTargets;
225 for (auto *Exit : CR->Exits) {
226 auto Set = gatherSuccessors(Exit);
227 for (auto *BB : Set) {
229 ExitTargets.insert(BB);
230 }
231 }
232
233 assert(ExitTargets.size() <= 1);
234 }
235#endif
236
238 LoopInfo &LI = getAnalysis().getLoopInfo();
239 auto *TopLevelRegion =
240 getAnalysis()
241 .getRegionInfo()
242 .getWritableTopLevelRegion();
243
244
245
246
247
248 bool modified = false;
249 while (runOnConvergenceRegion(LI, TopLevelRegion)) {
250 modified = true;
251 }
252
253#if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
254 validateRegionExits(TopLevelRegion);
255#endif
256 return modified;
257 }
258
259 void getAnalysisUsage(AnalysisUsage &AU) const override {
260 AU.addRequired();
262 AU.addRequired();
263
264 AU.addPreserved();
265 FunctionPass::getAnalysisUsage(AU);
266 }
267};
268}
269
270char SPIRVMergeRegionExitTargets::ID = 0;
271
273 "SPIRV split region exit blocks", false, false)
278
281
283 return new SPIRVMergeRegionExitTargets();
284}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file defines the DenseMap class.
static bool runOnFunction(Function &F, bool PostInlining)
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
static void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, BasicBlock *NewTarget)
This file defines the SmallPtrSet class.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
const Function * getParent() const
Return the enclosing method, or null if none.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Legacy analysis pass which computes a DominatorTree.
FunctionPass class - This class is used to implement most global optimizations.
The legacy pass manager's analysis pass to compute loop information.
SmallVector< ConvergenceRegion * > Children
SmallPtrSet< BasicBlock *, 2 > Exits
ConvergenceRegion * Parent
SmallPtrSet< BasicBlock *, 8 > Blocks
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
LLVM_ABI void addCase(ConstantInt *OnVal, BasicBlock *Dest)
Add an entry to the switch instruction.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ BasicBlock
Various leaf nodes.
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
decltype(auto) dyn_cast(const From &Val)
dyn_cast - Return the argument parameter cast to the specified type.
bool isa(const From &Val)
isa - Return true if the parameter to the template is an instance of one of the template type argu...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
FunctionPass * createSPIRVMergeRegionExitTargetsPass()
Definition SPIRVMergeRegionExitTargets.cpp:282