84 #ifdef EXPENSIVE_CHECKS
90 #define DEBUG_TYPE "dfa-jump-threading"
92 STATISTIC(NumTransforms,
"Number of transformations done");
93 STATISTIC(NumCloned,
"Number of blocks cloned");
94 STATISTIC(NumPaths,
"Number of individual paths threaded");
98 cl::desc(
"View the CFG before DFA Jump Threading"),
102 "dfa-max-path-length",
103 cl::desc(
"Max number of blocks searched to find a threading path"),
108 cl::desc(
"Maximum cost accepted for the transformation"),
113 class SelectInstToUnfold {
121 PHINode *getUse() {
return SIUse; }
123 explicit operator bool()
const {
return SI && SIUse; }
127 std::vector<SelectInstToUnfold> *NewSIsToUnfold,
128 std::vector<BasicBlock *> *NewBBs);
130 class DFAJumpThreading {
134 : AC(AC), DT(DT),
TTI(
TTI), ORE(ORE) {}
144 for (SelectInstToUnfold SIToUnfold : SelectInsts)
145 Stack.push_back(SIToUnfold);
147 while (!
Stack.empty()) {
148 SelectInstToUnfold SIToUnfold =
Stack.pop_back_val();
150 std::vector<SelectInstToUnfold> NewSIsToUnfold;
151 std::vector<BasicBlock *> NewBBs;
152 unfold(&DTU, SIToUnfold, &NewSIsToUnfold, &NewBBs);
155 for (
const SelectInstToUnfold &NewSIToUnfold : NewSIsToUnfold)
156 Stack.push_back(NewSIToUnfold);
166 class DFAJumpThreadingLegacyPass :
public FunctionPass {
184 &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(
F);
185 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
187 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
189 &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
191 return DFAJumpThreading(AC, DT,
TTI, ORE).run(
F);
198 "DFA Jump Threading",
false,
false)
208 return new DFAJumpThreadingLegacyPass();
214 void createBasicBlockAndSinkSelectInst(
217 BranchInst **NewBranch, std::vector<SelectInstToUnfold> *NewSIsToUnfold,
218 std::vector<BasicBlock *> *NewBBs) {
224 NewBBs->push_back(*NewBlock);
227 NewSIsToUnfold->push_back(SelectInstToUnfold(SIToSink, SIUse));
239 std::vector<SelectInstToUnfold> *NewSIsToUnfold,
240 std::vector<BasicBlock *> *NewBBs) {
242 PHINode *SIUse = SIToUnfold.getUse();
259 if (
SelectInst *SIOp = dyn_cast<SelectInst>(
SI->getTrueValue())) {
260 createBasicBlockAndSinkSelectInst(DTU,
SI, SIUse, SIOp, EndBlock,
261 "si.unfold.true", &TrueBlock, &TrueBranch,
262 NewSIsToUnfold, NewBBs);
264 if (
SelectInst *SIOp = dyn_cast<SelectInst>(
SI->getFalseValue())) {
265 createBasicBlockAndSinkSelectInst(DTU,
SI, SIUse, SIOp, EndBlock,
266 "si.unfold.false", &FalseBlock,
267 &FalseBranch, NewSIsToUnfold, NewBBs);
272 if (!TrueBlock && !FalseBlock) {
275 NewBBs->push_back(FalseBlock);
287 if (TrueBlock && FalseBlock) {
300 Phi.addIncoming(Phi.getIncomingValueForBlock(StartBlock), TrueBlock);
301 Phi.addIncoming(Phi.getIncomingValueForBlock(StartBlock), FalseBlock);
306 Value *SIOp1 =
SI->getTrueValue();
307 Value *SIOp2 =
SI->getFalseValue();
311 NewBlock = FalseBlock;
316 NewBlock = TrueBlock;
329 for (
auto II = EndBlock->
begin();
PHINode *Phi = dyn_cast<PHINode>(II);
332 Phi->addIncoming(Phi->getIncomingValueForBlock(StartBlock), NewBlock);
341 SI->eraseFromParent();
349 typedef std::deque<BasicBlock *> PathType;
350 typedef std::vector<PathType> PathsType;
352 typedef std::vector<ClonedBlock> CloneList;
381 struct ThreadingPath {
383 uint64_t getExitValue()
const {
return ExitVal; }
388 bool isExitValueSet()
const {
return IsExitValSet; }
391 const BasicBlock *getDeterminatorBB()
const {
return DBB; }
395 const PathType &getPath()
const {
return Path; }
396 void setPath(
const PathType &NewPath) {
Path = NewPath; }
399 OS <<
Path <<
" [ " << ExitVal <<
", " << DBB->getName() <<
" ]";
406 bool IsExitValSet =
false;
418 if (isPredictable(
SI)) {
423 <<
"Switch instruction is not predictable.";
428 virtual ~MainSwitch() =
default;
430 SwitchInst *getInstr()
const {
return Instr; }
440 std::deque<Instruction *> Q;
444 Value *FirstDef =
SI->getOperand(0);
445 auto *Inst = dyn_cast<Instruction>(FirstDef);
453 if (!isa<PHINode>(Inst))
456 LLVM_DEBUG(
dbgs() <<
"\tisPredictable() FirstDef: " << *Inst <<
"\n");
459 SeenValues.
insert(FirstDef);
465 if (
auto *Phi = dyn_cast<PHINode>(Current)) {
466 for (
Value *Incoming : Phi->incoming_values()) {
467 if (!isPredictableValue(Incoming, SeenValues))
469 addInstToQueue(Incoming, Q, SeenValues);
472 }
else if (
SelectInst *SelI = dyn_cast<SelectInst>(Current)) {
473 if (!isValidSelectInst(SelI))
475 if (!isPredictableValue(SelI->getTrueValue(), SeenValues) ||
476 !isPredictableValue(SelI->getFalseValue(), SeenValues)) {
479 addInstToQueue(SelI->getTrueValue(), Q, SeenValues);
480 addInstToQueue(SelI->getFalseValue(), Q, SeenValues);
481 LLVM_DEBUG(
dbgs() <<
"\tisPredictable() select: " << *SelI <<
"\n");
482 if (
auto *SelIUse = dyn_cast<PHINode>(SelI->user_back()))
483 SelectInsts.push_back(SelectInstToUnfold(SelI, SelIUse));
497 if (isa<ConstantInt>(InpVal))
501 if (!isa<Instruction>(InpVal))
507 void addInstToQueue(
Value *Val, std::deque<Instruction *> &Q,
517 if (!
SI->hasOneUse())
522 if (!SIUse && !(isa<PHINode>(SIUse) || isa<SelectInst>(SIUse)))
533 if (isa<PHINode>(SIUse) &&
539 for (SelectInstToUnfold SIToUnfold : SelectInsts) {
553 struct AllSwitchPaths {
558 std::vector<ThreadingPath> &getThreadingPaths() {
return TPaths; }
559 unsigned getNumThreadingPaths() {
return TPaths.size(); }
561 BasicBlock *getSwitchBlock() {
return SwitchBlock; }
564 VisitedBlocks Visited;
565 PathsType LoopPaths =
paths(SwitchBlock, Visited, 1);
566 StateDefMap StateDef = getStateDefMap();
568 for (PathType Path : LoopPaths) {
573 if (StateDef.count(
BB) != 0) {
574 const PHINode *Phi = dyn_cast<PHINode>(StateDef[
BB]);
575 assert(Phi &&
"Expected a state-defining instr to be a phi node.");
578 if (
const ConstantInt *
C = dyn_cast<const ConstantInt>(V)) {
579 TPath.setExitValue(
C);
580 TPath.setDeterminator(
BB);
586 if (TPath.isExitValueSet() &&
BB ==
Path.front())
592 if (TPath.isExitValueSet() && isSupported(TPath))
593 TPaths.push_back(TPath);
603 unsigned PathDepth)
const {
611 <<
"Exploration stopped after visiting MaxPathLength="
623 if (!Successors.
insert(Succ).second)
627 if (Succ == SwitchBlock) {
633 if (Visited.contains(Succ))
636 PathsType SuccPaths =
paths(Succ, Visited, PathDepth + 1);
637 for (PathType Path : SuccPaths) {
638 PathType NewPath(Path);
639 NewPath.push_front(
BB);
640 Res.push_back(NewPath);
651 StateDefMap getStateDefMap()
const {
656 assert(isa<PHINode>(FirstDef) &&
"After select unfolding, all state "
657 "definitions are expected to be phi "
661 Stack.push_back(dyn_cast<PHINode>(FirstDef));
664 while (!
Stack.empty()) {
668 SeenValues.
insert(CurPhi);
671 if (Incoming == FirstDef || isa<ConstantInt>(Incoming) ||
676 assert(isa<PHINode>(Incoming) &&
"After select unfolding, all state "
677 "definitions are expected to be phi "
680 Stack.push_back(cast<PHINode>(Incoming));
705 bool isSupported(
const ThreadingPath &TPath) {
713 const BasicBlock *DeterminatorBB = TPath.getDeterminatorBB();
716 SwitchCondUseBB == TPath.getPath().front() &&
717 "The first BB in a threading path should have the switch instruction");
718 if (SwitchCondUseBB != TPath.getPath().front())
722 PathType
Path = TPath.getPath();
726 bool IsDetBBSeen =
false;
727 bool IsDefBBSeen =
false;
728 bool IsUseBBSeen =
false;
730 if (
BB == DeterminatorBB)
732 if (
BB == SwitchCondDefBB)
734 if (
BB == SwitchCondUseBB)
736 if (IsDetBBSeen && IsUseBBSeen && !IsDefBBSeen)
746 std::vector<ThreadingPath> TPaths;
749 struct TransformDFA {
754 : SwitchPaths(SwitchPaths), DT(DT), AC(AC),
TTI(
TTI), ORE(ORE),
755 EphValues(EphValues) {}
758 if (isLegalAndProfitableToTransform()) {
759 createAllExitPaths();
769 bool isLegalAndProfitableToTransform() {
775 DuplicateBlockMap DuplicateMap;
777 for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
778 PathType PathBBs = TPath.getPath();
779 uint64_t NextState = TPath.getExitValue();
780 const BasicBlock *Determinator = TPath.getDeterminatorBB();
784 BasicBlock *VisitedBB = getClonedBB(
BB, NextState, DuplicateMap);
787 DuplicateMap[
BB].push_back({
BB, NextState});
792 if (PathBBs.front() == Determinator)
797 auto DetIt =
std::find(PathBBs.begin(), PathBBs.end(), Determinator);
798 for (
auto BBIt = DetIt; BBIt != PathBBs.end(); BBIt++) {
800 VisitedBB = getClonedBB(
BB, NextState, DuplicateMap);
804 DuplicateMap[
BB].push_back({
BB, NextState});
808 LLVM_DEBUG(
dbgs() <<
"DFA Jump Threading: Not jump threading, contains "
809 <<
"non-duplicatable instructions.\n");
813 <<
"Contains non-duplicatable instructions.";
819 LLVM_DEBUG(
dbgs() <<
"DFA Jump Threading: Not jump threading, contains "
820 <<
"convergent instructions.\n");
823 <<
"Contains convergent instructions.";
829 unsigned DuplicationCost = 0;
831 unsigned JumpTableSize = 0;
834 if (JumpTableSize == 0) {
838 unsigned CondBranches =
840 DuplicationCost =
Metrics.NumInsts / CondBranches;
848 DuplicationCost =
Metrics.NumInsts / JumpTableSize;
851 LLVM_DEBUG(
dbgs() <<
"\nDFA Jump Threading: Cost to jump thread block "
852 << SwitchPaths->getSwitchBlock()->getName()
853 <<
" is: " << DuplicationCost <<
"\n\n");
856 LLVM_DEBUG(
dbgs() <<
"Not jump threading, duplication cost exceeds the "
857 <<
"cost threshold.\n");
860 <<
"Duplication cost exceeds the cost threshold (cost="
861 <<
ore::NV(
"Cost", DuplicationCost)
869 <<
"Switch statement jump-threaded.";
876 void createAllExitPaths() {
880 BasicBlock *SwitchBlock = SwitchPaths->getSwitchBlock();
881 for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
883 PathType NewPath(TPath.getPath());
884 NewPath.push_back(SwitchBlock);
885 TPath.setPath(NewPath);
889 DuplicateBlockMap DuplicateMap;
896 for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
897 createExitPath(NewDefs, TPath, DuplicateMap, BlocksToClean, &DTU);
903 for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths())
904 updateLastSuccessor(TPath, DuplicateMap, &DTU);
920 void createExitPath(DefMap &NewDefs, ThreadingPath &Path,
921 DuplicateBlockMap &DuplicateMap,
926 PathType PathBBs =
Path.getPath();
929 if (PathBBs.front() == Determinator)
932 auto DetIt =
std::find(PathBBs.begin(), PathBBs.end(), Determinator);
933 auto Prev = std::prev(DetIt);
935 for (
auto BBIt = DetIt; BBIt != PathBBs.end(); BBIt++) {
941 BasicBlock *NextBB = getClonedBB(
BB, NextState, DuplicateMap);
943 updatePredecessor(PrevBB,
BB, NextBB, DTU);
949 BasicBlock *NewBB = cloneBlockAndUpdatePredecessor(
950 BB, PrevBB, NextState, DuplicateMap, NewDefs, DTU);
951 DuplicateMap[
BB].push_back({NewBB, NextState});
952 BlocksToClean.
insert(NewBB);
963 void updateSSA(DefMap &NewDefs) {
967 for (
auto KV : NewDefs) {
970 std::vector<Instruction *> Cloned = KV.second;
974 for (
Use &U :
I->uses()) {
977 if (UserPN->getIncomingBlock(U) ==
BB)
979 }
else if (
User->getParent() ==
BB) {
983 UsesToRename.push_back(&U);
988 if (UsesToRename.empty())
996 unsigned VarNum = SSAUpdate.
AddVariable(
I->getName(),
I->getType());
1001 while (!UsesToRename.empty())
1017 DuplicateBlockMap &DuplicateMap,
1030 if (isa<PHINode>(&
I))
1038 updateSuccessorPhis(
BB, NewBB, NextState, VMap, DuplicateMap);
1039 updatePredecessor(PrevBB,
BB, NewBB, DTU);
1040 updateDefMap(NewDefs, VMap);
1045 if (SuccSet.
insert(SuccBB).second)
1058 DuplicateBlockMap &DuplicateMap) {
1059 std::vector<BasicBlock *> BlocksToUpdate;
1063 if (
BB == SwitchPaths->getSwitchBlock()) {
1065 BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState);
1066 BlocksToUpdate.push_back(NextCase);
1067 BasicBlock *ClonedSucc = getClonedBB(NextCase, NextState, DuplicateMap);
1069 BlocksToUpdate.push_back(ClonedSucc);
1074 BlocksToUpdate.push_back(Succ);
1079 BasicBlock *ClonedSucc = getClonedBB(Succ, NextState, DuplicateMap);
1081 BlocksToUpdate.push_back(ClonedSucc);
1089 for (
auto II = Succ->begin();
PHINode *Phi = dyn_cast<PHINode>(II);
1093 if (isa<Constant>(Incoming)) {
1097 Value *ClonedVal = VMap[Incoming];
1113 if (!isPredecessor(OldBB, PrevBB))
1133 for (
auto Entry : VMap) {
1135 dyn_cast<Instruction>(
const_cast<Value *
>(Entry.first));
1136 if (!Inst || !Entry.second || isa<BranchInst>(Inst) ||
1137 isa<SwitchInst>(Inst)) {
1141 Instruction *Cloned = dyn_cast<Instruction>(Entry.second);
1145 NewDefsVector.push_back({Inst, Cloned});
1149 sort(NewDefsVector, [](
const auto &
LHS,
const auto &
RHS) {
1150 if (
LHS.first ==
RHS.first)
1151 return LHS.second->comesBefore(
RHS.second);
1152 return LHS.first->comesBefore(
RHS.first);
1155 for (
const auto &KV : NewDefsVector)
1156 NewDefs[KV.first].push_back(KV.second);
1164 void updateLastSuccessor(ThreadingPath &TPath,
1165 DuplicateBlockMap &DuplicateMap,
1167 uint64_t NextState = TPath.getExitValue();
1169 BasicBlock *LastBlock = getClonedBB(
BB, NextState, DuplicateMap);
1176 BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState);
1178 std::vector<DominatorTree::UpdateType> DTUpdates;
1181 if (Succ != NextCase && SuccSet.
insert(Succ).second)
1185 Switch->eraseFromParent();
1196 std::vector<PHINode *> PhiToRemove;
1197 for (
auto II =
BB->begin();
PHINode *Phi = dyn_cast<PHINode>(II); ++II) {
1198 PhiToRemove.push_back(Phi);
1200 for (
PHINode *PN : PhiToRemove) {
1202 PN->eraseFromParent();
1208 for (
auto II =
BB->begin();
PHINode *Phi = dyn_cast<PHINode>(II); ++II) {
1209 std::vector<BasicBlock *> BlocksToRemove;
1211 if (!isPredecessor(
BB, IncomingBB))
1212 BlocksToRemove.push_back(IncomingBB);
1222 DuplicateBlockMap &DuplicateMap) {
1223 CloneList ClonedBBs = DuplicateMap[
BB];
1227 auto It =
llvm::find_if(ClonedBBs, [NextState](
const ClonedBlock &
C) {
1228 return C.State == NextState;
1230 return It != ClonedBBs.end() ? (*It).BB :
nullptr;
1237 for (
auto Case :
Switch->cases()) {
1238 if (Case.getCaseValue()->getZExtValue() == NextState) {
1239 NextCase = Case.getCaseSuccessor();
1244 NextCase =
Switch->getDefaultDest();
1253 AllSwitchPaths *SwitchPaths;
1259 std::vector<ThreadingPath> TPaths;
1263 LLVM_DEBUG(
dbgs() <<
"\nDFA Jump threading: " <<
F.getName() <<
"\n");
1265 if (
F.hasOptSize()) {
1266 LLVM_DEBUG(
dbgs() <<
"Skipping due to the 'minsize' attribute\n");
1274 bool MadeChanges =
false;
1277 auto *
SI = dyn_cast<SwitchInst>(
BB.getTerminator());
1282 <<
" is predictable\n");
1289 <<
"candidate for jump threading\n");
1292 unfoldSelectInstrs(DT,
Switch.getSelectInsts());
1293 if (!
Switch.getSelectInsts().empty())
1296 AllSwitchPaths SwitchPaths(&Switch, ORE);
1299 if (SwitchPaths.getNumThreadingPaths() > 0) {
1300 ThreadableLoops.push_back(SwitchPaths);
1312 if (ThreadableLoops.size() > 0)
1315 for (AllSwitchPaths SwitchPaths : ThreadableLoops) {
1316 TransformDFA Transform(&SwitchPaths, DT, AC,
TTI, ORE, EphValues);
1321 #ifdef EXPENSIVE_CHECKS
1339 if (!DFAJumpThreading(&AC, &DT, &
TTI, &ORE).
run(
F))