24#include "llvm/IR/IntrinsicsSPIRV.h"
37using Edge = std::pair<BasicBlock *, BasicBlock *>;
44 V.partialOrderVisit(Start, std::move(
Op));
51 if (
Node->Entry == BB)
54 for (
auto *Child :
Node->Children) {
74 if (ExitTargets.
size() == 0)
77 return *ExitTargets.
begin();
87 if (
II->getIntrinsicID() != Intrinsic::spv_loop_merge &&
88 II->getIntrinsicID() != Intrinsic::spv_selection_merge)
102 if (
II->getIntrinsicID() != Intrinsic::spv_loop_merge)
112 for (
auto &
I : Header) {
165 std::vector<Instruction *> Output;
168 Output.push_back(&
I);
189 std::stack<BasicBlock *> ToVisit;
192 ToVisit.push(&Start);
193 Seen.
insert(ToVisit.top());
194 while (ToVisit.size() != 0) {
218 for (
size_t i = 0; i < BI->getNumSuccessors(); i++) {
219 if (BI->getSuccessor(i) == OldTarget)
220 BI->setSuccessor(i, NewTarget);
224 if (BI->getSuccessor(0) != BI->getSuccessor(1))
230 Builder.SetInsertPoint(BI);
231 Builder.CreateBr(BI->getSuccessor(0));
232 BI->eraseFromParent();
243 if (!
II ||
II->getIntrinsicID() != Intrinsic::spv_selection_merge)
247 II->eraseFromParent();
248 if (!
C->isConstantUsed())
249 C->destroyConstant();
262 if (BI->getSuccessor() == OldTarget)
263 BI->setSuccessor(NewTarget);
271 for (
size_t i = 0; i <
SI->getNumSuccessors(); i++) {
272 if (
SI->getSuccessor(i) == OldTarget)
273 SI->setSuccessor(i, NewTarget);
278 assert(
false &&
"Unhandled terminator type.");
285 struct DivergentConstruct;
289 using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>;
295 struct DivergentConstruct {
300 DivergentConstruct *Parent =
nullptr;
301 ConstructList Children;
314 Splitter(Function &F, LoopInfo &LI) : F(F), LI(LI) { invalidate(); }
323 std::vector<BasicBlock *> getLoopConstructBlocks(BasicBlock *Header,
326 std::vector<BasicBlock *> Output;
330 if (DT.dominates(
Merge, BB) || !DT.dominates(Header, BB))
332 Output.push_back(BB);
339 std::vector<BasicBlock *>
340 getSelectionConstructBlocks(DivergentConstruct *Node) {
345 for (DivergentConstruct *It =
Node->Parent; It !=
nullptr;
347 OutsideBlocks.
insert(It->Merge);
349 OutsideBlocks.
insert(It->Continue);
352 std::vector<BasicBlock *> Output;
354 if (OutsideBlocks.count(BB) != 0)
356 if (DT.dominates(Node->Merge, BB) || !DT.dominates(Node->Header, BB))
358 Output.push_back(BB);
365 std::vector<BasicBlock *> getSwitchConstructBlocks(BasicBlock *Header,
369 std::vector<BasicBlock *> Output;
372 if (!DT.dominates(Header, BB))
378 Output.push_back(BB);
385 std::vector<BasicBlock *> getCaseConstructBlocks(BasicBlock *Target,
389 std::vector<BasicBlock *> Output;
393 if (!DT.dominates(Target, BB))
399 Output.push_back(BB);
428 createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
429 SmallPtrSet<BasicBlock *, 0> Seen;
430 std::vector<Edge> Output;
431 Output.reserve(Edges.size());
433 for (
auto &[Src, Dst] : Edges) {
439 F.getContext(), Src->getName() +
".new.src", &F);
442 Builder.CreateBr(Dst);
446 Output.emplace_back(Src, Dst);
454 BasicBlock *createSingleExitNode(BasicBlock *Header,
455 std::vector<Edge> &Edges) {
457 std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);
459 std::vector<BasicBlock *> Dsts;
460 DenseMap<BasicBlock *, ConstantInt *> DstToIndex;
462 Header->getName() +
".new.exit", &F);
464 for (
auto &[Src, Dst] : FixedEdges) {
465 if (DstToIndex.
count(Dst) != 0)
471 if (Dsts.size() == 1) {
472 for (
auto &[Src, Dst] : FixedEdges) {
475 ExitBuilder.CreateBr(Dsts[0]);
480 for (
auto &[Src, Dst] : FixedEdges) {
482 B2.SetInsertPoint(Src->getFirstInsertionPt());
483 B2.CreateStore(DstToIndex[Dst], Variable);
487 Value *
Load = ExitBuilder.CreateLoad(ExitBuilder.getInt32Ty(), Variable);
492 if (Dsts.size() == 2) {
495 ExitBuilder.CreateCondBr(Condition, Dsts[0], Dsts[1]);
499 SwitchInst *Sw = ExitBuilder.CreateSwitch(Load, Dsts[0], Dsts.size() - 1);
501 Sw->
addCase(DstToIndex[BB], BB);
510 Builder.CreateUnreachable();
515 bool addMergeForLoops(Function &
F) {
516 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
517 auto *TopLevelRegion =
518 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
520 .getTopLevelRegion();
544 if (
Merge ==
nullptr) {
546 Merge = CreateUnreachable(
F);
547 Builder.SetInsertPoint(Br);
549 Br->eraseFromParent();
557 SmallVector<Value *, 2>
Args = {MergeAddress, ContinueAddress};
558 SmallVector<unsigned, 1> LoopControlImms =
560 for (
unsigned Imm : LoopControlImms)
561 Args.emplace_back(ConstantInt::get(Builder.getInt32Ty(), Imm));
562 Builder.CreateIntrinsic(Intrinsic::spv_loop_merge, {
Args});
572 bool addMergeForNodesWithMultiplePredecessors(Function &
F) {
591 Builder.SetInsertPoint(Header->getTerminator());
594 createOpSelectMerge(&Builder, MergeAddress);
607 bool sortSelectionMerge(Function &
F, BasicBlock &
Block) {
608 std::vector<Instruction *> MergeInstructions;
609 for (Instruction &
I :
Block)
611 MergeInstructions.push_back(&
I);
613 if (MergeInstructions.size() <= 1)
616 Instruction *InsertionPoint = *MergeInstructions.begin();
618 PartialOrderingVisitor Visitor(
F);
619 std::sort(MergeInstructions.begin(), MergeInstructions.end(),
620 [&Visitor](Instruction *
Left, Instruction *
Right) {
623 BasicBlock *RightMerge = getDesignatedMergeBlock(Right);
624 BasicBlock *LeftMerge = getDesignatedMergeBlock(Left);
625 return !Visitor.compare(RightMerge, LeftMerge);
628 for (Instruction *
I : MergeInstructions) {
639 bool sortSelectionMergeHeaders(Function &
F) {
641 for (BasicBlock &BB :
F) {
649 bool splitBlocksWithMultipleHeaders(Function &
F) {
650 std::stack<BasicBlock *> Work;
653 if (MergeInstructions.size() <= 1)
658 const bool Modified = Work.size() > 0;
659 while (Work.size() > 0) {
663 std::vector<Instruction *> MergeInstructions =
665 for (
unsigned i = 1; i < MergeInstructions.size(); i++) {
667 Header->splitBasicBlock(MergeInstructions[i],
"new.header");
674 Builder.SetInsertPoint(Term);
675 Builder.CreateCondBr(Builder.getTrue(), NewBlock, Unreachable);
676 Term->eraseFromParent();
688 bool addMergeForDivergentBlocks(Function &
F) {
700 std::vector<BasicBlock *> Candidates;
709 if (Candidates.size() <= 1)
718 createOpSelectMerge(&Builder, MergeAddress);
726 std::vector<Edge> getExitsFrom(
const BlockSet &Construct,
727 BasicBlock &Header) {
728 std::vector<Edge> Output;
729 visit(Header, [&](BasicBlock *Item) {
730 if (Construct.
count(Item) == 0)
745 void constructDivergentConstruct(
BlockSet &Visited, Splitter &S,
746 BasicBlock *BB, DivergentConstruct *Parent) {
747 if (Visited.
count(BB) != 0)
752 if (MIS.size() == 0) {
754 constructDivergentConstruct(Visited, S,
Successor, Parent);
764 auto Output = std::make_unique<DivergentConstruct>();
766 Output->Merge =
Merge;
768 Output->Parent = Parent;
770 constructDivergentConstruct(Visited, S,
Merge, Parent);
772 constructDivergentConstruct(Visited, S,
Continue, Output.get());
775 constructDivergentConstruct(Visited, S,
Successor, Output.get());
778 Parent->Children.emplace_back(std::move(Output));
782 BlockSet getConstructBlocks(Splitter &S, DivergentConstruct *Node) {
785 if (
Node->Continue) {
786 auto LoopBlocks = S.getLoopConstructBlocks(
Node->Header,
Node->Merge);
787 return BlockSet(LoopBlocks.begin(), LoopBlocks.end());
790 auto SelectionBlocks = S.getSelectionConstructBlocks(Node);
791 return BlockSet(SelectionBlocks.begin(), SelectionBlocks.end());
796 bool fixupConstruct(Splitter &S, DivergentConstruct *Node) {
798 for (
auto &Child :
Node->Children)
799 Modified |= fixupConstruct(S, Child.get());
803 if (
Node->Parent ==
nullptr)
809 if (
Node->Parent->Header ==
nullptr)
816 BlockSet ConstructBlocks = getConstructBlocks(S, Node);
817 auto Edges = getExitsFrom(ConstructBlocks, *
Node->Header);
820 if (Edges.size() < 1)
823 bool HasBadEdge =
Node->Merge ==
Node->Parent->Merge ||
824 Node->Merge ==
Node->Parent->Continue;
826 for (
auto &[Src, Dst] : Edges) {
833 if (
Node->Merge == Dst)
838 if (
Node->Continue == Dst)
850 BasicBlock *NewExit = S.createSingleExitNode(
Node->Header, Edges);
857 assert(MergeInstructions.size() == 1);
862 I->setOperand(0, MergeAddress);
870 Node->Merge = NewExit;
877 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
880 DivergentConstruct Root;
882 constructDivergentConstruct(Visited, S, &*
F.begin(), &Root);
883 return fixupConstruct(S, &Root);
891 bool simplifyBranches(Function &
F) {
894 for (BasicBlock &BB :
F) {
898 if (
SI->getNumCases() > 1)
903 Builder.SetInsertPoint(SI);
905 if (
SI->getNumCases() == 0) {
906 Builder.CreateBr(
SI->getDefaultDest());
910 SI->case_begin()->getCaseValue());
911 Builder.CreateCondBr(Condition,
SI->case_begin()->getCaseSuccessor(),
912 SI->getDefaultDest());
914 SI->eraseFromParent();
923 bool splitSwitchCases(Function &
F) {
926 for (BasicBlock &BB :
F) {
934 auto It =
SI->case_begin();
935 while (It !=
SI->case_end()) {
937 if (Seen.
count(Target) == 0) {
947 Builder.CreateBr(Target);
948 SI->addCase(It->getCaseValue(), NewTarget);
949 It =
SI->removeCase(It);
958 bool removeUselessBlocks(Function &
F) {
964 for (BasicBlock &BB :
F) {
971 if (MergeBlocks.count(&BB) != 0 || ContinueBlocks.count(&BB) != 0)
980 for (BasicBlock *Predecessor : Predecessors)
991 bool addHeaderToRemainingDivergentDAG(Function &
F) {
1003 for (BasicBlock &BB :
F) {
1004 if (HeaderBlocks.count(&BB) != 0)
1009 size_t CandidateEdges = 0;
1011 if (MergeBlocks.count(
Successor) != 0 ||
1016 CandidateEdges += 1;
1019 if (CandidateEdges <= 1)
1025 bool HasBadBlock =
false;
1026 visit(*Header, [&](
const BasicBlock *Node) {
1031 if (Node == Header || Node ==
Merge)
1034 HasBadBlock |= MergeBlocks.count(Node) != 0 ||
1035 ContinueBlocks.count(Node) != 0 ||
1036 HeaderBlocks.count(Node) != 0;
1037 return !HasBadBlock;
1045 if (
Merge ==
nullptr) {
1048 Builder.SetInsertPoint(Header->getTerminator());
1051 createOpSelectMerge(&Builder, MergeAddress);
1057 SplitInstruction = SplitInstruction->
getPrevNode();
1059 Merge->splitBasicBlockBefore(SplitInstruction,
"new.merge");
1062 Builder.SetInsertPoint(Header->getTerminator());
1065 createOpSelectMerge(&Builder, MergeAddress);
1074 SPIRVStructurizer() : FunctionPass(ID) {}
1094 Modified |= addMergeForNodesWithMultiplePredecessors(
F);
1099 Modified |= sortSelectionMergeHeaders(
F);
1104 Modified |= splitBlocksWithMultipleHeaders(
F);
1110 Modified |= addMergeForDivergentBlocks(
F);
1134 Modified |= addHeaderToRemainingDivergentDAG(
F);
1142 void getAnalysisUsage(AnalysisUsage &AU)
const override {
1145 AU.
addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
1147 AU.
addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
1148 FunctionPass::getAnalysisUsage(AU);
1151 void createOpSelectMerge(
IRBuilder<> *Builder, BlockAddress *MergeAddress) {
1154 MDNode *MDNode = BBTerminatorInst->
getMetadata(
"hlsl.controlflow.hint");
1156 ConstantInt *BranchHint = ConstantInt::get(Builder->
getInt32Ty(), 0);
1160 "invalid metadata hlsl.controlflow.hint");
1164 SmallVector<Value *, 2>
Args = {MergeAddress, BranchHint};
1172char SPIRVStructurizer::ID = 0;
1175 "structurize SPIRV",
false,
false)
1185 return new SPIRVStructurizer();
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
ReachingDefInfo InstSet & ToRemove
This file defines the DenseMap class.
static bool runOnFunction(Function &F, bool PostInlining)
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
static bool splitCriticalEdges(CallBrInst *CBR, DominatorTree *DT)
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
static BasicBlock * getDesignatedMergeBlock(Instruction *I)
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
static std::vector< Instruction * > getMergeInstructions(BasicBlock &BB)
SmallPtrSet< BasicBlock *, 0 > BlockSet
static BasicBlock * getDesignatedContinueBlock(Instruction *I)
static const ConvergenceRegion * getRegionForHeader(const ConvergenceRegion *Node, BasicBlock *BB)
static bool hasLoopMergeInstruction(BasicBlock &BB)
static SmallPtrSet< BasicBlock *, 2 > getContinueBlocks(Function &F)
static SmallPtrSet< BasicBlock *, 2 > getMergeBlocks(Function &F)
static SmallPtrSet< BasicBlock *, 2 > getHeaderBlocks(Function &F)
static bool isDefinedAsSelectionMergeBy(BasicBlock &Header, BasicBlock &Merge)
static void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, BasicBlock *NewTarget)
static void partialOrderVisit(BasicBlock &Start, std::function< bool(BasicBlock *)> Op)
static bool isMergeInstruction(Instruction *I)
static BasicBlock * getExitFor(const ConvergenceRegion *CR)
static void replaceIfBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, BasicBlock *NewTarget)
This file defines the SmallPtrSet class.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
const Function * getParent() const
Return the enclosing method, or null if none.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
LLVM_ABI const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
LLVM_ABI SymbolTableList< BasicBlock >::iterator eraseFromParent()
Unlink 'this' from the containing function and delete it.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
The address of a basic block.
BasicBlock * getBasicBlock() const
static LLVM_ABI BlockAddress * get(Function *F, BasicBlock *BB)
Return a BlockAddress for the specified function and basic block.
Represents analyses that only rely on functions' control flow.
This is an important base class in LLVM.
LLVM_ABI bool isConstantUsed() const
Return true if the constant has users other than constant expressions and other dangling things.
LLVM_ABI void destroyConstant()
Called if some element of this constant is no longer valid.
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
bool dominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
dominates - Returns true iff A dominates B.
void recalculate(ParentType &Func)
recalculate - compute a dominator tree for the given function
DomTreeNodeBase< NodeT > * getNode(const NodeT *BB) const
getNode - return the (Post)DominatorTree node for the specified basic block.
Legacy analysis pass which computes a DominatorTree.
FunctionPass class - This class is used to implement most global optimizations.
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
BasicBlock * GetInsertBlock() const
LLVM_ABI Value * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > OverloadTypes, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="", ArrayRef< OperandBundleDef > OpBundles={}, function_ref< void(CallInst *)> SetFn=[](CallInst *) {})
Variant to create a possibly constant-folded intrinsic.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
MDNode * getMetadata(unsigned KindID) const
Get the metadata of given kind attached to this Instruction.
A wrapper class for inspecting calls to intrinsic functions.
bool isLoopHeader(const BlockT *BB) const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
const MDOperand & getOperand(unsigned I) const
unsigned getNumOperands() const
Return number of MDNode operands.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
PreservedAnalyses run(Function &M, FunctionAnalysisManager &AM)
SmallPtrSet< BasicBlock *, 2 > Exits
SmallPtrSet< BasicBlock *, 8 > Blocks
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
LLVM_ABI void addCase(ConstantInt *OnVal, BasicBlock *Dest)
Add an entry to the switch instruction.
BasicBlock * getSuccessor(unsigned i=0) const
Type * getType() const
All values are typed, get the type of this value.
self_iterator getIterator()
FunctionPassManager manages FunctionPasses.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
@ C
The default llvm calling convention, compatible with C.
PostDomTreeBase< BasicBlock > BBPostDomTree
DomTreeBase< BasicBlock > BBDomTree
@ BasicBlock
Various leaf nodes.
std::enable_if_t< detail::IsValidPointer< X, Y >::value, X * > extract(Y &&MD)
Extract a Value from Metadata.
NodeAddr< NodeBase * > Node
friend class Instruction
Iterator for Instructions in a `BasicBlock.
LLVM_ABI iterator begin() const
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
FunctionAddr VTableAddr Value
FunctionPass * createSPIRVStructurizerPass()
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
auto successors(const MachineBasicBlock *BB)
bool sortBlocks(Function &F)
auto pred_size(const MachineBasicBlock *BB)
AllocaInst * createVariable(Function &F, Type *Type)
auto dyn_cast_or_null(const Y &Val)
SmallVector< unsigned, 1 > getSpirvLoopControlOperandsFromLoopMetadata(MDNode *LoopMD)
auto succ_size(const MachineBasicBlock *BB)
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
auto predecessors(const MachineBasicBlock *BB)
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.