38#ifndef LLVM_ADT_GENERICUNIFORMITYIMPL_H
39#define LLVM_ADT_GENERICUNIFORMITYIMPL_H
50#define DEBUG_TYPE "uniformity"
56template <
typename Range>
auto unique(Range &&R) {
88 using BlockT =
typename ContextT::BlockT;
108 POIndex[&BB] = m_order.
size();
111 <<
"): " << Context.print(&BB) <<
"\n");
113 ReducibleCycleHeaders.
insert(&BB);
118 return POIndex.
lookup(BB);
122 return ReducibleCycleHeaders.
contains(BB);
129 const ContextT &Context;
261 using BlockT =
typename ContextT::BlockT;
318 CachedControlDivDescs;
329 using BlockT =
typename ContextT::BlockT;
333 using UseT =
typename ContextT::UseT;
342 typename SyncDependenceAnalysisT::DivergenceDescriptor;
363 bool AllDefsDivergent =
true);
379 if (
I.isTerminator()) {
442 void taintAndPushAllDefs(
const BlockT &JoinBlock);
446 void taintAndPushPhiNodes(
const BlockT &JoinBlock);
451 void propagateCycleExitDivergence(
const BlockT &DivExit,
455 void analyzeCycleExitDivergence(
const CycleT &OuterDivCycle);
460 const CycleT &OuterDivCycle);
469 bool isTemporalDivergent(
const BlockT &ObservingBlock,
473template <
typename ImplT>
481 using BlockT =
typename ContextT::BlockT;
492 typename SyncDependenceAnalysisT::DivergenceDescriptor;
508 std::unique_ptr<DivergenceDescriptorT>
DivDesc;
518 Out <<
"Propagator::BlockLabels {\n";
519 for (
int BlockIdx = (
int)
CyclePOT.size() - 1; BlockIdx >= 0; --BlockIdx) {
522 Out <<
Context.print(
Block) <<
"(" << BlockIdx <<
") : ";
526 Out <<
Context.print(Label) <<
"\n";
538 <<
"\tpushed label: " <<
Context.print(&PushedLabel)
540 <<
"\told label: " <<
Context.print(OldLabel) <<
"\n");
543 if (OldLabel == &PushedLabel)
546 if (OldLabel != &SuccBlock) {
547 auto SuccIdx =
CyclePOT.getIndex(&SuccBlock);
576 DivDesc->CycleDivBlocks.insert(&ExitBlock);
588 DivDesc->JoinDivBlocks.insert(&SuccBlock);
602 const BlockT *FloorLabel =
nullptr;
608 if (DivTermCycle && !DivTermCycle->contains(SuccBlock)) {
612 DivDesc->CycleDivBlocks.insert(SuccBlock);
614 <<
Context.print(SuccBlock) <<
"\n");
616 auto SuccIdx =
CyclePOT.getIndex(SuccBlock);
618 FloorIdx = std::min<int>(FloorIdx, SuccIdx);
623 if (BlockIdx == -1 || BlockIdx < FloorIdx)
629 if (BlockIdx == DivTermIdx) {
636 << BlockIdx <<
"\n");
641 bool CausedJoin =
false;
642 int LoweredFloorIdx = FloorIdx;
668 if (
const auto *BlockCycle = getReducibleParent(
Block)) {
670 BlockCycle->getExitBlocks(BlockCycleExits);
671 for (
auto *BlockCycleExit : BlockCycleExits) {
674 std::min<int>(LoweredFloorIdx,
CyclePOT.getIndex(BlockCycleExit));
678 CausedJoin |=
visitEdge(*SuccBlock, *Label);
680 std::min<int>(LoweredFloorIdx,
CyclePOT.getIndex(SuccBlock));
687 FloorIdx = LoweredFloorIdx;
688 }
else if (FloorLabel != Label) {
691 FloorIdx = LoweredFloorIdx;
712 for (
const auto *Exit : Exits) {
715 DivDesc->CycleDivBlocks.insert(Exit);
726template <
typename ContextT>
730template <
typename ContextT>
733 : CyclePO(
Context), DT(DT), CI(CI) {
737template <
typename ContextT>
742 return EmptyDivergenceDesc;
746 auto ItCached = CachedControlDivDescs.find(DivTermBlock);
747 if (ItCached != CachedControlDivDescs.end())
748 return *ItCached->second;
758 for (
const auto *BB : Blocks) {
759 Out << LS << CI.getSSAContext().
print(BB);
766 dbgs() <<
"\nResult (" << CI.getSSAContext().print(DivTermBlock)
767 <<
"):\n JoinDivBlocks: " <<
printBlockSet(DivDesc->JoinDivBlocks)
768 <<
" CycleDivBlocks: " <<
printBlockSet(DivDesc->CycleDivBlocks)
773 CachedControlDivDescs.try_emplace(DivTermBlock, std::move(DivDesc));
774 assert(ItInserted.second);
775 return *ItInserted.first->second;
778template <
typename ContextT>
781 if (
I.isTerminator()) {
782 if (DivergentTermBlocks.insert(
I.getParent()).second) {
784 <<
Context.print(
I.getParent()) <<
"\n");
790 if (isAlwaysUniform(
I))
793 return markDefsDivergent(
I);
796template <
typename ContextT>
799 if (DivergentValues.insert(Val).second) {
806template <
typename ContextT>
809 UniformOverrides.insert(&Instr);
812template <
typename ContextT>
814 const InstructionT &
I,
const CycleT &OuterDivCycle) {
820 if (isAlwaysUniform(
I))
823 if (!usesValueFromCycle(
I, OuterDivCycle))
826 if (markDivergent(
I))
827 Worklist.push_back(&
I);
841template <
typename ContextT>
843 const CycleT &OuterDivCycle) {
851 OuterDivCycle.getExitBlocks(DomFrontier);
854 auto isInDomRegion = [&](BlockT *BB) {
856 if (OuterDivCycle.contains(
P))
868 bool Promoted =
false;
870 for (
auto *W : DomFrontier) {
871 if (!isInDomRegion(W)) {
895 for (
const auto *UserBlock : DomFrontier) {
897 <<
Context.print(UserBlock) <<
"\n");
898 for (
const auto &Phi : UserBlock->phis()) {
900 analyzeTemporalDivergence(Phi, OuterDivCycle);
906 for (
const auto *UserBlock : DomRegion) {
908 <<
Context.print(UserBlock) <<
"\n");
909 for (
const auto &
I : *UserBlock) {
911 analyzeTemporalDivergence(
I, OuterDivCycle);
916template <
typename ContextT>
918 const BlockT &DivExit,
const CycleT &InnerDivCycle) {
921 auto *DivCycle = &InnerDivCycle;
922 auto *OuterDivCycle = DivCycle;
923 auto *ExitLevelCycle = CI.getCycle(&DivExit);
924 const unsigned CycleExitDepth =
925 ExitLevelCycle ? ExitLevelCycle->getDepth() : 0;
928 while (DivCycle && DivCycle->getDepth() > CycleExitDepth) {
930 <<
Context.print(DivCycle->getHeader()) <<
"\n");
931 OuterDivCycle = DivCycle;
932 DivCycle = DivCycle->getParentCycle();
935 <<
Context.print(OuterDivCycle->getHeader()) <<
"\n");
937 if (!DivergentExitCycles.insert(OuterDivCycle).second)
942 for (
const auto *
C : AssumedDivergent) {
943 if (
C->contains(OuterDivCycle))
947 analyzeCycleExitDivergence(*OuterDivCycle);
950template <
typename ContextT>
954 for (
const auto &
I :
instrs(BB)) {
958 if (
I.isTerminator())
961 if (markDivergent(
I))
962 Worklist.push_back(&
I);
967template <
typename ContextT>
969 const BlockT &JoinBlock) {
972 for (
const auto &Phi : JoinBlock.phis()) {
980 if (ContextT::isConstantOrUndefValuePhi(Phi))
982 if (markDivergent(Phi))
983 Worklist.push_back(&Phi);
990template <
typename CycleT>
994 [Candidate](CycleT *
C) {
return C->contains(Candidate); }))
1005template <
typename CycleT,
typename BlockT>
1007 const BlockT *DivTermBlock,
1008 const BlockT *JoinBlock) {
1021 while (Parent && !Parent->contains(DivTermBlock)) {
1025 assert(!Parent->isReducible());
1030 LLVM_DEBUG(
dbgs() <<
"cycle made divergent by external branch\n");
1038template <
typename ContextT,
typename CycleT,
typename BlockT,
1039 typename DominatorTreeT>
1040static const CycleT *
1042 const BlockT *JoinBlock,
const DominatorTreeT &DT,
1043 ContextT &Context) {
1045 <<
"for internal branch " <<
Context.print(DivTermBlock)
1047 if (DT.properlyDominates(DivTermBlock, JoinBlock))
1062 <<
" does not dominate join\n");
1065 while (Parent && !DT.properlyDominates(Parent->getHeader(), JoinBlock)) {
1067 <<
" does not dominate join\n");
1072 LLVM_DEBUG(
dbgs() <<
" cycle made divergent by internal branch\n");
1076template <
typename ContextT,
typename CycleT,
typename BlockT,
1077 typename DominatorTreeT>
1078static const CycleT *
1080 const BlockT *JoinBlock,
const DominatorTreeT &DT,
1081 ContextT &Context) {
1097template <
typename ContextT>
1099 const BlockT &ObservingBlock,
const InstructionT &Def)
const {
1100 const BlockT *DefBlock = Def.getParent();
1101 for (
const CycleT *
Cycle = CI.getCycle(DefBlock);
1104 if (DivergentExitCycles.contains(
Cycle)) {
1111template <
typename ContextT>
1114 const auto *DivTermBlock = Term.getParent();
1115 DivergentTermBlocks.insert(DivTermBlock);
1120 if (!DT.isReachableFromEntry(DivTermBlock))
1123 const auto &DivDesc = SDA.getJoinBlocks(DivTermBlock);
1127 for (
const auto *JoinBlock : DivDesc.JoinDivBlocks) {
1128 const auto *
Cycle = CI.getCycle(JoinBlock);
1137 taintAndPushPhiNodes(*JoinBlock);
1143 return A->getDepth() >
B->getDepth();
1151 for (
auto *
C : DivCycles) {
1155 for (
const BlockT *BB :
C->blocks()) {
1156 taintAndPushAllDefs(*BB);
1160 const auto *BranchCycle = CI.getCycle(DivTermBlock);
1161 assert(DivDesc.CycleDivBlocks.empty() || BranchCycle);
1162 for (
const auto *DivExitBlock : DivDesc.CycleDivBlocks) {
1163 propagateCycleExitDivergence(*DivExitBlock, *BranchCycle);
1167template <
typename ContextT>
1170 auto DivValuesCopy = DivergentValues;
1171 for (
const auto DivVal : DivValuesCopy) {
1172 assert(isDivergent(DivVal) &&
"Worklist invariant violated!");
1178 while (!Worklist.empty()) {
1180 Worklist.pop_back();
1184 if (
I->isTerminator()) {
1185 analyzeControlDivergence(*
I);
1190 assert(isDivergent(*
I) &&
"Worklist invariant violated!");
1195template <
typename ContextT>
1198 return UniformOverrides.contains(&Instr);
1201template <
typename ContextT>
1206 DA.reset(
new ImplT{Func, DT, CI,
TTI});
1211template <
typename ContextT>
1213 bool haveDivergentArgs =
false;
1215 if (DivergentValues.empty()) {
1216 assert(DivergentTermBlocks.empty());
1217 assert(DivergentExitCycles.empty());
1218 OS <<
"ALL VALUES UNIFORM\n";
1222 for (
const auto &
entry : DivergentValues) {
1225 if (!haveDivergentArgs) {
1226 OS <<
"DIVERGENT ARGUMENTS:\n";
1227 haveDivergentArgs =
true;
1233 if (!AssumedDivergent.empty()) {
1234 OS <<
"CYCLES ASSSUMED DIVERGENT:\n";
1235 for (
const CycleT *cycle : AssumedDivergent) {
1236 OS <<
" " << cycle->print(
Context) <<
'\n';
1240 if (!DivergentExitCycles.empty()) {
1241 OS <<
"CYCLES WITH DIVERGENT EXIT:\n";
1242 for (
const CycleT *cycle : DivergentExitCycles) {
1243 OS <<
" " << cycle->print(
Context) <<
'\n';
1250 OS <<
"DEFINITIONS\n";
1253 for (
auto value : defs) {
1254 if (isDivergent(
value))
1255 OS <<
" DIVERGENT: ";
1261 OS <<
"TERMINATORS\n";
1264 bool divergentTerminators = hasDivergentTerminator(
block);
1265 for (
auto *
T : terms) {
1266 if (divergentTerminators)
1267 OS <<
" DIVERGENT: ";
1273 OS <<
"END BLOCK\n";
1277template <
typename ContextT>
1279 return DA->hasDivergence();
1283template <
typename ContextT>
1285 return DA->isDivergent(V);
1288template <
typename ContextT>
1290 return DA->isDivergent(*
I);
1293template <
typename ContextT>
1295 return DA->isDivergentUse(U);
1298template <
typename ContextT>
1300 return DA->hasDivergentTerminator(
B);
1304template <
typename ContextT>
1309template <
typename ContextT>
1314 while (!Stack.empty()) {
1315 auto *NextBB = Stack.back();
1316 if (Finalized.
count(NextBB)) {
1320 LLVM_DEBUG(
dbgs() <<
" visiting " << CI.getSSAContext().print(NextBB)
1322 auto *NestedCycle = CI.getCycle(NextBB);
1325 while (NestedCycle->getParentCycle() !=
Cycle)
1326 NestedCycle = NestedCycle->getParentCycle();
1329 NestedCycle->getExitBlocks(NestedExits);
1330 bool PushedNodes =
false;
1331 for (
auto *NestedExitBB : NestedExits) {
1333 << CI.getSSAContext().print(NestedExitBB) <<
"\n");
1336 if (Finalized.
count(NestedExitBB))
1339 Stack.push_back(NestedExitBB);
1341 << CI.getSSAContext().print(NestedExitBB) <<
"\n");
1346 computeCyclePO(CI, NestedCycle, Finalized);
1353 bool PushedNodes =
false;
1356 << CI.getSSAContext().print(SuccBB) <<
"\n");
1359 if (Finalized.
count(SuccBB))
1362 Stack.push_back(SuccBB);
1363 LLVM_DEBUG(
dbgs() <<
" pushed succ: " << CI.getSSAContext().print(SuccBB)
1369 << CI.getSSAContext().print(NextBB) <<
"\n");
1371 Finalized.
insert(NextBB);
1372 appendBlock(*NextBB);
1378template <
typename ContextT>
1380 const CycleInfoT &CI,
const CycleT *
Cycle,
1387 << CI.getSSAContext().print(CycleHeader) <<
"\n");
1389 Finalized.
insert(CycleHeader);
1393 << CI.getSSAContext().print(CycleHeader) <<
"\n");
1398 LLVM_DEBUG(
dbgs() <<
" examine succ: " << CI.getSSAContext().print(BB)
1402 if (BB == CycleHeader)
1404 if (!Finalized.
count(BB)) {
1405 LLVM_DEBUG(
dbgs() <<
" pushed succ: " << CI.getSSAContext().print(BB)
1407 Stack.push_back(BB);
1412 computeStackPO(Stack, CI,
Cycle, Finalized);
1418template <
typename ContextT>
1425 computeStackPO(Stack, CI,
nullptr, Finalized);
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Given that RA is a live value
print Instructions which execute on loop entry
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallPtrSet class.
This file defines the SparseBitVector class.
unify loop Fixup each natural loop to have a single exit block
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Compute divergence starting with a divergent branch.
typename SyncDependenceAnalysisT::BlockLabelMap BlockLabelMapT
const ModifiedPO & CyclePOT
GenericSyncDependenceAnalysis< ContextT > SyncDependenceAnalysisT
typename ContextT::DominatorTreeT DominatorTreeT
bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel)
const BlockT & DivTermBlock
std::unique_ptr< DivergenceDescriptorT > DivDesc
void printDefs(raw_ostream &Out)
typename ContextT::FunctionT FunctionT
const DominatorTreeT & DT
ModifiedPostOrder< ContextT > ModifiedPO
std::unique_ptr< DivergenceDescriptorT > computeJoinPoints()
BlockLabelMapT & BlockLabels
SparseBitVector FreshLabels
bool visitCycleExitEdge(const BlockT &ExitBlock, const BlockT &Label)
typename ContextT::ValueRefT ValueRefT
typename ContextT::BlockT BlockT
typename SyncDependenceAnalysisT::DivergenceDescriptor DivergenceDescriptorT
typename CycleInfoT::CycleT CycleT
DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT, const CycleInfoT &CI, const BlockT &DivTermBlock)
bool visitEdge(const BlockT &SuccBlock, const BlockT &Label)
Cycle information for a function.
FunctionT * getFunction() const
GenericCycle< ContextT > CycleT
CycleT * getCycle(const BlockT *Block) const
Find the innermost cycle containing a given block.
A possibly irreducible generalization of a Loop.
BlockT * getHeader() const
bool isReducible() const
Whether the cycle is a natural loop.
void getExitBlocks(SmallVectorImpl< BlockT * > &TmpStorage) const
Return all of the successor blocks of this cycle.
bool contains(const BlockT *Block) const
Return whether Block is contained in the cycle.
const GenericCycle * getParentCycle() const
Locate join blocks for disjoint paths starting at a divergent branch.
GenericSyncDependenceAnalysis(const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI)
ModifiedPostOrder< ContextT > ModifiedPO
typename ContextT::DominatorTreeT DominatorTreeT
typename ContextT::FunctionT FunctionT
typename ContextT::InstructionT InstructionT
typename ContextT::BlockT BlockT
typename ContextT::ValueRefT ValueRefT
typename CycleInfoT::CycleT CycleT
const DivergenceDescriptor & getJoinBlocks(const BlockT *DivTermBlock)
Computes divergent join points and cycle exits caused by branch divergence in Term.
Construct a specially modified post-order traversal of cycles.
typename ContextT::FunctionT FunctionT
const BlockT * operator[](size_t idx) const
typename CycleInfoT::CycleT CycleT
bool isReducibleCycleHeader(const BlockT *BB) const
ModifiedPostOrder(const ContextT &C)
unsigned count(BlockT *BB) const
void compute(const CycleInfoT &CI)
Generically compute the modified post order.
GenericCycleInfo< ContextT > CycleInfoT
void appendBlock(const BlockT &BB, bool isReducibleCycleHeader=false)
unsigned getIndex(const BlockT *BB) const
typename std::vector< BlockT * >::const_iterator const_iterator
typename ContextT::DominatorTreeT DominatorTreeT
typename ContextT::BlockT BlockT
Simple wrapper around std::function<void(raw_ostream&)>.
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
iterator erase(const_iterator CI)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
This class implements an extremely fast bulk output stream that can only output to a stream.
@ C
The default llvm calling convention, compatible with C.
This is an optimization pass for GlobalISel generic memory operations.
static const CycleT * getIntDivCycle(const CycleT *Cycle, const BlockT *DivTermBlock, const BlockT *JoinBlock, const DominatorTreeT &DT, ContextT &Context)
Return the outermost cycle made divergent by branch inside it.
static const CycleT * getExtDivCycle(const CycleT *Cycle, const BlockT *DivTermBlock, const BlockT *JoinBlock)
Return the outermost cycle made divergent by branch outside it.
constexpr auto adl_begin(RangeT &&range) -> decltype(adl_detail::begin_impl(std::forward< RangeT >(range)))
Returns the begin iterator to range using std::begin and function found through Argument-Dependent Lo...
auto successors(const MachineBasicBlock *BB)
static bool insertIfNotContained(SmallVector< CycleT * > &Cycles, CycleT *Candidate)
Add Candidate to Cycles if it is not already contained in Cycles.
constexpr auto adl_end(RangeT &&range) -> decltype(adl_detail::end_impl(std::forward< RangeT >(range)))
Returns the end iterator to range using std::end and functions found through Argument-Dependent Looku...
static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out)
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
auto predecessors(const MachineBasicBlock *BB)
auto instrs(const MachineBasicBlock &BB)
unsigned succ_size(const MachineBasicBlock *BB)
static const CycleT * getOutermostDivergentCycle(const CycleT *Cycle, const BlockT *DivTermBlock, const BlockT *JoinBlock, const DominatorTreeT &DT, ContextT &Context)
Information discovered by the sync dependence analysis for each divergent branch.
ConstBlockSet CycleDivBlocks
ConstBlockSet JoinDivBlocks
BlockLabelMap BlockLabels