LLVM  9.0.0svn
LoopPredication.cpp
Go to the documentation of this file.
1 //===-- LoopPredication.cpp - Guard based loop predication pass -----------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // The LoopPredication pass tries to convert loop variant range checks to loop
11 // invariant by widening checks across loop iterations. For example, it will
12 // convert
13 //
14 // for (i = 0; i < n; i++) {
15 // guard(i < len);
16 // ...
17 // }
18 //
19 // to
20 //
21 // for (i = 0; i < n; i++) {
22 // guard(n - 1 < len);
23 // ...
24 // }
25 //
26 // After this transformation the condition of the guard is loop invariant, so
27 // loop-unswitch can later unswitch the loop by this condition which basically
28 // predicates the loop by the widened condition:
29 //
30 // if (n - 1 < len)
31 // for (i = 0; i < n; i++) {
32 // ...
33 // }
34 // else
35 // deoptimize
36 //
37 // It's tempting to rely on SCEV here, but it has proven to be problematic.
38 // Generally the facts SCEV provides about the increment step of add
39 // recurrences are true if the backedge of the loop is taken, which implicitly
40 // assumes that the guard doesn't fail. Using these facts to optimize the
41 // guard results in a circular logic where the guard is optimized under the
42 // assumption that it never fails.
43 //
44 // For example, in the loop below the induction variable will be marked as nuw
45 // basing on the guard. Basing on nuw the guard predicate will be considered
46 // monotonic. Given a monotonic condition it's tempting to replace the induction
47 // variable in the condition with its value on the last iteration. But this
48 // transformation is not correct, e.g. e = 4, b = 5 breaks the loop.
49 //
50 // for (int i = b; i != e; i++)
51 // guard(i u< len)
52 //
53 // One of the ways to reason about this problem is to use an inductive proof
54 // approach. Given the loop:
55 //
56 // if (B(0)) {
57 // do {
58 // I = PHI(0, I.INC)
59 // I.INC = I + Step
60 // guard(G(I));
61 // } while (B(I));
62 // }
63 //
64 // where B(x) and G(x) are predicates that map integers to booleans, we want a
65 // loop invariant expression M such the following program has the same semantics
66 // as the above:
67 //
68 // if (B(0)) {
69 // do {
70 // I = PHI(0, I.INC)
71 // I.INC = I + Step
72 // guard(G(0) && M);
73 // } while (B(I));
74 // }
75 //
76 // One solution for M is M = forall X . (G(X) && B(X)) => G(X + Step)
77 //
78 // Informal proof that the transformation above is correct:
79 //
80 // By the definition of guards we can rewrite the guard condition to:
81 // G(I) && G(0) && M
82 //
83 // Let's prove that for each iteration of the loop:
84 // G(0) && M => G(I)
85 // And the condition above can be simplified to G(Start) && M.
86 //
87 // Induction base.
88 // G(0) && M => G(0)
89 //
90 // Induction step. Assuming G(0) && M => G(I) on the subsequent
91 // iteration:
92 //
93 // B(I) is true because it's the backedge condition.
94 // G(I) is true because the backedge is guarded by this condition.
95 //
96 // So M = forall X . (G(X) && B(X)) => G(X + Step) implies G(I + Step).
97 //
98 // Note that we can use anything stronger than M, i.e. any condition which
99 // implies M.
100 //
101 // When S = 1 (i.e. forward iterating loop), the transformation is supported
102 // when:
103 // * The loop has a single latch with the condition of the form:
104 // B(X) = latchStart + X <pred> latchLimit,
105 // where <pred> is u<, u<=, s<, or s<=.
106 // * The guard condition is of the form
107 // G(X) = guardStart + X u< guardLimit
108 //
109 // For the ult latch comparison case M is:
110 // forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit =>
111 // guardStart + X + 1 u< guardLimit
112 //
113 // The only way the antecedent can be true and the consequent can be false is
114 // if
115 // X == guardLimit - 1 - guardStart
116 // (and guardLimit is non-zero, but we won't use this latter fact).
117 // If X == guardLimit - 1 - guardStart then the second half of the antecedent is
118 // latchStart + guardLimit - 1 - guardStart u< latchLimit
119 // and its negation is
120 // latchStart + guardLimit - 1 - guardStart u>= latchLimit
121 //
122 // In other words, if
123 // latchLimit u<= latchStart + guardLimit - 1 - guardStart
124 // then:
125 // (the ranges below are written in ConstantRange notation, where [A, B) is the
126 // set for (I = A; I != B; I++ /*maywrap*/) yield(I);)
127 //
128 // forall X . guardStart + X u< guardLimit &&
129 // latchStart + X u< latchLimit =>
130 // guardStart + X + 1 u< guardLimit
131 // == forall X . guardStart + X u< guardLimit &&
132 // latchStart + X u< latchStart + guardLimit - 1 - guardStart =>
133 // guardStart + X + 1 u< guardLimit
134 // == forall X . (guardStart + X) in [0, guardLimit) &&
135 // (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) =>
136 // (guardStart + X + 1) in [0, guardLimit)
137 // == forall X . X in [-guardStart, guardLimit - guardStart) &&
138 // X in [-latchStart, guardLimit - 1 - guardStart) =>
139 // X in [-guardStart - 1, guardLimit - guardStart - 1)
140 // == true
141 //
142 // So the widened condition is:
143 // guardStart u< guardLimit &&
144 // latchStart + guardLimit - 1 - guardStart u>= latchLimit
145 // Similarly for ule condition the widened condition is:
146 // guardStart u< guardLimit &&
147 // latchStart + guardLimit - 1 - guardStart u> latchLimit
148 // For slt condition the widened condition is:
149 // guardStart u< guardLimit &&
150 // latchStart + guardLimit - 1 - guardStart s>= latchLimit
151 // For sle condition the widened condition is:
152 // guardStart u< guardLimit &&
153 // latchStart + guardLimit - 1 - guardStart s> latchLimit
154 //
155 // When S = -1 (i.e. reverse iterating loop), the transformation is supported
156 // when:
157 // * The loop has a single latch with the condition of the form:
158 // B(X) = X <pred> latchLimit, where <pred> is u>, u>=, s>, or s>=.
159 // * The guard condition is of the form
160 // G(X) = X - 1 u< guardLimit
161 //
162 // For the ugt latch comparison case M is:
163 // forall X. X-1 u< guardLimit and X u> latchLimit => X-2 u< guardLimit
164 //
165 // The only way the antecedent can be true and the consequent can be false is if
166 // X == 1.
167 // If X == 1 then the second half of the antecedent is
168 // 1 u> latchLimit, and its negation is latchLimit u>= 1.
169 //
170 // So the widened condition is:
171 // guardStart u< guardLimit && latchLimit u>= 1.
172 // Similarly for sgt condition the widened condition is:
173 // guardStart u< guardLimit && latchLimit s>= 1.
174 // For uge condition the widened condition is:
175 // guardStart u< guardLimit && latchLimit u> 1.
176 // For sge condition the widened condition is:
177 // guardStart u< guardLimit && latchLimit s> 1.
178 //===----------------------------------------------------------------------===//
179 
181 #include "llvm/ADT/Statistic.h"
184 #include "llvm/Analysis/LoopInfo.h"
185 #include "llvm/Analysis/LoopPass.h"
189 #include "llvm/IR/Function.h"
190 #include "llvm/IR/GlobalValue.h"
191 #include "llvm/IR/IntrinsicInst.h"
192 #include "llvm/IR/Module.h"
193 #include "llvm/IR/PatternMatch.h"
194 #include "llvm/Pass.h"
195 #include "llvm/Support/Debug.h"
196 #include "llvm/Transforms/Scalar.h"
198 
199 #define DEBUG_TYPE "loop-predication"
200 
201 STATISTIC(TotalConsidered, "Number of guards considered");
202 STATISTIC(TotalWidened, "Number of checks widened");
203 
204 using namespace llvm;
205 
206 static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation",
207  cl::Hidden, cl::init(true));
208 
209 static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop",
210  cl::Hidden, cl::init(true));
211 
212 static cl::opt<bool>
213  SkipProfitabilityChecks("loop-predication-skip-profitability-checks",
214  cl::Hidden, cl::init(false));
215 
216 // This is the scale factor for the latch probability. We use this during
217 // profitability analysis to find other exiting blocks that have a much higher
218 // probability of exiting the loop instead of loop exiting via latch.
219 // This value should be greater than 1 for a sane profitability check.
221  "loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0),
222  cl::desc("scale factor for the latch probability. Value should be greater "
223  "than 1. Lower values are ignored"));
224 
225 namespace {
226 class LoopPredication {
227  /// Represents an induction variable check:
228  /// icmp Pred, <induction variable>, <loop invariant limit>
229  struct LoopICmp {
230  ICmpInst::Predicate Pred;
231  const SCEVAddRecExpr *IV;
232  const SCEV *Limit;
233  LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV,
234  const SCEV *Limit)
235  : Pred(Pred), IV(IV), Limit(Limit) {}
236  LoopICmp() {}
237  void dump() {
238  dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV
239  << ", Limit = " << *Limit << "\n";
240  }
241  };
242 
243  ScalarEvolution *SE;
245 
246  Loop *L;
247  const DataLayout *DL;
248  BasicBlock *Preheader;
249  LoopICmp LatchCheck;
250 
251  bool isSupportedStep(const SCEV* Step);
252  Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI) {
253  return parseLoopICmp(ICI->getPredicate(), ICI->getOperand(0),
254  ICI->getOperand(1));
255  }
256  Optional<LoopICmp> parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS,
257  Value *RHS);
258 
259  Optional<LoopICmp> parseLoopLatchICmp();
260 
261  bool CanExpand(const SCEV* S);
262  Value *expandCheck(SCEVExpander &Expander, IRBuilder<> &Builder,
263  ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
264  Instruction *InsertAt);
265 
266  Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander,
267  IRBuilder<> &Builder);
268  Optional<Value *> widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck,
269  LoopICmp RangeCheck,
270  SCEVExpander &Expander,
271  IRBuilder<> &Builder);
272  Optional<Value *> widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck,
273  LoopICmp RangeCheck,
274  SCEVExpander &Expander,
275  IRBuilder<> &Builder);
276  bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
277 
278  // If the loop always exits through another block in the loop, we should not
279  // predicate based on the latch check. For example, the latch check can be a
280  // very coarse grained check and there can be more fine grained exit checks
281  // within the loop. We identify such unprofitable loops through BPI.
282  bool isLoopProfitableToPredicate();
283 
284  // When the IV type is wider than the range operand type, we can still do loop
285  // predication, by generating SCEVs for the range and latch that are of the
286  // same type. We achieve this by generating a SCEV truncate expression for the
287  // latch IV. This is done iff truncation of the IV is a safe operation,
288  // without loss of information.
289  // Another way to achieve this is by generating a wider type SCEV for the
290  // range check operand, however, this needs a more involved check that
291  // operands do not overflow. This can lead to loss of information when the
292  // range operand is of the form: add i32 %offset, %iv. We need to prove that
293  // sext(x + y) is same as sext(x) + sext(y).
294  // This function returns true if we can safely represent the IV type in
295  // the RangeCheckType without loss of information.
296  bool isSafeToTruncateWideIVType(Type *RangeCheckType);
297  // Return the loopLatchCheck corresponding to the RangeCheckType if safe to do
298  // so.
299  Optional<LoopICmp> generateLoopLatchCheck(Type *RangeCheckType);
300 
301 public:
302  LoopPredication(ScalarEvolution *SE, BranchProbabilityInfo *BPI)
303  : SE(SE), BPI(BPI){};
304  bool runOnLoop(Loop *L);
305 };
306 
307 class LoopPredicationLegacyPass : public LoopPass {
308 public:
309  static char ID;
310  LoopPredicationLegacyPass() : LoopPass(ID) {
312  }
313 
314  void getAnalysisUsage(AnalysisUsage &AU) const override {
317  }
318 
319  bool runOnLoop(Loop *L, LPPassManager &LPM) override {
320  if (skipLoop(L))
321  return false;
322  auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
323  BranchProbabilityInfo &BPI =
324  getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
325  LoopPredication LP(SE, &BPI);
326  return LP.runOnLoop(L);
327  }
328 };
329 
331 } // end namespace llvm
332 
333 INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication",
334  "Loop predication", false, false)
337 INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication",
338  "Loop predication", false, false)
339 
341  return new LoopPredicationLegacyPass();
342 }
343 
346  LPMUpdater &U) {
347  const auto &FAM =
348  AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager();
349  Function *F = L.getHeader()->getParent();
350  auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F);
351  LoopPredication LP(&AR.SE, BPI);
352  if (!LP.runOnLoop(&L))
353  return PreservedAnalyses::all();
354 
356 }
357 
359 LoopPredication::parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS,
360  Value *RHS) {
361  const SCEV *LHSS = SE->getSCEV(LHS);
362  if (isa<SCEVCouldNotCompute>(LHSS))
363  return None;
364  const SCEV *RHSS = SE->getSCEV(RHS);
365  if (isa<SCEVCouldNotCompute>(RHSS))
366  return None;
367 
368  // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV
369  if (SE->isLoopInvariant(LHSS, L)) {
370  std::swap(LHS, RHS);
371  std::swap(LHSS, RHSS);
372  Pred = ICmpInst::getSwappedPredicate(Pred);
373  }
374 
375  const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHSS);
376  if (!AR || AR->getLoop() != L)
377  return None;
378 
379  return LoopICmp(Pred, AR, RHSS);
380 }
381 
382 Value *LoopPredication::expandCheck(SCEVExpander &Expander,
383  IRBuilder<> &Builder,
384  ICmpInst::Predicate Pred, const SCEV *LHS,
385  const SCEV *RHS, Instruction *InsertAt) {
386  // TODO: we can check isLoopEntryGuardedByCond before emitting the check
387 
388  Type *Ty = LHS->getType();
389  assert(Ty == RHS->getType() && "expandCheck operands have different types?");
390 
391  if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS))
392  return Builder.getTrue();
393 
394  Value *LHSV = Expander.expandCodeFor(LHS, Ty, InsertAt);
395  Value *RHSV = Expander.expandCodeFor(RHS, Ty, InsertAt);
396  return Builder.CreateICmp(Pred, LHSV, RHSV);
397 }
398 
400 LoopPredication::generateLoopLatchCheck(Type *RangeCheckType) {
401 
402  auto *LatchType = LatchCheck.IV->getType();
403  if (RangeCheckType == LatchType)
404  return LatchCheck;
405  // For now, bail out if latch type is narrower than range type.
406  if (DL->getTypeSizeInBits(LatchType) < DL->getTypeSizeInBits(RangeCheckType))
407  return None;
408  if (!isSafeToTruncateWideIVType(RangeCheckType))
409  return None;
410  // We can now safely identify the truncated version of the IV and limit for
411  // RangeCheckType.
412  LoopICmp NewLatchCheck;
413  NewLatchCheck.Pred = LatchCheck.Pred;
414  NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>(
415  SE->getTruncateExpr(LatchCheck.IV, RangeCheckType));
416  if (!NewLatchCheck.IV)
417  return None;
418  NewLatchCheck.Limit = SE->getTruncateExpr(LatchCheck.Limit, RangeCheckType);
419  LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType
420  << "can be represented as range check type:"
421  << *RangeCheckType << "\n");
422  LLVM_DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n");
423  LLVM_DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n");
424  return NewLatchCheck;
425 }
426 
427 bool LoopPredication::isSupportedStep(const SCEV* Step) {
428  return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop);
429 }
430 
431 bool LoopPredication::CanExpand(const SCEV* S) {
432  return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE);
433 }
434 
435 Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
436  LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck,
437  SCEVExpander &Expander, IRBuilder<> &Builder) {
438  auto *Ty = RangeCheck.IV->getType();
439  // Generate the widened condition for the forward loop:
440  // guardStart u< guardLimit &&
441  // latchLimit <pred> guardLimit - 1 - guardStart + latchStart
442  // where <pred> depends on the latch condition predicate. See the file
443  // header comment for the reasoning.
444  // guardLimit - guardStart + latchStart - 1
445  const SCEV *GuardStart = RangeCheck.IV->getStart();
446  const SCEV *GuardLimit = RangeCheck.Limit;
447  const SCEV *LatchStart = LatchCheck.IV->getStart();
448  const SCEV *LatchLimit = LatchCheck.Limit;
449 
450  // guardLimit - guardStart + latchStart - 1
451  const SCEV *RHS =
452  SE->getAddExpr(SE->getMinusSCEV(GuardLimit, GuardStart),
453  SE->getMinusSCEV(LatchStart, SE->getOne(Ty)));
454  if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) ||
455  !CanExpand(LatchLimit) || !CanExpand(RHS)) {
456  LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
457  return None;
458  }
459  auto LimitCheckPred =
461 
462  LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n");
463  LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n");
464  LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n");
465 
466  Instruction *InsertAt = Preheader->getTerminator();
467  auto *LimitCheck =
468  expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, RHS, InsertAt);
469  auto *FirstIterationCheck = expandCheck(Expander, Builder, RangeCheck.Pred,
470  GuardStart, GuardLimit, InsertAt);
471  return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
472 }
473 
474 Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
475  LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck,
476  SCEVExpander &Expander, IRBuilder<> &Builder) {
477  auto *Ty = RangeCheck.IV->getType();
478  const SCEV *GuardStart = RangeCheck.IV->getStart();
479  const SCEV *GuardLimit = RangeCheck.Limit;
480  const SCEV *LatchLimit = LatchCheck.Limit;
481  if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) ||
482  !CanExpand(LatchLimit)) {
483  LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
484  return None;
485  }
486  // The decrement of the latch check IV should be the same as the
487  // rangeCheckIV.
488  auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE);
489  if (RangeCheck.IV != PostDecLatchCheckIV) {
490  LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: "
491  << *PostDecLatchCheckIV
492  << " and RangeCheckIV: " << *RangeCheck.IV << "\n");
493  return None;
494  }
495 
496  // Generate the widened condition for CountDownLoop:
497  // guardStart u< guardLimit &&
498  // latchLimit <pred> 1.
499  // See the header comment for reasoning of the checks.
500  Instruction *InsertAt = Preheader->getTerminator();
501  auto LimitCheckPred =
503  auto *FirstIterationCheck = expandCheck(Expander, Builder, ICmpInst::ICMP_ULT,
504  GuardStart, GuardLimit, InsertAt);
505  auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, LatchLimit,
506  SE->getOne(Ty), InsertAt);
507  return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
508 }
509 
510 /// If ICI can be widened to a loop invariant condition emits the loop
511 /// invariant condition in the loop preheader and return it, otherwise
512 /// returns None.
513 Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
514  SCEVExpander &Expander,
515  IRBuilder<> &Builder) {
516  LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
517  LLVM_DEBUG(ICI->dump());
518 
519  // parseLoopStructure guarantees that the latch condition is:
520  // ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=.
521  // We are looking for the range checks of the form:
522  // i u< guardLimit
523  auto RangeCheck = parseLoopICmp(ICI);
524  if (!RangeCheck) {
525  LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
526  return None;
527  }
528  LLVM_DEBUG(dbgs() << "Guard check:\n");
529  LLVM_DEBUG(RangeCheck->dump());
530  if (RangeCheck->Pred != ICmpInst::ICMP_ULT) {
531  LLVM_DEBUG(dbgs() << "Unsupported range check predicate("
532  << RangeCheck->Pred << ")!\n");
533  return None;
534  }
535  auto *RangeCheckIV = RangeCheck->IV;
536  if (!RangeCheckIV->isAffine()) {
537  LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n");
538  return None;
539  }
540  auto *Step = RangeCheckIV->getStepRecurrence(*SE);
541  // We cannot just compare with latch IV step because the latch and range IVs
542  // may have different types.
543  if (!isSupportedStep(Step)) {
544  LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n");
545  return None;
546  }
547  auto *Ty = RangeCheckIV->getType();
548  auto CurrLatchCheckOpt = generateLoopLatchCheck(Ty);
549  if (!CurrLatchCheckOpt) {
550  LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check "
551  "corresponding to range type: "
552  << *Ty << "\n");
553  return None;
554  }
555 
556  LoopICmp CurrLatchCheck = *CurrLatchCheckOpt;
557  // At this point, the range and latch step should have the same type, but need
558  // not have the same value (we support both 1 and -1 steps).
559  assert(Step->getType() ==
560  CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() &&
561  "Range and latch steps should be of same type!");
562  if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) {
563  LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n");
564  return None;
565  }
566 
567  if (Step->isOne())
568  return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
569  Expander, Builder);
570  else {
571  assert(Step->isAllOnesValue() && "Step should be -1!");
572  return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck,
573  Expander, Builder);
574  }
575 }
576 
577 bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
578  SCEVExpander &Expander) {
579  LLVM_DEBUG(dbgs() << "Processing guard:\n");
580  LLVM_DEBUG(Guard->dump());
581 
582  TotalConsidered++;
583 
584  IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator()));
585 
586  // The guard condition is expected to be in form of:
587  // cond1 && cond2 && cond3 ...
588  // Iterate over subconditions looking for icmp conditions which can be
589  // widened across loop iterations. Widening these conditions remember the
590  // resulting list of subconditions in Checks vector.
591  SmallVector<Value *, 4> Worklist(1, Guard->getOperand(0));
592  SmallPtrSet<Value *, 4> Visited;
593 
595 
596  unsigned NumWidened = 0;
597  do {
598  Value *Condition = Worklist.pop_back_val();
599  if (!Visited.insert(Condition).second)
600  continue;
601 
602  Value *LHS, *RHS;
603  using namespace llvm::PatternMatch;
604  if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) {
605  Worklist.push_back(LHS);
606  Worklist.push_back(RHS);
607  continue;
608  }
609 
610  if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
611  if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Builder)) {
612  Checks.push_back(NewRangeCheck.getValue());
613  NumWidened++;
614  continue;
615  }
616  }
617 
618  // Save the condition as is if we can't widen it
619  Checks.push_back(Condition);
620  } while (Worklist.size() != 0);
621 
622  if (NumWidened == 0)
623  return false;
624 
625  TotalWidened += NumWidened;
626 
627  // Emit the new guard condition
628  Builder.SetInsertPoint(Guard);
629  Value *LastCheck = nullptr;
630  for (auto *Check : Checks)
631  if (!LastCheck)
632  LastCheck = Check;
633  else
634  LastCheck = Builder.CreateAnd(LastCheck, Check);
635  Guard->setOperand(0, LastCheck);
636 
637  LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
638  return true;
639 }
640 
641 Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() {
642  using namespace PatternMatch;
643 
644  BasicBlock *LoopLatch = L->getLoopLatch();
645  if (!LoopLatch) {
646  LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n");
647  return None;
648  }
649 
650  ICmpInst::Predicate Pred;
651  Value *LHS, *RHS;
652  BasicBlock *TrueDest, *FalseDest;
653 
654  if (!match(LoopLatch->getTerminator(),
655  m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), TrueDest,
656  FalseDest))) {
657  LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n");
658  return None;
659  }
660  assert((TrueDest == L->getHeader() || FalseDest == L->getHeader()) &&
661  "One of the latch's destinations must be the header");
662  if (TrueDest != L->getHeader())
663  Pred = ICmpInst::getInversePredicate(Pred);
664 
665  auto Result = parseLoopICmp(Pred, LHS, RHS);
666  if (!Result) {
667  LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
668  return None;
669  }
670 
671  // Check affine first, so if it's not we don't try to compute the step
672  // recurrence.
673  if (!Result->IV->isAffine()) {
674  LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n");
675  return None;
676  }
677 
678  auto *Step = Result->IV->getStepRecurrence(*SE);
679  if (!isSupportedStep(Step)) {
680  LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n");
681  return None;
682  }
683 
684  auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) {
685  if (Step->isOne()) {
686  return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT &&
687  Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE;
688  } else {
689  assert(Step->isAllOnesValue() && "Step should be -1!");
690  return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT &&
691  Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_SGE;
692  }
693  };
694 
695  if (IsUnsupportedPredicate(Step, Result->Pred)) {
696  LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred
697  << ")!\n");
698  return None;
699  }
700  return Result;
701 }
702 
703 // Returns true if its safe to truncate the IV to RangeCheckType.
704 bool LoopPredication::isSafeToTruncateWideIVType(Type *RangeCheckType) {
705  if (!EnableIVTruncation)
706  return false;
707  assert(DL->getTypeSizeInBits(LatchCheck.IV->getType()) >
708  DL->getTypeSizeInBits(RangeCheckType) &&
709  "Expected latch check IV type to be larger than range check operand "
710  "type!");
711  // The start and end values of the IV should be known. This is to guarantee
712  // that truncating the wide type will not lose information.
713  auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit);
714  auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart());
715  if (!Limit || !Start)
716  return false;
717  // This check makes sure that the IV does not change sign during loop
718  // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE,
719  // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the
720  // IV wraps around, and the truncation of the IV would lose the range of
721  // iterations between 2^32 and 2^64.
722  bool Increasing;
723  if (!SE->isMonotonicPredicate(LatchCheck.IV, LatchCheck.Pred, Increasing))
724  return false;
725  // The active bits should be less than the bits in the RangeCheckType. This
726  // guarantees that truncating the latch check to RangeCheckType is a safe
727  // operation.
728  auto RangeCheckTypeBitSize = DL->getTypeSizeInBits(RangeCheckType);
729  return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize &&
730  Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize;
731 }
732 
733 bool LoopPredication::isLoopProfitableToPredicate() {
734  if (SkipProfitabilityChecks || !BPI)
735  return true;
736 
738  L->getExitEdges(ExitEdges);
739  // If there is only one exiting edge in the loop, it is always profitable to
740  // predicate the loop.
741  if (ExitEdges.size() == 1)
742  return true;
743 
744  // Calculate the exiting probabilities of all exiting edges from the loop,
745  // starting with the LatchExitProbability.
746  // Heuristic for profitability: If any of the exiting blocks' probability of
747  // exiting the loop is larger than exiting through the latch block, it's not
748  // profitable to predicate the loop.
749  auto *LatchBlock = L->getLoopLatch();
750  assert(LatchBlock && "Should have a single latch at this point!");
751  auto *LatchTerm = LatchBlock->getTerminator();
752  assert(LatchTerm->getNumSuccessors() == 2 &&
753  "expected to be an exiting block with 2 succs!");
754  unsigned LatchBrExitIdx =
755  LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0;
756  BranchProbability LatchExitProbability =
757  BPI->getEdgeProbability(LatchBlock, LatchBrExitIdx);
758 
759  // Protect against degenerate inputs provided by the user. Providing a value
760  // less than one, can invert the definition of profitable loop predication.
761  float ScaleFactor = LatchExitProbabilityScale;
762  if (ScaleFactor < 1) {
763  LLVM_DEBUG(
764  dbgs()
765  << "Ignored user setting for loop-predication-latch-probability-scale: "
766  << LatchExitProbabilityScale << "\n");
767  LLVM_DEBUG(dbgs() << "The value is set to 1.0\n");
768  ScaleFactor = 1.0;
769  }
770  const auto LatchProbabilityThreshold =
771  LatchExitProbability * ScaleFactor;
772 
773  for (const auto &ExitEdge : ExitEdges) {
774  BranchProbability ExitingBlockProbability =
775  BPI->getEdgeProbability(ExitEdge.first, ExitEdge.second);
776  // Some exiting edge has higher probability than the latch exiting edge.
777  // No longer profitable to predicate.
778  if (ExitingBlockProbability > LatchProbabilityThreshold)
779  return false;
780  }
781  // Using BPI, we have concluded that the most probable way to exit from the
782  // loop is through the latch (or there's no profile information and all
783  // exits are equally likely).
784  return true;
785 }
786 
787 bool LoopPredication::runOnLoop(Loop *Loop) {
788  L = Loop;
789 
790  LLVM_DEBUG(dbgs() << "Analyzing ");
791  LLVM_DEBUG(L->dump());
792 
793  Module *M = L->getHeader()->getModule();
794 
795  // There is nothing to do if the module doesn't use guards
796  auto *GuardDecl =
797  M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard));
798  if (!GuardDecl || GuardDecl->use_empty())
799  return false;
800 
801  DL = &M->getDataLayout();
802 
803  Preheader = L->getLoopPreheader();
804  if (!Preheader)
805  return false;
806 
807  auto LatchCheckOpt = parseLoopLatchICmp();
808  if (!LatchCheckOpt)
809  return false;
810  LatchCheck = *LatchCheckOpt;
811 
812  LLVM_DEBUG(dbgs() << "Latch check:\n");
813  LLVM_DEBUG(LatchCheck.dump());
814 
815  if (!isLoopProfitableToPredicate()) {
816  LLVM_DEBUG(dbgs() << "Loop not profitable to predicate!\n");
817  return false;
818  }
819  // Collect all the guards into a vector and process later, so as not
820  // to invalidate the instruction iterator.
822  for (const auto BB : L->blocks())
823  for (auto &I : *BB)
824  if (isGuard(&I))
825  Guards.push_back(cast<IntrinsicInst>(&I));
826 
827  if (Guards.empty())
828  return false;
829 
830  SCEVExpander Expander(*SE, *DL, "loop-predication");
831 
832  bool Changed = false;
833  for (auto *Guard : Guards)
834  Changed |= widenGuardConditions(Guard, Expander);
835 
836  return Changed;
837 }
Pass interface - Implemented by all &#39;passes&#39;.
Definition: Pass.h:81
static bool Check(DecodeStatus &Out, DecodeStatus In)
BinaryOp_match< LHS, RHS, Instruction::And > m_And(const LHS &L, const RHS &R)
Definition: PatternMatch.h:749
A parsed version of the target data layout string in and methods for querying it. ...
Definition: DataLayout.h:111
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:71
Value * CreateICmp(CmpInst::Predicate P, Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:1949
BlockT * getLoopLatch() const
If there is a single latch block for this loop, return it.
Definition: LoopInfoImpl.h:225
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
PreservedAnalyses getLoopPassPreservedAnalyses()
Returns the minimum set of Analyses that all loop passes must preserve.
STATISTIC(TotalConsidered, "Number of guards considered")
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:770
This class represents lattice values for constants.
Definition: AllocatorList.h:24
loop predication
A Module instance is used to store all the information related to an LLVM module. ...
Definition: Module.h:65
The main scalar evolution driver.
BlockT * getLoopPreheader() const
If there is a preheader for this loop, return it.
Definition: LoopInfoImpl.h:174
unsigned less or equal
Definition: InstrTypes.h:672
unsigned less than
Definition: InstrTypes.h:671
bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
bool isMonotonicPredicate(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred, bool &Increasing)
Return true if, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or...
The adaptor from a function pass to a loop pass computes these analyses and makes them available to t...
F(f)
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.cpp:138
INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication", "Loop predication", false, false) INITIALIZE_PASS_END(LoopPredicationLegacyPass
void dump() const
Support for debugging, callable in GDB: V->dump()
Definition: AsmWriter.cpp:4298
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:48
AnalysisUsage & addRequired()
const Module * getModule() const
Return the module owning the function this basic block belongs to, or nullptr if the function does no...
Definition: BasicBlock.cpp:134
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:51
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Definition: Function.cpp:626
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE, etc.
Definition: InstrTypes.h:745
static cl::opt< float > LatchExitProbabilityScale("loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0), cl::desc("scale factor for the latch probability. Value should be greater " "than 1. Lower values are ignored"))
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:743
BlockT * getHeader() const
Definition: LoopInfo.h:100
Analysis pass which computes BranchProbabilityInfo.
This node represents a polynomial recurrence on the trip count of the specified loop.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block...
Definition: IRBuilder.h:127
Legacy analysis pass which computes BranchProbabilityInfo.
Value * getOperand(unsigned i) const
Definition: User.h:170
static cl::opt< bool > EnableIVTruncation("loop-predication-enable-iv-truncation", cl::Hidden, cl::init(true))
bool isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS. ...
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:423
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:154
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
void dump() const
Definition: LoopInfo.cpp:401
LLVM Basic Block Representation.
Definition: BasicBlock.h:58
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:46
ConstantInt * getTrue()
Get the constant value for i1 true.
Definition: IRBuilder.h:287
const SCEV * getAddExpr(SmallVectorImpl< const SCEV *> &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
brc_match< Cond_t > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
Represent the analysis usage information of a pass.
Pass * createLoopPredicationPass()
This instruction compares its operands according to the predicate given to the constructor.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:646
Value * expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I)
Insert code to directly compute the specified SCEV expression into the program.
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS. Minus is represented in SCEV as A+B*-1.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:160
size_t size() const
Definition: SmallVector.h:53
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE, "Assign register bank of generic virtual registers", false, false) RegBankSelect
signed greater than
Definition: InstrTypes.h:673
BranchProbability getEdgeProbability(const BasicBlock *Src, unsigned IndexInSuccessors) const
Get an edge&#39;s probability, relative to other out-edges of the Src.
void getExitEdges(SmallVectorImpl< Edge > &ExitEdges) const
Return all pairs of (inside_block,outside_block).
Definition: LoopInfoImpl.h:155
This class provides an interface for updating the loop pass manager based on mutations to the loop ne...
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements...
Definition: SmallPtrSet.h:418
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
Type * getType() const
Return the LLVM type of this SCEV expression.
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty)
An analysis over an "inner" IR unit that provides access to an analysis manager over a "outer" IR uni...
Definition: PassManager.h:1154
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:847
Module.h This file contains the declarations for the Module class.
signed less than
Definition: InstrTypes.h:675
Predicate getFlippedStrictnessPredicate() const
For predicate of kind "is X or equal to 0" returns the predicate "is X".
Definition: InstrTypes.h:776
bool isGuard(const User *U)
Returns true iff U has semantics of a guard.
Definition: GuardUtils.cpp:18
void setOperand(unsigned i, Value *Val)
Definition: User.h:175
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:133
Function * getFunction(StringRef Name) const
Look up the specified function in the module symbol table.
Definition: Module.cpp:176
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:941
signed less or equal
Definition: InstrTypes.h:676
static cl::opt< bool > EnableCountDownLoop("loop-predication-enable-count-down-loop", cl::Hidden, cl::init(true))
This class uses information about analyze scalars to rewrite expressions in canonical form...
iterator insert(iterator I, T &&Elt)
Definition: SmallVector.h:478
uint64_t getTypeSizeInBits(Type *Ty) const
Size examples:
Definition: DataLayout.h:568
Predicate getPredicate() const
Return the predicate for this instruction.
Definition: InstrTypes.h:721
Analysis providing branch probability information.
This class represents an analyzed expression in the program.
LLVM_NODISCARD bool empty() const
Definition: SmallVector.h:56
unsigned greater or equal
Definition: InstrTypes.h:670
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:465
#define I(x, y, z)
Definition: MD5.cpp:58
void getLoopAnalysisUsage(AnalysisUsage &AU)
Helper to consistently add the set of standard passes to a loop pass&#39;s AnalysisUsage.
Definition: LoopUtils.cpp:132
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:323
static cl::opt< bool > SkipProfitabilityChecks("loop-predication-skip-profitability-checks", cl::Hidden, cl::init(false))
Value * CreateAnd(Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:1164
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
void initializeLoopPredicationLegacyPassPass(PassRegistry &)
PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U)
bool isOne() const
Return true if the expression is a constant one.
LLVM Value Representation.
Definition: Value.h:73
const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
unsigned greater than
Definition: InstrTypes.h:669
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:761
A container for analyses that lazily runs them and caches their results.
#define LLVM_DEBUG(X)
Definition: Debug.h:123
bool isSafeToExpand(const SCEV *S, ScalarEvolution &SE)
Return true if the given expression is safe to expand in the sense that all materialized values are s...
iterator_range< block_iterator > blocks() const
Definition: LoopInfo.h:156
signed greater or equal
Definition: InstrTypes.h:674
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:44
This class represents a constant integer value.
CmpClass_match< LHS, RHS, ICmpInst, ICmpInst::Predicate > m_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R)