24#include "llvm/IR/IntrinsicsSPIRV.h"
32#include <unordered_set>
37using BlockSet = std::unordered_set<BasicBlock *>;
38using Edge = std::pair<BasicBlock *, BasicBlock *>;
45 V.partialOrderVisit(Start, std::move(
Op));
52 if (
Node->Entry == BB)
55 for (
auto *Child :
Node->Children) {
66 std::unordered_set<BasicBlock *> ExitTargets;
74 assert(ExitTargets.size() <= 1);
75 if (ExitTargets.size() == 0)
78 return *ExitTargets.begin();
88 if (
II->getIntrinsicID() != Intrinsic::spv_loop_merge &&
89 II->getIntrinsicID() != Intrinsic::spv_selection_merge)
103 if (
II->getIntrinsicID() != Intrinsic::spv_loop_merge)
113 for (
auto &
I : Header) {
166 std::vector<Instruction *> Output;
169 Output.push_back(&
I);
190 std::stack<BasicBlock *> ToVisit;
193 ToVisit.push(&Start);
194 Seen.
insert(ToVisit.top());
195 while (ToVisit.size() != 0) {
219 for (
size_t i = 0; i < BI->getNumSuccessors(); i++) {
220 if (BI->getSuccessor(i) == OldTarget)
221 BI->setSuccessor(i, NewTarget);
225 if (BI->isUnconditional())
229 if (BI->getSuccessor(0) != BI->getSuccessor(1))
235 Builder.SetInsertPoint(BI);
236 Builder.CreateBr(BI->getSuccessor(0));
237 BI->eraseFromParent();
248 if (!
II ||
II->getIntrinsicID() != Intrinsic::spv_selection_merge)
252 II->eraseFromParent();
253 if (!
C->isConstantUsed())
254 C->destroyConstant();
271 for (
size_t i = 0; i <
SI->getNumSuccessors(); i++) {
272 if (
SI->getSuccessor(i) == OldTarget)
273 SI->setSuccessor(i, NewTarget);
278 assert(
false &&
"Unhandled terminator type.");
285 struct DivergentConstruct;
289 using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>;
295 struct DivergentConstruct {
300 DivergentConstruct *Parent =
nullptr;
301 ConstructList Children;
314 Splitter(Function &F, LoopInfo &LI) : F(F), LI(LI) { invalidate(); }
323 std::vector<BasicBlock *> getLoopConstructBlocks(BasicBlock *Header,
326 std::vector<BasicBlock *> Output;
330 if (DT.dominates(
Merge, BB) || !DT.dominates(Header, BB))
332 Output.push_back(BB);
339 std::vector<BasicBlock *>
340 getSelectionConstructBlocks(DivergentConstruct *Node) {
343 OutsideBlocks.insert(
Node->Merge);
345 for (DivergentConstruct *It =
Node->Parent; It !=
nullptr;
347 OutsideBlocks.insert(It->Merge);
349 OutsideBlocks.insert(It->Continue);
352 std::vector<BasicBlock *> Output;
354 if (OutsideBlocks.count(BB) != 0)
356 if (DT.dominates(Node->Merge, BB) || !DT.dominates(Node->Header, BB))
358 Output.push_back(BB);
365 std::vector<BasicBlock *> getSwitchConstructBlocks(BasicBlock *Header,
369 std::vector<BasicBlock *> Output;
372 if (!DT.dominates(Header, BB))
378 Output.push_back(BB);
385 std::vector<BasicBlock *> getCaseConstructBlocks(BasicBlock *Target,
389 std::vector<BasicBlock *> Output;
393 if (!DT.dominates(Target, BB))
399 Output.push_back(BB);
428 createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
429 std::unordered_set<BasicBlock *> Seen;
430 std::vector<Edge> Output;
431 Output.reserve(Edges.size());
433 for (
auto &[Src, Dst] : Edges) {
434 auto [Iterator,
Inserted] = Seen.insert(Src);
439 F.getContext(), Src->getName() +
".new.src", &F);
442 Builder.CreateBr(Dst);
446 Output.emplace_back(Src, Dst);
452 AllocaInst *CreateVariable(Function &F,
Type *
Type,
454 const DataLayout &
DL = F.getDataLayout();
455 return new AllocaInst(
Type,
DL.getAllocaAddrSpace(),
nullptr,
"reg",
461 BasicBlock *createSingleExitNode(BasicBlock *Header,
462 std::vector<Edge> &Edges) {
464 std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);
466 std::vector<BasicBlock *> Dsts;
467 std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
469 Header->getName() +
".new.exit", &F);
471 for (
auto &[Src, Dst] : FixedEdges) {
472 if (DstToIndex.count(Dst) != 0)
474 DstToIndex.emplace(Dst, ExitBuilder.getInt32(DstToIndex.size()));
478 if (Dsts.size() == 1) {
479 for (
auto &[Src, Dst] : FixedEdges) {
482 ExitBuilder.CreateBr(Dsts[0]);
486 AllocaInst *
Variable = CreateVariable(F, ExitBuilder.getInt32Ty(),
487 F.begin()->getFirstInsertionPt());
488 for (
auto &[Src, Dst] : FixedEdges) {
490 B2.SetInsertPoint(Src->getFirstInsertionPt());
491 B2.CreateStore(DstToIndex[Dst], Variable);
495 Value *
Load = ExitBuilder.CreateLoad(ExitBuilder.getInt32Ty(), Variable);
500 if (Dsts.size() == 2) {
503 ExitBuilder.CreateCondBr(Condition, Dsts[0], Dsts[1]);
507 SwitchInst *Sw = ExitBuilder.CreateSwitch(Load, Dsts[0], Dsts.size() - 1);
509 Sw->
addCase(DstToIndex[BB], BB);
516 Value *createExitVariable(
518 const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
524 Builder.SetInsertPoint(
T);
530 BI->isConditional() ? BI->getSuccessor(1) :
nullptr;
535 if (
LHS ==
nullptr ||
RHS ==
nullptr)
537 return Builder.CreateSelect(BI->getCondition(),
LHS,
RHS);
548 Builder.CreateUnreachable();
553 bool addMergeForLoops(Function &
F) {
554 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
555 auto *TopLevelRegion =
556 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
558 .getTopLevelRegion();
582 if (
Merge ==
nullptr) {
586 Merge = CreateUnreachable(
F);
587 Builder.SetInsertPoint(Br);
597 SmallVector<Value *, 2>
Args = {MergeAddress, ContinueAddress};
600 for (
unsigned Imm : LoopControlImms)
601 Args.emplace_back(ConstantInt::get(Builder.getInt32Ty(), Imm));
602 Builder.CreateIntrinsic(Intrinsic::spv_loop_merge, {
Args});
612 bool addMergeForNodesWithMultiplePredecessors(Function &
F) {
631 Builder.SetInsertPoint(Header->getTerminator());
634 createOpSelectMerge(&Builder, MergeAddress);
647 bool sortSelectionMerge(Function &
F, BasicBlock &
Block) {
648 std::vector<Instruction *> MergeInstructions;
649 for (Instruction &
I :
Block)
651 MergeInstructions.push_back(&
I);
653 if (MergeInstructions.size() <= 1)
656 Instruction *InsertionPoint = *MergeInstructions.begin();
658 PartialOrderingVisitor Visitor(
F);
659 std::sort(MergeInstructions.begin(), MergeInstructions.end(),
660 [&Visitor](Instruction *
Left, Instruction *
Right) {
663 BasicBlock *RightMerge = getDesignatedMergeBlock(Right);
664 BasicBlock *LeftMerge = getDesignatedMergeBlock(Left);
665 return !Visitor.compare(RightMerge, LeftMerge);
668 for (Instruction *
I : MergeInstructions) {
679 bool sortSelectionMergeHeaders(Function &
F) {
681 for (BasicBlock &BB :
F) {
689 bool splitBlocksWithMultipleHeaders(Function &
F) {
690 std::stack<BasicBlock *> Work;
693 if (MergeInstructions.size() <= 1)
698 const bool Modified = Work.size() > 0;
699 while (Work.size() > 0) {
703 std::vector<Instruction *> MergeInstructions =
705 for (
unsigned i = 1; i < MergeInstructions.size(); i++) {
707 Header->splitBasicBlock(MergeInstructions[i],
"new.header");
714 Builder.SetInsertPoint(BI);
715 Builder.CreateCondBr(Builder.getTrue(), NewBlock, Unreachable);
728 bool addMergeForDivergentBlocks(Function &
F) {
740 std::vector<BasicBlock *> Candidates;
749 if (Candidates.size() <= 1)
758 createOpSelectMerge(&Builder, MergeAddress);
766 std::vector<Edge> getExitsFrom(
const BlockSet &Construct,
767 BasicBlock &Header) {
768 std::vector<Edge> Output;
769 visit(Header, [&](BasicBlock *Item) {
770 if (Construct.count(Item) == 0)
785 void constructDivergentConstruct(
BlockSet &Visited, Splitter &S,
786 BasicBlock *BB, DivergentConstruct *Parent) {
787 if (Visited.count(BB) != 0)
792 if (MIS.size() == 0) {
794 constructDivergentConstruct(Visited, S,
Successor, Parent);
804 auto Output = std::make_unique<DivergentConstruct>();
806 Output->Merge =
Merge;
808 Output->Parent = Parent;
810 constructDivergentConstruct(Visited, S,
Merge, Parent);
812 constructDivergentConstruct(Visited, S,
Continue, Output.get());
815 constructDivergentConstruct(Visited, S,
Successor, Output.get());
818 Parent->Children.emplace_back(std::move(Output));
822 BlockSet getConstructBlocks(Splitter &S, DivergentConstruct *Node) {
825 if (
Node->Continue) {
826 auto LoopBlocks = S.getLoopConstructBlocks(
Node->Header,
Node->Merge);
827 return BlockSet(LoopBlocks.begin(), LoopBlocks.end());
830 auto SelectionBlocks = S.getSelectionConstructBlocks(Node);
831 return BlockSet(SelectionBlocks.begin(), SelectionBlocks.end());
836 bool fixupConstruct(Splitter &S, DivergentConstruct *Node) {
838 for (
auto &Child :
Node->Children)
839 Modified |= fixupConstruct(S, Child.get());
843 if (
Node->Parent ==
nullptr)
849 if (
Node->Parent->Header ==
nullptr)
856 BlockSet ConstructBlocks = getConstructBlocks(S, Node);
857 auto Edges = getExitsFrom(ConstructBlocks, *
Node->Header);
860 if (Edges.size() < 1)
863 bool HasBadEdge =
Node->Merge ==
Node->Parent->Merge ||
864 Node->Merge ==
Node->Parent->Continue;
866 for (
auto &[Src, Dst] : Edges) {
873 if (
Node->Merge == Dst)
878 if (
Node->Continue == Dst)
890 BasicBlock *NewExit = S.createSingleExitNode(
Node->Header, Edges);
897 assert(MergeInstructions.size() == 1);
902 I->setOperand(0, MergeAddress);
910 Node->Merge = NewExit;
916 bool splitCriticalEdges(Function &
F) {
917 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
920 DivergentConstruct Root;
922 constructDivergentConstruct(Visited, S, &*
F.begin(), &Root);
923 return fixupConstruct(S, &Root);
931 bool simplifyBranches(Function &
F) {
934 for (BasicBlock &BB :
F) {
938 if (
SI->getNumCases() > 1)
943 Builder.SetInsertPoint(SI);
945 if (
SI->getNumCases() == 0) {
946 Builder.CreateBr(
SI->getDefaultDest());
950 SI->case_begin()->getCaseValue());
951 Builder.CreateCondBr(Condition,
SI->case_begin()->getCaseSuccessor(),
952 SI->getDefaultDest());
954 SI->eraseFromParent();
963 bool splitSwitchCases(Function &
F) {
966 for (BasicBlock &BB :
F) {
972 Seen.insert(
SI->getDefaultDest());
974 auto It =
SI->case_begin();
975 while (It !=
SI->case_end()) {
977 if (Seen.count(Target) == 0) {
987 Builder.CreateBr(Target);
988 SI->addCase(It->getCaseValue(), NewTarget);
989 It =
SI->removeCase(It);
998 bool removeUselessBlocks(Function &
F) {
1004 for (BasicBlock &BB :
F) {
1011 if (MergeBlocks.count(&BB) != 0 || ContinueBlocks.count(&BB) != 0)
1020 for (BasicBlock *Predecessor : Predecessors)
1031 bool addHeaderToRemainingDivergentDAG(Function &
F) {
1043 for (BasicBlock &BB :
F) {
1044 if (HeaderBlocks.count(&BB) != 0)
1049 size_t CandidateEdges = 0;
1051 if (MergeBlocks.count(
Successor) != 0 ||
1056 CandidateEdges += 1;
1059 if (CandidateEdges <= 1)
1065 bool HasBadBlock =
false;
1066 visit(*Header, [&](
const BasicBlock *Node) {
1071 if (Node == Header || Node ==
Merge)
1074 HasBadBlock |= MergeBlocks.count(Node) != 0 ||
1075 ContinueBlocks.count(Node) != 0 ||
1076 HeaderBlocks.count(Node) != 0;
1077 return !HasBadBlock;
1085 if (
Merge ==
nullptr) {
1088 Builder.SetInsertPoint(Header->getTerminator());
1091 createOpSelectMerge(&Builder, MergeAddress);
1097 SplitInstruction = SplitInstruction->
getPrevNode();
1099 Merge->splitBasicBlockBefore(SplitInstruction,
"new.merge");
1102 Builder.SetInsertPoint(Header->getTerminator());
1105 createOpSelectMerge(&Builder, MergeAddress);
1114 SPIRVStructurizer() : FunctionPass(ID) {}
1134 Modified |= addMergeForNodesWithMultiplePredecessors(
F);
1139 Modified |= sortSelectionMergeHeaders(
F);
1144 Modified |= splitBlocksWithMultipleHeaders(
F);
1150 Modified |= addMergeForDivergentBlocks(
F);
1174 Modified |= addHeaderToRemainingDivergentDAG(
F);
1182 void getAnalysisUsage(AnalysisUsage &AU)
const override {
1185 AU.
addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
1187 AU.
addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
1188 FunctionPass::getAnalysisUsage(AU);
1191 void createOpSelectMerge(
IRBuilder<> *Builder, BlockAddress *MergeAddress) {
1194 MDNode *MDNode = BBTerminatorInst->
getMetadata(
"hlsl.controlflow.hint");
1196 ConstantInt *BranchHint = ConstantInt::get(Builder->
getInt32Ty(), 0);
1200 "invalid metadata hlsl.controlflow.hint");
1204 SmallVector<Value *, 2>
Args = {MergeAddress, BranchHint};
1212char SPIRVStructurizer::ID = 0;
1215 "structurize SPIRV",
false,
false)
1225 return new SPIRVStructurizer();
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
ReachingDefInfo InstSet & ToRemove
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file defines the DenseMap class.
static bool runOnFunction(Function &F, bool PostInlining)
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
static BasicBlock * getDesignatedMergeBlock(Instruction *I)
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
static std::vector< Instruction * > getMergeInstructions(BasicBlock &BB)
static BasicBlock * getDesignatedContinueBlock(Instruction *I)
std::unordered_set< BasicBlock * > BlockSet
static const ConvergenceRegion * getRegionForHeader(const ConvergenceRegion *Node, BasicBlock *BB)
static bool hasLoopMergeInstruction(BasicBlock &BB)
static SmallPtrSet< BasicBlock *, 2 > getContinueBlocks(Function &F)
static SmallPtrSet< BasicBlock *, 2 > getMergeBlocks(Function &F)
static SmallPtrSet< BasicBlock *, 2 > getHeaderBlocks(Function &F)
static bool isDefinedAsSelectionMergeBy(BasicBlock &Header, BasicBlock &Merge)
static void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, BasicBlock *NewTarget)
static void partialOrderVisit(BasicBlock &Start, std::function< bool(BasicBlock *)> Op)
static bool isMergeInstruction(Instruction *I)
static BasicBlock * getExitFor(const ConvergenceRegion *CR)
static void replaceIfBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, BasicBlock *NewTarget)
This file defines the SmallPtrSet class.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
const Function * getParent() const
Return the enclosing method, or null if none.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
LLVM_ABI const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
LLVM_ABI SymbolTableList< BasicBlock >::iterator eraseFromParent()
Unlink 'this' from the containing function and delete it.
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
The address of a basic block.
BasicBlock * getBasicBlock() const
static LLVM_ABI BlockAddress * get(Function *F, BasicBlock *BB)
Return a BlockAddress for the specified function and basic block.
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Represents analyses that only rely on functions' control flow.
This is an important base class in LLVM.
LLVM_ABI bool isConstantUsed() const
Return true if the constant has users other than constant expressions and other dangling things.
LLVM_ABI void destroyConstant()
Called if some element of this constant is no longer valid.
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
bool dominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
dominates - Returns true iff A dominates B.
void recalculate(ParentType &Func)
recalculate - compute a dominator tree for the given function
DomTreeNodeBase< NodeT > * getNode(const NodeT *BB) const
getNode - return the (Post)DominatorTree node for the specified basic block.
Legacy analysis pass which computes a DominatorTree.
FunctionPass class - This class is used to implement most global optimizations.
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
BasicBlock * GetInsertBlock() const
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
MDNode * getMetadata(unsigned KindID) const
Get the metadata of given kind attached to this Instruction.
A wrapper class for inspecting calls to intrinsic functions.
bool isLoopHeader(const BlockT *BB) const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
const MDOperand & getOperand(unsigned I) const
unsigned getNumOperands() const
Return number of MDNode operands.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
PreservedAnalyses run(Function &M, FunctionAnalysisManager &AM)
SmallPtrSet< BasicBlock *, 2 > Exits
SmallPtrSet< BasicBlock *, 8 > Blocks
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
LLVM_ABI void addCase(ConstantInt *OnVal, BasicBlock *Dest)
Add an entry to the switch instruction.
Type * getType() const
All values are typed, get the type of this value.
self_iterator getIterator()
FunctionPassManager manages FunctionPasses.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
@ C
The default llvm calling convention, compatible with C.
PostDomTreeBase< BasicBlock > BBPostDomTree
DomTreeBase< BasicBlock > BBDomTree
@ BasicBlock
Various leaf nodes.
std::enable_if_t< detail::IsValidPointer< X, Y >::value, X * > extract(Y &&MD)
Extract a Value from Metadata.
NodeAddr< NodeBase * > Node
friend class Instruction
Iterator for Instructions in a `BasicBlock.
LLVM_ABI iterator begin() const
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
FunctionAddr VTableAddr Value
FunctionPass * createSPIRVStructurizerPass()
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
auto successors(const MachineBasicBlock *BB)
bool sortBlocks(Function &F)
auto pred_size(const MachineBasicBlock *BB)
SmallVector< unsigned, 1 > getSpirvLoopControlOperandsFromLoopMetadata(Loop *L)
auto dyn_cast_or_null(const Y &Val)
auto succ_size(const MachineBasicBlock *BB)
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
auto predecessors(const MachineBasicBlock *BB)
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.