LLVM  10.0.0svn
AMDGPUUnifyDivergentExitNodes.cpp
Go to the documentation of this file.
1 //===- AMDGPUUnifyDivergentExitNodes.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a variant of the UnifyDivergentExitNodes pass. Rather than ensuring
10 // there is at most one ret and one unreachable instruction, it ensures there is
11 // at most one divergent exiting block.
12 //
13 // StructurizeCFG can't deal with multi-exit regions formed by branches to
14 // multiple return nodes. It is not desirable to structurize regions with
15 // uniform branches, so unifying those to the same return block as divergent
16 // branches inhibits use of scalar branching. It still can't deal with the case
17 // where one branch goes to return, and one unreachable. Replace unreachable in
18 // this case with a return.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "AMDGPU.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/SmallPtrSet.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringRef.h"
31 #include "llvm/IR/BasicBlock.h"
32 #include "llvm/IR/CFG.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/Function.h"
35 #include "llvm/IR/InstrTypes.h"
36 #include "llvm/IR/Instructions.h"
37 #include "llvm/IR/Intrinsics.h"
38 #include "llvm/IR/Type.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/Casting.h"
41 #include "llvm/Transforms/Scalar.h"
42 #include "llvm/Transforms/Utils.h"
43 
44 using namespace llvm;
45 
46 #define DEBUG_TYPE "amdgpu-unify-divergent-exit-nodes"
47 
48 namespace {
49 
50 class AMDGPUUnifyDivergentExitNodes : public FunctionPass {
51 public:
52  static char ID; // Pass identification, replacement for typeid
53 
54  AMDGPUUnifyDivergentExitNodes() : FunctionPass(ID) {
56  }
57 
58  // We can preserve non-critical-edgeness when we unify function exit nodes
59  void getAnalysisUsage(AnalysisUsage &AU) const override;
60  bool runOnFunction(Function &F) override;
61 };
62 
63 } // end anonymous namespace
64 
66 
68 
69 INITIALIZE_PASS_BEGIN(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE,
70  "Unify divergent function exit nodes", false, false)
73 INITIALIZE_PASS_END(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE,
74  "Unify divergent function exit nodes", false, false)
75 
76 void AMDGPUUnifyDivergentExitNodes::getAnalysisUsage(AnalysisUsage &AU) const{
77  // TODO: Preserve dominator tree.
78  AU.addRequired<PostDominatorTreeWrapperPass>();
79 
80  AU.addRequired<LegacyDivergenceAnalysis>();
81 
82  // No divergent values are changed, only blocks and branch edges.
83  AU.addPreserved<LegacyDivergenceAnalysis>();
84 
85  // We preserve the non-critical-edgeness property
86  AU.addPreservedID(BreakCriticalEdgesID);
87 
88  // This is a cluster of orthogonal Transforms
89  AU.addPreservedID(LowerSwitchID);
91 
92  AU.addRequired<TargetTransformInfoWrapperPass>();
93 }
94 
95 /// \returns true if \p BB is reachable through only uniform branches.
96 /// XXX - Is there a more efficient way to find this?
98  BasicBlock &BB) {
101 
102  for (BasicBlock *Pred : predecessors(&BB))
103  Stack.push_back(Pred);
104 
105  while (!Stack.empty()) {
106  BasicBlock *Top = Stack.pop_back_val();
107  if (!DA.isUniform(Top->getTerminator()))
108  return false;
109 
110  for (BasicBlock *Pred : predecessors(Top)) {
111  if (Visited.insert(Pred).second)
112  Stack.push_back(Pred);
113  }
114  }
115 
116  return true;
117 }
118 
120  ArrayRef<BasicBlock *> ReturningBlocks,
121  const TargetTransformInfo &TTI,
122  StringRef Name) {
123  // Otherwise, we need to insert a new basic block into the function, add a PHI
124  // nodes (if the function returns values), and convert all of the return
125  // instructions into unconditional branches.
126  BasicBlock *NewRetBlock = BasicBlock::Create(F.getContext(), Name, &F);
127 
128  PHINode *PN = nullptr;
129  if (F.getReturnType()->isVoidTy()) {
130  ReturnInst::Create(F.getContext(), nullptr, NewRetBlock);
131  } else {
132  // If the function doesn't return void... add a PHI node to the block...
133  PN = PHINode::Create(F.getReturnType(), ReturningBlocks.size(),
134  "UnifiedRetVal");
135  NewRetBlock->getInstList().push_back(PN);
136  ReturnInst::Create(F.getContext(), PN, NewRetBlock);
137  }
138 
139  // Loop over all of the blocks, replacing the return instruction with an
140  // unconditional branch.
141  for (BasicBlock *BB : ReturningBlocks) {
142  // Add an incoming element to the PHI node for every return instruction that
143  // is merging into this new block...
144  if (PN)
145  PN->addIncoming(BB->getTerminator()->getOperand(0), BB);
146 
147  // Remove and delete the return inst.
148  BB->getTerminator()->eraseFromParent();
149  BranchInst::Create(NewRetBlock, BB);
150  }
151 
152  for (BasicBlock *BB : ReturningBlocks) {
153  // Cleanup possible branch to unconditional branch to the return.
154  simplifyCFG(BB, TTI, {2});
155  }
156 
157  return NewRetBlock;
158 }
159 
161  auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
162  if (PDT.getRoots().size() <= 1)
163  return false;
164 
165  LegacyDivergenceAnalysis &DA = getAnalysis<LegacyDivergenceAnalysis>();
166 
167  // Loop over all of the blocks in a function, tracking all of the blocks that
168  // return.
169  SmallVector<BasicBlock *, 4> ReturningBlocks;
170  SmallVector<BasicBlock *, 4> UnreachableBlocks;
171 
172  // Dummy return block for infinite loop.
173  BasicBlock *DummyReturnBB = nullptr;
174 
175  for (BasicBlock *BB : PDT.getRoots()) {
176  if (isa<ReturnInst>(BB->getTerminator())) {
177  if (!isUniformlyReached(DA, *BB))
178  ReturningBlocks.push_back(BB);
179  } else if (isa<UnreachableInst>(BB->getTerminator())) {
180  if (!isUniformlyReached(DA, *BB))
181  UnreachableBlocks.push_back(BB);
182  } else if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) {
183 
185  if (DummyReturnBB == nullptr) {
186  DummyReturnBB = BasicBlock::Create(F.getContext(),
187  "DummyReturnBlock", &F);
188  Type *RetTy = F.getReturnType();
189  Value *RetVal = RetTy->isVoidTy() ? nullptr : UndefValue::get(RetTy);
190  ReturnInst::Create(F.getContext(), RetVal, DummyReturnBB);
191  ReturningBlocks.push_back(DummyReturnBB);
192  }
193 
194  if (BI->isUnconditional()) {
195  BasicBlock *LoopHeaderBB = BI->getSuccessor(0);
196  BI->eraseFromParent(); // Delete the unconditional branch.
197  // Add a new conditional branch with a dummy edge to the return block.
198  BranchInst::Create(LoopHeaderBB, DummyReturnBB, BoolTrue, BB);
199  } else { // Conditional branch.
200  // Create a new transition block to hold the conditional branch.
201  BasicBlock *TransitionBB = BB->splitBasicBlock(BI, "TransitionBlock");
202 
203  // Create a branch that will always branch to the transition block and
204  // references DummyReturnBB.
206  BranchInst::Create(TransitionBB, DummyReturnBB, BoolTrue, BB);
207  }
208  }
209  }
210 
211  if (!UnreachableBlocks.empty()) {
212  BasicBlock *UnreachableBlock = nullptr;
213 
214  if (UnreachableBlocks.size() == 1) {
215  UnreachableBlock = UnreachableBlocks.front();
216  } else {
217  UnreachableBlock = BasicBlock::Create(F.getContext(),
218  "UnifiedUnreachableBlock", &F);
219  new UnreachableInst(F.getContext(), UnreachableBlock);
220 
221  for (BasicBlock *BB : UnreachableBlocks) {
222  // Remove and delete the unreachable inst.
223  BB->getTerminator()->eraseFromParent();
224  BranchInst::Create(UnreachableBlock, BB);
225  }
226  }
227 
228  if (!ReturningBlocks.empty()) {
229  // Don't create a new unreachable inst if we have a return. The
230  // structurizer/annotator can't handle the multiple exits
231 
232  Type *RetTy = F.getReturnType();
233  Value *RetVal = RetTy->isVoidTy() ? nullptr : UndefValue::get(RetTy);
234  // Remove and delete the unreachable inst.
235  UnreachableBlock->getTerminator()->eraseFromParent();
236 
237  Function *UnreachableIntrin =
238  Intrinsic::getDeclaration(F.getParent(), Intrinsic::amdgcn_unreachable);
239 
240  // Insert a call to an intrinsic tracking that this is an unreachable
241  // point, in case we want to kill the active lanes or something later.
242  CallInst::Create(UnreachableIntrin, {}, "", UnreachableBlock);
243 
244  // Don't create a scalar trap. We would only want to trap if this code was
245  // really reached, but a scalar trap would happen even if no lanes
246  // actually reached here.
247  ReturnInst::Create(F.getContext(), RetVal, UnreachableBlock);
248  ReturningBlocks.push_back(UnreachableBlock);
249  }
250  }
251 
252  // Now handle return blocks.
253  if (ReturningBlocks.empty())
254  return false; // No blocks return
255 
256  if (ReturningBlocks.size() == 1)
257  return false; // Already has a single return block
258 
259  const TargetTransformInfo &TTI
260  = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
261 
262  unifyReturnBlockSet(F, ReturningBlocks, TTI, "UnifiedReturnBlock");
263  return true;
264 }
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks &#39;this&#39; from the containing basic block and deletes it.
Definition: Instruction.cpp:67
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
This class represents lattice values for constants.
Definition: AllocatorList.h:23
amdgpu Simplify well known AMD library false FunctionCallee Value const Twine & Name
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
F(f)
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.cpp:144
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, Instruction *InsertBefore=nullptr)
static bool isUniformlyReached(const LegacyDivergenceAnalysis &DA, BasicBlock &BB)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:50
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:96
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: APInt.h:32
Function * getDeclaration(Module *M, ID id, ArrayRef< Type *> Tys=None)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1093
bool isVoidTy() const
Return true if this is &#39;void&#39;.
Definition: Type.h:141
static bool runOnFunction(Function &F, bool PostInlining)
Type * getReturnType() const
Returns the type of the ret val.
Definition: Function.h:168
Wrapper pass for TargetTransformInfo.
LLVM Basic Block Representation.
Definition: BasicBlock.h:57
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:46
Conditional or Unconditional Branch instruction.
char & BreakCriticalEdgesID
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:148
This function has undefined behavior.
This file contains the declarations for the subclasses of Constant, which represent the different fla...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:370
Represent the analysis usage information of a pass.
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:284
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:99
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function. ...
Definition: Function.cpp:205
char & LowerSwitchID
static UndefValue * get(Type *T)
Static factory methods - Return an &#39;undef&#39; object of the specified type.
Definition: Constants.cpp:1446
size_t size() const
Definition: SmallVector.h:52
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE, "Assign register bank of generic virtual registers", false, false) RegBankSelect
const InstListType & getInstList() const
Return the underlying instruction list container.
Definition: BasicBlock.h:338
INITIALIZE_PASS_BEGIN(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE, "Unify divergent function exit nodes", false, false) INITIALIZE_PASS_END(AMDGPUUnifyDivergentExitNodes
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements...
Definition: SmallPtrSet.h:417
This is the shared class of boolean and integer constants.
Definition: Constants.h:83
char & AMDGPUUnifyDivergentExitNodesID
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:837
LLVM_NODISCARD T pop_back_val()
Definition: SmallVector.h:374
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
pred_range predecessors(BasicBlock *BB)
Definition: CFG.h:124
static ConstantInt * getTrue(LLVMContext &Context)
Definition: Constants.cpp:609
void push_back(pointer val)
Definition: ilist.h:311
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
iterator_range< typename GraphTraits< GraphType >::nodes_iterator > nodes(const GraphType &G)
Definition: GraphTraits.h:108
LLVM_NODISCARD bool empty() const
Definition: SmallVector.h:55
SymbolTableList< BasicBlock >::iterator eraseFromParent()
Unlink &#39;this&#39; from the containing function and delete it.
Definition: BasicBlock.cpp:121
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="")
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:414
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:575
LLVM Value Representation.
Definition: Value.h:74
bool isUniform(const Value *V) const
print Print MemDeps of function
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:48
bool simplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI, const SimplifyCFGOptions &Options={}, SmallPtrSetImpl< BasicBlock *> *LoopHeaders=nullptr)
This function is used to do simplification of a CFG.
This pass exposes codegen information to IR-level passes.
void initializeAMDGPUUnifyDivergentExitNodesPass(PassRegistry &)
static BasicBlock * unifyReturnBlockSet(Function &F, ArrayRef< BasicBlock *> ReturningBlocks, const TargetTransformInfo &TTI, StringRef Name)