204 #define DEBUG_TYPE "loop-predication"
206 STATISTIC(TotalConsidered,
"Number of guards considered");
207 STATISTIC(TotalWidened,
"Number of checks widened");
209 using namespace llvm;
227 cl::desc(
"scale factor for the latch probability. Value should be greater "
228 "than 1. Lower values are ignored"));
231 "loop-predication-predicate-widenable-branches-to-deopt",
cl::Hidden,
232 cl::desc(
"Whether or not we should predicate guards "
233 "expressed as widenable branches to deoptimize blocks"),
245 : Pred(Pred),
IV(
IV), Limit(Limit) {}
246 LoopICmp() =
default;
248 dbgs() <<
"LoopICmp Pred = " << Pred <<
", IV = " << *
IV
249 <<
", Limit = " << *Limit <<
"\n";
265 bool isSupportedStep(
const SCEV* Step);
283 bool isLoopInvariantValue(
const SCEV*
S);
307 bool isLoopProfitableToPredicate();
314 :
AA(
AA), DT(DT), SE(SE), LI(LI), MSSAU(MSSAU){};
315 bool runOnLoop(
Loop *L);
318 class LoopPredicationLegacyPass :
public LoopPass {
334 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
335 auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
336 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
337 auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>();
338 std::unique_ptr<MemorySSAUpdater> MSSAU;
340 MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA());
341 auto *
AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
343 return LP.runOnLoop(L);
351 "Loop predication",
false,
false)
358 return new LoopPredicationLegacyPass();
364 std::unique_ptr<MemorySSAUpdater> MSSAU;
366 MSSAU = std::make_unique<MemorySSAUpdater>(AR.
MSSA);
368 MSSAU ? MSSAU.get() :
nullptr);
369 if (!LP.runOnLoop(&L))
379 LoopPredication::parseLoopICmp(
ICmpInst *ICI) {
384 const SCEV *LHSS = SE->getSCEV(
LHS);
385 if (isa<SCEVCouldNotCompute>(LHSS))
387 const SCEV *RHSS = SE->getSCEV(
RHS);
388 if (isa<SCEVCouldNotCompute>(RHSS))
392 if (SE->isLoopInvariant(LHSS, L)) {
402 return LoopICmp(Pred, AR, RHSS);
410 assert(Ty ==
RHS->
getType() &&
"expandCheck operands have different types?");
412 if (SE->isLoopInvariant(
LHS, L) && SE->isLoopInvariant(
RHS, L)) {
414 if (SE->isLoopEntryGuardedByCond(L, Pred,
LHS,
RHS))
424 return Builder.CreateICmp(Pred, LHSV, RHSV);
443 const LoopICmp LatchCheck,
444 Type *RangeCheckType) {
447 assert(
DL.getTypeSizeInBits(LatchCheck.IV->getType()).getFixedSize() >
448 DL.getTypeSizeInBits(RangeCheckType).getFixedSize() &&
449 "Expected latch check IV type to be larger than range check operand "
453 auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit);
454 auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart());
455 if (!Limit || !Start)
467 auto RangeCheckTypeBitSize =
468 DL.getTypeSizeInBits(RangeCheckType).getFixedSize();
469 return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize &&
470 Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize;
478 const LoopICmp LatchCheck,
479 Type *RangeCheckType) {
481 auto *LatchType = LatchCheck.IV->getType();
482 if (RangeCheckType == LatchType)
485 if (
DL.getTypeSizeInBits(LatchType).getFixedSize() <
486 DL.getTypeSizeInBits(RangeCheckType).getFixedSize())
492 LoopICmp NewLatchCheck;
493 NewLatchCheck.Pred = LatchCheck.Pred;
494 NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>(
496 if (!NewLatchCheck.IV)
498 NewLatchCheck.Limit = SE.
getTruncateExpr(LatchCheck.Limit, RangeCheckType);
500 <<
"can be represented as range check type:"
501 << *RangeCheckType <<
"\n");
502 LLVM_DEBUG(
dbgs() <<
"LatchCheck.IV: " << *NewLatchCheck.IV <<
"\n");
503 LLVM_DEBUG(
dbgs() <<
"LatchCheck.Limit: " << *NewLatchCheck.Limit <<
"\n");
504 return NewLatchCheck;
507 bool LoopPredication::isSupportedStep(
const SCEV* Step) {
516 return Preheader->getTerminator();
524 for (
const SCEV *
Op : Ops)
525 if (!SE->isLoopInvariant(
Op, L) ||
528 return Preheader->getTerminator();
531 bool LoopPredication::isLoopInvariantValue(
const SCEV*
S) {
551 if (SE->isLoopInvariant(
S, L))
560 if (
const auto *LI = dyn_cast<LoadInst>(U->getValue()))
562 if (
AA->pointsToConstantMemory(LI->getOperand(0)) ||
563 LI->hasMetadata(LLVMContext::MD_invariant_load))
569 LoopICmp LatchCheck, LoopICmp RangeCheck,
571 auto *Ty = RangeCheck.IV->getType();
578 const SCEV *GuardStart = RangeCheck.IV->getStart();
579 const SCEV *GuardLimit = RangeCheck.Limit;
580 const SCEV *LatchStart = LatchCheck.IV->getStart();
581 const SCEV *LatchLimit = LatchCheck.Limit;
585 if (!isLoopInvariantValue(GuardStart) ||
586 !isLoopInvariantValue(GuardLimit) ||
587 !isLoopInvariantValue(LatchStart) ||
588 !isLoopInvariantValue(LatchLimit)) {
600 SE->getAddExpr(SE->getMinusSCEV(GuardLimit, GuardStart),
601 SE->getMinusSCEV(LatchStart, SE->getOne(Ty)));
602 auto LimitCheckPred =
610 expandCheck(Expander, Guard, LimitCheckPred, LatchLimit,
RHS);
611 auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred,
612 GuardStart, GuardLimit);
614 return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
618 LoopICmp LatchCheck, LoopICmp RangeCheck,
620 auto *Ty = RangeCheck.IV->getType();
621 const SCEV *GuardStart = RangeCheck.IV->getStart();
622 const SCEV *GuardLimit = RangeCheck.Limit;
623 const SCEV *LatchStart = LatchCheck.IV->getStart();
624 const SCEV *LatchLimit = LatchCheck.Limit;
628 if (!isLoopInvariantValue(GuardStart) ||
629 !isLoopInvariantValue(GuardLimit) ||
630 !isLoopInvariantValue(LatchStart) ||
631 !isLoopInvariantValue(LatchLimit)) {
642 auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE);
643 if (RangeCheck.IV != PostDecLatchCheckIV) {
645 << *PostDecLatchCheckIV
646 <<
" and RangeCheckIV: " << *RangeCheck.IV <<
"\n");
654 auto LimitCheckPred =
656 auto *FirstIterationCheck = expandCheck(Expander, Guard,
658 GuardStart, GuardLimit);
659 auto *LimitCheck = expandCheck(Expander, Guard, LimitCheckPred, LatchLimit,
662 return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
670 RC.IV->getStepRecurrence(*SE)->isOne() &&
690 auto RangeCheck = parseLoopICmp(ICI);
692 LLVM_DEBUG(
dbgs() <<
"Failed to parse the loop latch condition!\n");
699 << RangeCheck->Pred <<
")!\n");
702 auto *RangeCheckIV = RangeCheck->IV;
703 if (!RangeCheckIV->isAffine()) {
707 auto *Step = RangeCheckIV->getStepRecurrence(*SE);
710 if (!isSupportedStep(Step)) {
711 LLVM_DEBUG(
dbgs() <<
"Range check and latch have IVs different steps!\n");
714 auto *Ty = RangeCheckIV->getType();
716 if (!CurrLatchCheckOpt) {
718 "corresponding to range type: "
723 LoopICmp CurrLatchCheck = *CurrLatchCheckOpt;
727 CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() &&
728 "Range and latch steps should be of same type!");
729 if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) {
730 LLVM_DEBUG(
dbgs() <<
"Range and latch have different step values!\n");
735 return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
739 return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck,
748 unsigned NumWidened = 0;
756 Value *WideableCond =
nullptr;
758 Value *Condition = Worklist.pop_back_val();
759 if (!Visited.
insert(Condition).second)
765 Worklist.push_back(
LHS);
766 Worklist.push_back(
RHS);
771 m_Intrinsic<Intrinsic::experimental_widenable_condition>())) {
773 WideableCond = Condition;
777 if (
ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
778 if (
auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander,
780 Checks.push_back(NewRangeCheck.getValue());
787 Checks.push_back(Condition);
788 }
while (!Worklist.empty());
794 Checks.push_back(WideableCond);
798 bool LoopPredication::widenGuardConditions(
IntrinsicInst *Guard,
805 unsigned NumWidened = collectChecks(Checks, Guard->
getOperand(0), Expander,
810 TotalWidened += NumWidened;
823 bool LoopPredication::widenWidenableBranchGuardConditions(
831 unsigned NumWidened = collectChecks(Checks, BI->
getCondition(),
836 TotalWidened += NumWidened;
845 "Stopped being a guard after transform?");
852 using namespace PatternMatch;
868 "One of the latch's destinations must be the header");
875 auto Result = parseLoopICmp(ICI);
877 LLVM_DEBUG(
dbgs() <<
"Failed to parse the loop latch condition!\n");
886 if (!
Result->IV->isAffine()) {
891 auto *Step =
Result->IV->getStepRecurrence(*SE);
892 if (!isSupportedStep(Step)) {
893 LLVM_DEBUG(
dbgs() <<
"Unsupported loop stride(" << *Step <<
")!\n");
909 if (IsUnsupportedPredicate(Step,
Result->Pred)) {
919 bool LoopPredication::isLoopProfitableToPredicate() {
927 if (ExitEdges.size() == 1)
936 assert(LatchBlock &&
"Should have a single latch at this point!");
937 auto *LatchTerm = LatchBlock->getTerminator();
938 assert(LatchTerm->getNumSuccessors() == 2 &&
939 "expected to be an exiting block with 2 succs!");
940 unsigned LatchBrExitIdx =
941 LatchTerm->getSuccessor(0) == L->
getHeader() ? 1 : 0;
949 auto *LatchExitBlock = LatchTerm->getSuccessor(LatchBrExitIdx);
950 if (isa<UnreachableInst>(LatchTerm) ||
951 LatchExitBlock->getTerminatingDeoptimizeCall())
955 if (!ProfileData || !ProfileData->
getOperand(0))
958 if (!MDS->getString().equals(
"branch_weights"))
964 MDNode *LatchProfileData = LatchTerm->getMetadata(LLVMContext::MD_prof);
967 if (!IsValidProfileData(LatchProfileData, LatchTerm))
970 auto ComputeBranchProbability =
974 MDNode *ProfileData =
Term->getMetadata(LLVMContext::MD_prof);
975 unsigned NumSucc =
Term->getNumSuccessors();
976 if (IsValidProfileData(ProfileData,
Term)) {
977 uint64_t Numerator = 0, Denominator = 0, ProfVal = 0;
978 for (
unsigned i = 0;
i < NumSucc;
i++) {
980 mdconst::extract<ConstantInt>(ProfileData->
getOperand(
i + 1));
982 if (
Term->getSuccessor(
i) == ExitBlock)
983 Numerator += ProfVal;
984 Denominator += ProfVal;
988 assert(LatchBlock != ExitingBlock &&
989 "Latch term should always have profile data!");
996 ComputeBranchProbability(LatchBlock, LatchExitBlock);
1001 if (ScaleFactor < 1) {
1004 <<
"Ignored user setting for loop-predication-latch-probability-scale: "
1009 const auto LatchProbabilityThreshold = LatchExitProbability * ScaleFactor;
1011 for (
const auto &ExitEdge : ExitEdges) {
1013 ComputeBranchProbability(ExitEdge.first, ExitEdge.second);
1016 if (ExitingBlockProbability > LatchProbabilityThreshold)
1040 if (
BB == Pred->getSingleSuccessor()) {
1048 auto *
Term = Pred->getTerminator();
1054 return cast<BranchInst>(
Term);
1069 for (
BasicBlock *ExitingBB : ExitingBlocks) {
1071 if (isa<SCEVCouldNotCompute>(ExitCount))
1074 "We should only have known counts for exiting blocks that "
1076 ExitCounts.push_back(ExitCount);
1078 if (ExitCounts.size() < 2)
1111 if (ExitingBlocks.empty())
1122 const SCEV *LatchEC = SE->getExitCount(L, Latch);
1123 if (isa<SCEVCouldNotCompute>(LatchEC))
1132 bool ChangedLoop =
false;
1134 for (
auto *ExitingBB : ExitingBlocks) {
1135 if (LI->getLoopFor(ExitingBB) != L)
1138 auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
1161 !SE->isLoopInvariant(MinEC, L) ||
1169 auto *
IP = cast<Instruction>(WidenableBR->getCondition());
1172 IP->moveBefore(WidenableBR);
1174 if (
auto *MUD = MSSAU->getMemorySSA()->getMemoryAccess(
IP))
1175 MSSAU->moveToPlace(MUD, WidenableBR->getParent(),
1180 bool InvalidateLoop =
false;
1181 Value *MinECV =
nullptr;
1182 for (
BasicBlock *ExitingBB : ExitingBlocks) {
1186 if (LI->getLoopFor(ExitingBB) != L)
1190 auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
1198 const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
1199 if (isa<SCEVCouldNotCompute>(ExitCount) ||
1219 MinECV =
Rewriter.expandCodeFor(MinEC);
1223 ECV =
B.CreateZExt(ECV, WiderTy);
1224 RHS =
B.CreateZExt(
RHS, WiderTy);
1226 assert(!Latch || DT->dominates(ExitingBB, Latch));
1231 NewCond =
B.CreateFreeze(NewCond);
1237 InvalidateLoop =
true;
1251 bool LoopPredication::runOnLoop(
Loop *
Loop) {
1262 bool HasIntrinsicGuards = GuardDecl && !GuardDecl->use_empty();
1263 auto *WCDecl =
M->getFunction(
1265 bool HasWidenableConditions =
1267 if (!HasIntrinsicGuards && !HasWidenableConditions)
1270 DL = &
M->getDataLayout();
1276 auto LatchCheckOpt = parseLoopLatchICmp();
1279 LatchCheck = *LatchCheckOpt;
1284 if (!isLoopProfitableToPredicate()) {
1292 for (
const auto BB : L->
blocks()) {
1295 Guards.push_back(cast<IntrinsicInst>(&
I));
1298 GuardsAsWidenableBranches.push_back(
1299 cast<BranchInst>(
BB->getTerminator()));
1303 bool Changed =
false;
1304 for (
auto *Guard : Guards)
1305 Changed |= widenGuardConditions(Guard, Expander);
1306 for (
auto *Guard : GuardsAsWidenableBranches)
1307 Changed |= widenWidenableBranchGuardConditions(Guard, Expander);
1308 Changed |= predicateLoopExits(L, Expander);
1311 MSSAU->getMemorySSA()->verifyMemorySSA();