LLVM 17.0.0git
LoopPredication.cpp
Go to the documentation of this file.
1//===-- LoopPredication.cpp - Guard based loop predication pass -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The LoopPredication pass tries to convert loop variant range checks to loop
10// invariant by widening checks across loop iterations. For example, it will
11// convert
12//
13// for (i = 0; i < n; i++) {
14// guard(i < len);
15// ...
16// }
17//
18// to
19//
20// for (i = 0; i < n; i++) {
21// guard(n - 1 < len);
22// ...
23// }
24//
25// After this transformation the condition of the guard is loop invariant, so
26// loop-unswitch can later unswitch the loop by this condition which basically
27// predicates the loop by the widened condition:
28//
29// if (n - 1 < len)
30// for (i = 0; i < n; i++) {
31// ...
32// }
33// else
34// deoptimize
35//
36// It's tempting to rely on SCEV here, but it has proven to be problematic.
37// Generally the facts SCEV provides about the increment step of add
38// recurrences are true if the backedge of the loop is taken, which implicitly
39// assumes that the guard doesn't fail. Using these facts to optimize the
40// guard results in a circular logic where the guard is optimized under the
41// assumption that it never fails.
42//
43// For example, in the loop below the induction variable will be marked as nuw
44// basing on the guard. Basing on nuw the guard predicate will be considered
45// monotonic. Given a monotonic condition it's tempting to replace the induction
46// variable in the condition with its value on the last iteration. But this
47// transformation is not correct, e.g. e = 4, b = 5 breaks the loop.
48//
49// for (int i = b; i != e; i++)
50// guard(i u< len)
51//
52// One of the ways to reason about this problem is to use an inductive proof
53// approach. Given the loop:
54//
55// if (B(0)) {
56// do {
57// I = PHI(0, I.INC)
58// I.INC = I + Step
59// guard(G(I));
60// } while (B(I));
61// }
62//
63// where B(x) and G(x) are predicates that map integers to booleans, we want a
64// loop invariant expression M such the following program has the same semantics
65// as the above:
66//
67// if (B(0)) {
68// do {
69// I = PHI(0, I.INC)
70// I.INC = I + Step
71// guard(G(0) && M);
72// } while (B(I));
73// }
74//
75// One solution for M is M = forall X . (G(X) && B(X)) => G(X + Step)
76//
77// Informal proof that the transformation above is correct:
78//
79// By the definition of guards we can rewrite the guard condition to:
80// G(I) && G(0) && M
81//
82// Let's prove that for each iteration of the loop:
83// G(0) && M => G(I)
84// And the condition above can be simplified to G(Start) && M.
85//
86// Induction base.
87// G(0) && M => G(0)
88//
89// Induction step. Assuming G(0) && M => G(I) on the subsequent
90// iteration:
91//
92// B(I) is true because it's the backedge condition.
93// G(I) is true because the backedge is guarded by this condition.
94//
95// So M = forall X . (G(X) && B(X)) => G(X + Step) implies G(I + Step).
96//
97// Note that we can use anything stronger than M, i.e. any condition which
98// implies M.
99//
100// When S = 1 (i.e. forward iterating loop), the transformation is supported
101// when:
102// * The loop has a single latch with the condition of the form:
103// B(X) = latchStart + X <pred> latchLimit,
104// where <pred> is u<, u<=, s<, or s<=.
105// * The guard condition is of the form
106// G(X) = guardStart + X u< guardLimit
107//
108// For the ult latch comparison case M is:
109// forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit =>
110// guardStart + X + 1 u< guardLimit
111//
112// The only way the antecedent can be true and the consequent can be false is
113// if
114// X == guardLimit - 1 - guardStart
115// (and guardLimit is non-zero, but we won't use this latter fact).
116// If X == guardLimit - 1 - guardStart then the second half of the antecedent is
117// latchStart + guardLimit - 1 - guardStart u< latchLimit
118// and its negation is
119// latchStart + guardLimit - 1 - guardStart u>= latchLimit
120//
121// In other words, if
122// latchLimit u<= latchStart + guardLimit - 1 - guardStart
123// then:
124// (the ranges below are written in ConstantRange notation, where [A, B) is the
125// set for (I = A; I != B; I++ /*maywrap*/) yield(I);)
126//
127// forall X . guardStart + X u< guardLimit &&
128// latchStart + X u< latchLimit =>
129// guardStart + X + 1 u< guardLimit
130// == forall X . guardStart + X u< guardLimit &&
131// latchStart + X u< latchStart + guardLimit - 1 - guardStart =>
132// guardStart + X + 1 u< guardLimit
133// == forall X . (guardStart + X) in [0, guardLimit) &&
134// (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) =>
135// (guardStart + X + 1) in [0, guardLimit)
136// == forall X . X in [-guardStart, guardLimit - guardStart) &&
137// X in [-latchStart, guardLimit - 1 - guardStart) =>
138// X in [-guardStart - 1, guardLimit - guardStart - 1)
139// == true
140//
141// So the widened condition is:
142// guardStart u< guardLimit &&
143// latchStart + guardLimit - 1 - guardStart u>= latchLimit
144// Similarly for ule condition the widened condition is:
145// guardStart u< guardLimit &&
146// latchStart + guardLimit - 1 - guardStart u> latchLimit
147// For slt condition the widened condition is:
148// guardStart u< guardLimit &&
149// latchStart + guardLimit - 1 - guardStart s>= latchLimit
150// For sle condition the widened condition is:
151// guardStart u< guardLimit &&
152// latchStart + guardLimit - 1 - guardStart s> latchLimit
153//
154// When S = -1 (i.e. reverse iterating loop), the transformation is supported
155// when:
156// * The loop has a single latch with the condition of the form:
157// B(X) = X <pred> latchLimit, where <pred> is u>, u>=, s>, or s>=.
158// * The guard condition is of the form
159// G(X) = X - 1 u< guardLimit
160//
161// For the ugt latch comparison case M is:
162// forall X. X-1 u< guardLimit and X u> latchLimit => X-2 u< guardLimit
163//
164// The only way the antecedent can be true and the consequent can be false is if
165// X == 1.
166// If X == 1 then the second half of the antecedent is
167// 1 u> latchLimit, and its negation is latchLimit u>= 1.
168//
169// So the widened condition is:
170// guardStart u< guardLimit && latchLimit u>= 1.
171// Similarly for sgt condition the widened condition is:
172// guardStart u< guardLimit && latchLimit s>= 1.
173// For uge condition the widened condition is:
174// guardStart u< guardLimit && latchLimit u> 1.
175// For sge condition the widened condition is:
176// guardStart u< guardLimit && latchLimit s> 1.
177//===----------------------------------------------------------------------===//
178
180#include "llvm/ADT/Statistic.h"
190#include "llvm/IR/Function.h"
192#include "llvm/IR/Module.h"
193#include "llvm/IR/PatternMatch.h"
196#include "llvm/Pass.h"
198#include "llvm/Support/Debug.h"
204#include <optional>
205
206#define DEBUG_TYPE "loop-predication"
207
208STATISTIC(TotalConsidered, "Number of guards considered");
209STATISTIC(TotalWidened, "Number of checks widened");
210
211using namespace llvm;
212
213static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation",
214 cl::Hidden, cl::init(true));
215
216static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop",
217 cl::Hidden, cl::init(true));
218
219static cl::opt<bool>
220 SkipProfitabilityChecks("loop-predication-skip-profitability-checks",
221 cl::Hidden, cl::init(false));
222
223// This is the scale factor for the latch probability. We use this during
224// profitability analysis to find other exiting blocks that have a much higher
225// probability of exiting the loop instead of loop exiting via latch.
226// This value should be greater than 1 for a sane profitability check.
228 "loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0),
229 cl::desc("scale factor for the latch probability. Value should be greater "
230 "than 1. Lower values are ignored"));
231
233 "loop-predication-predicate-widenable-branches-to-deopt", cl::Hidden,
234 cl::desc("Whether or not we should predicate guards "
235 "expressed as widenable branches to deoptimize blocks"),
236 cl::init(true));
237
239 "loop-predication-insert-assumes-of-predicated-guards-conditions",
241 cl::desc("Whether or not we should insert assumes of conditions of "
242 "predicated guards"),
243 cl::init(true));
244
245namespace {
246/// Represents an induction variable check:
247/// icmp Pred, <induction variable>, <loop invariant limit>
248struct LoopICmp {
250 const SCEVAddRecExpr *IV;
251 const SCEV *Limit;
252 LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV,
253 const SCEV *Limit)
254 : Pred(Pred), IV(IV), Limit(Limit) {}
255 LoopICmp() = default;
256 void dump() {
257 dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV
258 << ", Limit = " << *Limit << "\n";
259 }
260};
261
262class LoopPredication {
263 AliasAnalysis *AA;
264 DominatorTree *DT;
265 ScalarEvolution *SE;
266 LoopInfo *LI;
267 MemorySSAUpdater *MSSAU;
268
269 Loop *L;
270 const DataLayout *DL;
271 BasicBlock *Preheader;
272 LoopICmp LatchCheck;
273
274 bool isSupportedStep(const SCEV* Step);
275 std::optional<LoopICmp> parseLoopICmp(ICmpInst *ICI);
276 std::optional<LoopICmp> parseLoopLatchICmp();
277
278 /// Return an insertion point suitable for inserting a safe to speculate
279 /// instruction whose only user will be 'User' which has operands 'Ops'. A
280 /// trivial result would be the at the User itself, but we try to return a
281 /// loop invariant location if possible.
282 Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops);
283 /// Same as above, *except* that this uses the SCEV definition of invariant
284 /// which is that an expression *can be made* invariant via SCEVExpander.
285 /// Thus, this version is only suitable for finding an insert point to be be
286 /// passed to SCEVExpander!
287 Instruction *findInsertPt(const SCEVExpander &Expander, Instruction *User,
289
290 /// Return true if the value is known to produce a single fixed value across
291 /// all iterations on which it executes. Note that this does not imply
292 /// speculation safety. That must be established separately.
293 bool isLoopInvariantValue(const SCEV* S);
294
295 Value *expandCheck(SCEVExpander &Expander, Instruction *Guard,
296 ICmpInst::Predicate Pred, const SCEV *LHS,
297 const SCEV *RHS);
298
299 std::optional<Value *> widenICmpRangeCheck(ICmpInst *ICI,
300 SCEVExpander &Expander,
301 Instruction *Guard);
302 std::optional<Value *>
303 widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck,
304 SCEVExpander &Expander,
305 Instruction *Guard);
306 std::optional<Value *>
307 widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck,
308 SCEVExpander &Expander,
309 Instruction *Guard);
310 unsigned collectChecks(SmallVectorImpl<Value *> &Checks, Value *Condition,
311 SCEVExpander &Expander, Instruction *Guard);
312 bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
313 bool widenWidenableBranchGuardConditions(BranchInst *Guard, SCEVExpander &Expander);
314 // If the loop always exits through another block in the loop, we should not
315 // predicate based on the latch check. For example, the latch check can be a
316 // very coarse grained check and there can be more fine grained exit checks
317 // within the loop.
318 bool isLoopProfitableToPredicate();
319
320 bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter);
321
322public:
323 LoopPredication(AliasAnalysis *AA, DominatorTree *DT, ScalarEvolution *SE,
324 LoopInfo *LI, MemorySSAUpdater *MSSAU)
325 : AA(AA), DT(DT), SE(SE), LI(LI), MSSAU(MSSAU){};
326 bool runOnLoop(Loop *L);
327};
328
329class LoopPredicationLegacyPass : public LoopPass {
330public:
331 static char ID;
332 LoopPredicationLegacyPass() : LoopPass(ID) {
334 }
335
336 void getAnalysisUsage(AnalysisUsage &AU) const override {
340 }
341
342 bool runOnLoop(Loop *L, LPPassManager &LPM) override {
343 if (skipLoop(L))
344 return false;
345 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
346 auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
347 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
348 auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>();
349 std::unique_ptr<MemorySSAUpdater> MSSAU;
350 if (MSSAWP)
351 MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA());
352 auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
353 LoopPredication LP(AA, DT, SE, LI, MSSAU ? MSSAU.get() : nullptr);
354 return LP.runOnLoop(L);
355 }
356};
357
358char LoopPredicationLegacyPass::ID = 0;
359} // end namespace
360
361INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication",
362 "Loop predication", false, false)
365INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication",
367
369 return new LoopPredicationLegacyPass();
370}
371
374 LPMUpdater &U) {
375 std::unique_ptr<MemorySSAUpdater> MSSAU;
376 if (AR.MSSA)
377 MSSAU = std::make_unique<MemorySSAUpdater>(AR.MSSA);
378 LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI,
379 MSSAU ? MSSAU.get() : nullptr);
380 if (!LP.runOnLoop(&L))
381 return PreservedAnalyses::all();
382
384 if (AR.MSSA)
385 PA.preserve<MemorySSAAnalysis>();
386 return PA;
387}
388
389std::optional<LoopICmp> LoopPredication::parseLoopICmp(ICmpInst *ICI) {
390 auto Pred = ICI->getPredicate();
391 auto *LHS = ICI->getOperand(0);
392 auto *RHS = ICI->getOperand(1);
393
394 const SCEV *LHSS = SE->getSCEV(LHS);
395 if (isa<SCEVCouldNotCompute>(LHSS))
396 return std::nullopt;
397 const SCEV *RHSS = SE->getSCEV(RHS);
398 if (isa<SCEVCouldNotCompute>(RHSS))
399 return std::nullopt;
400
401 // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV
402 if (SE->isLoopInvariant(LHSS, L)) {
403 std::swap(LHS, RHS);
404 std::swap(LHSS, RHSS);
406 }
407
408 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHSS);
409 if (!AR || AR->getLoop() != L)
410 return std::nullopt;
411
412 return LoopICmp(Pred, AR, RHSS);
413}
414
415Value *LoopPredication::expandCheck(SCEVExpander &Expander,
416 Instruction *Guard,
417 ICmpInst::Predicate Pred, const SCEV *LHS,
418 const SCEV *RHS) {
419 Type *Ty = LHS->getType();
420 assert(Ty == RHS->getType() && "expandCheck operands have different types?");
421
422 if (SE->isLoopInvariant(LHS, L) && SE->isLoopInvariant(RHS, L)) {
423 IRBuilder<> Builder(Guard);
424 if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS))
425 return Builder.getTrue();
426 if (SE->isLoopEntryGuardedByCond(L, ICmpInst::getInversePredicate(Pred),
427 LHS, RHS))
428 return Builder.getFalse();
429 }
430
431 Value *LHSV =
432 Expander.expandCodeFor(LHS, Ty, findInsertPt(Expander, Guard, {LHS}));
433 Value *RHSV =
434 Expander.expandCodeFor(RHS, Ty, findInsertPt(Expander, Guard, {RHS}));
435 IRBuilder<> Builder(findInsertPt(Guard, {LHSV, RHSV}));
436 return Builder.CreateICmp(Pred, LHSV, RHSV);
437}
438
439// Returns true if its safe to truncate the IV to RangeCheckType.
440// When the IV type is wider than the range operand type, we can still do loop
441// predication, by generating SCEVs for the range and latch that are of the
442// same type. We achieve this by generating a SCEV truncate expression for the
443// latch IV. This is done iff truncation of the IV is a safe operation,
444// without loss of information.
445// Another way to achieve this is by generating a wider type SCEV for the
446// range check operand, however, this needs a more involved check that
447// operands do not overflow. This can lead to loss of information when the
448// range operand is of the form: add i32 %offset, %iv. We need to prove that
449// sext(x + y) is same as sext(x) + sext(y).
450// This function returns true if we can safely represent the IV type in
451// the RangeCheckType without loss of information.
453 ScalarEvolution &SE,
454 const LoopICmp LatchCheck,
455 Type *RangeCheckType) {
457 return false;
458 assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()).getFixedValue() >
459 DL.getTypeSizeInBits(RangeCheckType).getFixedValue() &&
460 "Expected latch check IV type to be larger than range check operand "
461 "type!");
462 // The start and end values of the IV should be known. This is to guarantee
463 // that truncating the wide type will not lose information.
464 auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit);
465 auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart());
466 if (!Limit || !Start)
467 return false;
468 // This check makes sure that the IV does not change sign during loop
469 // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE,
470 // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the
471 // IV wraps around, and the truncation of the IV would lose the range of
472 // iterations between 2^32 and 2^64.
473 if (!SE.getMonotonicPredicateType(LatchCheck.IV, LatchCheck.Pred))
474 return false;
475 // The active bits should be less than the bits in the RangeCheckType. This
476 // guarantees that truncating the latch check to RangeCheckType is a safe
477 // operation.
478 auto RangeCheckTypeBitSize =
479 DL.getTypeSizeInBits(RangeCheckType).getFixedValue();
480 return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize &&
481 Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize;
482}
483
484
485// Return an LoopICmp describing a latch check equivlent to LatchCheck but with
486// the requested type if safe to do so. May involve the use of a new IV.
487static std::optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL,
488 ScalarEvolution &SE,
489 const LoopICmp LatchCheck,
490 Type *RangeCheckType) {
491
492 auto *LatchType = LatchCheck.IV->getType();
493 if (RangeCheckType == LatchType)
494 return LatchCheck;
495 // For now, bail out if latch type is narrower than range type.
496 if (DL.getTypeSizeInBits(LatchType).getFixedValue() <
497 DL.getTypeSizeInBits(RangeCheckType).getFixedValue())
498 return std::nullopt;
499 if (!isSafeToTruncateWideIVType(DL, SE, LatchCheck, RangeCheckType))
500 return std::nullopt;
501 // We can now safely identify the truncated version of the IV and limit for
502 // RangeCheckType.
503 LoopICmp NewLatchCheck;
504 NewLatchCheck.Pred = LatchCheck.Pred;
505 NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>(
506 SE.getTruncateExpr(LatchCheck.IV, RangeCheckType));
507 if (!NewLatchCheck.IV)
508 return std::nullopt;
509 NewLatchCheck.Limit = SE.getTruncateExpr(LatchCheck.Limit, RangeCheckType);
510 LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType
511 << "can be represented as range check type:"
512 << *RangeCheckType << "\n");
513 LLVM_DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n");
514 LLVM_DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n");
515 return NewLatchCheck;
516}
517
518bool LoopPredication::isSupportedStep(const SCEV* Step) {
519 return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop);
520}
521
522Instruction *LoopPredication::findInsertPt(Instruction *Use,
523 ArrayRef<Value*> Ops) {
524 for (Value *Op : Ops)
525 if (!L->isLoopInvariant(Op))
526 return Use;
527 return Preheader->getTerminator();
528}
529
530Instruction *LoopPredication::findInsertPt(const SCEVExpander &Expander,
533 // Subtlety: SCEV considers things to be invariant if the value produced is
534 // the same across iterations. This is not the same as being able to
535 // evaluate outside the loop, which is what we actually need here.
536 for (const SCEV *Op : Ops)
537 if (!SE->isLoopInvariant(Op, L) ||
538 !Expander.isSafeToExpandAt(Op, Preheader->getTerminator()))
539 return Use;
540 return Preheader->getTerminator();
541}
542
543bool LoopPredication::isLoopInvariantValue(const SCEV* S) {
544 // Handling expressions which produce invariant results, but *haven't* yet
545 // been removed from the loop serves two important purposes.
546 // 1) Most importantly, it resolves a pass ordering cycle which would
547 // otherwise need us to iteration licm, loop-predication, and either
548 // loop-unswitch or loop-peeling to make progress on examples with lots of
549 // predicable range checks in a row. (Since, in the general case, we can't
550 // hoist the length checks until the dominating checks have been discharged
551 // as we can't prove doing so is safe.)
552 // 2) As a nice side effect, this exposes the value of peeling or unswitching
553 // much more obviously in the IR. Otherwise, the cost modeling for other
554 // transforms would end up needing to duplicate all of this logic to model a
555 // check which becomes predictable based on a modeled peel or unswitch.
556 //
557 // The cost of doing so in the worst case is an extra fill from the stack in
558 // the loop to materialize the loop invariant test value instead of checking
559 // against the original IV which is presumable in a register inside the loop.
560 // Such cases are presumably rare, and hint at missing oppurtunities for
561 // other passes.
562
563 if (SE->isLoopInvariant(S, L))
564 // Note: This the SCEV variant, so the original Value* may be within the
565 // loop even though SCEV has proven it is loop invariant.
566 return true;
567
568 // Handle a particular important case which SCEV doesn't yet know about which
569 // shows up in range checks on arrays with immutable lengths.
570 // TODO: This should be sunk inside SCEV.
571 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S))
572 if (const auto *LI = dyn_cast<LoadInst>(U->getValue()))
573 if (LI->isUnordered() && L->hasLoopInvariantOperands(LI))
574 if (!isModSet(AA->getModRefInfoMask(LI->getOperand(0))) ||
575 LI->hasMetadata(LLVMContext::MD_invariant_load))
576 return true;
577 return false;
578}
579
580std::optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
581 LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander,
582 Instruction *Guard) {
583 auto *Ty = RangeCheck.IV->getType();
584 // Generate the widened condition for the forward loop:
585 // guardStart u< guardLimit &&
586 // latchLimit <pred> guardLimit - 1 - guardStart + latchStart
587 // where <pred> depends on the latch condition predicate. See the file
588 // header comment for the reasoning.
589 // guardLimit - guardStart + latchStart - 1
590 const SCEV *GuardStart = RangeCheck.IV->getStart();
591 const SCEV *GuardLimit = RangeCheck.Limit;
592 const SCEV *LatchStart = LatchCheck.IV->getStart();
593 const SCEV *LatchLimit = LatchCheck.Limit;
594 // Subtlety: We need all the values to be *invariant* across all iterations,
595 // but we only need to check expansion safety for those which *aren't*
596 // already guaranteed to dominate the guard.
597 if (!isLoopInvariantValue(GuardStart) ||
598 !isLoopInvariantValue(GuardLimit) ||
599 !isLoopInvariantValue(LatchStart) ||
600 !isLoopInvariantValue(LatchLimit)) {
601 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
602 return std::nullopt;
603 }
604 if (!Expander.isSafeToExpandAt(LatchStart, Guard) ||
605 !Expander.isSafeToExpandAt(LatchLimit, Guard)) {
606 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
607 return std::nullopt;
608 }
609
610 // guardLimit - guardStart + latchStart - 1
611 const SCEV *RHS =
612 SE->getAddExpr(SE->getMinusSCEV(GuardLimit, GuardStart),
613 SE->getMinusSCEV(LatchStart, SE->getOne(Ty)));
614 auto LimitCheckPred =
616
617 LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n");
618 LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n");
619 LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n");
620
621 auto *LimitCheck =
622 expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, RHS);
623 auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred,
624 GuardStart, GuardLimit);
625 IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
626 return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
627}
628
629std::optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
630 LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander,
631 Instruction *Guard) {
632 auto *Ty = RangeCheck.IV->getType();
633 const SCEV *GuardStart = RangeCheck.IV->getStart();
634 const SCEV *GuardLimit = RangeCheck.Limit;
635 const SCEV *LatchStart = LatchCheck.IV->getStart();
636 const SCEV *LatchLimit = LatchCheck.Limit;
637 // Subtlety: We need all the values to be *invariant* across all iterations,
638 // but we only need to check expansion safety for those which *aren't*
639 // already guaranteed to dominate the guard.
640 if (!isLoopInvariantValue(GuardStart) ||
641 !isLoopInvariantValue(GuardLimit) ||
642 !isLoopInvariantValue(LatchStart) ||
643 !isLoopInvariantValue(LatchLimit)) {
644 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
645 return std::nullopt;
646 }
647 if (!Expander.isSafeToExpandAt(LatchStart, Guard) ||
648 !Expander.isSafeToExpandAt(LatchLimit, Guard)) {
649 LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
650 return std::nullopt;
651 }
652 // The decrement of the latch check IV should be the same as the
653 // rangeCheckIV.
654 auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE);
655 if (RangeCheck.IV != PostDecLatchCheckIV) {
656 LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: "
657 << *PostDecLatchCheckIV
658 << " and RangeCheckIV: " << *RangeCheck.IV << "\n");
659 return std::nullopt;
660 }
661
662 // Generate the widened condition for CountDownLoop:
663 // guardStart u< guardLimit &&
664 // latchLimit <pred> 1.
665 // See the header comment for reasoning of the checks.
666 auto LimitCheckPred =
668 auto *FirstIterationCheck = expandCheck(Expander, Guard,
670 GuardStart, GuardLimit);
671 auto *LimitCheck = expandCheck(Expander, Guard, LimitCheckPred, LatchLimit,
672 SE->getOne(Ty));
673 IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
674 return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
675}
676
678 LoopICmp& RC) {
679 // LFTR canonicalizes checks to the ICMP_NE/EQ form; normalize back to the
680 // ULT/UGE form for ease of handling by our caller.
681 if (ICmpInst::isEquality(RC.Pred) &&
682 RC.IV->getStepRecurrence(*SE)->isOne() &&
683 SE->isKnownPredicate(ICmpInst::ICMP_ULE, RC.IV->getStart(), RC.Limit))
684 RC.Pred = RC.Pred == ICmpInst::ICMP_NE ?
686}
687
688/// If ICI can be widened to a loop invariant condition emits the loop
689/// invariant condition in the loop preheader and return it, otherwise
690/// returns std::nullopt.
691std::optional<Value *>
692LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander,
693 Instruction *Guard) {
694 LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
695 LLVM_DEBUG(ICI->dump());
696
697 // parseLoopStructure guarantees that the latch condition is:
698 // ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=.
699 // We are looking for the range checks of the form:
700 // i u< guardLimit
701 auto RangeCheck = parseLoopICmp(ICI);
702 if (!RangeCheck) {
703 LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
704 return std::nullopt;
705 }
706 LLVM_DEBUG(dbgs() << "Guard check:\n");
707 LLVM_DEBUG(RangeCheck->dump());
708 if (RangeCheck->Pred != ICmpInst::ICMP_ULT) {
709 LLVM_DEBUG(dbgs() << "Unsupported range check predicate("
710 << RangeCheck->Pred << ")!\n");
711 return std::nullopt;
712 }
713 auto *RangeCheckIV = RangeCheck->IV;
714 if (!RangeCheckIV->isAffine()) {
715 LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n");
716 return std::nullopt;
717 }
718 auto *Step = RangeCheckIV->getStepRecurrence(*SE);
719 // We cannot just compare with latch IV step because the latch and range IVs
720 // may have different types.
721 if (!isSupportedStep(Step)) {
722 LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n");
723 return std::nullopt;
724 }
725 auto *Ty = RangeCheckIV->getType();
726 auto CurrLatchCheckOpt = generateLoopLatchCheck(*DL, *SE, LatchCheck, Ty);
727 if (!CurrLatchCheckOpt) {
728 LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check "
729 "corresponding to range type: "
730 << *Ty << "\n");
731 return std::nullopt;
732 }
733
734 LoopICmp CurrLatchCheck = *CurrLatchCheckOpt;
735 // At this point, the range and latch step should have the same type, but need
736 // not have the same value (we support both 1 and -1 steps).
737 assert(Step->getType() ==
738 CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() &&
739 "Range and latch steps should be of same type!");
740 if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) {
741 LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n");
742 return std::nullopt;
743 }
744
745 if (Step->isOne())
746 return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
747 Expander, Guard);
748 else {
749 assert(Step->isAllOnesValue() && "Step should be -1!");
750 return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck,
751 Expander, Guard);
752 }
753}
754
755unsigned LoopPredication::collectChecks(SmallVectorImpl<Value *> &Checks,
756 Value *Condition,
757 SCEVExpander &Expander,
758 Instruction *Guard) {
759 unsigned NumWidened = 0;
760 // The guard condition is expected to be in form of:
761 // cond1 && cond2 && cond3 ...
762 // Iterate over subconditions looking for icmp conditions which can be
763 // widened across loop iterations. Widening these conditions remember the
764 // resulting list of subconditions in Checks vector.
765 SmallVector<Value *, 4> Worklist(1, Condition);
767 Visited.insert(Condition);
768 Value *WideableCond = nullptr;
769 do {
770 Value *Condition = Worklist.pop_back_val();
771 Value *LHS, *RHS;
772 using namespace llvm::PatternMatch;
773 if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) {
774 if (Visited.insert(LHS).second)
775 Worklist.push_back(LHS);
776 if (Visited.insert(RHS).second)
777 Worklist.push_back(RHS);
778 continue;
779 }
780
781 if (match(Condition,
782 m_Intrinsic<Intrinsic::experimental_widenable_condition>())) {
783 // Pick any, we don't care which
784 WideableCond = Condition;
785 continue;
786 }
787
788 if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
789 if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander,
790 Guard)) {
791 Checks.push_back(*NewRangeCheck);
792 NumWidened++;
793 continue;
794 }
795 }
796
797 // Save the condition as is if we can't widen it
798 Checks.push_back(Condition);
799 } while (!Worklist.empty());
800 // At the moment, our matching logic for wideable conditions implicitly
801 // assumes we preserve the form: (br (and Cond, WC())). FIXME
802 // Note that if there were multiple calls to wideable condition in the
803 // traversal, we only need to keep one, and which one is arbitrary.
804 if (WideableCond)
805 Checks.push_back(WideableCond);
806 return NumWidened;
807}
808
809bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
810 SCEVExpander &Expander) {
811 LLVM_DEBUG(dbgs() << "Processing guard:\n");
812 LLVM_DEBUG(Guard->dump());
813
814 TotalConsidered++;
816 unsigned NumWidened = collectChecks(Checks, Guard->getOperand(0), Expander,
817 Guard);
818 if (NumWidened == 0)
819 return false;
820
821 TotalWidened += NumWidened;
822
823 // Emit the new guard condition
824 IRBuilder<> Builder(findInsertPt(Guard, Checks));
825 Value *AllChecks = Builder.CreateAnd(Checks);
826 auto *OldCond = Guard->getOperand(0);
827 Guard->setOperand(0, AllChecks);
829 Builder.SetInsertPoint(&*++BasicBlock::iterator(Guard));
830 Builder.CreateAssumption(OldCond);
831 }
832 RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
833
834 LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
835 return true;
836}
837
838bool LoopPredication::widenWidenableBranchGuardConditions(
839 BranchInst *BI, SCEVExpander &Expander) {
840 assert(isGuardAsWidenableBranch(BI) && "Must be!");
841 LLVM_DEBUG(dbgs() << "Processing guard:\n");
842 LLVM_DEBUG(BI->dump());
843
844 Value *Cond, *WC;
845 BasicBlock *IfTrueBB, *IfFalseBB;
846 bool Parsed = parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB);
847 assert(Parsed && "Must be able to parse widenable branch");
848 (void)Parsed;
849
850 TotalConsidered++;
852 unsigned NumWidened = collectChecks(Checks, BI->getCondition(),
853 Expander, BI);
854 if (NumWidened == 0)
855 return false;
856
857 TotalWidened += NumWidened;
858
859 // Emit the new guard condition
860 IRBuilder<> Builder(findInsertPt(BI, Checks));
861 Value *AllChecks = Builder.CreateAnd(Checks);
862 auto *OldCond = BI->getCondition();
863 BI->setCondition(AllChecks);
865 Builder.SetInsertPoint(IfTrueBB, IfTrueBB->getFirstInsertionPt());
866 // If this block has other predecessors, we might not be able to use Cond.
867 // In this case, create a Phi where every other input is `true` and input
868 // from guard block is Cond.
869 Value *AssumeCond = Cond;
870 if (!IfTrueBB->getUniquePredecessor()) {
871 auto *GuardBB = BI->getParent();
872 auto *PN = Builder.CreatePHI(Cond->getType(), pred_size(IfTrueBB),
873 "assume.cond");
874 for (auto *Pred : predecessors(IfTrueBB))
875 PN->addIncoming(Pred == GuardBB ? Cond : Builder.getTrue(), Pred);
876 AssumeCond = PN;
877 }
878 Builder.CreateAssumption(AssumeCond);
879 }
880 RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
882 "Stopped being a guard after transform?");
883
884 LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
885 return true;
886}
887
888std::optional<LoopICmp> LoopPredication::parseLoopLatchICmp() {
889 using namespace PatternMatch;
890
891 BasicBlock *LoopLatch = L->getLoopLatch();
892 if (!LoopLatch) {
893 LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n");
894 return std::nullopt;
895 }
896
897 auto *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
898 if (!BI || !BI->isConditional()) {
899 LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n");
900 return std::nullopt;
901 }
902 BasicBlock *TrueDest = BI->getSuccessor(0);
903 assert(
904 (TrueDest == L->getHeader() || BI->getSuccessor(1) == L->getHeader()) &&
905 "One of the latch's destinations must be the header");
906
907 auto *ICI = dyn_cast<ICmpInst>(BI->getCondition());
908 if (!ICI) {
909 LLVM_DEBUG(dbgs() << "Failed to match the latch condition!\n");
910 return std::nullopt;
911 }
912 auto Result = parseLoopICmp(ICI);
913 if (!Result) {
914 LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
915 return std::nullopt;
916 }
917
918 if (TrueDest != L->getHeader())
920
921 // Check affine first, so if it's not we don't try to compute the step
922 // recurrence.
923 if (!Result->IV->isAffine()) {
924 LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n");
925 return std::nullopt;
926 }
927
928 auto *Step = Result->IV->getStepRecurrence(*SE);
929 if (!isSupportedStep(Step)) {
930 LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n");
931 return std::nullopt;
932 }
933
934 auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) {
935 if (Step->isOne()) {
936 return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT &&
937 Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE;
938 } else {
939 assert(Step->isAllOnesValue() && "Step should be -1!");
940 return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT &&
941 Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_SGE;
942 }
943 };
944
945 normalizePredicate(SE, L, *Result);
946 if (IsUnsupportedPredicate(Step, Result->Pred)) {
947 LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred
948 << ")!\n");
949 return std::nullopt;
950 }
951
952 return Result;
953}
954
955bool LoopPredication::isLoopProfitableToPredicate() {
957 return true;
958
960 L->getExitEdges(ExitEdges);
961 // If there is only one exiting edge in the loop, it is always profitable to
962 // predicate the loop.
963 if (ExitEdges.size() == 1)
964 return true;
965
966 // Calculate the exiting probabilities of all exiting edges from the loop,
967 // starting with the LatchExitProbability.
968 // Heuristic for profitability: If any of the exiting blocks' probability of
969 // exiting the loop is larger than exiting through the latch block, it's not
970 // profitable to predicate the loop.
971 auto *LatchBlock = L->getLoopLatch();
972 assert(LatchBlock && "Should have a single latch at this point!");
973 auto *LatchTerm = LatchBlock->getTerminator();
974 assert(LatchTerm->getNumSuccessors() == 2 &&
975 "expected to be an exiting block with 2 succs!");
976 unsigned LatchBrExitIdx =
977 LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0;
978 // We compute branch probabilities without BPI. We do not rely on BPI since
979 // Loop predication is usually run in an LPM and BPI is only preserved
980 // lossily within loop pass managers, while BPI has an inherent notion of
981 // being complete for an entire function.
982
983 // If the latch exits into a deoptimize or an unreachable block, do not
984 // predicate on that latch check.
985 auto *LatchExitBlock = LatchTerm->getSuccessor(LatchBrExitIdx);
986 if (isa<UnreachableInst>(LatchTerm) ||
987 LatchExitBlock->getTerminatingDeoptimizeCall())
988 return false;
989
990 // Latch terminator has no valid profile data, so nothing to check
991 // profitability on.
992 if (!hasValidBranchWeightMD(*LatchTerm))
993 return true;
994
995 auto ComputeBranchProbability =
996 [&](const BasicBlock *ExitingBlock,
997 const BasicBlock *ExitBlock) -> BranchProbability {
998 auto *Term = ExitingBlock->getTerminator();
999 unsigned NumSucc = Term->getNumSuccessors();
1000 if (MDNode *ProfileData = getValidBranchWeightMDNode(*Term)) {
1001 SmallVector<uint32_t> Weights;
1002 extractBranchWeights(ProfileData, Weights);
1003 uint64_t Numerator = 0, Denominator = 0;
1004 for (auto [i, Weight] : llvm::enumerate(Weights)) {
1005 if (Term->getSuccessor(i) == ExitBlock)
1006 Numerator += Weight;
1007 Denominator += Weight;
1008 }
1009 return BranchProbability::getBranchProbability(Numerator, Denominator);
1010 } else {
1011 assert(LatchBlock != ExitingBlock &&
1012 "Latch term should always have profile data!");
1013 // No profile data, so we choose the weight as 1/num_of_succ(Src)
1014 return BranchProbability::getBranchProbability(1, NumSucc);
1015 }
1016 };
1017
1018 BranchProbability LatchExitProbability =
1019 ComputeBranchProbability(LatchBlock, LatchExitBlock);
1020
1021 // Protect against degenerate inputs provided by the user. Providing a value
1022 // less than one, can invert the definition of profitable loop predication.
1023 float ScaleFactor = LatchExitProbabilityScale;
1024 if (ScaleFactor < 1) {
1025 LLVM_DEBUG(
1026 dbgs()
1027 << "Ignored user setting for loop-predication-latch-probability-scale: "
1028 << LatchExitProbabilityScale << "\n");
1029 LLVM_DEBUG(dbgs() << "The value is set to 1.0\n");
1030 ScaleFactor = 1.0;
1031 }
1032 const auto LatchProbabilityThreshold = LatchExitProbability * ScaleFactor;
1033
1034 for (const auto &ExitEdge : ExitEdges) {
1035 BranchProbability ExitingBlockProbability =
1036 ComputeBranchProbability(ExitEdge.first, ExitEdge.second);
1037 // Some exiting edge has higher probability than the latch exiting edge.
1038 // No longer profitable to predicate.
1039 if (ExitingBlockProbability > LatchProbabilityThreshold)
1040 return false;
1041 }
1042
1043 // We have concluded that the most probable way to exit from the
1044 // loop is through the latch (or there's no profile information and all
1045 // exits are equally likely).
1046 return true;
1047}
1048
1049/// If we can (cheaply) find a widenable branch which controls entry into the
1050/// loop, return it.
1052 // Walk back through any unconditional executed blocks and see if we can find
1053 // a widenable condition which seems to control execution of this loop. Note
1054 // that we predict that maythrow calls are likely untaken and thus that it's
1055 // profitable to widen a branch before a maythrow call with a condition
1056 // afterwards even though that may cause the slow path to run in a case where
1057 // it wouldn't have otherwise.
1058 BasicBlock *BB = L->getLoopPreheader();
1059 if (!BB)
1060 return nullptr;
1061 do {
1062 if (BasicBlock *Pred = BB->getSinglePredecessor())
1063 if (BB == Pred->getSingleSuccessor()) {
1064 BB = Pred;
1065 continue;
1066 }
1067 break;
1068 } while (true);
1069
1070 if (BasicBlock *Pred = BB->getSinglePredecessor()) {
1071 auto *Term = Pred->getTerminator();
1072
1073 Value *Cond, *WC;
1074 BasicBlock *IfTrueBB, *IfFalseBB;
1075 if (parseWidenableBranch(Term, Cond, WC, IfTrueBB, IfFalseBB) &&
1076 IfTrueBB == BB)
1077 return cast<BranchInst>(Term);
1078 }
1079 return nullptr;
1080}
1081
1082/// Return the minimum of all analyzeable exit counts. This is an upper bound
1083/// on the actual exit count. If there are not at least two analyzeable exits,
1084/// returns SCEVCouldNotCompute.
1086 DominatorTree &DT,
1087 Loop *L) {
1088 SmallVector<BasicBlock *, 16> ExitingBlocks;
1089 L->getExitingBlocks(ExitingBlocks);
1090
1092 for (BasicBlock *ExitingBB : ExitingBlocks) {
1093 const SCEV *ExitCount = SE.getExitCount(L, ExitingBB);
1094 if (isa<SCEVCouldNotCompute>(ExitCount))
1095 continue;
1096 assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
1097 "We should only have known counts for exiting blocks that "
1098 "dominate latch!");
1099 ExitCounts.push_back(ExitCount);
1100 }
1101 if (ExitCounts.size() < 2)
1102 return SE.getCouldNotCompute();
1103 return SE.getUMinFromMismatchedTypes(ExitCounts);
1104}
1105
1106/// This implements an analogous, but entirely distinct transform from the main
1107/// loop predication transform. This one is phrased in terms of using a
1108/// widenable branch *outside* the loop to allow us to simplify loop exits in a
1109/// following loop. This is close in spirit to the IndVarSimplify transform
1110/// of the same name, but is materially different widening loosens legality
1111/// sharply.
1112bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
1113 // The transformation performed here aims to widen a widenable condition
1114 // above the loop such that all analyzeable exit leading to deopt are dead.
1115 // It assumes that the latch is the dominant exit for profitability and that
1116 // exits branching to deoptimizing blocks are rarely taken. It relies on the
1117 // semantics of widenable expressions for legality. (i.e. being able to fall
1118 // down the widenable path spuriously allows us to ignore exit order,
1119 // unanalyzeable exits, side effects, exceptional exits, and other challenges
1120 // which restrict the applicability of the non-WC based version of this
1121 // transform in IndVarSimplify.)
1122 //
1123 // NOTE ON POISON/UNDEF - We're hoisting an expression above guards which may
1124 // imply flags on the expression being hoisted and inserting new uses (flags
1125 // are only correct for current uses). The result is that we may be
1126 // inserting a branch on the value which can be either poison or undef. In
1127 // this case, the branch can legally go either way; we just need to avoid
1128 // introducing UB. This is achieved through the use of the freeze
1129 // instruction.
1130
1131 SmallVector<BasicBlock *, 16> ExitingBlocks;
1132 L->getExitingBlocks(ExitingBlocks);
1133
1134 if (ExitingBlocks.empty())
1135 return false; // Nothing to do.
1136
1137 auto *Latch = L->getLoopLatch();
1138 if (!Latch)
1139 return false;
1140
1141 auto *WidenableBR = FindWidenableTerminatorAboveLoop(L, *LI);
1142 if (!WidenableBR)
1143 return false;
1144
1145 const SCEV *LatchEC = SE->getExitCount(L, Latch);
1146 if (isa<SCEVCouldNotCompute>(LatchEC))
1147 return false; // profitability - want hot exit in analyzeable set
1148
1149 // At this point, we have found an analyzeable latch, and a widenable
1150 // condition above the loop. If we have a widenable exit within the loop
1151 // (for which we can't compute exit counts), drop the ability to further
1152 // widen so that we gain ability to analyze it's exit count and perform this
1153 // transform. TODO: It'd be nice to know for sure the exit became
1154 // analyzeable after dropping widenability.
1155 bool ChangedLoop = false;
1156
1157 for (auto *ExitingBB : ExitingBlocks) {
1158 if (LI->getLoopFor(ExitingBB) != L)
1159 continue;
1160
1161 auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
1162 if (!BI)
1163 continue;
1164
1165 Use *Cond, *WC;
1166 BasicBlock *IfTrueBB, *IfFalseBB;
1167 if (parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB) &&
1168 L->contains(IfTrueBB)) {
1169 WC->set(ConstantInt::getTrue(IfTrueBB->getContext()));
1170 ChangedLoop = true;
1171 }
1172 }
1173 if (ChangedLoop)
1174 SE->forgetLoop(L);
1175
1176 // The use of umin(all analyzeable exits) instead of latch is subtle, but
1177 // important for profitability. We may have a loop which hasn't been fully
1178 // canonicalized just yet. If the exit we chose to widen is provably never
1179 // taken, we want the widened form to *also* be provably never taken. We
1180 // can't guarantee this as a current unanalyzeable exit may later become
1181 // analyzeable, but we can at least avoid the obvious cases.
1182 const SCEV *MinEC = getMinAnalyzeableBackedgeTakenCount(*SE, *DT, L);
1183 if (isa<SCEVCouldNotCompute>(MinEC) || MinEC->getType()->isPointerTy() ||
1184 !SE->isLoopInvariant(MinEC, L) ||
1185 !Rewriter.isSafeToExpandAt(MinEC, WidenableBR))
1186 return ChangedLoop;
1187
1188 // Subtlety: We need to avoid inserting additional uses of the WC. We know
1189 // that it can only have one transitive use at the moment, and thus moving
1190 // that use to just before the branch and inserting code before it and then
1191 // modifying the operand is legal.
1192 auto *IP = cast<Instruction>(WidenableBR->getCondition());
1193 // Here we unconditionally modify the IR, so after this point we should return
1194 // only `true`!
1195 IP->moveBefore(WidenableBR);
1196 if (MSSAU)
1197 if (auto *MUD = MSSAU->getMemorySSA()->getMemoryAccess(IP))
1198 MSSAU->moveToPlace(MUD, WidenableBR->getParent(),
1200 Rewriter.setInsertPoint(IP);
1201 IRBuilder<> B(IP);
1202
1203 bool InvalidateLoop = false;
1204 Value *MinECV = nullptr; // lazily generated if needed
1205 for (BasicBlock *ExitingBB : ExitingBlocks) {
1206 // If our exiting block exits multiple loops, we can only rewrite the
1207 // innermost one. Otherwise, we're changing how many times the innermost
1208 // loop runs before it exits.
1209 if (LI->getLoopFor(ExitingBB) != L)
1210 continue;
1211
1212 // Can't rewrite non-branch yet.
1213 auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
1214 if (!BI)
1215 continue;
1216
1217 // If already constant, nothing to do.
1218 if (isa<Constant>(BI->getCondition()))
1219 continue;
1220
1221 const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);
1222 if (isa<SCEVCouldNotCompute>(ExitCount) ||
1223 ExitCount->getType()->isPointerTy() ||
1224 !Rewriter.isSafeToExpandAt(ExitCount, WidenableBR))
1225 continue;
1226
1227 const bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB));
1228 BasicBlock *ExitBB = BI->getSuccessor(ExitIfTrue ? 0 : 1);
1229 if (!ExitBB->getPostdominatingDeoptimizeCall())
1230 continue;
1231
1232 /// Here we can be fairly sure that executing this exit will most likely
1233 /// lead to executing llvm.experimental.deoptimize.
1234 /// This is a profitability heuristic, not a legality constraint.
1235
1236 // If we found a widenable exit condition, do two things:
1237 // 1) fold the widened exit test into the widenable condition
1238 // 2) fold the branch to untaken - avoids infinite looping
1239
1240 Value *ECV = Rewriter.expandCodeFor(ExitCount);
1241 if (!MinECV)
1242 MinECV = Rewriter.expandCodeFor(MinEC);
1243 Value *RHS = MinECV;
1244 if (ECV->getType() != RHS->getType()) {
1245 Type *WiderTy = SE->getWiderType(ECV->getType(), RHS->getType());
1246 ECV = B.CreateZExt(ECV, WiderTy);
1247 RHS = B.CreateZExt(RHS, WiderTy);
1248 }
1249 assert(!Latch || DT->dominates(ExitingBB, Latch));
1250 Value *NewCond = B.CreateICmp(ICmpInst::ICMP_UGT, ECV, RHS);
1251 // Freeze poison or undef to an arbitrary bit pattern to ensure we can
1252 // branch without introducing UB. See NOTE ON POISON/UNDEF above for
1253 // context.
1254 NewCond = B.CreateFreeze(NewCond);
1255
1256 widenWidenableBranch(WidenableBR, NewCond);
1257
1258 Value *OldCond = BI->getCondition();
1259 BI->setCondition(ConstantInt::get(OldCond->getType(), !ExitIfTrue));
1260 InvalidateLoop = true;
1261 }
1262
1263 if (InvalidateLoop)
1264 // We just mutated a bunch of loop exits changing there exit counts
1265 // widely. We need to force recomputation of the exit counts given these
1266 // changes. Note that all of the inserted exits are never taken, and
1267 // should be removed next time the CFG is modified.
1268 SE->forgetLoop(L);
1269
1270 // Always return `true` since we have moved the WidenableBR's condition.
1271 return true;
1272}
1273
1274bool LoopPredication::runOnLoop(Loop *Loop) {
1275 L = Loop;
1276
1277 LLVM_DEBUG(dbgs() << "Analyzing ");
1278 LLVM_DEBUG(L->dump());
1279
1280 Module *M = L->getHeader()->getModule();
1281
1282 // There is nothing to do if the module doesn't use guards
1283 auto *GuardDecl =
1284 M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard));
1285 bool HasIntrinsicGuards = GuardDecl && !GuardDecl->use_empty();
1286 auto *WCDecl = M->getFunction(
1287 Intrinsic::getName(Intrinsic::experimental_widenable_condition));
1288 bool HasWidenableConditions =
1289 PredicateWidenableBranchGuards && WCDecl && !WCDecl->use_empty();
1290 if (!HasIntrinsicGuards && !HasWidenableConditions)
1291 return false;
1292
1293 DL = &M->getDataLayout();
1294
1295 Preheader = L->getLoopPreheader();
1296 if (!Preheader)
1297 return false;
1298
1299 auto LatchCheckOpt = parseLoopLatchICmp();
1300 if (!LatchCheckOpt)
1301 return false;
1302 LatchCheck = *LatchCheckOpt;
1303
1304 LLVM_DEBUG(dbgs() << "Latch check:\n");
1305 LLVM_DEBUG(LatchCheck.dump());
1306
1307 if (!isLoopProfitableToPredicate()) {
1308 LLVM_DEBUG(dbgs() << "Loop not profitable to predicate!\n");
1309 return false;
1310 }
1311 // Collect all the guards into a vector and process later, so as not
1312 // to invalidate the instruction iterator.
1314 SmallVector<BranchInst *, 4> GuardsAsWidenableBranches;
1315 for (const auto BB : L->blocks()) {
1316 for (auto &I : *BB)
1317 if (isGuard(&I))
1318 Guards.push_back(cast<IntrinsicInst>(&I));
1320 isGuardAsWidenableBranch(BB->getTerminator()))
1321 GuardsAsWidenableBranches.push_back(
1322 cast<BranchInst>(BB->getTerminator()));
1323 }
1324
1325 SCEVExpander Expander(*SE, *DL, "loop-predication");
1326 bool Changed = false;
1327 for (auto *Guard : Guards)
1328 Changed |= widenGuardConditions(Guard, Expander);
1329 for (auto *Guard : GuardsAsWidenableBranches)
1330 Changed |= widenWidenableBranchGuardConditions(Guard, Expander);
1331 Changed |= predicateLoopExits(L, Expander);
1332
1333 if (MSSAU && VerifyMemorySSA)
1334 MSSAU->getMemorySSA()->verifyMemorySSA();
1335 return Changed;
1336}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
assume Assume Builder
SmallVector< MachineOperand, 4 > Cond
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DEBUG(X)
Definition: Debug.h:101
static cl::opt< bool > SkipProfitabilityChecks("irce-skip-profitability-checks", cl::Hidden, cl::init(false))
static cl::opt< float > LatchExitProbabilityScale("loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0), cl::desc("scale factor for the latch probability. Value should be greater " "than 1. Lower values are ignored"))
static void normalizePredicate(ScalarEvolution *SE, Loop *L, LoopICmp &RC)
static cl::opt< bool > SkipProfitabilityChecks("loop-predication-skip-profitability-checks", cl::Hidden, cl::init(false))
static const SCEV * getMinAnalyzeableBackedgeTakenCount(ScalarEvolution &SE, DominatorTree &DT, Loop *L)
Return the minimum of all analyzeable exit counts.
static cl::opt< bool > EnableCountDownLoop("loop-predication-enable-count-down-loop", cl::Hidden, cl::init(true))
static cl::opt< bool > EnableIVTruncation("loop-predication-enable-iv-truncation", cl::Hidden, cl::init(true))
static std::optional< LoopICmp > generateLoopLatchCheck(const DataLayout &DL, ScalarEvolution &SE, const LoopICmp LatchCheck, Type *RangeCheckType)
loop predication
static cl::opt< bool > PredicateWidenableBranchGuards("loop-predication-predicate-widenable-branches-to-deopt", cl::Hidden, cl::desc("Whether or not we should predicate guards " "expressed as widenable branches to deoptimize blocks"), cl::init(true))
static bool isSafeToTruncateWideIVType(const DataLayout &DL, ScalarEvolution &SE, const LoopICmp LatchCheck, Type *RangeCheckType)
static cl::opt< bool > InsertAssumesOfPredicatedGuardsConditions("loop-predication-insert-assumes-of-predicated-guards-conditions", cl::Hidden, cl::desc("Whether or not we should insert assumes of conditions of " "predicated guards"), cl::init(true))
static BranchInst * FindWidenableTerminatorAboveLoop(Loop *L, LoopInfo &LI)
If we can (cheaply) find a widenable branch which controls entry into the loop, return it.
#define I(x, y, z)
Definition: MD5.cpp:58
This file exposes an interface to building/using memory SSA to walk memory instructions using a use/d...
Module.h This file contains the declarations for the Module class.
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
This file contains the declarations for profiling metadata utility functions.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
Virtual Register Rewriter
Definition: VirtRegMap.cpp:237
Value * RHS
Value * LHS
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:620
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
LLVM Basic Block Representation.
Definition: BasicBlock.h:56
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
Definition: BasicBlock.cpp:245
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:284
const BasicBlock * getUniquePredecessor() const
Return the predecessor of this block if it has a unique predecessor block.
Definition: BasicBlock.cpp:292
const BasicBlock * getSingleSuccessor() const
Return the successor of this block if it has a single successor.
Definition: BasicBlock.cpp:314
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:87
LLVMContext & getContext() const
Get the context in which this basic block lives.
Definition: BasicBlock.cpp:35
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:127
const CallInst * getPostdominatingDeoptimizeCall() const
Returns the call instruction calling @llvm.experimental.deoptimize that is present either in current ...
Definition: BasicBlock.cpp:196
Conditional or Unconditional Branch instruction.
void setCondition(Value *V)
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
Value * getCondition() const
Legacy analysis pass which computes BranchProbabilityInfo.
static BranchProbability getBranchProbability(uint64_t Numerator, uint64_t Denominator)
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:718
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:747
@ ICMP_SLE
signed less or equal
Definition: InstrTypes.h:748
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:742
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:741
@ ICMP_SGT
signed greater than
Definition: InstrTypes.h:745
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:743
@ ICMP_NE
not equal
Definition: InstrTypes.h:740
@ ICMP_SGE
signed greater or equal
Definition: InstrTypes.h:746
@ ICMP_ULE
unsigned less or equal
Definition: InstrTypes.h:744
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:859
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition: InstrTypes.h:832
Predicate getPredicate() const
Return the predicate for this instruction.
Definition: InstrTypes.h:808
Predicate getFlippedStrictnessPredicate() const
For predicate of kind "is X or equal to 0" returns the predicate "is X".
Definition: InstrTypes.h:925
static ConstantInt * getTrue(LLVMContext &Context)
Definition: Constants.cpp:833
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:888
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
This instruction compares its operands according to the predicate given to the constructor.
bool isEquality() const
Return true if this predicate is either EQ or NE.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2558
const BasicBlock * getParent() const
Definition: Instruction.h:90
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:47
This class provides an interface for updating the loop pass manager based on mutations to the loop ne...
virtual bool runOnLoop(Loop *L, LPPassManager &LPM)=0
bool skipLoop(const Loop *L) const
Optional passes call this function to check whether the pass should be skipped.
Definition: LoopPass.cpp:370
PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U)
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:547
Metadata node.
Definition: Metadata.h:943
An analysis that produces MemorySSA for a function.
Definition: MemorySSA.h:936
Legacy analysis pass which computes MemorySSA.
Definition: MemorySSA.h:986
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Pass interface - Implemented by all 'passes'.
Definition: Pass.h:91
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:158
This node represents a polynomial recurrence on the trip count of the specified loop.
This class uses information about analyze scalars to rewrite expressions in canonical form.
bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint) const
Return true if the given expression is safe to expand in the sense that all materialized values are d...
Value * expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I)
Insert code to directly compute the specified SCEV expression into the program.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents an analyzed expression in the program.
bool isOne() const
Return true if the expression is a constant one.
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
Type * getType() const
Return the LLVM type of this SCEV expression.
The main scalar evolution driver.
bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
const SCEV * getCouldNotCompute()
const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:365
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:450
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:577
void push_back(const T &Elt)
Definition: SmallVector.h:416
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:258
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
void setOperand(unsigned i, Value *Val)
Definition: User.h:174
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void dump() const
Support for debugging, callable in GDB: V->dump()
Definition: AsmWriter.cpp:4941
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Definition: Function.cpp:979
BinaryOp_match< LHS, RHS, Instruction::And > m_And(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:76
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:445
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition: Local.cpp:537
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are are tuples (A,...
Definition: STLExtras.h:2424
void widenWidenableBranch(BranchInst *WidenableBR, Value *NewCond)
Given a branch we know is widenable (defined per Analysis/GuardUtils.h), widen it such that condition...
Definition: GuardUtils.cpp:82
Interval::succ_iterator succ_begin(Interval *I)
succ_begin/succ_end - define methods so that Intervals may be used just like BasicBlocks can with the...
Definition: Interval.h:99
bool isGuard(const User *U)
Returns true iff U has semantics of a guard expressed in a form of call of llvm.experimental....
Definition: GuardUtils.cpp:18
Pass * createLoopPredicationPass()
MDNode * getValidBranchWeightMDNode(const Instruction &I)
Get the valid branch weights metadata node.
bool isModSet(const ModRefInfo MRI)
Definition: ModRef.h:48
bool parseWidenableBranch(const User *U, Value *&Condition, Value *&WidenableCondition, BasicBlock *&IfTrueBB, BasicBlock *&IfFalseBB)
If U is widenable branch looking like: cond = ... wc = call i1 @llvm.experimental....
Definition: GuardUtils.cpp:44
bool hasValidBranchWeightMD(const Instruction &I)
Checks if an instructions has valid Branch Weight Metadata.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
void getLoopAnalysisUsage(AnalysisUsage &AU)
Helper to consistently add the set of standard passes to a loop pass's AnalysisUsage.
Definition: LoopUtils.cpp:141
bool VerifyMemorySSA
Enables verification of MemorySSA.
Definition: MemorySSA.cpp:89
bool isGuardAsWidenableBranch(const User *U)
Returns true iff U has semantics of a guard expressed in a form of a widenable conditional branch to ...
Definition: GuardUtils.cpp:29
bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl< uint32_t > &Weights)
Extract branch weights from MD_prof metadata.
PreservedAnalyses getLoopPassPreservedAnalyses()
Returns the minimum set of Analyses that all loop passes must preserve.
auto predecessors(const MachineBasicBlock *BB)
void initializeLoopPredicationLegacyPassPass(PassRegistry &)
unsigned pred_size(const MachineBasicBlock *BB)
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
The adaptor from a function pass to a loop pass computes these analyses and makes them available to t...