LLVM 23.0.0git
SPIRVMergeRegionExitTargets.cpp
Go to the documentation of this file.
1//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Merge the multiple exit targets of a convergence region into a single block.
10// Each exit target will be assigned a constant value, and a phi node + switch
11// will allow the new exit target to re-route to the correct basic block.
12//
13//===----------------------------------------------------------------------===//
14
17#include "SPIRV.h"
18#include "SPIRVSubtarget.h"
19#include "SPIRVUtils.h"
20#include "llvm/ADT/DenseMap.h"
23#include "llvm/IR/Dominators.h"
24#include "llvm/IR/IRBuilder.h"
25#include "llvm/IR/Intrinsics.h"
30
31using namespace llvm;
32
33namespace {
34
35/// Create a value in BB set to the value associated with the branch the block
36/// terminator will take.
37static llvm::Value *
38createExitVariable(BasicBlock *BB,
39 const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
40 auto *T = BB->getTerminator();
41 if (isa<ReturnInst>(T))
42 return nullptr;
43 if (auto *BI = dyn_cast<UncondBrInst>(T))
44 return TargetToValue.lookup(BI->getSuccessor());
45
46 IRBuilder<> Builder(BB);
47 Builder.SetInsertPoint(T);
48
49 if (auto *BI = dyn_cast<CondBrInst>(T)) {
50 Value *LHS = TargetToValue.lookup(BI->getSuccessor(0));
51 Value *RHS = TargetToValue.lookup(BI->getSuccessor(1));
52
53 if (LHS == nullptr || RHS == nullptr)
54 return LHS == nullptr ? RHS : LHS;
55 return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
56 }
57
58 // TODO: add support for switch cases.
59 llvm_unreachable("Unhandled terminator type.");
60}
61
62static AllocaInst *createVariable(Function &F, Type *Type,
63 BasicBlock::iterator Position) {
64 const DataLayout &DL = F.getDataLayout();
65 return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
66 Position);
67}
68
69// Run the pass on the given convergence region, ignoring the sub-regions.
70// Returns true if the CFG changed, false otherwise.
71static bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
73 // Gather all the exit targets for this region.
75 for (BasicBlock *Exit : CR->Exits) {
76 for (BasicBlock *Target : successors(Exit)) {
77 if (CR->Blocks.count(Target) == 0)
78 ExitTargets.insert(Target);
79 }
80 }
81
82 // If we have zero or one exit target, nothing do to.
83 if (ExitTargets.size() <= 1)
84 return false;
85
86 // Create the new single exit target.
87 auto F = CR->Entry->getParent();
88 auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
89 IRBuilder<> Builder(NewExitTarget);
90
91 AllocaInst *Variable = createVariable(*F, Builder.getInt32Ty(),
92 F->begin()->getFirstInsertionPt());
93
94 // CodeGen output needs to be stable. Using the set as-is would order
95 // the targets differently depending on the allocation pattern.
96 // Sorting per basic-block ordering in the function.
97 std::vector<BasicBlock *> SortedExitTargets;
98 std::vector<BasicBlock *> SortedExits;
99 for (BasicBlock &BB : *F) {
100 if (ExitTargets.count(&BB) != 0)
101 SortedExitTargets.push_back(&BB);
102 if (CR->Exits.count(&BB) != 0)
103 SortedExits.push_back(&BB);
104 }
105
106 // Creating one constant per distinct exit target. This will be route to the
107 // correct target.
109 for (BasicBlock *Target : SortedExitTargets)
110 TargetToValue.insert(
111 std::make_pair(Target, Builder.getInt32(TargetToValue.size())));
112
113 // Creating one variable per exit node, set to the constant matching the
114 // targeted external block.
115 std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
116 for (auto Exit : SortedExits) {
117 llvm::Value *Value = createExitVariable(Exit, TargetToValue);
118 IRBuilder<> B2(Exit);
119 B2.SetInsertPoint(Exit->getFirstInsertionPt());
120 B2.CreateStore(Value, Variable);
121 ExitToVariable.emplace_back(std::make_pair(Exit, Value));
122 }
123
124 llvm::Value *Load = Builder.CreateLoad(Builder.getInt32Ty(), Variable);
125
126 // Creating the switch to jump to the correct exit target.
127 llvm::SwitchInst *Sw = Builder.CreateSwitch(Load, SortedExitTargets[0],
128 SortedExitTargets.size() - 1);
129 for (size_t i = 1; i < SortedExitTargets.size(); i++) {
130 BasicBlock *BB = SortedExitTargets[i];
131 Sw->addCase(TargetToValue[BB], BB);
132 }
133
134 // Fix exit branches to redirect to the new exit.
135 for (auto Exit : CR->Exits) {
136 Instruction *T = Exit->getTerminator();
137 for (auto I = succ_begin(T), E = succ_end(T); I != E; ++I)
138 if (ExitTargets.contains(*I))
139 I.getUse()->set(NewExitTarget);
140 }
141
142 CR = CR->Parent;
143 while (CR) {
144 CR->Blocks.insert(NewExitTarget);
145 CR = CR->Parent;
146 }
147
148 return true;
149}
150
151/// Run the pass on the given convergence region and sub-regions (DFS).
152/// Returns true if a region/sub-region was modified, false otherwise.
153/// This returns as soon as one region/sub-region has been modified.
154static bool runOnConvergenceRegion(LoopInfo &LI, SPIRV::ConvergenceRegion *CR) {
155 for (auto *Child : CR->Children)
156 if (runOnConvergenceRegion(LI, Child))
157 return true;
158
159 return runOnConvergenceRegionNoRecurse(LI, CR);
160}
161
162#if !NDEBUG
163/// Validates each edge exiting the region has the same destination basic
164/// block.
165static void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
166 for (auto *Child : CR->Children)
167 validateRegionExits(Child);
168
169 std::unordered_set<BasicBlock *> ExitTargets;
170 for (auto *Exit : CR->Exits) {
171 for (auto *BB : successors(Exit)) {
172 if (CR->Blocks.count(BB) == 0)
173 ExitTargets.insert(BB);
174 }
175 }
176
177 assert(ExitTargets.size() <= 1);
178}
179#endif
180
181static bool runImpl(Function &F, LoopInfo &LI,
183 auto *TopLevelRegion = RegionInfo.getWritableTopLevelRegion();
184
185 // FIXME: very inefficient method: each time a region is modified, we bubble
186 // back up, and recompute the whole convergence region tree. Once the
187 // algorithm is completed and test coverage good enough, rewrite this pass
188 // to be efficient instead of simple.
189 bool Modified = false;
190 while (runOnConvergenceRegion(LI, TopLevelRegion)) {
191 Modified = true;
192 }
193
194#if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
195 validateRegionExits(TopLevelRegion);
196#endif
197 return Modified;
198}
199
200class SPIRVMergeRegionExitTargetsLegacy : public FunctionPass {
201public:
202 static char ID;
203
204 SPIRVMergeRegionExitTargetsLegacy() : FunctionPass(ID) {}
205
206 bool runOnFunction(Function &F) override {
207 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
208 auto &RegionInfo = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
209 .getRegionInfo();
210 return runImpl(F, LI, RegionInfo);
211 }
212
213 void getAnalysisUsage(AnalysisUsage &AU) const override {
214 AU.addRequired<DominatorTreeWrapperPass>();
215 AU.addRequired<LoopInfoWrapperPass>();
216 AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
217
218 AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
219 FunctionPass::getAnalysisUsage(AU);
220 }
221};
222} // namespace
223
231
232char SPIRVMergeRegionExitTargetsLegacy::ID = 0;
233
234INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargetsLegacy,
235 "split-region-exit-blocks",
236 "SPIRV split region exit blocks", false, false)
237INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
241
242INITIALIZE_PASS_END(SPIRVMergeRegionExitTargetsLegacy,
243 "split-region-exit-blocks",
244 "SPIRV split region exit blocks", false, false)
245
247 return new SPIRVMergeRegionExitTargetsLegacy();
248}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
This file defines the DenseMap class.
static bool runOnFunction(Function &F, bool PostInlining)
static bool runImpl(Function &F, const TargetLowering &TLI, const LibcallLoweringInfo &Libcalls, AssumptionCache *AC)
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define T
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
This file defines the SmallPtrSet class.
Value * RHS
Value * LHS
an instruction to allocate memory on the stack
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition BasicBlock.h:206
InstListType::iterator iterator
Instruction iterators...
Definition BasicBlock.h:170
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
Definition BasicBlock.h:237
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:205
unsigned size() const
Definition DenseMap.h:110
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:241
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:316
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2847
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:596
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition Analysis.h:115
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
SmallVector< ConvergenceRegion * > Children
size_type size() const
Definition SmallPtrSet.h:99
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Multiway switch.
LLVM_ABI void addCase(ConstantInt *OnVal, BasicBlock *Dest)
Add an entry to the switch instruction.
Target - Wrapper for Target specific information.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
LLVM Value Representation.
Definition Value.h:75
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
This is an optimization pass for GlobalISel generic memory operations.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
auto successors(const MachineBasicBlock *BB)
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
RNSuccIterator< NodeRef, BlockT, RegionT > succ_begin(NodeRef Node)
RNSuccIterator< NodeRef, BlockT, RegionT > succ_end(NodeRef Node)
FunctionPass * createSPIRVMergeRegionExitTargetsPass()
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.